Ensures EmbeddedDocumentField does not accepts references to Document classes in its constructor
This commit is contained in:
parent
bf2de81873
commit
5dbee2a270
@ -645,9 +645,17 @@ class EmbeddedDocumentField(BaseField):
|
||||
def document_type(self):
|
||||
if isinstance(self.document_type_obj, six.string_types):
|
||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||
self.document_type_obj = self.owner_document
|
||||
resolved_document_type = self.owner_document
|
||||
else:
|
||||
self.document_type_obj = get_document(self.document_type_obj)
|
||||
resolved_document_type = get_document(self.document_type_obj)
|
||||
|
||||
if not issubclass(resolved_document_type, EmbeddedDocument):
|
||||
# Due to the late resolution of the document_type
|
||||
# There is a chance that it won't be an EmbeddedDocument (#1661)
|
||||
self.error('Invalid embedded document class provided to an '
|
||||
'EmbeddedDocumentField')
|
||||
self.document_type_obj = resolved_document_type
|
||||
|
||||
return self.document_type_obj
|
||||
|
||||
def to_python(self, value):
|
||||
|
@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase):
|
||||
]))
|
||||
self.assertEqual(a.b.c.txt, 'hi')
|
||||
|
||||
def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self):
|
||||
raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet")
|
||||
|
||||
class MyDoc2(Document):
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
|
||||
class MyDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
def test_embedded_document_validation(self):
|
||||
"""Ensure that invalid embedded documents cannot be assigned to
|
||||
embedded document fields.
|
||||
@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
|
||||
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
|
||||
|
||||
|
||||
class TestEmbeddedDocumentField(MongoDBTestCase):
|
||||
def test___init___(self):
|
||||
class MyDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
field = EmbeddedDocumentField(MyDoc)
|
||||
self.assertEqual(field.document_type_obj, MyDoc)
|
||||
|
||||
field2 = EmbeddedDocumentField('MyDoc')
|
||||
self.assertEqual(field2.document_type_obj, 'MyDoc')
|
||||
|
||||
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
|
||||
with self.assertRaises(ValidationError):
|
||||
EmbeddedDocumentField(dict)
|
||||
|
||||
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
|
||||
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
with self.assertRaises(ValidationError) as ctx:
|
||||
emb.document_type
|
||||
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
|
||||
|
||||
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
|
||||
# Relates to #1661
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingDoc(Document):
|
||||
emb = EmbeddedDocumentField(MyDoc)
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingdoc2(Document):
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
|
||||
class CachedReferenceFieldTest(MongoDBTestCase):
|
||||
|
||||
def test_cached_reference_field_get_and_save(self):
|
||||
|
@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
|
||||
|
||||
MONGO_TEST_DB = 'mongoenginetest'
|
||||
MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
|
||||
|
||||
|
||||
class MongoDBTestCase(unittest.TestCase):
|
||||
"""Base class for tests that need a mongodb connection
|
||||
db is being dropped automatically
|
||||
It ensures that the db is clean at the beginning and dropped at the end automatically
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -32,6 +32,7 @@ def get_mongodb_version():
|
||||
"""
|
||||
return tuple(get_connection().server_info()['versionArray'])
|
||||
|
||||
|
||||
def _decorated_with_ver_requirement(func, ver_tuple):
|
||||
"""Return a given function decorated with the version requirement
|
||||
for a particular MongoDB version tuple.
|
||||
@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple):
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def needs_mongodb_v26(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v2.6.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, (2, 6))
|
||||
|
||||
|
||||
def needs_mongodb_v3(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v3.0.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, (3, 0))
|
||||
|
||||
|
||||
def skip_pymongo3(f):
|
||||
"""Raise a SkipTest exception if we're running a test against
|
||||
PyMongo v3.x.
|
||||
|
Loading…
x
Reference in New Issue
Block a user