diff --git a/docs/changelog.rst b/docs/changelog.rst index bce26c5f..c9cdda52 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.8.2 ================ +- Improved cascading saves write performance (#361) - Fixed amibiguity and differing behaviour regarding field defaults (#349) - ImageFields now include PIL error messages if invalid error (#353) - Added lock when calling doc.Delete() for when signals have no sender (#350) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 444d9a25..651228da 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -97,6 +97,14 @@ class DocumentMetaclass(type): attrs['_reverse_db_field_map'] = dict( (v, k) for k, v in attrs['_db_field_map'].iteritems()) + # Set cascade flag if not set + if 'cascade' not in attrs['_meta']: + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + cascade = any([isinstance(x, (ReferenceField, GenericReferenceField)) + for x in doc_fields.values()]) + attrs['_meta']['cascade'] = cascade + # # Set document hierarchy # diff --git a/mongoengine/document.py b/mongoengine/document.py index 8e7ccc20..2bdecb77 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -8,6 +8,7 @@ from pymongo.read_preferences import ReadPreference from bson import ObjectId from bson.dbref import DBRef from mongoengine import signals +from mongoengine.common import _import_class from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) @@ -284,15 +285,17 @@ class Document(BaseDocument): def cascade_save(self, *args, **kwargs): """Recursively saves any references / generic references on an objects""" - import fields _refs = kwargs.get('_refs', []) or [] + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + for name, cls in self._fields.items(): - if not isinstance(cls, (fields.ReferenceField, - fields.GenericReferenceField)): + if not isinstance(cls, (ReferenceField, + GenericReferenceField)): continue - ref = getattr(self, name) + ref = self._data.get(name) if not ref or isinstance(ref, DBRef): continue diff --git a/tests/document/instance.py b/tests/document/instance.py index cdc6fe08..f4abd027 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -646,6 +646,22 @@ class InstanceTest(unittest.TestCase): self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) + def test_setting_cascade(self): + + class ForcedCascade(Document): + meta = {'cascade': True} + + class Feed(Document): + name = StringField() + + class Subscription(Document): + name = StringField() + feed = ReferenceField(Feed) + + self.assertTrue(ForcedCascade._meta['cascade']) + self.assertTrue(Subscription._meta['cascade']) + self.assertFalse(Feed._meta['cascade']) + def test_save_cascades(self): class Person(Document): @@ -1018,6 +1034,96 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.age, 21) self.assertEqual(person.active, False) + def test_query_count_when_saving(self): + """Ensure references don't cause extra fetches when saving""" + class Organization(Document): + name = StringField() + + class User(Document): + name = StringField() + orgs = ListField(ReferenceField('Organization')) + + class Feed(Document): + name = StringField() + + class UserSubscription(Document): + name = StringField() + user = ReferenceField(User) + feed = ReferenceField(Feed) + + Organization.drop_collection() + User.drop_collection() + Feed.drop_collection() + UserSubscription.drop_collection() + + self.assertTrue(UserSubscription._meta['cascade']) + + o1 = Organization(name="o1").save() + o2 = Organization(name="o2").save() + + u1 = User(name="Ross", orgs=[o1, o2]).save() + f1 = Feed(name="MongoEngine").save() + + sub = UserSubscription(user=u1, feed=f1).save() + + user = User.objects.first() + # Even if stored as ObjectId's internally mongoengine uses DBRefs + # As ObjectId's aren't automatically derefenced + self.assertTrue(isinstance(user._data['orgs'][0], DBRef)) + self.assertTrue(isinstance(user.orgs[0], Organization)) + self.assertTrue(isinstance(user._data['orgs'][0], Organization)) + + # Changing a value + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription.objects.first() + self.assertEqual(q, 1) + sub.name = "Test Sub" + sub.save() + self.assertEqual(q, 2) + + # Changing a value that will cascade + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription.objects.first() + self.assertEqual(q, 1) + sub.user.name = "Test" + self.assertEqual(q, 2) + sub.save() + self.assertEqual(q, 3) + + # Changing a value and one that will cascade + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription.objects.first() + sub.name = "Test Sub 2" + self.assertEqual(q, 1) + sub.user.name = "Test 2" + self.assertEqual(q, 2) + sub.save() + self.assertEqual(q, 4) # One for the UserSub and one for the User + + # Saving with just the refs + with query_counter() as q: + self.assertEqual(q, 0) + sub = UserSubscription(user=u1.pk, feed=f1.pk) + sub.validate() + self.assertEqual(q, 0) # Check no change + sub.save() + self.assertEqual(q, 1) + + # Saving new objects + with query_counter() as q: + self.assertEqual(q, 0) + user = User.objects.first() + self.assertEqual(q, 1) + feed = Feed.objects.first() + self.assertEqual(q, 2) + sub = UserSubscription(user=user, feed=feed) + self.assertEqual(q, 2) # Check no change + sub.save() + self.assertEqual(q, 3) + def test_set_unset_one_operation(self): """Ensure that $set and $unset actions are performed in the same operation.