Merge remote branch 'aleszoulek/dev' into dev
This commit is contained in:
commit
10c0b035ae
@ -268,6 +268,48 @@ class Q(QNode):
|
||||
return not bool(self.query)
|
||||
|
||||
|
||||
class QueryFieldList(object):
|
||||
"""Object that handles combinations of .only() and .exclude() calls"""
|
||||
ONLY = True
|
||||
EXCLUDE = False
|
||||
|
||||
def __init__(self, fields=[], direction=ONLY, always_include=[]):
|
||||
self.direction = direction
|
||||
self.fields = set(fields)
|
||||
self.always_include = set(always_include)
|
||||
|
||||
def as_dict(self):
|
||||
return dict((field, self.direction) for field in self.fields)
|
||||
|
||||
def __add__(self, f):
|
||||
if not self.fields:
|
||||
self.fields = f.fields
|
||||
self.direction = f.direction
|
||||
elif self.direction is self.ONLY and f.direction is self.ONLY:
|
||||
self.fields = self.fields.intersection(f.fields)
|
||||
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE:
|
||||
self.fields = self.fields.union(f.fields)
|
||||
elif self.direction is self.ONLY and f.direction is self.EXCLUDE:
|
||||
self.fields -= f.fields
|
||||
elif self.direction is self.EXCLUDE and f.direction is self.ONLY:
|
||||
self.direction = self.ONLY
|
||||
self.fields = f.fields - self.fields
|
||||
|
||||
if self.always_include:
|
||||
if self.direction is self.ONLY and self.fields:
|
||||
self.fields = self.fields.union(self.always_include)
|
||||
else:
|
||||
self.fields -= self.always_include
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.fields = set([])
|
||||
self.direction = self.ONLY
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.fields)
|
||||
|
||||
|
||||
class QuerySet(object):
|
||||
"""A set of results returned from a query. Wraps a MongoDB cursor,
|
||||
providing :class:`~mongoengine.Document` objects as the results.
|
||||
@ -281,7 +323,7 @@ class QuerySet(object):
|
||||
self._query_obj = Q()
|
||||
self._initial_query = {}
|
||||
self._where_clause = None
|
||||
self._loaded_fields = []
|
||||
self._loaded_fields = QueryFieldList()
|
||||
self._ordering = []
|
||||
self._snapshot = False
|
||||
self._timeout = True
|
||||
@ -290,6 +332,7 @@ class QuerySet(object):
|
||||
# subclasses of the class being used
|
||||
if document._meta.get('allow_inheritance'):
|
||||
self._initial_query = {'_types': self._document._class_name}
|
||||
self._loaded_fields = QueryFieldList(always_include=['_cls'])
|
||||
self._cursor_obj = None
|
||||
self._limit = None
|
||||
self._skip = None
|
||||
@ -423,7 +466,7 @@ class QuerySet(object):
|
||||
'timeout': self._timeout,
|
||||
}
|
||||
if self._loaded_fields:
|
||||
cursor_args['fields'] = self._loaded_fields
|
||||
cursor_args['fields'] = self._loaded_fields.as_dict()
|
||||
self._cursor_obj = self._collection.find(self._query,
|
||||
**cursor_args)
|
||||
# Apply where clauses to cursor
|
||||
@ -818,15 +861,37 @@ class QuerySet(object):
|
||||
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
self._loaded_fields = []
|
||||
fields = self._fields_to_dbfields(fields)
|
||||
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY)
|
||||
return self
|
||||
|
||||
|
||||
def exclude(self, *fields):
|
||||
"""Opposite to .only(), exclude some document's fields. ::
|
||||
|
||||
post = BlogPost.objects(...).exclude("comments")
|
||||
|
||||
:param fields: fields to exclude
|
||||
"""
|
||||
fields = self._fields_to_dbfields(fields)
|
||||
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE)
|
||||
return self
|
||||
|
||||
def all_fields(self):
|
||||
"""Include all fields. Reset all previously calls of .only() and .exclude(). ::
|
||||
|
||||
post = BlogPost.objects(...).exclude("comments").only("title").all_fields()
|
||||
"""
|
||||
self._loaded_fields = QueryFieldList(always_include=self._loaded_fields.always_include)
|
||||
return self
|
||||
|
||||
def _fields_to_dbfields(self, fields):
|
||||
"""Translate fields paths to its db equivalents"""
|
||||
ret = []
|
||||
for field in fields:
|
||||
field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.')))
|
||||
self._loaded_fields.append(field)
|
||||
|
||||
# _cls is needed for polymorphism
|
||||
if self._document._meta.get('allow_inheritance'):
|
||||
self._loaded_fields += ['_cls']
|
||||
return self
|
||||
ret.append(field)
|
||||
return ret
|
||||
|
||||
def order_by(self, *keys):
|
||||
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
|
||||
|
@ -6,7 +6,7 @@ import pymongo
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
|
||||
DoesNotExist)
|
||||
DoesNotExist, QueryFieldList)
|
||||
from mongoengine import *
|
||||
|
||||
|
||||
@ -497,6 +497,104 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
def test_exclude(self):
|
||||
class User(EmbeddedDocument):
|
||||
name = StringField()
|
||||
email = StringField()
|
||||
|
||||
class Comment(EmbeddedDocument):
|
||||
title = StringField()
|
||||
text = StringField()
|
||||
|
||||
class BlogPost(Document):
|
||||
content = StringField()
|
||||
author = EmbeddedDocumentField(User)
|
||||
comments = ListField(EmbeddedDocumentField(Comment))
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(content='Had a good coffee today...')
|
||||
post.author = User(name='Test User')
|
||||
post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')]
|
||||
post.save()
|
||||
|
||||
obj = BlogPost.objects.exclude('author', 'comments.text').get()
|
||||
self.assertEqual(obj.author, None)
|
||||
self.assertEqual(obj.content, 'Had a good coffee today...')
|
||||
self.assertEqual(obj.comments[0].title, 'I aggree')
|
||||
self.assertEqual(obj.comments[0].text, None)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
def test_exclude_only_combining(self):
|
||||
class Attachment(EmbeddedDocument):
|
||||
name = StringField()
|
||||
content = StringField()
|
||||
|
||||
class Email(Document):
|
||||
sender = StringField()
|
||||
to = StringField()
|
||||
subject = StringField()
|
||||
body = StringField()
|
||||
content_type = StringField()
|
||||
attachments = ListField(EmbeddedDocumentField(Attachment))
|
||||
|
||||
Email.drop_collection()
|
||||
email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain')
|
||||
email.attachments = [
|
||||
Attachment(name='file1.doc', content='ABC'),
|
||||
Attachment(name='file2.doc', content='XYZ'),
|
||||
]
|
||||
email.save()
|
||||
|
||||
obj = Email.objects.exclude('content_type').exclude('body').get()
|
||||
self.assertEqual(obj.sender, 'me')
|
||||
self.assertEqual(obj.to, 'you')
|
||||
self.assertEqual(obj.subject, 'From Russia with Love')
|
||||
self.assertEqual(obj.body, None)
|
||||
self.assertEqual(obj.content_type, None)
|
||||
|
||||
obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get()
|
||||
self.assertEqual(obj.sender, None)
|
||||
self.assertEqual(obj.to, 'you')
|
||||
self.assertEqual(obj.subject, None)
|
||||
self.assertEqual(obj.body, None)
|
||||
self.assertEqual(obj.content_type, None)
|
||||
|
||||
obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get()
|
||||
self.assertEqual(obj.attachments[0].name, 'file1.doc')
|
||||
self.assertEqual(obj.attachments[0].content, None)
|
||||
self.assertEqual(obj.sender, None)
|
||||
self.assertEqual(obj.to, 'you')
|
||||
self.assertEqual(obj.subject, None)
|
||||
self.assertEqual(obj.body, None)
|
||||
self.assertEqual(obj.content_type, None)
|
||||
|
||||
Email.drop_collection()
|
||||
|
||||
def test_all_fields(self):
|
||||
|
||||
class Email(Document):
|
||||
sender = StringField()
|
||||
to = StringField()
|
||||
subject = StringField()
|
||||
body = StringField()
|
||||
content_type = StringField()
|
||||
|
||||
Email.drop_collection()
|
||||
|
||||
email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain')
|
||||
email.save()
|
||||
|
||||
obj = Email.objects.exclude('content_type', 'body').only('to', 'body').all_fields().get()
|
||||
self.assertEqual(obj.sender, 'me')
|
||||
self.assertEqual(obj.to, 'you')
|
||||
self.assertEqual(obj.subject, 'From Russia with Love')
|
||||
self.assertEqual(obj.body, 'Hello!')
|
||||
self.assertEqual(obj.content_type, 'text/plain')
|
||||
|
||||
Email.drop_collection()
|
||||
|
||||
def test_find_embedded(self):
|
||||
"""Ensure that an embedded document is properly returned from a query.
|
||||
"""
|
||||
@ -1634,6 +1732,62 @@ class QTest(unittest.TestCase):
|
||||
for condition in conditions:
|
||||
self.assertTrue(condition in query['$or'])
|
||||
|
||||
class QueryFieldListTest(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
q = QueryFieldList()
|
||||
self.assertFalse(q)
|
||||
|
||||
q = QueryFieldList(always_include=['_cls'])
|
||||
self.assertFalse(q)
|
||||
|
||||
def test_include_include(self):
|
||||
q = QueryFieldList()
|
||||
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'b': True})
|
||||
|
||||
def test_include_exclude(self):
|
||||
q = QueryFieldList()
|
||||
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
|
||||
self.assertEqual(q.as_dict(), {'a': True})
|
||||
|
||||
def test_exclude_exclude(self):
|
||||
q = QueryFieldList()
|
||||
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
|
||||
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
|
||||
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
|
||||
|
||||
def test_exclude_include(self):
|
||||
q = QueryFieldList()
|
||||
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
|
||||
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'c': True})
|
||||
|
||||
def test_always_include(self):
|
||||
q = QueryFieldList(always_include=['x', 'y'])
|
||||
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
|
||||
|
||||
|
||||
def test_reset(self):
|
||||
q = QueryFieldList(always_include=['x', 'y'])
|
||||
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
|
||||
q.reset()
|
||||
self.assertFalse(q)
|
||||
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
|
||||
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user