diff --git a/mongoengine/base.py b/mongoengine/base.py index 0c024be5..5188c310 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -298,8 +298,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class = super_new(cls, name, bases, attrs) # Provide a default queryset unless one has been manually provided - if not hasattr(new_class, 'objects'): - new_class.objects = QuerySetManager() + manager = attrs.get('objects', QuerySetManager()) + if hasattr(manager, 'queryset_class'): + meta['queryset_class'] = manager.queryset_class + new_class.objects = manager user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1d004d68..d7d349ad 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -428,8 +428,6 @@ class QuerySet(object): querying collection :param query: Django-style query keyword arguments """ - #if q_obj: - #self._where_clause = q_obj.as_js(self._document) query = Q(**query) if q_obj: query &= q_obj @@ -524,9 +522,14 @@ class QuerySet(object): fields = [] field = None for field_name in parts: + # Handle ListField indexing: if field_name.isdigit(): + try: + field = field.field + except AttributeError, err: + raise InvalidQueryError( + "Can't use index on unsubscriptable field (%s)" % err) fields.append(field_name) - field = field.field continue if field is None: # Look up first field from the document @@ -1075,11 +1078,13 @@ class QuerySet(object): # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) parts = [] + for field in fields: if isinstance(field, str): parts.append(field) else: parts.append(field.db_field) + # Convert value to proper value field = fields[-1] @@ -1315,8 +1320,11 @@ class QuerySet(object): class QuerySetManager(object): - def __init__(self, manager_func=None): - self._manager_func = manager_func + get_queryset = None + + def __init__(self, queryset_func=None): + if queryset_func: + self.get_queryset = queryset_func self._collections = {} def __get__(self, instance, owner): @@ -1360,11 +1368,11 @@ class QuerySetManager(object): # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet queryset = queryset_class(owner, self._collections[(db, collection)]) - if self._manager_func: - if self._manager_func.func_code.co_argcount == 1: - queryset = self._manager_func(queryset) + if self.get_queryset: + if self.get_queryset.func_code.co_argcount == 1: + queryset = self.get_queryset(queryset) else: - queryset = self._manager_func(owner, queryset) + queryset = self.get_queryset(owner, queryset) return queryset diff --git a/tests/queryset.py b/tests/queryset.py index 5a2c46cb..6a8e16d5 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -5,8 +5,9 @@ import unittest import pymongo from datetime import datetime, timedelta -from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, - DoesNotExist, QueryFieldList) +from mongoengine.queryset import (QuerySet, QuerySetManager, + MultipleObjectsReturned, DoesNotExist, + QueryFieldList) from mongoengine import * @@ -211,6 +212,55 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() + def test_update_array_position(self): + """Ensure that updating by array position works. + + Check update() and update_one() can take syntax like: + set__posts__1__comments__1__name="testc" + Check that it only works for ListFields. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update all of the first comments of second posts of all blogs + blog = Blog.objects().update(set__posts__1__comments__0__name="testc") + testc_blogs = Blog.objects(posts__1__comments__0__name="testc") + self.assertEqual(len(testc_blogs), 2) + + Blog.drop_collection() + + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update only the first blog returned by the query + blog = Blog.objects().update_one( + set__posts__1__comments__1__name="testc") + testc_blogs = Blog.objects(posts__1__comments__1__name="testc") + self.assertEqual(len(testc_blogs), 1) + + # Check that using this indexing syntax on a non-list fails + def non_list_indexing(): + Blog.objects().update(set__posts__1__comments__0__name__1="asdf") + self.assertRaises(InvalidQueryError, non_list_indexing) + + Blog.drop_collection() + def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new document. @@ -1737,6 +1787,53 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_custom_querysets_set_manager_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class CustomQuerySetManager(QuerySetManager): + queryset_class = CustomQuerySet + + class Post(Document): + objects = CustomQuerySetManager() + + Post.drop_collection() + + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + def test_custom_querysets_managers_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySetManager(QuerySetManager): + + @staticmethod + def get_queryset(doc_cls, queryset): + return queryset(is_published=True) + + class Post(Document): + is_published = BooleanField(default=False) + published = CustomQuerySetManager() + + Post.drop_collection() + + Post().save() + Post(is_published=True).save() + self.assertEquals(Post.objects.count(), 2) + self.assertEquals(Post.published.count(), 1) + + Post.drop_collection() + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """