From 5b498bd8d6f161e81605a686fe927307b2e28078 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 10 May 2013 15:05:16 +0000 Subject: [PATCH] Added no_sub_classes context manager and queryset helper (#312) --- docs/changelog.rst | 1 + mongoengine/context_managers.py | 36 ++++++++++++++++++++-- mongoengine/queryset/queryset.py | 4 +-- tests/queryset/queryset.py | 23 ++++++++++++-- tests/test_context_managers.py | 51 +++++++++++++++++++++++++++++++- 5 files changed, 108 insertions(+), 7 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 3b6813a0..c3e50e40 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.0 ================ +- Added no_sub_classes context manager and queryset helper (#312) - Querysets now utilises a local cache - Changed __len__ behavour in the queryset (#247, #311) - Fixed querying string versions of ObjectIds issue with ReferenceField (#307) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 76d5fbfa..1280e117 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,8 +1,10 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -from mongoengine.queryset import OperationError, QuerySet +from mongoengine.queryset import QuerySet -__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter") + +__all__ = ("switch_db", "switch_collection", "no_dereference", + "no_sub_classes", "query_counter") class switch_db(object): @@ -130,6 +132,36 @@ class no_dereference(object): return self.cls +class no_sub_classes(object): + """ no_sub_classes context manager. + + Only returns instances of this class and no sub (inherited) classes:: + + with no_sub_classes(Group) as Group: + Group.objects.find() + + """ + + def __init__(self, cls): + """ Construct the no_sub_classes context manager. + + :param cls: the class to turn querying sub classes on + """ + self.cls = cls + + def __enter__(self): + """ change the objects default and _auto_dereference values""" + self.cls._all_subclasses = self.cls._subclasses + self.cls._subclasses = (self.cls,) + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the default and _auto_dereference values""" + self.cls._subclasses = self.cls._all_subclasses + delattr(self.cls, '_all_subclasses') + return self.cls + + class QuerySetNoDeRef(QuerySet): """Special no_dereference QuerySet""" def __dereference(items, max_depth=1, instance=None, name=None): diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 4cf86d1e..5da62953 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -569,9 +569,9 @@ class QuerySet(object): queryset._none = True return queryset - def disable_inheritance(self): + def no_sub_classes(self): """ - Disable inheritance query, fetch only objects for the query class + Only return instances of this document and not any inherited documents """ if self._document._meta.get('allow_inheritance') is True: self._initial_query = {"_cls": self._document._class_name} diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 27a418dc..bf237618 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -3322,7 +3322,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 2) - def test_disable_inheritance_queryset(self): + def test_no_sub_classes(self): class A(Document): x = IntField() y = IntField() @@ -3332,15 +3332,34 @@ class QuerySetTest(unittest.TestCase): class B(A): z = IntField() + class C(B): + zz = IntField() + A.drop_collection() A(x=10, y=20).save() A(x=15, y=30).save() B(x=20, y=40).save() B(x=30, y=50).save() + C(x=40, y=60).save() - for obj in A.objects.disable_inheritance(): + self.assertEqual(A.objects.no_sub_classes().count(), 2) + self.assertEqual(A.objects.count(), 5) + + self.assertEqual(B.objects.no_sub_classes().count(), 2) + self.assertEqual(B.objects.count(), 3) + + self.assertEqual(C.objects.no_sub_classes().count(), 1) + self.assertEqual(C.objects.count(), 1) + + for obj in A.objects.no_sub_classes(): self.assertEqual(obj.__class__, A) + for obj in B.objects.no_sub_classes(): + self.assertEqual(obj.__class__, B) + + for obj in C.objects.no_sub_classes(): + self.assertEqual(obj.__class__, C) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index f87d6383..c201a5fc 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -5,7 +5,8 @@ import unittest from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, - no_dereference, query_counter) + no_sub_classes, no_dereference, + query_counter) class ContextManagersTest(unittest.TestCase): @@ -138,6 +139,54 @@ class ContextManagersTest(unittest.TestCase): self.assertTrue(isinstance(group.ref, User)) self.assertTrue(isinstance(group.generic, User)) + def test_no_sub_classes(self): + class A(Document): + x = IntField() + y = IntField() + + meta = {'allow_inheritance': True} + + class B(A): + z = IntField() + + class C(B): + zz = IntField() + + A.drop_collection() + + A(x=10, y=20).save() + A(x=15, y=30).save() + B(x=20, y=40).save() + B(x=30, y=50).save() + C(x=40, y=60).save() + + self.assertEqual(A.objects.count(), 5) + self.assertEqual(B.objects.count(), 3) + self.assertEqual(C.objects.count(), 1) + + with no_sub_classes(A) as A: + self.assertEqual(A.objects.count(), 2) + + for obj in A.objects: + self.assertEqual(obj.__class__, A) + + with no_sub_classes(B) as B: + self.assertEqual(B.objects.count(), 2) + + for obj in B.objects: + self.assertEqual(obj.__class__, B) + + with no_sub_classes(C) as C: + self.assertEqual(C.objects.count(), 1) + + for obj in C.objects: + self.assertEqual(obj.__class__, C) + + # Confirm context manager exit correctly + self.assertEqual(A.objects.count(), 5) + self.assertEqual(B.objects.count(), 3) + self.assertEqual(C.objects.count(), 1) + def test_query_counter(self): connect('mongoenginetest') db = get_db()