Add Document.modify() method

This commit is contained in:
Dmitry Konishchev 2014-09-30 15:15:11 +04:00
parent aa28abd517
commit 4752f9aa37
2 changed files with 106 additions and 4 deletions

View File

@ -12,7 +12,7 @@ from mongoengine.common import _import_class
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList, BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document) ALLOW_INHERITANCE, get_document)
from mongoengine.errors import ValidationError from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError
from mongoengine.queryset import (OperationError, NotUniqueError, from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform) QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
@ -192,6 +192,44 @@ class Document(BaseDocument):
cls.ensure_indexes() cls.ensure_indexes()
return cls._collection return cls._collection
def modify(self, query={}, **update):
"""Perform an atomic update of the document in the database and reload
the document object using updated version.
Returns True if the document has been updated or False if the document
in the database doesn't match the query.
.. note:: All unsaved changes that has been made to the document are
rejected if the method returns True.
:param query: the update will be performed only if the document in the
database matches the query
:param update: Django-style update keyword arguments
"""
if self.pk is None:
raise InvalidDocumentError("The document does not have a primary key.")
id_field = self._meta["id_field"]
query = query.copy() if isinstance(query, dict) else query.to_query(self)
if id_field not in query:
query[id_field] = self.pk
elif query[id_field] != self.pk:
raise InvalidQueryError("Invalid document modify query: it must modify only this document.")
updated = self._qs(**query).modify(new=True, **update)
if updated is None:
return False
for field in self._fields_ordered:
setattr(self, field, self._reload(field, updated[field]))
self._changed_fields = updated._changed_fields
self._created = False
return True
def save(self, force_insert=False, validate=True, clean=True, def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None, write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, save_condition=None, **kwargs): _refs=None, save_condition=None, **kwargs):

View File

@ -9,7 +9,7 @@ import unittest
import uuid import uuid
from datetime import datetime from datetime import datetime
from bson import DBRef from bson import DBRef, ObjectId
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
PickleDyanmicEmbedded, PickleDynamicTest) PickleDyanmicEmbedded, PickleDynamicTest)
@ -34,15 +34,21 @@ class InstanceTest(unittest.TestCase):
connect(db='mongoenginetest') connect(db='mongoenginetest')
self.db = get_db() self.db = get_db()
class Job(EmbeddedDocument):
name = StringField()
years = IntField()
class Person(Document): class Person(Document):
name = StringField() name = StringField()
age = IntField() age = IntField()
job = EmbeddedDocumentField(Job)
non_field = True non_field = True
meta = {"allow_inheritance": True} meta = {"allow_inheritance": True}
self.Person = Person self.Person = Person
self.Job = Job
def tearDown(self): def tearDown(self):
for collection in self.db.collection_names(): for collection in self.db.collection_names():
@ -50,6 +56,9 @@ class InstanceTest(unittest.TestCase):
continue continue
self.db.drop_collection(collection) self.db.drop_collection(collection)
def assertDbEqual(self, docs):
self.assertEqual(list(self.Person._get_collection().find().sort("id")), sorted(docs, key=lambda doc: doc["_id"]))
def test_capped_collection(self): def test_capped_collection(self):
"""Ensure that capped collections work properly. """Ensure that capped collections work properly.
""" """
@ -452,7 +461,7 @@ class InstanceTest(unittest.TestCase):
def test_dictionary_access(self): def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly. """Ensure that dictionary-style field access works properly.
""" """
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30, job=self.Job())
self.assertEqual(person['name'], 'Test User') self.assertEqual(person['name'], 'Test User')
self.assertRaises(KeyError, person.__getitem__, 'salary') self.assertRaises(KeyError, person.__getitem__, 'salary')
@ -462,7 +471,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person['name'], 'Another User') self.assertEqual(person['name'], 'Another User')
# Length = length(assigned fields + id) # Length = length(assigned fields + id)
self.assertEqual(len(person), 4) self.assertEqual(len(person), 5)
self.assertTrue('age' in person) self.assertTrue('age' in person)
person.age = None person.age = None
@ -617,6 +626,61 @@ class InstanceTest(unittest.TestCase):
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
t.save(clean=False) t.save(clean=False)
def test_modify_empty(self):
doc = self.Person(id=ObjectId(), name="bob", age=10).save()
self.assertRaises(InvalidDocumentError, lambda: self.Person().modify(set__age=10))
self.assertDbEqual([dict(doc.to_mongo())])
def test_modify_invalid_query(self):
doc1 = self.Person(id=ObjectId(), name="bob", age=10).save()
doc2 = self.Person(id=ObjectId(), name="jim", age=20).save()
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
self.assertRaises(InvalidQueryError, lambda:
doc1.modify(dict(id=doc2.id), set__value=20))
self.assertDbEqual(docs)
def test_modify_match_another_document(self):
doc1 = self.Person(id=ObjectId(), name="bob", age=10).save()
doc2 = self.Person(id=ObjectId(), name="jim", age=20).save()
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
assert not doc1.modify(dict(name=doc2.name), set__age=100)
self.assertDbEqual(docs)
def test_modify_not_exists(self):
doc1 = self.Person(id=ObjectId(), name="bob", age=10).save()
doc2 = self.Person(id=ObjectId(), name="jim", age=20)
docs = [dict(doc1.to_mongo())]
assert not doc2.modify(dict(name=doc2.name), set__age=100)
self.assertDbEqual(docs)
def test_modify_update(self):
other_doc = self.Person(id=ObjectId(), name="bob", age=10).save()
doc = self.Person(id=ObjectId(), name="jim", age=20,
job=self.Job(name="10gen", years=3)).save()
doc_copy = doc._from_son(doc.to_mongo())
# these changes must go away
doc.name = "liza"
doc.job.name = "Google"
doc.job.years = 3
assert doc.modify(set__age=21, set__job__name="MongoDB", unset__job__years=True)
doc_copy.age = 21
doc_copy.job.name = "MongoDB"
del doc_copy.job.years
assert doc.to_json() == doc_copy.to_json()
assert doc._get_changed_fields() == []
self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
def test_save(self): def test_save(self):
"""Ensure that a document may be saved in the database. """Ensure that a document may be saved in the database.
""" """