diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9d189235..47e33d1d 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -25,7 +25,7 @@ try: except ImportError: Int64 = long -from mongoengine.errors import ValidationError +from mongoengine.errors import ValidationError, DoesNotExist from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, @@ -948,9 +948,11 @@ class ReferenceField(BaseField): cls = get_document(value.cls) else: cls = self.document_type - value = cls._get_db().dereference(value) - if value is not None: - instance._data[self.name] = cls._from_son(value) + dereferenced = cls._get_db().dereference(value) + if dereferenced is None: + raise DoesNotExist('Trying to dereference unknown document %s' % value) + else: + instance._data[self.name] = cls._from_son(dereferenced) return super(ReferenceField, self).__get__(instance, owner) @@ -1094,9 +1096,11 @@ class CachedReferenceField(BaseField): self._auto_dereference = instance._fields[self.name]._auto_dereference # Dereference DBRefs if self._auto_dereference and isinstance(value, DBRef): - value = self.document_type._get_db().dereference(value) - if value is not None: - instance._data[self.name] = self.document_type._from_son(value) + dereferenced = self.document_type._get_db().dereference(value) + if dereferenced is None: + raise DoesNotExist('Trying to dereference unknown document %s' % value) + else: + instance._data[self.name] = self.document_type._from_son(dereferenced) return super(CachedReferenceField, self).__get__(instance, owner) @@ -1214,7 +1218,11 @@ class GenericReferenceField(BaseField): self._auto_dereference = instance._fields[self.name]._auto_dereference if self._auto_dereference and isinstance(value, (dict, SON)): - instance._data[self.name] = self.dereference(value) + dereferenced = self.dereference(value) + if dereferenced is None: + raise DoesNotExist('Trying to dereference unknown document %s' % value) + else: + instance._data[self.name] = dereferenced return super(GenericReferenceField, self).__get__(instance, owner) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 3b3fe857..36b9f4cd 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -31,7 +31,7 @@ from mongoengine import * from mongoengine.connection import get_db from mongoengine.base import _document_registry from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList -from mongoengine.errors import NotRegistered +from mongoengine.errors import NotRegistered, DoesNotExist from mongoengine.python_support import PY3, b, bin_type __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") @@ -1726,6 +1726,37 @@ class FieldTest(unittest.TestCase): self.assertEqual(content, User.objects.first().groups[0].content) + def test_reference_miss(self): + """Ensure an exception is raised when dereferencing unknow document + """ + + class Foo(Document): + pass + + class Bar(Document): + ref = ReferenceField(Foo) + generic_ref = GenericReferenceField() + + Foo.drop_collection() + Bar.drop_collection() + + foo = Foo().save() + bar = Bar(ref=foo, generic_ref=foo).save() + + # Reference is no longer valid + foo.delete() + bar = Bar.objects.get() + self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref')) + self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref')) + + # When auto_dereference is disabled, there is no trouble returning DBRef + bar = Bar.objects.get() + expected = foo.to_dbref() + bar._fields['ref']._auto_dereference = False + self.assertEqual(bar.ref, expected) + bar._fields['generic_ref']._auto_dereference = False + self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'}) + def test_reference_validation(self): """Ensure that invalid docment objects cannot be assigned to reference fields.