Merge branch 'feature/slicing_fields' into dev

This commit is contained in:
Ross Lawley 2011-05-20 14:18:48 +01:00
commit 04953351f1
2 changed files with 147 additions and 33 deletions

View File

@ -8,6 +8,7 @@ import pymongo.objectid
import re import re
import copy import copy
import itertools import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True ONLY = True
EXCLUDE = False EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]): def __init__(self, fields=[], value=ONLY, always_include=[]):
self.direction = direction self.value = value
self.fields = set(fields) self.fields = set(fields)
self.always_include = set(always_include) self.always_include = set(always_include)
def as_dict(self): def as_dict(self):
return dict((field, self.direction) for field in self.fields) return dict((field, self.value) for field in self.fields)
def __add__(self, f): def __add__(self, f):
if not self.fields: if not self.fields:
self.fields = f.fields self.fields = f.fields
self.direction = f.direction self.value = f.value
elif self.direction is self.ONLY and f.direction is self.ONLY: elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields) self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields) self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE: elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY: elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.direction = self.ONLY self.value = self.ONLY
self.fields = f.fields - self.fields self.fields = f.fields - self.fields
if self.always_include: if self.always_include:
if self.direction is self.ONLY and self.fields: if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include) self.fields = self.fields.union(self.always_include)
else: else:
self.fields -= self.always_include self.fields -= self.always_include
@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self): def reset(self):
self.fields = set([]) self.fields = set([])
self.direction = self.ONLY self.value = self.ONLY
def __nonzero__(self): def __nonzero__(self):
return bool(self.fields) return bool(self.fields)
@ -547,7 +548,7 @@ class QuerySet(object):
return '.'.join(parts) return '.'.join(parts)
@classmethod @classmethod
def _transform_query(cls, _doc_cls=None, **query): def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
@ -894,10 +895,8 @@ class QuerySet(object):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.ONLY) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) return self.fields(**fields)
return self
def exclude(self, *fields): def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. :: """Opposite to .only(), exclude some document's fields. ::
@ -906,8 +905,44 @@ class QuerySet(object):
:param fields: fields to exclude :param fields: fields to exclude
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self return self
def all_fields(self): def all_fields(self):
@ -1291,7 +1326,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']: if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta # Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():

View File

@ -597,6 +597,81 @@ class QuerySetTest(unittest.TestCase):
Email.drop_collection() Email.drop_collection()
def test_slicing_fields(self):
"""Ensure that query slicing an array works.
"""
class Numbers(Document):
n = ListField(IntField())
Numbers.drop_collection()
numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__n=3).get()
self.assertEquals(numbers.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__n=-3).get()
self.assertEquals(numbers.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__n=[2, 3]).get()
self.assertEquals(numbers.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__n=[-5, 4]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__n=[-5, 10]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
def test_slicing_nested_fields(self):
"""Ensure that query slicing an embedded array works.
"""
class EmbeddedNumber(EmbeddedDocument):
n = ListField(IntField())
class Numbers(Document):
embedded = EmbeddedDocumentField(EmbeddedNumber)
Numbers.drop_collection()
numbers = Numbers()
numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__embedded__n=3).get()
self.assertEquals(numbers.embedded.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__embedded__n=-3).get()
self.assertEquals(numbers.embedded.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get()
self.assertEquals(numbers.embedded.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
""" """
@ -1951,49 +2026,53 @@ class QueryFieldListTest(unittest.TestCase):
def test_include_include(self): def test_include_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True}) self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self): def test_include_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True}) self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self): def test_exclude_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self): def test_exclude_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True}) self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self): def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self): def test_reset(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset() q.reset()
self.assertFalse(q) self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
def test_using_a_slice(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a'], value={"$slice": 5})
self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()