From 0fb629e24ccb6ce6f2b8c3bf92fd3c239ba5ef11 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 11 Jul 2011 16:01:48 +0100 Subject: [PATCH] Added cascading deletes Also ensured that unsetting works when not the default value of a field --- docs/changelog.rst | 1 + mongoengine/base.py | 19 ++++++++--- mongoengine/document.py | 17 +++++++++- tests/document.py | 70 ++++++++++++++++++++++++++++++++++++----- 4 files changed, 94 insertions(+), 13 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index cad1b687..1b4842e7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added cascading saves - so changes to Referenced documents are saved on .save() - Added select_related() support - Added support for the positional operator - Updated geo index checking to be recursive and check in embedded documents diff --git a/mongoengine/base.py b/mongoengine/base.py index b83164aa..25b049a3 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -747,10 +747,21 @@ class BaseDocument(object): if '_id' in set_data: del(set_data['_id']) - for k,v in set_data.items(): - if not v: - del(set_data[k]) - unset_data[k] = 1 + # Determine if any changed items were actually unset. + for path, value in set_data.items(): + if value: + continue + + # If we've set a value that aint the default value save it. + if path in self._fields: + default = self._fields[path].default + if callable(default): + default = default() + if default != value: + continue + + del(set_data[path]) + unset_data[path] = 1 return set_data, unset_data @classmethod diff --git a/mongoengine/document.py b/mongoengine/document.py index c653c8fb..6ccda997 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -112,7 +112,7 @@ class Document(BaseDocument): self._collection = db[collection_name] return self._collection - def save(self, safe=True, force_insert=False, validate=True, write_options=None): + def save(self, safe=True, force_insert=False, validate=True, write_options=None, _refs=None): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -131,6 +131,8 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ + from fields import ReferenceField, GenericReferenceField + signals.pre_save.send(self.__class__, document=self) if validate: @@ -140,6 +142,7 @@ class Document(BaseDocument): write_options = {} doc = self.to_mongo() + created = '_id' not in doc try: collection = self.__class__.objects._collection @@ -154,6 +157,18 @@ class Document(BaseDocument): collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options) if removals: collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) + + # Save any references / generic references + _refs = _refs or [] + for name, cls in self._fields.items(): + if isinstance(cls, (ReferenceField, GenericReferenceField)): + ref = getattr(self, name) + if ref and str(ref) not in _refs: + _refs.append(str(ref)) + ref.save(safe=safe, force_insert=force_insert, + validate=validate, write_options=write_options, + _refs=_refs) + except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): diff --git a/tests/document.py b/tests/document.py index 9498cfb2..81670eb0 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1054,11 +1054,11 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() - p1 = Person(name="Wilson Jr") + p1 = Person(name="Wilson Snr") p1.parent = None p1.save() - p2 = Person(name="Wilson Jr2") + p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save() @@ -1071,6 +1071,51 @@ class DocumentTest(unittest.TestCase): p0.name = 'wpjunior' p0.save() + def test_save_cascades(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEquals(p1.name, p.parent.name) + + def test_save_cascades_generically(self): + + class Person(Document): + name = StringField() + parent = GenericReferenceField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEquals(p1.name, p.parent.name) + def test_update(self): """Ensure that an existing document is updated instead of be overwritten. """ @@ -1364,22 +1409,31 @@ class DocumentTest(unittest.TestCase): """Ensure save only sets / unsets changed fields """ + class User(self.Person): + active = BooleanField(default=True) + + + User.drop_collection() + # Create person object and save it to the database - person = self.Person(name='Test User', age=30) - person.save() - person.reload() + user = User(name='Test User', age=30, active=True) + user.save() + user.reload() + # Simulated Race condition same_person = self.Person.objects.get() + same_person.active = False + + user.age = 21 + user.save() - person.age = 21 same_person.name = 'User' - - person.save() same_person.save() person = self.Person.objects.get() self.assertEquals(person.name, 'User') self.assertEquals(person.age, 21) + self.assertEquals(person.active, False) def test_delete(self): """Ensure that document may be deleted using the delete method.