Compare commits
	
		
			13 Commits
		
	
	
		
			simpler-in
			...
			v0.10.8
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | c6c5f85abb | ||
|  | 7b860f7739 | ||
|  | e28804c03a | ||
|  | 1b9432824b | ||
|  | 25e0f12976 | ||
|  | f168682a68 | ||
|  | d25058a46d | ||
|  | 4d0c092d9f | ||
|  | 15714ef855 | ||
|  | eb743beaa3 | ||
|  | 0007535a46 | ||
|  | 8391af026c | ||
|  | 800f656dcf | 
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -15,3 +15,5 @@ env/ | ||||
| .pydevproject | ||||
| tests/test_bugfix.py | ||||
| htmlcov/ | ||||
| venv | ||||
| venv3 | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| language: python | ||||
|  | ||||
| python: | ||||
| - '2.6' | ||||
| - '2.6'  # TODO remove in v0.11.0 | ||||
| - '2.7' | ||||
| - '3.3' | ||||
| - '3.4' | ||||
|   | ||||
| @@ -4,9 +4,19 @@ Changelog | ||||
|  | ||||
| Changes in 0.10.8 | ||||
| ================= | ||||
| - Added support for QuerySet.batch_size (#1426) | ||||
| - Fixed query set iteration within iteration #1427 | ||||
| - Fixed an issue where specifying a MongoDB URI host would override more information than it should #1421 | ||||
| - Added ability to filter the generic reference field by ObjectId and DBRef #1425 | ||||
| - Fixed delete cascade for models with a custom primary key field #1247 | ||||
| - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | ||||
| - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | ||||
| - Fixed BaseQuerySet#sum/average for fields w/ explicit db_field #1417 | ||||
| - Fixed QuerySet#sum/average for fields w/ explicit db_field #1417 | ||||
| - Fixed filtering by embedded_doc=None #1422 | ||||
| - Added support for cursor.comment #1420 | ||||
| - Fixed doc.get_<field>_display #1419 | ||||
| - Fixed __repr__ method of the StrictDict #1424 | ||||
| - Added a deprecation warning for Python 2.6 | ||||
|  | ||||
| Changes in 0.10.7 | ||||
| ================= | ||||
|   | ||||
| @@ -438,7 +438,7 @@ class StrictDict(object): | ||||
|                 __slots__ = allowed_keys_tuple | ||||
|  | ||||
|                 def __repr__(self): | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys()) | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) | ||||
|  | ||||
|             cls._classes[allowed_keys] = SpecificStrictDict | ||||
|         return cls._classes[allowed_keys] | ||||
|   | ||||
| @@ -121,7 +121,7 @@ class BaseDocument(object): | ||||
|                 else: | ||||
|                     self._data[key] = value | ||||
|  | ||||
|         # Set any get_fieldname_display methods | ||||
|         # Set any get_<field>_display methods | ||||
|         self.__set_field_display() | ||||
|  | ||||
|         if self._dynamic: | ||||
| @@ -1005,19 +1005,18 @@ class BaseDocument(object): | ||||
|         return '.'.join(parts) | ||||
|  | ||||
|     def __set_field_display(self): | ||||
|         """Dynamically set the display value for a field with choices""" | ||||
|         for attr_name, field in self._fields.items(): | ||||
|             if field.choices: | ||||
|                 if self._dynamic: | ||||
|                     obj = self | ||||
|                 else: | ||||
|                     obj = type(self) | ||||
|                 setattr(obj, | ||||
|         """For each field that specifies choices, create a | ||||
|         get_<field>_display method. | ||||
|         """ | ||||
|         fields_with_choices = [(n, f) for n, f in self._fields.items() | ||||
|                                if f.choices] | ||||
|         for attr_name, field in fields_with_choices: | ||||
|             setattr(self, | ||||
|                     'get_%s_display' % attr_name, | ||||
|                     partial(self.__get_field_display, field=field)) | ||||
|  | ||||
|     def __get_field_display(self, field): | ||||
|         """Returns the display value for a choice field""" | ||||
|         """Return the display value for a choice field""" | ||||
|         value = getattr(self, field.name) | ||||
|         if field.choices and isinstance(field.choices[0], (list, tuple)): | ||||
|             return dict(field.choices).get(value, value) | ||||
|   | ||||
| @@ -25,7 +25,8 @@ _dbs = {} | ||||
|  | ||||
| def register_connection(alias, name=None, host=None, port=None, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, authentication_source=None, | ||||
|                         username=None, password=None, | ||||
|                         authentication_source=None, | ||||
|                         authentication_mechanism=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
| @@ -70,20 +71,26 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|  | ||||
|     resolved_hosts = [] | ||||
|     for entity in conn_host: | ||||
|         # Handle uri style connections | ||||
|  | ||||
|         # Handle Mongomock | ||||
|         if entity.startswith('mongomock://'): | ||||
|             conn_settings['is_mock'] = True | ||||
|             # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` | ||||
|             resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) | ||||
|  | ||||
|         # Handle URI style connections, only updating connection params which | ||||
|         # were explicitly specified in the URI. | ||||
|         elif '://' in entity: | ||||
|             uri_dict = uri_parser.parse_uri(entity) | ||||
|             resolved_hosts.append(entity) | ||||
|             conn_settings.update({ | ||||
|                 'name': uri_dict.get('database') or name, | ||||
|                 'username': uri_dict.get('username'), | ||||
|                 'password': uri_dict.get('password'), | ||||
|                 'read_preference': read_preference, | ||||
|             }) | ||||
|  | ||||
|             if uri_dict.get('database'): | ||||
|                 conn_settings['name'] = uri_dict.get('database') | ||||
|  | ||||
|             for param in ('read_preference', 'username', 'password'): | ||||
|                 if uri_dict.get(param): | ||||
|                     conn_settings[param] = uri_dict[param] | ||||
|  | ||||
|             uri_options = uri_dict['options'] | ||||
|             if 'replicaset' in uri_options: | ||||
|                 conn_settings['replicaSet'] = True | ||||
|   | ||||
| @@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField): | ||||
|         return self.document_type._fields.get(member_name) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if not isinstance(value, self.document_type): | ||||
|         if value is not None and not isinstance(value, self.document_type): | ||||
|             value = self.document_type._from_son(value) | ||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||
|         return self.to_mongo(value) | ||||
| @@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField): | ||||
|         if document is None: | ||||
|             return None | ||||
|  | ||||
|         if isinstance(document, (dict, SON)): | ||||
|         if isinstance(document, (dict, SON, ObjectId, DBRef)): | ||||
|             return document | ||||
|  | ||||
|         id_field_name = document.__class__._meta['id_field'] | ||||
|   | ||||
| @@ -1,9 +1,22 @@ | ||||
| """Helper functions and types to aid with Python 2.5 - 3 support.""" | ||||
| """Helper functions and types to aid with Python 2.6 - 3 support.""" | ||||
|  | ||||
| import sys | ||||
| import warnings | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
|  | ||||
| # Show a deprecation warning for people using Python v2.6 | ||||
| # TODO remove in mongoengine v0.11.0 | ||||
| if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||
|     warnings.warn( | ||||
|         'Python v2.6 support is deprecated and is going to be dropped ' | ||||
|         'entirely in the upcoming v0.11.0 release. Update your Python ' | ||||
|         'version if you want to have access to the latest features and ' | ||||
|         'bug fixes in MongoEngine.', | ||||
|         DeprecationWarning | ||||
|     ) | ||||
|  | ||||
| if pymongo.version_tuple[0] < 3: | ||||
|     IS_PYMONGO_3 = False | ||||
| else: | ||||
|   | ||||
| @@ -82,6 +82,7 @@ class BaseQuerySet(object): | ||||
|         self._limit = None | ||||
|         self._skip = None | ||||
|         self._hint = -1  # Using -1 as None is a valid value for hint | ||||
|         self._batch_size = None | ||||
|         self.only_fields = [] | ||||
|         self._max_time_ms = None | ||||
|  | ||||
| @@ -275,6 +276,8 @@ class BaseQuerySet(object): | ||||
|         except StopIteration: | ||||
|             return result | ||||
|  | ||||
|         # If we were able to retrieve the 2nd doc, rewind the cursor and | ||||
|         # raise the MultipleObjectsReturned exception. | ||||
|         queryset.rewind() | ||||
|         message = u'%d items returned, instead of 1' % queryset.count() | ||||
|         raise queryset._document.MultipleObjectsReturned(message) | ||||
| @@ -444,7 +447,7 @@ class BaseQuerySet(object): | ||||
|                 if doc._collection == document_cls._collection: | ||||
|                     for ref in queryset: | ||||
|                         cascade_refs.add(ref.id) | ||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) | ||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self, 'pk__nin': cascade_refs}) | ||||
|                 ref_q_count = ref_q.count() | ||||
|                 if ref_q_count > 0: | ||||
|                     ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs) | ||||
| @@ -781,6 +784,19 @@ class BaseQuerySet(object): | ||||
|         queryset._hint = index | ||||
|         return queryset | ||||
|  | ||||
|     def batch_size(self, size): | ||||
|         """Limit the number of documents returned in a single batch (each | ||||
|         batch requires a round trip to the server). | ||||
|  | ||||
|         See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size | ||||
|         for details. | ||||
|  | ||||
|         :param size: desired size of each batch. | ||||
|         """ | ||||
|         queryset = self.clone() | ||||
|         queryset._batch_size = size | ||||
|         return queryset | ||||
|  | ||||
|     def distinct(self, field): | ||||
|         """Return a list of distinct values for a given field. | ||||
|  | ||||
| @@ -933,6 +949,14 @@ class BaseQuerySet(object): | ||||
|         queryset._ordering = queryset._get_order_by(keys) | ||||
|         return queryset | ||||
|  | ||||
|     def comment(self, text): | ||||
|         """Add a comment to the query. | ||||
|  | ||||
|         See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment | ||||
|         for details. | ||||
|         """ | ||||
|         return self._chainable_method("comment", text) | ||||
|  | ||||
|     def explain(self, format=False): | ||||
|         """Return an explain plan record for the | ||||
|         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. | ||||
| @@ -1459,6 +1483,9 @@ class BaseQuerySet(object): | ||||
|             if self._hint != -1: | ||||
|                 self._cursor_obj.hint(self._hint) | ||||
|  | ||||
|             if self._batch_size is not None: | ||||
|                 self._cursor_obj.batch_size(self._batch_size) | ||||
|  | ||||
|         return self._cursor_obj | ||||
|  | ||||
|     def __deepcopy__(self, memo): | ||||
|   | ||||
| @@ -30,6 +30,7 @@ class QuerySet(BaseQuerySet): | ||||
|         batch. Otherwise iterate the result_cache. | ||||
|         """ | ||||
|         self._iter = True | ||||
|  | ||||
|         if self._has_more: | ||||
|             return self._iter_results() | ||||
|  | ||||
| @@ -42,10 +43,12 @@ class QuerySet(BaseQuerySet): | ||||
|         """ | ||||
|         if self._len is not None: | ||||
|             return self._len | ||||
|  | ||||
|         # Populate the result cache with *all* of the docs in the cursor | ||||
|         if self._has_more: | ||||
|             # populate the cache | ||||
|             list(self._iter_results()) | ||||
|  | ||||
|         # Cache the length of the complete result cache and return it | ||||
|         self._len = len(self._result_cache) | ||||
|         return self._len | ||||
|  | ||||
| @@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet): | ||||
|     def _iter_results(self): | ||||
|         """A generator for iterating over the result cache. | ||||
|  | ||||
|         Also populates the cache if there are more possible results to yield. | ||||
|         Raises StopIteration when there are no more results""" | ||||
|         Also populates the cache if there are more possible results to | ||||
|         yield. Raises StopIteration when there are no more results. | ||||
|         """ | ||||
|         if self._result_cache is None: | ||||
|             self._result_cache = [] | ||||
|  | ||||
|         pos = 0 | ||||
|         while True: | ||||
|             upper = len(self._result_cache) | ||||
|             while pos < upper: | ||||
|  | ||||
|             # For all positions lower than the length of the current result | ||||
|             # cache, serve the docs straight from the cache w/o hitting the | ||||
|             # database. | ||||
|             # XXX it's VERY important to compute the len within the `while` | ||||
|             # condition because the result cache might expand mid-iteration | ||||
|             # (e.g. if we call len(qs) inside a loop that iterates over the | ||||
|             # queryset). Fortunately len(list) is O(1) in Python, so this | ||||
|             # doesn't cause performance issues. | ||||
|             while pos < len(self._result_cache): | ||||
|                 yield self._result_cache[pos] | ||||
|                 pos += 1 | ||||
|  | ||||
|             # Raise StopIteration if we already established there were no more | ||||
|             # docs in the db cursor. | ||||
|             if not self._has_more: | ||||
|                 raise StopIteration | ||||
|  | ||||
|             # Otherwise, populate more of the cache and repeat. | ||||
|             if len(self._result_cache) <= pos: | ||||
|                 self._populate_cache() | ||||
|  | ||||
| @@ -86,11 +104,21 @@ class QuerySet(BaseQuerySet): | ||||
|         """ | ||||
|         if self._result_cache is None: | ||||
|             self._result_cache = [] | ||||
|         if self._has_more: | ||||
|  | ||||
|         # Skip populating the cache if we already established there are no | ||||
|         # more docs to pull from the database. | ||||
|         if not self._has_more: | ||||
|             return | ||||
|  | ||||
|         # Pull in ITER_CHUNK_SIZE docs from the database and store them in | ||||
|         # the result cache. | ||||
|         try: | ||||
|             for i in xrange(ITER_CHUNK_SIZE): | ||||
|                 self._result_cache.append(self.next()) | ||||
|         except StopIteration: | ||||
|             # Getting this exception means there are no more docs in the | ||||
|             # db cursor. Set _has_more to False so that we can use that | ||||
|             # information in other places. | ||||
|             self._has_more = False | ||||
|  | ||||
|     def count(self, with_limit_and_skip=False): | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| from collections import defaultdict | ||||
|  | ||||
| from bson import SON | ||||
| from bson import ObjectId, SON | ||||
| from bson.dbref import DBRef | ||||
| import pymongo | ||||
|  | ||||
| from mongoengine.base.fields import UPDATE_OPERATORS | ||||
| @@ -26,6 +27,7 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + | ||||
|                    STRING_OPERATORS + CUSTOM_OPERATORS) | ||||
|  | ||||
|  | ||||
| # TODO make this less complex | ||||
| def query(_doc_cls=None, **kwargs): | ||||
|     """Transform a query from Django-style format to Mongo format. | ||||
|     """ | ||||
| @@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs): | ||||
|             parts = [] | ||||
|  | ||||
|             CachedReferenceField = _import_class('CachedReferenceField') | ||||
|             GenericReferenceField = _import_class('GenericReferenceField') | ||||
|  | ||||
|             cleaned_fields = [] | ||||
|             for field in fields: | ||||
| @@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs): | ||||
|                 # 'in', 'nin' and 'all' require a list of values | ||||
|                 value = [field.prepare_query_value(op, v) for v in value] | ||||
|  | ||||
|             # If we're querying a GenericReferenceField, we need to alter the | ||||
|             # key depending on the value: | ||||
|             # * If the value is a DBRef, the key should be "field_name._ref". | ||||
|             # * If the value is an ObjectId, the key should be "field_name._ref.$id". | ||||
|             if isinstance(field, GenericReferenceField): | ||||
|                 if isinstance(value, DBRef): | ||||
|                     parts[-1] += '._ref' | ||||
|                 elif isinstance(value, ObjectId): | ||||
|                     parts[-1] += '._ref.$id' | ||||
|  | ||||
|         # if op and op not in COMPARISON_OPERATORS: | ||||
|         if op: | ||||
|             if op in GEO_OPERATORS: | ||||
| @@ -128,11 +141,13 @@ def query(_doc_cls=None, **kwargs): | ||||
|  | ||||
|         for i, part in indices: | ||||
|             parts.insert(i, part) | ||||
|  | ||||
|         key = '.'.join(parts) | ||||
|  | ||||
|         if op is None or key not in mongo_query: | ||||
|             mongo_query[key] = value | ||||
|         elif key in mongo_query: | ||||
|             if key in mongo_query and isinstance(mongo_query[key], dict): | ||||
|             if isinstance(mongo_query[key], dict): | ||||
|                 mongo_query[key].update(value) | ||||
|                 # $max/minDistance needs to come last - convert to SON | ||||
|                 value_dict = mongo_query[key] | ||||
|   | ||||
| @@ -9,5 +9,5 @@ tests = tests | ||||
| [flake8] | ||||
| ignore=E501,F401,F403,F405,I201 | ||||
| exclude=build,dist,docs,venv,.tox,.eggs,tests | ||||
| max-complexity=42 | ||||
| max-complexity=45 | ||||
| application-import-names=mongoengine,tests | ||||
|   | ||||
| @@ -2,10 +2,8 @@ | ||||
| import unittest | ||||
| import sys | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import pymongo | ||||
| from random import randint | ||||
|  | ||||
| from nose.plugins.skip import SkipTest | ||||
| from datetime import datetime | ||||
| @@ -17,11 +15,9 @@ __all__ = ("IndexesTest", ) | ||||
|  | ||||
|  | ||||
| class IndexesTest(unittest.TestCase): | ||||
|     _MAX_RAND = 10 ** 10 | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND)) | ||||
|         self.connection = connect(db=self.db_name) | ||||
|         self.connection = connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|         class Person(Document): | ||||
|   | ||||
| @@ -2810,6 +2810,38 @@ class FieldTest(unittest.TestCase): | ||||
|         Post.drop_collection() | ||||
|         User.drop_collection() | ||||
|  | ||||
|     def test_generic_reference_filter_by_dbref(self): | ||||
|         """Ensure we can search for a specific generic reference by | ||||
|         providing its ObjectId. | ||||
|         """ | ||||
|         class Doc(Document): | ||||
|             ref = GenericReferenceField() | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         doc1 = Doc.objects.create() | ||||
|         doc2 = Doc.objects.create(ref=doc1) | ||||
|  | ||||
|         doc = Doc.objects.get(ref=DBRef('doc', doc1.pk)) | ||||
|         self.assertEqual(doc, doc2) | ||||
|  | ||||
|     def test_generic_reference_filter_by_objectid(self): | ||||
|         """Ensure we can search for a specific generic reference by | ||||
|         providing its DBRef. | ||||
|         """ | ||||
|         class Doc(Document): | ||||
|             ref = GenericReferenceField() | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         doc1 = Doc.objects.create() | ||||
|         doc2 = Doc.objects.create(ref=doc1) | ||||
|  | ||||
|         self.assertTrue(isinstance(doc1.pk, ObjectId)) | ||||
|  | ||||
|         doc = Doc.objects.get(ref=doc1.pk) | ||||
|         self.assertEqual(doc, doc2) | ||||
|  | ||||
|     def test_binary_fields(self): | ||||
|         """Ensure that binary fields can be stored and retrieved. | ||||
|         """ | ||||
| @@ -3001,28 +3033,32 @@ class FieldTest(unittest.TestCase): | ||||
|                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), | ||||
|                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) | ||||
|             style = StringField(max_length=3, choices=( | ||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') | ||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') | ||||
|  | ||||
|         Shirt.drop_collection() | ||||
|  | ||||
|         shirt = Shirt() | ||||
|         shirt1 = Shirt() | ||||
|         shirt2 = Shirt() | ||||
|  | ||||
|         self.assertEqual(shirt.get_size_display(), None) | ||||
|         self.assertEqual(shirt.get_style_display(), 'Small') | ||||
|         # Make sure get_<field>_display returns the default value (or None) | ||||
|         self.assertEqual(shirt1.get_size_display(), None) | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Wide') | ||||
|  | ||||
|         shirt.size = "XXL" | ||||
|         shirt.style = "B" | ||||
|         self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') | ||||
|         self.assertEqual(shirt.get_style_display(), 'Baggy') | ||||
|         shirt1.size = 'XXL' | ||||
|         shirt1.style = 'B' | ||||
|         shirt2.size = 'M' | ||||
|         shirt2.style = 'S' | ||||
|         self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Baggy') | ||||
|         self.assertEqual(shirt2.get_size_display(), 'Medium') | ||||
|         self.assertEqual(shirt2.get_style_display(), 'Small') | ||||
|  | ||||
|         # Set as Z - an invalid choice | ||||
|         shirt.size = "Z" | ||||
|         shirt.style = "Z" | ||||
|         self.assertEqual(shirt.get_size_display(), 'Z') | ||||
|         self.assertEqual(shirt.get_style_display(), 'Z') | ||||
|         self.assertRaises(ValidationError, shirt.validate) | ||||
|  | ||||
|         Shirt.drop_collection() | ||||
|         shirt1.size = 'Z' | ||||
|         shirt1.style = 'Z' | ||||
|         self.assertEqual(shirt1.get_size_display(), 'Z') | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Z') | ||||
|         self.assertRaises(ValidationError, shirt1.validate) | ||||
|  | ||||
|     def test_simple_choices_validation(self): | ||||
|         """Ensure that value is in a container of allowed values. | ||||
|   | ||||
| @@ -337,9 +337,36 @@ class QuerySetTest(unittest.TestCase): | ||||
|         query = query.filter(boolfield=True) | ||||
|         self.assertEqual(query.count(), 1) | ||||
|  | ||||
|     def test_batch_size(self): | ||||
|         """Ensure that batch_size works.""" | ||||
|         class A(Document): | ||||
|             s = StringField() | ||||
|  | ||||
|         A.drop_collection() | ||||
|  | ||||
|         for i in range(100): | ||||
|             A.objects.create(s=str(i)) | ||||
|  | ||||
|         # test iterating over the result set | ||||
|         cnt = 0 | ||||
|         for a in A.objects.batch_size(10): | ||||
|             cnt += 1 | ||||
|         self.assertEqual(cnt, 100) | ||||
|  | ||||
|         # test chaining | ||||
|         qs = A.objects.all() | ||||
|         qs = qs.limit(10).batch_size(20).skip(91) | ||||
|         cnt = 0 | ||||
|         for a in qs: | ||||
|             cnt += 1 | ||||
|         self.assertEqual(cnt, 9) | ||||
|  | ||||
|         # test invalid batch size | ||||
|         qs = A.objects.batch_size(-1) | ||||
|         self.assertRaises(ValueError, lambda: list(qs)) | ||||
|  | ||||
|     def test_update_write_concern(self): | ||||
|         """Test that passing write_concern works""" | ||||
|  | ||||
|         self.Person.drop_collection() | ||||
|  | ||||
|         write_concern = {"fsync": True} | ||||
| @@ -1239,7 +1266,8 @@ class QuerySetTest(unittest.TestCase): | ||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||
|  | ||||
|     def test_find_embedded(self): | ||||
|         """Ensure that an embedded document is properly returned from a query. | ||||
|         """Ensure that an embedded document is properly returned from | ||||
|         a query. | ||||
|         """ | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
| @@ -1250,16 +1278,31 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost(content='Had a good coffee today...') | ||||
|         post.author = User(name='Test User') | ||||
|         post.save() | ||||
|         BlogPost.objects.create( | ||||
|             author=User(name='Test User'), | ||||
|             content='Had a good coffee today...' | ||||
|         ) | ||||
|  | ||||
|         result = BlogPost.objects.first() | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|     def test_find_empty_embedded(self): | ||||
|         """Ensure that you can save and find an empty embedded document.""" | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             content = StringField() | ||||
|             author = EmbeddedDocumentField(User) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         BlogPost.objects.create(content='Anonymous post...') | ||||
|  | ||||
|         result = BlogPost.objects.get(author=None) | ||||
|         self.assertEqual(result.author, None) | ||||
|  | ||||
|     def test_find_dict_item(self): | ||||
|         """Ensure that DictField items may be found. | ||||
|         """ | ||||
| @@ -2199,6 +2242,21 @@ class QuerySetTest(unittest.TestCase): | ||||
|             a.author.name for a in Author.objects.order_by('-author__age')] | ||||
|         self.assertEqual(names, ['User A', 'User B', 'User C']) | ||||
|  | ||||
|     def test_comment(self): | ||||
|         """Make sure adding a comment to the query works.""" | ||||
|         class User(Document): | ||||
|             age = IntField() | ||||
|  | ||||
|         with db_ops_tracker() as q: | ||||
|             adult = (User.objects.filter(age__gte=18) | ||||
|                 .comment('looking for an adult') | ||||
|                 .first()) | ||||
|             ops = q.get_ops() | ||||
|             self.assertEqual(len(ops), 1) | ||||
|             op = ops[0] | ||||
|             self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}}) | ||||
|             self.assertEqual(op['query']['$comment'], 'looking for an adult') | ||||
|  | ||||
|     def test_map_reduce(self): | ||||
|         """Ensure map/reduce is both mapping and reducing. | ||||
|         """ | ||||
| @@ -4860,6 +4918,56 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||
|  | ||||
|     def test_len_during_iteration(self): | ||||
|         """Tests that calling len on a queyset during iteration doesn't | ||||
|         stop paging. | ||||
|         """ | ||||
|         class Data(Document): | ||||
|             pass | ||||
|  | ||||
|         for i in xrange(300): | ||||
|             Data().save() | ||||
|  | ||||
|         records = Data.objects.limit(250) | ||||
|  | ||||
|         # This should pull all 250 docs from mongo and populate the result | ||||
|         # cache | ||||
|         len(records) | ||||
|  | ||||
|         # Assert that iterating over documents in the qs touches every | ||||
|         # document even if we call len(qs) midway through the iteration. | ||||
|         for i, r in enumerate(records): | ||||
|             if i == 58: | ||||
|                 len(records) | ||||
|         self.assertEqual(i, 249) | ||||
|  | ||||
|         # Assert the same behavior is true even if we didn't pre-populate the | ||||
|         # result cache. | ||||
|         records = Data.objects.limit(250) | ||||
|         for i, r in enumerate(records): | ||||
|             if i == 58: | ||||
|                 len(records) | ||||
|         self.assertEqual(i, 249) | ||||
|  | ||||
|     def test_iteration_within_iteration(self): | ||||
|         """You should be able to reliably iterate over all the documents | ||||
|         in a given queryset even if there are multiple iterations of it | ||||
|         happening at the same time. | ||||
|         """ | ||||
|         class Data(Document): | ||||
|             pass | ||||
|  | ||||
|         for i in xrange(300): | ||||
|             Data().save() | ||||
|  | ||||
|         qs = Data.objects.limit(250) | ||||
|         for i, doc in enumerate(qs): | ||||
|             for j, doc2 in enumerate(qs): | ||||
|                 pass | ||||
|  | ||||
|         self.assertEqual(i, 249) | ||||
|         self.assertEqual(j, 249) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -174,19 +174,9 @@ class ConnectionTest(unittest.TestCase): | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|     def test_connect_uri_without_db(self): | ||||
|         """Ensure connect() method works properly with uri's without database_name | ||||
|         """Ensure connect() method works properly if the URI doesn't | ||||
|         include a database name. | ||||
|         """ | ||||
|         c = connect(db='mongoenginetest', alias='admin') | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|         c.admin.add_user("admin", "password") | ||||
|         c.admin.authenticate("admin", "password") | ||||
|         c.mongoenginetest.add_user("username", "password") | ||||
|  | ||||
|         if not IS_PYMONGO_3: | ||||
|             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|  | ||||
|         connect("mongoenginetest", host='mongodb://localhost/') | ||||
|  | ||||
|         conn = get_connection() | ||||
| @@ -196,8 +186,31 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'mongoenginetest') | ||||
|  | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|     def test_connect_uri_default_db(self): | ||||
|         """Ensure connect() defaults to the right database name if | ||||
|         the URI and the database_name don't explicitly specify it. | ||||
|         """ | ||||
|         connect(host='mongodb://localhost/') | ||||
|  | ||||
|         conn = get_connection() | ||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||
|  | ||||
|         db = get_db() | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'test') | ||||
|  | ||||
|     def test_uri_without_credentials_doesnt_override_conn_settings(self): | ||||
|         """Ensure connect() uses the username & password params if the URI | ||||
|         doesn't explicitly specify them. | ||||
|         """ | ||||
|         c = connect(host='mongodb://localhost/mongoenginetest', | ||||
|                     username='user', | ||||
|                     password='pass') | ||||
|  | ||||
|         # OperationFailure means that mongoengine attempted authentication | ||||
|         # w/ the provided username/password and failed - that's the desired | ||||
|         # behavior. If the MongoDB URI would override the credentials | ||||
|         self.assertRaises(OperationFailure, get_db) | ||||
|  | ||||
|     def test_connect_uri_with_authsource(self): | ||||
|         """Ensure that the connect() method works well with | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import unittest | ||||
|  | ||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict | ||||
|  | ||||
|  | ||||
| @@ -13,6 +14,14 @@ class TestStrictDict(unittest.TestCase): | ||||
|         d = self.dtype(a=1, b=1, c=1) | ||||
|         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||
|  | ||||
|     def test_repr(self): | ||||
|         d = self.dtype(a=1, b=2, c=3) | ||||
|         self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') | ||||
|  | ||||
|         # make sure quotes are escaped properly | ||||
|         d = self.dtype(a='"', b="'", c="") | ||||
|         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') | ||||
|  | ||||
|     def test_init_fails_on_nonexisting_attrs(self): | ||||
|         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user