Updated QuerySet to allow more granular fields control.

Added a fields method and tests showing the retrival of subranges of
List Fields.

Refs #167
This commit is contained in:
Ross Lawley 2011-05-18 16:39:19 +01:00
parent 5d5a84dbcf
commit 371dbf009f
2 changed files with 90 additions and 32 deletions

View File

@ -8,6 +8,7 @@ import pymongo.objectid
import re import re
import copy import copy
import itertools import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True ONLY = True
EXCLUDE = False EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]): def __init__(self, fields=[], value=ONLY, always_include=[]):
self.direction = direction self.value = value
self.fields = set(fields) self.fields = set(fields)
self.always_include = set(always_include) self.always_include = set(always_include)
def as_dict(self): def as_dict(self):
return dict((field, self.direction) for field in self.fields) return dict((field, self.value) for field in self.fields)
def __add__(self, f): def __add__(self, f):
if not self.fields: if not self.fields:
self.fields = f.fields self.fields = f.fields
self.direction = f.direction self.value = f.value
elif self.direction is self.ONLY and f.direction is self.ONLY: elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields) self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields) self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE: elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY: elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.direction = self.ONLY self.value = self.ONLY
self.fields = f.fields - self.fields self.fields = f.fields - self.fields
if self.always_include: if self.always_include:
if self.direction is self.ONLY and self.fields: if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include) self.fields = self.fields.union(self.always_include)
else: else:
self.fields -= self.always_include self.fields -= self.always_include
@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self): def reset(self):
self.fields = set([]) self.fields = set([])
self.direction = self.ONLY self.value = self.ONLY
def __nonzero__(self): def __nonzero__(self):
return bool(self.fields) return bool(self.fields)
@ -890,10 +891,8 @@ class QuerySet(object):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.ONLY) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) return self.fields(**fields)
return self
def exclude(self, *fields): def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. :: """Opposite to .only(), exclude some document's fields. ::
@ -902,8 +901,31 @@ class QuerySet(object):
:param fields: fields to exclude :param fields: fields to exclude
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(comments={"$slice": 5}) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
fields = sorted(kwargs.iteritems(), key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self return self
def all_fields(self): def all_fields(self):
@ -1277,7 +1299,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']: if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta # Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():

View File

@ -597,6 +597,38 @@ class QuerySetTest(unittest.TestCase):
Email.drop_collection() Email.drop_collection()
def test_custom_fields(self):
"""Ensure that query slicing an array works.
"""
class Numbers(Document):
n = ListField(IntField())
Numbers.drop_collection()
numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(n={"$slice": 3}).get()
self.assertEquals(numbers.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(n={"$slice": -3}).get()
self.assertEquals(numbers.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(n={"$slice": [2, 3]}).get()
self.assertEquals(numbers.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(n={"$slice": [-5, 4]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
""" """
@ -1931,49 +1963,53 @@ class QueryFieldListTest(unittest.TestCase):
def test_include_include(self): def test_include_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True}) self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self): def test_include_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True}) self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self): def test_exclude_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self): def test_exclude_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True}) self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self): def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self): def test_reset(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset() q.reset()
self.assertFalse(q) self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
def test_using_a_slice(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a'], value={"$slice": 5})
self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()