Compare commits
	
		
			2 Commits
		
	
	
		
			v0.23.1
			...
			fix-iterat
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 894678da39 | ||
|  | 0a66a4b8a9 | 
| @@ -275,6 +275,8 @@ class BaseQuerySet(object): | |||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             return result |             return result | ||||||
|  |  | ||||||
|  |         # If we were able to retrieve the 2nd doc, rewind the cursor and | ||||||
|  |         # raise the MultipleObjectsReturned exception. | ||||||
|         queryset.rewind() |         queryset.rewind() | ||||||
|         message = u'%d items returned, instead of 1' % queryset.count() |         message = u'%d items returned, instead of 1' % queryset.count() | ||||||
|         raise queryset._document.MultipleObjectsReturned(message) |         raise queryset._document.MultipleObjectsReturned(message) | ||||||
|   | |||||||
| @@ -30,6 +30,7 @@ class QuerySet(BaseQuerySet): | |||||||
|         batch. Otherwise iterate the result_cache. |         batch. Otherwise iterate the result_cache. | ||||||
|         """ |         """ | ||||||
|         self._iter = True |         self._iter = True | ||||||
|  |  | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|             return self._iter_results() |             return self._iter_results() | ||||||
|  |  | ||||||
| @@ -42,10 +43,12 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._len is not None: |         if self._len is not None: | ||||||
|             return self._len |             return self._len | ||||||
|  |  | ||||||
|  |         # Populate the result cache with *all* of the docs in the cursor | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|             # populate the cache |  | ||||||
|             list(self._iter_results()) |             list(self._iter_results()) | ||||||
|  |  | ||||||
|  |         # Cache the length of the complete result cache and return it | ||||||
|         self._len = len(self._result_cache) |         self._len = len(self._result_cache) | ||||||
|         return self._len |         return self._len | ||||||
|  |  | ||||||
| @@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet): | |||||||
|     def _iter_results(self): |     def _iter_results(self): | ||||||
|         """A generator for iterating over the result cache. |         """A generator for iterating over the result cache. | ||||||
|  |  | ||||||
|         Also populates the cache if there are more possible results to yield. |         Also populates the cache if there are more possible results to | ||||||
|         Raises StopIteration when there are no more results""" |         yield. Raises StopIteration when there are no more results. | ||||||
|  |         """ | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|  |  | ||||||
|         pos = 0 |         pos = 0 | ||||||
|         while True: |         while True: | ||||||
|             upper = len(self._result_cache) |  | ||||||
|             while pos < upper: |             # For all positions lower than the length of the current result | ||||||
|  |             # cache, serve the docs straight from the cache w/o hitting the | ||||||
|  |             # database. | ||||||
|  |             # XXX it's VERY important to compute the len within the `while` | ||||||
|  |             # condition because the result cache might expand mid-iteration | ||||||
|  |             # (e.g. if we call len(qs) inside a loop that iterates over the | ||||||
|  |             # queryset). Fortunately len(list) is O(1) in Python, so this | ||||||
|  |             # doesn't cause performance issues. | ||||||
|  |             while pos < len(self._result_cache): | ||||||
|                 yield self._result_cache[pos] |                 yield self._result_cache[pos] | ||||||
|                 pos += 1 |                 pos += 1 | ||||||
|  |  | ||||||
|  |             # Raise StopIteration if we already established there were no more | ||||||
|  |             # docs in the db cursor. | ||||||
|             if not self._has_more: |             if not self._has_more: | ||||||
|                 raise StopIteration |                 raise StopIteration | ||||||
|  |  | ||||||
|  |             # Otherwise, populate more of the cache and repeat. | ||||||
|             if len(self._result_cache) <= pos: |             if len(self._result_cache) <= pos: | ||||||
|                 self._populate_cache() |                 self._populate_cache() | ||||||
|  |  | ||||||
| @@ -86,11 +104,21 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|         if self._has_more: |  | ||||||
|  |         # Skip populating the cache if we already established there are no | ||||||
|  |         # more docs to pull from the database. | ||||||
|  |         if not self._has_more: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         # Pull in ITER_CHUNK_SIZE docs from the database and store them in | ||||||
|  |         # the result cache. | ||||||
|         try: |         try: | ||||||
|             for i in xrange(ITER_CHUNK_SIZE): |             for i in xrange(ITER_CHUNK_SIZE): | ||||||
|                 self._result_cache.append(self.next()) |                 self._result_cache.append(self.next()) | ||||||
|         except StopIteration: |         except StopIteration: | ||||||
|  |             # Getting this exception means there are no more docs in the | ||||||
|  |             # db cursor. Set _has_more to False so that we can use that | ||||||
|  |             # information in other places. | ||||||
|             self._has_more = False |             self._has_more = False | ||||||
|  |  | ||||||
|     def count(self, with_limit_and_skip=False): |     def count(self, with_limit_and_skip=False): | ||||||
|   | |||||||
| @@ -4890,6 +4890,56 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) |         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||||
|  |  | ||||||
|  |     def test_len_during_iteration(self): | ||||||
|  |         """Tests that calling len on a queyset during iteration doesn't | ||||||
|  |         stop paging. | ||||||
|  |         """ | ||||||
|  |         class Data(Document): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         for i in xrange(300): | ||||||
|  |             Data().save() | ||||||
|  |  | ||||||
|  |         records = Data.objects.limit(250) | ||||||
|  |  | ||||||
|  |         # This should pull all 250 docs from mongo and populate the result | ||||||
|  |         # cache | ||||||
|  |         len(records) | ||||||
|  |  | ||||||
|  |         # Assert that iterating over documents in the qs touches every | ||||||
|  |         # document even if we call len(qs) midway through the iteration. | ||||||
|  |         for i, r in enumerate(records): | ||||||
|  |             if i == 58: | ||||||
|  |                 len(records) | ||||||
|  |         self.assertEqual(i, 249) | ||||||
|  |  | ||||||
|  |         # Assert the same behavior is true even if we didn't pre-populate the | ||||||
|  |         # result cache. | ||||||
|  |         records = Data.objects.limit(250) | ||||||
|  |         for i, r in enumerate(records): | ||||||
|  |             if i == 58: | ||||||
|  |                 len(records) | ||||||
|  |         self.assertEqual(i, 249) | ||||||
|  |  | ||||||
|  |     def test_iteration_within_iteration(self): | ||||||
|  |         """You should be able to reliably iterate over all the documents | ||||||
|  |         in a given queryset even if there are multiple iterations of it | ||||||
|  |         happening at the same time. | ||||||
|  |         """ | ||||||
|  |         class Data(Document): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         for i in xrange(300): | ||||||
|  |             Data().save() | ||||||
|  |  | ||||||
|  |         qs = Data.objects.limit(250) | ||||||
|  |         for i, doc in enumerate(qs): | ||||||
|  |             for j, doc2 in enumerate(qs): | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |         self.assertEqual(i, 249) | ||||||
|  |         self.assertEqual(j, 249) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user