diff --git a/mongoengine/base.py b/mongoengine/base.py index 9a6b5f12..77c2d7d1 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -299,8 +299,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class = super_new(cls, name, bases, attrs) # Provide a default queryset unless one has been manually provided - if not hasattr(new_class, 'objects'): - new_class.objects = QuerySetManager() + manager = attrs.get('objects', QuerySetManager()) + if hasattr(manager, 'queryset_class'): + meta['queryset_class'] = manager.queryset_class + new_class.objects = manager user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index e6c93353..a494fbd0 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -428,8 +428,6 @@ class QuerySet(object): querying collection :param query: Django-style query keyword arguments """ - #if q_obj: - #self._where_clause = q_obj.as_js(self._document) query = Q(**query) if q_obj: query &= q_obj @@ -1322,8 +1320,11 @@ class QuerySet(object): class QuerySetManager(object): - def __init__(self, manager_func=None): - self._manager_func = manager_func + get_queryset = None + + def __init__(self, queryset_func=None): + if queryset_func: + self.get_queryset = queryset_func self._collections = {} def __get__(self, instance, owner): @@ -1367,11 +1368,11 @@ class QuerySetManager(object): # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet queryset = queryset_class(owner, self._collections[(db, collection)]) - if self._manager_func: - if self._manager_func.func_code.co_argcount == 1: - queryset = self._manager_func(queryset) + if self.get_queryset: + if self.get_queryset.func_code.co_argcount == 1: + queryset = self.get_queryset(queryset) else: - queryset = self._manager_func(owner, queryset) + queryset = self.get_queryset(owner, queryset) return queryset diff --git a/tests/queryset.py b/tests/queryset.py index b0693f66..6a8e16d5 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -5,8 +5,9 @@ import unittest import pymongo from datetime import datetime, timedelta -from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, - DoesNotExist, QueryFieldList) +from mongoengine.queryset import (QuerySet, QuerySetManager, + MultipleObjectsReturned, DoesNotExist, + QueryFieldList) from mongoengine import * @@ -1786,6 +1787,53 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_custom_querysets_set_manager_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class CustomQuerySetManager(QuerySetManager): + queryset_class = CustomQuerySet + + class Post(Document): + objects = CustomQuerySetManager() + + Post.drop_collection() + + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + def test_custom_querysets_managers_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySetManager(QuerySetManager): + + @staticmethod + def get_queryset(doc_cls, queryset): + return queryset(is_published=True) + + class Post(Document): + is_published = BooleanField(default=False) + published = CustomQuerySetManager() + + Post.drop_collection() + + Post().save() + Post(is_published=True).save() + self.assertEquals(Post.objects.count(), 2) + self.assertEquals(Post.published.count(), 1) + + Post.drop_collection() + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """