Merge pull request #1859 from bagerard/fix_EmbeddedDocumentField_init_with_Document

Ensures EmbeddedDocumentField does not accepts references to Document
This commit is contained in:
erdenezul 2018-09-03 17:22:08 +08:00 committed by GitHub
commit 3574e21e4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 4 deletions

View File

@ -647,9 +647,17 @@ class EmbeddedDocumentField(BaseField):
def document_type(self): def document_type(self):
if isinstance(self.document_type_obj, six.string_types): if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document resolved_document_type = self.owner_document
else: else:
self.document_type_obj = get_document(self.document_type_obj) resolved_document_type = get_document(self.document_type_obj)
if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
# There is a chance that it won't be an EmbeddedDocument (#1661)
self.error('Invalid embedded document class provided to an '
'EmbeddedDocumentField')
self.document_type_obj = resolved_document_type
return self.document_type_obj return self.document_type_obj
def to_python(self, value): def to_python(self, value):

View File

@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase):
])) ]))
self.assertEqual(a.b.c.txt, 'hi') self.assertEqual(a.b.c.txt, 'hi')
def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self):
raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet")
class MyDoc2(Document):
emb = EmbeddedDocumentField('MyDoc')
class MyDoc(EmbeddedDocument):
name = StringField()
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to """Ensure that invalid embedded documents cannot be assigned to
embedded document fields. embedded document fields.
@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
class TestEmbeddedDocumentField(MongoDBTestCase):
def test___init___(self):
class MyDoc(EmbeddedDocument):
name = StringField()
field = EmbeddedDocumentField(MyDoc)
self.assertEqual(field.document_type_obj, MyDoc)
field2 = EmbeddedDocumentField('MyDoc')
self.assertEqual(field2.document_type_obj, 'MyDoc')
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
with self.assertRaises(ValidationError):
EmbeddedDocumentField(dict)
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
class MyDoc(Document):
name = StringField()
emb = EmbeddedDocumentField('MyDoc')
with self.assertRaises(ValidationError) as ctx:
emb.document_type
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
# Relates to #1661
class MyDoc(Document):
name = StringField()
with self.assertRaises(ValidationError):
class MyFailingDoc(Document):
emb = EmbeddedDocumentField(MyDoc)
with self.assertRaises(ValidationError):
class MyFailingdoc2(Document):
emb = EmbeddedDocumentField('MyDoc')
class CachedReferenceFieldTest(MongoDBTestCase): class CachedReferenceFieldTest(MongoDBTestCase):
def test_cached_reference_field_get_and_save(self): def test_cached_reference_field_get_and_save(self):

View File

@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection
from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.python_support import IS_PYMONGO_3
MONGO_TEST_DB = 'mongoenginetest' MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
class MongoDBTestCase(unittest.TestCase): class MongoDBTestCase(unittest.TestCase):
"""Base class for tests that need a mongodb connection """Base class for tests that need a mongodb connection
db is being dropped automatically It ensures that the db is clean at the beginning and dropped at the end automatically
""" """
@classmethod @classmethod
@ -32,6 +32,7 @@ def get_mongodb_version():
""" """
return tuple(get_connection().server_info()['versionArray']) return tuple(get_connection().server_info()['versionArray'])
def _decorated_with_ver_requirement(func, ver_tuple): def _decorated_with_ver_requirement(func, ver_tuple):
"""Return a given function decorated with the version requirement """Return a given function decorated with the version requirement
for a particular MongoDB version tuple. for a particular MongoDB version tuple.
@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple):
return _inner return _inner
def needs_mongodb_v26(func): def needs_mongodb_v26(func):
"""Raise a SkipTest exception if we're working with MongoDB version """Raise a SkipTest exception if we're working with MongoDB version
lower than v2.6. lower than v2.6.
""" """
return _decorated_with_ver_requirement(func, (2, 6)) return _decorated_with_ver_requirement(func, (2, 6))
def needs_mongodb_v3(func): def needs_mongodb_v3(func):
"""Raise a SkipTest exception if we're working with MongoDB version """Raise a SkipTest exception if we're working with MongoDB version
lower than v3.0. lower than v3.0.
""" """
return _decorated_with_ver_requirement(func, (3, 0)) return _decorated_with_ver_requirement(func, (3, 0))
def skip_pymongo3(f): def skip_pymongo3(f):
"""Raise a SkipTest exception if we're running a test against """Raise a SkipTest exception if we're running a test against
PyMongo v3.x. PyMongo v3.x.