From 437b11af9ac3f2a160868265337a22eb6fcb6ff5 Mon Sep 17 00:00:00 2001 From: Alex Xu Date: Mon, 10 Jul 2017 16:43:24 -0400 Subject: [PATCH 01/14] docs: use explicit register_delete_rule example The previous example of creating bi-directional delete rules was vague since the example defined only one class and the relationship between "Foo" and "Bar" wasn't clear. I added a more explicit example where the relationship between the two classes is explicit. --- mongoengine/fields.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0d402712..0029d68b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -971,11 +971,13 @@ class ReferenceField(BaseField): .. code-block:: python - class Bar(Document): - content = StringField() - foo = ReferenceField('Foo') + class Org(Document): + owner = ReferenceField('User') - Foo.register_delete_rule(Bar, 'foo', NULLIFY) + class User(Document): + org = ReferenceField('Org', reverse_delete_rule=CASCADE) + + User.register_delete_rule(Org, 'owner', DENY) .. versionchanged:: 0.5 added `reverse_delete_rule` """ From c4de879b207e4e7087b099459ff801b8897620f1 Mon Sep 17 00:00:00 2001 From: Paulo Matos Date: Fri, 11 Aug 2017 09:09:33 +0200 Subject: [PATCH 02/14] Clarify comment in validation example --- docs/guide/document-instances.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index 0e9fcef6..64f17c08 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -57,7 +57,8 @@ document values for example:: def clean(self): """Ensures that only published essays have a `pub_date` and - automatically sets the pub_date if published and not set""" + automatically sets `pub_date` if essay is published and `pub_date` + is not set""" if self.status == 'Draft' and self.pub_date is not None: msg = 'Draft entries should not have a publication date.' raise ValidationError(msg) From 2f075be6f8857d297d9e939d9150417158d93d5a Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Mon, 2 Oct 2017 22:46:27 +0800 Subject: [PATCH 03/14] parse read_preference from conn_host #1665 --- mongoengine/connection.py | 11 +++++++++++ tests/test_connection.py | 5 +++++ 2 files changed, 16 insertions(+) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 34ff4dc3..ef815343 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -103,6 +103,17 @@ def register_connection(alias, name=None, host=None, port=None, conn_settings['authentication_source'] = uri_options['authsource'] if 'authmechanism' in uri_options: conn_settings['authentication_mechanism'] = uri_options['authmechanism'] + if 'readpreference' in uri_options: + read_preferences = (ReadPreference.NEAREST, + ReadPreference.PRIMARY, + ReadPreference.PRIMARY_PREFERRED, + ReadPreference.SECONDARY, + ReadPreference.SECONDARY_PREFERRED) + read_pf_mode = uri_options['readpreference'] + for preference in read_preferences: + if preference.mode == read_pf_mode: + conn_settings['read_preference'] = preference + break else: resolved_hosts.append(entity) conn_settings['host'] = resolved_hosts diff --git a/tests/test_connection.py b/tests/test_connection.py index cdcf1377..f0c272e4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -364,6 +364,11 @@ class ConnectionTest(unittest.TestCase): date_doc = DateDoc.objects.first() self.assertEqual(d, date_doc.the_date) + def test_read_preference_from_parse(self): + from pymongo import ReadPreference + conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") + self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) + def test_multiple_connection_settings(self): connect('mongoenginetest', alias='t1', host="localhost") From 416486c370cbe565128b04e9fadcc65e643d12d9 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Mon, 2 Oct 2017 23:13:25 +0800 Subject: [PATCH 04/14] use read_preference only pymongo3.x #1665 --- mongoengine/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index ef815343..419af6bc 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -103,7 +103,7 @@ def register_connection(alias, name=None, host=None, port=None, conn_settings['authentication_source'] = uri_options['authsource'] if 'authmechanism' in uri_options: conn_settings['authentication_mechanism'] = uri_options['authmechanism'] - if 'readpreference' in uri_options: + if IS_PYMONGO_3 and 'readpreference' in uri_options: read_preferences = (ReadPreference.NEAREST, ReadPreference.PRIMARY, ReadPreference.PRIMARY_PREFERRED, From 5c4ce8754e29d435a621a1780ef96c3a6bb60ebf Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Mon, 2 Oct 2017 23:15:37 +0800 Subject: [PATCH 05/14] run tests only pymongo3 #1565 --- tests/test_connection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index f0c272e4..f58b1a3e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -365,9 +365,10 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(d, date_doc.the_date) def test_read_preference_from_parse(self): - from pymongo import ReadPreference - conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") - self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) + if IS_PYMONGO_3: + from pymongo import ReadPreference + conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") + self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) def test_multiple_connection_settings(self): connect('mongoenginetest', alias='t1', host="localhost") From 6e2db1ced6a8621e35cf38f27993751fe5ea1e6b Mon Sep 17 00:00:00 2001 From: Erdenezul Date: Tue, 3 Oct 2017 09:23:17 +0800 Subject: [PATCH 06/14] read_preference from parse_uri only PYMONGO_3 #1665 --- mongoengine/connection.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 419af6bc..feba0b58 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -104,14 +104,15 @@ def register_connection(alias, name=None, host=None, port=None, if 'authmechanism' in uri_options: conn_settings['authentication_mechanism'] = uri_options['authmechanism'] if IS_PYMONGO_3 and 'readpreference' in uri_options: - read_preferences = (ReadPreference.NEAREST, - ReadPreference.PRIMARY, - ReadPreference.PRIMARY_PREFERRED, - ReadPreference.SECONDARY, - ReadPreference.SECONDARY_PREFERRED) - read_pf_mode = uri_options['readpreference'] + read_preferences = ( + ReadPreference.NEAREST, + ReadPreference.PRIMARY, + ReadPreference.PRIMARY_PREFERRED, + ReadPreference.SECONDARY, + ReadPreference.SECONDARY_PREFERRED) + read_pf_mode = uri_options['readpreference'].lower() for preference in read_preferences: - if preference.mode == read_pf_mode: + if preference.name.lower() == read_pf_mode: conn_settings['read_preference'] = preference break else: From 080226dd7289bb3973c46b76eba15174dd8c6977 Mon Sep 17 00:00:00 2001 From: Tal Yalon Date: Fri, 22 Jun 2018 14:16:17 +0300 Subject: [PATCH 07/14] Fix issue #1286 and #844.: when building a query set from filters that reference the same field several times, do not assume each value is a dict --- mongoengine/queryset/transform.py | 2 +- tests/queryset/queryset.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 5f777f41..f450c8a3 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -147,7 +147,7 @@ def query(_doc_cls=None, **kwargs): if op is None or key not in mongo_query: mongo_query[key] = value elif key in mongo_query: - if isinstance(mongo_query[key], dict): + if isinstance(mongo_query[key], dict) and isinstance(value, dict): mongo_query[key].update(value) # $max/minDistance needs to come last - convert to SON value_dict = mongo_query[key] diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 497a0d23..9b1b3256 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -1202,6 +1202,14 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() Blog.drop_collection() + def test_filter_chaining_with_regex(self): + person = self.Person(name='Guido van Rossum') + person.save() + + people = self.Person.objects + people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum') + self.assertEqual(people.count(), 1) + def assertSequence(self, qs, expected): qs = list(qs) expected = list(expected) From 5dbee2a2708722a4e9835ac5738991ba394954c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 30 Aug 2018 16:03:16 +0200 Subject: [PATCH 08/14] Ensures EmbeddedDocumentField does not accepts references to Document classes in its constructor --- mongoengine/fields.py | 12 +++++++++-- tests/fields/fields.py | 47 ++++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 8 +++++-- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 16f3185f..a54d3a52 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -645,9 +645,17 @@ class EmbeddedDocumentField(BaseField): def document_type(self): if isinstance(self.document_type_obj, six.string_types): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: - self.document_type_obj = self.owner_document + resolved_document_type = self.owner_document else: - self.document_type_obj = get_document(self.document_type_obj) + resolved_document_type = get_document(self.document_type_obj) + + if not issubclass(resolved_document_type, EmbeddedDocument): + # Due to the late resolution of the document_type + # There is a chance that it won't be an EmbeddedDocument (#1661) + self.error('Invalid embedded document class provided to an ' + 'EmbeddedDocumentField') + self.document_type_obj = resolved_document_type + return self.document_type_obj def to_python(self, value): diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 7352d242..362acec4 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase): ])) self.assertEqual(a.b.c.txt, 'hi') + def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self): + raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet") + + class MyDoc2(Document): + emb = EmbeddedDocumentField('MyDoc') + + class MyDoc(EmbeddedDocument): + name = StringField() + def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to embedded document fields. @@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) +class TestEmbeddedDocumentField(MongoDBTestCase): + def test___init___(self): + class MyDoc(EmbeddedDocument): + name = StringField() + + field = EmbeddedDocumentField(MyDoc) + self.assertEqual(field.document_type_obj, MyDoc) + + field2 = EmbeddedDocumentField('MyDoc') + self.assertEqual(field2.document_type_obj, 'MyDoc') + + def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): + with self.assertRaises(ValidationError): + EmbeddedDocumentField(dict) + + def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): + + class MyDoc(Document): + name = StringField() + + emb = EmbeddedDocumentField('MyDoc') + with self.assertRaises(ValidationError) as ctx: + emb.document_type + self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) + + def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): + # Relates to #1661 + class MyDoc(Document): + name = StringField() + + with self.assertRaises(ValidationError): + class MyFailingDoc(Document): + emb = EmbeddedDocumentField(MyDoc) + + with self.assertRaises(ValidationError): + class MyFailingdoc2(Document): + emb = EmbeddedDocumentField('MyDoc') + class CachedReferenceFieldTest(MongoDBTestCase): def test_cached_reference_field_get_and_save(self): diff --git a/tests/utils.py b/tests/utils.py index 4566d864..acd318c5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection from mongoengine.python_support import IS_PYMONGO_3 -MONGO_TEST_DB = 'mongoenginetest' +MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database class MongoDBTestCase(unittest.TestCase): """Base class for tests that need a mongodb connection - db is being dropped automatically + It ensures that the db is clean at the beginning and dropped at the end automatically """ @classmethod @@ -32,6 +32,7 @@ def get_mongodb_version(): """ return tuple(get_connection().server_info()['versionArray']) + def _decorated_with_ver_requirement(func, ver_tuple): """Return a given function decorated with the version requirement for a particular MongoDB version tuple. @@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple): return _inner + def needs_mongodb_v26(func): """Raise a SkipTest exception if we're working with MongoDB version lower than v2.6. """ return _decorated_with_ver_requirement(func, (2, 6)) + def needs_mongodb_v3(func): """Raise a SkipTest exception if we're working with MongoDB version lower than v3.0. """ return _decorated_with_ver_requirement(func, (3, 0)) + def skip_pymongo3(f): """Raise a SkipTest exception if we're running a test against PyMongo v3.x. From bd524d2e1e51f2c9ea4bc971003eeb34e8e7a090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 30 Aug 2018 23:13:10 +0200 Subject: [PATCH 09/14] Documented that it is possible to specify a name when using a dict to define an index --- docs/guide/defining-documents.rst | 3 +++ mongoengine/context_managers.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 2a8d5418..366d12c7 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -513,6 +513,9 @@ If a dictionary is passed then the following options are available: Allows you to automatically expire data from a collection by setting the time in seconds to expire the a field. +:attr:`name` (Optional) + Allows you to specify a name for the index + .. note:: Inheritance adds extra fields indices see: :ref:`document-inheritance`. diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index ec2e9e8b..0343e163 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -215,7 +215,7 @@ class query_counter(object): """Get the number of queries.""" ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} count = self.db.system.profile.find(ignore_query).count() - self.counter - self.counter += 1 + self.counter += 1 # Account for the query we just fired return count From a7852a89cc5d8a9bf0f976c99cd42eacade4ef85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 1 Sep 2018 23:30:50 +0200 Subject: [PATCH 10/14] Fixes 2 bugs in no_subclasses context mgr (__exit__ swallows exception + repair feature) --- mongoengine/context_managers.py | 9 +++---- tests/test_context_managers.py | 44 +++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index ec2e9e8b..cfc0cdd4 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -145,18 +145,17 @@ class no_sub_classes(object): :param cls: the class to turn querying sub classes on """ self.cls = cls + self.cls_initial_subclasses = None def __enter__(self): """Change the objects default and _auto_dereference values.""" - self.cls._all_subclasses = self.cls._subclasses - self.cls._subclasses = (self.cls,) + self.cls_initial_subclasses = self.cls._subclasses + self.cls._subclasses = (self.cls._class_name,) return self.cls def __exit__(self, t, value, traceback): """Reset the default and _auto_dereference values.""" - self.cls._subclasses = self.cls._all_subclasses - delattr(self.cls, '_all_subclasses') - return self.cls + self.cls._subclasses = self.cls_initial_subclasses class query_counter(object): diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 0f6bf815..8c96016c 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -140,8 +140,6 @@ class ContextManagersTest(unittest.TestCase): def test_no_sub_classes(self): class A(Document): x = IntField() - y = IntField() - meta = {'allow_inheritance': True} class B(A): @@ -152,29 +150,29 @@ class ContextManagersTest(unittest.TestCase): A.drop_collection() - A(x=10, y=20).save() - A(x=15, y=30).save() - B(x=20, y=40).save() - B(x=30, y=50).save() - C(x=40, y=60).save() + A(x=10).save() + A(x=15).save() + B(x=20).save() + B(x=30).save() + C(x=40).save() self.assertEqual(A.objects.count(), 5) self.assertEqual(B.objects.count(), 3) self.assertEqual(C.objects.count(), 1) - with no_sub_classes(A) as A: + with no_sub_classes(A): self.assertEqual(A.objects.count(), 2) for obj in A.objects: self.assertEqual(obj.__class__, A) - with no_sub_classes(B) as B: + with no_sub_classes(B): self.assertEqual(B.objects.count(), 2) for obj in B.objects: self.assertEqual(obj.__class__, B) - with no_sub_classes(C) as C: + with no_sub_classes(C): self.assertEqual(C.objects.count(), 1) for obj in C.objects: @@ -185,6 +183,32 @@ class ContextManagersTest(unittest.TestCase): self.assertEqual(B.objects.count(), 3) self.assertEqual(C.objects.count(), 1) + def test_no_sub_classes_modification_to_document_class_are_temporary(self): + class A(Document): + x = IntField() + meta = {'allow_inheritance': True} + + class B(A): + z = IntField() + + self.assertEqual(A._subclasses, ('A', 'A.B')) + with no_sub_classes(A): + self.assertEqual(A._subclasses, ('A',)) + self.assertEqual(A._subclasses, ('A', 'A.B')) + + self.assertEqual(B._subclasses, ('A.B',)) + with no_sub_classes(B): + self.assertEqual(B._subclasses, ('A.B',)) + self.assertEqual(B._subclasses, ('A.B',)) + + def test_no_subclass_context_manager_does_not_swallow_exception(self): + class User(Document): + name = StringField() + + with self.assertRaises(TypeError): + with no_sub_classes(User): + raise TypeError() + def test_query_counter(self): connect('mongoenginetest') db = get_db() From 408274152baf75485f17eee9cc0550fd7bb82960 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Tue, 4 Sep 2018 20:24:34 +0800 Subject: [PATCH 11/14] reduce cycle complexity using logic map --- mongoengine/queryset/transform.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 5f777f41..555be6f9 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -214,17 +214,20 @@ def update(_doc_cls=None, **update): if parts[0] in UPDATE_OPERATORS: op = parts.pop(0) # Convert Pythonic names to Mongo equivalents - if op in ('push_all', 'pull_all'): - op = op.replace('_all', 'All') - elif op == 'dec': + operator_map = { + 'push_all': 'pushAll', + 'pull_all': 'pullAll', + 'dec': 'inc', + 'add_to_set': 'addToSet', + 'set_on_insert': 'setOnInsert' + } + # If operator doesn't found from operator map, op value will stay + # unchanged + op = operator_map.get(op, op) + if op == 'dec': # Support decrement by flipping a positive value's sign # and using 'inc' - op = 'inc' value = -value - elif op == 'add_to_set': - op = 'addToSet' - elif op == 'set_on_insert': - op = 'setOnInsert' match = None if parts[-1] in COMPARISON_OPERATORS: From e83b529f1c8cc032968797057d7791169fc9bba5 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Tue, 4 Sep 2018 20:38:42 +0800 Subject: [PATCH 12/14] flip value before changing op to inc --- mongoengine/queryset/transform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 555be6f9..a8670543 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -221,13 +221,13 @@ def update(_doc_cls=None, **update): 'add_to_set': 'addToSet', 'set_on_insert': 'setOnInsert' } - # If operator doesn't found from operator map, op value will stay - # unchanged - op = operator_map.get(op, op) if op == 'dec': # Support decrement by flipping a positive value's sign # and using 'inc' value = -value + # If operator doesn't found from operator map, op value will stay + # unchanged + op = operator_map.get(op, op) match = None if parts[-1] in COMPARISON_OPERATORS: From b65478e7d9c8f09d915a102a367e8da52ad6bdf4 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Tue, 4 Sep 2018 20:44:44 +0800 Subject: [PATCH 13/14] trigger ci --- mongoengine/queryset/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index a8670543..6021d464 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -225,8 +225,8 @@ def update(_doc_cls=None, **update): # Support decrement by flipping a positive value's sign # and using 'inc' value = -value - # If operator doesn't found from operator map, op value will stay - # unchanged + # If the operator doesn't found from operator map, the op value + # will stay unchanged op = operator_map.get(op, op) match = None From ab08e67eaf3f6b809a58740c0cfbbb24e1a3ef0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 30 Aug 2018 23:57:16 +0200 Subject: [PATCH 14/14] fix inc/dec operator with decimal --- mongoengine/fields.py | 7 ++- mongoengine/queryset/transform.py | 6 +++ tests/queryset/queryset.py | 79 ++++++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 89b901e7..d8eaec4e 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -364,7 +364,8 @@ class FloatField(BaseField): class DecimalField(BaseField): - """Fixed-point decimal number field. + """Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used. + If using floats, beware of Decimal to float conversion (potential precision loss) .. versionchanged:: 0.8 .. versionadded:: 0.3 @@ -375,7 +376,9 @@ class DecimalField(BaseField): """ :param min_value: Validation rule for the minimum acceptable value. :param max_value: Validation rule for the maximum acceptable value. - :param force_string: Store as a string. + :param force_string: Store the value as a string (instead of a float). + Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied) + and some query operator won't work (e.g: inc, dec) :param precision: Number of decimal places to store. :param rounding: The rounding rule from the python decimal library: diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 6021d464..25bd68e0 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -201,14 +201,18 @@ def update(_doc_cls=None, **update): format. """ mongo_update = {} + for key, value in update.items(): if key == '__raw__': mongo_update.update(value) continue + parts = key.split('__') + # if there is no operator, default to 'set' if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: parts.insert(0, 'set') + # Check for an operator and transform to mongo-style if there is op = None if parts[0] in UPDATE_OPERATORS: @@ -294,6 +298,8 @@ def update(_doc_cls=None, **update): value = field.prepare_query_value(op, value) elif op == 'unset': value = 1 + elif op == 'inc': + value = field.prepare_query_value(op, value) if match: match = '$' + match diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index a405e892..b0dd354d 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -3,6 +3,7 @@ import datetime import unittest import uuid +from decimal import Decimal from bson import DBRef, ObjectId from nose.plugins.skip import SkipTest @@ -1851,21 +1852,16 @@ class QuerySetTest(unittest.TestCase): self.assertEqual( 1, BlogPost.objects(author__in=["%s" % me.pk]).count()) - def test_update(self): - """Ensure that atomic updates work properly. - """ + def test_update_intfield_operator(self): class BlogPost(Document): - name = StringField() - title = StringField() hits = IntField() - tags = ListField(StringField()) BlogPost.drop_collection() - post = BlogPost(name="Test Post", hits=5, tags=['test']) + post = BlogPost(hits=5) post.save() - BlogPost.objects.update(set__hits=10) + BlogPost.objects.update_one(set__hits=10) post.reload() self.assertEqual(post.hits, 10) @@ -1882,6 +1878,55 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.hits, 11) + def test_update_decimalfield_operator(self): + class BlogPost(Document): + review = DecimalField() + + BlogPost.drop_collection() + + post = BlogPost(review=3.5) + post.save() + + BlogPost.objects.update_one(inc__review=0.1) # test with floats + post.reload() + self.assertEqual(float(post.review), 3.6) + + BlogPost.objects.update_one(dec__review=0.1) + post.reload() + self.assertEqual(float(post.review), 3.5) + + BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal + post.reload() + self.assertEqual(float(post.review), 3.62) + + BlogPost.objects.update_one(dec__review=Decimal(0.12)) + post.reload() + self.assertEqual(float(post.review), 3.5) + + def test_update_decimalfield_operator_not_working_with_force_string(self): + class BlogPost(Document): + review = DecimalField(force_string=True) + + BlogPost.drop_collection() + + post = BlogPost(review=3.5) + post.save() + + with self.assertRaises(OperationError): + BlogPost.objects.update_one(inc__review=0.1) # test with floats + + def test_update_listfield_operator(self): + """Ensure that atomic updates work properly. + """ + class BlogPost(Document): + tags = ListField(StringField()) + + BlogPost.drop_collection() + + post = BlogPost(tags=['test']) + post.save() + + # ListField operator BlogPost.objects.update(push__tags='mongo') post.reload() self.assertTrue('mongo' in post.tags) @@ -1900,13 +1945,23 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.tags.count('unique'), 1) - self.assertNotEqual(post.hits, None) - BlogPost.objects.update_one(unset__hits=1) - post.reload() - self.assertEqual(post.hits, None) + BlogPost.drop_collection() + + def test_update_unset(self): + class BlogPost(Document): + title = StringField() BlogPost.drop_collection() + post = BlogPost(title='garbage').save() + + self.assertNotEqual(post.title, None) + BlogPost.objects.update_one(unset__title=1) + post.reload() + self.assertEqual(post.title, None) + pymongo_doc = BlogPost.objects.as_pymongo().first() + self.assertNotIn('title', pymongo_doc) + @needs_mongodb_v26 def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly.