Merge branch 'master' of https://github.com/MongoEngine/mongoengine into dax_py3
This commit is contained in:
		| @@ -513,6 +513,9 @@ If a dictionary is passed then the following options are available: | |||||||
|     Allows you to automatically expire data from a collection by setting the |     Allows you to automatically expire data from a collection by setting the | ||||||
|     time in seconds to expire the a field. |     time in seconds to expire the a field. | ||||||
|  |  | ||||||
|  | :attr:`name` (Optional) | ||||||
|  |     Allows you to specify a name for the index | ||||||
|  |  | ||||||
| .. note:: | .. note:: | ||||||
|  |  | ||||||
|     Inheritance adds extra fields indices see: :ref:`document-inheritance`. |     Inheritance adds extra fields indices see: :ref:`document-inheritance`. | ||||||
|   | |||||||
| @@ -57,7 +57,8 @@ document values for example:: | |||||||
|  |  | ||||||
|         def clean(self): |         def clean(self): | ||||||
|             """Ensures that only published essays have a `pub_date` and |             """Ensures that only published essays have a `pub_date` and | ||||||
|             automatically sets the pub_date if published and not set""" |             automatically sets `pub_date` if essay is published and `pub_date` | ||||||
|  |             is not set""" | ||||||
|             if self.status == 'Draft' and self.pub_date is not None: |             if self.status == 'Draft' and self.pub_date is not None: | ||||||
|                 msg = 'Draft entries should not have a publication date.' |                 msg = 'Draft entries should not have a publication date.' | ||||||
|                 raise ValidationError(msg) |                 raise ValidationError(msg) | ||||||
|   | |||||||
| @@ -104,6 +104,18 @@ def register_connection(alias, db=None, name=None, host=None, port=None, | |||||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] |                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||||
|             if 'authmechanism' in uri_options: |             if 'authmechanism' in uri_options: | ||||||
|                 conn_settings['authentication_mechanism'] = uri_options['authmechanism'] |                 conn_settings['authentication_mechanism'] = uri_options['authmechanism'] | ||||||
|  |             if IS_PYMONGO_3 and 'readpreference' in uri_options: | ||||||
|  |                 read_preferences = ( | ||||||
|  |                     ReadPreference.NEAREST, | ||||||
|  |                     ReadPreference.PRIMARY, | ||||||
|  |                     ReadPreference.PRIMARY_PREFERRED, | ||||||
|  |                     ReadPreference.SECONDARY, | ||||||
|  |                     ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |                 read_pf_mode = uri_options['readpreference'].lower() | ||||||
|  |                 for preference in read_preferences: | ||||||
|  |                     if preference.name.lower() == read_pf_mode: | ||||||
|  |                         conn_settings['read_preference'] = preference | ||||||
|  |                         break | ||||||
|         else: |         else: | ||||||
|             resolved_hosts.append(entity) |             resolved_hosts.append(entity) | ||||||
|     conn_settings['host'] = resolved_hosts |     conn_settings['host'] = resolved_hosts | ||||||
|   | |||||||
| @@ -145,18 +145,17 @@ class no_sub_classes(object): | |||||||
|         :param cls: the class to turn querying sub classes on |         :param cls: the class to turn querying sub classes on | ||||||
|         """ |         """ | ||||||
|         self.cls = cls |         self.cls = cls | ||||||
|  |         self.cls_initial_subclasses = None | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         """Change the objects default and _auto_dereference values.""" |         """Change the objects default and _auto_dereference values.""" | ||||||
|         self.cls._all_subclasses = self.cls._subclasses |         self.cls_initial_subclasses = self.cls._subclasses | ||||||
|         self.cls._subclasses = (self.cls,) |         self.cls._subclasses = (self.cls._class_name,) | ||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|     def __exit__(self, t, value, traceback): |     def __exit__(self, t, value, traceback): | ||||||
|         """Reset the default and _auto_dereference values.""" |         """Reset the default and _auto_dereference values.""" | ||||||
|         self.cls._subclasses = self.cls._all_subclasses |         self.cls._subclasses = self.cls_initial_subclasses | ||||||
|         delattr(self.cls, '_all_subclasses') |  | ||||||
|         return self.cls |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class query_counter(object): | class query_counter(object): | ||||||
| @@ -215,7 +214,7 @@ class query_counter(object): | |||||||
|         """Get the number of queries.""" |         """Get the number of queries.""" | ||||||
|         ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} |         ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} | ||||||
|         count = self.db.system.profile.find(ignore_query).count() - self.counter |         count = self.db.system.profile.find(ignore_query).count() - self.counter | ||||||
|         self.counter += 1 |         self.counter += 1       # Account for the query we just fired | ||||||
|         return count |         return count | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -364,7 +364,8 @@ class FloatField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class DecimalField(BaseField): | class DecimalField(BaseField): | ||||||
|     """Fixed-point decimal number field. |     """Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used. | ||||||
|  |     If using floats, beware of Decimal to float conversion (potential precision loss) | ||||||
|  |  | ||||||
|     .. versionchanged:: 0.8 |     .. versionchanged:: 0.8 | ||||||
|     .. versionadded:: 0.3 |     .. versionadded:: 0.3 | ||||||
| @@ -375,7 +376,9 @@ class DecimalField(BaseField): | |||||||
|         """ |         """ | ||||||
|         :param min_value: Validation rule for the minimum acceptable value. |         :param min_value: Validation rule for the minimum acceptable value. | ||||||
|         :param max_value: Validation rule for the maximum acceptable value. |         :param max_value: Validation rule for the maximum acceptable value. | ||||||
|         :param force_string: Store as a string. |         :param force_string: Store the value as a string (instead of a float). | ||||||
|  |          Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied) | ||||||
|  |          and some query operator won't work (e.g: inc, dec) | ||||||
|         :param precision: Number of decimal places to store. |         :param precision: Number of decimal places to store. | ||||||
|         :param rounding: The rounding rule from the python decimal library: |         :param rounding: The rounding rule from the python decimal library: | ||||||
|  |  | ||||||
| @@ -647,9 +650,17 @@ class EmbeddedDocumentField(BaseField): | |||||||
|     def document_type(self): |     def document_type(self): | ||||||
|         if isinstance(self.document_type_obj, six.string_types): |         if isinstance(self.document_type_obj, six.string_types): | ||||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: |             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||||
|                 self.document_type_obj = self.owner_document |                 resolved_document_type = self.owner_document | ||||||
|             else: |             else: | ||||||
|                 self.document_type_obj = get_document(self.document_type_obj) |                 resolved_document_type = get_document(self.document_type_obj) | ||||||
|  |  | ||||||
|  |             if not issubclass(resolved_document_type, EmbeddedDocument): | ||||||
|  |                 # Due to the late resolution of the document_type | ||||||
|  |                 # There is a chance that it won't be an EmbeddedDocument (#1661) | ||||||
|  |                 self.error('Invalid embedded document class provided to an ' | ||||||
|  |                            'EmbeddedDocumentField') | ||||||
|  |             self.document_type_obj = resolved_document_type | ||||||
|  |  | ||||||
|         return self.document_type_obj |         return self.document_type_obj | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
| @@ -1029,11 +1040,13 @@ class ReferenceField(BaseField): | |||||||
|  |  | ||||||
|     .. code-block:: python |     .. code-block:: python | ||||||
|  |  | ||||||
|         class Bar(Document): |         class Org(Document): | ||||||
|             content = StringField() |             owner = ReferenceField('User') | ||||||
|             foo = ReferenceField('Foo') |  | ||||||
|  |  | ||||||
|         Foo.register_delete_rule(Bar, 'foo', NULLIFY) |         class User(Document): | ||||||
|  |             org = ReferenceField('Org', reverse_delete_rule=CASCADE) | ||||||
|  |  | ||||||
|  |         User.register_delete_rule(Org, 'owner', DENY) | ||||||
|  |  | ||||||
|     .. versionchanged:: 0.5 added `reverse_delete_rule` |     .. versionchanged:: 0.5 added `reverse_delete_rule` | ||||||
|     """ |     """ | ||||||
|   | |||||||
| @@ -147,7 +147,7 @@ def query(_doc_cls=None, **kwargs): | |||||||
|         if op is None or key not in mongo_query: |         if op is None or key not in mongo_query: | ||||||
|             mongo_query[key] = value |             mongo_query[key] = value | ||||||
|         elif key in mongo_query: |         elif key in mongo_query: | ||||||
|             if isinstance(mongo_query[key], dict): |             if isinstance(mongo_query[key], dict) and isinstance(value, dict): | ||||||
|                 mongo_query[key].update(value) |                 mongo_query[key].update(value) | ||||||
|                 # $max/minDistance needs to come last - convert to SON |                 # $max/minDistance needs to come last - convert to SON | ||||||
|                 value_dict = mongo_query[key] |                 value_dict = mongo_query[key] | ||||||
| @@ -201,30 +201,37 @@ def update(_doc_cls=None, **update): | |||||||
|     format. |     format. | ||||||
|     """ |     """ | ||||||
|     mongo_update = {} |     mongo_update = {} | ||||||
|  |  | ||||||
|     for key, value in update.items(): |     for key, value in update.items(): | ||||||
|         if key == '__raw__': |         if key == '__raw__': | ||||||
|             mongo_update.update(value) |             mongo_update.update(value) | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
|         parts = key.split('__') |         parts = key.split('__') | ||||||
|  |  | ||||||
|         # if there is no operator, default to 'set' |         # if there is no operator, default to 'set' | ||||||
|         if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: |         if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: | ||||||
|             parts.insert(0, 'set') |             parts.insert(0, 'set') | ||||||
|  |  | ||||||
|         # Check for an operator and transform to mongo-style if there is |         # Check for an operator and transform to mongo-style if there is | ||||||
|         op = None |         op = None | ||||||
|         if parts[0] in UPDATE_OPERATORS: |         if parts[0] in UPDATE_OPERATORS: | ||||||
|             op = parts.pop(0) |             op = parts.pop(0) | ||||||
|             # Convert Pythonic names to Mongo equivalents |             # Convert Pythonic names to Mongo equivalents | ||||||
|             if op in ('push_all', 'pull_all'): |             operator_map = { | ||||||
|                 op = op.replace('_all', 'All') |                 'push_all': 'pushAll', | ||||||
|             elif op == 'dec': |                 'pull_all': 'pullAll', | ||||||
|  |                 'dec': 'inc', | ||||||
|  |                 'add_to_set': 'addToSet', | ||||||
|  |                 'set_on_insert': 'setOnInsert' | ||||||
|  |             } | ||||||
|  |             if op == 'dec': | ||||||
|                 # Support decrement by flipping a positive value's sign |                 # Support decrement by flipping a positive value's sign | ||||||
|                 # and using 'inc' |                 # and using 'inc' | ||||||
|                 op = 'inc' |  | ||||||
|                 value = -value |                 value = -value | ||||||
|             elif op == 'add_to_set': |             # If the operator doesn't found from operator map, the op value | ||||||
|                 op = 'addToSet' |             # will stay unchanged | ||||||
|             elif op == 'set_on_insert': |             op = operator_map.get(op, op) | ||||||
|                 op = 'setOnInsert' |  | ||||||
|  |  | ||||||
|         match = None |         match = None | ||||||
|         if parts[-1] in COMPARISON_OPERATORS: |         if parts[-1] in COMPARISON_OPERATORS: | ||||||
| @@ -291,6 +298,8 @@ def update(_doc_cls=None, **update): | |||||||
|                     value = field.prepare_query_value(op, value) |                     value = field.prepare_query_value(op, value) | ||||||
|             elif op == 'unset': |             elif op == 'unset': | ||||||
|                 value = 1 |                 value = 1 | ||||||
|  |             elif op == 'inc': | ||||||
|  |                 value = field.prepare_query_value(op, value) | ||||||
|  |  | ||||||
|         if match: |         if match: | ||||||
|             match = '$' + match |             match = '$' + match | ||||||
|   | |||||||
| @@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase): | |||||||
|         ])) |         ])) | ||||||
|         self.assertEqual(a.b.c.txt, 'hi') |         self.assertEqual(a.b.c.txt, 'hi') | ||||||
|  |  | ||||||
|  |     def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self): | ||||||
|  |         raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet") | ||||||
|  |  | ||||||
|  |         class MyDoc2(Document): | ||||||
|  |             emb = EmbeddedDocumentField('MyDoc') | ||||||
|  |  | ||||||
|  |         class MyDoc(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|     def test_embedded_document_validation(self): |     def test_embedded_document_validation(self): | ||||||
|         """Ensure that invalid embedded documents cannot be assigned to |         """Ensure that invalid embedded documents cannot be assigned to | ||||||
|         embedded document fields. |         embedded document fields. | ||||||
| @@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): | |||||||
|         self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) |         self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestEmbeddedDocumentField(MongoDBTestCase): | ||||||
|  |     def test___init___(self): | ||||||
|  |         class MyDoc(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         field = EmbeddedDocumentField(MyDoc) | ||||||
|  |         self.assertEqual(field.document_type_obj, MyDoc) | ||||||
|  |  | ||||||
|  |         field2 = EmbeddedDocumentField('MyDoc') | ||||||
|  |         self.assertEqual(field2.document_type_obj, 'MyDoc') | ||||||
|  |  | ||||||
|  |     def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): | ||||||
|  |         with self.assertRaises(ValidationError): | ||||||
|  |             EmbeddedDocumentField(dict) | ||||||
|  |  | ||||||
|  |     def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): | ||||||
|  |  | ||||||
|  |         class MyDoc(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         emb = EmbeddedDocumentField('MyDoc') | ||||||
|  |         with self.assertRaises(ValidationError) as ctx: | ||||||
|  |             emb.document_type | ||||||
|  |         self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) | ||||||
|  |  | ||||||
|  |     def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): | ||||||
|  |         # Relates to #1661 | ||||||
|  |         class MyDoc(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         with self.assertRaises(ValidationError): | ||||||
|  |             class MyFailingDoc(Document): | ||||||
|  |                 emb = EmbeddedDocumentField(MyDoc) | ||||||
|  |  | ||||||
|  |         with self.assertRaises(ValidationError): | ||||||
|  |             class MyFailingdoc2(Document): | ||||||
|  |                 emb = EmbeddedDocumentField('MyDoc') | ||||||
|  |  | ||||||
| class CachedReferenceFieldTest(MongoDBTestCase): | class CachedReferenceFieldTest(MongoDBTestCase): | ||||||
|  |  | ||||||
|     def test_cached_reference_field_get_and_save(self): |     def test_cached_reference_field_get_and_save(self): | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| import datetime | import datetime | ||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
|  | from decimal import Decimal | ||||||
|  |  | ||||||
| from bson import DBRef, ObjectId | from bson import DBRef, ObjectId | ||||||
| from nose.plugins.skip import SkipTest | from nose.plugins.skip import SkipTest | ||||||
| @@ -1202,6 +1203,14 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|         Blog.drop_collection() |         Blog.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_filter_chaining_with_regex(self): | ||||||
|  |         person = self.Person(name='Guido van Rossum') | ||||||
|  |         person.save() | ||||||
|  |  | ||||||
|  |         people = self.Person.objects | ||||||
|  |         people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum') | ||||||
|  |         self.assertEqual(people.count(), 1) | ||||||
|  |  | ||||||
|     def assertSequence(self, qs, expected): |     def assertSequence(self, qs, expected): | ||||||
|         qs = list(qs) |         qs = list(qs) | ||||||
|         expected = list(expected) |         expected = list(expected) | ||||||
| @@ -1851,21 +1860,16 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             1, BlogPost.objects(author__in=["%s" % me.pk]).count()) |             1, BlogPost.objects(author__in=["%s" % me.pk]).count()) | ||||||
|  |  | ||||||
|     def test_update(self): |     def test_update_intfield_operator(self): | ||||||
|         """Ensure that atomic updates work properly. |  | ||||||
|         """ |  | ||||||
|         class BlogPost(Document): |         class BlogPost(Document): | ||||||
|             name = StringField() |  | ||||||
|             title = StringField() |  | ||||||
|             hits = IntField() |             hits = IntField() | ||||||
|             tags = ListField(StringField()) |  | ||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         post = BlogPost(name="Test Post", hits=5, tags=['test']) |         post = BlogPost(hits=5) | ||||||
|         post.save() |         post.save() | ||||||
|  |  | ||||||
|         BlogPost.objects.update(set__hits=10) |         BlogPost.objects.update_one(set__hits=10) | ||||||
|         post.reload() |         post.reload() | ||||||
|         self.assertEqual(post.hits, 10) |         self.assertEqual(post.hits, 10) | ||||||
|  |  | ||||||
| @@ -1882,6 +1886,55 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         post.reload() |         post.reload() | ||||||
|         self.assertEqual(post.hits, 11) |         self.assertEqual(post.hits, 11) | ||||||
|  |  | ||||||
|  |     def test_update_decimalfield_operator(self): | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             review = DecimalField() | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost(review=3.5) | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         BlogPost.objects.update_one(inc__review=0.1)             # test with floats | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(float(post.review), 3.6) | ||||||
|  |  | ||||||
|  |         BlogPost.objects.update_one(dec__review=0.1) | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(float(post.review), 3.5) | ||||||
|  |  | ||||||
|  |         BlogPost.objects.update_one(inc__review=Decimal(0.12))   # test with Decimal | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(float(post.review), 3.62) | ||||||
|  |  | ||||||
|  |         BlogPost.objects.update_one(dec__review=Decimal(0.12)) | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(float(post.review), 3.5) | ||||||
|  |  | ||||||
|  |     def test_update_decimalfield_operator_not_working_with_force_string(self): | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             review = DecimalField(force_string=True) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost(review=3.5) | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         with self.assertRaises(OperationError): | ||||||
|  |             BlogPost.objects.update_one(inc__review=0.1)             # test with floats | ||||||
|  |  | ||||||
|  |     def test_update_listfield_operator(self): | ||||||
|  |         """Ensure that atomic updates work properly. | ||||||
|  |         """ | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             tags = ListField(StringField()) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost(tags=['test']) | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         # ListField operator | ||||||
|         BlogPost.objects.update(push__tags='mongo') |         BlogPost.objects.update(push__tags='mongo') | ||||||
|         post.reload() |         post.reload() | ||||||
|         self.assertTrue('mongo' in post.tags) |         self.assertTrue('mongo' in post.tags) | ||||||
| @@ -1900,13 +1953,23 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         post.reload() |         post.reload() | ||||||
|         self.assertEqual(post.tags.count('unique'), 1) |         self.assertEqual(post.tags.count('unique'), 1) | ||||||
|  |  | ||||||
|         self.assertNotEqual(post.hits, None) |         BlogPost.drop_collection() | ||||||
|         BlogPost.objects.update_one(unset__hits=1) |  | ||||||
|         post.reload() |     def test_update_unset(self): | ||||||
|         self.assertEqual(post.hits, None) |         class BlogPost(Document): | ||||||
|  |             title = StringField() | ||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost(title='garbage').save() | ||||||
|  |  | ||||||
|  |         self.assertNotEqual(post.title, None) | ||||||
|  |         BlogPost.objects.update_one(unset__title=1) | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.title, None) | ||||||
|  |         pymongo_doc = BlogPost.objects.as_pymongo().first() | ||||||
|  |         self.assertNotIn('title', pymongo_doc) | ||||||
|  |  | ||||||
|     @needs_mongodb_v26 |     @needs_mongodb_v26 | ||||||
|     def test_update_push_with_position(self): |     def test_update_push_with_position(self): | ||||||
|         """Ensure that the 'push' update with position works properly. |         """Ensure that the 'push' update with position works properly. | ||||||
|   | |||||||
| @@ -364,6 +364,12 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         date_doc = DateDoc.objects.first() |         date_doc = DateDoc.objects.first() | ||||||
|         self.assertEqual(d, date_doc.the_date) |         self.assertEqual(d, date_doc.the_date) | ||||||
|  |  | ||||||
|  |     def test_read_preference_from_parse(self): | ||||||
|  |         if IS_PYMONGO_3: | ||||||
|  |             from pymongo import ReadPreference | ||||||
|  |             conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") | ||||||
|  |             self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|     def test_multiple_connection_settings(self): |     def test_multiple_connection_settings(self): | ||||||
|         connect('mongoenginetest', alias='t1', host="localhost") |         connect('mongoenginetest', alias='t1', host="localhost") | ||||||
|  |  | ||||||
|   | |||||||
| @@ -140,8 +140,6 @@ class ContextManagersTest(unittest.TestCase): | |||||||
|     def test_no_sub_classes(self): |     def test_no_sub_classes(self): | ||||||
|         class A(Document): |         class A(Document): | ||||||
|             x = IntField() |             x = IntField() | ||||||
|             y = IntField() |  | ||||||
|  |  | ||||||
|             meta = {'allow_inheritance': True} |             meta = {'allow_inheritance': True} | ||||||
|  |  | ||||||
|         class B(A): |         class B(A): | ||||||
| @@ -152,29 +150,29 @@ class ContextManagersTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         A.drop_collection() |         A.drop_collection() | ||||||
|  |  | ||||||
|         A(x=10, y=20).save() |         A(x=10).save() | ||||||
|         A(x=15, y=30).save() |         A(x=15).save() | ||||||
|         B(x=20, y=40).save() |         B(x=20).save() | ||||||
|         B(x=30, y=50).save() |         B(x=30).save() | ||||||
|         C(x=40, y=60).save() |         C(x=40).save() | ||||||
|  |  | ||||||
|         self.assertEqual(A.objects.count(), 5) |         self.assertEqual(A.objects.count(), 5) | ||||||
|         self.assertEqual(B.objects.count(), 3) |         self.assertEqual(B.objects.count(), 3) | ||||||
|         self.assertEqual(C.objects.count(), 1) |         self.assertEqual(C.objects.count(), 1) | ||||||
|  |  | ||||||
|         with no_sub_classes(A) as A: |         with no_sub_classes(A): | ||||||
|             self.assertEqual(A.objects.count(), 2) |             self.assertEqual(A.objects.count(), 2) | ||||||
|  |  | ||||||
|             for obj in A.objects: |             for obj in A.objects: | ||||||
|                 self.assertEqual(obj.__class__, A) |                 self.assertEqual(obj.__class__, A) | ||||||
|  |  | ||||||
|         with no_sub_classes(B) as B: |         with no_sub_classes(B): | ||||||
|             self.assertEqual(B.objects.count(), 2) |             self.assertEqual(B.objects.count(), 2) | ||||||
|  |  | ||||||
|             for obj in B.objects: |             for obj in B.objects: | ||||||
|                 self.assertEqual(obj.__class__, B) |                 self.assertEqual(obj.__class__, B) | ||||||
|  |  | ||||||
|         with no_sub_classes(C) as C: |         with no_sub_classes(C): | ||||||
|             self.assertEqual(C.objects.count(), 1) |             self.assertEqual(C.objects.count(), 1) | ||||||
|  |  | ||||||
|             for obj in C.objects: |             for obj in C.objects: | ||||||
| @@ -185,6 +183,32 @@ class ContextManagersTest(unittest.TestCase): | |||||||
|         self.assertEqual(B.objects.count(), 3) |         self.assertEqual(B.objects.count(), 3) | ||||||
|         self.assertEqual(C.objects.count(), 1) |         self.assertEqual(C.objects.count(), 1) | ||||||
|  |  | ||||||
|  |     def test_no_sub_classes_modification_to_document_class_are_temporary(self): | ||||||
|  |         class A(Document): | ||||||
|  |             x = IntField() | ||||||
|  |             meta = {'allow_inheritance': True} | ||||||
|  |  | ||||||
|  |         class B(A): | ||||||
|  |             z = IntField() | ||||||
|  |  | ||||||
|  |         self.assertEqual(A._subclasses, ('A', 'A.B')) | ||||||
|  |         with no_sub_classes(A): | ||||||
|  |             self.assertEqual(A._subclasses, ('A',)) | ||||||
|  |         self.assertEqual(A._subclasses, ('A', 'A.B')) | ||||||
|  |  | ||||||
|  |         self.assertEqual(B._subclasses, ('A.B',)) | ||||||
|  |         with no_sub_classes(B): | ||||||
|  |             self.assertEqual(B._subclasses, ('A.B',)) | ||||||
|  |         self.assertEqual(B._subclasses, ('A.B',)) | ||||||
|  |  | ||||||
|  |     def test_no_subclass_context_manager_does_not_swallow_exception(self): | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         with self.assertRaises(TypeError): | ||||||
|  |             with no_sub_classes(User): | ||||||
|  |                 raise TypeError() | ||||||
|  |  | ||||||
|     def test_query_counter(self): |     def test_query_counter(self): | ||||||
|         connect('mongoenginetest') |         connect('mongoenginetest') | ||||||
|         db = get_db() |         db = get_db() | ||||||
|   | |||||||
| @@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection | |||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import IS_PYMONGO_3 | ||||||
|  |  | ||||||
|  |  | ||||||
| MONGO_TEST_DB = 'mongoenginetest' | MONGO_TEST_DB = 'mongoenginetest'   # standard name for the test database | ||||||
|  |  | ||||||
|  |  | ||||||
| class MongoDBTestCase(unittest.TestCase): | class MongoDBTestCase(unittest.TestCase): | ||||||
|     """Base class for tests that need a mongodb connection |     """Base class for tests that need a mongodb connection | ||||||
|     db is being dropped automatically |     It ensures that the db is clean at the beginning and dropped at the end automatically | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -32,6 +32,7 @@ def get_mongodb_version(): | |||||||
|     """ |     """ | ||||||
|     return tuple(get_connection().server_info()['versionArray']) |     return tuple(get_connection().server_info()['versionArray']) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _decorated_with_ver_requirement(func, ver_tuple): | def _decorated_with_ver_requirement(func, ver_tuple): | ||||||
|     """Return a given function decorated with the version requirement |     """Return a given function decorated with the version requirement | ||||||
|     for a particular MongoDB version tuple. |     for a particular MongoDB version tuple. | ||||||
| @@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple): | |||||||
|  |  | ||||||
|     return _inner |     return _inner | ||||||
|  |  | ||||||
|  |  | ||||||
| def needs_mongodb_v26(func): | def needs_mongodb_v26(func): | ||||||
|     """Raise a SkipTest exception if we're working with MongoDB version |     """Raise a SkipTest exception if we're working with MongoDB version | ||||||
|     lower than v2.6. |     lower than v2.6. | ||||||
|     """ |     """ | ||||||
|     return _decorated_with_ver_requirement(func, (2, 6)) |     return _decorated_with_ver_requirement(func, (2, 6)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def needs_mongodb_v3(func): | def needs_mongodb_v3(func): | ||||||
|     """Raise a SkipTest exception if we're working with MongoDB version |     """Raise a SkipTest exception if we're working with MongoDB version | ||||||
|     lower than v3.0. |     lower than v3.0. | ||||||
|     """ |     """ | ||||||
|     return _decorated_with_ver_requirement(func, (3, 0)) |     return _decorated_with_ver_requirement(func, (3, 0)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def skip_pymongo3(f): | def skip_pymongo3(f): | ||||||
|     """Raise a SkipTest exception if we're running a test against |     """Raise a SkipTest exception if we're running a test against | ||||||
|     PyMongo v3.x. |     PyMongo v3.x. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user