From 42a58dda57e9be18ca4c6ca104fb556f40514b41 Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Fri, 8 Jan 2010 18:39:06 +0000 Subject: [PATCH] Added update() and update_one() with tests/docs --- docs/userguide.rst | 37 +++++++++++++++++- mongoengine/document.py | 7 ++-- mongoengine/queryset.py | 84 +++++++++++++++++++++++++++++++++++++++++ tests/document.py | 4 +- tests/queryset.py | 35 +++++++++++++++++ 5 files changed, 161 insertions(+), 6 deletions(-) diff --git a/docs/userguide.rst b/docs/userguide.rst index 331980e4..bba6ffbb 100644 --- a/docs/userguide.rst +++ b/docs/userguide.rst @@ -144,7 +144,7 @@ MongoEngine allows you to specify that a field should be unique across a collection by providing ``unique=True`` to a :class:`~mongoengine.Field`\ 's constructor. If you try to save a document that has the same value for a unique field as a document that is already in the database, a -:class:`~mongoengine.ValidationError` will be raised. You may also specify +:class:`~mongoengine.OperationError` will be raised. You may also specify multi-field uniqueness constraints by using :attr:`unique_with`, which may be either a single field name, or a list or tuple of field names:: @@ -454,3 +454,38 @@ would be generating "tag-clouds":: from operator import itemgetter top_tags = sorted(tag_freqs.items(), key=itemgetter(1), reverse=True)[:10] +Atomic updates +-------------- +Documents may be updated atomically by using the +:meth:`~mongoengine.queryset.QuerySet.update_one` and +:meth:`~mongoengine.queryset.QuerySet.update` methods on a +:meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" +that you may use with these methods: + +* ``set`` -- set a particular value +* ``unset`` -- delete a particular value (since MongoDB v1.3+) +* ``inc`` -- increment a value by a given amount +* ``dec`` -- decrement a value by a given amount +* ``push`` -- append a value to a list +* ``push_all`` -- append several values to a list +* ``pull`` -- remove a value from a list +* ``pull_all`` -- remove several values from a list + +The syntax for atomic updates is similar to the querying syntax, but the +modifier comes before the field, not after it:: + + >>> post = BlogPost(title='Test', page_views=0, tags=['database']) + >>> post.save() + >>> BlogPost.objects(id=post.id).update_one(inc__page_views=1) + >>> post.reload() # the document has been changed, so we need to reload it + >>> post.page_views + 1 + >>> BlogPost.objects(id=post.id).update_one(set__title='Example Post') + >>> post.reload() + >>> post.title + 'Example Post' + >>> BlogPost.objects(id=post.id).update_one(push__tags='nosql') + >>> post.reload() + >>> post.tags + ['database', 'nosql'] + diff --git a/mongoengine/document.py b/mongoengine/document.py index daae230c..8cbca5dc 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,11 +1,12 @@ from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError) +from queryset import OperationError from connection import _get_db import pymongo -__all__ = ['Document', 'EmbeddedDocument', 'ValidationError'] +__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError'] class EmbeddedDocument(BaseDocument): @@ -65,7 +66,7 @@ class Document(BaseDocument): try: object_id = self.__class__.objects._collection.save(doc, safe=safe) except pymongo.errors.OperationFailure, err: - raise ValidationError('Tried to safe duplicate unique keys (%s)' + raise OperationError('Tried to save duplicate unique keys (%s)' % str(err)) self.id = self._fields['id'].to_python(object_id) @@ -81,7 +82,7 @@ class Document(BaseDocument): """ obj = self.__class__.objects(id=self.id).first() for field in self._fields: - setattr(self, field, getattr(obj, field)) + setattr(self, field, obj[field]) def validate(self): """Ensure that all fields' values are valid and that required fields diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index cb8d826b..39cef283 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -10,6 +10,10 @@ class InvalidQueryError(Exception): pass +class OperationError(Exception): + pass + + class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. @@ -254,6 +258,86 @@ class QuerySet(object): """ self._collection.remove(self._query) + @classmethod + def _transform_update(cls, _doc_cls=None, **update): + """Transform an update spec from Django-style format to Mongo format. + """ + operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', + 'pull_all'] + + mongo_update = {} + for key, value in update.items(): + parts = key.split('__') + # Check for an operator and transform to mongo-style if there is + op = None + if parts[0] in operators: + op = parts.pop(0) + # Convert Pythonic names to Mongo equivalents + if op in ('push_all', 'pull_all'): + op = op.replace('_all', 'All') + elif op == 'dec': + # Support decrement by flipping a positive value's sign + # and using 'inc' + op = 'inc' + if value > 0: + value = -value + + if _doc_cls: + # Switch field names to proper names [set in Field(name='foo')] + fields = QuerySet._lookup_field(_doc_cls, parts) + parts = [field.name for field in fields] + + # Convert value to proper value + field = fields[-1] + if op in (None, 'set', 'unset', 'push', 'pull'): + value = field.prepare_query_value(value) + elif op in ('pushAll', 'pullAll'): + value = [field.prepare_query_value(v) for v in value] + + key = '.'.join(parts) + + if op: + value = {key: value} + key = '$' + op + + if op is None or key not in mongo_update: + mongo_update[key] = value + elif key in mongo_update and isinstance(mongo_update[key], dict): + mongo_update[key].update(value) + + return mongo_update + + def update(self, safe_update=True, **update): + """Perform an atomic update on the fields matched by the query. + """ + if pymongo.version < '1.1.1': + raise OperationError('update() method requires PyMongo 1.1.1+') + + update = QuerySet._transform_update(self._document, **update) + try: + self._collection.update(self._query, update, safe=safe_update, + multi=True) + except pymongo.errors.OperationFailure, err: + if str(err) == 'multi not coded yet': + raise OperationError('update() method requires MongoDB 1.1.3+') + raise OperationError('Update failed (%s)' % str(err)) + + def update_one(self, safe_update=True, **update): + """Perform an atomic update on first field matched by the query. + """ + update = QuerySet._transform_update(self._document, **update) + try: + # Explicitly provide 'multi=False' to newer versions of PyMongo + # as the default may change to 'True' + if pymongo.version >= '1.1.1': + self._collection.update(self._query, update, safe=safe_update, + multi=False) + else: + # Older versions of PyMongo don't support 'multi' + self._collection.update(self._query, update, safe=safe_update) + except pymongo.errors.OperationFailure, e: + raise OperationError('Update failed [%s]' % str(e)) + def __iter__(self): return self diff --git a/tests/document.py b/tests/document.py index a2f74649..5448635c 100644 --- a/tests/document.py +++ b/tests/document.py @@ -262,7 +262,7 @@ class DocumentTest(unittest.TestCase): # Two posts with the same slug is not allowed post2 = BlogPost(title='test2', slug='test') - self.assertRaises(ValidationError, post2.save) + self.assertRaises(OperationError, post2.save) class Date(EmbeddedDocument): year = IntField(name='yr') @@ -283,7 +283,7 @@ class DocumentTest(unittest.TestCase): # Now there will be two docs with the same slug and the same day: fail post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') - self.assertRaises(ValidationError, post3.save) + self.assertRaises(OperationError, post3.save) BlogPost.drop_collection() diff --git a/tests/queryset.py b/tests/queryset.py index 12e95934..698ada9c 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -201,6 +201,41 @@ class QuerySetTest(unittest.TestCase): self.Person.objects.delete() self.assertEqual(len(self.Person.objects), 0) + def test_update(self): + """Ensure that atomic updates work properly. + """ + class BlogPost(Document): + title = StringField() + hits = IntField() + tags = ListField(StringField()) + + BlogPost.drop_collection() + + post = BlogPost(name="Test Post", hits=5, tags=['test']) + post.save() + + BlogPost.objects.update(set__hits=10) + post.reload() + self.assertEqual(post.hits, 10) + + BlogPost.objects.update_one(inc__hits=1) + post.reload() + self.assertEqual(post.hits, 11) + + BlogPost.objects.update_one(dec__hits=1) + post.reload() + self.assertEqual(post.hits, 10) + + BlogPost.objects.update(push__tags='mongo') + post.reload() + self.assertTrue('mongo' in post.tags) + + BlogPost.objects.update_one(push_all__tags=['db', 'nosql']) + post.reload() + self.assertTrue('db' in post.tags and 'nosql' in post.tags) + + BlogPost.drop_collection() + def test_order_by(self): """Ensure that QuerySets may be ordered. """