From bda4776a18522704469d16c0c58b855f0cb434a9 Mon Sep 17 00:00:00 2001 From: Ales Zoulek Date: Wed, 3 Nov 2010 16:37:41 +0100 Subject: [PATCH] added Queryset.exclude() + tests --- mongoengine/queryset.py | 68 +++++++++++++++++--- tests/queryset.py | 133 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 191 insertions(+), 10 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index e46380b6..8ce5ec7a 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -268,6 +268,48 @@ class Q(QNode): return not bool(self.query) +class QueryFieldList(object): + """Object that handles combinations of .only() and .exclude() calls""" + ONLY = True + EXCLUDE = False + + def __init__(self, fields=[], direction=ONLY, always_include=[]): + self.direction = direction + self.fields = set(fields) + self.always_include = set(always_include) + + def as_dict(self): + return dict((field, self.direction) for field in self.fields) + + def __add__(self, f): + if not self.fields: + self.fields = f.fields + self.direction = f.direction + elif self.direction is self.ONLY and f.direction is self.ONLY: + self.fields = self.fields.intersection(f.fields) + elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: + self.fields = self.fields.union(f.fields) + elif self.direction is self.ONLY and f.direction is self.EXCLUDE: + self.fields -= f.fields + elif self.direction is self.EXCLUDE and f.direction is self.ONLY: + self.direction = self.ONLY + self.fields = f.fields - self.fields + + if self.always_include: + if self.direction is self.ONLY and self.fields: + self.fields = self.fields.union(self.always_include) + else: + self.fields -= self.always_include + return self + + def reset(self): + self.fields = set([]) + self.direction = self.ONLY + + def __nonzero__(self): + return bool(self.fields) + + class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. @@ -281,7 +323,7 @@ class QuerySet(object): self._query_obj = Q() self._initial_query = {} self._where_clause = None - self._loaded_fields = [] + self._loaded_fields = QueryFieldList() self._ordering = [] self._snapshot = False self._timeout = True @@ -290,6 +332,7 @@ class QuerySet(object): # subclasses of the class being used if document._meta.get('allow_inheritance'): self._initial_query = {'_types': self._document._class_name} + self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None self._limit = None self._skip = None @@ -423,7 +466,7 @@ class QuerySet(object): 'timeout': self._timeout, } if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields + cursor_args['fields'] = self._loaded_fields.as_dict() self._cursor_obj = self._collection.find(self._query, **cursor_args) # Apply where clauses to cursor @@ -818,15 +861,22 @@ class QuerySet(object): .. versionadded:: 0.3 """ - self._loaded_fields = [] + fields = self._fields_to_dbfields(fields) + self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) + return self + + + def exclude(self, *fields): + fields = self._fields_to_dbfields(fields) + self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) + return self + + def _fields_to_dbfields(self, fields): + ret = [] for field in fields: field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) - self._loaded_fields.append(field) - - # _cls is needed for polymorphism - if self._document._meta.get('allow_inheritance'): - self._loaded_fields += ['_cls'] - return self + ret.append(field) + return ret def order_by(self, *keys): """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The diff --git a/tests/queryset.py b/tests/queryset.py index 8b25524e..4e1302ee 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -6,7 +6,7 @@ import pymongo from datetime import datetime, timedelta from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, - DoesNotExist) + DoesNotExist, QueryFieldList) from mongoengine import * @@ -497,6 +497,81 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_exclude(self): + class User(EmbeddedDocument): + name = StringField() + email = StringField() + + class Comment(EmbeddedDocument): + title = StringField() + text = StringField() + + class BlogPost(Document): + content = StringField() + author = EmbeddedDocumentField(User) + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + post = BlogPost(content='Had a good coffee today...') + post.author = User(name='Test User') + post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post.save() + + obj = BlogPost.objects.exclude('author', 'comments.text').get() + self.assertEqual(obj.author, None) + self.assertEqual(obj.content, 'Had a good coffee today...') + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[0].text, None) + + BlogPost.drop_collection() + + def test_exclude_only_combining(self): + class Attachment(EmbeddedDocument): + name = StringField() + content = StringField() + + class Email(Document): + sender = StringField() + to = StringField() + subject = StringField() + body = StringField() + content_type = StringField() + attachments = ListField(EmbeddedDocumentField(Attachment)) + + Email.drop_collection() + email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email.attachments = [ + Attachment(name='file1.doc', content='ABC'), + Attachment(name='file2.doc', content='XYZ'), + ] + email.save() + + obj = Email.objects.exclude('content_type').exclude('body').get() + self.assertEqual(obj.sender, 'me') + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, 'From Russia with Love') + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get() + self.assertEqual(obj.sender, None) + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, None) + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get() + self.assertEqual(obj.attachments[0].name, 'file1.doc') + self.assertEqual(obj.attachments[0].content, None) + self.assertEqual(obj.sender, None) + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, None) + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + Email.drop_collection() + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -1594,6 +1669,62 @@ class QTest(unittest.TestCase): for condition in conditions: self.assertTrue(condition in query['$or']) +class QueryFieldListTest(unittest.TestCase): + def test_empty(self): + q = QueryFieldList() + self.assertFalse(q) + + q = QueryFieldList(always_include=['_cls']) + self.assertFalse(q) + + def test_include_include(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'a': True, 'b': True}) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'b': True}) + + def test_include_exclude(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'a': True, 'b': True}) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': True}) + + def test_exclude_exclude(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': False, 'b': False}) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) + + def test_exclude_include(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': False, 'b': False}) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'c': True}) + + def test_always_include(self): + q = QueryFieldList(always_include=['x', 'y']) + q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) + + + def test_reset(self): + q = QueryFieldList(always_include=['x', 'y']) + q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) + q.reset() + self.assertFalse(q) + q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) + + + + if __name__ == '__main__': unittest.main()