From 3ca2e953fb5ff6f5f5373cae9dce6b97609d2d3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20W=C3=B3jcik?= Date: Thu, 9 Feb 2017 12:02:46 -0800 Subject: [PATCH] Fix limit/skip/hint/batch_size chaining (#1476) --- mongoengine/queryset/base.py | 183 +++++++++++++++++++++---------- mongoengine/queryset/queryset.py | 8 +- tests/queryset/queryset.py | 114 +++++++++++++++---- 3 files changed, 220 insertions(+), 85 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 098f198e..7e485686 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -86,6 +86,7 @@ class BaseQuerySet(object): self._batch_size = None self.only_fields = [] self._max_time_ms = None + self._comment = None def __call__(self, q_obj=None, class_check=True, read_preference=None, **query): @@ -706,39 +707,36 @@ class BaseQuerySet(object): with switch_db(self._document, alias) as cls: collection = cls._get_collection() - return self.clone_into(self.__class__(self._document, collection)) + return self._clone_into(self.__class__(self._document, collection)) def clone(self): - """Creates a copy of the current - :class:`~mongoengine.queryset.QuerySet` + """Create a copy of the current queryset.""" + return self._clone_into(self.__class__(self._document, self._collection_obj)) - .. versionadded:: 0.5 + def _clone_into(self, new_qs): + """Copy all of the relevant properties of this queryset to + a new queryset (which has to be an instance of + :class:`~mongoengine.queryset.base.BaseQuerySet`). """ - return self.clone_into(self.__class__(self._document, self._collection_obj)) - - def clone_into(self, cls): - """Creates a copy of the current - :class:`~mongoengine.queryset.base.BaseQuerySet` into another child class - """ - if not isinstance(cls, BaseQuerySet): + if not isinstance(new_qs, BaseQuerySet): raise OperationError( - '%s is not a subclass of BaseQuerySet' % cls.__name__) + '%s is not a subclass of BaseQuerySet' % new_qs.__name__) copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', '_where_clause', '_loaded_fields', '_ordering', '_snapshot', '_timeout', '_class_check', '_slave_okay', '_read_preference', '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', '_limit', '_skip', '_hint', '_auto_dereference', - '_search_text', 'only_fields', '_max_time_ms') + '_search_text', 'only_fields', '_max_time_ms', '_comment') for prop in copy_props: val = getattr(self, prop) - setattr(cls, prop, copy.copy(val)) + setattr(new_qs, prop, copy.copy(val)) if self._cursor_obj: - cls._cursor_obj = self._cursor_obj.clone() + new_qs._cursor_obj = self._cursor_obj.clone() - return cls + return new_qs def select_related(self, max_depth=1): """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or @@ -760,7 +758,11 @@ class BaseQuerySet(object): """ queryset = self.clone() queryset._limit = n if n != 0 else 1 - # Return self to allow chaining + + # If a cursor object has already been created, apply the limit to it. + if queryset._cursor_obj: + queryset._cursor_obj.limit(queryset._limit) + return queryset def skip(self, n): @@ -771,6 +773,11 @@ class BaseQuerySet(object): """ queryset = self.clone() queryset._skip = n + + # If a cursor object has already been created, apply the skip to it. + if queryset._cursor_obj: + queryset._cursor_obj.skip(queryset._skip) + return queryset def hint(self, index=None): @@ -788,6 +795,11 @@ class BaseQuerySet(object): """ queryset = self.clone() queryset._hint = index + + # If a cursor object has already been created, apply the hint to it. + if queryset._cursor_obj: + queryset._cursor_obj.hint(queryset._hint) + return queryset def batch_size(self, size): @@ -801,6 +813,11 @@ class BaseQuerySet(object): """ queryset = self.clone() queryset._batch_size = size + + # If a cursor object has already been created, apply the batch size to it. + if queryset._cursor_obj: + queryset._cursor_obj.batch_size(queryset._batch_size) + return queryset def distinct(self, field): @@ -972,13 +989,31 @@ class BaseQuerySet(object): def order_by(self, *keys): """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. + Ascending order is assumed. If no keys are passed, existing ordering + is cleared instead. :param keys: fields to order the query results by; keys may be prefixed with **+** or **-** to determine the ordering direction """ queryset = self.clone() - queryset._ordering = queryset._get_order_by(keys) + + old_ordering = queryset._ordering + new_ordering = queryset._get_order_by(keys) + + if queryset._cursor_obj: + + # If a cursor object has already been created, apply the sort to it + if new_ordering: + queryset._cursor_obj.sort(new_ordering) + + # If we're trying to clear a previous explicit ordering, we need + # to clear the cursor entirely (because PyMongo doesn't allow + # clearing an existing sort on a cursor). + elif old_ordering: + queryset._cursor_obj = None + + queryset._ordering = new_ordering + return queryset def comment(self, text): @@ -1424,10 +1459,13 @@ class BaseQuerySet(object): raise StopIteration raw_doc = self._cursor.next() + if self._as_pymongo: return self._get_as_pymongo(raw_doc) - doc = self._document._from_son(raw_doc, - _auto_dereference=self._auto_dereference, only_fields=self.only_fields) + + doc = self._document._from_son( + raw_doc, _auto_dereference=self._auto_dereference, + only_fields=self.only_fields) if self._scalar: return self._get_scalar(doc) @@ -1437,7 +1475,6 @@ class BaseQuerySet(object): def rewind(self): """Rewind the cursor to its unevaluated state. - .. versionadded:: 0.3 """ self._iter = False @@ -1487,43 +1524,54 @@ class BaseQuerySet(object): @property def _cursor(self): - if self._cursor_obj is None: + """Return a PyMongo cursor object corresponding to this queryset.""" - # In PyMongo 3+, we define the read preference on a collection - # level, not a cursor level. Thus, we need to get a cloned - # collection object using `with_options` first. - if IS_PYMONGO_3 and self._read_preference is not None: - self._cursor_obj = self._collection\ - .with_options(read_preference=self._read_preference)\ - .find(self._query, **self._cursor_args) - else: - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - where_clause = self._sub_js_fields(self._where_clause) - self._cursor_obj.where(where_clause) + # If _cursor_obj already exists, return it immediately. + if self._cursor_obj is not None: + return self._cursor_obj - if self._ordering: - # Apply query ordering - self._cursor_obj.sort(self._ordering) - elif self._ordering is None and self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model, unless - # it's been explicitly cleared via order_by with no arguments - order = self._get_order_by(self._document._meta['ordering']) - self._cursor_obj.sort(order) + # Create a new PyMongo cursor. + # XXX In PyMongo 3+, we define the read preference on a collection + # level, not a cursor level. Thus, we need to get a cloned collection + # object using `with_options` first. + if IS_PYMONGO_3 and self._read_preference is not None: + self._cursor_obj = self._collection\ + .with_options(read_preference=self._read_preference)\ + .find(self._query, **self._cursor_args) + else: + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply "where" clauses to cursor + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + self._cursor_obj.where(where_clause) - if self._limit is not None: - self._cursor_obj.limit(self._limit) + # Apply ordering to the cursor. + # XXX self._ordering can be equal to: + # * None if we didn't explicitly call order_by on this queryset. + # * A list of PyMongo-style sorting tuples. + # * An empty list if we explicitly called order_by() without any + # arguments. This indicates that we want to clear the default + # ordering. + if self._ordering: + # explicit ordering + self._cursor_obj.sort(self._ordering) + elif self._ordering is None and self._document._meta['ordering']: + # default ordering + order = self._get_order_by(self._document._meta['ordering']) + self._cursor_obj.sort(order) - if self._skip is not None: - self._cursor_obj.skip(self._skip) + if self._limit is not None: + self._cursor_obj.limit(self._limit) - if self._hint != -1: - self._cursor_obj.hint(self._hint) + if self._skip is not None: + self._cursor_obj.skip(self._skip) - if self._batch_size is not None: - self._cursor_obj.batch_size(self._batch_size) + if self._hint != -1: + self._cursor_obj.hint(self._hint) + + if self._batch_size is not None: + self._cursor_obj.batch_size(self._batch_size) return self._cursor_obj @@ -1698,7 +1746,13 @@ class BaseQuerySet(object): return ret def _get_order_by(self, keys): - """Creates a list of order by fields""" + """Given a list of MongoEngine-style sort keys, return a list + of sorting tuples that can be applied to a PyMongo cursor. For + example: + + >>> qs._get_order_by(['-last_name', 'first_name']) + [('last_name', -1), ('first_name', 1)] + """ key_list = [] for key in keys: if not key: @@ -1711,17 +1765,19 @@ class BaseQuerySet(object): direction = pymongo.ASCENDING if key[0] == '-': direction = pymongo.DESCENDING + if key[0] in ('-', '+'): key = key[1:] + key = key.replace('__', '.') try: key = self._document._translate_field_name(key) except Exception: + # TODO this exception should be more specific pass + key_list.append((key, direction)) - if self._cursor_obj and key_list: - self._cursor_obj.sort(key_list) return key_list def _get_scalar(self, doc): @@ -1819,10 +1875,21 @@ class BaseQuerySet(object): return code def _chainable_method(self, method_name, val): + """Call a particular method on the PyMongo cursor call a particular chainable method + with the provided value. + """ queryset = self.clone() - method = getattr(queryset._cursor, method_name) - method(val) + + # Get an existing cursor object or create a new one + cursor = queryset._cursor + + # Find the requested method on the cursor and call it with the + # provided value + getattr(cursor, method_name)(val) + + # Cache the value on the queryset._{method_name} setattr(queryset, '_' + method_name, val) + return queryset # Deprecated diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 9c1f24e1..b5d2765b 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -136,13 +136,15 @@ class QuerySet(BaseQuerySet): return self._len def no_cache(self): - """Convert to a non_caching queryset + """Convert to a non-caching queryset .. versionadded:: 0.8.3 Convert to non caching queryset """ if self._result_cache is not None: raise OperationError('QuerySet already cached') - return self.clone_into(QuerySetNoCache(self._document, self._collection)) + + return self._clone_into(QuerySetNoCache(self._document, + self._collection)) class QuerySetNoCache(BaseQuerySet): @@ -153,7 +155,7 @@ class QuerySetNoCache(BaseQuerySet): .. versionadded:: 0.8.3 Convert to caching queryset """ - return self.clone_into(QuerySet(self._document, self._collection)) + return self._clone_into(QuerySet(self._document, self._collection)) def __repr__(self): """Provides the string representation of the QuerySet diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 2d5b5b0f..c54fa13d 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -106,58 +106,111 @@ class QuerySetTest(unittest.TestCase): list(BlogPost.objects(author2__name="test")) def test_find(self): - """Ensure that a query returns a valid set of results. - """ - self.Person(name="User A", age=20).save() - self.Person(name="User B", age=30).save() + """Ensure that a query returns a valid set of results.""" + user_a = self.Person.objects.create(name='User A', age=20) + user_b = self.Person.objects.create(name='User B', age=30) # Find all people in the collection people = self.Person.objects self.assertEqual(people.count(), 2) results = list(people) + self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode))) - self.assertEqual(results[0].name, "User A") + + self.assertEqual(results[0], user_a) + self.assertEqual(results[0].name, 'User A') self.assertEqual(results[0].age, 20) - self.assertEqual(results[1].name, "User B") + + self.assertEqual(results[1], user_b) + self.assertEqual(results[1].name, 'User B') self.assertEqual(results[1].age, 30) - # Use a query to filter the people found to just person1 + # Filter people by age people = self.Person.objects(age=20) self.assertEqual(people.count(), 1) person = people.next() + self.assertEqual(person, user_a) self.assertEqual(person.name, "User A") self.assertEqual(person.age, 20) - # Test limit + def test_limit(self): + """Ensure that QuerySet.limit works as expected.""" + user_a = self.Person.objects.create(name='User A', age=20) + user_b = self.Person.objects.create(name='User B', age=30) + + # Test limit on a new queryset people = list(self.Person.objects.limit(1)) self.assertEqual(len(people), 1) - self.assertEqual(people[0].name, 'User A') + self.assertEqual(people[0], user_a) - # Test skip + # Test limit on an existing queryset + people = self.Person.objects + self.assertEqual(len(people), 2) + people2 = people.limit(1) + self.assertEqual(len(people), 2) + self.assertEqual(len(people2), 1) + self.assertEqual(people2[0], user_a) + + # Test chaining of only after limit + person = self.Person.objects().limit(1).only('name').first() + self.assertEqual(person, user_a) + self.assertEqual(person.name, 'User A') + self.assertEqual(person.age, None) + + def test_skip(self): + """Ensure that QuerySet.skip works as expected.""" + user_a = self.Person.objects.create(name='User A', age=20) + user_b = self.Person.objects.create(name='User B', age=30) + + # Test skip on a new queryset people = list(self.Person.objects.skip(1)) self.assertEqual(len(people), 1) - self.assertEqual(people[0].name, 'User B') + self.assertEqual(people[0], user_b) - person3 = self.Person(name="User C", age=40) - person3.save() + # Test skip on an existing queryset + people = self.Person.objects + self.assertEqual(len(people), 2) + people2 = people.skip(1) + self.assertEqual(len(people), 2) + self.assertEqual(len(people2), 1) + self.assertEqual(people2[0], user_b) + + # Test chaining of only after skip + person = self.Person.objects().skip(1).only('name').first() + self.assertEqual(person, user_b) + self.assertEqual(person.name, 'User B') + self.assertEqual(person.age, None) + + def test_slice(self): + """Ensure slicing a queryset works as expected.""" + user_a = self.Person.objects.create(name='User A', age=20) + user_b = self.Person.objects.create(name='User B', age=30) + user_c = self.Person.objects.create(name="User C", age=40) # Test slice limit people = list(self.Person.objects[:2]) self.assertEqual(len(people), 2) - self.assertEqual(people[0].name, 'User A') - self.assertEqual(people[1].name, 'User B') + self.assertEqual(people[0], user_a) + self.assertEqual(people[1], user_b) # Test slice skip people = list(self.Person.objects[1:]) self.assertEqual(len(people), 2) - self.assertEqual(people[0].name, 'User B') - self.assertEqual(people[1].name, 'User C') + self.assertEqual(people[0], user_b) + self.assertEqual(people[1], user_c) # Test slice limit and skip people = list(self.Person.objects[1:2]) self.assertEqual(len(people), 1) - self.assertEqual(people[0].name, 'User B') + self.assertEqual(people[0], user_b) + + # Test slice limit and skip on an existing queryset + people = self.Person.objects + self.assertEqual(len(people), 3) + people2 = people[1:2] + self.assertEqual(len(people2), 1) + self.assertEqual(people2[0], user_b) # Test slice limit and skip cursor reset qs = self.Person.objects[1:2] @@ -168,6 +221,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(len(people), 1) self.assertEqual(people[0].name, 'User B') + # Test empty slice people = list(self.Person.objects[1:1]) self.assertEqual(len(people), 0) @@ -187,12 +241,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual("[, ]", "%s" % self.Person.objects[51:53]) - # Test only after limit - self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None) - - # Test only after skip - self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None) - def test_find_one(self): """Ensure that a query using find_one returns a valid result. """ @@ -1226,6 +1274,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + # default ordering should be used by default with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').first() self.assertEqual(len(q.get_ops()), 1) @@ -1234,11 +1283,28 @@ class QuerySetTest(unittest.TestCase): {'published_date': -1} ) + # calling order_by() should clear the default ordering with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by().first() self.assertEqual(len(q.get_ops()), 1) self.assertFalse('$orderby' in q.get_ops()[0]['query']) + # calling an explicit order_by should use a specified sort + with db_ops_tracker() as q: + BlogPost.objects.filter(title='whatever').order_by('published_date').first() + self.assertEqual(len(q.get_ops()), 1) + self.assertEqual( + q.get_ops()[0]['query']['$orderby'], + {'published_date': 1} + ) + + # calling order_by() after an explicit sort should clear it + with db_ops_tracker() as q: + qs = BlogPost.objects.filter(title='whatever').order_by('published_date') + qs.order_by().first() + self.assertEqual(len(q.get_ops()), 1) + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + def test_no_ordering_for_get(self): """ Ensure that Doc.objects.get doesn't use any ordering. """