Added cascading deletes

Also ensured that unsetting works when not the default value of a field
This commit is contained in:
Ross Lawley 2011-07-11 16:01:48 +01:00
parent 147e33c3ca
commit 0fb629e24c
4 changed files with 94 additions and 13 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Added cascading saves - so changes to Referenced documents are saved on .save()
- Added select_related() support - Added select_related() support
- Added support for the positional operator - Added support for the positional operator
- Updated geo index checking to be recursive and check in embedded documents - Updated geo index checking to be recursive and check in embedded documents

View File

@ -747,10 +747,21 @@ class BaseDocument(object):
if '_id' in set_data: if '_id' in set_data:
del(set_data['_id']) del(set_data['_id'])
for k,v in set_data.items(): # Determine if any changed items were actually unset.
if not v: for path, value in set_data.items():
del(set_data[k]) if value:
unset_data[k] = 1 continue
# If we've set a value that aint the default value save it.
if path in self._fields:
default = self._fields[path].default
if callable(default):
default = default()
if default != value:
continue
del(set_data[path])
unset_data[path] = 1
return set_data, unset_data return set_data, unset_data
@classmethod @classmethod

View File

@ -112,7 +112,7 @@ class Document(BaseDocument):
self._collection = db[collection_name] self._collection = db[collection_name]
return self._collection return self._collection
def save(self, safe=True, force_insert=False, validate=True, write_options=None): def save(self, safe=True, force_insert=False, validate=True, write_options=None, _refs=None):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
@ -131,6 +131,8 @@ class Document(BaseDocument):
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to. have recorded the write and will force an fsync on each server being written to.
""" """
from fields import ReferenceField, GenericReferenceField
signals.pre_save.send(self.__class__, document=self) signals.pre_save.send(self.__class__, document=self)
if validate: if validate:
@ -140,6 +142,7 @@ class Document(BaseDocument):
write_options = {} write_options = {}
doc = self.to_mongo() doc = self.to_mongo()
created = '_id' not in doc created = '_id' not in doc
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
@ -154,6 +157,18 @@ class Document(BaseDocument):
collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options) collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options)
if removals: if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
# Save any references / generic references
_refs = _refs or []
for name, cls in self._fields.items():
if isinstance(cls, (ReferenceField, GenericReferenceField)):
ref = getattr(self, name)
if ref and str(ref) not in _refs:
_refs.append(str(ref))
ref.save(safe=safe, force_insert=force_insert,
validate=validate, write_options=write_options,
_refs=_refs)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err): if u'duplicate key' in unicode(err):

View File

@ -1054,11 +1054,11 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
p1 = Person(name="Wilson Jr") p1 = Person(name="Wilson Snr")
p1.parent = None p1.parent = None
p1.save() p1.save()
p2 = Person(name="Wilson Jr2") p2 = Person(name="Wilson Jr")
p2.parent = p1 p2.parent = p1
p2.save() p2.save()
@ -1071,6 +1071,51 @@ class DocumentTest(unittest.TestCase):
p0.name = 'wpjunior' p0.name = 'wpjunior'
p0.save() p0.save()
def test_save_cascades(self):
class Person(Document):
name = StringField()
parent = ReferenceField('self')
Person.drop_collection()
p1 = Person(name="Wilson Snr")
p1.parent = None
p1.save()
p2 = Person(name="Wilson Jr")
p2.parent = p1
p2.save()
p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson"
p.save()
p1.reload()
self.assertEquals(p1.name, p.parent.name)
def test_save_cascades_generically(self):
class Person(Document):
name = StringField()
parent = GenericReferenceField()
Person.drop_collection()
p1 = Person(name="Wilson Snr")
p1.save()
p2 = Person(name="Wilson Jr")
p2.parent = p1
p2.save()
p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson"
p.save()
p1.reload()
self.assertEquals(p1.name, p.parent.name)
def test_update(self): def test_update(self):
"""Ensure that an existing document is updated instead of be overwritten. """Ensure that an existing document is updated instead of be overwritten.
""" """
@ -1364,22 +1409,31 @@ class DocumentTest(unittest.TestCase):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields
""" """
class User(self.Person):
active = BooleanField(default=True)
User.drop_collection()
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30) user = User(name='Test User', age=30, active=True)
person.save() user.save()
person.reload() user.reload()
# Simulated Race condition
same_person = self.Person.objects.get() same_person = self.Person.objects.get()
same_person.active = False
user.age = 21
user.save()
person.age = 21
same_person.name = 'User' same_person.name = 'User'
person.save()
same_person.save() same_person.save()
person = self.Person.objects.get() person = self.Person.objects.get()
self.assertEquals(person.name, 'User') self.assertEquals(person.name, 'User')
self.assertEquals(person.age, 21) self.assertEquals(person.age, 21)
self.assertEquals(person.active, False)
def test_delete(self): def test_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method.