Added no_dereference context manager (#82)

Reorganised the context_managers as well
This commit is contained in:
Ross Lawley 2013-01-23 19:05:44 +00:00
parent 4f70c27b56
commit 3a6dc77d36
16 changed files with 289 additions and 128 deletions

View File

@ -7,7 +7,6 @@ Connecting
.. autofunction:: mongoengine.connect .. autofunction:: mongoengine.connect
.. autofunction:: mongoengine.register_connection .. autofunction:: mongoengine.register_connection
.. autoclass:: mongoengine.SwitchDB
Documents Documents
========= =========
@ -35,6 +34,13 @@ Documents
.. autoclass:: mongoengine.ValidationError .. autoclass:: mongoengine.ValidationError
:members: :members:
Context Managers
================
.. autoclass:: mongoengine.context_managers.switch_db
.. autoclass:: mongoengine.context_managers.no_dereference
.. autoclass:: mongoengine.context_managers.query_counter
Querying Querying
======== ========

View File

@ -32,8 +32,9 @@ Changes in 0.8.X
- Fixed inheritance and unique index creation (#140) - Fixed inheritance and unique index creation (#140)
- Fixed reverse delete rule with inheritance (#197) - Fixed reverse delete rule with inheritance (#197)
- Fixed validation for GenericReferences which havent been dereferenced - Fixed validation for GenericReferences which havent been dereferenced
- Added SwitchDB context manager (#106) - Added switch_db context manager (#106)
- Added switch_db method to document instances (#106) - Added switch_db method to document instances (#106)
- Added no_dereference context manager (#82)
Changes in 0.7.9 Changes in 0.7.9
================ ================

View File

@ -75,15 +75,15 @@ Switch Database Context Manager
=============================== ===============================
Sometimes you might want to switch the database to query against for a class. Sometimes you might want to switch the database to query against for a class.
The SwitchDB context manager allows you to change the database alias for a The :class:`~mongoengine.context_managers.switch_db` context manager allows
class eg :: you to change the database alias for a class eg ::
from mongoengine import SwitchDB from mongoengine.context_managers import switch_db
class User(Document): class User(Document):
name = StringField() name = StringField()
meta = {"db_alias": "user-db"} meta = {"db_alias": "user-db"}
with SwitchDB(User, 'archive-user-db') as User: with switch_db(User, 'archive-user-db') as User:
User(name="Ross").save() # Saves the 'archive-user-db' User(name="Ross").save() # Saves the 'archive-user-db'

View File

@ -369,6 +369,22 @@ references to the depth of 1 level. If you have more complicated documents and
want to dereference more of the object at once then increasing the :attr:`max_depth` want to dereference more of the object at once then increasing the :attr:`max_depth`
will dereference more levels of the document. will dereference more levels of the document.
Turning off dereferencing
-------------------------
Sometimes for performance reasons you don't want to automatically dereference
data . To turn off all dereferencing you can use the
:class:`~mongoengine.context_managers.no_dereference` context manager::
with no_dereference(Post) as Post:
post = Post.objects.first()
assert(isinstance(post.author, ObjectId))
.. note::
:class:`~mongoengine.context_managers.no_dereference` only works on the
Default QuerySet manager.
Advanced queries Advanced queries
================ ================
Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword

View File

@ -23,6 +23,7 @@ class BaseField(object):
name = None name = None
_geo_index = False _geo_index = False
_auto_gen = False # Call `generate` to generate a value _auto_gen = False # Call `generate` to generate a value
_auto_dereference = True
# These track each time a Field instance is created. Used to retain order. # These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly # The auto_creation_counter is used for fields that MongoEngine implicitly
@ -163,9 +164,11 @@ class ComplexBaseField(BaseField):
ReferenceField = _import_class('ReferenceField') ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField') GenericReferenceField = _import_class('GenericReferenceField')
dereference = self.field is None or isinstance(self.field, dereference = (self._auto_dereference and
(GenericReferenceField, ReferenceField)) (self.field is None or isinstance(self.field,
if not self._dereference and instance._initialised and dereference: (GenericReferenceField, ReferenceField))))
if not self.__dereference and instance._initialised and dereference:
instance._data[self.name] = self._dereference( instance._data[self.name] = self._dereference(
instance._data.get(self.name), max_depth=1, instance=instance, instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name name=self.name
@ -182,7 +185,8 @@ class ComplexBaseField(BaseField):
value = BaseDict(value, instance, self.name) value = BaseDict(value, instance, self.name)
instance._data[self.name] = value instance._data[self.name] = value
if (instance._initialised and isinstance(value, (BaseList, BaseDict)) if (self._auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict))
and not value._dereferenced): and not value._dereferenced):
value = self._dereference( value = self._dereference(
value, max_depth=1, instance=instance, name=self.name value, max_depth=1, instance=instance, name=self.name

View File

@ -11,7 +11,7 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField', 'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField', 'GenericEmbeddedDocumentField', 'GeoPointField',
'ReferenceField', 'StringField') 'ReferenceField', 'StringField', 'ComplexBaseField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)

View File

@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser
__all__ = ['ConnectionError', 'connect', 'register_connection', __all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME', 'SwitchDB'] 'DEFAULT_CONNECTION_NAME']
DEFAULT_CONNECTION_NAME = 'default' DEFAULT_CONNECTION_NAME = 'default'
@ -164,47 +164,6 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
return get_connection(alias) return get_connection(alias)
class SwitchDB(object):
""" SwitchDB alias context manager.
Example ::
# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with SwitchDB(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the SwitchDB context manager
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
# Support old naming convention # Support old naming convention
_get_connection = get_connection _get_connection = get_connection
_get_db = get_db _get_db = get_db

View File

@ -0,0 +1,159 @@
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import OperationError, QuerySet
__all__ = ("switch_db", "no_dereference", "query_counter")
class switch_db(object):
""" switch_db alias context manager.
Example ::
# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with switch_db(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the switch_db context manager
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
class no_dereference(object):
""" no_dereference context manager.
Turns off all dereferencing in Documents::
with no_dereference(Group) as Group:
Group.objects.find()
"""
def __init__(self, cls):
""" Construct the no_dereference context manager.
:param cls: the class to turn dereferencing off on
"""
self.cls = cls
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
ComplexBaseField = _import_class('ComplexBaseField')
self.deref_fields = [k for k, v in self.cls._fields.iteritems()
if isinstance(v, (ReferenceField,
GenericReferenceField,
ComplexBaseField))]
def __enter__(self):
""" change the objects default and _auto_dereference values"""
if 'queryset_class' in self.cls._meta:
raise OperationError("no_dereference context manager only works on"
" default queryset classes")
objects = self.cls.__dict__['objects']
objects.default = QuerySetNoDeRef
self.cls.objects = objects
for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values"""
objects = self.cls.__dict__['objects']
objects.default = QuerySet
self.cls.objects = objects
for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
return self.cls
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items
class query_counter(object):
""" Query_counter contextmanager to get the number of queries. """
def __init__(self):
""" Construct the query_counter. """
self.counter = 0
self.db = get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
return value == self._get_count()
def __ne__(self, value):
""" != Compare querycounter. """
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
return self._get_count() >= value
def __int__(self):
""" int representation. """
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter
self.counter += 1
return count

View File

@ -1,15 +1,17 @@
from __future__ import with_statement
import warnings import warnings
import pymongo import pymongo
import re import re
from bson.dbref import DBRef from bson.dbref import DBRef
from mongoengine import signals, queryset from mongoengine import signals
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDocument, BaseDict, BaseList,
BaseDict, BaseList, ALLOW_INHERITANCE, get_document) ALLOW_INHERITANCE, get_document)
from queryset import OperationError, NotUniqueError from mongoengine.queryset import OperationError, NotUniqueError
from connection import get_db, DEFAULT_CONNECTION_NAME, SwitchDB from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError', 'DynamicEmbeddedDocument', 'OperationError',
@ -381,11 +383,11 @@ class Document(BaseDocument):
user.save() user.save()
If you need to read from another database see If you need to read from another database see
:class:`~mongoengine.SwitchDB` :class:`~mongoengine.context_managers.switch_db`
:param db_alias: The database alias to use for saving the document :param db_alias: The database alias to use for saving the document
""" """
with SwitchDB(self.__class__, db_alias) as cls: with switch_db(self.__class__, db_alias) as cls:
collection = cls._get_collection() collection = cls._get_collection()
db = cls._get_db db = cls._get_db
self._get_collection = lambda: collection self._get_collection = lambda: collection

View File

@ -779,7 +779,7 @@ class ReferenceField(BaseField):
value = instance._data.get(self.name) value = instance._data.get(self.name)
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, DBRef): if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value) value = self.document_type._get_db().dereference(value)
if value is not None: if value is not None:
instance._data[self.name] = self.document_type._from_son(value) instance._data[self.name] = self.document_type._from_son(value)

View File

@ -18,11 +18,11 @@ class QuerySetManager(object):
""" """
get_queryset = None get_queryset = None
default = QuerySet
def __init__(self, queryset_func=None): def __init__(self, queryset_func=None):
if queryset_func: if queryset_func:
self.get_queryset = queryset_func self.get_queryset = queryset_func
self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when """Descriptor for instantiating a new QuerySet object when
@ -33,7 +33,7 @@ class QuerySetManager(object):
return self return self
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta.get('queryset_class') or QuerySet queryset_class = owner._meta.get('queryset_class', self.default)
queryset = queryset_class(owner, owner._get_collection()) queryset = queryset_class(owner, owner._get_collection())
if self.get_queryset: if self.get_queryset:
arg_count = self.get_queryset.func_code.co_argcount arg_count = self.get_queryset.func_code.co_argcount

View File

@ -109,7 +109,6 @@ class QuerySet(object):
queryset._class_check = class_check queryset._class_check = class_check
return queryset return queryset
def __iter__(self): def __iter__(self):
"""Support iterator protocol""" """Support iterator protocol"""
self.rewind() self.rewind()

View File

@ -1,59 +0,0 @@
from mongoengine.connection import get_db
class query_counter(object):
""" Query_counter contextmanager to get the number of queries. """
def __init__(self):
""" Construct the query_counter. """
self.counter = 0
self.db = get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
return value == self._get_count()
def __ne__(self, value):
""" != Compare querycounter. """
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
return self._get_count() >= value
def __int__(self):
""" int representation. """
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter
self.counter += 1
return count

View File

@ -17,7 +17,7 @@ from bson import ObjectId
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_connection from mongoengine.connection import get_connection
from mongoengine.python_support import PY3 from mongoengine.python_support import PY3
from mongoengine.tests import query_counter from mongoengine.context_managers import query_counter
from mongoengine.queryset import (QuerySet, QuerySetManager, from mongoengine.queryset import (QuerySet, QuerySetManager,
MultipleObjectsReturned, DoesNotExist, MultipleObjectsReturned, DoesNotExist,
QueryFieldList, queryset_manager) QueryFieldList, queryset_manager)

View File

@ -1,3 +1,4 @@
from __future__ import with_statement
import datetime import datetime
import pymongo import pymongo
import unittest import unittest
@ -8,6 +9,7 @@ from bson.tz_util import utc
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db, get_connection, ConnectionError from mongoengine.connection import get_db, get_connection, ConnectionError
from mongoengine.context_managers import switch_db
class ConnectionTest(unittest.TestCase): class ConnectionTest(unittest.TestCase):
@ -105,7 +107,7 @@ class ConnectionTest(unittest.TestCase):
Group(name="hello - default").save() Group(name="hello - default").save()
self.assertEqual(1, Group.objects.count()) self.assertEqual(1, Group.objects.count())
with SwitchDB(Group, 'testdb-1') as Group: with switch_db(Group, 'testdb-1') as Group:
self.assertEqual(0, Group.objects.count()) self.assertEqual(0, Group.objects.count())

View File

@ -8,7 +8,7 @@ from bson import DBRef, ObjectId
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.tests import query_counter from mongoengine.context_managers import query_counter, no_dereference
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -1121,5 +1121,77 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_no_dereference_context_manager_object_id(self):
"""Ensure that DBRef items in ListFields aren't dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
ref = ReferenceField(User, dbref=False)
generic = GenericReferenceField()
members = ListField(ReferenceField(User, dbref=False))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
User(name='user %s' % i).save()
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()
with no_dereference(Group) as NoDeRefGroup:
self.assertTrue(Group._fields['members']._auto_dereference)
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
with no_dereference(Group) as Group:
group = Group.objects.first()
self.assertTrue(all([not isinstance(m, User)
for m in group.members]))
self.assertFalse(isinstance(group.ref, User))
self.assertFalse(isinstance(group.generic, User))
self.assertTrue(all([isinstance(m, User)
for m in group.members]))
self.assertTrue(isinstance(group.ref, User))
self.assertTrue(isinstance(group.generic, User))
def test_no_dereference_context_manager_dbref(self):
"""Ensure that DBRef items in ListFields aren't dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
ref = ReferenceField(User, dbref=True)
generic = GenericReferenceField()
members = ListField(ReferenceField(User, dbref=True))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
User(name='user %s' % i).save()
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()
with no_dereference(Group) as NoDeRefGroup:
self.assertTrue(Group._fields['members']._auto_dereference)
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
with no_dereference(Group) as Group:
group = Group.objects.first()
self.assertTrue(all([not isinstance(m, User)
for m in group.members]))
self.assertFalse(isinstance(group.ref, User))
self.assertFalse(isinstance(group.generic, User))
self.assertTrue(all([isinstance(m, User)
for m in group.members]))
self.assertTrue(isinstance(group.ref, User))
self.assertTrue(isinstance(group.generic, User))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()