From adb60ef1acc8a8baf3476ce49a44c5ec73335fd2 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 24 Sep 2012 18:45:02 +0000 Subject: [PATCH] Improved import cache --- mongoengine/base.py | 61 +++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 6e4cd917..773c1d4c 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -235,7 +235,8 @@ class BaseField(object): pass def _validate(self, value): - from mongoengine import Document, EmbeddedDocument + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') # check choices if self.choices: is_cls = isinstance(value, (Document, EmbeddedDocument)) @@ -283,7 +284,9 @@ class ComplexBaseField(BaseField): if instance is None: # Document class being used rather than a document object return self - from fields import GenericReferenceField, ReferenceField + + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') dereference = self.field is None or isinstance(self.field, (GenericReferenceField, ReferenceField)) if not self._dereference and instance._initialised and dereference: @@ -310,6 +313,7 @@ class ComplexBaseField(BaseField): ) value._dereferenced = True instance._data[self.name] = value + return value def __set__(self, instance, value): @@ -321,7 +325,7 @@ class ComplexBaseField(BaseField): def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ - from mongoengine import Document + Document = _import_class('Document') if isinstance(value, basestring): return value @@ -363,7 +367,7 @@ class ComplexBaseField(BaseField): def to_mongo(self, value): """Convert a Python type to a MongoDB-compatible type. """ - from mongoengine import Document + Document = _import_class("Document") if isinstance(value, basestring): return value @@ -399,7 +403,7 @@ class ComplexBaseField(BaseField): meta.get('allow_inheritance', ALLOW_INHERITANCE) == False) if allow_inheritance and not self.field: - from fields import GenericReferenceField + GenericReferenceField = _import_class("GenericReferenceField") value_dict[k] = GenericReferenceField().to_mongo(v) else: collection = v._get_collection_name() @@ -460,7 +464,7 @@ class ComplexBaseField(BaseField): @property def _dereference(self,): if not self.__dereference: - from dereference import DeReference + DeReference = _import_class("DeReference") self.__dereference = DeReference() # Cached return self.__dereference @@ -943,7 +947,7 @@ class BaseDocument(object): field = None if not hasattr(self, name) and not name.startswith('_'): - from fields import DynamicField + DynamicField = _import_class("DynamicField") field = DynamicField(db_field=name) field.name = name self._dynamic_fields[name] = field @@ -1121,7 +1125,8 @@ class BaseDocument(object): def _get_changed_fields(self, key='', inspected=None): """Returns a list of all fields that have explicitly been changed. """ - from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument + EmbeddedDocument = _import_class("EmbeddedDocument") + DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) @@ -1252,7 +1257,9 @@ class BaseDocument(object): geo_indices = [] inspected.append(cls) - from fields import EmbeddedDocumentField, GeoPointField + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GeoPointField = _import_class("GeoPointField") + for field in cls._fields.values(): if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): continue @@ -1486,14 +1493,30 @@ def _import_class(cls_name): """Cached mechanism for imports""" if cls_name in _class_registry: return _class_registry.get(cls_name) - if cls_name == 'Document': - from mongoengine.document import Document as cls - elif cls_name == 'EmbeddedDocument': - from mongoengine.document import EmbeddedDocument as cls - elif cls_name == 'DictField': - from mongoengine.fields import DictField as cls - elif cls_name == 'OperationError': - from queryset import OperationError as cls - _class_registry[cls_name] = cls - return cls + doc_classes = ['Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument'] + field_classes = ['DictField', 'DynamicField', 'EmbeddedDocumentField', + 'GenericReferenceField', 'GeoPointField', + 'ReferenceField'] + queryset_classes = ['OperationError'] + deref_classes = ['DeReference'] + + if cls_name in doc_classes: + from mongoengine import document as module + import_classes = doc_classes + elif cls_name in field_classes: + from mongoengine import fields as module + import_classes = field_classes + elif cls_name in queryset_classes: + from mongoengine import queryset as module + import_classes = queryset_classes + elif cls_name in deref_classes: + from mongoengine import dereference as module + import_classes = deref_classes + else: + raise ValueError('No import set for: ' % cls_name) + + for cls in import_classes: + _class_registry[cls] = getattr(module, cls) + + return _class_registry.get(cls_name)