diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index f4849619..49c8f69d 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -431,7 +431,9 @@ class QuerySet(object): self._cursor_obj.where(self._where_clause) # apply default ordering - if self._document._meta['ordering']: + if self._ordering: + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: self.order_by(*self._document._meta['ordering']) if self._limit is not None: @@ -818,11 +820,7 @@ class QuerySet(object): """ self._loaded_fields = [] for field in fields: - if '.' in field: - raise InvalidQueryError('Subfields cannot be used as ' - 'arguments to QuerySet.only') - # Translate field name - field = QuerySet._lookup_field(self._document, field)[-1].db_field + field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) self._loaded_fields.append(field) # _cls is needed for polymorphism @@ -919,8 +917,7 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - if op in (None, 'set', 'unset', 'pop', 'push', 'pull', - 'addToSet'): + if op in (None, 'set', 'push', 'pull', 'addToSet'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] diff --git a/tests/queryset.py b/tests/queryset.py index 37fe7501..374fdb54 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -452,6 +452,51 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj.salary, employee.salary) self.assertEqual(obj.name, None) + def test_only_with_subfields(self): + class User(EmbeddedDocument): + name = StringField() + email = StringField() + + class Comment(EmbeddedDocument): + title = StringField() + text = StringField() + + class BlogPost(Document): + content = StringField() + author = EmbeddedDocumentField(User) + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + post = BlogPost(content='Had a good coffee today...') + post.author = User(name='Test User') + post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post.save() + + obj = BlogPost.objects.only('author.name',).get() + self.assertEqual(obj.content, None) + self.assertEqual(obj.author.email, None) + self.assertEqual(obj.author.name, 'Test User') + self.assertEqual(obj.comments, []) + + obj = BlogPost.objects.only('content', 'comments.title',).get() + self.assertEqual(obj.content, 'Had a good coffee today...') + self.assertEqual(obj.author, None) + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].text, None) + self.assertEqual(obj.comments[1].text, None) + + obj = BlogPost.objects.only('comments',).get() + self.assertEqual(obj.content, None) + self.assertEqual(obj.author, None) + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].text, 'Great post!') + self.assertEqual(obj.comments[1].text, 'I hate coffee') + + BlogPost.drop_collection() + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -733,6 +778,11 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.update_one(add_to_set__tags='unique') post.reload() self.assertEqual(post.tags.count('unique'), 1) + + self.assertNotEqual(post.hits, None) + BlogPost.objects.update_one(unset__hits=1) + post.reload() + self.assertEqual(post.hits, None) BlogPost.drop_collection() @@ -1395,6 +1445,45 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_order_then_filter(self): + """Ensure that ordering still works after filtering. + """ + class Number(Document): + n = IntField() + + Number.drop_collection() + + n2 = Number.objects.create(n=2) + n1 = Number.objects.create(n=1) + + self.assertEqual(list(Number.objects), [n2, n1]) + self.assertEqual(list(Number.objects.order_by('n')), [n1, n2]) + self.assertEqual(list(Number.objects.order_by('n').filter()), [n1, n2]) + + Number.drop_collection() + + def test_unset_reference(self): + class Comment(Document): + text = StringField() + + class Post(Document): + comment = ReferenceField(Comment) + + Comment.drop_collection() + Post.drop_collection() + + comment = Comment.objects.create(text='test') + post = Post.objects.create(comment=comment) + + self.assertEqual(post.comment, comment) + Post.objects.update(unset__comment=1) + post.reload() + self.assertEqual(post.comment, None) + + Comment.drop_collection() + Post.drop_collection() + + class QTest(unittest.TestCase): def test_empty_q(self):