diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 5f748b50..2702a46b 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -82,6 +82,7 @@ class BaseQuerySet(object): self._limit = None self._skip = None self._hint = -1 # Using -1 as None is a valid value for hint + self._batch_size = None self.only_fields = [] self._max_time_ms = None @@ -781,6 +782,16 @@ class BaseQuerySet(object): queryset._hint = index return queryset + def batch_size(self, size): + """Limit the number of documents returned in a single batch (each + batch requires a round trip to the server). + + :param size: desired size of each batch. + """ + queryset = self.clone() + queryset._batch_size = size + return queryset + def distinct(self, field): """Return a list of distinct values for a given field. @@ -1467,6 +1478,9 @@ class BaseQuerySet(object): if self._hint != -1: self._cursor_obj.hint(self._hint) + if self._batch_size is not None: + self._cursor_obj.batch_size(self._batch_size) + return self._cursor_obj def __deepcopy__(self, memo): diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 88ae18aa..552e3ed7 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -337,6 +337,36 @@ class QuerySetTest(unittest.TestCase): query = query.filter(boolfield=True) self.assertEqual(query.count(), 1) + def test_batch_size(self): + """Ensure that batch_size works.""" + class A(Document): + s = StringField() + + A.drop_collection() + + for i in range(100): + A.objects.create(s=str(i)) + + # test iterating over the result set + cnt = 0 + for a in A.objects.batch_size(10): + cnt += 1 + self.assertEqual(cnt, 100) + + # test chaining + qs = A.objects.all() + qs = qs.limit(10).batch_size(20).skip(91) + cnt = 0 + for a in qs: + cnt += 1 + self.assertEqual(cnt, 9) + + # test invalid batch size + qs = A.objects.batch_size(-1) + self.assertRaises(ValueError, lambda: list(qs)) + qs = A.objects.filter(s='1').batch_size('not a number') + self.assertRaises(TypeError, lambda: [doc for doc in qs]) + def test_update_write_concern(self): """Test that passing write_concern works""" self.Person.drop_collection()