From 0902b957641e09861f0864e09501b174624577f0 Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Mon, 18 Oct 2010 00:27:40 +0100 Subject: [PATCH] Added support for recursive embedded documents --- mongoengine/base.py | 7 +++---- mongoengine/fields.py | 33 +++++++++++++++++++++------------ tests/fields.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 16 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 1f7ba1fe..2253e4a2 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -204,6 +204,9 @@ class DocumentMetaclass(type): exc = subclass_exception('MultipleObjectsReturned', base_excs, module) new_class.add_to_class('MultipleObjectsReturned', exc) + global _document_registry + _document_registry[name] = new_class + return new_class def add_to_class(self, name, value): @@ -216,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): """ def __new__(cls, name, bases, attrs): - global _document_registry - super_new = super(TopLevelDocumentMetaclass, cls).__new__ # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. @@ -322,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class._fields['id'] = ObjectIdField(db_field='_id') new_class.id = new_class._fields['id'] - _document_registry[name] = new_class - return new_class diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 62d2ef2f..63107f23 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -233,33 +233,43 @@ class EmbeddedDocumentField(BaseField): :class:`~mongoengine.EmbeddedDocument`. """ - def __init__(self, document, **kwargs): - if not issubclass(document, EmbeddedDocument): - raise ValidationError('Invalid embedded document class provided ' - 'to an EmbeddedDocumentField') - self.document = document + def __init__(self, document_type, **kwargs): + if not isinstance(document_type, basestring): + if not issubclass(document_type, EmbeddedDocument): + raise ValidationError('Invalid embedded document class ' + 'provided to an EmbeddedDocumentField') + self.document_type_obj = document_type super(EmbeddedDocumentField, self).__init__(**kwargs) + @property + def document_type(self): + if isinstance(self.document_type_obj, basestring): + if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: + self.document_type_obj = self.owner_document + else: + self.document_type_obj = get_document(self.document_type_obj) + return self.document_type_obj + def to_python(self, value): - if not isinstance(value, self.document): - return self.document._from_son(value) + if not isinstance(value, self.document_type): + return self.document_type._from_son(value) return value def to_mongo(self, value): - return self.document.to_mongo(value) + return self.document_type.to_mongo(value) def validate(self, value): """Make sure that the document instance is an instance of the EmbeddedDocument subclass provided when the document was defined. """ # Using isinstance also works for subclasses of self.document - if not isinstance(value, self.document): + if not isinstance(value, self.document_type): raise ValidationError('Invalid embedded document instance ' 'provided to an EmbeddedDocumentField') - self.document.validate(value) + self.document_type.validate(value) def lookup_member(self, member_name): - return self.document._fields.get(member_name) + return self.document_type._fields.get(member_name) def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -413,7 +423,6 @@ class ReferenceField(BaseField): raise ValidationError('Argument to ReferenceField constructor ' 'must be a document class or a string') self.document_type_obj = document_type - self.document_obj = None super(ReferenceField, self).__init__(**kwargs) @property diff --git a/tests/fields.py b/tests/fields.py index e30f843e..208b4643 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -423,6 +423,36 @@ class FieldTest(unittest.TestCase): self.assertEqual(peter.boss, bill) self.assertEqual(peter.friends, friends) + def test_recursive_embedding(self): + """Ensure that EmbeddedDocumentFields can contain their own documents. + """ + class Tree(Document): + name = StringField() + children = ListField(EmbeddedDocumentField('TreeNode')) + + class TreeNode(EmbeddedDocument): + name = StringField() + children = ListField(EmbeddedDocumentField('self')) + + tree = Tree(name="Tree") + + first_child = TreeNode(name="Child 1") + tree.children.append(first_child) + + second_child = TreeNode(name="Child 2") + first_child.children.append(second_child) + + third_child = TreeNode(name="Child 3") + first_child.children.append(third_child) + + tree.save() + + tree_obj = Tree.objects.first() + self.assertEqual(len(tree.children), 1) + self.assertEqual(tree.children[0].name, first_child.name) + self.assertEqual(tree.children[0].children[0].name, second_child.name) + self.assertEqual(tree.children[0].children[1].name, third_child.name) + def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. """