Add Document.modify() method
This commit is contained in:
parent
aa28abd517
commit
4752f9aa37
@ -12,7 +12,7 @@ from mongoengine.common import _import_class
|
|||||||
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
|
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
|
||||||
BaseDocument, BaseDict, BaseList,
|
BaseDocument, BaseDict, BaseList,
|
||||||
ALLOW_INHERITANCE, get_document)
|
ALLOW_INHERITANCE, get_document)
|
||||||
from mongoengine.errors import ValidationError
|
from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError
|
||||||
from mongoengine.queryset import (OperationError, NotUniqueError,
|
from mongoengine.queryset import (OperationError, NotUniqueError,
|
||||||
QuerySet, transform)
|
QuerySet, transform)
|
||||||
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
||||||
@ -192,6 +192,44 @@ class Document(BaseDocument):
|
|||||||
cls.ensure_indexes()
|
cls.ensure_indexes()
|
||||||
return cls._collection
|
return cls._collection
|
||||||
|
|
||||||
|
def modify(self, query={}, **update):
|
||||||
|
"""Perform an atomic update of the document in the database and reload
|
||||||
|
the document object using updated version.
|
||||||
|
|
||||||
|
Returns True if the document has been updated or False if the document
|
||||||
|
in the database doesn't match the query.
|
||||||
|
|
||||||
|
.. note:: All unsaved changes that has been made to the document are
|
||||||
|
rejected if the method returns True.
|
||||||
|
|
||||||
|
:param query: the update will be performed only if the document in the
|
||||||
|
database matches the query
|
||||||
|
:param update: Django-style update keyword arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.pk is None:
|
||||||
|
raise InvalidDocumentError("The document does not have a primary key.")
|
||||||
|
|
||||||
|
id_field = self._meta["id_field"]
|
||||||
|
query = query.copy() if isinstance(query, dict) else query.to_query(self)
|
||||||
|
|
||||||
|
if id_field not in query:
|
||||||
|
query[id_field] = self.pk
|
||||||
|
elif query[id_field] != self.pk:
|
||||||
|
raise InvalidQueryError("Invalid document modify query: it must modify only this document.")
|
||||||
|
|
||||||
|
updated = self._qs(**query).modify(new=True, **update)
|
||||||
|
if updated is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for field in self._fields_ordered:
|
||||||
|
setattr(self, field, self._reload(field, updated[field]))
|
||||||
|
|
||||||
|
self._changed_fields = updated._changed_fields
|
||||||
|
self._created = False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def save(self, force_insert=False, validate=True, clean=True,
|
def save(self, force_insert=False, validate=True, clean=True,
|
||||||
write_concern=None, cascade=None, cascade_kwargs=None,
|
write_concern=None, cascade=None, cascade_kwargs=None,
|
||||||
_refs=None, save_condition=None, **kwargs):
|
_refs=None, save_condition=None, **kwargs):
|
||||||
|
@ -9,7 +9,7 @@ import unittest
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from bson import DBRef
|
from bson import DBRef, ObjectId
|
||||||
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
|
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
|
||||||
PickleDyanmicEmbedded, PickleDynamicTest)
|
PickleDyanmicEmbedded, PickleDynamicTest)
|
||||||
|
|
||||||
@ -34,15 +34,21 @@ class InstanceTest(unittest.TestCase):
|
|||||||
connect(db='mongoenginetest')
|
connect(db='mongoenginetest')
|
||||||
self.db = get_db()
|
self.db = get_db()
|
||||||
|
|
||||||
|
class Job(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
years = IntField()
|
||||||
|
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
age = IntField()
|
age = IntField()
|
||||||
|
job = EmbeddedDocumentField(Job)
|
||||||
|
|
||||||
non_field = True
|
non_field = True
|
||||||
|
|
||||||
meta = {"allow_inheritance": True}
|
meta = {"allow_inheritance": True}
|
||||||
|
|
||||||
self.Person = Person
|
self.Person = Person
|
||||||
|
self.Job = Job
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
for collection in self.db.collection_names():
|
for collection in self.db.collection_names():
|
||||||
@ -50,6 +56,9 @@ class InstanceTest(unittest.TestCase):
|
|||||||
continue
|
continue
|
||||||
self.db.drop_collection(collection)
|
self.db.drop_collection(collection)
|
||||||
|
|
||||||
|
def assertDbEqual(self, docs):
|
||||||
|
self.assertEqual(list(self.Person._get_collection().find().sort("id")), sorted(docs, key=lambda doc: doc["_id"]))
|
||||||
|
|
||||||
def test_capped_collection(self):
|
def test_capped_collection(self):
|
||||||
"""Ensure that capped collections work properly.
|
"""Ensure that capped collections work properly.
|
||||||
"""
|
"""
|
||||||
@ -452,7 +461,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
def test_dictionary_access(self):
|
def test_dictionary_access(self):
|
||||||
"""Ensure that dictionary-style field access works properly.
|
"""Ensure that dictionary-style field access works properly.
|
||||||
"""
|
"""
|
||||||
person = self.Person(name='Test User', age=30)
|
person = self.Person(name='Test User', age=30, job=self.Job())
|
||||||
self.assertEqual(person['name'], 'Test User')
|
self.assertEqual(person['name'], 'Test User')
|
||||||
|
|
||||||
self.assertRaises(KeyError, person.__getitem__, 'salary')
|
self.assertRaises(KeyError, person.__getitem__, 'salary')
|
||||||
@ -462,7 +471,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(person['name'], 'Another User')
|
self.assertEqual(person['name'], 'Another User')
|
||||||
|
|
||||||
# Length = length(assigned fields + id)
|
# Length = length(assigned fields + id)
|
||||||
self.assertEqual(len(person), 4)
|
self.assertEqual(len(person), 5)
|
||||||
|
|
||||||
self.assertTrue('age' in person)
|
self.assertTrue('age' in person)
|
||||||
person.age = None
|
person.age = None
|
||||||
@ -617,6 +626,61 @@ class InstanceTest(unittest.TestCase):
|
|||||||
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
|
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
|
||||||
t.save(clean=False)
|
t.save(clean=False)
|
||||||
|
|
||||||
|
def test_modify_empty(self):
|
||||||
|
doc = self.Person(id=ObjectId(), name="bob", age=10).save()
|
||||||
|
self.assertRaises(InvalidDocumentError, lambda: self.Person().modify(set__age=10))
|
||||||
|
self.assertDbEqual([dict(doc.to_mongo())])
|
||||||
|
|
||||||
|
def test_modify_invalid_query(self):
|
||||||
|
doc1 = self.Person(id=ObjectId(), name="bob", age=10).save()
|
||||||
|
doc2 = self.Person(id=ObjectId(), name="jim", age=20).save()
|
||||||
|
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
|
||||||
|
|
||||||
|
self.assertRaises(InvalidQueryError, lambda:
|
||||||
|
doc1.modify(dict(id=doc2.id), set__value=20))
|
||||||
|
|
||||||
|
self.assertDbEqual(docs)
|
||||||
|
|
||||||
|
def test_modify_match_another_document(self):
|
||||||
|
doc1 = self.Person(id=ObjectId(), name="bob", age=10).save()
|
||||||
|
doc2 = self.Person(id=ObjectId(), name="jim", age=20).save()
|
||||||
|
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
|
||||||
|
|
||||||
|
assert not doc1.modify(dict(name=doc2.name), set__age=100)
|
||||||
|
|
||||||
|
self.assertDbEqual(docs)
|
||||||
|
|
||||||
|
def test_modify_not_exists(self):
|
||||||
|
doc1 = self.Person(id=ObjectId(), name="bob", age=10).save()
|
||||||
|
doc2 = self.Person(id=ObjectId(), name="jim", age=20)
|
||||||
|
docs = [dict(doc1.to_mongo())]
|
||||||
|
|
||||||
|
assert not doc2.modify(dict(name=doc2.name), set__age=100)
|
||||||
|
|
||||||
|
self.assertDbEqual(docs)
|
||||||
|
|
||||||
|
def test_modify_update(self):
|
||||||
|
other_doc = self.Person(id=ObjectId(), name="bob", age=10).save()
|
||||||
|
doc = self.Person(id=ObjectId(), name="jim", age=20,
|
||||||
|
job=self.Job(name="10gen", years=3)).save()
|
||||||
|
|
||||||
|
doc_copy = doc._from_son(doc.to_mongo())
|
||||||
|
|
||||||
|
# these changes must go away
|
||||||
|
doc.name = "liza"
|
||||||
|
doc.job.name = "Google"
|
||||||
|
doc.job.years = 3
|
||||||
|
|
||||||
|
assert doc.modify(set__age=21, set__job__name="MongoDB", unset__job__years=True)
|
||||||
|
doc_copy.age = 21
|
||||||
|
doc_copy.job.name = "MongoDB"
|
||||||
|
del doc_copy.job.years
|
||||||
|
|
||||||
|
assert doc.to_json() == doc_copy.to_json()
|
||||||
|
assert doc._get_changed_fields() == []
|
||||||
|
|
||||||
|
self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
|
||||||
|
|
||||||
def test_save(self):
|
def test_save(self):
|
||||||
"""Ensure that a document may be saved in the database.
|
"""Ensure that a document may be saved in the database.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user