diff --git a/docs/apireference.rst b/docs/apireference.rst index 69b1db03..049cc303 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -7,7 +7,6 @@ Connecting .. autofunction:: mongoengine.connect .. autofunction:: mongoengine.register_connection -.. autoclass:: mongoengine.SwitchDB Documents ========= @@ -35,6 +34,13 @@ Documents .. autoclass:: mongoengine.ValidationError :members: +Context Managers +================ + +.. autoclass:: mongoengine.context_managers.switch_db +.. autoclass:: mongoengine.context_managers.no_dereference +.. autoclass:: mongoengine.context_managers.query_counter + Querying ======== diff --git a/docs/changelog.rst b/docs/changelog.rst index 65e11034..bead6935 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,8 +32,9 @@ Changes in 0.8.X - Fixed inheritance and unique index creation (#140) - Fixed reverse delete rule with inheritance (#197) - Fixed validation for GenericReferences which havent been dereferenced -- Added SwitchDB context manager (#106) +- Added switch_db context manager (#106) - Added switch_db method to document instances (#106) +- Added no_dereference context manager (#82) Changes in 0.7.9 ================ diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index b39ccda4..ebd61a97 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -75,15 +75,15 @@ Switch Database Context Manager =============================== Sometimes you might want to switch the database to query against for a class. -The SwitchDB context manager allows you to change the database alias for a -class eg :: +The :class:`~mongoengine.context_managers.switch_db` context manager allows +you to change the database alias for a class eg :: - from mongoengine import SwitchDB + from mongoengine.context_managers import switch_db class User(Document): name = StringField() meta = {"db_alias": "user-db"} - with SwitchDB(User, 'archive-user-db') as User: + with switch_db(User, 'archive-user-db') as User: User(name="Ross").save() # Saves the 'archive-user-db' diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 40e36e32..7ccf1432 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -93,7 +93,7 @@ may used with :class:`~mongoengine.GeoPointField`\ s: [(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]). .. note:: Requires Mongo Server 2.0 * ``max_distance`` -- can be added to your location queries to set a maximum -distance. + distance. Querying lists @@ -369,6 +369,22 @@ references to the depth of 1 level. If you have more complicated documents and want to dereference more of the object at once then increasing the :attr:`max_depth` will dereference more levels of the document. +Turning off dereferencing +------------------------- + +Sometimes for performance reasons you don't want to automatically dereference +data . To turn off all dereferencing you can use the +:class:`~mongoengine.context_managers.no_dereference` context manager:: + + with no_dereference(Post) as Post: + post = Post.objects.first() + assert(isinstance(post.author, ObjectId)) + +.. note:: + + :class:`~mongoengine.context_managers.no_dereference` only works on the + Default QuerySet manager. + Advanced queries ================ Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index a892fbd2..82981e25 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -23,6 +23,7 @@ class BaseField(object): name = None _geo_index = False _auto_gen = False # Call `generate` to generate a value + _auto_dereference = True # These track each time a Field instance is created. Used to retain order. # The auto_creation_counter is used for fields that MongoEngine implicitly @@ -163,9 +164,11 @@ class ComplexBaseField(BaseField): ReferenceField = _import_class('ReferenceField') GenericReferenceField = _import_class('GenericReferenceField') - dereference = self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)) - if not self._dereference and instance._initialised and dereference: + dereference = (self._auto_dereference and + (self.field is None or isinstance(self.field, + (GenericReferenceField, ReferenceField)))) + + if not self.__dereference and instance._initialised and dereference: instance._data[self.name] = self._dereference( instance._data.get(self.name), max_depth=1, instance=instance, name=self.name @@ -182,7 +185,8 @@ class ComplexBaseField(BaseField): value = BaseDict(value, instance, self.name) instance._data[self.name] = value - if (instance._initialised and isinstance(value, (BaseList, BaseDict)) + if (self._auto_dereference and instance._initialised and + isinstance(value, (BaseList, BaseDict)) and not value._dereferenced): value = self._dereference( value, max_depth=1, instance=instance, name=self.name diff --git a/mongoengine/common.py b/mongoengine/common.py index a8422c09..718ac0b2 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -11,7 +11,7 @@ def _import_class(cls_name): field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', 'FileField', 'GenericReferenceField', 'GenericEmbeddedDocumentField', 'GeoPointField', - 'ReferenceField', 'StringField') + 'ReferenceField', 'StringField', 'ComplexBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 9f906a28..a47be446 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser __all__ = ['ConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME', 'SwitchDB'] + 'DEFAULT_CONNECTION_NAME'] DEFAULT_CONNECTION_NAME = 'default' @@ -164,47 +164,6 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs): return get_connection(alias) -class SwitchDB(object): - """ SwitchDB alias context manager. - - Example :: - - # Register connections - register_connection('default', 'mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') - - class Group(Document): - name = StringField() - - Group(name="test").save() # Saves in the default db - - with SwitchDB(Group, 'testdb-1') as Group: - Group(name="hello testdb!").save() # Saves in testdb-1 - - """ - - def __init__(self, cls, db_alias): - """ Construct the SwitchDB context manager - - :param cls: the class to change the registered db - :param db_alias: the name of the specific database to use - """ - self.cls = cls - self.collection = cls._get_collection() - self.db_alias = db_alias - self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) - - def __enter__(self): - """ change the db_alias and clear the cached collection """ - self.cls._meta["db_alias"] = self.db_alias - self.cls._collection = None - return self.cls - - def __exit__(self, t, value, traceback): - """ Reset the db_alias and collection """ - self.cls._meta["db_alias"] = self.ori_db_alias - self.cls._collection = self.collection - # Support old naming convention _get_connection = get_connection _get_db = get_db diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py new file mode 100644 index 00000000..7255d51c --- /dev/null +++ b/mongoengine/context_managers.py @@ -0,0 +1,159 @@ +from mongoengine.common import _import_class +from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.queryset import OperationError, QuerySet + +__all__ = ("switch_db", "no_dereference", "query_counter") + + +class switch_db(object): + """ switch_db alias context manager. + + Example :: + + # Register connections + register_connection('default', 'mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group(name="test").save() # Saves in the default db + + with switch_db(Group, 'testdb-1') as Group: + Group(name="hello testdb!").save() # Saves in testdb-1 + + """ + + def __init__(self, cls, db_alias): + """ Construct the switch_db context manager + + :param cls: the class to change the registered db + :param db_alias: the name of the specific database to use + """ + self.cls = cls + self.collection = cls._get_collection() + self.db_alias = db_alias + self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) + + def __enter__(self): + """ change the db_alias and clear the cached collection """ + self.cls._meta["db_alias"] = self.db_alias + self.cls._collection = None + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the db_alias and collection """ + self.cls._meta["db_alias"] = self.ori_db_alias + self.cls._collection = self.collection + + +class no_dereference(object): + """ no_dereference context manager. + + Turns off all dereferencing in Documents:: + + with no_dereference(Group) as Group: + Group.objects.find() + + """ + + def __init__(self, cls): + """ Construct the no_dereference context manager. + + :param cls: the class to turn dereferencing off on + """ + self.cls = cls + + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + ComplexBaseField = _import_class('ComplexBaseField') + + self.deref_fields = [k for k, v in self.cls._fields.iteritems() + if isinstance(v, (ReferenceField, + GenericReferenceField, + ComplexBaseField))] + + def __enter__(self): + """ change the objects default and _auto_dereference values""" + if 'queryset_class' in self.cls._meta: + raise OperationError("no_dereference context manager only works on" + " default queryset classes") + objects = self.cls.__dict__['objects'] + objects.default = QuerySetNoDeRef + self.cls.objects = objects + for field in self.deref_fields: + self.cls._fields[field]._auto_dereference = False + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the default and _auto_dereference values""" + objects = self.cls.__dict__['objects'] + objects.default = QuerySet + self.cls.objects = objects + for field in self.deref_fields: + self.cls._fields[field]._auto_dereference = True + return self.cls + + +class QuerySetNoDeRef(QuerySet): + """Special no_dereference QuerySet""" + def __dereference(items, max_depth=1, instance=None, name=None): + return items + + +class query_counter(object): + """ Query_counter contextmanager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/mongoengine/document.py b/mongoengine/document.py index 3bc4caed..9d4a1e6e 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,15 +1,17 @@ +from __future__ import with_statement import warnings import pymongo import re from bson.dbref import DBRef -from mongoengine import signals, queryset - -from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, - BaseDict, BaseList, ALLOW_INHERITANCE, get_document) -from queryset import OperationError, NotUniqueError -from connection import get_db, DEFAULT_CONNECTION_NAME, SwitchDB +from mongoengine import signals +from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, + BaseDocument, BaseDict, BaseList, + ALLOW_INHERITANCE, get_document) +from mongoengine.queryset import OperationError, NotUniqueError +from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME +from mongoengine.context_managers import switch_db __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', @@ -381,11 +383,11 @@ class Document(BaseDocument): user.save() If you need to read from another database see - :class:`~mongoengine.SwitchDB` + :class:`~mongoengine.context_managers.switch_db` :param db_alias: The database alias to use for saving the document """ - with SwitchDB(self.__class__, db_alias) as cls: + with switch_db(self.__class__, db_alias) as cls: collection = cls._get_collection() db = cls._get_db self._get_collection = lambda: collection diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f7817742..1ccdb650 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -779,7 +779,7 @@ class ReferenceField(BaseField): value = instance._data.get(self.name) # Dereference DBRefs - if isinstance(value, DBRef): + if self._auto_dereference and isinstance(value, DBRef): value = self.document_type._get_db().dereference(value) if value is not None: instance._data[self.name] = self.document_type._from_son(value) diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py index d9f9992f..47c2143d 100644 --- a/mongoengine/queryset/manager.py +++ b/mongoengine/queryset/manager.py @@ -18,11 +18,11 @@ class QuerySetManager(object): """ get_queryset = None + default = QuerySet def __init__(self, queryset_func=None): if queryset_func: self.get_queryset = queryset_func - self._collections = {} def __get__(self, instance, owner): """Descriptor for instantiating a new QuerySet object when @@ -33,7 +33,7 @@ class QuerySetManager(object): return self # owner is the document that contains the QuerySetManager - queryset_class = owner._meta.get('queryset_class') or QuerySet + queryset_class = owner._meta.get('queryset_class', self.default) queryset = queryset_class(owner, owner._get_collection()) if self.get_queryset: arg_count = self.get_queryset.func_code.co_argcount diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index e6373700..a9ff6e73 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -109,7 +109,6 @@ class QuerySet(object): queryset._class_check = class_check return queryset - def __iter__(self): """Support iterator protocol""" self.rewind() diff --git a/mongoengine/tests.py b/mongoengine/tests.py deleted file mode 100644 index 68663772..00000000 --- a/mongoengine/tests.py +++ /dev/null @@ -1,59 +0,0 @@ -from mongoengine.connection import get_db - - -class query_counter(object): - """ Query_counter contextmanager to get the number of queries. """ - - def __init__(self): - """ Construct the query_counter. """ - self.counter = 0 - self.db = get_db() - - def __enter__(self): - """ On every with block we need to drop the profile collection. """ - self.db.set_profiling_level(0) - self.db.system.profile.drop() - self.db.set_profiling_level(2) - return self - - def __exit__(self, t, value, traceback): - """ Reset the profiling level. """ - self.db.set_profiling_level(0) - - def __eq__(self, value): - """ == Compare querycounter. """ - return value == self._get_count() - - def __ne__(self, value): - """ != Compare querycounter. """ - return not self.__eq__(value) - - def __lt__(self, value): - """ < Compare querycounter. """ - return self._get_count() < value - - def __le__(self, value): - """ <= Compare querycounter. """ - return self._get_count() <= value - - def __gt__(self, value): - """ > Compare querycounter. """ - return self._get_count() > value - - def __ge__(self, value): - """ >= Compare querycounter. """ - return self._get_count() >= value - - def __int__(self): - """ int representation. """ - return self._get_count() - - def __repr__(self): - """ repr query_counter as the number of queries. """ - return u"%s" % self._get_count() - - def _get_count(self): - """ Get the number of queries. """ - count = self.db.system.profile.find().count() - self.counter - self.counter += 1 - return count diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index b5b0b280..35940447 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -17,7 +17,7 @@ from bson import ObjectId from mongoengine import * from mongoengine.connection import get_connection from mongoengine.python_support import PY3 -from mongoengine.tests import query_counter +from mongoengine.context_managers import query_counter from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, QueryFieldList, queryset_manager) diff --git a/tests/test_connection.py b/tests/test_connection.py index 7ff18a38..2a216fef 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,4 @@ +from __future__ import with_statement import datetime import pymongo import unittest @@ -8,6 +9,7 @@ from bson.tz_util import utc from mongoengine import * from mongoengine.connection import get_db, get_connection, ConnectionError +from mongoengine.context_managers import switch_db class ConnectionTest(unittest.TestCase): @@ -105,7 +107,7 @@ class ConnectionTest(unittest.TestCase): Group(name="hello - default").save() self.assertEqual(1, Group.objects.count()) - with SwitchDB(Group, 'testdb-1') as Group: + with switch_db(Group, 'testdb-1') as Group: self.assertEqual(0, Group.objects.count()) diff --git a/tests/test_dereference.py b/tests/test_dereference.py index f42482d1..8e4ffdd8 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -8,7 +8,7 @@ from bson import DBRef, ObjectId from mongoengine import * from mongoengine.connection import get_db -from mongoengine.tests import query_counter +from mongoengine.context_managers import query_counter, no_dereference class FieldTest(unittest.TestCase): @@ -1121,5 +1121,77 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) + def test_no_dereference_context_manager_object_id(self): + """Ensure that DBRef items in ListFields aren't dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=False) + generic = GenericReferenceField() + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + user = User.objects.first() + Group(ref=user, members=User.objects, generic=user).save() + + with no_dereference(Group) as NoDeRefGroup: + self.assertTrue(Group._fields['members']._auto_dereference) + self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + + with no_dereference(Group) as Group: + group = Group.objects.first() + self.assertTrue(all([not isinstance(m, User) + for m in group.members])) + self.assertFalse(isinstance(group.ref, User)) + self.assertFalse(isinstance(group.generic, User)) + + self.assertTrue(all([isinstance(m, User) + for m in group.members])) + self.assertTrue(isinstance(group.ref, User)) + self.assertTrue(isinstance(group.generic, User)) + + def test_no_dereference_context_manager_dbref(self): + """Ensure that DBRef items in ListFields aren't dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=True) + generic = GenericReferenceField() + members = ListField(ReferenceField(User, dbref=True)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + user = User.objects.first() + Group(ref=user, members=User.objects, generic=user).save() + + with no_dereference(Group) as NoDeRefGroup: + self.assertTrue(Group._fields['members']._auto_dereference) + self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + + with no_dereference(Group) as Group: + group = Group.objects.first() + self.assertTrue(all([not isinstance(m, User) + for m in group.members])) + self.assertFalse(isinstance(group.ref, User)) + self.assertFalse(isinstance(group.generic, User)) + + self.assertTrue(all([isinstance(m, User) + for m in group.members])) + self.assertTrue(isinstance(group.ref, User)) + self.assertTrue(isinstance(group.generic, User)) + if __name__ == '__main__': unittest.main()