diff --git a/mongoengine/base.py b/mongoengine/base.py index 07f53c30..18ee9134 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -447,7 +447,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Subclassed documents inherit collection from superclass for base in bases: if hasattr(base, '_meta'): - if 'collection' in attrs.get('meta', {}) and not base._meta.get('abstract', False): import warnings msg = "Trying to set a collection on a subclass (%s)" % name @@ -465,14 +464,20 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Propagate 'allow_inheritance' if 'allow_inheritance' in base._meta: base_meta['allow_inheritance'] = base._meta['allow_inheritance'] + if 'queryset_class' in base._meta: + base_meta['queryset_class'] = base._meta['queryset_class'] + try: + base_meta['objects'] = base.__getattribute__(base, 'objects') + except AttributeError: + pass meta = { 'abstract': False, 'collection': collection, 'max_documents': None, 'max_size': None, - 'ordering': [], # default ordering applied at runtime - 'indexes': [], # indexes to be ensured at runtime + 'ordering': [], # default ordering applied at runtime + 'indexes': [], # indexes to be ensured at runtime 'id_field': id_field, 'index_background': False, 'index_drop_dups': False, @@ -496,7 +501,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class._meta['collection'] = collection(new_class) # Provide a default queryset unless one has been manually provided - manager = attrs.get('objects', QuerySetManager()) + manager = attrs.get('objects', meta.get('objects', QuerySetManager())) if hasattr(manager, 'queryset_class'): meta['queryset_class'] = manager.queryset_class new_class.objects = manager diff --git a/tests/queryset.py b/tests/queryset.py index 51c95112..a21bae69 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2318,6 +2318,56 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_custom_querysets_inherited(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class Base(Document): + meta = {'abstract': True, 'queryset_class': CustomQuerySet} + + class Post(Base): + pass + + Post.drop_collection() + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + def test_custom_querysets_inherited_direct(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class CustomQuerySetManager(QuerySetManager): + queryset_class = CustomQuerySet + + class Base(Document): + meta = {'abstract': True} + objects = CustomQuerySetManager() + + class Post(Base): + pass + + Post.drop_collection() + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """