import datetime import unittest import pytest from bson import SON, DBRef, ObjectId from mongoengine import ( BooleanField, ComplexDateTimeField, DateField, DateTimeField, DictField, Document, DoesNotExist, DynamicDocument, DynamicField, EmbeddedDocument, EmbeddedDocumentField, EmbeddedDocumentListField, FieldDoesNotExist, FloatField, GenericLazyReferenceField, GenericReferenceField, IntField, LazyReferenceField, ListField, MultipleObjectsReturned, NotRegistered, NotUniqueError, ObjectIdField, OperationError, ReferenceField, SortedListField, StringField, ValidationError, ) from mongoengine.base import ( BaseField, EmbeddedDocumentList, _document_registry, ) from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase class TestField(MongoDBTestCase): def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. """ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) day = DateField(default=datetime.date.today) person = Person(name="Ross") # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "day", "name", "userid"] assert person.validate() is None assert person.name == person.name assert person.age == person.age assert person.userid == person.userid assert person.created == person.created assert person.day == person.day assert person._data["name"] == person.name assert person._data["age"] == person.age assert person._data["userid"] == person.userid assert person._data["created"] == person.created assert person._data["day"] == person.day # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "day", "name", "userid"] def test_custom_field_validation_raise_deprecated_error_when_validation_return_something( self, ): # Covers introduction of a breaking change in the validation parameter (0.18) def _not_empty(z): return bool(z) class Person(Document): name = StringField(validation=_not_empty) Person.drop_collection() error = ( "validation argument for `name` must not return anything, " "it should raise a ValidationError if validation fails" ) with pytest.raises(DeprecatedError) as exc_info: Person(name="").validate() assert str(exc_info.value) == error with pytest.raises(DeprecatedError) as exc_info: Person(name="").save() assert str(exc_info.value) == error def test_custom_field_validation_raise_validation_error(self): def _not_empty(z): if not z: raise ValidationError("cantbeempty") class Person(Document): name = StringField(validation=_not_empty) Person.drop_collection() with pytest.raises(ValidationError) as exc_info: Person(name="").validate() assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( exc_info.value ) Person(name="garbage").validate() Person(name="garbage").save() def test_default_values_set_to_None(self): """Ensure that default field values are used even when we explcitly initialize the doc with None values. """ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) # Trying setting values to None person = Person(name=None, age=None, userid=None, created=None) # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "userid"] assert person.validate() is None assert person.name == person.name assert person.age == person.age assert person.userid == person.userid assert person.created == person.created assert person._data["name"] == person.name assert person._data["age"] == person.age assert person._data["userid"] == person.userid assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "userid"] def test_default_values_when_setting_to_None(self): """Ensure that default field values are used when creating a document. """ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) person = Person() person.name = None person.age = None person.userid = None person.created = None # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "userid"] assert person.validate() is None assert person.name is None assert person.age == 30 assert person.userid == "test" assert isinstance(person.created, datetime.datetime) assert person._data["name"] == person.name assert person._data["age"] == person.age assert person._data["userid"] == person.userid assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "userid"] def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc( self, ): """List field with default can be set to the empty list (strict)""" # Issue #1733 class Doc(Document): x = ListField(IntField(), default=lambda: [42]) doc = Doc(x=[1]).save() doc.x = [] doc.save() reloaded = Doc.objects.get(id=doc.id) assert reloaded.x == [] def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc( self, ): """List field with default can be set to the empty list (dynamic)""" # Issue #1733 class Doc(DynamicDocument): x = ListField(IntField(), default=lambda: [42]) doc = Doc(x=[1]).save() doc.x = [] doc.y = 2 # Was triggering the bug doc.save() reloaded = Doc.objects.get(id=doc.id) assert reloaded.x == [] def test_default_values_when_deleting_value(self): """Ensure that default field values are used after non-default values are explicitly deleted. """ class Person(Document): name = StringField() age = IntField(default=30, required=False) userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) person = Person( name="Ross", age=50, userid="different", created=datetime.datetime(2014, 6, 12), ) del person.name del person.age del person.userid del person.created data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "userid"] assert person.validate() is None assert person.name is None assert person.age == 30 assert person.userid == "test" assert isinstance(person.created, datetime.datetime) assert person.created != datetime.datetime(2014, 6, 12) assert person._data["name"] == person.name assert person._data["age"] == person.age assert person._data["userid"] == person.userid assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) assert data_to_be_saved == ["age", "created", "userid"] def test_required_values(self): """Ensure that required field constraints are enforced.""" class Person(Document): name = StringField(required=True) age = IntField(required=True) userid = StringField() person = Person(name="Test User") with pytest.raises(ValidationError): person.validate() person = Person(age=30) with pytest.raises(ValidationError): person.validate() def test_not_required_handles_none_in_update(self): """Ensure that every fields should accept None if required is False. """ class HandleNoneFields(Document): str_fld = StringField() int_fld = IntField() flt_fld = FloatField() comp_dt_fld = ComplexDateTimeField() HandleNoneFields.drop_collection() doc = HandleNoneFields() doc.str_fld = "spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 doc.com_dt_fld = datetime.datetime.utcnow() doc.save() res = HandleNoneFields.objects(id=doc.id).update( set__str_fld=None, set__int_fld=None, set__flt_fld=None, set__comp_dt_fld=None, ) assert res == 1 # Retrieve data from db and verify it. ret = HandleNoneFields.objects.all()[0] assert ret.str_fld is None assert ret.int_fld is None assert ret.flt_fld is None assert ret.comp_dt_fld is None def test_not_required_handles_none_from_database(self): """Ensure that every field can handle null values from the database. """ class HandleNoneFields(Document): str_fld = StringField(required=True) int_fld = IntField(required=True) flt_fld = FloatField(required=True) comp_dt_fld = ComplexDateTimeField(required=True) HandleNoneFields.drop_collection() doc = HandleNoneFields() doc.str_fld = "spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 doc.comp_dt_fld = datetime.datetime.utcnow() doc.save() # Unset all the fields HandleNoneFields._get_collection().update_one( {"_id": doc.id}, {"$unset": {"str_fld": 1, "int_fld": 1, "flt_fld": 1, "comp_dt_fld": 1}}, ) # Retrieve data from db and verify it. ret = HandleNoneFields.objects.first() assert ret.str_fld is None assert ret.int_fld is None assert ret.flt_fld is None assert ret.comp_dt_fld is None # Retrieved object shouldn't pass validation when a re-save is # attempted. with pytest.raises(ValidationError): ret.validate() def test_default_id_validation_as_objectid(self): """Ensure that invalid values cannot be assigned to an ObjectIdField. """ class Person(Document): name = StringField() person = Person(name="Test User") assert person.id is None person.id = 47 with pytest.raises(ValidationError): person.validate() person.id = "abc" with pytest.raises(ValidationError): person.validate() person.id = str(ObjectId()) person.validate() def test_db_field_validation(self): """Ensure that db_field doesn't accept invalid values.""" # dot in the name with pytest.raises(ValueError): class User(Document): name = StringField(db_field="user.name") # name starting with $ with pytest.raises(ValueError): class UserX1(Document): name = StringField(db_field="$name") # name containing a null character with pytest.raises(ValueError): class UserX2(Document): name = StringField(db_field="name\0") def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements.""" access_level_choices = ( ("a", "Administration"), ("b", "Manager"), ("c", "Staff"), ) class User(Document): pass class Comment(EmbeddedDocument): content = StringField() class BlogPost(Document): content = StringField() comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) authors = ListField(ReferenceField(User)) authors_as_lazy = ListField(LazyReferenceField(User)) generic = ListField(GenericReferenceField()) generic_as_lazy = ListField(GenericLazyReferenceField()) access_list = ListField(choices=access_level_choices, display_sep=", ") User.drop_collection() BlogPost.drop_collection() post = BlogPost(content="Went for a walk today...") post.validate() post.tags = "fun" with pytest.raises(ValidationError): post.validate() post.tags = [1, 2] with pytest.raises(ValidationError): post.validate() post.tags = ["fun", "leisure"] post.validate() post.tags = ("fun", "leisure") post.validate() post.access_list = "a,b" with pytest.raises(ValidationError): post.validate() post.access_list = ["c", "d"] with pytest.raises(ValidationError): post.validate() post.access_list = ["a", "b"] post.validate() assert post.get_access_list_display() == "Administration, Manager" post.comments = ["a"] with pytest.raises(ValidationError): post.validate() post.comments = "yay" with pytest.raises(ValidationError): post.validate() comments = [Comment(content="Good for you"), Comment(content="Yay.")] post.comments = comments post.validate() post.authors = [Comment()] with pytest.raises(ValidationError): post.validate() post.authors = [User()] with pytest.raises(ValidationError): post.validate() user = User() user.save() post.authors = [user] post.validate() post.authors_as_lazy = [Comment()] with pytest.raises(ValidationError): post.validate() post.authors_as_lazy = [User()] with pytest.raises(ValidationError): post.validate() post.authors_as_lazy = [user] post.validate() post.generic = [1, 2] with pytest.raises(ValidationError): post.validate() post.generic = [User(), Comment()] with pytest.raises(ValidationError): post.validate() post.generic = [Comment()] with pytest.raises(ValidationError): post.validate() post.generic = [user] post.validate() post.generic_as_lazy = [1, 2] with pytest.raises(ValidationError): post.validate() post.generic_as_lazy = [User(), Comment()] with pytest.raises(ValidationError): post.validate() post.generic_as_lazy = [Comment()] with pytest.raises(ValidationError): post.validate() post.generic_as_lazy = [user] post.validate() def test_sorted_list_sorting(self): """Ensure that a sorted list field properly sorts values.""" class Comment(EmbeddedDocument): order = IntField() content = StringField() class BlogPost(Document): content = StringField() comments = SortedListField(EmbeddedDocumentField(Comment), ordering="order") tags = SortedListField(StringField()) BlogPost.drop_collection() post = BlogPost(content="Went for a walk today...") post.save() post.tags = ["leisure", "fun"] post.save() post.reload() assert post.tags == ["fun", "leisure"] comment1 = Comment(content="Good for you", order=1) comment2 = Comment(content="Yay.", order=0) comments = [comment1, comment2] post.comments = comments post.save() post.reload() assert post.comments[0].content == comment2.content assert post.comments[1].content == comment1.content post.comments[0].order = 2 post.save() post.reload() assert post.comments[0].content == comment1.content assert post.comments[1].content == comment2.content def test_reverse_list_sorting(self): """Ensure that a reverse sorted list field properly sorts values""" class Category(EmbeddedDocument): count = IntField() name = StringField() class CategoryList(Document): categories = SortedListField( EmbeddedDocumentField(Category), ordering="count", reverse=True ) name = StringField() CategoryList.drop_collection() catlist = CategoryList(name="Top categories") cat1 = Category(name="posts", count=10) cat2 = Category(name="food", count=100) cat3 = Category(name="drink", count=40) catlist.categories = [cat1, cat2, cat3] catlist.save() catlist.reload() assert catlist.categories[0].name == cat2.name assert catlist.categories[1].name == cat3.name assert catlist.categories[2].name == cat1.name def test_list_field(self): """Ensure that list types work as expected.""" class BlogPost(Document): info = ListField() BlogPost.drop_collection() post = BlogPost() post.info = "my post" with pytest.raises(ValidationError): post.validate() post.info = {"title": "test"} with pytest.raises(ValidationError): post.validate() post.info = ["test"] post.save() post = BlogPost() post.info = [{"test": "test"}] post.save() post = BlogPost() post.info = [{"test": 3}] post.save() assert BlogPost.objects.count() == 3 assert BlogPost.objects.filter(info__exact="test").count() == 1 assert BlogPost.objects.filter(info__0__test="test").count() == 1 # Confirm handles non strings or non existing keys assert BlogPost.objects.filter(info__0__test__exact="5").count() == 0 assert BlogPost.objects.filter(info__100__test__exact="test").count() == 0 # test queries by list post = BlogPost() post.info = ["1", "2"] post.save() post = BlogPost.objects(info=["1", "2"]).get() post.info += ["3", "4"] post.save() assert BlogPost.objects(info=["1", "2", "3", "4"]).count() == 1 post = BlogPost.objects(info=["1", "2", "3", "4"]).get() post.info *= 2 post.save() assert ( BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count() == 1 ) def test_list_field_manipulative_operators(self): """Ensure that ListField works with standard list operators that manipulate the list.""" class BlogPost(Document): ref = StringField() info = ListField(StringField()) BlogPost.drop_collection() post = BlogPost() post.ref = "1234" post.info = ["0", "1", "2", "3", "4", "5"] post.save() def reset_post(): post.info = ["0", "1", "2", "3", "4", "5"] post.save() # '__add__(listB)' # listA+listB # operator.add(listA, listB) reset_post() temp = ["a", "b"] post.info = post.info + temp assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] # '__delitem__(index)' # aka 'del list[index]' # aka 'operator.delitem(list, index)' reset_post() del post.info[2] # del from middle ('2') assert post.info == ["0", "1", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "1", "3", "4", "5"] # '__delitem__(slice(i, j))' # aka 'del list[i:j]' # aka 'operator.delitem(list, slice(i,j))' reset_post() del post.info[1:3] # removes '1', '2' assert post.info == ["0", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "3", "4", "5"] # '__iadd__' # aka 'list += list' reset_post() temp = ["a", "b"] post.info += temp assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] # '__imul__' # aka 'list *= number' reset_post() post.info *= 2 assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__mul__' # aka 'listA*listB' reset_post() post.info = post.info * 2 assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__rmul__' # aka 'listB*listA' reset_post() post.info = 2 * post.info assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__setitem__(index, value)' # aka 'list[index]=value' # aka 'setitem(list, value)' reset_post() post.info[4] = "a" assert post.info == ["0", "1", "2", "3", "a", "5"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "a", "5"] # __setitem__(index, value) with a negative index reset_post() post.info[-2] = "a" assert post.info == ["0", "1", "2", "3", "a", "5"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "a", "5"] # '__setitem__(slice(i, j), listB)' # aka 'listA[i:j] = listB' # aka 'setitem(listA, slice(i, j), listB)' reset_post() post.info[1:3] = ["h", "e", "l", "l", "o"] assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] # '__setitem__(slice(i, j), listB)' with negative i and j reset_post() post.info[-5:-3] = ["h", "e", "l", "l", "o"] assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] # negative # 'append' reset_post() post.info.append("h") assert post.info == ["0", "1", "2", "3", "4", "5", "h"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "h"] # 'extend' reset_post() post.info.extend(["h", "e", "l", "l", "o"]) assert post.info == ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] post.save() post.reload() assert post.info == ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] # 'insert' # 'pop' reset_post() x = post.info.pop(2) y = post.info.pop() assert post.info == ["0", "1", "3", "4"] assert x == "2" assert y == "5" post.save() post.reload() assert post.info == ["0", "1", "3", "4"] # 'remove' reset_post() post.info.remove("2") assert post.info == ["0", "1", "3", "4", "5"] post.save() post.reload() assert post.info == ["0", "1", "3", "4", "5"] # 'reverse' reset_post() post.info.reverse() assert post.info == ["5", "4", "3", "2", "1", "0"] post.save() post.reload() assert post.info == ["5", "4", "3", "2", "1", "0"] # 'sort': though this operator method does manipulate the list, it is # tested in the 'test_list_field_lexicograpic_operators' function def test_list_field_invalid_operators(self): class BlogPost(Document): ref = StringField() info = ListField(StringField()) post = BlogPost() post.ref = "1234" post.info = ["0", "1", "2", "3", "4", "5"] # '__hash__' # aka 'hash(list)' with pytest.raises(TypeError): hash(post.info) def test_list_field_lexicographic_operators(self): """Ensure that ListField works with standard list operators that do lexigraphic ordering. """ class BlogPost(Document): ref = StringField() text_info = ListField(StringField()) oid_info = ListField(ObjectIdField()) bool_info = ListField(BooleanField()) BlogPost.drop_collection() blogSmall = BlogPost(ref="small") blogSmall.text_info = ["a", "a", "a"] blogSmall.bool_info = [False, False] blogSmall.save() blogSmall.reload() blogLargeA = BlogPost(ref="big") blogLargeA.text_info = ["a", "z", "j"] blogLargeA.bool_info = [False, True] blogLargeA.save() blogLargeA.reload() blogLargeB = BlogPost(ref="big2") blogLargeB.text_info = ["a", "z", "j"] blogLargeB.oid_info = [ "54495ad94c934721ede76f90", "54495ad94c934721ede76d23", "54495ad94c934721ede76d00", ] blogLargeB.bool_info = [False, True] blogLargeB.save() blogLargeB.reload() # '__eq__' aka '==' assert blogLargeA.text_info == blogLargeB.text_info assert blogLargeA.bool_info == blogLargeB.bool_info # '__ge__' aka '>=' assert blogLargeA.text_info >= blogSmall.text_info assert blogLargeA.text_info >= blogLargeB.text_info assert blogLargeA.bool_info >= blogSmall.bool_info assert blogLargeA.bool_info >= blogLargeB.bool_info # '__gt__' aka '>' assert blogLargeA.text_info >= blogSmall.text_info assert blogLargeA.bool_info >= blogSmall.bool_info # '__le__' aka '<=' assert blogSmall.text_info <= blogLargeB.text_info assert blogLargeA.text_info <= blogLargeB.text_info assert blogSmall.bool_info <= blogLargeB.bool_info assert blogLargeA.bool_info <= blogLargeB.bool_info # '__lt__' aka '<' assert blogSmall.text_info < blogLargeB.text_info assert blogSmall.bool_info < blogLargeB.bool_info # '__ne__' aka '!=' assert blogSmall.text_info != blogLargeB.text_info assert blogSmall.bool_info != blogLargeB.bool_info # 'sort' blogLargeB.bool_info = [True, False, True, False] blogLargeB.text_info.sort() blogLargeB.oid_info.sort() blogLargeB.bool_info.sort() sorted_target_list = [ ObjectId("54495ad94c934721ede76d00"), ObjectId("54495ad94c934721ede76d23"), ObjectId("54495ad94c934721ede76f90"), ] assert blogLargeB.text_info == ["a", "j", "z"] assert blogLargeB.oid_info == sorted_target_list assert blogLargeB.bool_info == [False, False, True, True] blogLargeB.save() blogLargeB.reload() assert blogLargeB.text_info == ["a", "j", "z"] assert blogLargeB.oid_info == sorted_target_list assert blogLargeB.bool_info == [False, False, True, True] def test_list_assignment(self): """Ensure that list field element assignment and slicing work.""" class BlogPost(Document): info = ListField() BlogPost.drop_collection() post = BlogPost() post.info = ["e1", "e2", 3, "4", 5] post.save() post.info[0] = 1 post.save() post.reload() assert post.info[0] == 1 post.info[1:3] = ["n2", "n3"] post.save() post.reload() assert post.info == [1, "n2", "n3", "4", 5] post.info[-1] = "n5" post.save() post.reload() assert post.info == [1, "n2", "n3", "4", "n5"] post.info[-2] = 4 post.save() post.reload() assert post.info == [1, "n2", "n3", 4, "n5"] post.info[1:-1] = [2] post.save() post.reload() assert post.info == [1, 2, "n5"] post.info[:-1] = [1, "n2", "n3", 4] post.save() post.reload() assert post.info == [1, "n2", "n3", 4, "n5"] post.info[-4:3] = [2, 3] post.save() post.reload() assert post.info == [1, 2, 3, 4, "n5"] def test_list_field_passed_in_value(self): class Foo(Document): bars = ListField(ReferenceField("Bar")) class Bar(Document): text = StringField() bar = Bar(text="hi") bar.save() foo = Foo(bars=[]) foo.bars.append(bar) assert repr(foo.bars) == "[]" def test_list_field_strict(self): """Ensure that list field handles validation if provided a strict field type. """ class Simple(Document): mapping = ListField(field=IntField()) Simple.drop_collection() e = Simple() e.mapping = [1] e.save() # try creating an invalid mapping with pytest.raises(ValidationError): e.mapping = ["abc"] e.save() def test_list_field_max_length(self): """Ensure ListField's max_length is respected.""" class Foo(Document): items = ListField(IntField(), max_length=5) foo = Foo() for i in range(1, 7): foo.items.append(i) if i < 6: foo.save() else: with pytest.raises(ValidationError) as exc_info: foo.save() assert "List is too long" in str(exc_info.value) def test_list_field_max_length_set_operator(self): """Ensure ListField's max_length is respected for a "set" operator.""" class Foo(Document): items = ListField(IntField(), max_length=3) foo = Foo.objects.create(items=[1, 2, 3]) with pytest.raises(ValidationError) as exc_info: foo.modify(set__items=[1, 2, 3, 4]) assert "List is too long" in str(exc_info.value) def test_list_field_rejects_strings(self): """Strings aren't valid list field data types.""" class Simple(Document): mapping = ListField() Simple.drop_collection() e = Simple() e.mapping = "hello world" with pytest.raises(ValidationError): e.save() def test_complex_field_required(self): """Ensure required cant be None / Empty.""" class Simple(Document): mapping = ListField(required=True) Simple.drop_collection() e = Simple() e.mapping = [] with pytest.raises(ValidationError): e.save() class Simple(Document): mapping = DictField(required=True) Simple.drop_collection() e = Simple() e.mapping = {} with pytest.raises(ValidationError): e.save() def test_complex_field_same_value_not_changed(self): """If a complex field is set to the same value, it should not be marked as changed. """ class Simple(Document): mapping = ListField() Simple.drop_collection() e = Simple().save() e.mapping = [] assert e._changed_fields == [] class Simple(Document): mapping = DictField() Simple.drop_collection() e = Simple().save() e.mapping = {} assert e._changed_fields == [] def test_slice_marks_field_as_changed(self): class Simple(Document): widgets = ListField() simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[:3] = [] assert ["widgets"] == simple._changed_fields simple.save() simple = simple.reload() assert simple.widgets == [4] def test_del_slice_marks_field_as_changed(self): class Simple(Document): widgets = ListField() simple = Simple(widgets=[1, 2, 3, 4]).save() del simple.widgets[:3] assert ["widgets"] == simple._changed_fields simple.save() simple = simple.reload() assert simple.widgets == [4] def test_list_field_with_negative_indices(self): class Simple(Document): widgets = ListField() simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[-1] = 5 assert ["widgets.3"] == simple._changed_fields simple.save() simple = simple.reload() assert simple.widgets == [1, 2, 3, 5] def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" class SettingBase(EmbeddedDocument): meta = {"allow_inheritance": True} class StringSetting(SettingBase): value = StringField() class IntegerSetting(SettingBase): value = IntField() class Simple(Document): mapping = ListField() Simple.drop_collection() e = Simple() e.mapping.append(StringSetting(value="foo")) e.mapping.append(IntegerSetting(value=42)) e.mapping.append( { "number": 1, "string": "Hi!", "float": 1.001, "complex": IntegerSetting(value=42), "list": [IntegerSetting(value=42), StringSetting(value="foo")], } ) e.save() e2 = Simple.objects.get(id=e.id) assert isinstance(e2.mapping[0], StringSetting) assert isinstance(e2.mapping[1], IntegerSetting) # Test querying assert Simple.objects.filter(mapping__1__value=42).count() == 1 assert Simple.objects.filter(mapping__2__number=1).count() == 1 assert Simple.objects.filter(mapping__2__complex__value=42).count() == 1 assert Simple.objects.filter(mapping__2__list__0__value=42).count() == 1 assert Simple.objects.filter(mapping__2__list__1__value="foo").count() == 1 # Confirm can update Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) assert Simple.objects.filter(mapping__1__value=10).count() == 1 Simple.objects().update(set__mapping__2__list__1=StringSetting(value="Boo")) assert Simple.objects.filter(mapping__2__list__1__value="foo").count() == 0 assert Simple.objects.filter(mapping__2__list__1__value="Boo").count() == 1 def test_embedded_db_field(self): class Embedded(EmbeddedDocument): number = IntField(default=0, db_field="i") class Test(Document): embedded = EmbeddedDocumentField(Embedded, db_field="x") Test.drop_collection() test = Test() test.embedded = Embedded(number=1) test.save() Test.objects.update_one(inc__embedded__number=1) test = Test.objects.get() assert test.embedded.number == 2 doc = self.db.test.find_one() assert doc["x"]["i"] == 2 def test_double_embedded_db_field(self): """Make sure multiple layers of embedded docs resolve db fields properly and can be initialized using dicts. """ class C(EmbeddedDocument): txt = StringField() class B(EmbeddedDocument): c = EmbeddedDocumentField(C, db_field="fc") class A(Document): b = EmbeddedDocumentField(B, db_field="fb") a = A(b=B(c=C(txt="hi"))) a.validate() a = A(b={"c": {"txt": "hi"}}) a.validate() def test_double_embedded_db_field_from_son(self): """Make sure multiple layers of embedded docs resolve db fields from SON properly. """ class C(EmbeddedDocument): txt = StringField() class B(EmbeddedDocument): c = EmbeddedDocumentField(C, db_field="fc") class A(Document): b = EmbeddedDocumentField(B, db_field="fb") a = A._from_son(SON([("fb", SON([("fc", SON([("txt", "hi")]))]))])) assert a.b.c.txt == "hi" @pytest.mark.xfail( reason="Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet", raises=NotRegistered, ) def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet( self, ): class MyDoc2(Document): emb = EmbeddedDocumentField("MyFunkyDoc123") class MyFunkyDoc123(EmbeddedDocument): name = StringField() def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to embedded document fields. """ class Comment(EmbeddedDocument): content = StringField() class PersonPreferences(EmbeddedDocument): food = StringField(required=True) number = IntField() class Person(Document): name = StringField() preferences = EmbeddedDocumentField(PersonPreferences) Person.drop_collection() person = Person(name="Test User") person.preferences = "My Preferences" with pytest.raises(ValidationError): person.validate() # Check that only the right embedded doc works person.preferences = Comment(content="Nice blog post...") with pytest.raises(ValidationError): person.validate() # Check that the embedded doc is valid person.preferences = PersonPreferences() with pytest.raises(ValidationError): person.validate() person.preferences = PersonPreferences(food="Cheese", number=47) assert person.preferences.food == "Cheese" person.validate() def test_embedded_document_inheritance(self): """Ensure that subclasses of embedded documents may be provided to EmbeddedDocumentFields of the superclass' type. """ class User(EmbeddedDocument): name = StringField() meta = {"allow_inheritance": True} class PowerUser(User): power = IntField() class BlogPost(Document): content = StringField() author = EmbeddedDocumentField(User) BlogPost.drop_collection() post = BlogPost(content="What I did today...") post.author = PowerUser(name="Test User", power=47) post.save() assert 47 == BlogPost.objects.first().author.power def test_embedded_document_inheritance_with_list(self): """Ensure that nested list of subclassed embedded documents is handled correctly. """ class Group(EmbeddedDocument): name = StringField() content = ListField(StringField()) class Basedoc(Document): groups = ListField(EmbeddedDocumentField(Group)) meta = {"abstract": True} class User(Basedoc): doctype = StringField(require=True, default="userdata") User.drop_collection() content = ["la", "le", "lu"] group = Group(name="foo", content=content) foobar = User(groups=[group]) foobar.save() assert content == User.objects.first().groups[0].content def test_reference_miss(self): """Ensure an exception is raised when dereferencing an unknown document. """ class Foo(Document): pass class Bar(Document): ref = ReferenceField(Foo) generic_ref = GenericReferenceField() Foo.drop_collection() Bar.drop_collection() foo = Foo().save() bar = Bar(ref=foo, generic_ref=foo).save() # Reference is no longer valid foo.delete() bar = Bar.objects.get() with pytest.raises(DoesNotExist): bar.ref with pytest.raises(DoesNotExist): bar.generic_ref # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() expected = foo.to_dbref() bar._fields["ref"]._auto_dereference = False assert bar.ref == expected bar._fields["generic_ref"]._auto_dereference = False assert bar.generic_ref == {"_ref": expected, "_cls": "Foo"} def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced.""" class User(Document): name = StringField() class Group(Document): members = ListField(ReferenceField(User)) User.drop_collection() Group.drop_collection() user1 = User(name="user1") user1.save() user2 = User(name="user2") user2.save() group = Group(members=[user1, user2]) group.save() group_obj = Group.objects.first() assert group_obj.members[0].name == user1.name assert group_obj.members[1].name == user2.name def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents.""" class Employee(Document): name = StringField() boss = ReferenceField("self") friends = ListField(ReferenceField("self")) Employee.drop_collection() bill = Employee(name="Bill Lumbergh") bill.save() michael = Employee(name="Michael Bolton") michael.save() samir = Employee(name="Samir Nagheenanajar") samir.save() friends = [michael, samir] peter = Employee(name="Peter Gibbons", boss=bill, friends=friends) peter.save() peter = Employee.objects.with_id(peter.id) assert peter.boss == bill assert peter.friends == friends def test_recursive_embedding(self): """Ensure that EmbeddedDocumentFields can contain their own documents.""" class TreeNode(EmbeddedDocument): name = StringField() children = ListField(EmbeddedDocumentField("self")) class Tree(Document): name = StringField() children = ListField(EmbeddedDocumentField("TreeNode")) Tree.drop_collection() tree = Tree(name="Tree") first_child = TreeNode(name="Child 1") tree.children.append(first_child) second_child = TreeNode(name="Child 2") first_child.children.append(second_child) tree.save() tree = Tree.objects.first() assert len(tree.children) == 1 assert len(tree.children[0].children) == 1 third_child = TreeNode(name="Child 3") tree.children[0].children.append(third_child) tree.save() assert len(tree.children) == 1 assert tree.children[0].name == first_child.name assert tree.children[0].children[0].name == second_child.name assert tree.children[0].children[1].name == third_child.name # Test updating tree.children[0].name = "I am Child 1" tree.children[0].children[0].name = "I am Child 2" tree.children[0].children[1].name = "I am Child 3" tree.save() assert tree.children[0].name == "I am Child 1" assert tree.children[0].children[0].name == "I am Child 2" assert tree.children[0].children[1].name == "I am Child 3" # Test removal assert len(tree.children[0].children) == 2 del tree.children[0].children[1] tree.save() assert len(tree.children[0].children) == 1 tree.children[0].children.pop(0) tree.save() assert len(tree.children[0].children) == 0 assert tree.children[0].children == [] tree.children[0].children.insert(0, third_child) tree.children[0].children.insert(0, second_child) tree.save() assert len(tree.children[0].children) == 2 assert tree.children[0].children[0].name == second_child.name assert tree.children[0].children[1].name == third_child.name def test_drop_abstract_document(self): """Ensure that an abstract document cannot be dropped given it has no underlying collection. """ class AbstractDoc(Document): name = StringField() meta = {"abstract": True} with pytest.raises(OperationError): AbstractDoc.drop_collection() def test_reference_class_with_abstract_parent(self): """Ensure that a class with an abstract parent can be referenced.""" class Sibling(Document): name = StringField() meta = {"abstract": True} class Sister(Sibling): pass class Brother(Sibling): sibling = ReferenceField(Sibling) Sister.drop_collection() Brother.drop_collection() sister = Sister(name="Alice") sister.save() brother = Brother(name="Bob", sibling=sister) brother.save() assert Brother.objects[0].sibling.name == sister.name def test_reference_abstract_class(self): """Ensure that an abstract class instance cannot be used in the reference of that abstract class. """ class Sibling(Document): name = StringField() meta = {"abstract": True} class Sister(Sibling): pass class Brother(Sibling): sibling = ReferenceField(Sibling) Sister.drop_collection() Brother.drop_collection() sister = Sibling(name="Alice") brother = Brother(name="Bob", sibling=sister) with pytest.raises(ValidationError): brother.save() def test_abstract_reference_base_type(self): """Ensure that an an abstract reference fails validation when given a Document that does not inherit from the abstract type. """ class Sibling(Document): name = StringField() meta = {"abstract": True} class Brother(Sibling): sibling = ReferenceField(Sibling) class Mother(Document): name = StringField() Brother.drop_collection() Mother.drop_collection() mother = Mother(name="Carol") mother.save() brother = Brother(name="Bob", sibling=mother) with pytest.raises(ValidationError): brother.save() def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items.""" class Link(Document): title = StringField() meta = {"allow_inheritance": False} class Post(Document): title = StringField() class Bookmark(Document): bookmark_object = GenericReferenceField() Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() link_1 = Link(title="Pitchfork") link_1.save() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() bm = Bookmark(bookmark_object=post_1) bm.save() bm = Bookmark.objects(bookmark_object=post_1).first() assert bm.bookmark_object == post_1 assert isinstance(bm.bookmark_object, Post) bm.bookmark_object = link_1 bm.save() bm = Bookmark.objects(bookmark_object=link_1).first() assert bm.bookmark_object == link_1 assert isinstance(bm.bookmark_object, Link) def test_generic_reference_list(self): """Ensure that a ListField properly dereferences generic references.""" class Link(Document): title = StringField() class Post(Document): title = StringField() class User(Document): bookmarks = ListField(GenericReferenceField()) Link.drop_collection() Post.drop_collection() User.drop_collection() link_1 = Link(title="Pitchfork") link_1.save() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() user = User(bookmarks=[post_1, link_1]) user.save() user = User.objects(bookmarks__all=[post_1, link_1]).first() assert user.bookmarks[0] == post_1 assert user.bookmarks[1] == link_1 def test_generic_reference_document_not_registered(self): """Ensure dereferencing out of the document registry throws a `NotRegistered` error. """ class Link(Document): title = StringField() class User(Document): bookmarks = ListField(GenericReferenceField()) Link.drop_collection() User.drop_collection() link_1 = Link(title="Pitchfork") link_1.save() user = User(bookmarks=[link_1]) user.save() # Mimic User and Link definitions being in a different file # and the Link model not being imported in the User file. del _document_registry["Link"] user = User.objects.first() try: user.bookmarks raise AssertionError("Link was removed from the registry") except NotRegistered: pass def test_generic_reference_is_none(self): class Person(Document): name = StringField() city = GenericReferenceField() Person.drop_collection() Person(name="Wilson Jr").save() assert repr(Person.objects(city=None)) == "[]" def test_generic_reference_choices(self): """Ensure that a GenericReferenceField can handle choices.""" class Link(Document): title = StringField() class Post(Document): title = StringField() class Bookmark(Document): bookmark_object = GenericReferenceField(choices=(Post,)) Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() link_1 = Link(title="Pitchfork") link_1.save() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() bm = Bookmark(bookmark_object=link_1) with pytest.raises(ValidationError): bm.validate() bm = Bookmark(bookmark_object=post_1) bm.save() bm = Bookmark.objects.first() assert bm.bookmark_object == post_1 def test_generic_reference_string_choices(self): """Ensure that a GenericReferenceField can handle choices as strings""" class Link(Document): title = StringField() class Post(Document): title = StringField() class Bookmark(Document): bookmark_object = GenericReferenceField(choices=("Post", Link)) Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() link_1 = Link(title="Pitchfork") link_1.save() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() bm = Bookmark(bookmark_object=link_1) bm.save() bm = Bookmark(bookmark_object=post_1) bm.save() bm = Bookmark(bookmark_object=bm) with pytest.raises(ValidationError): bm.validate() def test_generic_reference_choices_no_dereference(self): """Ensure that a GenericReferenceField can handle choices on non-derefenreced (i.e. DBRef) elements """ class Post(Document): title = StringField() class Bookmark(Document): bookmark_object = GenericReferenceField(choices=(Post,)) other_field = StringField() Post.drop_collection() Bookmark.drop_collection() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() bm = Bookmark(bookmark_object=post_1) bm.save() bm = Bookmark.objects.get(id=bm.id) # bookmark_object is now a DBRef bm.other_field = "dummy_change" bm.save() def test_generic_reference_list_choices(self): """Ensure that a ListField properly dereferences generic references and respects choices. """ class Link(Document): title = StringField() class Post(Document): title = StringField() class User(Document): bookmarks = ListField(GenericReferenceField(choices=(Post,))) Link.drop_collection() Post.drop_collection() User.drop_collection() link_1 = Link(title="Pitchfork") link_1.save() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() user = User(bookmarks=[link_1]) with pytest.raises(ValidationError): user.validate() user = User(bookmarks=[post_1]) user.save() user = User.objects.first() assert user.bookmarks == [post_1] def test_generic_reference_list_item_modification(self): """Ensure that modifications of related documents (through generic reference) don't influence on querying""" class Post(Document): title = StringField() class User(Document): username = StringField() bookmarks = ListField(GenericReferenceField()) Post.drop_collection() User.drop_collection() post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() user = User(bookmarks=[post_1]) user.save() post_1.title = "Title was modified" user.username = "New username" user.save() user = User.objects(bookmarks__all=[post_1]).first() assert user is not None assert user.bookmarks[0] == post_1 def test_generic_reference_filter_by_dbref(self): """Ensure we can search for a specific generic reference by providing its ObjectId. """ class Doc(Document): ref = GenericReferenceField() Doc.drop_collection() doc1 = Doc.objects.create() doc2 = Doc.objects.create(ref=doc1) doc = Doc.objects.get(ref=DBRef("doc", doc1.pk)) assert doc == doc2 def test_generic_reference_is_not_tracked_in_parent_doc(self): """Ensure that modifications of related documents (through generic reference) don't influence the owner changed fields (#1934) """ class Doc1(Document): name = StringField() class Doc2(Document): ref = GenericReferenceField() refs = ListField(GenericReferenceField()) Doc1.drop_collection() Doc2.drop_collection() doc1 = Doc1(name="garbage1").save() doc11 = Doc1(name="garbage11").save() doc2 = Doc2(ref=doc1, refs=[doc11]).save() doc2.ref.name = "garbage2" assert doc2._get_changed_fields() == [] doc2.refs[0].name = "garbage3" assert doc2._get_changed_fields() == [] assert doc2._delta() == ({}, {}) def test_generic_reference_field(self): """Ensure we can search for a specific generic reference by providing its DBRef. """ class Doc(Document): ref = GenericReferenceField() Doc.drop_collection() doc1 = Doc.objects.create() doc2 = Doc.objects.create(ref=doc1) assert isinstance(doc1.pk, ObjectId) doc = Doc.objects.get(ref=doc1.pk) assert doc == doc2 def test_choices_allow_using_sets_as_choices(self): """Ensure that sets can be used when setting choices""" class Shirt(Document): size = StringField(choices={"M", "L"}) Shirt(size="M").validate() def test_choices_validation_allow_no_value(self): """Ensure that .validate passes and no value was provided for a field setup with choices """ class Shirt(Document): size = StringField(choices=("S", "M")) shirt = Shirt() shirt.validate() def test_choices_validation_accept_possible_value(self): """Ensure that value is in a container of allowed values.""" class Shirt(Document): size = StringField(choices=("S", "M")) shirt = Shirt(size="S") shirt.validate() def test_choices_validation_reject_unknown_value(self): """Ensure that unallowed value are rejected upon validation""" class Shirt(Document): size = StringField(choices=("S", "M")) shirt = Shirt(size="XS") with pytest.raises(ValidationError): shirt.validate() def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. """ class Shirt(Document): size = StringField( max_length=3, choices=( ("S", "Small"), ("M", "Medium"), ("L", "Large"), ("XL", "Extra Large"), ("XXL", "Extra Extra Large"), ), ) style = StringField( max_length=3, choices=(("S", "Small"), ("B", "Baggy"), ("W", "Wide")), default="W", ) Shirt.drop_collection() shirt1 = Shirt() shirt2 = Shirt() # Make sure get__display returns the default value (or None) assert shirt1.get_size_display() is None assert shirt1.get_style_display() == "Wide" shirt1.size = "XXL" shirt1.style = "B" shirt2.size = "M" shirt2.style = "S" assert shirt1.get_size_display() == "Extra Extra Large" assert shirt1.get_style_display() == "Baggy" assert shirt2.get_size_display() == "Medium" assert shirt2.get_style_display() == "Small" # Set as Z - an invalid choice shirt1.size = "Z" shirt1.style = "Z" assert shirt1.get_size_display() == "Z" assert shirt1.get_style_display() == "Z" with pytest.raises(ValidationError): shirt1.validate() def test_simple_choices_validation(self): """Ensure that value is in a container of allowed values.""" class Shirt(Document): size = StringField(max_length=3, choices=("S", "M", "L", "XL", "XXL")) Shirt.drop_collection() shirt = Shirt() shirt.validate() shirt.size = "S" shirt.validate() shirt.size = "XS" with pytest.raises(ValidationError): shirt.validate() def test_simple_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. """ class Shirt(Document): size = StringField(max_length=3, choices=("S", "M", "L", "XL", "XXL")) style = StringField( max_length=3, choices=("Small", "Baggy", "wide"), default="Small" ) Shirt.drop_collection() shirt = Shirt() assert shirt.get_size_display() is None assert shirt.get_style_display() == "Small" shirt.size = "XXL" shirt.style = "Baggy" assert shirt.get_size_display() == "XXL" assert shirt.get_style_display() == "Baggy" # Set as Z - an invalid choice shirt.size = "Z" shirt.style = "Z" assert shirt.get_size_display() == "Z" assert shirt.get_style_display() == "Z" with pytest.raises(ValidationError): shirt.validate() def test_simple_choices_validation_invalid_value(self): """Ensure that error messages are correct.""" SIZES = ("S", "M", "L", "XL", "XXL") COLORS = (("R", "Red"), ("B", "Blue")) SIZE_MESSAGE = "Value must be one of ('S', 'M', 'L', 'XL', 'XXL')" COLOR_MESSAGE = "Value must be one of ['R', 'B']" class Shirt(Document): size = StringField(max_length=3, choices=SIZES) color = StringField(max_length=1, choices=COLORS) Shirt.drop_collection() shirt = Shirt() shirt.validate() shirt.size = "S" shirt.color = "R" shirt.validate() shirt.size = "XS" shirt.color = "G" try: shirt.validate() except ValidationError as error: # get the validation rules error_dict = error.to_dict() assert error_dict["size"] == SIZE_MESSAGE assert error_dict["color"] == COLOR_MESSAGE def test_recursive_validation(self): """Ensure that a validation result to_dict is available.""" class Author(EmbeddedDocument): name = StringField(required=True) class Comment(EmbeddedDocument): author = EmbeddedDocumentField(Author, required=True) content = StringField(required=True) class Post(Document): title = StringField(required=True) comments = ListField(EmbeddedDocumentField(Comment)) bob = Author(name="Bob") post = Post(title="hello world") post.comments.append(Comment(content="hello", author=bob)) post.comments.append(Comment(author=bob)) with pytest.raises(ValidationError): post.validate() try: post.validate() except ValidationError as error: # ValidationError.errors property assert hasattr(error, "errors") assert isinstance(error.errors, dict) assert "comments" in error.errors assert 1 in error.errors["comments"] assert isinstance(error.errors["comments"][1]["content"], ValidationError) # ValidationError.schema property error_dict = error.to_dict() assert isinstance(error_dict, dict) assert "comments" in error_dict assert 1 in error_dict["comments"] assert "content" in error_dict["comments"][1] assert error_dict["comments"][1]["content"] == "Field is required" post.comments[1].content = "here we go" post.validate() def test_tuples_as_tuples(self): """Ensure that tuples remain tuples when they are inside a ComplexBaseField. """ class EnumField(BaseField): def __init__(self, **kwargs): super().__init__(**kwargs) def to_mongo(self, value): return value def to_python(self, value): return tuple(value) class TestDoc(Document): items = ListField(EnumField()) TestDoc.drop_collection() tuples = [(100, "Testing")] doc = TestDoc() doc.items = tuples doc.save() x = TestDoc.objects().get() assert x is not None assert len(x.items) == 1 assert tuple(x.items[0]) in tuples assert x.items[0] in tuples def test_dynamic_fields_class(self): class Doc2(Document): field_1 = StringField(db_field="f") class Doc(Document): my_id = IntField(primary_key=True) embed_me = DynamicField(db_field="e") field_x = StringField(db_field="x") Doc.drop_collection() Doc2.drop_collection() doc2 = Doc2(field_1="hello") doc = Doc(my_id=1, embed_me=doc2, field_x="x") with pytest.raises(OperationError): doc.save() doc2.save() doc.save() doc = Doc.objects.get() assert doc.embed_me.field_1 == "hello" def test_dynamic_fields_embedded_class(self): class Embed(EmbeddedDocument): field_1 = StringField(db_field="f") class Doc(Document): my_id = IntField(primary_key=True) embed_me = DynamicField(db_field="e") field_x = StringField(db_field="x") Doc.drop_collection() Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save() doc = Doc.objects.get() assert doc.embed_me.field_1 == "hello" def test_dynamicfield_dump_document(self): """Ensure a DynamicField can handle another document's dump.""" class Doc(Document): field = DynamicField() class ToEmbed(Document): id = IntField(primary_key=True, default=1) recursive = DynamicField() class ToEmbedParent(Document): id = IntField(primary_key=True, default=1) recursive = DynamicField() meta = {"allow_inheritance": True} class ToEmbedChild(ToEmbedParent): pass to_embed_recursive = ToEmbed(id=1).save() to_embed = ToEmbed(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed) doc.save() assert isinstance(doc.field, ToEmbed) assert doc.field == to_embed # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed_child) doc.save() assert isinstance(doc.field, ToEmbedChild) assert doc.field == to_embed_child def test_cls_field(self): class Animal(Document): meta = {"allow_inheritance": True} class Fish(Animal): pass class Mammal(Animal): pass class Dog(Mammal): pass class Human(Mammal): pass Animal.objects.delete() Dog().save() Fish().save() Human().save() assert ( Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count() == 2 ) assert Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count() == 0 def test_sparse_field(self): class Doc(Document): name = StringField(required=False, unique=True, sparse=True) # This would raise an exception in a non-sparse unique index Doc().save() Doc().save() def test_undefined_field_exception(self): """Tests if a `FieldDoesNotExist` exception is raised when trying to instantiate a document with a field that's not defined. """ class Doc(Document): foo = StringField() with pytest.raises(FieldDoesNotExist): Doc(bar="test") def test_undefined_field_exception_with_strict(self): """Tests if a `FieldDoesNotExist` exception is raised when trying to instantiate a document with a field that's not defined, even when strict is set to False. """ class Doc(Document): foo = StringField() meta = {"strict": False} with pytest.raises(FieldDoesNotExist): Doc(bar="test") def test_undefined_field_works_no_confusion_with_db_field(self): class Doc(Document): foo = StringField(db_field="bar") with pytest.raises(FieldDoesNotExist): Doc(bar="test") class TestEmbeddedDocumentListField(MongoDBTestCase): def setUp(self): """ Create two BlogPost entries in the database, each with several EmbeddedDocuments. """ class Comments(EmbeddedDocument): author = StringField() message = StringField() class BlogPost(Document): comments = EmbeddedDocumentListField(Comments) BlogPost.drop_collection() self.Comments = Comments self.BlogPost = BlogPost self.post1 = self.BlogPost( comments=[ self.Comments(author="user1", message="message1"), self.Comments(author="user2", message="message1"), ] ).save() self.post2 = self.BlogPost( comments=[ self.Comments(author="user2", message="message2"), self.Comments(author="user2", message="message3"), self.Comments(author="user3", message="message1"), ] ).save() def test_fails_upon_validate_if_provide_a_doc_instead_of_a_list_of_doc(self): # Relates to Issue #1464 comment = self.Comments(author="John") class Title(Document): content = StringField() # Test with an embeddedDocument instead of a list(embeddedDocument) # It's an edge case but it used to fail with a vague error, making it difficult to troubleshoot it post = self.BlogPost(comments=comment) with pytest.raises(ValidationError) as exc_info: post.validate() error_msg = str(exc_info.value) assert "'comments'" in error_msg assert "Only lists and tuples may be used in a list field" in error_msg # Test with a Document post = self.BlogPost(comments=Title(content="garbage")) with pytest.raises(ValidationError) as exc_info: post.validate() error_msg = str(exc_info.value) assert "'comments'" in error_msg assert "Only lists and tuples may be used in a list field" in error_msg def test_no_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents with a no keyword. """ filtered = self.post1.comments.filter() # Ensure nothing was changed assert filtered == self.post1.comments def test_single_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents with a single keyword. """ filtered = self.post1.comments.filter(author="user1") # Ensure only 1 entry was returned. assert len(filtered) == 1 # Ensure the entry returned is the correct entry. assert filtered[0].author == "user1" def test_multi_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents with multiple keywords. """ filtered = self.post2.comments.filter(author="user2", message="message2") # Ensure only 1 entry was returned. assert len(filtered) == 1 # Ensure the entry returned is the correct entry. assert filtered[0].author == "user2" assert filtered[0].message == "message2" def test_chained_filter(self): """ Tests chained filter methods of a List of Embedded Documents """ filtered = self.post2.comments.filter(author="user2").filter(message="message2") # Ensure only 1 entry was returned. assert len(filtered) == 1 # Ensure the entry returned is the correct entry. assert filtered[0].author == "user2" assert filtered[0].message == "message2" def test_unknown_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents when the keyword is not a known keyword. """ with pytest.raises(AttributeError): self.post2.comments.filter(year=2) def test_no_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents with a no keyword. """ filtered = self.post1.comments.exclude() # Ensure everything was removed assert filtered == [] def test_single_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents with a single keyword. """ excluded = self.post1.comments.exclude(author="user1") # Ensure only 1 entry was returned. assert len(excluded) == 1 # Ensure the entry returned is the correct entry. assert excluded[0].author == "user2" def test_multi_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents with multiple keywords. """ excluded = self.post2.comments.exclude(author="user3", message="message1") # Ensure only 2 entries were returned. assert len(excluded) == 2 # Ensure the entries returned are the correct entries. assert excluded[0].author == "user2" assert excluded[1].author == "user2" def test_non_matching_exclude(self): """ Tests the exclude method of a List of Embedded Documents when the keyword does not match any entries. """ excluded = self.post2.comments.exclude(author="user4") # Ensure the 3 entries still exist. assert len(excluded) == 3 def test_unknown_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents when the keyword is not a known keyword. """ with pytest.raises(AttributeError): self.post2.comments.exclude(year=2) def test_chained_filter_exclude(self): """ Tests the exclude method after a filter method of a List of Embedded Documents. """ excluded = self.post2.comments.filter(author="user2").exclude( message="message2" ) # Ensure only 1 entry was returned. assert len(excluded) == 1 # Ensure the entry returned is the correct entry. assert excluded[0].author == "user2" assert excluded[0].message == "message3" def test_count(self): """ Tests the count method of a List of Embedded Documents. """ assert self.post1.comments.count() == 2 assert self.post1.comments.count() == len(self.post1.comments) def test_filtered_count(self): """ Tests the filter + count method of a List of Embedded Documents. """ count = self.post1.comments.filter(author="user1").count() assert count == 1 def test_single_keyword_get(self): """ Tests the get method of a List of Embedded Documents using a single keyword. """ comment = self.post1.comments.get(author="user1") assert isinstance(comment, self.Comments) assert comment.author == "user1" def test_multi_keyword_get(self): """ Tests the get method of a List of Embedded Documents using multiple keywords. """ comment = self.post2.comments.get(author="user2", message="message2") assert isinstance(comment, self.Comments) assert comment.author == "user2" assert comment.message == "message2" def test_no_keyword_multiple_return_get(self): """ Tests the get method of a List of Embedded Documents without a keyword to return multiple documents. """ with pytest.raises(MultipleObjectsReturned): self.post1.comments.get() def test_keyword_multiple_return_get(self): """ Tests the get method of a List of Embedded Documents with a keyword to return multiple documents. """ with pytest.raises(MultipleObjectsReturned): self.post2.comments.get(author="user2") def test_unknown_keyword_get(self): """ Tests the get method of a List of Embedded Documents with an unknown keyword. """ with pytest.raises(AttributeError): self.post2.comments.get(year=2020) def test_no_result_get(self): """ Tests the get method of a List of Embedded Documents where get returns no results. """ with pytest.raises(DoesNotExist): self.post1.comments.get(author="user3") def test_first(self): """ Tests the first method of a List of Embedded Documents to ensure it returns the first comment. """ comment = self.post1.comments.first() # Ensure a Comment object was returned. assert isinstance(comment, self.Comments) assert comment == self.post1.comments[0] def test_create(self): """ Test the create method of a List of Embedded Documents. """ comment = self.post1.comments.create(author="user4", message="message1") self.post1.save() # Ensure the returned value is the comment object. assert isinstance(comment, self.Comments) assert comment.author == "user4" assert comment.message == "message1" # Ensure the new comment was actually saved to the database. assert comment in self.BlogPost.objects(comments__author="user4")[0].comments def test_filtered_create(self): """ Test the create method of a List of Embedded Documents chained to a call to the filter method. Filtering should have no effect on creation. """ comment = self.post1.comments.filter(author="user1").create( author="user4", message="message1" ) self.post1.save() # Ensure the returned value is the comment object. assert isinstance(comment, self.Comments) assert comment.author == "user4" assert comment.message == "message1" # Ensure the new comment was actually saved to the database. assert comment in self.BlogPost.objects(comments__author="user4")[0].comments def test_no_keyword_update(self): """ Tests the update method of a List of Embedded Documents with no keywords. """ original = list(self.post1.comments) number = self.post1.comments.update() self.post1.save() # Ensure that nothing was altered. assert original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments assert original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure the method returned 0 as the number of entries # modified assert number == 0 def test_single_keyword_update(self): """ Tests the update method of a List of Embedded Documents with a single keyword. """ number = self.post1.comments.update(author="user4") self.post1.save() comments = self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure that the database was updated properly. assert comments[0].author == "user4" assert comments[1].author == "user4" # Ensure the method returned 2 as the number of entries # modified assert number == 2 def test_unicode(self): """ Tests that unicode strings handled correctly """ post = self.BlogPost( comments=[ self.Comments(author="user1", message="сообщение"), self.Comments(author="user2", message="хабарлама"), ] ).save() assert post.comments.get(message="сообщение").author == "user1" def test_save(self): """ Tests the save method of a List of Embedded Documents. """ comments = self.post1.comments new_comment = self.Comments(author="user4") comments.append(new_comment) comments.save() # Ensure that the new comment has been added to the database. assert new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments def test_delete(self): """ Tests the delete method of a List of Embedded Documents. """ number = self.post1.comments.delete() self.post1.save() # Ensure that all the comments under post1 were deleted in the # database. assert self.BlogPost.objects(id=self.post1.id)[0].comments == [] # Ensure that post1 comments were deleted from the list. assert self.post1.comments == [] # Ensure that comments still returned a EmbeddedDocumentList object. assert isinstance(self.post1.comments, EmbeddedDocumentList) # Ensure that the delete method returned 2 as the number of entries # deleted from the database assert number == 2 def test_empty_list_embedded_documents_with_unique_field(self): """ Tests that only one document with an empty list of embedded documents that have a unique field can be saved, but if the unique field is also sparse than multiple documents with an empty list can be saved. """ class EmbeddedWithUnique(EmbeddedDocument): number = IntField(unique=True) class A(Document): my_list = ListField(EmbeddedDocumentField(EmbeddedWithUnique)) A(my_list=[]).save() with pytest.raises(NotUniqueError): A(my_list=[]).save() class EmbeddedWithSparseUnique(EmbeddedDocument): number = IntField(unique=True, sparse=True) class B(Document): my_list = ListField(EmbeddedDocumentField(EmbeddedWithSparseUnique)) A.drop_collection() B.drop_collection() B(my_list=[]).save() B(my_list=[]).save() def test_filtered_delete(self): """ Tests the delete method of a List of Embedded Documents after the filter method has been called. """ comment = self.post1.comments[1] number = self.post1.comments.filter(author="user2").delete() self.post1.save() # Ensure that only the user2 comment was deleted. assert comment not in self.BlogPost.objects(id=self.post1.id)[0].comments assert len(self.BlogPost.objects(id=self.post1.id)[0].comments) == 1 # Ensure that the user2 comment no longer exists in the list. assert comment not in self.post1.comments assert len(self.post1.comments) == 1 # Ensure that the delete method returned 1 as the number of entries # deleted from the database assert number == 1 def test_custom_data(self): """ Tests that custom data is saved in the field object and doesn't interfere with the rest of field functionalities. """ custom_data = {"a": "a_value", "b": [1, 2]} class CustomData(Document): a_field = IntField() c_field = IntField(custom_data=custom_data) CustomData.drop_collection() a1 = CustomData(a_field=1, c_field=2).save() assert 2 == a1.c_field assert not hasattr(a1.c_field, "custom_data") assert hasattr(CustomData.c_field, "custom_data") assert custom_data["a"] == CustomData.c_field.custom_data["a"] if __name__ == "__main__": unittest.main()