From 803164a993b272dd72860cf5cb5ea1b2464a1aee Mon Sep 17 00:00:00 2001 From: Dan Crosta Date: Mon, 11 Jul 2011 08:08:49 -0400 Subject: [PATCH 1/3] add unique index on User.username --- mongoengine/django/auth.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index 2711ee18..92424909 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -32,6 +32,12 @@ class User(Document): last_login = DateTimeField(default=datetime.datetime.now) date_joined = DateTimeField(default=datetime.datetime.now) + meta = { + 'indexes': [ + {'fields': ['username'], 'unique': True} + ] + } + def __unicode__(self): return self.username From 0fb629e24ccb6ce6f2b8c3bf92fd3c239ba5ef11 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 11 Jul 2011 16:01:48 +0100 Subject: [PATCH 2/3] Added cascading deletes Also ensured that unsetting works when not the default value of a field --- docs/changelog.rst | 1 + mongoengine/base.py | 19 ++++++++--- mongoengine/document.py | 17 +++++++++- tests/document.py | 70 ++++++++++++++++++++++++++++++++++++----- 4 files changed, 94 insertions(+), 13 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index cad1b687..1b4842e7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added cascading saves - so changes to Referenced documents are saved on .save() - Added select_related() support - Added support for the positional operator - Updated geo index checking to be recursive and check in embedded documents diff --git a/mongoengine/base.py b/mongoengine/base.py index b83164aa..25b049a3 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -747,10 +747,21 @@ class BaseDocument(object): if '_id' in set_data: del(set_data['_id']) - for k,v in set_data.items(): - if not v: - del(set_data[k]) - unset_data[k] = 1 + # Determine if any changed items were actually unset. + for path, value in set_data.items(): + if value: + continue + + # If we've set a value that aint the default value save it. + if path in self._fields: + default = self._fields[path].default + if callable(default): + default = default() + if default != value: + continue + + del(set_data[path]) + unset_data[path] = 1 return set_data, unset_data @classmethod diff --git a/mongoengine/document.py b/mongoengine/document.py index c653c8fb..6ccda997 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -112,7 +112,7 @@ class Document(BaseDocument): self._collection = db[collection_name] return self._collection - def save(self, safe=True, force_insert=False, validate=True, write_options=None): + def save(self, safe=True, force_insert=False, validate=True, write_options=None, _refs=None): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -131,6 +131,8 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ + from fields import ReferenceField, GenericReferenceField + signals.pre_save.send(self.__class__, document=self) if validate: @@ -140,6 +142,7 @@ class Document(BaseDocument): write_options = {} doc = self.to_mongo() + created = '_id' not in doc try: collection = self.__class__.objects._collection @@ -154,6 +157,18 @@ class Document(BaseDocument): collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options) if removals: collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) + + # Save any references / generic references + _refs = _refs or [] + for name, cls in self._fields.items(): + if isinstance(cls, (ReferenceField, GenericReferenceField)): + ref = getattr(self, name) + if ref and str(ref) not in _refs: + _refs.append(str(ref)) + ref.save(safe=safe, force_insert=force_insert, + validate=validate, write_options=write_options, + _refs=_refs) + except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): diff --git a/tests/document.py b/tests/document.py index 9498cfb2..81670eb0 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1054,11 +1054,11 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() - p1 = Person(name="Wilson Jr") + p1 = Person(name="Wilson Snr") p1.parent = None p1.save() - p2 = Person(name="Wilson Jr2") + p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save() @@ -1071,6 +1071,51 @@ class DocumentTest(unittest.TestCase): p0.name = 'wpjunior' p0.save() + def test_save_cascades(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEquals(p1.name, p.parent.name) + + def test_save_cascades_generically(self): + + class Person(Document): + name = StringField() + parent = GenericReferenceField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEquals(p1.name, p.parent.name) + def test_update(self): """Ensure that an existing document is updated instead of be overwritten. """ @@ -1364,22 +1409,31 @@ class DocumentTest(unittest.TestCase): """Ensure save only sets / unsets changed fields """ + class User(self.Person): + active = BooleanField(default=True) + + + User.drop_collection() + # Create person object and save it to the database - person = self.Person(name='Test User', age=30) - person.save() - person.reload() + user = User(name='Test User', age=30, active=True) + user.save() + user.reload() + # Simulated Race condition same_person = self.Person.objects.get() + same_person.active = False + + user.age = 21 + user.save() - person.age = 21 same_person.name = 'User' - - person.save() same_person.save() person = self.Person.objects.get() self.assertEquals(person.name, 'User') self.assertEquals(person.age, 21) + self.assertEquals(person.active, False) def test_delete(self): """Ensure that document may be deleted using the delete method. From 1452d3fac5f2500cda0439a45294b9382c8c2d42 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 11 Jul 2011 16:50:31 +0100 Subject: [PATCH 3/3] Fixed item_frequency methods to handle null values [fixes #216] --- mongoengine/queryset.py | 13 ++++++++----- tests/queryset.py | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 6b110ff0..d533736b 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1435,7 +1435,7 @@ class QuerySet(object): path = '{{~%(field)s}}'.split('.'); field = this; for (p in path) { field = field[path[p]]; } - if (field.constructor == Array) { + if (field && field.constructor == Array) { field.forEach(function(item) { emit(item, 1); }); @@ -1481,7 +1481,7 @@ class QuerySet(object): db[collection].find(query).forEach(function(doc) { field = doc; for (p in path) { field = field[path[p]]; } - if (field.constructor == Array) { + if (field && field.constructor == Array) { total += field.length; } else { total++; @@ -1497,7 +1497,7 @@ class QuerySet(object): db[collection].find(query).forEach(function(doc) { field = doc; for (p in path) { field = field[path[p]]; } - if (field.constructor == Array) { + if (field && field.constructor == Array) { field.forEach(function(item) { frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); }); @@ -1509,8 +1509,11 @@ class QuerySet(object): return frequencies; } """ - - return self.exec_js(freq_func, field, normalize=normalize) + data = self.exec_js(freq_func, field, normalize=normalize) + if 'undefined' in data: + data[None] = data['undefined'] + del(data['undefined']) + return data def __repr__(self): limit = REPR_OUTPUT_SIZE + 1 diff --git a/tests/queryset.py b/tests/queryset.py index c0860b5c..e21db0fa 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1781,6 +1781,28 @@ class QuerySetTest(unittest.TestCase): test_assertions(exec_js) test_assertions(map_reduce) + def test_item_frequencies_null_values(self): + + class Person(Document): + name = StringField() + city = StringField() + + Person.drop_collection() + + Person(name="Wilson Snr", city="CRB").save() + Person(name="Wilson Jr").save() + + freq = Person.objects.item_frequencies('city') + self.assertEquals(freq, {'CRB': 1.0, None: 1.0}) + freq = Person.objects.item_frequencies('city', normalize=True) + self.assertEquals(freq, {'CRB': 0.5, None: 0.5}) + + + freq = Person.objects.item_frequencies('city', map_reduce=True) + self.assertEquals(freq, {'CRB': 1.0, None: 1.0}) + freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True) + self.assertEquals(freq, {'CRB': 0.5, None: 0.5}) + def test_average(self): """Ensure that field can be averaged correctly. """