Added no_dereference context manager (#82)

Reorganised the context_managers as well
This commit is contained in:
Ross Lawley 2013-01-23 19:05:44 +00:00
parent 4f70c27b56
commit 3a6dc77d36
16 changed files with 289 additions and 128 deletions

View File

@ -7,7 +7,6 @@ Connecting
.. autofunction:: mongoengine.connect
.. autofunction:: mongoengine.register_connection
.. autoclass:: mongoengine.SwitchDB
Documents
=========
@ -35,6 +34,13 @@ Documents
.. autoclass:: mongoengine.ValidationError
:members:
Context Managers
================
.. autoclass:: mongoengine.context_managers.switch_db
.. autoclass:: mongoengine.context_managers.no_dereference
.. autoclass:: mongoengine.context_managers.query_counter
Querying
========

View File

@ -32,8 +32,9 @@ Changes in 0.8.X
- Fixed inheritance and unique index creation (#140)
- Fixed reverse delete rule with inheritance (#197)
- Fixed validation for GenericReferences which havent been dereferenced
- Added SwitchDB context manager (#106)
- Added switch_db context manager (#106)
- Added switch_db method to document instances (#106)
- Added no_dereference context manager (#82)
Changes in 0.7.9
================

View File

@ -75,15 +75,15 @@ Switch Database Context Manager
===============================
Sometimes you might want to switch the database to query against for a class.
The SwitchDB context manager allows you to change the database alias for a
class eg ::
The :class:`~mongoengine.context_managers.switch_db` context manager allows
you to change the database alias for a class eg ::
from mongoengine import SwitchDB
from mongoengine.context_managers import switch_db
class User(Document):
name = StringField()
meta = {"db_alias": "user-db"}
with SwitchDB(User, 'archive-user-db') as User:
with switch_db(User, 'archive-user-db') as User:
User(name="Ross").save() # Saves the 'archive-user-db'

View File

@ -93,7 +93,7 @@ may used with :class:`~mongoengine.GeoPointField`\ s:
[(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]).
.. note:: Requires Mongo Server 2.0
* ``max_distance`` -- can be added to your location queries to set a maximum
distance.
distance.
Querying lists
@ -369,6 +369,22 @@ references to the depth of 1 level. If you have more complicated documents and
want to dereference more of the object at once then increasing the :attr:`max_depth`
will dereference more levels of the document.
Turning off dereferencing
-------------------------
Sometimes for performance reasons you don't want to automatically dereference
data . To turn off all dereferencing you can use the
:class:`~mongoengine.context_managers.no_dereference` context manager::
with no_dereference(Post) as Post:
post = Post.objects.first()
assert(isinstance(post.author, ObjectId))
.. note::
:class:`~mongoengine.context_managers.no_dereference` only works on the
Default QuerySet manager.
Advanced queries
================
Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword

View File

@ -23,6 +23,7 @@ class BaseField(object):
name = None
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
_auto_dereference = True
# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
@ -163,9 +164,11 @@ class ComplexBaseField(BaseField):
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
dereference = self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))
if not self._dereference and instance._initialised and dereference:
dereference = (self._auto_dereference and
(self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))))
if not self.__dereference and instance._initialised and dereference:
instance._data[self.name] = self._dereference(
instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name
@ -182,7 +185,8 @@ class ComplexBaseField(BaseField):
value = BaseDict(value, instance, self.name)
instance._data[self.name] = value
if (instance._initialised and isinstance(value, (BaseList, BaseDict))
if (self._auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict))
and not value._dereferenced):
value = self._dereference(
value, max_depth=1, instance=instance, name=self.name

View File

@ -11,7 +11,7 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField',
'ReferenceField', 'StringField')
'ReferenceField', 'StringField', 'ComplexBaseField')
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)

View File

@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser
__all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME', 'SwitchDB']
'DEFAULT_CONNECTION_NAME']
DEFAULT_CONNECTION_NAME = 'default'
@ -164,47 +164,6 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
return get_connection(alias)
class SwitchDB(object):
""" SwitchDB alias context manager.
Example ::
# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with SwitchDB(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the SwitchDB context manager
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
# Support old naming convention
_get_connection = get_connection
_get_db = get_db

View File

@ -0,0 +1,159 @@
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import OperationError, QuerySet
__all__ = ("switch_db", "no_dereference", "query_counter")
class switch_db(object):
""" switch_db alias context manager.
Example ::
# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with switch_db(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the switch_db context manager
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
class no_dereference(object):
""" no_dereference context manager.
Turns off all dereferencing in Documents::
with no_dereference(Group) as Group:
Group.objects.find()
"""
def __init__(self, cls):
""" Construct the no_dereference context manager.
:param cls: the class to turn dereferencing off on
"""
self.cls = cls
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
ComplexBaseField = _import_class('ComplexBaseField')
self.deref_fields = [k for k, v in self.cls._fields.iteritems()
if isinstance(v, (ReferenceField,
GenericReferenceField,
ComplexBaseField))]
def __enter__(self):
""" change the objects default and _auto_dereference values"""
if 'queryset_class' in self.cls._meta:
raise OperationError("no_dereference context manager only works on"
" default queryset classes")
objects = self.cls.__dict__['objects']
objects.default = QuerySetNoDeRef
self.cls.objects = objects
for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values"""
objects = self.cls.__dict__['objects']
objects.default = QuerySet
self.cls.objects = objects
for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
return self.cls
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items
class query_counter(object):
""" Query_counter contextmanager to get the number of queries. """
def __init__(self):
""" Construct the query_counter. """
self.counter = 0
self.db = get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
return value == self._get_count()
def __ne__(self, value):
""" != Compare querycounter. """
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
return self._get_count() >= value
def __int__(self):
""" int representation. """
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter
self.counter += 1
return count

View File

@ -1,15 +1,17 @@
from __future__ import with_statement
import warnings
import pymongo
import re
from bson.dbref import DBRef
from mongoengine import signals, queryset
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
BaseDict, BaseList, ALLOW_INHERITANCE, get_document)
from queryset import OperationError, NotUniqueError
from connection import get_db, DEFAULT_CONNECTION_NAME, SwitchDB
from mongoengine import signals
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document)
from mongoengine.queryset import OperationError, NotUniqueError
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError',
@ -381,11 +383,11 @@ class Document(BaseDocument):
user.save()
If you need to read from another database see
:class:`~mongoengine.SwitchDB`
:class:`~mongoengine.context_managers.switch_db`
:param db_alias: The database alias to use for saving the document
"""
with SwitchDB(self.__class__, db_alias) as cls:
with switch_db(self.__class__, db_alias) as cls:
collection = cls._get_collection()
db = cls._get_db
self._get_collection = lambda: collection

View File

@ -779,7 +779,7 @@ class ReferenceField(BaseField):
value = instance._data.get(self.name)
# Dereference DBRefs
if isinstance(value, DBRef):
if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)

View File

@ -18,11 +18,11 @@ class QuerySetManager(object):
"""
get_queryset = None
default = QuerySet
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {}
def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when
@ -33,7 +33,7 @@ class QuerySetManager(object):
return self
# owner is the document that contains the QuerySetManager
queryset_class = owner._meta.get('queryset_class') or QuerySet
queryset_class = owner._meta.get('queryset_class', self.default)
queryset = queryset_class(owner, owner._get_collection())
if self.get_queryset:
arg_count = self.get_queryset.func_code.co_argcount

View File

@ -109,7 +109,6 @@ class QuerySet(object):
queryset._class_check = class_check
return queryset
def __iter__(self):
"""Support iterator protocol"""
self.rewind()

View File

@ -1,59 +0,0 @@
from mongoengine.connection import get_db
class query_counter(object):
""" Query_counter contextmanager to get the number of queries. """
def __init__(self):
""" Construct the query_counter. """
self.counter = 0
self.db = get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
return value == self._get_count()
def __ne__(self, value):
""" != Compare querycounter. """
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
return self._get_count() >= value
def __int__(self):
""" int representation. """
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter
self.counter += 1
return count

View File

@ -17,7 +17,7 @@ from bson import ObjectId
from mongoengine import *
from mongoengine.connection import get_connection
from mongoengine.python_support import PY3
from mongoengine.tests import query_counter
from mongoengine.context_managers import query_counter
from mongoengine.queryset import (QuerySet, QuerySetManager,
MultipleObjectsReturned, DoesNotExist,
QueryFieldList, queryset_manager)

View File

@ -1,3 +1,4 @@
from __future__ import with_statement
import datetime
import pymongo
import unittest
@ -8,6 +9,7 @@ from bson.tz_util import utc
from mongoengine import *
from mongoengine.connection import get_db, get_connection, ConnectionError
from mongoengine.context_managers import switch_db
class ConnectionTest(unittest.TestCase):
@ -105,7 +107,7 @@ class ConnectionTest(unittest.TestCase):
Group(name="hello - default").save()
self.assertEqual(1, Group.objects.count())
with SwitchDB(Group, 'testdb-1') as Group:
with switch_db(Group, 'testdb-1') as Group:
self.assertEqual(0, Group.objects.count())

View File

@ -8,7 +8,7 @@ from bson import DBRef, ObjectId
from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.tests import query_counter
from mongoengine.context_managers import query_counter, no_dereference
class FieldTest(unittest.TestCase):
@ -1121,5 +1121,77 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2)
def test_no_dereference_context_manager_object_id(self):
"""Ensure that DBRef items in ListFields aren't dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
ref = ReferenceField(User, dbref=False)
generic = GenericReferenceField()
members = ListField(ReferenceField(User, dbref=False))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
User(name='user %s' % i).save()
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()
with no_dereference(Group) as NoDeRefGroup:
self.assertTrue(Group._fields['members']._auto_dereference)
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
with no_dereference(Group) as Group:
group = Group.objects.first()
self.assertTrue(all([not isinstance(m, User)
for m in group.members]))
self.assertFalse(isinstance(group.ref, User))
self.assertFalse(isinstance(group.generic, User))
self.assertTrue(all([isinstance(m, User)
for m in group.members]))
self.assertTrue(isinstance(group.ref, User))
self.assertTrue(isinstance(group.generic, User))
def test_no_dereference_context_manager_dbref(self):
"""Ensure that DBRef items in ListFields aren't dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
ref = ReferenceField(User, dbref=True)
generic = GenericReferenceField()
members = ListField(ReferenceField(User, dbref=True))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
User(name='user %s' % i).save()
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()
with no_dereference(Group) as NoDeRefGroup:
self.assertTrue(Group._fields['members']._auto_dereference)
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
with no_dereference(Group) as Group:
group = Group.objects.first()
self.assertTrue(all([not isinstance(m, User)
for m in group.members]))
self.assertFalse(isinstance(group.ref, User))
self.assertFalse(isinstance(group.generic, User))
self.assertTrue(all([isinstance(m, User)
for m in group.members]))
self.assertTrue(isinstance(group.ref, User))
self.assertTrue(isinstance(group.generic, User))
if __name__ == '__main__':
unittest.main()