Added support for recursive embedded documents

This commit is contained in:
Harry Marr 2010-10-18 00:27:40 +01:00
parent e93c4c87d8
commit 0902b95764
3 changed files with 54 additions and 16 deletions

View File

@ -204,6 +204,9 @@ class DocumentMetaclass(type):
exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
new_class.add_to_class('MultipleObjectsReturned', exc)
global _document_registry
_document_registry[name] = new_class
return new_class
def add_to_class(self, name, value):
@ -216,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
"""
def __new__(cls, name, bases, attrs):
global _document_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc.
@ -322,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
_document_registry[name] = new_class
return new_class

View File

@ -233,33 +233,43 @@ class EmbeddedDocumentField(BaseField):
:class:`~mongoengine.EmbeddedDocument`.
"""
def __init__(self, document, **kwargs):
if not issubclass(document, EmbeddedDocument):
raise ValidationError('Invalid embedded document class provided '
'to an EmbeddedDocumentField')
self.document = document
def __init__(self, document_type, **kwargs):
if not isinstance(document_type, basestring):
if not issubclass(document_type, EmbeddedDocument):
raise ValidationError('Invalid embedded document class '
'provided to an EmbeddedDocumentField')
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def to_python(self, value):
if not isinstance(value, self.document):
return self.document._from_son(value)
if not isinstance(value, self.document_type):
return self.document_type._from_son(value)
return value
def to_mongo(self, value):
return self.document.to_mongo(value)
return self.document_type.to_mongo(value)
def validate(self, value):
"""Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined.
"""
# Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document):
if not isinstance(value, self.document_type):
raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField')
self.document.validate(value)
self.document_type.validate(value)
def lookup_member(self, member_name):
return self.document._fields.get(member_name)
return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value):
return self.to_mongo(value)
@ -413,7 +423,6 @@ class ReferenceField(BaseField):
raise ValidationError('Argument to ReferenceField constructor '
'must be a document class or a string')
self.document_type_obj = document_type
self.document_obj = None
super(ReferenceField, self).__init__(**kwargs)
@property

View File

@ -423,6 +423,36 @@ class FieldTest(unittest.TestCase):
self.assertEqual(peter.boss, bill)
self.assertEqual(peter.friends, friends)
def test_recursive_embedding(self):
"""Ensure that EmbeddedDocumentFields can contain their own documents.
"""
class Tree(Document):
name = StringField()
children = ListField(EmbeddedDocumentField('TreeNode'))
class TreeNode(EmbeddedDocument):
name = StringField()
children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1")
tree.children.append(first_child)
second_child = TreeNode(name="Child 2")
first_child.children.append(second_child)
third_child = TreeNode(name="Child 3")
first_child.children.append(third_child)
tree.save()
tree_obj = Tree.objects.first()
self.assertEqual(len(tree.children), 1)
self.assertEqual(tree.children[0].name, first_child.name)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents.
"""