From aabc18755c591fac355d9966bce5758de594c4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 7 Oct 2020 00:01:09 +0200 Subject: [PATCH 1/6] fix inconsistencies in ._changed_fields computation --- mongoengine/base/document.py | 42 +++++++++++++++++++++++++++++++++--- tests/document/test_delta.py | 34 ++++++++++++++++++++++++++--- tests/test_dereference.py | 4 ++-- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index e697fe40..55b40228 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -537,6 +537,9 @@ class BaseDocument: """Using _get_changed_fields iterate and remove any fields that are marked as changed. """ + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") + for changed in self._get_changed_fields(): parts = changed.split(".") data = self @@ -549,7 +552,8 @@ class BaseDocument: elif isinstance(data, dict): data = data.get(part, None) else: - data = getattr(data, part, None) + field_name = data._reverse_db_field_map.get(part, part) + data = getattr(data, field_name, None) if not isinstance(data, LazyReference) and hasattr( data, "_changed_fields" @@ -558,10 +562,40 @@ class BaseDocument: continue data._changed_fields = [] + elif isinstance(data, (list, tuple, dict)): + if hasattr(data, "field") and isinstance( + data.field, (ReferenceField, GenericReferenceField) + ): + continue + BaseDocument._nestable_types_clear_changed_fields(data) self._changed_fields = [] - def _nestable_types_changed_fields(self, changed_fields, base_key, data): + @staticmethod + def _nestable_types_clear_changed_fields(data): + """Inspect nested data for changed fields + + :param data: data to inspect for changes + """ + Document = _import_class("Document") + + # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(data, "items"): + iterator = enumerate(data) + else: + iterator = data.items() + + for index_or_key, value in iterator: + if hasattr(value, "_get_changed_fields") and not isinstance( + value, Document + ): # don't follow references + value._clear_changed_fields() + elif isinstance(value, (list, tuple, dict)): + BaseDocument._nestable_types_clear_changed_fields(value) + + @staticmethod + def _nestable_types_changed_fields(changed_fields, base_key, data): """Inspect nested data for changed fields :param changed_fields: Previously collected changed fields @@ -586,7 +620,9 @@ class BaseDocument: changed = value._get_changed_fields() changed_fields += ["{}{}".format(item_key, k) for k in changed if k] elif isinstance(value, (list, tuple, dict)): - self._nestable_types_changed_fields(changed_fields, item_key, value) + BaseDocument._nestable_types_changed_fields( + changed_fields, item_key, value + ) def _get_changed_fields(self): """Return a list of all fields that have explicitly been changed. diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 2324211b..27439bc2 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -537,6 +537,7 @@ class TestDelta(MongoDBTestCase): {}, ) doc.save() + assert doc._get_changed_fields() == [] doc = doc.reload(10) assert doc.embedded_field.list_field[0] == "1" @@ -767,9 +768,7 @@ class TestDelta(MongoDBTestCase): MyDoc.drop_collection() - mydoc = MyDoc( - name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}} - ).save() + MyDoc(name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}}).save() mydoc = MyDoc.objects.first() subdoc = mydoc.subs["a"]["b"] @@ -781,6 +780,35 @@ class TestDelta(MongoDBTestCase): mydoc._clear_changed_fields() assert [] == mydoc._get_changed_fields() + def test_nested_nested_fields_db_field_set__gets_mark_as_changed_and_cleaned(self): + class EmbeddedDoc(EmbeddedDocument): + name = StringField(db_field="db_name") + + class MyDoc(Document): + embed = EmbeddedDocumentField(EmbeddedDoc, db_field="db_embed") + name = StringField(db_field="db_name") + + MyDoc.drop_collection() + + MyDoc(name="testcase1", embed=EmbeddedDoc(name="foo")).save() + + mydoc = MyDoc.objects.first() + mydoc.embed.name = "foo1" + + assert mydoc.embed._get_changed_fields() == ["db_name"] + assert mydoc._get_changed_fields() == ["db_embed.db_name"] + + mydoc = MyDoc.objects.first() + embed = EmbeddedDoc(name="foo2") + embed.name = "bar" + mydoc.embed = embed + + assert embed._get_changed_fields() == ["db_name"] + assert mydoc._get_changed_fields() == ["db_embed"] + + mydoc._clear_changed_fields() + assert mydoc._get_changed_fields() == [] + def test_lower_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): name = StringField() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 0f9f412c..8ba429f4 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -370,8 +370,7 @@ class FieldTest(unittest.TestCase): assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]] def test_circular_reference(self): - """Ensure you can handle circular references - """ + """Ensure you can handle circular references""" class Relation(EmbeddedDocument): name = StringField() @@ -426,6 +425,7 @@ class FieldTest(unittest.TestCase): daughter.relations.append(mother) daughter.relations.append(daughter) + assert daughter._get_changed_fields() == ["relations"] daughter.save() assert "[, ]" == "%s" % Person.objects() From c9d53ca5d50e79bf1e8fb7afd012984b91ff5824 Mon Sep 17 00:00:00 2001 From: Mateusz Stankiewicz Date: Fri, 30 Oct 2020 13:06:37 +0100 Subject: [PATCH 2/6] Add EnumField --- AUTHORS | 1 + docs/changelog.rst | 1 + docs/guide/defining-documents.rst | 1 + mongoengine/fields.py | 52 +++++++++++++++ tests/fields/test_enum_field.py | 103 ++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+) create mode 100644 tests/fields/test_enum_field.py diff --git a/AUTHORS b/AUTHORS index 02e43955..10d04c68 100644 --- a/AUTHORS +++ b/AUTHORS @@ -257,3 +257,4 @@ that much better: * Matthew Simpson (https://github.com/mcsimps2) * Leonardo Domingues (https://github.com/leodmgs) * Agustin Barto (https://github.com/abarto) + * Stankiewicz Mateusz (https://github.com/mas15) diff --git a/docs/changelog.rst b/docs/changelog.rst index f616f4a6..f63edc61 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,7 @@ Development - Fix the behavior of Doc.objects.limit(0) which should return all documents (similar to mongodb) #2311 - Bug fix in ListField when updating the first item, it was saving the whole list, instead of just replacing the first item (as it's usually done) #2392 +- Add EnumField: ``mongoengine.fields.EnumField`` Changes in 0.20.0 ================= diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index f5c70728..7fc20ba8 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -76,6 +76,7 @@ are as follows: * :class:`~mongoengine.fields.EmailField` * :class:`~mongoengine.fields.EmbeddedDocumentField` * :class:`~mongoengine.fields.EmbeddedDocumentListField` +* :class:`~mongoengine.fields.EnumField` * :class:`~mongoengine.fields.FileField` * :class:`~mongoengine.fields.FloatField` * :class:`~mongoengine.fields.GenericEmbeddedDocumentField` diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c5926cbd..1ce66055 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -87,6 +87,7 @@ __all__ = ( "PolygonField", "SequenceField", "UUIDField", + "EnumField", "MultiPointField", "MultiLineStringField", "MultiPolygonField", @@ -1622,6 +1623,57 @@ class BinaryField(BaseField): return super().prepare_query_value(op, self.to_mongo(value)) +class EnumField(BaseField): + """ Enumeration Field. Values are stored underneath as strings. + Example usage: + .. code-block:: python + + class Status(Enum): + NEW = 'new' + DONE = 'done' + + class ModelWithEnum(Document): + status = EnumField(Status, default=Status.NEW) + + ModelWithEnum(status='done') + ModelWithEnum(status=Status.DONE) + + Enum fields can be searched using enum or its value: + .. code-block:: python + + ModelWithEnum.objects(status='new').count() + ModelWithEnum.objects(status=Status.NEW).count() + """ + def __init__(self, enum, **kwargs): + self._enum_cls = enum + kwargs["choices"] = list(self._enum_cls) + super().__init__(**kwargs) + + def __set__(self, instance, value): + if value is None or isinstance(value, self._enum_cls): + value = value # if it is proper enum or none then fine + else: # if it not, then try to create enum of it + value = self._enum_cls(value) + return super().__set__(instance, value) + + def to_mongo(self, value): + if isinstance(value, self._enum_cls): + return str(value.value) + return str(value) + + def validate(self, value): + if not isinstance(value, self._enum_cls): + self.error( + "EnumField only accepts instances of " + "(%s)" % self._enum_cls + ) + + def prepare_query_value(self, op, value): + if value is None: + return value + return super().prepare_query_value(op, self.to_mongo(value)) + + class GridFSError(Exception): pass diff --git a/tests/fields/test_enum_field.py b/tests/fields/test_enum_field.py new file mode 100644 index 00000000..00404aa5 --- /dev/null +++ b/tests/fields/test_enum_field.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +from enum import Enum + +import pytest + +from mongoengine import * +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class Status(Enum): + NEW = 'new' + DONE = 'done' + + +class ModelWithEnum(Document): + status = EnumField(Status) + + +class Color(Enum): + RED = 1 + BLUE = 2 + + +class ModelWithColor(Document): + color = EnumField(Color, default=Color.RED) + + +class TestEnumField(MongoDBTestCase): + def test_storage(self): + model = ModelWithEnum(status=Status.NEW).save() + assert get_as_pymongo(model) == {"_id": model.id, "status": 'new'} + + def test_set_enum(self): + ModelWithEnum.drop_collection() + m = ModelWithEnum(status=Status.NEW).save() + assert ModelWithEnum.objects(status=Status.NEW).count() == 1 + assert ModelWithEnum.objects.first().status == Status.NEW + m.validate() + + def test_set_by_value(self): + ModelWithEnum.drop_collection() + ModelWithEnum(status='new').save() + assert ModelWithEnum.objects.first().status == Status.NEW + + def test_filter(self): + ModelWithEnum.drop_collection() + ModelWithEnum(status='new').save() + assert ModelWithEnum.objects(status='new').count() == 1 + assert ModelWithEnum.objects(status=Status.NEW).count() == 1 + assert ModelWithEnum.objects(status=Status.DONE).count() == 0 + + def test_change_value(self): + m = ModelWithEnum(status='new') + m.status = Status.DONE + m.validate() + assert m.status == Status.DONE + + def test_set_default(self): + class ModelWithDefault(Document): + status = EnumField(Status, default=Status.DONE) + + m = ModelWithDefault() + m.validate() + m.save() + assert m.status == Status.DONE + + def test_enum_with_int(self): + m = ModelWithColor() + m.validate() + m.save() + assert m.color == Color.RED + assert ModelWithColor.objects(color=Color.RED).count() == 1 + assert ModelWithColor.objects(color=1).count() == 1 + assert ModelWithColor.objects(color=2).count() == 0 + + def test_storage_enum_with_int(self): + model = ModelWithColor(color=Color.BLUE).save() + assert get_as_pymongo(model) == {"_id": model.id, "color": "2"} + + def test_enum_field_can_be_empty(self): + m = ModelWithEnum() + m.validate() + m.save() + assert m.status is None + assert ModelWithEnum.objects()[0].status is None + assert ModelWithEnum.objects(status=None).count() == 1 + + def test_cannot_create_model_with_wrong_enum_value(self): + with pytest.raises(ValueError): + ModelWithEnum(status='wrong_one') + + def test_cannot_create_model_with_wrong_enum_type(self): + with pytest.raises(ValueError): + ModelWithColor(color='wrong_type') + + def test_cannot_create_model_with_wrong_enum_value_int(self): + with pytest.raises(ValueError): + ModelWithColor(color=3) + + def test_cannot_set_wrong_enum_value(self): + m = ModelWithEnum(status='new') + with pytest.raises(ValueError): + m.status = 'wrong' From f4962fbc40060ba205cd2aae6cc1acd53295f4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 30 Oct 2020 21:10:21 +0100 Subject: [PATCH 3/6] remove utf8 encoding declaration in test files as it's not needed/recommended --- docs/conf.py | 1 - tests/document/test_class_methods.py | 1 - tests/document/test_delta.py | 1 - tests/document/test_indexes.py | 1 - tests/document/test_inheritance.py | 1 - tests/document/test_instance.py | 1 - tests/document/test_validation.py | 1 - tests/fields/test_binary_field.py | 1 - tests/fields/test_boolean_field.py | 1 - tests/fields/test_cached_reference_field.py | 1 - tests/fields/test_complex_datetime_field.py | 1 - tests/fields/test_date_field.py | 1 - tests/fields/test_datetime_field.py | 1 - tests/fields/test_decimal_field.py | 1 - tests/fields/test_dict_field.py | 1 - tests/fields/test_email_field.py | 1 - tests/fields/test_embedded_document_field.py | 1 - tests/fields/test_fields.py | 1 - tests/fields/test_file_field.py | 1 - tests/fields/test_float_field.py | 1 - tests/fields/test_geo_fields.py | 1 - tests/fields/test_int_field.py | 1 - tests/fields/test_lazy_reference_field.py | 1 - tests/fields/test_map_field.py | 1 - tests/fields/test_reference_field.py | 1 - tests/fields/test_sequence_field.py | 2 -- tests/fields/test_url_field.py | 1 - tests/fields/test_uuid_field.py | 1 - tests/queryset/test_queryset.py | 2 -- tests/queryset/test_queryset_aggregation.py | 2 -- tests/test_dereference.py | 1 - tests/test_signals.py | 1 - 32 files changed, 35 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 48c8e859..fdb5b61d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # MongoEngine documentation build configuration file, created by # sphinx-quickstart on Sun Nov 22 18:14:13 2009. diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index be883b2a..0276457c 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from mongoengine import * diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 94f89f99..ed3bb67b 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from bson import SON diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 45d1cd23..726eebe5 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from datetime import datetime diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index 53a1489b..e7901f05 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest import warnings diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 8d42d15b..9554659c 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import pickle import unittest diff --git a/tests/document/test_validation.py b/tests/document/test_validation.py index 2439f283..c4228a96 100644 --- a/tests/document/test_validation.py +++ b/tests/document/test_validation.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from datetime import datetime diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py index a9c0c7e5..4f7af325 100644 --- a/tests/fields/test_binary_field.py +++ b/tests/fields/test_binary_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import uuid from bson import Binary diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py index 041f9f56..737e0dbf 100644 --- a/tests/fields/test_boolean_field.py +++ b/tests/fields/test_boolean_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from mongoengine import * diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index bb4c57d2..dd804b38 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from decimal import Decimal import pytest diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py index d118ad23..d8ae9175 100644 --- a/tests/fields/test_complex_datetime_field.py +++ b/tests/fields/test_complex_datetime_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import datetime import itertools import math diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index 42a4b7f1..a98f222a 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import datetime import pytest diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index 48936af7..0858548c 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import datetime as dt import pytest diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py index c531166f..a7cd09a2 100644 --- a/tests/fields/test_decimal_field.py +++ b/tests/fields/test_decimal_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from decimal import Decimal import pytest diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 12140916..05482107 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from bson import InvalidDocument import pytest diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index 5a58ede4..893180a4 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import sys import pytest diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index 13ca9c0b..e116dc0d 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from mongoengine import ( diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index fe349d1e..aa530ced 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import datetime import unittest diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index cbac9b69..de10c987 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import copy import os import tempfile diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py index 839494a9..817dcfeb 100644 --- a/tests/fields/test_float_field.py +++ b/tests/fields/test_float_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from mongoengine import * diff --git a/tests/fields/test_geo_fields.py b/tests/fields/test_geo_fields.py index 7618b3a0..7960178e 100644 --- a/tests/fields/test_geo_fields.py +++ b/tests/fields/test_geo_fields.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from mongoengine import * diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py index 1f9c5a77..529ae4db 100644 --- a/tests/fields/test_int_field.py +++ b/tests/fields/test_int_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from mongoengine import * diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index 50e60262..07d4d3a9 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from bson import DBRef, ObjectId import pytest diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py index 8b8b1c46..ea60e34d 100644 --- a/tests/fields/test_map_field.py +++ b/tests/fields/test_map_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import datetime import pytest diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index 949eac67..24401ce0 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from bson import DBRef, SON import pytest diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py index 81d648fd..b6b2917f 100644 --- a/tests/fields/test_sequence_field.py +++ b/tests/fields/test_sequence_field.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from mongoengine import * from tests.utils import MongoDBTestCase diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index c449e467..50845b90 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from mongoengine import * diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py index 21b7a090..0f3d2d84 100644 --- a/tests/fields/test_uuid_field.py +++ b/tests/fields/test_uuid_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import uuid import pytest diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 73c419b3..a2440302 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import datetime import unittest import uuid diff --git a/tests/queryset/test_queryset_aggregation.py b/tests/queryset/test_queryset_aggregation.py index 00e04a36..146501d0 100644 --- a/tests/queryset/test_queryset_aggregation.py +++ b/tests/queryset/test_queryset_aggregation.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import unittest import warnings diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 8ba429f4..c40cc0bd 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from bson import DBRef, ObjectId diff --git a/tests/test_signals.py b/tests/test_signals.py index 64976e25..19707165 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import unittest from mongoengine import * From 9e40f3ae8328e2f755a8c681c5ff892d4e462d90 Mon Sep 17 00:00:00 2001 From: Mateusz Stankiewicz Date: Sat, 31 Oct 2020 10:47:20 +0100 Subject: [PATCH 4/6] PR ammends --- mongoengine/fields.py | 38 ++++++--- tests/fields/test_enum_field.py | 146 ++++++++++++++++---------------- 2 files changed, 99 insertions(+), 85 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1ce66055..69277d06 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -848,8 +848,7 @@ class DynamicField(BaseField): Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" def to_mongo(self, value, use_db_field=True, fields=None): - """Convert a Python type to a MongoDB compatible type. - """ + """Convert a Python type to a MongoDB compatible type.""" if isinstance(value, str): return value @@ -1624,7 +1623,7 @@ class BinaryField(BaseField): class EnumField(BaseField): - """ Enumeration Field. Values are stored underneath as strings. + """Enumeration Field. Values are stored underneath as strings. Example usage: .. code-block:: python @@ -1643,30 +1642,41 @@ class EnumField(BaseField): ModelWithEnum.objects(status='new').count() ModelWithEnum.objects(status=Status.NEW).count() + + Note that choices cannot be set explicitly, they are derived + from the provided enum class. """ + def __init__(self, enum, **kwargs): self._enum_cls = enum + if "choices" in kwargs: + raise ValueError( + "'choices' can't be set on EnumField, " + "it is implicitly set as the enum class" + ) kwargs["choices"] = list(self._enum_cls) super().__init__(**kwargs) def __set__(self, instance, value): - if value is None or isinstance(value, self._enum_cls): - value = value # if it is proper enum or none then fine - else: # if it not, then try to create enum of it - value = self._enum_cls(value) + is_legal_value = value is None or isinstance(value, self._enum_cls) + if not is_legal_value: + try: + value = self._enum_cls(value) + except Exception: + pass return super().__set__(instance, value) def to_mongo(self, value): if isinstance(value, self._enum_cls): - return str(value.value) - return str(value) + return value.value + return value def validate(self, value): - if not isinstance(value, self._enum_cls): - self.error( - "EnumField only accepts instances of " - "(%s)" % self._enum_cls - ) + if value and not isinstance(value, self._enum_cls): + try: + self._enum_cls(value) + except Exception as e: + self.error(str(e)) def prepare_query_value(self, op, value): if value is None: diff --git a/tests/fields/test_enum_field.py b/tests/fields/test_enum_field.py index 00404aa5..1f89b9bf 100644 --- a/tests/fields/test_enum_field.py +++ b/tests/fields/test_enum_field.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from enum import Enum import pytest @@ -8,14 +7,72 @@ from tests.utils import MongoDBTestCase, get_as_pymongo class Status(Enum): - NEW = 'new' - DONE = 'done' - - + NEW = "new" + DONE = "done" + + class ModelWithEnum(Document): status = EnumField(Status) +class TestStringEnumField(MongoDBTestCase): + def test_storage(self): + model = ModelWithEnum(status=Status.NEW).save() + assert get_as_pymongo(model) == {"_id": model.id, "status": "new"} + + def test_set_enum(self): + ModelWithEnum.drop_collection() + ModelWithEnum(status=Status.NEW).save() + assert ModelWithEnum.objects(status=Status.NEW).count() == 1 + assert ModelWithEnum.objects.first().status == Status.NEW + + def test_set_by_value(self): + ModelWithEnum.drop_collection() + ModelWithEnum(status="new").save() + assert ModelWithEnum.objects.first().status == Status.NEW + + def test_filter(self): + ModelWithEnum.drop_collection() + ModelWithEnum(status="new").save() + assert ModelWithEnum.objects(status="new").count() == 1 + assert ModelWithEnum.objects(status=Status.NEW).count() == 1 + assert ModelWithEnum.objects(status=Status.DONE).count() == 0 + + def test_change_value(self): + m = ModelWithEnum(status="new") + m.status = Status.DONE + m.save() + assert m.status == Status.DONE + + def test_set_default(self): + class ModelWithDefault(Document): + status = EnumField(Status, default=Status.DONE) + + m = ModelWithDefault().save() + assert m.status == Status.DONE + + def test_enum_field_can_be_empty(self): + ModelWithEnum.drop_collection() + m = ModelWithEnum().save() + assert m.status is None + assert ModelWithEnum.objects()[0].status is None + assert ModelWithEnum.objects(status=None).count() == 1 + + def test_set_none_explicitly(self): + ModelWithEnum.drop_collection() + ModelWithEnum(status=None).save() + assert ModelWithEnum.objects.first().status is None + + def test_cannot_create_model_with_wrong_enum_value(self): + m = ModelWithEnum(status="wrong_one") + with pytest.raises(ValidationError): + m.validate() + + def test_user_is_informed_when_tries_to_set_choices(self): + with pytest.raises(ValueError, match="'choices' can't be set on EnumField"): + EnumField(Status, choices=["my", "custom", "options"]) + + class Color(Enum): RED = 1 BLUE = 2 @@ -25,79 +82,26 @@ class ModelWithColor(Document): color = EnumField(Color, default=Color.RED) -class TestEnumField(MongoDBTestCase): - def test_storage(self): - model = ModelWithEnum(status=Status.NEW).save() - assert get_as_pymongo(model) == {"_id": model.id, "status": 'new'} - - def test_set_enum(self): - ModelWithEnum.drop_collection() - m = ModelWithEnum(status=Status.NEW).save() - assert ModelWithEnum.objects(status=Status.NEW).count() == 1 - assert ModelWithEnum.objects.first().status == Status.NEW - m.validate() - - def test_set_by_value(self): - ModelWithEnum.drop_collection() - ModelWithEnum(status='new').save() - assert ModelWithEnum.objects.first().status == Status.NEW - - def test_filter(self): - ModelWithEnum.drop_collection() - ModelWithEnum(status='new').save() - assert ModelWithEnum.objects(status='new').count() == 1 - assert ModelWithEnum.objects(status=Status.NEW).count() == 1 - assert ModelWithEnum.objects(status=Status.DONE).count() == 0 - - def test_change_value(self): - m = ModelWithEnum(status='new') - m.status = Status.DONE - m.validate() - assert m.status == Status.DONE - - def test_set_default(self): - class ModelWithDefault(Document): - status = EnumField(Status, default=Status.DONE) - - m = ModelWithDefault() - m.validate() - m.save() - assert m.status == Status.DONE - +class TestIntEnumField(MongoDBTestCase): def test_enum_with_int(self): - m = ModelWithColor() - m.validate() - m.save() + ModelWithColor.drop_collection() + m = ModelWithColor().save() assert m.color == Color.RED assert ModelWithColor.objects(color=Color.RED).count() == 1 assert ModelWithColor.objects(color=1).count() == 1 assert ModelWithColor.objects(color=2).count() == 0 + def test_create_int_enum_by_value(self): + model = ModelWithColor(color=2).save() + assert model.color == Color.BLUE + def test_storage_enum_with_int(self): model = ModelWithColor(color=Color.BLUE).save() - assert get_as_pymongo(model) == {"_id": model.id, "color": "2"} + assert get_as_pymongo(model) == {"_id": model.id, "color": 2} - def test_enum_field_can_be_empty(self): - m = ModelWithEnum() - m.validate() - m.save() - assert m.status is None - assert ModelWithEnum.objects()[0].status is None - assert ModelWithEnum.objects(status=None).count() == 1 + def test_validate_model(self): + with pytest.raises(ValidationError, match="Value must be one of"): + ModelWithColor(color=3).validate() - def test_cannot_create_model_with_wrong_enum_value(self): - with pytest.raises(ValueError): - ModelWithEnum(status='wrong_one') - - def test_cannot_create_model_with_wrong_enum_type(self): - with pytest.raises(ValueError): - ModelWithColor(color='wrong_type') - - def test_cannot_create_model_with_wrong_enum_value_int(self): - with pytest.raises(ValueError): - ModelWithColor(color=3) - - def test_cannot_set_wrong_enum_value(self): - m = ModelWithEnum(status='new') - with pytest.raises(ValueError): - m.status = 'wrong' + with pytest.raises(ValidationError, match="Value must be one of"): + ModelWithColor(color="wrong_type").validate() From 8ef721342683bf8c4622b22fa31f231f9dae9706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 1 Nov 2020 14:05:58 +0100 Subject: [PATCH 5/6] improve EnumField Doc and add quick test --- mongoengine/fields.py | 4 +++- tests/fields/test_enum_field.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 69277d06..8915d801 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1623,7 +1623,9 @@ class BinaryField(BaseField): class EnumField(BaseField): - """Enumeration Field. Values are stored underneath as strings. + """Enumeration Field. Values are stored underneath as is, + so it will only work with simple types (str, int, etc) that + are bson encodable Example usage: .. code-block:: python diff --git a/tests/fields/test_enum_field.py b/tests/fields/test_enum_field.py index 1f89b9bf..384e9afd 100644 --- a/tests/fields/test_enum_field.py +++ b/tests/fields/test_enum_field.py @@ -1,5 +1,6 @@ from enum import Enum +from bson import InvalidDocument import pytest from mongoengine import * @@ -105,3 +106,17 @@ class TestIntEnumField(MongoDBTestCase): with pytest.raises(ValidationError, match="Value must be one of"): ModelWithColor(color="wrong_type").validate() + + +class TestFunkyEnumField(MongoDBTestCase): + def test_enum_incompatible_bson_type_fails_during_save(self): + class FunkyColor(Enum): + YELLOW = object() + + class ModelWithFunkyColor(Document): + color = EnumField(FunkyColor) + + m = ModelWithFunkyColor(color=FunkyColor.YELLOW) + + with pytest.raises(InvalidDocument, match="cannot encode object"): + m.save() From 94a7e813b1de820531603113e38fa95746010f63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 1 Nov 2020 19:37:13 +0100 Subject: [PATCH 6/6] fix difference in test for certain version of pymongo --- tests/fields/test_enum_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fields/test_enum_field.py b/tests/fields/test_enum_field.py index 384e9afd..fc42487b 100644 --- a/tests/fields/test_enum_field.py +++ b/tests/fields/test_enum_field.py @@ -118,5 +118,5 @@ class TestFunkyEnumField(MongoDBTestCase): m = ModelWithFunkyColor(color=FunkyColor.YELLOW) - with pytest.raises(InvalidDocument, match="cannot encode object"): + with pytest.raises(InvalidDocument, match="[cC]annot encode object"): m.save()