From fd7f882011ce548efd7ae5fcb0f59fd38d38e98b Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 16:09:06 +0100 Subject: [PATCH] Save no longer tramples over documents now sets or unsets explicit fields. Fixes #146, refs #18 Thanks @zhangcheng for the initial code --- docs/changelog.rst | 5 ++- mongoengine/base.py | 9 +++-- mongoengine/document.py | 10 +++++ setup.py | 2 +- tests/document.py | 84 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 6 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0bbb5b82..ecd7ef57 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,9 +5,10 @@ Changelog Changes in dev ============== +- Fixed saving so sets updated values rather than overwrites - Added ComplexDateTimeField - Handles datetimes correctly with microseconds -- Added ComplexBaseField - for improved flexibility and performance. -- Added get_FIELD_display() method for easy choice field displaying. +- Added ComplexBaseField - for improved flexibility and performance +- Added get_FIELD_display() method for easy choice field displaying - Added queryset.slave_okay(enabled) method - Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable - Added insert method for bulk inserts diff --git a/mongoengine/base.py b/mongoengine/base.py index a22795c7..aed17bc3 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -92,6 +92,9 @@ class BaseField(object): """Descriptor for assigning a value to a field in a document. """ instance._data[self.name] = value + # If the field set is in the _present_fields list add it so we can track + if hasattr(instance, '_present_fields') and self.name not in instance._present_fields: + instance._present_fields.append(self.name) def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. @@ -592,13 +595,14 @@ class BaseDocument(object): if field.choices: # dynamically adds a way to get the display value for a field with choices setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field)) - # Use default value if present value = getattr(self, attr_name, None) setattr(self, attr_name, value) + # Assign initial values to instance for attr_name in values.keys(): try: - setattr(self, attr_name, values.pop(attr_name)) + value = values.pop(attr_name) + setattr(self, attr_name, value) except AttributeError: pass @@ -739,7 +743,6 @@ class BaseDocument(object): cls = subclasses[class_name] present_fields = data.keys() - for field_name, field in cls._fields.items(): if field.db_field in data: value = data[field.db_field] diff --git a/mongoengine/document.py b/mongoengine/document.py index cae8343d..e25bea06 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -95,6 +95,16 @@ class Document(BaseDocument): collection = self.__class__.objects._collection if force_insert: object_id = collection.insert(doc, safe=safe, **write_options) + elif '_id' in doc: + # Perform a set rather than a save - this will only save set fields + object_id = doc.pop('_id') + collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options) + + # Find and unset any fields explicitly set to None + if hasattr(self, '_present_fields'): + removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id']) + if removals: + collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) else: object_id = collection.save(doc, safe=safe, **write_options) except pymongo.errors.OperationFailure, err: diff --git a/setup.py b/setup.py index 1f65ae5d..37ec4375 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo', 'blinker', 'django>=1.3'], + install_requires=['pymongo', 'blinker', 'django==1.3'], test_suite='tests', ) diff --git a/tests/document.py b/tests/document.py index 14541469..f0af8f2d 100644 --- a/tests/document.py +++ b/tests/document.py @@ -789,6 +789,90 @@ class DocumentTest(unittest.TestCase): except ValidationError: self.fail() + def test_update(self): + """Ensure that an existing document is updated instead of be overwritten. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + + # Create same person object, with same id, without age + same_person = self.Person(name='Test') + same_person.id = person.id + same_person.save() + + # Confirm only one object + self.assertEquals(self.Person.objects.count(), 1) + + # reload + person.reload() + same_person.reload() + + # Confirm the same + self.assertEqual(person, same_person) + self.assertEqual(person.name, same_person.name) + self.assertEqual(person.age, same_person.age) + + # Confirm the saved values + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # Test only / exclude only updates included fields + person = self.Person.objects.only('name').get() + person.name = 'User' + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Test only / exclude can set non excluded / included fields + person = self.Person.objects.only('name').get() + person.name = 'Test' + person.age = 30 + person.save() + + person.reload() + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.name = 'User' + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Confirm does remove unrequired fields + person = self.Person.objects.exclude('name').get() + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, None) + + person = self.Person.objects.get() + person.name = None + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, None) + self.assertEqual(person.age, None) + def test_delete(self): """Ensure that document may be deleted using the delete method. """