Merge pull request #1334 from touilleMan/bug-892

Raise DoesNotExist when dereferencing unknown document
This commit is contained in:
Omer Katz 2016-11-17 11:31:15 +02:00 committed by GitHub
commit 936d2f1f47
2 changed files with 48 additions and 9 deletions

View File

@ -25,7 +25,7 @@ try:
except ImportError: except ImportError:
Int64 = long Int64 = long
from mongoengine.errors import ValidationError from mongoengine.errors import ValidationError, DoesNotExist
from mongoengine.python_support import (PY3, bin_type, txt_type, from mongoengine.python_support import (PY3, bin_type, txt_type,
str_types, StringIO) str_types, StringIO)
from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField,
@ -948,9 +948,11 @@ class ReferenceField(BaseField):
cls = get_document(value.cls) cls = get_document(value.cls)
else: else:
cls = self.document_type cls = self.document_type
value = cls._get_db().dereference(value) dereferenced = cls._get_db().dereference(value)
if value is not None: if dereferenced is None:
instance._data[self.name] = cls._from_son(value) raise DoesNotExist('Trying to dereference unknown document %s' % value)
else:
instance._data[self.name] = cls._from_son(dereferenced)
return super(ReferenceField, self).__get__(instance, owner) return super(ReferenceField, self).__get__(instance, owner)
@ -1094,9 +1096,11 @@ class CachedReferenceField(BaseField):
self._auto_dereference = instance._fields[self.name]._auto_dereference self._auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs # Dereference DBRefs
if self._auto_dereference and isinstance(value, DBRef): if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value) dereferenced = self.document_type._get_db().dereference(value)
if value is not None: if dereferenced is None:
instance._data[self.name] = self.document_type._from_son(value) raise DoesNotExist('Trying to dereference unknown document %s' % value)
else:
instance._data[self.name] = self.document_type._from_son(dereferenced)
return super(CachedReferenceField, self).__get__(instance, owner) return super(CachedReferenceField, self).__get__(instance, owner)
@ -1214,7 +1218,11 @@ class GenericReferenceField(BaseField):
self._auto_dereference = instance._fields[self.name]._auto_dereference self._auto_dereference = instance._fields[self.name]._auto_dereference
if self._auto_dereference and isinstance(value, (dict, SON)): if self._auto_dereference and isinstance(value, (dict, SON)):
instance._data[self.name] = self.dereference(value) dereferenced = self.dereference(value)
if dereferenced is None:
raise DoesNotExist('Trying to dereference unknown document %s' % value)
else:
instance._data[self.name] = dereferenced
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)

View File

@ -31,7 +31,7 @@ from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.base import _document_registry from mongoengine.base import _document_registry
from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList
from mongoengine.errors import NotRegistered from mongoengine.errors import NotRegistered, DoesNotExist
from mongoengine.python_support import PY3, b, bin_type from mongoengine.python_support import PY3, b, bin_type
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
@ -1726,6 +1726,37 @@ class FieldTest(unittest.TestCase):
self.assertEqual(content, User.objects.first().groups[0].content) self.assertEqual(content, User.objects.first().groups[0].content)
def test_reference_miss(self):
"""Ensure an exception is raised when dereferencing unknow document
"""
class Foo(Document):
pass
class Bar(Document):
ref = ReferenceField(Foo)
generic_ref = GenericReferenceField()
Foo.drop_collection()
Bar.drop_collection()
foo = Foo().save()
bar = Bar(ref=foo, generic_ref=foo).save()
# Reference is no longer valid
foo.delete()
bar = Bar.objects.get()
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref'))
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref'))
# When auto_dereference is disabled, there is no trouble returning DBRef
bar = Bar.objects.get()
expected = foo.to_dbref()
bar._fields['ref']._auto_dereference = False
self.assertEqual(bar.ref, expected)
bar._fields['generic_ref']._auto_dereference = False
self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'})
def test_reference_validation(self): def test_reference_validation(self):
"""Ensure that invalid docment objects cannot be assigned to reference """Ensure that invalid docment objects cannot be assigned to reference
fields. fields.