Merge pull request #1872 from bagerard/inc_operator_with_decimal

fix inc/dec operator with DecimalField + improve its doc
This commit is contained in:
erdenezul 2018-09-04 21:07:04 +08:00 committed by GitHub
commit d6d19c4229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 14 deletions

View File

@ -364,7 +364,8 @@ class FloatField(BaseField):
class DecimalField(BaseField): class DecimalField(BaseField):
"""Fixed-point decimal number field. """Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
If using floats, beware of Decimal to float conversion (potential precision loss)
.. versionchanged:: 0.8 .. versionchanged:: 0.8
.. versionadded:: 0.3 .. versionadded:: 0.3
@ -375,7 +376,9 @@ class DecimalField(BaseField):
""" """
:param min_value: Validation rule for the minimum acceptable value. :param min_value: Validation rule for the minimum acceptable value.
:param max_value: Validation rule for the maximum acceptable value. :param max_value: Validation rule for the maximum acceptable value.
:param force_string: Store as a string. :param force_string: Store the value as a string (instead of a float).
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
and some query operator won't work (e.g: inc, dec)
:param precision: Number of decimal places to store. :param precision: Number of decimal places to store.
:param rounding: The rounding rule from the python decimal library: :param rounding: The rounding rule from the python decimal library:

View File

@ -201,14 +201,18 @@ def update(_doc_cls=None, **update):
format. format.
""" """
mongo_update = {} mongo_update = {}
for key, value in update.items(): for key, value in update.items():
if key == '__raw__': if key == '__raw__':
mongo_update.update(value) mongo_update.update(value)
continue continue
parts = key.split('__') parts = key.split('__')
# if there is no operator, default to 'set' # if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
parts.insert(0, 'set') parts.insert(0, 'set')
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
op = None op = None
if parts[0] in UPDATE_OPERATORS: if parts[0] in UPDATE_OPERATORS:
@ -294,6 +298,8 @@ def update(_doc_cls=None, **update):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op == 'unset': elif op == 'unset':
value = 1 value = 1
elif op == 'inc':
value = field.prepare_query_value(op, value)
if match: if match:
match = '$' + match match = '$' + match

View File

@ -3,6 +3,7 @@
import datetime import datetime
import unittest import unittest
import uuid import uuid
from decimal import Decimal
from bson import DBRef, ObjectId from bson import DBRef, ObjectId
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
@ -1851,21 +1852,16 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual( self.assertEqual(
1, BlogPost.objects(author__in=["%s" % me.pk]).count()) 1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_update(self): def test_update_intfield_operator(self):
"""Ensure that atomic updates work properly.
"""
class BlogPost(Document): class BlogPost(Document):
name = StringField()
title = StringField()
hits = IntField() hits = IntField()
tags = ListField(StringField())
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost(name="Test Post", hits=5, tags=['test']) post = BlogPost(hits=5)
post.save() post.save()
BlogPost.objects.update(set__hits=10) BlogPost.objects.update_one(set__hits=10)
post.reload() post.reload()
self.assertEqual(post.hits, 10) self.assertEqual(post.hits, 10)
@ -1882,6 +1878,55 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.hits, 11) self.assertEqual(post.hits, 11)
def test_update_decimalfield_operator(self):
class BlogPost(Document):
review = DecimalField()
BlogPost.drop_collection()
post = BlogPost(review=3.5)
post.save()
BlogPost.objects.update_one(inc__review=0.1) # test with floats
post.reload()
self.assertEqual(float(post.review), 3.6)
BlogPost.objects.update_one(dec__review=0.1)
post.reload()
self.assertEqual(float(post.review), 3.5)
BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
post.reload()
self.assertEqual(float(post.review), 3.62)
BlogPost.objects.update_one(dec__review=Decimal(0.12))
post.reload()
self.assertEqual(float(post.review), 3.5)
def test_update_decimalfield_operator_not_working_with_force_string(self):
class BlogPost(Document):
review = DecimalField(force_string=True)
BlogPost.drop_collection()
post = BlogPost(review=3.5)
post.save()
with self.assertRaises(OperationError):
BlogPost.objects.update_one(inc__review=0.1) # test with floats
def test_update_listfield_operator(self):
"""Ensure that atomic updates work properly.
"""
class BlogPost(Document):
tags = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost(tags=['test'])
post.save()
# ListField operator
BlogPost.objects.update(push__tags='mongo') BlogPost.objects.update(push__tags='mongo')
post.reload() post.reload()
self.assertTrue('mongo' in post.tags) self.assertTrue('mongo' in post.tags)
@ -1900,13 +1945,23 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.tags.count('unique'), 1) self.assertEqual(post.tags.count('unique'), 1)
self.assertNotEqual(post.hits, None) BlogPost.drop_collection()
BlogPost.objects.update_one(unset__hits=1)
post.reload() def test_update_unset(self):
self.assertEqual(post.hits, None) class BlogPost(Document):
title = StringField()
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost(title='garbage').save()
self.assertNotEqual(post.title, None)
BlogPost.objects.update_one(unset__title=1)
post.reload()
self.assertEqual(post.title, None)
pymongo_doc = BlogPost.objects.as_pymongo().first()
self.assertNotIn('title', pymongo_doc)
@needs_mongodb_v26 @needs_mongodb_v26
def test_update_push_with_position(self): def test_update_push_with_position(self):
"""Ensure that the 'push' update with position works properly. """Ensure that the 'push' update with position works properly.