From 8c9afbd278eac13afdb9e12e23ec0e324d56d539 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 1 May 2013 19:40:49 +0000 Subject: [PATCH] Fix cloning of sliced querysets (#303) --- mongoengine/queryset/queryset.py | 14 +++----- tests/test_django.py | 60 +++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 1739f05e..c1c93781 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -72,7 +72,6 @@ class QuerySet(object): self._cursor_obj = None self._limit = None self._skip = None - self._slice = None self._hint = -1 # Using -1 as None is a valid value for hint def __call__(self, q_obj=None, class_check=True, slave_okay=False, @@ -127,8 +126,10 @@ class QuerySet(object): if isinstance(key, slice): try: queryset._cursor_obj = queryset._cursor[key] - queryset._slice = key queryset._skip, queryset._limit = key.start, key.stop + queryset._limit + if key.start and key.stop: + queryset._limit = key.stop - key.start except IndexError, err: # PyMongo raises an error if key.start == key.stop, catch it, # bin it, kill it. @@ -537,15 +538,9 @@ class QuerySet(object): val = getattr(self, prop) setattr(c, prop, copy.copy(val)) - if self._slice: - c._slice = self._slice - if self._cursor_obj: c._cursor_obj = self._cursor_obj.clone() - if self._slice: - c._cursor[self._slice] - return c def select_related(self, max_depth=1): @@ -571,7 +566,6 @@ class QuerySet(object): else: queryset._cursor.limit(n) queryset._limit = n - # Return self to allow chaining return queryset @@ -1155,7 +1149,7 @@ class QuerySet(object): self._cursor_obj.sort(order) if self._limit is not None: - self._cursor_obj.limit(self._limit - (self._skip or 0)) + self._cursor_obj.limit(self._limit) if self._skip is not None: self._cursor_obj.skip(self._skip) diff --git a/tests/test_django.py b/tests/test_django.py index e30fe3c9..f81213c3 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -150,22 +150,74 @@ class QuerySetTest(unittest.TestCase): # Try iterating the same queryset twice, nested, in a Django template. names = ['A', 'B', 'C', 'D'] - class User(Document): + class CustomUser(Document): name = StringField() def __unicode__(self): return self.name - User.drop_collection() + CustomUser.drop_collection() for name in names: - User(name=name).save() + CustomUser(name=name).save() - users = User.objects.all().order_by('name') + users = CustomUser.objects.all().order_by('name') template = Template("{% for user in users %}{{ user.name }}{% ifequal forloop.counter 2 %} {% for inner_user in users %}{{ inner_user.name }}{% endfor %} {% endifequal %}{% endfor %}") rendered = template.render(Context({'users': users})) self.assertEqual(rendered, 'AB ABCD CD') + def test_filter(self): + """Ensure that a queryset and filters work as expected + """ + + class Note(Document): + text = StringField() + + for i in xrange(1, 101): + Note(name="Note: %s" % i).save() + + # Check the count + self.assertEqual(Note.objects.count(), 100) + + # Get the first 10 and confirm + notes = Note.objects[:10] + self.assertEqual(notes.count(), 10) + + # Test djangos template filters + # self.assertEqual(length(notes), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + # Test with skip + notes = Note.objects.skip(90) + self.assertEqual(notes.count(), 10) + + # Test djangos template filters + self.assertEqual(notes.count(), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + # Test with limit + notes = Note.objects.skip(90) + self.assertEqual(notes.count(), 10) + + # Test djangos template filters + self.assertEqual(notes.count(), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + + # Test with skip and limit + notes = Note.objects.skip(10).limit(10) + + # Test djangos template filters + self.assertEqual(notes.count(), 10) + t = Template("{{ notes.count }}") + c = Context({"notes": notes}) + self.assertEqual(t.render(c), "10") + class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): backend = SessionStore