Added support for recursive embedded documents

This commit is contained in:
Harry Marr 2010-10-18 00:27:40 +01:00
parent e93c4c87d8
commit 0902b95764
3 changed files with 54 additions and 16 deletions

View File

@ -204,6 +204,9 @@ class DocumentMetaclass(type):
exc = subclass_exception('MultipleObjectsReturned', base_excs, module) exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
new_class.add_to_class('MultipleObjectsReturned', exc) new_class.add_to_class('MultipleObjectsReturned', exc)
global _document_registry
_document_registry[name] = new_class
return new_class return new_class
def add_to_class(self, name, value): def add_to_class(self, name, value):
@ -216,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
global _document_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__ super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have # Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc. # their own metadata with DB collection, etc.
@ -322,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._fields['id'] = ObjectIdField(db_field='_id') new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id'] new_class.id = new_class._fields['id']
_document_registry[name] = new_class
return new_class return new_class

View File

@ -233,33 +233,43 @@ class EmbeddedDocumentField(BaseField):
:class:`~mongoengine.EmbeddedDocument`. :class:`~mongoengine.EmbeddedDocument`.
""" """
def __init__(self, document, **kwargs): def __init__(self, document_type, **kwargs):
if not issubclass(document, EmbeddedDocument): if not isinstance(document_type, basestring):
raise ValidationError('Invalid embedded document class provided ' if not issubclass(document_type, EmbeddedDocument):
'to an EmbeddedDocumentField') raise ValidationError('Invalid embedded document class '
self.document = document 'provided to an EmbeddedDocumentField')
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs) super(EmbeddedDocumentField, self).__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def to_python(self, value): def to_python(self, value):
if not isinstance(value, self.document): if not isinstance(value, self.document_type):
return self.document._from_son(value) return self.document_type._from_son(value)
return value return value
def to_mongo(self, value): def to_mongo(self, value):
return self.document.to_mongo(value) return self.document_type.to_mongo(value)
def validate(self, value): def validate(self, value):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined. EmbeddedDocument subclass provided when the document was defined.
""" """
# Using isinstance also works for subclasses of self.document # Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document): if not isinstance(value, self.document_type):
raise ValidationError('Invalid embedded document instance ' raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField') 'provided to an EmbeddedDocumentField')
self.document.validate(value) self.document_type.validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document._fields.get(member_name) return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
@ -413,7 +423,6 @@ class ReferenceField(BaseField):
raise ValidationError('Argument to ReferenceField constructor ' raise ValidationError('Argument to ReferenceField constructor '
'must be a document class or a string') 'must be a document class or a string')
self.document_type_obj = document_type self.document_type_obj = document_type
self.document_obj = None
super(ReferenceField, self).__init__(**kwargs) super(ReferenceField, self).__init__(**kwargs)
@property @property

View File

@ -423,6 +423,36 @@ class FieldTest(unittest.TestCase):
self.assertEqual(peter.boss, bill) self.assertEqual(peter.boss, bill)
self.assertEqual(peter.friends, friends) self.assertEqual(peter.friends, friends)
def test_recursive_embedding(self):
"""Ensure that EmbeddedDocumentFields can contain their own documents.
"""
class Tree(Document):
name = StringField()
children = ListField(EmbeddedDocumentField('TreeNode'))
class TreeNode(EmbeddedDocument):
name = StringField()
children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1")
tree.children.append(first_child)
second_child = TreeNode(name="Child 2")
first_child.children.append(second_child)
third_child = TreeNode(name="Child 3")
first_child.children.append(third_child)
tree.save()
tree_obj = Tree.objects.first()
self.assertEqual(len(tree.children), 1)
self.assertEqual(tree.children[0].name, first_child.name)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
def test_undefined_reference(self): def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents. """Ensure that ReferenceFields may reference undefined Documents.
""" """