diff --git a/mongoengine/fields.py b/mongoengine/fields.py index e29f55ab..c20333a8 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,5 +1,6 @@ import datetime import decimal +import inspect import itertools import re import socket @@ -514,7 +515,7 @@ class BooleanField(BaseField): def to_python(self, value): try: value = bool(value) - except ValueError: + except (ValueError, TypeError): pass return value @@ -1028,17 +1029,6 @@ def key_not_string(d): return True -def key_has_dot_or_dollar(d): - """Helper function to recursively determine if any key in a - dictionary contains a dot or a dollar sign. - """ - for k, v in d.items(): - if ("." in k or k.startswith("$")) or ( - isinstance(v, dict) and key_has_dot_or_dollar(v) - ): - return True - - def key_starts_with_dollar(d): """Helper function to recursively determine if any key in a dictionary starts with a dollar @@ -1311,8 +1301,8 @@ class CachedReferenceField(BaseField): fields = [] # XXX ValidationError raised outside of the "validate" method. - if not isinstance(document_type, str) and not issubclass( - document_type, Document + if not isinstance(document_type, str) and not ( + inspect.isclass(document_type) and issubclass(document_type, Document) ): self.error( "Argument to CachedReferenceField constructor must be a" @@ -1642,7 +1632,7 @@ class EnumField(BaseField): "'choices' can't be set on EnumField, " "it is implicitly set as the enum class" ) - kwargs["choices"] = list(self._enum_cls) + kwargs["choices"] = list(self._enum_cls) # Implicit validator super().__init__(**kwargs) def __set__(self, instance, value): @@ -1659,13 +1649,6 @@ class EnumField(BaseField): return value.value return value - def validate(self, value): - if value and not isinstance(value, self._enum_cls): - try: - self._enum_cls(value) - except Exception as e: - self.error(str(e)) - def prepare_query_value(self, op, value): if value is None: return value diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py index 737e0dbf..d82d149f 100644 --- a/tests/fields/test_boolean_field.py +++ b/tests/fields/test_boolean_field.py @@ -13,6 +13,17 @@ class TestBooleanField(MongoDBTestCase): person.save() assert get_as_pymongo(person) == {"_id": person.id, "admin": True} + def test_construction_does_not_fail_uncastable_value(self): + class BoolFail: + def __bool__(self): + return "bogus" + + class Person(Document): + admin = BooleanField() + + person = Person(admin=BoolFail()) + person.admin == "bogus" + def test_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index 7a96bc06..9d0b387b 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -2,11 +2,28 @@ from decimal import Decimal import pytest -from mongoengine import * +from mongoengine import ( + CachedReferenceField, + DecimalField, + Document, + EmbeddedDocument, + EmbeddedDocumentField, + InvalidDocumentError, + ListField, + ReferenceField, + StringField, + ValidationError, +) from tests.utils import MongoDBTestCase class TestCachedReferenceField(MongoDBTestCase): + def test_constructor_fail_bad_document_type(self): + with pytest.raises( + ValidationError, match="must be a document class or a string" + ): + CachedReferenceField(document_type=0) + def test_get_and_save(self): """ Tests #1047: CachedReferenceField creates DBRefs on to_python, diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py index 519356e9..89a725a9 100644 --- a/tests/fields/test_decimal_field.py +++ b/tests/fields/test_decimal_field.py @@ -2,59 +2,11 @@ from decimal import Decimal import pytest -from mongoengine import * +from mongoengine import DecimalField, Document, ValidationError from tests.utils import MongoDBTestCase class TestDecimalField(MongoDBTestCase): - def test_validation(self): - """Ensure that invalid values cannot be assigned to decimal fields.""" - - class Person(Document): - height = DecimalField(min_value=Decimal("0.1"), max_value=Decimal("3.5")) - - Person.drop_collection() - - Person(height=Decimal("1.89")).save() - person = Person.objects.first() - assert person.height == Decimal("1.89") - - person.height = "2.0" - person.save() - person.height = 0.01 - with pytest.raises(ValidationError): - person.validate() - person.height = Decimal("0.01") - with pytest.raises(ValidationError): - person.validate() - person.height = Decimal("4.0") - with pytest.raises(ValidationError): - person.validate() - person.height = "something invalid" - with pytest.raises(ValidationError): - person.validate() - - person_2 = Person(height="something invalid") - with pytest.raises(ValidationError): - person_2.validate() - - def test_comparison(self): - class Person(Document): - money = DecimalField() - - Person.drop_collection() - - Person(money=6).save() - Person(money=7).save() - Person(money=8).save() - Person(money=10).save() - - assert 2 == Person.objects(money__gt=Decimal("7")).count() - assert 2 == Person.objects(money__gt=7).count() - assert 2 == Person.objects(money__gt="7").count() - - assert 3 == Person.objects(money__gte="7").count() - def test_storage(self): class Person(Document): float_value = DecimalField(precision=4) @@ -106,3 +58,63 @@ class TestDecimalField(MongoDBTestCase): for field_name in ["float_value", "string_value"]: actual = list(Person.objects().scalar(field_name)) assert expected == actual + + def test_save_none(self): + class Person(Document): + value = DecimalField() + + Person.drop_collection() + + person = Person(value=None) + assert person.value is None + person.save() + fetched_person = Person.objects.first() + fetched_person.value is None + + def test_validation(self): + """Ensure that invalid values cannot be assigned to decimal fields.""" + + class Person(Document): + height = DecimalField(min_value=Decimal("0.1"), max_value=Decimal("3.5")) + + Person.drop_collection() + + Person(height=Decimal("1.89")).save() + person = Person.objects.first() + assert person.height == Decimal("1.89") + + person.height = "2.0" + person.save() + person.height = 0.01 + with pytest.raises(ValidationError): + person.validate() + person.height = Decimal("0.01") + with pytest.raises(ValidationError): + person.validate() + person.height = Decimal("4.0") + with pytest.raises(ValidationError): + person.validate() + person.height = "something invalid" + with pytest.raises(ValidationError): + person.validate() + + person_2 = Person(height="something invalid") + with pytest.raises(ValidationError): + person_2.validate() + + def test_comparison(self): + class Person(Document): + money = DecimalField() + + Person.drop_collection() + + Person(money=6).save() + Person(money=7).save() + Person(money=8).save() + Person(money=10).save() + + assert 2 == Person.objects(money__gt=Decimal("7")).count() + assert 2 == Person.objects(money__gt=7).count() + assert 2 == Person.objects(money__gt="7").count() + + assert 3 == Person.objects(money__gte="7").count() diff --git a/tests/fields/test_enum_field.py b/tests/fields/test_enum_field.py index fc42487b..c101996c 100644 --- a/tests/fields/test_enum_field.py +++ b/tests/fields/test_enum_field.py @@ -3,7 +3,7 @@ from enum import Enum from bson import InvalidDocument import pytest -from mongoengine import * +from mongoengine import Document, EnumField, ValidationError from tests.utils import MongoDBTestCase, get_as_pymongo @@ -45,6 +45,11 @@ class TestStringEnumField(MongoDBTestCase): m.save() assert m.status == Status.DONE + m.status = "wrong" + assert m.status == "wrong" + with pytest.raises(ValidationError): + m.validate() + def test_set_default(self): class ModelWithDefault(Document): status = EnumField(Status, default=Status.DONE) diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index 266d0e9d..f7272619 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -374,34 +374,6 @@ class TestField(MongoDBTestCase): person.id = str(ObjectId()) person.validate() - def test_string_validation(self): - """Ensure that invalid values cannot be assigned to string fields.""" - - class Person(Document): - name = StringField(max_length=20) - userid = StringField(r"[0-9a-z_]+$") - - person = Person(name=34) - with pytest.raises(ValidationError): - person.validate() - - # Test regex validation on userid - person = Person(userid="test.User") - with pytest.raises(ValidationError): - person.validate() - - person.userid = "test_user" - assert person.userid == "test_user" - person.validate() - - # Test max length validation on name - person = Person(name="Name that is more than twenty characters") - with pytest.raises(ValidationError): - person.validate() - - person.name = "Shorter name" - person.validate() - def test_db_field_validation(self): """Ensure that db_field doesn't accept invalid values.""" diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py index 67fd0bce..4a3041f2 100644 --- a/tests/fields/test_long_field.py +++ b/tests/fields/test_long_field.py @@ -4,10 +4,26 @@ import pytest from mongoengine import * from mongoengine.connection import get_db -from tests.utils import MongoDBTestCase +from tests.utils import MongoDBTestCase, get_as_pymongo class TestLongField(MongoDBTestCase): + def test_storage(self): + class Person(Document): + value = LongField() + + Person.drop_collection() + person = Person(value=5000) + person.save() + assert get_as_pymongo(person) == {"_id": person.id, "value": 5000} + + def test_construction_does_not_fail_with_invalid_value(self): + class Person(Document): + value = LongField() + + person = Person(value="not_an_int") + assert person.value == "not_an_int" + def test_long_field_is_considered_as_int64(self): """ Tests that long fields are stored as long in mongo, even if long @@ -30,19 +46,16 @@ class TestLongField(MongoDBTestCase): class TestDocument(Document): value = LongField(min_value=0, max_value=110) - doc = TestDocument() - doc.value = 50 - doc.validate() + TestDocument(value=50).validate() - doc.value = -1 with pytest.raises(ValidationError): - doc.validate() - doc.value = 120 + TestDocument(value=-1).validate() + with pytest.raises(ValidationError): - doc.validate() - doc.value = "ten" + TestDocument(value=120).validate() + with pytest.raises(ValidationError): - doc.validate() + TestDocument(value="ten").validate() def test_long_ne_operator(self): class TestDocument(Document): @@ -53,4 +66,5 @@ class TestLongField(MongoDBTestCase): TestDocument(long_fld=None).save() TestDocument(long_fld=1).save() - assert 1 == TestDocument.objects(long_fld__ne=None).count() + assert TestDocument.objects(long_fld__ne=None).count() == 1 + assert TestDocument.objects(long_fld__ne=1).count() == 1 diff --git a/tests/fields/test_string_field.py b/tests/fields/test_string_field.py new file mode 100644 index 00000000..6e1d77f2 --- /dev/null +++ b/tests/fields/test_string_field.py @@ -0,0 +1,43 @@ +import pytest + +from mongoengine import * +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class TestStringField(MongoDBTestCase): + def test_storage(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + person = Person(name="test123") + person.save() + assert get_as_pymongo(person) == {"_id": person.id, "name": "test123"} + + def test_validation(self): + class Person(Document): + name = StringField(max_length=20, min_length=2) + userid = StringField(r"[0-9a-z_]+$") + + with pytest.raises(ValidationError, match="only accepts string values"): + Person(name=34).validate() + + with pytest.raises(ValidationError, match="value is too short"): + Person(name="s").validate() + + # Test regex validation on userid + person = Person(userid="test.User") + with pytest.raises(ValidationError): + person.validate() + + person.userid = "test_user" + assert person.userid == "test_user" + person.validate() + + # Test max length validation on name + person = Person(name="Name that is more than twenty characters") + with pytest.raises(ValidationError): + person.validate() + + person = Person(name="a friendl name", userid="7a757668sqjdkqlsdkq") + person.validate()