diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f8ab73d0..658d0c79 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -13,6 +13,7 @@ from mongoengine import signals from mongoengine.base.common import get_document from mongoengine.base.datastructures import (BaseDict, BaseList, EmbeddedDocumentList, + LazyReference, StrictDict) from mongoengine.base.fields import ComplexBaseField from mongoengine.common import _import_class @@ -488,7 +489,7 @@ class BaseDocument(object): else: data = getattr(data, part, None) - if hasattr(data, '_changed_fields'): + if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'): if getattr(data, '_is_document', False): continue diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 59204d4d..7fe34e43 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -3,6 +3,7 @@ import six from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document) +from mongoengine.base.datastructures import LazyReference from mongoengine.connection import get_db from mongoengine.document import Document, EmbeddedDocument from mongoengine.fields import DictField, ListField, MapField, ReferenceField @@ -99,7 +100,10 @@ class DeReference(object): if isinstance(item, (Document, EmbeddedDocument)): for field_name, field in item._fields.iteritems(): v = item._data.get(field_name, None) - if isinstance(v, DBRef): + if isinstance(v, LazyReference): + # LazyReference inherits DBRef but should not be dereferenced here ! + continue + elif isinstance(v, DBRef): reference_map.setdefault(field.document_type, set()).add(v.id) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id) @@ -110,6 +114,9 @@ class DeReference(object): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): key = field_cls reference_map.setdefault(key, set()).update(refs) + elif isinstance(item, LazyReference): + # LazyReference inherits DBRef but should not be dereferenced here ! + continue elif isinstance(item, DBRef): reference_map.setdefault(item.collection, set()).add(item.id) elif isinstance(item, (dict, SON)) and '_ref' in item: diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 6c4a06c9..8ca2b17f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -28,6 +28,7 @@ except ImportError: from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField, LazyReference, ObjectIdField, get_document) +from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.document import Document, EmbeddedDocument from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError @@ -789,6 +790,17 @@ class ListField(ComplexBaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) + def __get__(self, instance, owner): + if instance is None: + # Document class being used rather than a document object + return self + value = instance._data.get(self.name) + LazyReferenceField = _import_class('LazyReferenceField') + GenericLazyReferenceField = _import_class('GenericLazyReferenceField') + if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value: + instance._data[self.name] = [self.field.build_lazyref(x) for x in value] + return super(ListField, self).__get__(instance, owner) + def validate(self, value): """Make sure that a list of valid fields is being used.""" if (not isinstance(value, (list, tuple, QuerySet)) or @@ -2211,17 +2223,10 @@ class LazyReferenceField(BaseField): self.document_type_obj = get_document(self.document_type_obj) return self.document_type_obj - def __get__(self, instance, owner): - """Descriptor to allow lazy dereferencing.""" - if instance is None: - # Document class being used rather than a document object - return self - - value = instance._data.get(self.name) + def build_lazyref(self, value): if isinstance(value, LazyReference): if value.passthrough != self.passthrough: - instance._data[self.name] = LazyReference( - value.document_type, value.pk, passthrough=self.passthrough) + value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough) elif value is not None: if isinstance(value, self.document_type): value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough) @@ -2230,6 +2235,16 @@ class LazyReferenceField(BaseField): else: # value is the primary key of the referenced document value = LazyReference(self.document_type, value, passthrough=self.passthrough) + return value + + def __get__(self, instance, owner): + """Descriptor to allow lazy dereferencing.""" + if instance is None: + # Document class being used rather than a document object + return self + + value = self.build_lazyref(instance._data.get(self.name)) + if value: instance._data[self.name] = value return super(LazyReferenceField, self).__get__(instance, owner) @@ -2254,7 +2269,7 @@ class LazyReferenceField(BaseField): def validate(self, value): if isinstance(value, LazyReference): - if not issubclass(value.document_type, self.document_type): + if value.collection != self.document_type._get_collection_name(): self.error('Reference must be on a `%s` document.' % self.document_type) pk = value.pk elif isinstance(value, self.document_type): @@ -2314,23 +2329,26 @@ class GenericLazyReferenceField(GenericReferenceField): def _validate_choices(self, value): if isinstance(value, LazyReference): - value = value.document_type + value = value.document_type._class_name super(GenericLazyReferenceField, self)._validate_choices(value) - def __get__(self, instance, owner): - if instance is None: - return self - - value = instance._data.get(self.name) + def build_lazyref(self, value): if isinstance(value, LazyReference): if value.passthrough != self.passthrough: - instance._data[self.name] = LazyReference( - value.document_type, value.pk, passthrough=self.passthrough) + value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough) elif value is not None: if isinstance(value, (dict, SON)): value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough) elif isinstance(value, Document): value = LazyReference(type(value), value.pk, passthrough=self.passthrough) + return value + + def __get__(self, instance, owner): + if instance is None: + return self + + value = self.build_lazyref(instance._data.get(self.name)) + if value: instance._data[self.name] = value return super(GenericLazyReferenceField, self).__get__(instance, owner) @@ -2348,7 +2366,7 @@ class GenericLazyReferenceField(GenericReferenceField): if isinstance(document, LazyReference): return SON(( ('_cls', document.document_type._class_name), - ('_ref', document) + ('_ref', DBRef(document.document_type._get_collection_name(), document.pk)) )) else: return super(GenericLazyReferenceField, self).to_mongo(document) diff --git a/setup.cfg b/setup.cfg index 46edff3b..fd6192b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [nosetests] verbosity=2 detailed-errors=1 -tests=tests +#tests=tests cover-package=mongoengine [flake8] diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 632f5404..ffee25e6 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -4871,6 +4871,48 @@ class LazyReferenceFieldTest(MongoDBTestCase): self.assertNotEqual(animal, other_animalref) self.assertNotEqual(other_animalref, animal) + def test_lazy_reference_embedded(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class EmbeddedOcurrence(EmbeddedDocument): + in_list = ListField(LazyReferenceField(Animal)) + direct = LazyReferenceField(Animal) + + class Ocurrence(Document): + in_list = ListField(LazyReferenceField(Animal)) + in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) + direct = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal1 = Animal('doggo').save() + animal2 = Animal('cheeta').save() + + def check_fields_type(occ): + self.assertIsInstance(occ.direct, LazyReference) + for elem in occ.in_list: + self.assertIsInstance(elem, LazyReference) + self.assertIsInstance(occ.in_embedded.direct, LazyReference) + for elem in occ.in_embedded.in_list: + self.assertIsInstance(elem, LazyReference) + + occ = Ocurrence( + in_list=[animal1, animal2], + in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, + direct=animal1 + ).save() + check_fields_type(occ) + occ.reload() + check_fields_type(occ) + occ.direct = animal1.id + occ.in_list = [animal1.id, animal2.id] + occ.in_embedded.direct = animal1.id + occ.in_embedded.in_list = [animal1.id, animal2.id] + check_fields_type(occ) + class GenericLazyReferenceFieldTest(MongoDBTestCase): def test_generic_lazy_reference_simple(self): @@ -5051,6 +5093,50 @@ class GenericLazyReferenceFieldTest(MongoDBTestCase): p = Ocurrence.objects.get() self.assertIs(p.animal, None) + def test_generic_lazy_reference_embedded(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class EmbeddedOcurrence(EmbeddedDocument): + in_list = ListField(GenericLazyReferenceField()) + direct = GenericLazyReferenceField() + + class Ocurrence(Document): + in_list = ListField(GenericLazyReferenceField()) + in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) + direct = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal1 = Animal('doggo').save() + animal2 = Animal('cheeta').save() + + def check_fields_type(occ): + self.assertIsInstance(occ.direct, LazyReference) + for elem in occ.in_list: + self.assertIsInstance(elem, LazyReference) + self.assertIsInstance(occ.in_embedded.direct, LazyReference) + for elem in occ.in_embedded.in_list: + self.assertIsInstance(elem, LazyReference) + + occ = Ocurrence( + in_list=[animal1, animal2], + in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, + direct=animal1 + ).save() + check_fields_type(occ) + occ.reload() + check_fields_type(occ) + animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} + animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} + occ.direct = animal1_ref + occ.in_list = [animal1_ref, animal2_ref] + occ.in_embedded.direct = animal1_ref + occ.in_embedded.in_list = [animal1_ref, animal2_ref] + check_fields_type(occ) + if __name__ == '__main__': unittest.main()