fix inc/dec operator with decimal
This commit is contained in:
parent
00bf6ac258
commit
ab08e67eaf
@ -364,7 +364,8 @@ class FloatField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class DecimalField(BaseField):
|
class DecimalField(BaseField):
|
||||||
"""Fixed-point decimal number field.
|
"""Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
|
||||||
|
If using floats, beware of Decimal to float conversion (potential precision loss)
|
||||||
|
|
||||||
.. versionchanged:: 0.8
|
.. versionchanged:: 0.8
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
@ -375,7 +376,9 @@ class DecimalField(BaseField):
|
|||||||
"""
|
"""
|
||||||
:param min_value: Validation rule for the minimum acceptable value.
|
:param min_value: Validation rule for the minimum acceptable value.
|
||||||
:param max_value: Validation rule for the maximum acceptable value.
|
:param max_value: Validation rule for the maximum acceptable value.
|
||||||
:param force_string: Store as a string.
|
:param force_string: Store the value as a string (instead of a float).
|
||||||
|
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
|
||||||
|
and some query operator won't work (e.g: inc, dec)
|
||||||
:param precision: Number of decimal places to store.
|
:param precision: Number of decimal places to store.
|
||||||
:param rounding: The rounding rule from the python decimal library:
|
:param rounding: The rounding rule from the python decimal library:
|
||||||
|
|
||||||
|
@ -201,14 +201,18 @@ def update(_doc_cls=None, **update):
|
|||||||
format.
|
format.
|
||||||
"""
|
"""
|
||||||
mongo_update = {}
|
mongo_update = {}
|
||||||
|
|
||||||
for key, value in update.items():
|
for key, value in update.items():
|
||||||
if key == '__raw__':
|
if key == '__raw__':
|
||||||
mongo_update.update(value)
|
mongo_update.update(value)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parts = key.split('__')
|
parts = key.split('__')
|
||||||
|
|
||||||
# if there is no operator, default to 'set'
|
# if there is no operator, default to 'set'
|
||||||
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
||||||
parts.insert(0, 'set')
|
parts.insert(0, 'set')
|
||||||
|
|
||||||
# Check for an operator and transform to mongo-style if there is
|
# Check for an operator and transform to mongo-style if there is
|
||||||
op = None
|
op = None
|
||||||
if parts[0] in UPDATE_OPERATORS:
|
if parts[0] in UPDATE_OPERATORS:
|
||||||
@ -294,6 +298,8 @@ def update(_doc_cls=None, **update):
|
|||||||
value = field.prepare_query_value(op, value)
|
value = field.prepare_query_value(op, value)
|
||||||
elif op == 'unset':
|
elif op == 'unset':
|
||||||
value = 1
|
value = 1
|
||||||
|
elif op == 'inc':
|
||||||
|
value = field.prepare_query_value(op, value)
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
match = '$' + match
|
match = '$' + match
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
from bson import DBRef, ObjectId
|
from bson import DBRef, ObjectId
|
||||||
from nose.plugins.skip import SkipTest
|
from nose.plugins.skip import SkipTest
|
||||||
@ -1851,21 +1852,16 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
||||||
|
|
||||||
def test_update(self):
|
def test_update_intfield_operator(self):
|
||||||
"""Ensure that atomic updates work properly.
|
|
||||||
"""
|
|
||||||
class BlogPost(Document):
|
class BlogPost(Document):
|
||||||
name = StringField()
|
|
||||||
title = StringField()
|
|
||||||
hits = IntField()
|
hits = IntField()
|
||||||
tags = ListField(StringField())
|
|
||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
post = BlogPost(name="Test Post", hits=5, tags=['test'])
|
post = BlogPost(hits=5)
|
||||||
post.save()
|
post.save()
|
||||||
|
|
||||||
BlogPost.objects.update(set__hits=10)
|
BlogPost.objects.update_one(set__hits=10)
|
||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.hits, 10)
|
self.assertEqual(post.hits, 10)
|
||||||
|
|
||||||
@ -1882,6 +1878,55 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.hits, 11)
|
self.assertEqual(post.hits, 11)
|
||||||
|
|
||||||
|
def test_update_decimalfield_operator(self):
|
||||||
|
class BlogPost(Document):
|
||||||
|
review = DecimalField()
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(review=3.5)
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(inc__review=0.1) # test with floats
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.6)
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(dec__review=0.1)
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.5)
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.62)
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(dec__review=Decimal(0.12))
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.5)
|
||||||
|
|
||||||
|
def test_update_decimalfield_operator_not_working_with_force_string(self):
|
||||||
|
class BlogPost(Document):
|
||||||
|
review = DecimalField(force_string=True)
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(review=3.5)
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
|
BlogPost.objects.update_one(inc__review=0.1) # test with floats
|
||||||
|
|
||||||
|
def test_update_listfield_operator(self):
|
||||||
|
"""Ensure that atomic updates work properly.
|
||||||
|
"""
|
||||||
|
class BlogPost(Document):
|
||||||
|
tags = ListField(StringField())
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(tags=['test'])
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
# ListField operator
|
||||||
BlogPost.objects.update(push__tags='mongo')
|
BlogPost.objects.update(push__tags='mongo')
|
||||||
post.reload()
|
post.reload()
|
||||||
self.assertTrue('mongo' in post.tags)
|
self.assertTrue('mongo' in post.tags)
|
||||||
@ -1900,13 +1945,23 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.tags.count('unique'), 1)
|
self.assertEqual(post.tags.count('unique'), 1)
|
||||||
|
|
||||||
self.assertNotEqual(post.hits, None)
|
BlogPost.drop_collection()
|
||||||
BlogPost.objects.update_one(unset__hits=1)
|
|
||||||
post.reload()
|
def test_update_unset(self):
|
||||||
self.assertEqual(post.hits, None)
|
class BlogPost(Document):
|
||||||
|
title = StringField()
|
||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(title='garbage').save()
|
||||||
|
|
||||||
|
self.assertNotEqual(post.title, None)
|
||||||
|
BlogPost.objects.update_one(unset__title=1)
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(post.title, None)
|
||||||
|
pymongo_doc = BlogPost.objects.as_pymongo().first()
|
||||||
|
self.assertNotIn('title', pymongo_doc)
|
||||||
|
|
||||||
@needs_mongodb_v26
|
@needs_mongodb_v26
|
||||||
def test_update_push_with_position(self):
|
def test_update_push_with_position(self):
|
||||||
"""Ensure that the 'push' update with position works properly.
|
"""Ensure that the 'push' update with position works properly.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user