From 58f877de1a6813b3eb9f875e3988ae6cb8acfafb Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Sun, 28 Feb 2010 23:16:51 +0000 Subject: [PATCH] Added recursive / document name references --- mongoengine/fields.py | 13 ++++++++++++- tests/fields.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index a4ee7c13..8456be2f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -13,6 +13,8 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'DecimalField', 'URLField', 'GenericReferenceField'] +RECURSIVE_REFERENCE_CONSTANT = 'self' + class StringField(BaseField): """A unicode string field. @@ -334,10 +336,19 @@ class ReferenceField(BaseField): if not issubclass(document_type, (Document, basestring)): raise ValidationError('Argument to ReferenceField constructor ' 'must be a document class or a string') - self.document_type = document_type + self.document_type_obj = document_type self.document_obj = None super(ReferenceField, self).__init__(**kwargs) + @property + def document_type(self): + if isinstance(self.document_type_obj, basestring): + if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: + self.document_type_obj = self.owner_document + else: + self.document_type_obj = get_document(self.document_type_obj) + return self.document_type_obj + def __get__(self, instance, owner): """Descriptor to allow lazy dereferencing. """ diff --git a/tests/fields.py b/tests/fields.py index 382cd455..94f65186 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -350,6 +350,40 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() + def test_recursive_reference(self): + """Ensure that ReferenceFields can reference their own documents. + """ + class Employee(Document): + name = StringField() + boss = ReferenceField('self') + + bill = Employee(name='Bill Lumbergh') + bill.save() + peter = Employee(name='Peter Gibbons', boss=bill) + peter.save() + + peter = Employee.objects.with_id(peter.id) + self.assertEqual(peter.boss, bill) + + def test_undefined_reference(self): + """Ensure that ReferenceFields may reference undefined Documents. + """ + class Product(Document): + name = StringField() + company = ReferenceField('Company') + + class Company(Document): + name = StringField() + + ten_gen = Company(name='10gen') + ten_gen.save() + mongodb = Product(name='MongoDB', company=ten_gen) + mongodb.save() + + obj = Product.objects(company=ten_gen).first() + self.assertEqual(obj, mongodb) + self.assertEqual(obj.company, ten_gen) + def test_reference_query_conversion(self): """Ensure that ReferenceFields can be queried using objects and values of the type of the primary key of the referenced object.