diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 5f748b50..45da9a21 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -275,6 +275,8 @@ class BaseQuerySet(object): except StopIteration: return result + # If we were able to retrieve the 2nd doc, rewind the cursor and + # raise the MultipleObjectsReturned exception. queryset.rewind() message = u'%d items returned, instead of 1' % queryset.count() raise queryset._document.MultipleObjectsReturned(message) diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index e68537d2..b185b340 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -27,9 +27,10 @@ class QuerySet(BaseQuerySet): in batches of ``ITER_CHUNK_SIZE``. If ``self._has_more`` the cursor hasn't been exhausted so cache then - batch. Otherwise iterate the result_cache. + batch. Otherwise iterate the result_cache. """ self._iter = True + if self._has_more: return self._iter_results() @@ -42,10 +43,12 @@ class QuerySet(BaseQuerySet): """ if self._len is not None: return self._len + + # Populate the result cache with *all* of the docs in the cursor if self._has_more: - # populate the cache list(self._iter_results()) + # Cache the length of the complete result cache and return it self._len = len(self._result_cache) return self._len @@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet): def _iter_results(self): """A generator for iterating over the result cache. - Also populates the cache if there are more possible results to yield. - Raises StopIteration when there are no more results""" + Also populates the cache if there are more possible results to + yield. Raises StopIteration when there are no more results. + """ if self._result_cache is None: self._result_cache = [] + pos = 0 while True: - upper = len(self._result_cache) - while pos < upper: + + # For all positions lower than the length of the current result + # cache, serve the docs straight from the cache w/o hitting the + # database. + # XXX it's VERY important to compute the len within the `while` + # condition because the result cache might expand mid-iteration + # (e.g. if we call len(qs) inside a loop that iterates over the + # queryset). Fortunately len(list) is O(1) in Python, so this + # doesn't cause performance issues. + while pos < len(self._result_cache): yield self._result_cache[pos] pos += 1 + + # Raise StopIteration if we already established there were no more + # docs in the db cursor. if not self._has_more: raise StopIteration + + # Otherwise, populate more of the cache and repeat. if len(self._result_cache) <= pos: self._populate_cache() @@ -86,12 +104,22 @@ class QuerySet(BaseQuerySet): """ if self._result_cache is None: self._result_cache = [] - if self._has_more: - try: - for i in xrange(ITER_CHUNK_SIZE): - self._result_cache.append(self.next()) - except StopIteration: - self._has_more = False + + # Skip populating the cache if we already established there are no + # more docs to pull from the database. + if not self._has_more: + return + + # Pull in ITER_CHUNK_SIZE docs from the database and store them in + # the result cache. + try: + for i in xrange(ITER_CHUNK_SIZE): + self._result_cache.append(self.next()) + except StopIteration: + # Getting this exception means there are no more docs in the + # db cursor. Set _has_more to False so that we can use that + # information in other places. + self._has_more = False def count(self, with_limit_and_skip=False): """Count the selected elements in the query. diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 88ae18aa..0fb78b6b 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -4890,6 +4890,56 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(1, Doc.objects(item__type__="axe").count()) + def test_len_during_iteration(self): + """Tests that calling len on a queyset during iteration doesn't + stop paging. + """ + class Data(Document): + pass + + for i in xrange(300): + Data().save() + + records = Data.objects.limit(250) + + # This should pull all 250 docs from mongo and populate the result + # cache + len(records) + + # Assert that iterating over documents in the qs touches every + # document even if we call len(qs) midway through the iteration. + for i, r in enumerate(records): + if i == 58: + len(records) + self.assertEqual(i, 249) + + # Assert the same behavior is true even if we didn't pre-populate the + # result cache. + records = Data.objects.limit(250) + for i, r in enumerate(records): + if i == 58: + len(records) + self.assertEqual(i, 249) + + def test_iteration_within_iteration(self): + """You should be able to reliably iterate over all the documents + in a given queryset even if there are multiple iterations of it + happening at the same time. + """ + class Data(Document): + pass + + for i in xrange(300): + Data().save() + + qs = Data.objects.limit(250) + for i, doc in enumerate(qs): + for j, doc2 in enumerate(qs): + pass + + self.assertEqual(i, 249) + self.assertEqual(j, 249) + if __name__ == '__main__': unittest.main()