From 7d90aa76ff7116269dea42f2c6629ea6b868b0de Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 6 Nov 2012 16:04:23 +0000 Subject: [PATCH] Add _instance to Embedded Documents Fixes MongoEngine/mongoengine#139 --- mongoengine/base/datastructures.py | 20 +++++++++++- mongoengine/base/fields.py | 4 +++ mongoengine/document.py | 2 ++ mongoengine/fields.py | 7 +++-- tests/document/instance.py | 50 ++++++++++++++++++++++++------ 5 files changed, 70 insertions(+), 13 deletions(-) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 9a7620e6..c750b5ba 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,4 +1,5 @@ import weakref +from mongoengine.common import _import_class __all__ = ("BaseDict", "BaseList") @@ -16,6 +17,14 @@ class BaseDict(dict): self._name = name return super(BaseDict, self).__init__(dict_items) + def __getitem__(self, *args, **kwargs): + value = super(BaseDict, self).__getitem__(*args, **kwargs) + + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = self._instance + return value + def __setitem__(self, *args, **kwargs): self._mark_as_changed() return super(BaseDict, self).__setitem__(*args, **kwargs) @@ -75,6 +84,14 @@ class BaseList(list): self._name = name return super(BaseList, self).__init__(list_items) + def __getitem__(self, *args, **kwargs): + value = super(BaseList, self).__getitem__(*args, **kwargs) + + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = self._instance + return value + def __setitem__(self, *args, **kwargs): self._mark_as_changed() return super(BaseList, self).__setitem__(*args, **kwargs) @@ -84,7 +101,8 @@ class BaseList(list): return super(BaseList, self).__delitem__(*args, **kwargs) def __getstate__(self): - self.observer = None + self.instance = None + self._dereferenced = False return self def __setstate__(self, state): diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 00e040ca..fc1a0767 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -1,5 +1,6 @@ import operator import warnings +import weakref from bson import DBRef, ObjectId @@ -71,6 +72,9 @@ class BaseField(object): if callable(value): value = value() + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = weakref.proxy(instance) return value def __set__(self, instance, value): diff --git a/mongoengine/document.py b/mongoengine/document.py index 95dd6246..adbdcca2 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -40,6 +40,8 @@ class EmbeddedDocument(BaseDocument): my_metaclass = DocumentMetaclass __metaclass__ = DocumentMetaclass + _instance = None + def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) self._changed_fields = [] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 15e1626f..94e11556 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -625,7 +625,8 @@ class SortedListField(ListField): def to_mongo(self, value): value = super(SortedListField, self).to_mongo(value) if self._ordering is not None: - return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse) + return sorted(value, key=itemgetter(self._ordering), + reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse) @@ -655,7 +656,9 @@ class DictField(ComplexBaseField): self.error('Only dictionaries may be used in a DictField') if any(k for k in value.keys() if not isinstance(k, basestring)): - self.error('Invalid dictionary key - documents must have only string keys') + msg = ("Invalid dictionary key - documents must " + "have only string keys") + self.error(msg) if any(('.' in k or '$' in k) for k in value.keys()): self.error('Invalid dictionary key name - keys may not contain "."' ' or "$" characters') diff --git a/tests/document/instance.py b/tests/document/instance.py index fcc43bad..48ddc10d 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -183,9 +183,6 @@ class InstanceTest(unittest.TestCase): self.assertEqual(list_stats, CompareStats.objects.first().stats) - - - def test_db_field_load(self): """Ensure we load data correctly """ @@ -214,24 +211,24 @@ class InstanceTest(unittest.TestCase): class Person(Document): name = StringField(required=True) - rank_ = EmbeddedDocumentField(Rank, required=False, db_field='rank') + rank_ = EmbeddedDocumentField(Rank, + required=False, + db_field='rank') @property def rank(self): - return self.rank_.title if self.rank_ is not None else "Private" + if self.rank_ is None: + return "Private" + return self.rank_.title Person.drop_collection() Person(name="Jack", rank_=Rank(title="Corporal")).save() - Person(name="Fred").save() self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") self.assertEqual(Person.objects.get(name="Fred").rank, "Private") - - - def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys. """ @@ -247,7 +244,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(User._meta['id_field'], 'username') def create_invalid_user(): - User(name='test').save() # no primary key field + User(name='test').save() # no primary key field self.assertRaises(ValidationError, create_invalid_user) def define_invalid_user(): @@ -424,6 +421,36 @@ class InstanceTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) + def test_embedded_document_instance(self): + """Ensure that embedded documents can reference parent instance + """ + class Embedded(EmbeddedDocument): + string = StringField() + + class Doc(Document): + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + Doc(embedded_field=Embedded(string="Hi")).save() + + doc = Doc.objects.get() + self.assertEqual(doc, doc.embedded_field._instance) + + def test_embedded_document_complex_instance(self): + """Ensure that embedded documents in complex fields can reference + parent instance""" + class Embedded(EmbeddedDocument): + string = StringField() + + class Doc(Document): + embedded_field = ListField(EmbeddedDocumentField(Embedded)) + + Doc.drop_collection() + Doc(embedded_field=[Embedded(string="Hi")]).save() + + doc = Doc.objects.get() + self.assertEqual(doc, doc.embedded_field[0]._instance) + def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. """ @@ -442,6 +469,7 @@ class InstanceTest(unittest.TestCase): comment.date = datetime.now() comment.validate() + self.assertEqual(comment._instance, None) def test_embedded_db_field_validate(self): @@ -475,11 +503,13 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['_id'], person.id) # Test skipping validation on save + class Recipient(Document): email = EmailField(required=True) recipient = Recipient(email='root@localhost') self.assertRaises(ValidationError, recipient.save) + try: recipient.save(validate=False) except ValidationError: