Merge pull request #1334 from touilleMan/bug-892
Raise DoesNotExist when dereferencing unknown document
This commit is contained in:
commit
936d2f1f47
@ -25,7 +25,7 @@ try:
|
||||
except ImportError:
|
||||
Int64 = long
|
||||
|
||||
from mongoengine.errors import ValidationError
|
||||
from mongoengine.errors import ValidationError, DoesNotExist
|
||||
from mongoengine.python_support import (PY3, bin_type, txt_type,
|
||||
str_types, StringIO)
|
||||
from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField,
|
||||
@ -948,9 +948,11 @@ class ReferenceField(BaseField):
|
||||
cls = get_document(value.cls)
|
||||
else:
|
||||
cls = self.document_type
|
||||
value = cls._get_db().dereference(value)
|
||||
if value is not None:
|
||||
instance._data[self.name] = cls._from_son(value)
|
||||
dereferenced = cls._get_db().dereference(value)
|
||||
if dereferenced is None:
|
||||
raise DoesNotExist('Trying to dereference unknown document %s' % value)
|
||||
else:
|
||||
instance._data[self.name] = cls._from_son(dereferenced)
|
||||
|
||||
return super(ReferenceField, self).__get__(instance, owner)
|
||||
|
||||
@ -1094,9 +1096,11 @@ class CachedReferenceField(BaseField):
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
# Dereference DBRefs
|
||||
if self._auto_dereference and isinstance(value, DBRef):
|
||||
value = self.document_type._get_db().dereference(value)
|
||||
if value is not None:
|
||||
instance._data[self.name] = self.document_type._from_son(value)
|
||||
dereferenced = self.document_type._get_db().dereference(value)
|
||||
if dereferenced is None:
|
||||
raise DoesNotExist('Trying to dereference unknown document %s' % value)
|
||||
else:
|
||||
instance._data[self.name] = self.document_type._from_son(dereferenced)
|
||||
|
||||
return super(CachedReferenceField, self).__get__(instance, owner)
|
||||
|
||||
@ -1214,7 +1218,11 @@ class GenericReferenceField(BaseField):
|
||||
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
if self._auto_dereference and isinstance(value, (dict, SON)):
|
||||
instance._data[self.name] = self.dereference(value)
|
||||
dereferenced = self.dereference(value)
|
||||
if dereferenced is None:
|
||||
raise DoesNotExist('Trying to dereference unknown document %s' % value)
|
||||
else:
|
||||
instance._data[self.name] = dereferenced
|
||||
|
||||
return super(GenericReferenceField, self).__get__(instance, owner)
|
||||
|
||||
|
@ -31,7 +31,7 @@ from mongoengine import *
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.base import _document_registry
|
||||
from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList
|
||||
from mongoengine.errors import NotRegistered
|
||||
from mongoengine.errors import NotRegistered, DoesNotExist
|
||||
from mongoengine.python_support import PY3, b, bin_type
|
||||
|
||||
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
|
||||
@ -1726,6 +1726,37 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(content, User.objects.first().groups[0].content)
|
||||
|
||||
def test_reference_miss(self):
|
||||
"""Ensure an exception is raised when dereferencing unknow document
|
||||
"""
|
||||
|
||||
class Foo(Document):
|
||||
pass
|
||||
|
||||
class Bar(Document):
|
||||
ref = ReferenceField(Foo)
|
||||
generic_ref = GenericReferenceField()
|
||||
|
||||
Foo.drop_collection()
|
||||
Bar.drop_collection()
|
||||
|
||||
foo = Foo().save()
|
||||
bar = Bar(ref=foo, generic_ref=foo).save()
|
||||
|
||||
# Reference is no longer valid
|
||||
foo.delete()
|
||||
bar = Bar.objects.get()
|
||||
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref'))
|
||||
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref'))
|
||||
|
||||
# When auto_dereference is disabled, there is no trouble returning DBRef
|
||||
bar = Bar.objects.get()
|
||||
expected = foo.to_dbref()
|
||||
bar._fields['ref']._auto_dereference = False
|
||||
self.assertEqual(bar.ref, expected)
|
||||
bar._fields['generic_ref']._auto_dereference = False
|
||||
self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'})
|
||||
|
||||
def test_reference_validation(self):
|
||||
"""Ensure that invalid docment objects cannot be assigned to reference
|
||||
fields.
|
||||
|
Loading…
x
Reference in New Issue
Block a user