Compare commits

...

7 Commits

Author SHA1 Message Date
Stefan Wojcik
70f1067470 more thorough and isolated unit tests 2017-02-08 00:05:19 -08:00
Stefan Wojcik
fba4477dd8 fix flake8 2017-02-07 23:51:38 -08:00
Stefan Wojcik
6ac7a8eec6 fix chaining of the limit, skip, hint, and batch_size 2017-02-07 23:43:46 -08:00
Stefan Wojcik
bad4e6846a Cleaner ordering code + allow clearing an explicit sort 2017-02-07 23:38:56 -08:00
Stefan Wojcik
0e77efb0c7 Fix BaseQuerySet.comment 2017-02-07 23:34:39 -08:00
Stefan Wojcik
8a6f36256e minor cleanup 2017-02-06 19:29:35 -07:00
Stefan Wojcik
515ffac246 unit test showing issues with skip & limit on an existing queryset 2017-02-06 18:55:15 -07:00
3 changed files with 220 additions and 85 deletions

View File

@ -86,6 +86,7 @@ class BaseQuerySet(object):
self._batch_size = None
self.only_fields = []
self._max_time_ms = None
self._comment = None
def __call__(self, q_obj=None, class_check=True, read_preference=None,
**query):
@ -706,39 +707,36 @@ class BaseQuerySet(object):
with switch_db(self._document, alias) as cls:
collection = cls._get_collection()
return self.clone_into(self.__class__(self._document, collection))
return self._clone_into(self.__class__(self._document, collection))
def clone(self):
"""Creates a copy of the current
:class:`~mongoengine.queryset.QuerySet`
"""Create a copy of the current queryset."""
return self._clone_into(self.__class__(self._document, self._collection_obj))
.. versionadded:: 0.5
def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
"""
return self.clone_into(self.__class__(self._document, self._collection_obj))
def clone_into(self, cls):
"""Creates a copy of the current
:class:`~mongoengine.queryset.base.BaseQuerySet` into another child class
"""
if not isinstance(cls, BaseQuerySet):
if not isinstance(new_qs, BaseQuerySet):
raise OperationError(
'%s is not a subclass of BaseQuerySet' % cls.__name__)
'%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_class_check', '_slave_okay', '_read_preference',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms')
'_search_text', 'only_fields', '_max_time_ms', '_comment')
for prop in copy_props:
val = getattr(self, prop)
setattr(cls, prop, copy.copy(val))
setattr(new_qs, prop, copy.copy(val))
if self._cursor_obj:
cls._cursor_obj = self._cursor_obj.clone()
new_qs._cursor_obj = self._cursor_obj.clone()
return cls
return new_qs
def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects or
@ -760,7 +758,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._limit = n if n != 0 else 1
# Return self to allow chaining
# If a cursor object has already been created, apply the limit to it.
if queryset._cursor_obj:
queryset._cursor_obj.limit(queryset._limit)
return queryset
def skip(self, n):
@ -771,6 +773,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._skip = n
# If a cursor object has already been created, apply the skip to it.
if queryset._cursor_obj:
queryset._cursor_obj.skip(queryset._skip)
return queryset
def hint(self, index=None):
@ -788,6 +795,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._hint = index
# If a cursor object has already been created, apply the hint to it.
if queryset._cursor_obj:
queryset._cursor_obj.hint(queryset._hint)
return queryset
def batch_size(self, size):
@ -801,6 +813,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._batch_size = size
# If a cursor object has already been created, apply the batch size to it.
if queryset._cursor_obj:
queryset._cursor_obj.batch_size(queryset._batch_size)
return queryset
def distinct(self, field):
@ -972,13 +989,31 @@ class BaseQuerySet(object):
def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed.
Ascending order is assumed. If no keys are passed, existing ordering
is cleared instead.
:param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction
"""
queryset = self.clone()
queryset._ordering = queryset._get_order_by(keys)
old_ordering = queryset._ordering
new_ordering = queryset._get_order_by(keys)
if queryset._cursor_obj:
# If a cursor object has already been created, apply the sort to it
if new_ordering:
queryset._cursor_obj.sort(new_ordering)
# If we're trying to clear a previous explicit ordering, we need
# to clear the cursor entirely (because PyMongo doesn't allow
# clearing an existing sort on a cursor).
elif old_ordering:
queryset._cursor_obj = None
queryset._ordering = new_ordering
return queryset
def comment(self, text):
@ -1424,10 +1459,13 @@ class BaseQuerySet(object):
raise StopIteration
raw_doc = self._cursor.next()
if self._as_pymongo:
return self._get_as_pymongo(raw_doc)
doc = self._document._from_son(raw_doc,
_auto_dereference=self._auto_dereference, only_fields=self.only_fields)
doc = self._document._from_son(
raw_doc, _auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
if self._scalar:
return self._get_scalar(doc)
@ -1437,7 +1475,6 @@ class BaseQuerySet(object):
def rewind(self):
"""Rewind the cursor to its unevaluated state.
.. versionadded:: 0.3
"""
self._iter = False
@ -1487,43 +1524,54 @@ class BaseQuerySet(object):
@property
def _cursor(self):
if self._cursor_obj is None:
"""Return a PyMongo cursor object corresponding to this queryset."""
# In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned
# collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
# If _cursor_obj already exists, return it immediately.
if self._cursor_obj is not None:
return self._cursor_obj
if self._ordering:
# Apply query ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model, unless
# it's been explicitly cleared via order_by with no arguments
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
# Create a new PyMongo cursor.
# XXX In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply "where" clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
# Apply ordering to the cursor.
# XXX self._ordering can be equal to:
# * None if we didn't explicitly call order_by on this queryset.
# * A list of PyMongo-style sorting tuples.
# * An empty list if we explicitly called order_by() without any
# arguments. This indicates that we want to clear the default
# ordering.
if self._ordering:
# explicit ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# default ordering
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj
@ -1698,7 +1746,13 @@ class BaseQuerySet(object):
return ret
def _get_order_by(self, keys):
"""Creates a list of order by fields"""
"""Given a list of MongoEngine-style sort keys, return a list
of sorting tuples that can be applied to a PyMongo cursor. For
example:
>>> qs._get_order_by(['-last_name', 'first_name'])
[('last_name', -1), ('first_name', 1)]
"""
key_list = []
for key in keys:
if not key:
@ -1711,17 +1765,19 @@ class BaseQuerySet(object):
direction = pymongo.ASCENDING
if key[0] == '-':
direction = pymongo.DESCENDING
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
try:
key = self._document._translate_field_name(key)
except Exception:
# TODO this exception should be more specific
pass
key_list.append((key, direction))
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list
def _get_scalar(self, doc):
@ -1819,10 +1875,21 @@ class BaseQuerySet(object):
return code
def _chainable_method(self, method_name, val):
"""Call a particular method on the PyMongo cursor call a particular chainable method
with the provided value.
"""
queryset = self.clone()
method = getattr(queryset._cursor, method_name)
method(val)
# Get an existing cursor object or create a new one
cursor = queryset._cursor
# Find the requested method on the cursor and call it with the
# provided value
getattr(cursor, method_name)(val)
# Cache the value on the queryset._{method_name}
setattr(queryset, '_' + method_name, val)
return queryset
# Deprecated

View File

@ -136,13 +136,15 @@ class QuerySet(BaseQuerySet):
return self._len
def no_cache(self):
"""Convert to a non_caching queryset
"""Convert to a non-caching queryset
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection))
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
class QuerySetNoCache(BaseQuerySet):
@ -153,7 +155,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to caching queryset
"""
return self.clone_into(QuerySet(self._document, self._collection))
return self._clone_into(QuerySet(self._document, self._collection))
def __repr__(self):
"""Provides the string representation of the QuerySet

View File

@ -106,58 +106,111 @@ class QuerySetTest(unittest.TestCase):
list(BlogPost.objects(author2__name="test"))
def test_find(self):
"""Ensure that a query returns a valid set of results.
"""
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=30).save()
"""Ensure that a query returns a valid set of results."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Find all people in the collection
people = self.Person.objects
self.assertEqual(people.count(), 2)
results = list(people)
self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode)))
self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0], user_a)
self.assertEqual(results[0].name, 'User A')
self.assertEqual(results[0].age, 20)
self.assertEqual(results[1].name, "User B")
self.assertEqual(results[1], user_b)
self.assertEqual(results[1].name, 'User B')
self.assertEqual(results[1].age, 30)
# Use a query to filter the people found to just person1
# Filter people by age
people = self.Person.objects(age=20)
self.assertEqual(people.count(), 1)
person = people.next()
self.assertEqual(person, user_a)
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
# Test limit
def test_limit(self):
"""Ensure that QuerySet.limit works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test limit on a new queryset
people = list(self.Person.objects.limit(1))
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[0], user_a)
# Test skip
# Test limit on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.limit(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_a)
# Test chaining of only after limit
person = self.Person.objects().limit(1).only('name').first()
self.assertEqual(person, user_a)
self.assertEqual(person.name, 'User A')
self.assertEqual(person.age, None)
def test_skip(self):
"""Ensure that QuerySet.skip works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test skip on a new queryset
people = list(self.Person.objects.skip(1))
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[0], user_b)
person3 = self.Person(name="User C", age=40)
person3.save()
# Test skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.skip(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test chaining of only after skip
person = self.Person.objects().skip(1).only('name').first()
self.assertEqual(person, user_b)
self.assertEqual(person.name, 'User B')
self.assertEqual(person.age, None)
def test_slice(self):
"""Ensure slicing a queryset works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
user_c = self.Person.objects.create(name="User C", age=40)
# Test slice limit
people = list(self.Person.objects[:2])
self.assertEqual(len(people), 2)
self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[1].name, 'User B')
self.assertEqual(people[0], user_a)
self.assertEqual(people[1], user_b)
# Test slice skip
people = list(self.Person.objects[1:])
self.assertEqual(len(people), 2)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[1].name, 'User C')
self.assertEqual(people[0], user_b)
self.assertEqual(people[1], user_c)
# Test slice limit and skip
people = list(self.Person.objects[1:2])
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[0], user_b)
# Test slice limit and skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 3)
people2 = people[1:2]
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test slice limit and skip cursor reset
qs = self.Person.objects[1:2]
@ -168,6 +221,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
# Test empty slice
people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0)
@ -187,12 +241,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual("[<Person: Person object>, <Person: Person object>]",
"%s" % self.Person.objects[51:53])
# Test only after limit
self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None)
# Test only after skip
self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None)
def test_find_one(self):
"""Ensure that a query using find_one returns a valid result.
"""
@ -1226,6 +1274,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
# default ordering should be used by default
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1)
@ -1234,11 +1283,28 @@ class QuerySetTest(unittest.TestCase):
{'published_date': -1}
)
# calling order_by() should clear the default ordering
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
# calling an explicit order_by should use a specified sort
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by('published_date').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
q.get_ops()[0]['query']['$orderby'],
{'published_date': 1}
)
# calling order_by() after an explicit sort should clear it
with db_ops_tracker() as q:
qs = BlogPost.objects.filter(title='whatever').order_by('published_date')
qs.order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_no_ordering_for_get(self):
""" Ensure that Doc.objects.get doesn't use any ordering.
"""