Added no_dereference context manager (#82)
Reorganised the context_managers as well
This commit is contained in:
		| @@ -23,6 +23,7 @@ class BaseField(object): | ||||
|     name = None | ||||
|     _geo_index = False | ||||
|     _auto_gen = False  # Call `generate` to generate a value | ||||
|     _auto_dereference = True | ||||
|  | ||||
|     # These track each time a Field instance is created. Used to retain order. | ||||
|     # The auto_creation_counter is used for fields that MongoEngine implicitly | ||||
| @@ -163,9 +164,11 @@ class ComplexBaseField(BaseField): | ||||
|  | ||||
|         ReferenceField = _import_class('ReferenceField') | ||||
|         GenericReferenceField = _import_class('GenericReferenceField') | ||||
|         dereference = self.field is None or isinstance(self.field, | ||||
|             (GenericReferenceField, ReferenceField)) | ||||
|         if not self._dereference and instance._initialised and dereference: | ||||
|         dereference = (self._auto_dereference and | ||||
|                        (self.field is None or isinstance(self.field, | ||||
|                         (GenericReferenceField, ReferenceField)))) | ||||
|  | ||||
|         if not self.__dereference and instance._initialised and dereference: | ||||
|             instance._data[self.name] = self._dereference( | ||||
|                 instance._data.get(self.name), max_depth=1, instance=instance, | ||||
|                 name=self.name | ||||
| @@ -182,7 +185,8 @@ class ComplexBaseField(BaseField): | ||||
|             value = BaseDict(value, instance, self.name) | ||||
|             instance._data[self.name] = value | ||||
|  | ||||
|         if (instance._initialised and isinstance(value, (BaseList, BaseDict)) | ||||
|         if (self._auto_dereference and instance._initialised and | ||||
|             isinstance(value, (BaseList, BaseDict)) | ||||
|             and not value._dereferenced): | ||||
|             value = self._dereference( | ||||
|                 value, max_depth=1, instance=instance, name=self.name | ||||
|   | ||||
| @@ -11,7 +11,7 @@ def _import_class(cls_name): | ||||
|     field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', | ||||
|                      'FileField', 'GenericReferenceField', | ||||
|                      'GenericEmbeddedDocumentField', 'GeoPointField', | ||||
|                      'ReferenceField', 'StringField') | ||||
|                      'ReferenceField', 'StringField', 'ComplexBaseField') | ||||
|     queryset_classes = ('OperationError',) | ||||
|     deref_classes = ('DeReference',) | ||||
|  | ||||
|   | ||||
| @@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser | ||||
|  | ||||
|  | ||||
| __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||
|            'DEFAULT_CONNECTION_NAME', 'SwitchDB'] | ||||
|            'DEFAULT_CONNECTION_NAME'] | ||||
|  | ||||
|  | ||||
| DEFAULT_CONNECTION_NAME = 'default' | ||||
| @@ -164,47 +164,6 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs): | ||||
|     return get_connection(alias) | ||||
|  | ||||
|  | ||||
| class SwitchDB(object): | ||||
|     """ SwitchDB alias context manager. | ||||
|  | ||||
|     Example :: | ||||
|  | ||||
|         # Register connections | ||||
|         register_connection('default', 'mongoenginetest') | ||||
|         register_connection('testdb-1', 'mongoenginetest2') | ||||
|  | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group(name="test").save()  # Saves in the default db | ||||
|  | ||||
|         with SwitchDB(Group, 'testdb-1') as Group: | ||||
|             Group(name="hello testdb!").save()  # Saves in testdb-1 | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls, db_alias): | ||||
|         """ Construct the SwitchDB context manager | ||||
|  | ||||
|         :param cls: the class to change the registered db | ||||
|         :param db_alias: the name of the specific database to use | ||||
|         """ | ||||
|         self.cls = cls | ||||
|         self.collection = cls._get_collection() | ||||
|         self.db_alias = db_alias | ||||
|         self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """ change the db_alias and clear the cached collection """ | ||||
|         self.cls._meta["db_alias"] = self.db_alias | ||||
|         self.cls._collection = None | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """ Reset the db_alias and collection """ | ||||
|         self.cls._meta["db_alias"] = self.ori_db_alias | ||||
|         self.cls._collection = self.collection | ||||
|  | ||||
| # Support old naming convention | ||||
| _get_connection = get_connection | ||||
| _get_db = get_db | ||||
|   | ||||
							
								
								
									
										159
									
								
								mongoengine/context_managers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								mongoengine/context_managers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.queryset import OperationError, QuerySet | ||||
|  | ||||
| __all__ = ("switch_db", "no_dereference", "query_counter") | ||||
|  | ||||
|  | ||||
| class switch_db(object): | ||||
|     """ switch_db alias context manager. | ||||
|  | ||||
|     Example :: | ||||
|  | ||||
|         # Register connections | ||||
|         register_connection('default', 'mongoenginetest') | ||||
|         register_connection('testdb-1', 'mongoenginetest2') | ||||
|  | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group(name="test").save()  # Saves in the default db | ||||
|  | ||||
|         with switch_db(Group, 'testdb-1') as Group: | ||||
|             Group(name="hello testdb!").save()  # Saves in testdb-1 | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls, db_alias): | ||||
|         """ Construct the switch_db context manager | ||||
|  | ||||
|         :param cls: the class to change the registered db | ||||
|         :param db_alias: the name of the specific database to use | ||||
|         """ | ||||
|         self.cls = cls | ||||
|         self.collection = cls._get_collection() | ||||
|         self.db_alias = db_alias | ||||
|         self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """ change the db_alias and clear the cached collection """ | ||||
|         self.cls._meta["db_alias"] = self.db_alias | ||||
|         self.cls._collection = None | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """ Reset the db_alias and collection """ | ||||
|         self.cls._meta["db_alias"] = self.ori_db_alias | ||||
|         self.cls._collection = self.collection | ||||
|  | ||||
|  | ||||
| class no_dereference(object): | ||||
|     """ no_dereference context manager. | ||||
|  | ||||
|     Turns off all dereferencing in Documents:: | ||||
|  | ||||
|         with no_dereference(Group) as Group: | ||||
|             Group.objects.find() | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls): | ||||
|         """ Construct the no_dereference context manager. | ||||
|  | ||||
|         :param cls: the class to turn dereferencing off on | ||||
|         """ | ||||
|         self.cls = cls | ||||
|  | ||||
|         ReferenceField = _import_class('ReferenceField') | ||||
|         GenericReferenceField = _import_class('GenericReferenceField') | ||||
|         ComplexBaseField = _import_class('ComplexBaseField') | ||||
|  | ||||
|         self.deref_fields = [k for k, v in self.cls._fields.iteritems() | ||||
|                              if isinstance(v, (ReferenceField, | ||||
|                                                GenericReferenceField, | ||||
|                                                ComplexBaseField))] | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """ change the objects default and _auto_dereference values""" | ||||
|         if 'queryset_class' in self.cls._meta: | ||||
|             raise OperationError("no_dereference context manager only works on" | ||||
|                                  " default queryset classes") | ||||
|         objects = self.cls.__dict__['objects'] | ||||
|         objects.default = QuerySetNoDeRef | ||||
|         self.cls.objects = objects | ||||
|         for field in self.deref_fields: | ||||
|             self.cls._fields[field]._auto_dereference = False | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """ Reset the default and _auto_dereference values""" | ||||
|         objects = self.cls.__dict__['objects'] | ||||
|         objects.default = QuerySet | ||||
|         self.cls.objects = objects | ||||
|         for field in self.deref_fields: | ||||
|             self.cls._fields[field]._auto_dereference = True | ||||
|         return self.cls | ||||
|  | ||||
|  | ||||
| class QuerySetNoDeRef(QuerySet): | ||||
|     """Special no_dereference QuerySet""" | ||||
|     def __dereference(items, max_depth=1, instance=None, name=None): | ||||
|             return items | ||||
|  | ||||
|  | ||||
| class query_counter(object): | ||||
|     """ Query_counter contextmanager to get the number of queries. """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         """ Construct the query_counter. """ | ||||
|         self.counter = 0 | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """ On every with block we need to drop the profile collection. """ | ||||
|         self.db.set_profiling_level(0) | ||||
|         self.db.system.profile.drop() | ||||
|         self.db.set_profiling_level(2) | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """ Reset the profiling level. """ | ||||
|         self.db.set_profiling_level(0) | ||||
|  | ||||
|     def __eq__(self, value): | ||||
|         """ == Compare querycounter. """ | ||||
|         return value == self._get_count() | ||||
|  | ||||
|     def __ne__(self, value): | ||||
|         """ != Compare querycounter. """ | ||||
|         return not self.__eq__(value) | ||||
|  | ||||
|     def __lt__(self, value): | ||||
|         """ < Compare querycounter. """ | ||||
|         return self._get_count() < value | ||||
|  | ||||
|     def __le__(self, value): | ||||
|         """ <= Compare querycounter. """ | ||||
|         return self._get_count() <= value | ||||
|  | ||||
|     def __gt__(self, value): | ||||
|         """ > Compare querycounter. """ | ||||
|         return self._get_count() > value | ||||
|  | ||||
|     def __ge__(self, value): | ||||
|         """ >= Compare querycounter. """ | ||||
|         return self._get_count() >= value | ||||
|  | ||||
|     def __int__(self): | ||||
|         """ int representation. """ | ||||
|         return self._get_count() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         """ repr query_counter as the number of queries. """ | ||||
|         return u"%s" % self._get_count() | ||||
|  | ||||
|     def _get_count(self): | ||||
|         """ Get the number of queries. """ | ||||
|         count = self.db.system.profile.find().count() - self.counter | ||||
|         self.counter += 1 | ||||
|         return count | ||||
| @@ -1,15 +1,17 @@ | ||||
| from __future__ import with_statement | ||||
| import warnings | ||||
|  | ||||
| import pymongo | ||||
| import re | ||||
|  | ||||
| from bson.dbref import DBRef | ||||
| from mongoengine import signals, queryset | ||||
|  | ||||
| from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | ||||
|                   BaseDict, BaseList, ALLOW_INHERITANCE, get_document) | ||||
| from queryset import OperationError, NotUniqueError | ||||
| from connection import get_db, DEFAULT_CONNECTION_NAME, SwitchDB | ||||
| from mongoengine import signals | ||||
| from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, | ||||
|                               BaseDocument, BaseDict, BaseList, | ||||
|                               ALLOW_INHERITANCE, get_document) | ||||
| from mongoengine.queryset import OperationError, NotUniqueError | ||||
| from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME | ||||
| from mongoengine.context_managers import switch_db | ||||
|  | ||||
| __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | ||||
|            'DynamicEmbeddedDocument', 'OperationError', | ||||
| @@ -381,11 +383,11 @@ class Document(BaseDocument): | ||||
|             user.save() | ||||
|  | ||||
|         If you need to read from another database see | ||||
|         :class:`~mongoengine.SwitchDB` | ||||
|         :class:`~mongoengine.context_managers.switch_db` | ||||
|  | ||||
|         :param db_alias: The database alias to use for saving the document | ||||
|         """ | ||||
|         with SwitchDB(self.__class__, db_alias) as cls: | ||||
|         with switch_db(self.__class__, db_alias) as cls: | ||||
|             collection = cls._get_collection() | ||||
|             db = cls._get_db | ||||
|         self._get_collection = lambda: collection | ||||
|   | ||||
| @@ -779,7 +779,7 @@ class ReferenceField(BaseField): | ||||
|         value = instance._data.get(self.name) | ||||
|  | ||||
|         # Dereference DBRefs | ||||
|         if isinstance(value, DBRef): | ||||
|         if self._auto_dereference and isinstance(value, DBRef): | ||||
|             value = self.document_type._get_db().dereference(value) | ||||
|             if value is not None: | ||||
|                 instance._data[self.name] = self.document_type._from_son(value) | ||||
|   | ||||
| @@ -18,11 +18,11 @@ class QuerySetManager(object): | ||||
|     """ | ||||
|  | ||||
|     get_queryset = None | ||||
|     default = QuerySet | ||||
|  | ||||
|     def __init__(self, queryset_func=None): | ||||
|         if queryset_func: | ||||
|             self.get_queryset = queryset_func | ||||
|         self._collections = {} | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor for instantiating a new QuerySet object when | ||||
| @@ -33,7 +33,7 @@ class QuerySetManager(object): | ||||
|             return self | ||||
|  | ||||
|         # owner is the document that contains the QuerySetManager | ||||
|         queryset_class = owner._meta.get('queryset_class') or QuerySet | ||||
|         queryset_class = owner._meta.get('queryset_class', self.default) | ||||
|         queryset = queryset_class(owner, owner._get_collection()) | ||||
|         if self.get_queryset: | ||||
|             arg_count = self.get_queryset.func_code.co_argcount | ||||
|   | ||||
| @@ -109,7 +109,6 @@ class QuerySet(object): | ||||
|         queryset._class_check = class_check | ||||
|         return queryset | ||||
|  | ||||
|  | ||||
|     def __iter__(self): | ||||
|         """Support iterator protocol""" | ||||
|         self.rewind() | ||||
|   | ||||
| @@ -1,59 +0,0 @@ | ||||
| from mongoengine.connection import get_db | ||||
|  | ||||
|  | ||||
| class query_counter(object): | ||||
|     """ Query_counter contextmanager to get the number of queries. """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         """ Construct the query_counter. """ | ||||
|         self.counter = 0 | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """ On every with block we need to drop the profile collection. """ | ||||
|         self.db.set_profiling_level(0) | ||||
|         self.db.system.profile.drop() | ||||
|         self.db.set_profiling_level(2) | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """ Reset the profiling level. """ | ||||
|         self.db.set_profiling_level(0) | ||||
|  | ||||
|     def __eq__(self, value): | ||||
|         """ == Compare querycounter. """ | ||||
|         return value == self._get_count() | ||||
|  | ||||
|     def __ne__(self, value): | ||||
|         """ != Compare querycounter. """ | ||||
|         return not self.__eq__(value) | ||||
|  | ||||
|     def __lt__(self, value): | ||||
|         """ < Compare querycounter. """ | ||||
|         return self._get_count() < value | ||||
|  | ||||
|     def __le__(self, value): | ||||
|         """ <= Compare querycounter. """ | ||||
|         return self._get_count() <= value | ||||
|  | ||||
|     def __gt__(self, value): | ||||
|         """ > Compare querycounter. """ | ||||
|         return self._get_count() > value | ||||
|  | ||||
|     def __ge__(self, value): | ||||
|         """ >= Compare querycounter. """ | ||||
|         return self._get_count() >= value | ||||
|  | ||||
|     def __int__(self): | ||||
|         """ int representation. """ | ||||
|         return self._get_count() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         """ repr query_counter as the number of queries. """ | ||||
|         return u"%s" % self._get_count() | ||||
|  | ||||
|     def _get_count(self): | ||||
|         """ Get the number of queries. """ | ||||
|         count = self.db.system.profile.find().count() - self.counter | ||||
|         self.counter += 1 | ||||
|         return count | ||||
		Reference in New Issue
	
	Block a user