Merge pull request #1704 from touilleMan/lazyref-improvements
Improve LazyReferenceField and GenericLazyReferenceField with nested …
This commit is contained in:
		| @@ -13,6 +13,7 @@ from mongoengine import signals | ||||
| from mongoengine.base.common import get_document | ||||
| from mongoengine.base.datastructures import (BaseDict, BaseList, | ||||
|                                              EmbeddedDocumentList, | ||||
|                                              LazyReference, | ||||
|                                              StrictDict) | ||||
| from mongoengine.base.fields import ComplexBaseField | ||||
| from mongoengine.common import _import_class | ||||
| @@ -488,7 +489,7 @@ class BaseDocument(object): | ||||
|                 else: | ||||
|                     data = getattr(data, part, None) | ||||
|  | ||||
|                 if hasattr(data, '_changed_fields'): | ||||
|                 if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'): | ||||
|                     if getattr(data, '_is_document', False): | ||||
|                         continue | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import six | ||||
|  | ||||
| from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, | ||||
|                               TopLevelDocumentMetaclass, get_document) | ||||
| from mongoengine.base.datastructures import LazyReference | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.document import Document, EmbeddedDocument | ||||
| from mongoengine.fields import DictField, ListField, MapField, ReferenceField | ||||
| @@ -99,7 +100,10 @@ class DeReference(object): | ||||
|             if isinstance(item, (Document, EmbeddedDocument)): | ||||
|                 for field_name, field in item._fields.iteritems(): | ||||
|                     v = item._data.get(field_name, None) | ||||
|                     if isinstance(v, DBRef): | ||||
|                     if isinstance(v, LazyReference): | ||||
|                         # LazyReference inherits DBRef but should not be dereferenced here ! | ||||
|                         continue | ||||
|                     elif isinstance(v, DBRef): | ||||
|                         reference_map.setdefault(field.document_type, set()).add(v.id) | ||||
|                     elif isinstance(v, (dict, SON)) and '_ref' in v: | ||||
|                         reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id) | ||||
| @@ -110,6 +114,9 @@ class DeReference(object): | ||||
|                             if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): | ||||
|                                 key = field_cls | ||||
|                             reference_map.setdefault(key, set()).update(refs) | ||||
|             elif isinstance(item, LazyReference): | ||||
|                 # LazyReference inherits DBRef but should not be dereferenced here ! | ||||
|                 continue | ||||
|             elif isinstance(item, DBRef): | ||||
|                 reference_map.setdefault(item.collection, set()).add(item.id) | ||||
|             elif isinstance(item, (dict, SON)) and '_ref' in item: | ||||
|   | ||||
| @@ -28,6 +28,7 @@ except ImportError: | ||||
| from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, | ||||
|                               GeoJsonBaseField, LazyReference, ObjectIdField, | ||||
|                               get_document) | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.document import Document, EmbeddedDocument | ||||
| from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError | ||||
| @@ -789,6 +790,17 @@ class ListField(ComplexBaseField): | ||||
|         kwargs.setdefault('default', lambda: []) | ||||
|         super(ListField, self).__init__(**kwargs) | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         if instance is None: | ||||
|             # Document class being used rather than a document object | ||||
|             return self | ||||
|         value = instance._data.get(self.name) | ||||
|         LazyReferenceField = _import_class('LazyReferenceField') | ||||
|         GenericLazyReferenceField = _import_class('GenericLazyReferenceField') | ||||
|         if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value: | ||||
|             instance._data[self.name] = [self.field.build_lazyref(x) for x in value] | ||||
|         return super(ListField, self).__get__(instance, owner) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Make sure that a list of valid fields is being used.""" | ||||
|         if (not isinstance(value, (list, tuple, QuerySet)) or | ||||
| @@ -2211,17 +2223,10 @@ class LazyReferenceField(BaseField): | ||||
|                 self.document_type_obj = get_document(self.document_type_obj) | ||||
|         return self.document_type_obj | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor to allow lazy dereferencing.""" | ||||
|         if instance is None: | ||||
|             # Document class being used rather than a document object | ||||
|             return self | ||||
|  | ||||
|         value = instance._data.get(self.name) | ||||
|     def build_lazyref(self, value): | ||||
|         if isinstance(value, LazyReference): | ||||
|             if value.passthrough != self.passthrough: | ||||
|                 instance._data[self.name] = LazyReference( | ||||
|                     value.document_type, value.pk, passthrough=self.passthrough) | ||||
|                 value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough) | ||||
|         elif value is not None: | ||||
|             if isinstance(value, self.document_type): | ||||
|                 value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough) | ||||
| @@ -2230,6 +2235,16 @@ class LazyReferenceField(BaseField): | ||||
|             else: | ||||
|                 # value is the primary key of the referenced document | ||||
|                 value = LazyReference(self.document_type, value, passthrough=self.passthrough) | ||||
|         return value | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor to allow lazy dereferencing.""" | ||||
|         if instance is None: | ||||
|             # Document class being used rather than a document object | ||||
|             return self | ||||
|  | ||||
|         value = self.build_lazyref(instance._data.get(self.name)) | ||||
|         if value: | ||||
|             instance._data[self.name] = value | ||||
|  | ||||
|         return super(LazyReferenceField, self).__get__(instance, owner) | ||||
| @@ -2254,7 +2269,7 @@ class LazyReferenceField(BaseField): | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if isinstance(value, LazyReference): | ||||
|             if not issubclass(value.document_type, self.document_type): | ||||
|             if value.collection != self.document_type._get_collection_name(): | ||||
|                 self.error('Reference must be on a `%s` document.' % self.document_type) | ||||
|             pk = value.pk | ||||
|         elif isinstance(value, self.document_type): | ||||
| @@ -2314,23 +2329,26 @@ class GenericLazyReferenceField(GenericReferenceField): | ||||
|  | ||||
|     def _validate_choices(self, value): | ||||
|         if isinstance(value, LazyReference): | ||||
|             value = value.document_type | ||||
|             value = value.document_type._class_name | ||||
|         super(GenericLazyReferenceField, self)._validate_choices(value) | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         if instance is None: | ||||
|             return self | ||||
|  | ||||
|         value = instance._data.get(self.name) | ||||
|     def build_lazyref(self, value): | ||||
|         if isinstance(value, LazyReference): | ||||
|             if value.passthrough != self.passthrough: | ||||
|                 instance._data[self.name] = LazyReference( | ||||
|                     value.document_type, value.pk, passthrough=self.passthrough) | ||||
|                 value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough) | ||||
|         elif value is not None: | ||||
|             if isinstance(value, (dict, SON)): | ||||
|                 value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough) | ||||
|             elif isinstance(value, Document): | ||||
|                 value = LazyReference(type(value), value.pk, passthrough=self.passthrough) | ||||
|         return value | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         if instance is None: | ||||
|             return self | ||||
|  | ||||
|         value = self.build_lazyref(instance._data.get(self.name)) | ||||
|         if value: | ||||
|             instance._data[self.name] = value | ||||
|  | ||||
|         return super(GenericLazyReferenceField, self).__get__(instance, owner) | ||||
| @@ -2348,7 +2366,7 @@ class GenericLazyReferenceField(GenericReferenceField): | ||||
|         if isinstance(document, LazyReference): | ||||
|             return SON(( | ||||
|                 ('_cls', document.document_type._class_name), | ||||
|                 ('_ref', document) | ||||
|                 ('_ref', DBRef(document.document_type._get_collection_name(), document.pk)) | ||||
|             )) | ||||
|         else: | ||||
|             return super(GenericLazyReferenceField, self).to_mongo(document) | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| [nosetests] | ||||
| verbosity=2 | ||||
| detailed-errors=1 | ||||
| tests=tests | ||||
| #tests=tests | ||||
| cover-package=mongoengine | ||||
|  | ||||
| [flake8] | ||||
|   | ||||
| @@ -4871,6 +4871,48 @@ class LazyReferenceFieldTest(MongoDBTestCase): | ||||
|         self.assertNotEqual(animal, other_animalref) | ||||
|         self.assertNotEqual(other_animalref, animal) | ||||
|  | ||||
|     def test_lazy_reference_embedded(self): | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|             tag = StringField() | ||||
|  | ||||
|         class EmbeddedOcurrence(EmbeddedDocument): | ||||
|             in_list = ListField(LazyReferenceField(Animal)) | ||||
|             direct = LazyReferenceField(Animal) | ||||
|  | ||||
|         class Ocurrence(Document): | ||||
|             in_list = ListField(LazyReferenceField(Animal)) | ||||
|             in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) | ||||
|             direct = LazyReferenceField(Animal) | ||||
|  | ||||
|         Animal.drop_collection() | ||||
|         Ocurrence.drop_collection() | ||||
|  | ||||
|         animal1 = Animal('doggo').save() | ||||
|         animal2 = Animal('cheeta').save() | ||||
|  | ||||
|         def check_fields_type(occ): | ||||
|             self.assertIsInstance(occ.direct, LazyReference) | ||||
|             for elem in occ.in_list: | ||||
|                 self.assertIsInstance(elem, LazyReference) | ||||
|             self.assertIsInstance(occ.in_embedded.direct, LazyReference) | ||||
|             for elem in occ.in_embedded.in_list: | ||||
|                 self.assertIsInstance(elem, LazyReference) | ||||
|  | ||||
|         occ = Ocurrence( | ||||
|             in_list=[animal1, animal2], | ||||
|             in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, | ||||
|             direct=animal1 | ||||
|         ).save() | ||||
|         check_fields_type(occ) | ||||
|         occ.reload() | ||||
|         check_fields_type(occ) | ||||
|         occ.direct = animal1.id | ||||
|         occ.in_list = [animal1.id, animal2.id] | ||||
|         occ.in_embedded.direct = animal1.id | ||||
|         occ.in_embedded.in_list = [animal1.id, animal2.id] | ||||
|         check_fields_type(occ) | ||||
|  | ||||
|  | ||||
| class GenericLazyReferenceFieldTest(MongoDBTestCase): | ||||
|     def test_generic_lazy_reference_simple(self): | ||||
| @@ -5051,6 +5093,50 @@ class GenericLazyReferenceFieldTest(MongoDBTestCase): | ||||
|         p = Ocurrence.objects.get() | ||||
|         self.assertIs(p.animal, None) | ||||
|  | ||||
|     def test_generic_lazy_reference_embedded(self): | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|             tag = StringField() | ||||
|  | ||||
|         class EmbeddedOcurrence(EmbeddedDocument): | ||||
|             in_list = ListField(GenericLazyReferenceField()) | ||||
|             direct = GenericLazyReferenceField() | ||||
|  | ||||
|         class Ocurrence(Document): | ||||
|             in_list = ListField(GenericLazyReferenceField()) | ||||
|             in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) | ||||
|             direct = GenericLazyReferenceField() | ||||
|  | ||||
|         Animal.drop_collection() | ||||
|         Ocurrence.drop_collection() | ||||
|  | ||||
|         animal1 = Animal('doggo').save() | ||||
|         animal2 = Animal('cheeta').save() | ||||
|  | ||||
|         def check_fields_type(occ): | ||||
|             self.assertIsInstance(occ.direct, LazyReference) | ||||
|             for elem in occ.in_list: | ||||
|                 self.assertIsInstance(elem, LazyReference) | ||||
|             self.assertIsInstance(occ.in_embedded.direct, LazyReference) | ||||
|             for elem in occ.in_embedded.in_list: | ||||
|                 self.assertIsInstance(elem, LazyReference) | ||||
|  | ||||
|         occ = Ocurrence( | ||||
|             in_list=[animal1, animal2], | ||||
|             in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, | ||||
|             direct=animal1 | ||||
|         ).save() | ||||
|         check_fields_type(occ) | ||||
|         occ.reload() | ||||
|         check_fields_type(occ) | ||||
|         animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} | ||||
|         animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} | ||||
|         occ.direct = animal1_ref | ||||
|         occ.in_list = [animal1_ref, animal2_ref] | ||||
|         occ.in_embedded.direct = animal1_ref | ||||
|         occ.in_embedded.in_list = [animal1_ref, animal2_ref] | ||||
|         check_fields_type(occ) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user