added Queryset.exclude() + tests
This commit is contained in:
parent
c6058fafed
commit
bda4776a18
@ -268,6 +268,48 @@ class Q(QNode):
|
|||||||
return not bool(self.query)
|
return not bool(self.query)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryFieldList(object):
|
||||||
|
"""Object that handles combinations of .only() and .exclude() calls"""
|
||||||
|
ONLY = True
|
||||||
|
EXCLUDE = False
|
||||||
|
|
||||||
|
def __init__(self, fields=[], direction=ONLY, always_include=[]):
|
||||||
|
self.direction = direction
|
||||||
|
self.fields = set(fields)
|
||||||
|
self.always_include = set(always_include)
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return dict((field, self.direction) for field in self.fields)
|
||||||
|
|
||||||
|
def __add__(self, f):
|
||||||
|
if not self.fields:
|
||||||
|
self.fields = f.fields
|
||||||
|
self.direction = f.direction
|
||||||
|
elif self.direction is self.ONLY and f.direction is self.ONLY:
|
||||||
|
self.fields = self.fields.intersection(f.fields)
|
||||||
|
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE:
|
||||||
|
self.fields = self.fields.union(f.fields)
|
||||||
|
elif self.direction is self.ONLY and f.direction is self.EXCLUDE:
|
||||||
|
self.fields -= f.fields
|
||||||
|
elif self.direction is self.EXCLUDE and f.direction is self.ONLY:
|
||||||
|
self.direction = self.ONLY
|
||||||
|
self.fields = f.fields - self.fields
|
||||||
|
|
||||||
|
if self.always_include:
|
||||||
|
if self.direction is self.ONLY and self.fields:
|
||||||
|
self.fields = self.fields.union(self.always_include)
|
||||||
|
else:
|
||||||
|
self.fields -= self.always_include
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.fields = set([])
|
||||||
|
self.direction = self.ONLY
|
||||||
|
|
||||||
|
def __nonzero__(self):
|
||||||
|
return bool(self.fields)
|
||||||
|
|
||||||
|
|
||||||
class QuerySet(object):
|
class QuerySet(object):
|
||||||
"""A set of results returned from a query. Wraps a MongoDB cursor,
|
"""A set of results returned from a query. Wraps a MongoDB cursor,
|
||||||
providing :class:`~mongoengine.Document` objects as the results.
|
providing :class:`~mongoengine.Document` objects as the results.
|
||||||
@ -281,7 +323,7 @@ class QuerySet(object):
|
|||||||
self._query_obj = Q()
|
self._query_obj = Q()
|
||||||
self._initial_query = {}
|
self._initial_query = {}
|
||||||
self._where_clause = None
|
self._where_clause = None
|
||||||
self._loaded_fields = []
|
self._loaded_fields = QueryFieldList()
|
||||||
self._ordering = []
|
self._ordering = []
|
||||||
self._snapshot = False
|
self._snapshot = False
|
||||||
self._timeout = True
|
self._timeout = True
|
||||||
@ -290,6 +332,7 @@ class QuerySet(object):
|
|||||||
# subclasses of the class being used
|
# subclasses of the class being used
|
||||||
if document._meta.get('allow_inheritance'):
|
if document._meta.get('allow_inheritance'):
|
||||||
self._initial_query = {'_types': self._document._class_name}
|
self._initial_query = {'_types': self._document._class_name}
|
||||||
|
self._loaded_fields = QueryFieldList(always_include=['_cls'])
|
||||||
self._cursor_obj = None
|
self._cursor_obj = None
|
||||||
self._limit = None
|
self._limit = None
|
||||||
self._skip = None
|
self._skip = None
|
||||||
@ -423,7 +466,7 @@ class QuerySet(object):
|
|||||||
'timeout': self._timeout,
|
'timeout': self._timeout,
|
||||||
}
|
}
|
||||||
if self._loaded_fields:
|
if self._loaded_fields:
|
||||||
cursor_args['fields'] = self._loaded_fields
|
cursor_args['fields'] = self._loaded_fields.as_dict()
|
||||||
self._cursor_obj = self._collection.find(self._query,
|
self._cursor_obj = self._collection.find(self._query,
|
||||||
**cursor_args)
|
**cursor_args)
|
||||||
# Apply where clauses to cursor
|
# Apply where clauses to cursor
|
||||||
@ -818,15 +861,22 @@ class QuerySet(object):
|
|||||||
|
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
self._loaded_fields = []
|
fields = self._fields_to_dbfields(fields)
|
||||||
|
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def exclude(self, *fields):
|
||||||
|
fields = self._fields_to_dbfields(fields)
|
||||||
|
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _fields_to_dbfields(self, fields):
|
||||||
|
ret = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.')))
|
field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.')))
|
||||||
self._loaded_fields.append(field)
|
ret.append(field)
|
||||||
|
return ret
|
||||||
# _cls is needed for polymorphism
|
|
||||||
if self._document._meta.get('allow_inheritance'):
|
|
||||||
self._loaded_fields += ['_cls']
|
|
||||||
return self
|
|
||||||
|
|
||||||
def order_by(self, *keys):
|
def order_by(self, *keys):
|
||||||
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
|
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
|
||||||
|
@ -6,7 +6,7 @@ import pymongo
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
|
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
|
||||||
DoesNotExist)
|
DoesNotExist, QueryFieldList)
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
|
||||||
|
|
||||||
@ -497,6 +497,81 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
def test_exclude(self):
|
||||||
|
class User(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
email = StringField()
|
||||||
|
|
||||||
|
class Comment(EmbeddedDocument):
|
||||||
|
title = StringField()
|
||||||
|
text = StringField()
|
||||||
|
|
||||||
|
class BlogPost(Document):
|
||||||
|
content = StringField()
|
||||||
|
author = EmbeddedDocumentField(User)
|
||||||
|
comments = ListField(EmbeddedDocumentField(Comment))
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(content='Had a good coffee today...')
|
||||||
|
post.author = User(name='Test User')
|
||||||
|
post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')]
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
obj = BlogPost.objects.exclude('author', 'comments.text').get()
|
||||||
|
self.assertEqual(obj.author, None)
|
||||||
|
self.assertEqual(obj.content, 'Had a good coffee today...')
|
||||||
|
self.assertEqual(obj.comments[0].title, 'I aggree')
|
||||||
|
self.assertEqual(obj.comments[0].text, None)
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
def test_exclude_only_combining(self):
|
||||||
|
class Attachment(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
content = StringField()
|
||||||
|
|
||||||
|
class Email(Document):
|
||||||
|
sender = StringField()
|
||||||
|
to = StringField()
|
||||||
|
subject = StringField()
|
||||||
|
body = StringField()
|
||||||
|
content_type = StringField()
|
||||||
|
attachments = ListField(EmbeddedDocumentField(Attachment))
|
||||||
|
|
||||||
|
Email.drop_collection()
|
||||||
|
email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain')
|
||||||
|
email.attachments = [
|
||||||
|
Attachment(name='file1.doc', content='ABC'),
|
||||||
|
Attachment(name='file2.doc', content='XYZ'),
|
||||||
|
]
|
||||||
|
email.save()
|
||||||
|
|
||||||
|
obj = Email.objects.exclude('content_type').exclude('body').get()
|
||||||
|
self.assertEqual(obj.sender, 'me')
|
||||||
|
self.assertEqual(obj.to, 'you')
|
||||||
|
self.assertEqual(obj.subject, 'From Russia with Love')
|
||||||
|
self.assertEqual(obj.body, None)
|
||||||
|
self.assertEqual(obj.content_type, None)
|
||||||
|
|
||||||
|
obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get()
|
||||||
|
self.assertEqual(obj.sender, None)
|
||||||
|
self.assertEqual(obj.to, 'you')
|
||||||
|
self.assertEqual(obj.subject, None)
|
||||||
|
self.assertEqual(obj.body, None)
|
||||||
|
self.assertEqual(obj.content_type, None)
|
||||||
|
|
||||||
|
obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get()
|
||||||
|
self.assertEqual(obj.attachments[0].name, 'file1.doc')
|
||||||
|
self.assertEqual(obj.attachments[0].content, None)
|
||||||
|
self.assertEqual(obj.sender, None)
|
||||||
|
self.assertEqual(obj.to, 'you')
|
||||||
|
self.assertEqual(obj.subject, None)
|
||||||
|
self.assertEqual(obj.body, None)
|
||||||
|
self.assertEqual(obj.content_type, None)
|
||||||
|
|
||||||
|
Email.drop_collection()
|
||||||
|
|
||||||
def test_find_embedded(self):
|
def test_find_embedded(self):
|
||||||
"""Ensure that an embedded document is properly returned from a query.
|
"""Ensure that an embedded document is properly returned from a query.
|
||||||
"""
|
"""
|
||||||
@ -1594,6 +1669,62 @@ class QTest(unittest.TestCase):
|
|||||||
for condition in conditions:
|
for condition in conditions:
|
||||||
self.assertTrue(condition in query['$or'])
|
self.assertTrue(condition in query['$or'])
|
||||||
|
|
||||||
|
class QueryFieldListTest(unittest.TestCase):
|
||||||
|
def test_empty(self):
|
||||||
|
q = QueryFieldList()
|
||||||
|
self.assertFalse(q)
|
||||||
|
|
||||||
|
q = QueryFieldList(always_include=['_cls'])
|
||||||
|
self.assertFalse(q)
|
||||||
|
|
||||||
|
def test_include_include(self):
|
||||||
|
q = QueryFieldList()
|
||||||
|
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'b': True})
|
||||||
|
|
||||||
|
def test_include_exclude(self):
|
||||||
|
q = QueryFieldList()
|
||||||
|
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
|
||||||
|
self.assertEqual(q.as_dict(), {'a': True})
|
||||||
|
|
||||||
|
def test_exclude_exclude(self):
|
||||||
|
q = QueryFieldList()
|
||||||
|
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
|
||||||
|
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
|
||||||
|
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
|
||||||
|
|
||||||
|
def test_exclude_include(self):
|
||||||
|
q = QueryFieldList()
|
||||||
|
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
|
||||||
|
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'c': True})
|
||||||
|
|
||||||
|
def test_always_include(self):
|
||||||
|
q = QueryFieldList(always_include=['x', 'y'])
|
||||||
|
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
q = QueryFieldList(always_include=['x', 'y'])
|
||||||
|
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
|
||||||
|
q.reset()
|
||||||
|
self.assertFalse(q)
|
||||||
|
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||||
|
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user