added Queryset.exclude() + tests

This commit is contained in:
Ales Zoulek 2010-11-03 16:37:41 +01:00
parent c6058fafed
commit bda4776a18
2 changed files with 191 additions and 10 deletions

View File

@ -268,6 +268,48 @@ class Q(QNode):
return not bool(self.query) return not bool(self.query)
class QueryFieldList(object):
"""Object that handles combinations of .only() and .exclude() calls"""
ONLY = True
EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]):
self.direction = direction
self.fields = set(fields)
self.always_include = set(always_include)
def as_dict(self):
return dict((field, self.direction) for field in self.fields)
def __add__(self, f):
if not self.fields:
self.fields = f.fields
self.direction = f.direction
elif self.direction is self.ONLY and f.direction is self.ONLY:
self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE:
self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE:
self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY:
self.direction = self.ONLY
self.fields = f.fields - self.fields
if self.always_include:
if self.direction is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include)
else:
self.fields -= self.always_include
return self
def reset(self):
self.fields = set([])
self.direction = self.ONLY
def __nonzero__(self):
return bool(self.fields)
class QuerySet(object): class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor, """A set of results returned from a query. Wraps a MongoDB cursor,
providing :class:`~mongoengine.Document` objects as the results. providing :class:`~mongoengine.Document` objects as the results.
@ -281,7 +323,7 @@ class QuerySet(object):
self._query_obj = Q() self._query_obj = Q()
self._initial_query = {} self._initial_query = {}
self._where_clause = None self._where_clause = None
self._loaded_fields = [] self._loaded_fields = QueryFieldList()
self._ordering = [] self._ordering = []
self._snapshot = False self._snapshot = False
self._timeout = True self._timeout = True
@ -290,6 +332,7 @@ class QuerySet(object):
# subclasses of the class being used # subclasses of the class being used
if document._meta.get('allow_inheritance'): if document._meta.get('allow_inheritance'):
self._initial_query = {'_types': self._document._class_name} self._initial_query = {'_types': self._document._class_name}
self._loaded_fields = QueryFieldList(always_include=['_cls'])
self._cursor_obj = None self._cursor_obj = None
self._limit = None self._limit = None
self._skip = None self._skip = None
@ -423,7 +466,7 @@ class QuerySet(object):
'timeout': self._timeout, 'timeout': self._timeout,
} }
if self._loaded_fields: if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields cursor_args['fields'] = self._loaded_fields.as_dict()
self._cursor_obj = self._collection.find(self._query, self._cursor_obj = self._collection.find(self._query,
**cursor_args) **cursor_args)
# Apply where clauses to cursor # Apply where clauses to cursor
@ -818,15 +861,22 @@ class QuerySet(object):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
self._loaded_fields = [] fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY)
return self
def exclude(self, *fields):
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE)
return self
def _fields_to_dbfields(self, fields):
ret = []
for field in fields: for field in fields:
field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.')))
self._loaded_fields.append(field) ret.append(field)
return ret
# _cls is needed for polymorphism
if self._document._meta.get('allow_inheritance'):
self._loaded_fields += ['_cls']
return self
def order_by(self, *keys): def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The

View File

@ -6,7 +6,7 @@ import pymongo
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
DoesNotExist) DoesNotExist, QueryFieldList)
from mongoengine import * from mongoengine import *
@ -497,6 +497,81 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_exclude(self):
class User(EmbeddedDocument):
name = StringField()
email = StringField()
class Comment(EmbeddedDocument):
title = StringField()
text = StringField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
comments = ListField(EmbeddedDocumentField(Comment))
BlogPost.drop_collection()
post = BlogPost(content='Had a good coffee today...')
post.author = User(name='Test User')
post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')]
post.save()
obj = BlogPost.objects.exclude('author', 'comments.text').get()
self.assertEqual(obj.author, None)
self.assertEqual(obj.content, 'Had a good coffee today...')
self.assertEqual(obj.comments[0].title, 'I aggree')
self.assertEqual(obj.comments[0].text, None)
BlogPost.drop_collection()
def test_exclude_only_combining(self):
class Attachment(EmbeddedDocument):
name = StringField()
content = StringField()
class Email(Document):
sender = StringField()
to = StringField()
subject = StringField()
body = StringField()
content_type = StringField()
attachments = ListField(EmbeddedDocumentField(Attachment))
Email.drop_collection()
email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain')
email.attachments = [
Attachment(name='file1.doc', content='ABC'),
Attachment(name='file2.doc', content='XYZ'),
]
email.save()
obj = Email.objects.exclude('content_type').exclude('body').get()
self.assertEqual(obj.sender, 'me')
self.assertEqual(obj.to, 'you')
self.assertEqual(obj.subject, 'From Russia with Love')
self.assertEqual(obj.body, None)
self.assertEqual(obj.content_type, None)
obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get()
self.assertEqual(obj.sender, None)
self.assertEqual(obj.to, 'you')
self.assertEqual(obj.subject, None)
self.assertEqual(obj.body, None)
self.assertEqual(obj.content_type, None)
obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get()
self.assertEqual(obj.attachments[0].name, 'file1.doc')
self.assertEqual(obj.attachments[0].content, None)
self.assertEqual(obj.sender, None)
self.assertEqual(obj.to, 'you')
self.assertEqual(obj.subject, None)
self.assertEqual(obj.body, None)
self.assertEqual(obj.content_type, None)
Email.drop_collection()
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
""" """
@ -1594,6 +1669,62 @@ class QTest(unittest.TestCase):
for condition in conditions: for condition in conditions:
self.assertTrue(condition in query['$or']) self.assertTrue(condition in query['$or'])
class QueryFieldListTest(unittest.TestCase):
def test_empty(self):
q = QueryFieldList()
self.assertFalse(q)
q = QueryFieldList(always_include=['_cls'])
self.assertFalse(q)
def test_include_include(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self):
q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset()
self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()