Added cascading deletes

Also ensured that unsetting works when not the default value of a field
This commit is contained in:
Ross Lawley 2011-07-11 16:01:48 +01:00
parent 147e33c3ca
commit 0fb629e24c
4 changed files with 94 additions and 13 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in dev
==============
- Added cascading saves - so changes to Referenced documents are saved on .save()
- Added select_related() support
- Added support for the positional operator
- Updated geo index checking to be recursive and check in embedded documents

View File

@ -747,10 +747,21 @@ class BaseDocument(object):
if '_id' in set_data:
del(set_data['_id'])
for k,v in set_data.items():
if not v:
del(set_data[k])
unset_data[k] = 1
# Determine if any changed items were actually unset.
for path, value in set_data.items():
if value:
continue
# If we've set a value that aint the default value save it.
if path in self._fields:
default = self._fields[path].default
if callable(default):
default = default()
if default != value:
continue
del(set_data[path])
unset_data[path] = 1
return set_data, unset_data
@classmethod

View File

@ -112,7 +112,7 @@ class Document(BaseDocument):
self._collection = db[collection_name]
return self._collection
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
def save(self, safe=True, force_insert=False, validate=True, write_options=None, _refs=None):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@ -131,6 +131,8 @@ class Document(BaseDocument):
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
"""
from fields import ReferenceField, GenericReferenceField
signals.pre_save.send(self.__class__, document=self)
if validate:
@ -140,6 +142,7 @@ class Document(BaseDocument):
write_options = {}
doc = self.to_mongo()
created = '_id' not in doc
try:
collection = self.__class__.objects._collection
@ -154,6 +157,18 @@ class Document(BaseDocument):
collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options)
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
# Save any references / generic references
_refs = _refs or []
for name, cls in self._fields.items():
if isinstance(cls, (ReferenceField, GenericReferenceField)):
ref = getattr(self, name)
if ref and str(ref) not in _refs:
_refs.append(str(ref))
ref.save(safe=safe, force_insert=force_insert,
validate=validate, write_options=write_options,
_refs=_refs)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err):

View File

@ -1054,11 +1054,11 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection()
p1 = Person(name="Wilson Jr")
p1 = Person(name="Wilson Snr")
p1.parent = None
p1.save()
p2 = Person(name="Wilson Jr2")
p2 = Person(name="Wilson Jr")
p2.parent = p1
p2.save()
@ -1071,6 +1071,51 @@ class DocumentTest(unittest.TestCase):
p0.name = 'wpjunior'
p0.save()
def test_save_cascades(self):
class Person(Document):
name = StringField()
parent = ReferenceField('self')
Person.drop_collection()
p1 = Person(name="Wilson Snr")
p1.parent = None
p1.save()
p2 = Person(name="Wilson Jr")
p2.parent = p1
p2.save()
p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson"
p.save()
p1.reload()
self.assertEquals(p1.name, p.parent.name)
def test_save_cascades_generically(self):
class Person(Document):
name = StringField()
parent = GenericReferenceField()
Person.drop_collection()
p1 = Person(name="Wilson Snr")
p1.save()
p2 = Person(name="Wilson Jr")
p2.parent = p1
p2.save()
p = Person.objects(name="Wilson Jr").get()
p.parent.name = "Daddy Wilson"
p.save()
p1.reload()
self.assertEquals(p1.name, p.parent.name)
def test_update(self):
"""Ensure that an existing document is updated instead of be overwritten.
"""
@ -1364,22 +1409,31 @@ class DocumentTest(unittest.TestCase):
"""Ensure save only sets / unsets changed fields
"""
class User(self.Person):
active = BooleanField(default=True)
User.drop_collection()
# Create person object and save it to the database
person = self.Person(name='Test User', age=30)
person.save()
person.reload()
user = User(name='Test User', age=30, active=True)
user.save()
user.reload()
# Simulated Race condition
same_person = self.Person.objects.get()
same_person.active = False
user.age = 21
user.save()
person.age = 21
same_person.name = 'User'
person.save()
same_person.save()
person = self.Person.objects.get()
self.assertEquals(person.name, 'User')
self.assertEquals(person.age, 21)
self.assertEquals(person.active, False)
def test_delete(self):
"""Ensure that document may be deleted using the delete method.