diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 54d7643d..8469e715 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -548,7 +548,7 @@ class QuerySet(object): return '.'.join(parts) @classmethod - def _transform_query(cls, _doc_cls=None, **query): + def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', @@ -915,13 +915,26 @@ class QuerySet(object): You can use the $slice operator to retrieve a subrange of elements in an array :: - post = BlogPost.objects(...).fields(comments={"$slice": 5}) // first 5 comments + post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments :param kwargs: A dictionary identifying what to include .. versionadded:: 0.5 """ - fields = sorted(kwargs.iteritems(), key=operator.itemgetter(1)) + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) for value, group in itertools.groupby(fields, lambda x: x[1]): fields = [field for field, value in group] fields = self._fields_to_dbfields(fields) diff --git a/tests/queryset.py b/tests/queryset.py index 1961d7cf..e29a6d9d 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -597,10 +597,9 @@ class QuerySetTest(unittest.TestCase): Email.drop_collection() - def test_custom_fields(self): + def test_slicing_fields(self): """Ensure that query slicing an array works. """ - class Numbers(Document): n = ListField(IntField()) @@ -610,25 +609,69 @@ class QuerySetTest(unittest.TestCase): numbers.save() # first three - numbers = Numbers.objects.fields(n={"$slice": 3}).get() + numbers = Numbers.objects.fields(slice__n=3).get() self.assertEquals(numbers.n, [0, 1, 2]) # last three - numbers = Numbers.objects.fields(n={"$slice": -3}).get() + numbers = Numbers.objects.fields(slice__n=-3).get() self.assertEquals(numbers.n, [-3, -2, -1]) # skip 2, limit 3 - numbers = Numbers.objects.fields(n={"$slice": [2, 3]}).get() + numbers = Numbers.objects.fields(slice__n=[2, 3]).get() self.assertEquals(numbers.n, [2, 3, 4]) # skip to fifth from last, limit 4 - numbers = Numbers.objects.fields(n={"$slice": [-5, 4]}).get() + numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() self.assertEquals(numbers.n, [-5, -4, -3, -2]) # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + def test_slicing_nested_fields(self): + """Ensure that query slicing an embedded array works. + """ + + class EmbeddedNumber(EmbeddedDocument): + n = ListField(IntField()) + + class Numbers(Document): + embedded = EmbeddedDocumentField(EmbeddedNumber) + + Numbers.drop_collection() + + numbers = Numbers() + numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__embedded__n=3).get() + self.assertEquals(numbers.embedded.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__embedded__n=-3).get() + self.assertEquals(numbers.embedded.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() + self.assertEquals(numbers.embedded.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """