diff --git a/AUTHORS b/AUTHORS index aa044bd2..1271a8d9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -253,3 +253,4 @@ that much better: * Gaurav Dadhania (https://github.com/GVRV) * Yurii Andrieiev (https://github.com/yandrieiev) * Filip Kucharczyk (https://github.com/Pacu2) + * Matthew Simpson (https://github.com/mcsimps2) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f8f527a3..c3d93740 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -41,6 +41,7 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.document import Document, EmbeddedDocument from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError +from mongoengine.mongodb_support import MONGODB_36, get_mongodb_version from mongoengine.python_support import StringIO from mongoengine.queryset import DO_NOTHING from mongoengine.queryset.base import BaseQuerySet @@ -1051,6 +1052,15 @@ def key_has_dot_or_dollar(d): return True +def key_starts_with_dollar(d): + """Helper function to recursively determine if any key in a + dictionary starts with a dollar + """ + for k, v in d.items(): + if (k.startswith("$")) or (isinstance(v, dict) and key_starts_with_dollar(v)): + return True + + class DictField(ComplexBaseField): """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. @@ -1077,11 +1087,18 @@ class DictField(ComplexBaseField): if key_not_string(value): msg = "Invalid dictionary key - documents must have only string keys" self.error(msg) - if key_has_dot_or_dollar(value): + + curr_mongo_ver = get_mongodb_version() + + if curr_mongo_ver < MONGODB_36 and key_has_dot_or_dollar(value): self.error( 'Invalid dictionary key name - keys may not contain "."' ' or startswith "$" characters' ) + elif curr_mongo_ver >= MONGODB_36 and key_starts_with_dollar(value): + self.error( + 'Invalid dictionary key name - keys may not startswith "$" characters' + ) super(DictField, self).validate(value) def lookup_member(self, member_name): diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index e88128f9..44e628f6 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -3,6 +3,7 @@ import pytest from mongoengine import * from mongoengine.base import BaseDict +from mongoengine.mongodb_support import MONGODB_36, get_mongodb_version from tests.utils import MongoDBTestCase, get_as_pymongo @@ -43,11 +44,7 @@ class TestDictField(MongoDBTestCase): with pytest.raises(ValidationError): post.validate() - post.info = {"the.title": "test"} - with pytest.raises(ValidationError): - post.validate() - - post.info = {"nested": {"the.title": "test"}} + post.info = {"$title.test": "test"} with pytest.raises(ValidationError): post.validate() @@ -55,6 +52,20 @@ class TestDictField(MongoDBTestCase): with pytest.raises(ValidationError): post.validate() + post.info = {"nested": {"the.title": "test"}} + if get_mongodb_version() < MONGODB_36: + with pytest.raises(ValidationError): + post.validate() + else: + post.validate() + + post.info = {"dollar_and_dot": {"te$st.test": "test"}} + if get_mongodb_version() < MONGODB_36: + with pytest.raises(ValidationError): + post.validate() + else: + post.validate() + post.info = {"title": "test"} post.save()