diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 45da9a21..17ee989e 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -82,6 +82,7 @@ class BaseQuerySet(object): self._limit = None self._skip = None self._hint = -1 # Using -1 as None is a valid value for hint + self._batch_size = None self.only_fields = [] self._max_time_ms = None @@ -783,6 +784,19 @@ class BaseQuerySet(object): queryset._hint = index return queryset + def batch_size(self, size): + """Limit the number of documents returned in a single batch (each + batch requires a round trip to the server). + + See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size + for details. + + :param size: desired size of each batch. + """ + queryset = self.clone() + queryset._batch_size = size + return queryset + def distinct(self, field): """Return a list of distinct values for a given field. @@ -1469,6 +1483,9 @@ class BaseQuerySet(object): if self._hint != -1: self._cursor_obj.hint(self._hint) + if self._batch_size is not None: + self._cursor_obj.batch_size(self._batch_size) + return self._cursor_obj def __deepcopy__(self, memo): diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 0fb78b6b..2c00838a 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -337,6 +337,34 @@ class QuerySetTest(unittest.TestCase): query = query.filter(boolfield=True) self.assertEqual(query.count(), 1) + def test_batch_size(self): + """Ensure that batch_size works.""" + class A(Document): + s = StringField() + + A.drop_collection() + + for i in range(100): + A.objects.create(s=str(i)) + + # test iterating over the result set + cnt = 0 + for a in A.objects.batch_size(10): + cnt += 1 + self.assertEqual(cnt, 100) + + # test chaining + qs = A.objects.all() + qs = qs.limit(10).batch_size(20).skip(91) + cnt = 0 + for a in qs: + cnt += 1 + self.assertEqual(cnt, 9) + + # test invalid batch size + qs = A.objects.batch_size(-1) + self.assertRaises(ValueError, lambda: list(qs)) + def test_update_write_concern(self): """Test that passing write_concern works""" self.Person.drop_collection()