diff --git a/docs/changelog.rst b/docs/changelog.rst index b8467381..777cab8d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -37,7 +37,7 @@ Changes in 0.9.X - DEV - Removing support for Django 1.4.x, pymongo 2.5.x, pymongo 2.6.x. - Removing support for Python < 2.6.6 - Fixed $maxDistance location for geoJSON $near queries with MongoDB 2.6+ #664 -- QuerySet.modify() method to provide find_and_modify() like behaviour #677 +- QuerySet.modify() and Document.modify() methods to provide find_and_modify() like behaviour #677 #773 - Added support for the using() method on a queryset #676 - PYPY support #673 - Connection pooling #674 diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 77e35df0..bb8176fe 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -491,8 +491,11 @@ Documents may be updated atomically by using the :meth:`~mongoengine.queryset.QuerySet.update_one`, :meth:`~mongoengine.queryset.QuerySet.update` and :meth:`~mongoengine.queryset.QuerySet.modify` methods on a -:meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" -that you may use with these methods: +:class:`~mongoengine.queryset.QuerySet` or +:meth:`~mongoengine.Document.modify` and +:meth:`~mongoengine.Document.save` (with :attr:`save_condition` argument) on a +:class:`~mongoengine.Document`. +There are several different "modifiers" that you may use with these methods: * ``set`` -- set a particular value * ``unset`` -- delete a particular value (since MongoDB v1.3+) diff --git a/mongoengine/document.py b/mongoengine/document.py index 34bbb9f6..2eab83ef 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -12,7 +12,7 @@ from mongoengine.common import _import_class from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) -from mongoengine.errors import ValidationError +from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError from mongoengine.queryset import (OperationError, NotUniqueError, QuerySet, transform) from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME @@ -192,6 +192,44 @@ class Document(BaseDocument): cls.ensure_indexes() return cls._collection + def modify(self, query={}, **update): + """Perform an atomic update of the document in the database and reload + the document object using updated version. + + Returns True if the document has been updated or False if the document + in the database doesn't match the query. + + .. note:: All unsaved changes that has been made to the document are + rejected if the method returns True. + + :param query: the update will be performed only if the document in the + database matches the query + :param update: Django-style update keyword arguments + """ + + if self.pk is None: + raise InvalidDocumentError("The document does not have a primary key.") + + id_field = self._meta["id_field"] + query = query.copy() if isinstance(query, dict) else query.to_query(self) + + if id_field not in query: + query[id_field] = self.pk + elif query[id_field] != self.pk: + raise InvalidQueryError("Invalid document modify query: it must modify only this document.") + + updated = self._qs(**query).modify(new=True, **update) + if updated is None: + return False + + for field in self._fields_ordered: + setattr(self, field, self._reload(field, updated[field])) + + self._changed_fields = updated._changed_fields + self._created = False + + return True + def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, _refs=None, save_condition=None, **kwargs): diff --git a/tests/document/instance.py b/tests/document/instance.py index c226b614..360d5385 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -9,7 +9,7 @@ import unittest import uuid from datetime import datetime -from bson import DBRef +from bson import DBRef, ObjectId from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, PickleDyanmicEmbedded, PickleDynamicTest) @@ -34,15 +34,21 @@ class InstanceTest(unittest.TestCase): connect(db='mongoenginetest') self.db = get_db() + class Job(EmbeddedDocument): + name = StringField() + years = IntField() + class Person(Document): name = StringField() age = IntField() + job = EmbeddedDocumentField(Job) non_field = True meta = {"allow_inheritance": True} self.Person = Person + self.Job = Job def tearDown(self): for collection in self.db.collection_names(): @@ -50,6 +56,9 @@ class InstanceTest(unittest.TestCase): continue self.db.drop_collection(collection) + def assertDbEqual(self, docs): + self.assertEqual(list(self.Person._get_collection().find().sort("id")), sorted(docs, key=lambda doc: doc["_id"])) + def test_capped_collection(self): """Ensure that capped collections work properly. """ @@ -452,7 +461,7 @@ class InstanceTest(unittest.TestCase): def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ - person = self.Person(name='Test User', age=30) + person = self.Person(name='Test User', age=30, job=self.Job()) self.assertEqual(person['name'], 'Test User') self.assertRaises(KeyError, person.__getitem__, 'salary') @@ -462,7 +471,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person['name'], 'Another User') # Length = length(assigned fields + id) - self.assertEqual(len(person), 4) + self.assertEqual(len(person), 5) self.assertTrue('age' in person) person.age = None @@ -617,6 +626,60 @@ class InstanceTest(unittest.TestCase): t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) t.save(clean=False) + def test_modify_empty(self): + doc = self.Person(name="bob", age=10).save() + self.assertRaises(InvalidDocumentError, lambda: self.Person().modify(set__age=10)) + self.assertDbEqual([dict(doc.to_mongo())]) + + def test_modify_invalid_query(self): + doc1 = self.Person(name="bob", age=10).save() + doc2 = self.Person(name="jim", age=20).save() + docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] + + self.assertRaises(InvalidQueryError, lambda: + doc1.modify(dict(id=doc2.id), set__value=20)) + + self.assertDbEqual(docs) + + def test_modify_match_another_document(self): + doc1 = self.Person(name="bob", age=10).save() + doc2 = self.Person(name="jim", age=20).save() + docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] + + assert not doc1.modify(dict(name=doc2.name), set__age=100) + + self.assertDbEqual(docs) + + def test_modify_not_exists(self): + doc1 = self.Person(name="bob", age=10).save() + doc2 = self.Person(id=ObjectId(), name="jim", age=20) + docs = [dict(doc1.to_mongo())] + + assert not doc2.modify(dict(name=doc2.name), set__age=100) + + self.assertDbEqual(docs) + + def test_modify_update(self): + other_doc = self.Person(name="bob", age=10).save() + doc = self.Person(name="jim", age=20, job=self.Job(name="10gen", years=3)).save() + + doc_copy = doc._from_son(doc.to_mongo()) + + # these changes must go away + doc.name = "liza" + doc.job.name = "Google" + doc.job.years = 3 + + assert doc.modify(set__age=21, set__job__name="MongoDB", unset__job__years=True) + doc_copy.age = 21 + doc_copy.job.name = "MongoDB" + del doc_copy.job.years + + assert doc.to_json() == doc_copy.to_json() + assert doc._get_changed_fields() == [] + + self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) + def test_save(self): """Ensure that a document may be saved in the database. """