Configurable cascading saves

Updated cascading save logic - can now add meta or pass
cascade to save().  Also Cleaned up reset changed fields logic
as well, so less looping

Refs: #370 #349
This commit is contained in:
Ross Lawley 2011-11-28 05:17:19 -08:00
parent 4607b08be5
commit e1bb453f32
2 changed files with 60 additions and 27 deletions

View File

@ -120,7 +120,8 @@ class Document(BaseDocument):
self._collection = db[collection_name]
return self._collection
def save(self, safe=True, force_insert=False, validate=True, write_options=None, _refs=None):
def save(self, safe=True, force_insert=False, validate=True, write_options=None,
cascade=None, _refs=None):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@ -138,14 +139,19 @@ class Document(BaseDocument):
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
:param cascade: Sets the flag for cascading saves. You can set a default by setting
"cascade" in the document __meta__
:param _refs: A list of processed references used in cascading saves
.. versionchanged:: 0.5
In existing documents it only saves changed fields using set / unset
Saves are cascaded and any :class:`~pymongo.dbref.DBRef` objects
that have changes are saved as well.
"""
from fields import ReferenceField, GenericReferenceField
.. versionchanged:: 0.6
Cascade saves are optional = defaults to True, if you want fine grain
control then you can turn off using document meta['cascade'] = False
"""
signals.pre_save.send(self.__class__, document=self)
if validate:
@ -173,16 +179,11 @@ class Document(BaseDocument):
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
# Save any references / generic references
_refs = _refs or []
for name, cls in self._fields.items():
if isinstance(cls, (ReferenceField, GenericReferenceField)):
ref = getattr(self, name)
if ref and str(ref) not in _refs:
_refs.append(str(ref))
ref.save(safe=safe, force_insert=force_insert,
validate=validate, write_options=write_options,
_refs=_refs)
cascade = self._meta.get('cascade', True) if cascade is None else cascade
if cascade:
self.cascade_save(safe=safe, force_insert=force_insert,
validate=validate, write_options=write_options,
cascade=cascade, _refs=_refs)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
@ -192,23 +193,26 @@ class Document(BaseDocument):
id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id)
def reset_changed_fields(doc, inspected=None):
"""Loop through and reset changed fields lists"""
inspected = inspected or []
inspected.append(doc)
if hasattr(doc, '_changed_fields'):
doc._changed_fields = []
for field_name in doc._fields:
field = getattr(doc, field_name)
if field not in inspected and hasattr(field, '_changed_fields'):
reset_changed_fields(field, inspected)
reset_changed_fields(self)
self._changed_fields = []
signals.post_save.send(self.__class__, document=self, created=creation_mode)
def cascade_save(self, *args, **kwargs):
"""Recursively saves any references / generic references on an object"""
from fields import ReferenceField, GenericReferenceField
_refs = kwargs.get('_refs', []) or []
for name, cls in self._fields.items():
if not isinstance(cls, (ReferenceField, GenericReferenceField)):
continue
ref = getattr(self, name)
if not ref:
continue
ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data))
if ref and ref_id not in _refs:
_refs.append(ref_id)
kwargs["_refs"] = _refs
ref.save(**kwargs)
ref._changed_fields = []
def update(self, **kwargs):
"""Performs an update on the :class:`~mongoengine.Document`
A convenience wrapper to :meth:`~mongoengine.QuerySet.update`.

View File

@ -1228,6 +1228,35 @@ class DocumentTest(unittest.TestCase):
p1.reload()
self.assertEquals(p1.name, p.parent.name)
def test_save_cascade_meta(self):
class Person(Document):
name = StringField()
parent = ReferenceField('self')
meta = {'cascade': False}
Person.drop_collection()
p1 = Person(name="Wilson Snr")
p1.parent = None
p1.save()
p2 = Person(name="Wilson Jr")
p2.parent = p1
p2.save()
p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson"
p.save()
p1.reload()
self.assertNotEquals(p1.name, p.parent.name)
p.save(cascade=True)
p1.reload()
self.assertEquals(p1.name, p.parent.name)
def test_save_cascades_generically(self):
class Person(Document):