Merge remote branch 'upstream/dev' into dev

This commit is contained in:
Alistair Roche 2011-05-24 11:32:23 +01:00
commit 8427877bd2
3 changed files with 63 additions and 12 deletions

View File

@ -299,8 +299,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
# Provide a default queryset unless one has been manually provided # Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'): manager = attrs.get('objects', QuerySetManager())
new_class.objects = QuerySetManager() if hasattr(manager, 'queryset_class'):
meta['queryset_class'] = manager.queryset_class
new_class.objects = manager
user_indexes = [QuerySet._build_index_spec(new_class, spec) user_indexes = [QuerySet._build_index_spec(new_class, spec)
for spec in meta['indexes']] + base_indexes for spec in meta['indexes']] + base_indexes

View File

@ -428,8 +428,6 @@ class QuerySet(object):
querying collection querying collection
:param query: Django-style query keyword arguments :param query: Django-style query keyword arguments
""" """
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query) query = Q(**query)
if q_obj: if q_obj:
query &= q_obj query &= q_obj
@ -1322,8 +1320,11 @@ class QuerySet(object):
class QuerySetManager(object): class QuerySetManager(object):
def __init__(self, manager_func=None): get_queryset = None
self._manager_func = manager_func
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {} self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
@ -1367,11 +1368,11 @@ class QuerySetManager(object):
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)]) queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func: if self.get_queryset:
if self._manager_func.func_code.co_argcount == 1: if self.get_queryset.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self.get_queryset(queryset)
else: else:
queryset = self._manager_func(owner, queryset) queryset = self.get_queryset(owner, queryset)
return queryset return queryset

View File

@ -5,8 +5,9 @@ import unittest
import pymongo import pymongo
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, QuerySetManager,
DoesNotExist, QueryFieldList) MultipleObjectsReturned, DoesNotExist,
QueryFieldList)
from mongoengine import * from mongoengine import *
@ -1786,6 +1787,53 @@ class QuerySetTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
def test_custom_querysets_set_manager_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySet(QuerySet):
def not_empty(self):
return len(self) > 0
class CustomQuerySetManager(QuerySetManager):
queryset_class = CustomQuerySet
class Post(Document):
objects = CustomQuerySetManager()
Post.drop_collection()
self.assertTrue(isinstance(Post.objects, CustomQuerySet))
self.assertFalse(Post.objects.not_empty())
Post().save()
self.assertTrue(Post.objects.not_empty())
Post.drop_collection()
def test_custom_querysets_managers_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySetManager(QuerySetManager):
@staticmethod
def get_queryset(doc_cls, queryset):
return queryset(is_published=True)
class Post(Document):
is_published = BooleanField(default=False)
published = CustomQuerySetManager()
Post.drop_collection()
Post().save()
Post(is_published=True).save()
self.assertEquals(Post.objects.count(), 2)
self.assertEquals(Post.published.count(), 1)
Post.drop_collection()
def test_call_after_limits_set(self): def test_call_after_limits_set(self):
"""Ensure that re-filtering after slicing works """Ensure that re-filtering after slicing works
""" """