added Queryset.exclude() + tests
This commit is contained in:
		| @@ -268,6 +268,48 @@ class Q(QNode): | |||||||
|         return not bool(self.query) |         return not bool(self.query) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QueryFieldList(object): | ||||||
|  |     """Object that handles combinations of .only() and .exclude() calls""" | ||||||
|  |     ONLY = True | ||||||
|  |     EXCLUDE = False | ||||||
|  |  | ||||||
|  |     def __init__(self, fields=[], direction=ONLY, always_include=[]): | ||||||
|  |         self.direction = direction | ||||||
|  |         self.fields = set(fields) | ||||||
|  |         self.always_include = set(always_include) | ||||||
|  |  | ||||||
|  |     def as_dict(self): | ||||||
|  |         return dict((field, self.direction) for field in self.fields) | ||||||
|  |  | ||||||
|  |     def __add__(self, f): | ||||||
|  |         if not self.fields: | ||||||
|  |             self.fields = f.fields | ||||||
|  |             self.direction = f.direction | ||||||
|  |         elif self.direction is self.ONLY and f.direction is self.ONLY: | ||||||
|  |             self.fields = self.fields.intersection(f.fields) | ||||||
|  |         elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: | ||||||
|  |             self.fields = self.fields.union(f.fields) | ||||||
|  |         elif self.direction is self.ONLY and f.direction is self.EXCLUDE: | ||||||
|  |             self.fields -= f.fields | ||||||
|  |         elif self.direction is self.EXCLUDE and f.direction is self.ONLY: | ||||||
|  |             self.direction = self.ONLY | ||||||
|  |             self.fields = f.fields - self.fields | ||||||
|  |  | ||||||
|  |         if self.always_include: | ||||||
|  |             if self.direction is self.ONLY and self.fields: | ||||||
|  |                 self.fields = self.fields.union(self.always_include) | ||||||
|  |             else: | ||||||
|  |                 self.fields -= self.always_include | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self.fields = set([]) | ||||||
|  |         self.direction = self.ONLY | ||||||
|  |  | ||||||
|  |     def __nonzero__(self): | ||||||
|  |         return bool(self.fields) | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuerySet(object): | class QuerySet(object): | ||||||
|     """A set of results returned from a query. Wraps a MongoDB cursor, |     """A set of results returned from a query. Wraps a MongoDB cursor, | ||||||
|     providing :class:`~mongoengine.Document` objects as the results. |     providing :class:`~mongoengine.Document` objects as the results. | ||||||
| @@ -281,7 +323,7 @@ class QuerySet(object): | |||||||
|         self._query_obj = Q() |         self._query_obj = Q() | ||||||
|         self._initial_query = {} |         self._initial_query = {} | ||||||
|         self._where_clause = None |         self._where_clause = None | ||||||
|         self._loaded_fields = [] |         self._loaded_fields = QueryFieldList() | ||||||
|         self._ordering = [] |         self._ordering = [] | ||||||
|         self._snapshot = False |         self._snapshot = False | ||||||
|         self._timeout = True |         self._timeout = True | ||||||
| @@ -290,6 +332,7 @@ class QuerySet(object): | |||||||
|         # subclasses of the class being used |         # subclasses of the class being used | ||||||
|         if document._meta.get('allow_inheritance'): |         if document._meta.get('allow_inheritance'): | ||||||
|             self._initial_query = {'_types': self._document._class_name} |             self._initial_query = {'_types': self._document._class_name} | ||||||
|  |             self._loaded_fields = QueryFieldList(always_include=['_cls']) | ||||||
|         self._cursor_obj = None |         self._cursor_obj = None | ||||||
|         self._limit = None |         self._limit = None | ||||||
|         self._skip = None |         self._skip = None | ||||||
| @@ -423,7 +466,7 @@ class QuerySet(object): | |||||||
|                 'timeout': self._timeout, |                 'timeout': self._timeout, | ||||||
|             } |             } | ||||||
|             if self._loaded_fields: |             if self._loaded_fields: | ||||||
|                 cursor_args['fields'] = self._loaded_fields |                 cursor_args['fields'] = self._loaded_fields.as_dict() | ||||||
|             self._cursor_obj = self._collection.find(self._query,  |             self._cursor_obj = self._collection.find(self._query,  | ||||||
|                                                      **cursor_args) |                                                      **cursor_args) | ||||||
|             # Apply where clauses to cursor |             # Apply where clauses to cursor | ||||||
| @@ -818,15 +861,22 @@ class QuerySet(object): | |||||||
|  |  | ||||||
|         .. versionadded:: 0.3 |         .. versionadded:: 0.3 | ||||||
|         """ |         """ | ||||||
|         self._loaded_fields = [] |         fields = self._fields_to_dbfields(fields) | ||||||
|  |         self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def exclude(self, *fields): | ||||||
|  |         fields = self._fields_to_dbfields(fields) | ||||||
|  |         self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def _fields_to_dbfields(self, fields): | ||||||
|  |         ret = [] | ||||||
|         for field in fields: |         for field in fields: | ||||||
|             field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) |             field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) | ||||||
|             self._loaded_fields.append(field) |             ret.append(field) | ||||||
|  |         return ret | ||||||
|         # _cls is needed for polymorphism |  | ||||||
|         if self._document._meta.get('allow_inheritance'): |  | ||||||
|             self._loaded_fields += ['_cls'] |  | ||||||
|         return self |  | ||||||
|  |  | ||||||
|     def order_by(self, *keys): |     def order_by(self, *keys): | ||||||
|         """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The |         """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ import pymongo | |||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
|  |  | ||||||
| from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, | from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, | ||||||
|                                   DoesNotExist) |                                   DoesNotExist, QueryFieldList) | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -497,6 +497,81 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_exclude(self): | ||||||
|  |         class User(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |             email = StringField() | ||||||
|  |  | ||||||
|  |         class Comment(EmbeddedDocument): | ||||||
|  |             title = StringField() | ||||||
|  |             text = StringField() | ||||||
|  |  | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             content = StringField() | ||||||
|  |             author = EmbeddedDocumentField(User) | ||||||
|  |             comments = ListField(EmbeddedDocumentField(Comment)) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost(content='Had a good coffee today...') | ||||||
|  |         post.author = User(name='Test User') | ||||||
|  |         post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         obj = BlogPost.objects.exclude('author', 'comments.text').get() | ||||||
|  |         self.assertEqual(obj.author, None) | ||||||
|  |         self.assertEqual(obj.content, 'Had a good coffee today...') | ||||||
|  |         self.assertEqual(obj.comments[0].title, 'I aggree') | ||||||
|  |         self.assertEqual(obj.comments[0].text, None) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_exclude_only_combining(self): | ||||||
|  |         class Attachment(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |             content = StringField() | ||||||
|  |  | ||||||
|  |         class Email(Document): | ||||||
|  |             sender = StringField() | ||||||
|  |             to = StringField() | ||||||
|  |             subject = StringField() | ||||||
|  |             body = StringField() | ||||||
|  |             content_type = StringField() | ||||||
|  |             attachments = ListField(EmbeddedDocumentField(Attachment)) | ||||||
|  |  | ||||||
|  |         Email.drop_collection() | ||||||
|  |         email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') | ||||||
|  |         email.attachments = [ | ||||||
|  |             Attachment(name='file1.doc', content='ABC'), | ||||||
|  |             Attachment(name='file2.doc', content='XYZ'), | ||||||
|  |         ] | ||||||
|  |         email.save() | ||||||
|  |  | ||||||
|  |         obj = Email.objects.exclude('content_type').exclude('body').get() | ||||||
|  |         self.assertEqual(obj.sender, 'me') | ||||||
|  |         self.assertEqual(obj.to, 'you') | ||||||
|  |         self.assertEqual(obj.subject, 'From Russia with Love') | ||||||
|  |         self.assertEqual(obj.body, None) | ||||||
|  |         self.assertEqual(obj.content_type, None) | ||||||
|  |  | ||||||
|  |         obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get() | ||||||
|  |         self.assertEqual(obj.sender, None) | ||||||
|  |         self.assertEqual(obj.to, 'you') | ||||||
|  |         self.assertEqual(obj.subject, None) | ||||||
|  |         self.assertEqual(obj.body, None) | ||||||
|  |         self.assertEqual(obj.content_type, None) | ||||||
|  |  | ||||||
|  |         obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get() | ||||||
|  |         self.assertEqual(obj.attachments[0].name, 'file1.doc') | ||||||
|  |         self.assertEqual(obj.attachments[0].content, None) | ||||||
|  |         self.assertEqual(obj.sender, None) | ||||||
|  |         self.assertEqual(obj.to, 'you') | ||||||
|  |         self.assertEqual(obj.subject, None) | ||||||
|  |         self.assertEqual(obj.body, None) | ||||||
|  |         self.assertEqual(obj.content_type, None) | ||||||
|  |  | ||||||
|  |         Email.drop_collection() | ||||||
|  |  | ||||||
|     def test_find_embedded(self): |     def test_find_embedded(self): | ||||||
|         """Ensure that an embedded document is properly returned from a query. |         """Ensure that an embedded document is properly returned from a query. | ||||||
|         """ |         """ | ||||||
| @@ -1594,6 +1669,62 @@ class QTest(unittest.TestCase): | |||||||
|         for condition in conditions: |         for condition in conditions: | ||||||
|             self.assertTrue(condition in query['$or']) |             self.assertTrue(condition in query['$or']) | ||||||
|  |  | ||||||
|  | class QueryFieldListTest(unittest.TestCase): | ||||||
|  |     def test_empty(self): | ||||||
|  |         q = QueryFieldList() | ||||||
|  |         self.assertFalse(q) | ||||||
|  |  | ||||||
|  |         q = QueryFieldList(always_include=['_cls']) | ||||||
|  |         self.assertFalse(q) | ||||||
|  |  | ||||||
|  |     def test_include_include(self): | ||||||
|  |         q = QueryFieldList() | ||||||
|  |         q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'a': True, 'b': True}) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'b': True}) | ||||||
|  |  | ||||||
|  |     def test_include_exclude(self): | ||||||
|  |         q = QueryFieldList() | ||||||
|  |         q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'a': True, 'b': True}) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) | ||||||
|  |         self.assertEqual(q.as_dict(), {'a': True}) | ||||||
|  |  | ||||||
|  |     def test_exclude_exclude(self): | ||||||
|  |         q = QueryFieldList() | ||||||
|  |         q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) | ||||||
|  |         self.assertEqual(q.as_dict(), {'a': False, 'b': False}) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) | ||||||
|  |         self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) | ||||||
|  |  | ||||||
|  |     def test_exclude_include(self): | ||||||
|  |         q = QueryFieldList() | ||||||
|  |         q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) | ||||||
|  |         self.assertEqual(q.as_dict(), {'a': False, 'b': False}) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'c': True}) | ||||||
|  |  | ||||||
|  |     def test_always_include(self): | ||||||
|  |         q = QueryFieldList(always_include=['x', 'y']) | ||||||
|  |         q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def test_reset(self): | ||||||
|  |         q = QueryFieldList(always_include=['x', 'y']) | ||||||
|  |         q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) | ||||||
|  |         q.reset() | ||||||
|  |         self.assertFalse(q) | ||||||
|  |         q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) | ||||||
|  |         self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user