Added support for recursive embedded documents
This commit is contained in:
		| @@ -204,6 +204,9 @@ class DocumentMetaclass(type): | |||||||
|         exc = subclass_exception('MultipleObjectsReturned', base_excs, module) |         exc = subclass_exception('MultipleObjectsReturned', base_excs, module) | ||||||
|         new_class.add_to_class('MultipleObjectsReturned', exc) |         new_class.add_to_class('MultipleObjectsReturned', exc) | ||||||
|      |      | ||||||
|  |         global _document_registry | ||||||
|  |         _document_registry[name] = new_class | ||||||
|  |  | ||||||
|         return new_class |         return new_class | ||||||
|          |          | ||||||
|     def add_to_class(self, name, value): |     def add_to_class(self, name, value): | ||||||
| @@ -216,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __new__(cls, name, bases, attrs): |     def __new__(cls, name, bases, attrs): | ||||||
|         global _document_registry |  | ||||||
|  |  | ||||||
|         super_new = super(TopLevelDocumentMetaclass, cls).__new__ |         super_new = super(TopLevelDocumentMetaclass, cls).__new__ | ||||||
|         # Classes defined in this package are abstract and should not have  |         # Classes defined in this package are abstract and should not have  | ||||||
|         # their own metadata with DB collection, etc. |         # their own metadata with DB collection, etc. | ||||||
| @@ -322,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|             new_class._fields['id'] = ObjectIdField(db_field='_id') |             new_class._fields['id'] = ObjectIdField(db_field='_id') | ||||||
|             new_class.id = new_class._fields['id'] |             new_class.id = new_class._fields['id'] | ||||||
|  |  | ||||||
|         _document_registry[name] = new_class |  | ||||||
|  |  | ||||||
|         return new_class |         return new_class | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -233,33 +233,43 @@ class EmbeddedDocumentField(BaseField): | |||||||
|     :class:`~mongoengine.EmbeddedDocument`. |     :class:`~mongoengine.EmbeddedDocument`. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, document, **kwargs): |     def __init__(self, document_type, **kwargs): | ||||||
|         if not issubclass(document, EmbeddedDocument): |         if not isinstance(document_type, basestring): | ||||||
|             raise ValidationError('Invalid embedded document class provided ' |             if not issubclass(document_type, EmbeddedDocument): | ||||||
|                                   'to an EmbeddedDocumentField') |                 raise ValidationError('Invalid embedded document class ' | ||||||
|         self.document = document |                                       'provided to an EmbeddedDocumentField') | ||||||
|  |         self.document_type_obj = document_type | ||||||
|         super(EmbeddedDocumentField, self).__init__(**kwargs) |         super(EmbeddedDocumentField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def document_type(self): | ||||||
|  |         if isinstance(self.document_type_obj, basestring): | ||||||
|  |             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||||
|  |                 self.document_type_obj = self.owner_document | ||||||
|  |             else: | ||||||
|  |                 self.document_type_obj = get_document(self.document_type_obj) | ||||||
|  |         return self.document_type_obj | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         if not isinstance(value, self.document): |         if not isinstance(value, self.document_type): | ||||||
|             return self.document._from_son(value) |             return self.document_type._from_son(value) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         return self.document.to_mongo(value) |         return self.document_type.to_mongo(value) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """Make sure that the document instance is an instance of the |         """Make sure that the document instance is an instance of the | ||||||
|         EmbeddedDocument subclass provided when the document was defined. |         EmbeddedDocument subclass provided when the document was defined. | ||||||
|         """ |         """ | ||||||
|         # Using isinstance also works for subclasses of self.document |         # Using isinstance also works for subclasses of self.document | ||||||
|         if not isinstance(value, self.document): |         if not isinstance(value, self.document_type): | ||||||
|             raise ValidationError('Invalid embedded document instance ' |             raise ValidationError('Invalid embedded document instance ' | ||||||
|                                   'provided to an EmbeddedDocumentField') |                                   'provided to an EmbeddedDocumentField') | ||||||
|         self.document.validate(value) |         self.document_type.validate(value) | ||||||
|  |  | ||||||
|     def lookup_member(self, member_name): |     def lookup_member(self, member_name): | ||||||
|         return self.document._fields.get(member_name) |         return self.document_type._fields.get(member_name) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
| @@ -413,7 +423,6 @@ class ReferenceField(BaseField): | |||||||
|                 raise ValidationError('Argument to ReferenceField constructor ' |                 raise ValidationError('Argument to ReferenceField constructor ' | ||||||
|                                       'must be a document class or a string') |                                       'must be a document class or a string') | ||||||
|         self.document_type_obj = document_type |         self.document_type_obj = document_type | ||||||
|         self.document_obj = None |  | ||||||
|         super(ReferenceField, self).__init__(**kwargs) |         super(ReferenceField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|   | |||||||
| @@ -423,6 +423,36 @@ class FieldTest(unittest.TestCase): | |||||||
|         self.assertEqual(peter.boss, bill) |         self.assertEqual(peter.boss, bill) | ||||||
|         self.assertEqual(peter.friends, friends) |         self.assertEqual(peter.friends, friends) | ||||||
|  |  | ||||||
|  |     def test_recursive_embedding(self): | ||||||
|  |         """Ensure that EmbeddedDocumentFields can contain their own documents. | ||||||
|  |         """ | ||||||
|  |         class Tree(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             children = ListField(EmbeddedDocumentField('TreeNode')) | ||||||
|  |  | ||||||
|  |         class TreeNode(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |             children = ListField(EmbeddedDocumentField('self')) | ||||||
|  |          | ||||||
|  |         tree = Tree(name="Tree") | ||||||
|  |  | ||||||
|  |         first_child = TreeNode(name="Child 1") | ||||||
|  |         tree.children.append(first_child) | ||||||
|  |  | ||||||
|  |         second_child = TreeNode(name="Child 2") | ||||||
|  |         first_child.children.append(second_child) | ||||||
|  |          | ||||||
|  |         third_child = TreeNode(name="Child 3") | ||||||
|  |         first_child.children.append(third_child) | ||||||
|  |  | ||||||
|  |         tree.save() | ||||||
|  |  | ||||||
|  |         tree_obj = Tree.objects.first() | ||||||
|  |         self.assertEqual(len(tree.children), 1) | ||||||
|  |         self.assertEqual(tree.children[0].name, first_child.name) | ||||||
|  |         self.assertEqual(tree.children[0].children[0].name, second_child.name) | ||||||
|  |         self.assertEqual(tree.children[0].children[1].name, third_child.name) | ||||||
|  |  | ||||||
|     def test_undefined_reference(self): |     def test_undefined_reference(self): | ||||||
|         """Ensure that ReferenceFields may reference undefined Documents. |         """Ensure that ReferenceFields may reference undefined Documents. | ||||||
|         """ |         """ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user