From ab08e67eaf3f6b809a58740c0cfbbb24e1a3ef0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 30 Aug 2018 23:57:16 +0200 Subject: [PATCH] fix inc/dec operator with decimal --- mongoengine/fields.py | 7 ++- mongoengine/queryset/transform.py | 6 +++ tests/queryset/queryset.py | 79 ++++++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 89b901e7..d8eaec4e 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -364,7 +364,8 @@ class FloatField(BaseField): class DecimalField(BaseField): - """Fixed-point decimal number field. + """Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used. + If using floats, beware of Decimal to float conversion (potential precision loss) .. versionchanged:: 0.8 .. versionadded:: 0.3 @@ -375,7 +376,9 @@ class DecimalField(BaseField): """ :param min_value: Validation rule for the minimum acceptable value. :param max_value: Validation rule for the maximum acceptable value. - :param force_string: Store as a string. + :param force_string: Store the value as a string (instead of a float). + Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied) + and some query operator won't work (e.g: inc, dec) :param precision: Number of decimal places to store. :param rounding: The rounding rule from the python decimal library: diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 6021d464..25bd68e0 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -201,14 +201,18 @@ def update(_doc_cls=None, **update): format. """ mongo_update = {} + for key, value in update.items(): if key == '__raw__': mongo_update.update(value) continue + parts = key.split('__') + # if there is no operator, default to 'set' if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: parts.insert(0, 'set') + # Check for an operator and transform to mongo-style if there is op = None if parts[0] in UPDATE_OPERATORS: @@ -294,6 +298,8 @@ def update(_doc_cls=None, **update): value = field.prepare_query_value(op, value) elif op == 'unset': value = 1 + elif op == 'inc': + value = field.prepare_query_value(op, value) if match: match = '$' + match diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index a405e892..b0dd354d 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -3,6 +3,7 @@ import datetime import unittest import uuid +from decimal import Decimal from bson import DBRef, ObjectId from nose.plugins.skip import SkipTest @@ -1851,21 +1852,16 @@ class QuerySetTest(unittest.TestCase): self.assertEqual( 1, BlogPost.objects(author__in=["%s" % me.pk]).count()) - def test_update(self): - """Ensure that atomic updates work properly. - """ + def test_update_intfield_operator(self): class BlogPost(Document): - name = StringField() - title = StringField() hits = IntField() - tags = ListField(StringField()) BlogPost.drop_collection() - post = BlogPost(name="Test Post", hits=5, tags=['test']) + post = BlogPost(hits=5) post.save() - BlogPost.objects.update(set__hits=10) + BlogPost.objects.update_one(set__hits=10) post.reload() self.assertEqual(post.hits, 10) @@ -1882,6 +1878,55 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.hits, 11) + def test_update_decimalfield_operator(self): + class BlogPost(Document): + review = DecimalField() + + BlogPost.drop_collection() + + post = BlogPost(review=3.5) + post.save() + + BlogPost.objects.update_one(inc__review=0.1) # test with floats + post.reload() + self.assertEqual(float(post.review), 3.6) + + BlogPost.objects.update_one(dec__review=0.1) + post.reload() + self.assertEqual(float(post.review), 3.5) + + BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal + post.reload() + self.assertEqual(float(post.review), 3.62) + + BlogPost.objects.update_one(dec__review=Decimal(0.12)) + post.reload() + self.assertEqual(float(post.review), 3.5) + + def test_update_decimalfield_operator_not_working_with_force_string(self): + class BlogPost(Document): + review = DecimalField(force_string=True) + + BlogPost.drop_collection() + + post = BlogPost(review=3.5) + post.save() + + with self.assertRaises(OperationError): + BlogPost.objects.update_one(inc__review=0.1) # test with floats + + def test_update_listfield_operator(self): + """Ensure that atomic updates work properly. + """ + class BlogPost(Document): + tags = ListField(StringField()) + + BlogPost.drop_collection() + + post = BlogPost(tags=['test']) + post.save() + + # ListField operator BlogPost.objects.update(push__tags='mongo') post.reload() self.assertTrue('mongo' in post.tags) @@ -1900,13 +1945,23 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.tags.count('unique'), 1) - self.assertNotEqual(post.hits, None) - BlogPost.objects.update_one(unset__hits=1) - post.reload() - self.assertEqual(post.hits, None) + BlogPost.drop_collection() + + def test_update_unset(self): + class BlogPost(Document): + title = StringField() BlogPost.drop_collection() + post = BlogPost(title='garbage').save() + + self.assertNotEqual(post.title, None) + BlogPost.objects.update_one(unset__title=1) + post.reload() + self.assertEqual(post.title, None) + pymongo_doc = BlogPost.objects.as_pymongo().first() + self.assertNotIn('title', pymongo_doc) + @needs_mongodb_v26 def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly.