From ac25f4b98bd8c4b6daad46faf1e8a163928d7bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 30 Aug 2019 16:13:30 +0300 Subject: [PATCH] ran unittest2pytest --- tests/all_warnings/test_warnings.py | 6 +- tests/document/test_class_methods.py | 133 +- tests/document/test_delta.py | 638 +++---- tests/document/test_dynamic.py | 193 +- tests/document/test_indexes.py | 285 ++- tests/document/test_inheritance.py | 256 ++- tests/document/test_instance.py | 1047 +++++----- tests/document/test_json_serialisation.py | 8 +- tests/document/test_validation.py | 76 +- tests/fields/test_binary_field.py | 40 +- tests/fields/test_boolean_field.py | 16 +- tests/fields/test_cached_reference_field.py | 179 +- tests/fields/test_complex_datetime_field.py | 46 +- tests/fields/test_date_field.py | 45 +- tests/fields/test_datetime_field.py | 71 +- tests/fields/test_decimal_field.py | 30 +- tests/fields/test_dict_field.py | 139 +- tests/fields/test_email_field.py | 37 +- tests/fields/test_embedded_document_field.py | 103 +- tests/fields/test_fields.py | 851 ++++----- tests/fields/test_file_field.py | 166 +- tests/fields/test_float_field.py | 20 +- tests/fields/test_geo_fields.py | 46 +- tests/fields/test_int_field.py | 14 +- tests/fields/test_lazy_reference_field.py | 118 +- tests/fields/test_long_field.py | 16 +- tests/fields/test_map_field.py | 31 +- tests/fields/test_reference_field.py | 46 +- tests/fields/test_sequence_field.py | 99 +- tests/fields/test_url_field.py | 15 +- tests/fields/test_uuid_field.py | 19 +- tests/queryset/test_field_list.py | 197 +- tests/queryset/test_geo.py | 205 +- tests/queryset/test_modify.py | 32 +- tests/queryset/test_pickable.py | 10 +- tests/queryset/test_queryset.py | 1784 +++++++++--------- tests/queryset/test_transform.py | 178 +- tests/queryset/test_visitor.py | 172 +- tests/test_common.py | 6 +- tests/test_connection.py | 254 ++- tests/test_context_managers.py | 139 +- tests/test_datastructures.py | 241 +-- tests/test_dereference.py | 386 ++-- tests/test_replicaset_connection.py | 2 +- tests/test_signals.py | 265 ++- tests/test_utils.py | 15 +- 46 files changed, 4247 insertions(+), 4428 deletions(-) diff --git a/tests/all_warnings/test_warnings.py b/tests/all_warnings/test_warnings.py index 67204617..a9910121 100644 --- a/tests/all_warnings/test_warnings.py +++ b/tests/all_warnings/test_warnings.py @@ -31,7 +31,5 @@ class TestAllWarnings(unittest.TestCase): meta = {"collection": "fail"} warning = self.warning_list[0] - self.assertEqual(SyntaxWarning, warning["category"]) - self.assertEqual( - "non_abstract_base", InheritedDocumentFailTest._get_collection_name() - ) + assert SyntaxWarning == warning["category"] + assert "non_abstract_base" == InheritedDocumentFailTest._get_collection_name() diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index c5df0843..98909d2f 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -29,43 +29,40 @@ class TestClassMethods(unittest.TestCase): def test_definition(self): """Ensure that document may be defined using fields. """ - self.assertEqual( - ["_cls", "age", "id", "name"], sorted(self.Person._fields.keys()) - ) - self.assertEqual( - ["IntField", "ObjectIdField", "StringField", "StringField"], - sorted([x.__class__.__name__ for x in self.Person._fields.values()]), + assert ["_cls", "age", "id", "name"] == sorted(self.Person._fields.keys()) + assert ["IntField", "ObjectIdField", "StringField", "StringField"] == sorted( + [x.__class__.__name__ for x in self.Person._fields.values()] ) def test_get_db(self): """Ensure that get_db returns the expected db. """ db = self.Person._get_db() - self.assertEqual(self.db, db) + assert self.db == db def test_get_collection_name(self): """Ensure that get_collection_name returns the expected collection name. """ collection_name = "person" - self.assertEqual(collection_name, self.Person._get_collection_name()) + assert collection_name == self.Person._get_collection_name() def test_get_collection(self): """Ensure that get_collection returns the expected collection. """ collection_name = "person" collection = self.Person._get_collection() - self.assertEqual(self.db[collection_name], collection) + assert self.db[collection_name] == collection def test_drop_collection(self): """Ensure that the collection may be dropped from the database. """ collection_name = "person" self.Person(name="Test").save() - self.assertIn(collection_name, list_collection_names(self.db)) + assert collection_name in list_collection_names(self.db) self.Person.drop_collection() - self.assertNotIn(collection_name, list_collection_names(self.db)) + assert collection_name not in list_collection_names(self.db) def test_register_delete_rule(self): """Ensure that register delete rule adds a delete rule to the document @@ -75,12 +72,10 @@ class TestClassMethods(unittest.TestCase): class Job(Document): employee = ReferenceField(self.Person) - self.assertEqual(self.Person._meta.get("delete_rules"), None) + assert self.Person._meta.get("delete_rules") == None self.Person.register_delete_rule(Job, "employee", NULLIFY) - self.assertEqual( - self.Person._meta["delete_rules"], {(Job, "employee"): NULLIFY} - ) + assert self.Person._meta["delete_rules"] == {(Job, "employee"): NULLIFY} def test_compare_indexes(self): """ Ensure that the indexes are properly created and that @@ -98,22 +93,22 @@ class TestClassMethods(unittest.TestCase): BlogPost.drop_collection() BlogPost.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPost.ensure_index(["author", "description"]) - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [], "extra": [[("author", 1), ("description", 1)]]}, - ) + assert BlogPost.compare_indexes() == { + "missing": [], + "extra": [[("author", 1), ("description", 1)]], + } BlogPost._get_collection().drop_index("author_1_description_1") - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPost._get_collection().drop_index("author_1_title_1") - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [[("author", 1), ("title", 1)]], "extra": []}, - ) + assert BlogPost.compare_indexes() == { + "missing": [[("author", 1), ("title", 1)]], + "extra": [], + } def test_compare_indexes_inheritance(self): """ Ensure that the indexes are properly created and that @@ -138,22 +133,22 @@ class TestClassMethods(unittest.TestCase): BlogPost.ensure_indexes() BlogPostWithTags.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPostWithTags.ensure_index(["author", "tag_list"]) - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [], "extra": [[("_cls", 1), ("author", 1), ("tag_list", 1)]]}, - ) + assert BlogPost.compare_indexes() == { + "missing": [], + "extra": [[("_cls", 1), ("author", 1), ("tag_list", 1)]], + } BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tag_list_1") - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tags_1") - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [[("_cls", 1), ("author", 1), ("tags", 1)]], "extra": []}, - ) + assert BlogPost.compare_indexes() == { + "missing": [[("_cls", 1), ("author", 1), ("tags", 1)]], + "extra": [], + } def test_compare_indexes_multiple_subclasses(self): """ Ensure that compare_indexes behaves correctly if called from a @@ -182,13 +177,9 @@ class TestClassMethods(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithCustomField.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) - self.assertEqual( - BlogPostWithTags.compare_indexes(), {"missing": [], "extra": []} - ) - self.assertEqual( - BlogPostWithCustomField.compare_indexes(), {"missing": [], "extra": []} - ) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} + assert BlogPostWithTags.compare_indexes() == {"missing": [], "extra": []} + assert BlogPostWithCustomField.compare_indexes() == {"missing": [], "extra": []} def test_compare_indexes_for_text_indexes(self): """ Ensure that compare_indexes behaves correctly for text indexes """ @@ -210,7 +201,7 @@ class TestClassMethods(unittest.TestCase): Doc.ensure_indexes() actual = Doc.compare_indexes() expected = {"missing": [], "extra": []} - self.assertEqual(actual, expected) + assert actual == expected def test_list_indexes_inheritance(self): """ ensure that all of the indexes are listed regardless of the super- @@ -240,19 +231,14 @@ class TestClassMethods(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithTagsAndExtraText.ensure_indexes() - self.assertEqual(BlogPost.list_indexes(), BlogPostWithTags.list_indexes()) - self.assertEqual( - BlogPost.list_indexes(), BlogPostWithTagsAndExtraText.list_indexes() - ) - self.assertEqual( - BlogPost.list_indexes(), - [ - [("_cls", 1), ("author", 1), ("tags", 1)], - [("_cls", 1), ("author", 1), ("tags", 1), ("extra_text", 1)], - [(u"_id", 1)], - [("_cls", 1)], - ], - ) + assert BlogPost.list_indexes() == BlogPostWithTags.list_indexes() + assert BlogPost.list_indexes() == BlogPostWithTagsAndExtraText.list_indexes() + assert BlogPost.list_indexes() == [ + [("_cls", 1), ("author", 1), ("tags", 1)], + [("_cls", 1), ("author", 1), ("tags", 1), ("extra_text", 1)], + [(u"_id", 1)], + [("_cls", 1)], + ] def test_register_delete_rule_inherited(self): class Vaccine(Document): @@ -271,8 +257,8 @@ class TestClassMethods(unittest.TestCase): class Cat(Animal): name = StringField(required=True) - self.assertEqual(Vaccine._meta["delete_rules"][(Animal, "vaccine_made")], PULL) - self.assertEqual(Vaccine._meta["delete_rules"][(Cat, "vaccine_made")], PULL) + assert Vaccine._meta["delete_rules"][(Animal, "vaccine_made")] == PULL + assert Vaccine._meta["delete_rules"][(Cat, "vaccine_made")] == PULL def test_collection_naming(self): """Ensure that a collection with a specified name may be used. @@ -281,19 +267,17 @@ class TestClassMethods(unittest.TestCase): class DefaultNamingTest(Document): pass - self.assertEqual( - "default_naming_test", DefaultNamingTest._get_collection_name() - ) + assert "default_naming_test" == DefaultNamingTest._get_collection_name() class CustomNamingTest(Document): meta = {"collection": "pimp_my_collection"} - self.assertEqual("pimp_my_collection", CustomNamingTest._get_collection_name()) + assert "pimp_my_collection" == CustomNamingTest._get_collection_name() class DynamicNamingTest(Document): meta = {"collection": lambda c: "DYNAMO"} - self.assertEqual("DYNAMO", DynamicNamingTest._get_collection_name()) + assert "DYNAMO" == DynamicNamingTest._get_collection_name() # Use Abstract class to handle backwards compatibility class BaseDocument(Document): @@ -302,14 +286,12 @@ class TestClassMethods(unittest.TestCase): class OldNamingConvention(BaseDocument): pass - self.assertEqual( - "oldnamingconvention", OldNamingConvention._get_collection_name() - ) + assert "oldnamingconvention" == OldNamingConvention._get_collection_name() class InheritedAbstractNamingTest(BaseDocument): meta = {"collection": "wibble"} - self.assertEqual("wibble", InheritedAbstractNamingTest._get_collection_name()) + assert "wibble" == InheritedAbstractNamingTest._get_collection_name() # Mixin tests class BaseMixin(object): @@ -318,8 +300,9 @@ class TestClassMethods(unittest.TestCase): class OldMixinNamingConvention(Document, BaseMixin): pass - self.assertEqual( - "oldmixinnamingconvention", OldMixinNamingConvention._get_collection_name() + assert ( + "oldmixinnamingconvention" + == OldMixinNamingConvention._get_collection_name() ) class BaseMixin(object): @@ -331,7 +314,7 @@ class TestClassMethods(unittest.TestCase): class MyDocument(BaseDocument): pass - self.assertEqual("basedocument", MyDocument._get_collection_name()) + assert "basedocument" == MyDocument._get_collection_name() def test_custom_collection_name_operations(self): """Ensure that a collection with a specified name is used as expected. @@ -343,16 +326,16 @@ class TestClassMethods(unittest.TestCase): meta = {"collection": collection_name} Person(name="Test User").save() - self.assertIn(collection_name, list_collection_names(self.db)) + assert collection_name in list_collection_names(self.db) user_obj = self.db[collection_name].find_one() - self.assertEqual(user_obj["name"], "Test User") + assert user_obj["name"] == "Test User" user_obj = Person.objects[0] - self.assertEqual(user_obj.name, "Test User") + assert user_obj.name == "Test User" Person.drop_collection() - self.assertNotIn(collection_name, list_collection_names(self.db)) + assert collection_name not in list_collection_names(self.db) def test_collection_name_and_primary(self): """Ensure that a collection with a specified name may be used. @@ -365,7 +348,7 @@ class TestClassMethods(unittest.TestCase): Person(name="Test User").save() user_obj = Person.objects.first() - self.assertEqual(user_obj.name, "Test User") + assert user_obj.name == "Test User" Person.drop_collection() diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 632d9b3f..2324211b 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -41,40 +41,40 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) doc.string_field = "hello" - self.assertEqual(doc._get_changed_fields(), ["string_field"]) - self.assertEqual(doc._delta(), ({"string_field": "hello"}, {})) + assert doc._get_changed_fields() == ["string_field"] + assert doc._delta() == ({"string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ["int_field"]) - self.assertEqual(doc._delta(), ({"int_field": 1}, {})) + assert doc._get_changed_fields() == ["int_field"] + assert doc._delta() == ({"int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({"dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({"list_field": list_value}, {})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({"list_field": list_value}, {}) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({}, {"dict_field": 1})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({}, {"dict_field": 1}) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({}, {"list_field": 1})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({}, {"list_field": 1}) def test_delta_recursive(self): self.delta_recursive(Document, EmbeddedDocument) @@ -102,8 +102,8 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) embedded_1 = Embedded() embedded_1.id = "010101" @@ -113,7 +113,7 @@ class TestDelta(MongoDBTestCase): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ["embedded_field"]) + assert doc._get_changed_fields() == ["embedded_field"] embedded_delta = { "id": "010101", @@ -122,27 +122,27 @@ class TestDelta(MongoDBTestCase): "dict_field": {"hello": "world"}, "list_field": ["1", 2, {"hello": "world"}], } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - self.assertEqual(doc._delta(), ({"embedded_field": embedded_delta}, {})) + assert doc.embedded_field._delta() == (embedded_delta, {}) + assert doc._delta() == ({"embedded_field": embedded_delta}, {}) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["embedded_field.dict_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"dict_field": 1})) - self.assertEqual(doc._delta(), ({}, {"embedded_field.dict_field": 1})) + assert doc._get_changed_fields() == ["embedded_field.dict_field"] + assert doc.embedded_field._delta() == ({}, {"dict_field": 1}) + assert doc._delta() == ({}, {"embedded_field.dict_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) + assert doc.embedded_field.dict_field == {} doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"list_field": 1})) - self.assertEqual(doc._delta(), ({}, {"embedded_field.list_field": 1})) + assert doc._get_changed_fields() == ["embedded_field.list_field"] + assert doc.embedded_field._delta() == ({}, {"list_field": 1}) + assert doc._delta() == ({}, {"embedded_field.list_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) + assert doc.embedded_field.list_field == [] embedded_2 = Embedded() embedded_2.string_field = "hello" @@ -151,148 +151,128 @@ class TestDelta(MongoDBTestCase): embedded_2.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field.list_field = ["1", 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"]) + assert doc._get_changed_fields() == ["embedded_field.list_field"] - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "string_field": "hello", - "dict_field": {"hello": "world"}, - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc.embedded_field._delta() == ( + { + "list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "dict_field": {"hello": "world"}, + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) - self.assertEqual( - doc._delta(), - ( - { - "embedded_field.list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "string_field": "hello", - "dict_field": {"hello": "world"}, - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "embedded_field.list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "dict_field": {"hello": "world"}, + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[0], "1") - self.assertEqual(doc.embedded_field.list_field[1], 2) + assert doc.embedded_field.list_field[0] == "1" + assert doc.embedded_field.list_field[1] == 2 for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) + assert doc.embedded_field.list_field[2][k] == embedded_2[k] doc.embedded_field.list_field[2].string_field = "world" - self.assertEqual( - doc._get_changed_fields(), ["embedded_field.list_field.2.string_field"] + assert doc._get_changed_fields() == ["embedded_field.list_field.2.string_field"] + assert doc.embedded_field._delta() == ( + {"list_field.2.string_field": "world"}, + {}, ) - self.assertEqual( - doc.embedded_field._delta(), ({"list_field.2.string_field": "world"}, {}) - ) - self.assertEqual( - doc._delta(), ({"embedded_field.list_field.2.string_field": "world"}, {}) + assert doc._delta() == ( + {"embedded_field.list_field.2.string_field": "world"}, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "world") + assert doc.embedded_field.list_field[2].string_field == "world" # Test multiple assignments doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field.2"]) - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "list_field.2": { - "_cls": "Embedded", - "string_field": "hello world", - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - "dict_field": {"hello": "world"}, - } - }, - {}, - ), + assert doc._get_changed_fields() == ["embedded_field.list_field.2"] + assert doc.embedded_field._delta() == ( + { + "list_field.2": { + "_cls": "Embedded", + "string_field": "hello world", + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + "dict_field": {"hello": "world"}, + } + }, + {}, ) - self.assertEqual( - doc._delta(), - ( - { - "embedded_field.list_field.2": { - "_cls": "Embedded", - "string_field": "hello world", - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - "dict_field": {"hello": "world"}, - } - }, - {}, - ), + assert doc._delta() == ( + { + "embedded_field.list_field.2": { + "_cls": "Embedded", + "string_field": "hello world", + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + "dict_field": {"hello": "world"}, + } + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world") + assert doc.embedded_field.list_field[2].string_field == "hello world" # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual( - doc._delta(), - ({"embedded_field.list_field.2.list_field": [2, {"hello": "world"}]}, {}), + assert doc._delta() == ( + {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}]}, + {}, ) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual( - doc._delta(), - ( - {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}, 1]}, - {}, - ), + assert doc._delta() == ( + {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}, 1]}, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1] - ) + assert doc.embedded_field.list_field[2].list_field == [2, {"hello": "world"}, 1] doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}] - ) + assert doc.embedded_field.list_field[2].list_field == [1, 2, {"hello": "world"}] del doc.embedded_field.list_field[2].list_field[2]["hello"] - self.assertEqual( - doc._delta(), ({}, {"embedded_field.list_field.2.list_field.2.hello": 1}) + assert doc._delta() == ( + {}, + {"embedded_field.list_field.2.list_field.2.hello": 1}, ) doc.save() doc = doc.reload(10) del doc.embedded_field.list_field[2].list_field - self.assertEqual( - doc._delta(), ({}, {"embedded_field.list_field.2.list_field": 1}) - ) + assert doc._delta() == ({}, {"embedded_field.list_field.2.list_field": 1}) doc.save() doc = doc.reload(10) @@ -302,12 +282,8 @@ class TestDelta(MongoDBTestCase): doc = doc.reload(10) doc.dict_field["Embedded"].string_field = "Hello World" - self.assertEqual( - doc._get_changed_fields(), ["dict_field.Embedded.string_field"] - ) - self.assertEqual( - doc._delta(), ({"dict_field.Embedded.string_field": "Hello World"}, {}) - ) + assert doc._get_changed_fields() == ["dict_field.Embedded.string_field"] + assert doc._delta() == ({"dict_field.Embedded.string_field": "Hello World"}, {}) def test_circular_reference_deltas(self): self.circular_reference_deltas(Document, Document) @@ -338,8 +314,8 @@ class TestDelta(MongoDBTestCase): p = Person.objects[0].select_related() o = Organization.objects.first() - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) + assert p.owns[0] == o + assert o.owner == p def test_circular_reference_deltas_2(self): self.circular_reference_deltas_2(Document, Document) @@ -379,9 +355,9 @@ class TestDelta(MongoDBTestCase): e = Person.objects.get(name="employee") o = Organization.objects.first() - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) - self.assertEqual(e.employer, o) + assert p.owns[0] == o + assert o.owner == p + assert e.employer == o return person, organization, employee @@ -401,40 +377,40 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) doc.string_field = "hello" - self.assertEqual(doc._get_changed_fields(), ["db_string_field"]) - self.assertEqual(doc._delta(), ({"db_string_field": "hello"}, {})) + assert doc._get_changed_fields() == ["db_string_field"] + assert doc._delta() == ({"db_string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ["db_int_field"]) - self.assertEqual(doc._delta(), ({"db_int_field": 1}, {})) + assert doc._get_changed_fields() == ["db_int_field"] + assert doc._delta() == ({"db_int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ["db_dict_field"]) - self.assertEqual(doc._delta(), ({"db_dict_field": dict_value}, {})) + assert doc._get_changed_fields() == ["db_dict_field"] + assert doc._delta() == ({"db_dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ["db_list_field"]) - self.assertEqual(doc._delta(), ({"db_list_field": list_value}, {})) + assert doc._get_changed_fields() == ["db_list_field"] + assert doc._delta() == ({"db_list_field": list_value}, {}) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["db_dict_field"]) - self.assertEqual(doc._delta(), ({}, {"db_dict_field": 1})) + assert doc._get_changed_fields() == ["db_dict_field"] + assert doc._delta() == ({}, {"db_dict_field": 1}) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["db_list_field"]) - self.assertEqual(doc._delta(), ({}, {"db_list_field": 1})) + assert doc._get_changed_fields() == ["db_list_field"] + assert doc._delta() == ({}, {"db_list_field": 1}) # Test it saves that data doc = Doc() @@ -447,10 +423,10 @@ class TestDelta(MongoDBTestCase): doc.save() doc = doc.reload(10) - self.assertEqual(doc.string_field, "hello") - self.assertEqual(doc.int_field, 1) - self.assertEqual(doc.dict_field, {"hello": "world"}) - self.assertEqual(doc.list_field, ["1", 2, {"hello": "world"}]) + assert doc.string_field == "hello" + assert doc.int_field == 1 + assert doc.dict_field == {"hello": "world"} + assert doc.list_field == ["1", 2, {"hello": "world"}] def test_delta_recursive_db_field(self): self.delta_recursive_db_field(Document, EmbeddedDocument) @@ -479,8 +455,8 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) embedded_1 = Embedded() embedded_1.string_field = "hello" @@ -489,7 +465,7 @@ class TestDelta(MongoDBTestCase): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field"]) + assert doc._get_changed_fields() == ["db_embedded_field"] embedded_delta = { "db_string_field": "hello", @@ -497,27 +473,27 @@ class TestDelta(MongoDBTestCase): "db_dict_field": {"hello": "world"}, "db_list_field": ["1", 2, {"hello": "world"}], } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - self.assertEqual(doc._delta(), ({"db_embedded_field": embedded_delta}, {})) + assert doc.embedded_field._delta() == (embedded_delta, {}) + assert doc._delta() == ({"db_embedded_field": embedded_delta}, {}) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_dict_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"db_dict_field": 1})) - self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_dict_field": 1})) + assert doc._get_changed_fields() == ["db_embedded_field.db_dict_field"] + assert doc.embedded_field._delta() == ({}, {"db_dict_field": 1}) + assert doc._delta() == ({}, {"db_embedded_field.db_dict_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) + assert doc.embedded_field.dict_field == {} doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"db_list_field": 1})) - self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_list_field": 1})) + assert doc._get_changed_fields() == ["db_embedded_field.db_list_field"] + assert doc.embedded_field._delta() == ({}, {"db_list_field": 1}) + assert doc._delta() == ({}, {"db_embedded_field.db_list_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) + assert doc.embedded_field.list_field == [] embedded_2 = Embedded() embedded_2.string_field = "hello" @@ -526,166 +502,142 @@ class TestDelta(MongoDBTestCase): embedded_2.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field.list_field = ["1", 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"]) - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "db_list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "db_string_field": "hello", - "db_dict_field": {"hello": "world"}, - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc._get_changed_fields() == ["db_embedded_field.db_list_field"] + assert doc.embedded_field._delta() == ( + { + "db_list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "db_string_field": "hello", + "db_dict_field": {"hello": "world"}, + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "db_string_field": "hello", - "db_dict_field": {"hello": "world"}, - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "db_string_field": "hello", + "db_dict_field": {"hello": "world"}, + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[0], "1") - self.assertEqual(doc.embedded_field.list_field[1], 2) + assert doc.embedded_field.list_field[0] == "1" + assert doc.embedded_field.list_field[1] == 2 for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) + assert doc.embedded_field.list_field[2][k] == embedded_2[k] doc.embedded_field.list_field[2].string_field = "world" - self.assertEqual( - doc._get_changed_fields(), - ["db_embedded_field.db_list_field.2.db_string_field"], + assert doc._get_changed_fields() == [ + "db_embedded_field.db_list_field.2.db_string_field" + ] + assert doc.embedded_field._delta() == ( + {"db_list_field.2.db_string_field": "world"}, + {}, ) - self.assertEqual( - doc.embedded_field._delta(), - ({"db_list_field.2.db_string_field": "world"}, {}), - ) - self.assertEqual( - doc._delta(), - ({"db_embedded_field.db_list_field.2.db_string_field": "world"}, {}), + assert doc._delta() == ( + {"db_embedded_field.db_list_field.2.db_string_field": "world"}, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "world") + assert doc.embedded_field.list_field[2].string_field == "world" # Test multiple assignments doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual( - doc._get_changed_fields(), ["db_embedded_field.db_list_field.2"] + assert doc._get_changed_fields() == ["db_embedded_field.db_list_field.2"] + assert doc.embedded_field._delta() == ( + { + "db_list_field.2": { + "_cls": "Embedded", + "db_string_field": "hello world", + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + "db_dict_field": {"hello": "world"}, + } + }, + {}, ) - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "db_list_field.2": { - "_cls": "Embedded", - "db_string_field": "hello world", - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - "db_dict_field": {"hello": "world"}, - } - }, - {}, - ), - ) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field.2": { - "_cls": "Embedded", - "db_string_field": "hello world", - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - "db_dict_field": {"hello": "world"}, - } - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field.2": { + "_cls": "Embedded", + "db_string_field": "hello world", + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + "db_dict_field": {"hello": "world"}, + } + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world") + assert doc.embedded_field.list_field[2].string_field == "hello world" # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field.2.db_list_field": [ - 2, - {"hello": "world"}, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field.2.db_list_field": [ + 2, + {"hello": "world"}, + ] + }, + {}, ) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field.2.db_list_field": [ - 2, - {"hello": "world"}, - 1, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field.2.db_list_field": [ + 2, + {"hello": "world"}, + 1, + ] + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1] - ) + assert doc.embedded_field.list_field[2].list_field == [2, {"hello": "world"}, 1] doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}] - ) + assert doc.embedded_field.list_field[2].list_field == [1, 2, {"hello": "world"}] del doc.embedded_field.list_field[2].list_field[2]["hello"] - self.assertEqual( - doc._delta(), - ({}, {"db_embedded_field.db_list_field.2.db_list_field.2.hello": 1}), + assert doc._delta() == ( + {}, + {"db_embedded_field.db_list_field.2.db_list_field.2.hello": 1}, ) doc.save() doc = doc.reload(10) del doc.embedded_field.list_field[2].list_field - self.assertEqual( - doc._delta(), ({}, {"db_embedded_field.db_list_field.2.db_list_field": 1}) + assert doc._delta() == ( + {}, + {"db_embedded_field.db_list_field.2.db_list_field": 1}, ) def test_delta_for_dynamic_documents(self): @@ -696,14 +648,16 @@ class TestDelta(MongoDBTestCase): Person.drop_collection() p = Person(name="James", age=34) - self.assertEqual( - p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {}) + assert p._delta() == ( + SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), + {}, ) p.doc = 123 del p.doc - self.assertEqual( - p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {}) + assert p._delta() == ( + SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), + {}, ) p = Person() @@ -712,18 +666,18 @@ class TestDelta(MongoDBTestCase): p.save() p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ["age"]) - self.assertEqual(p._delta(), ({"age": 24}, {})) + assert p.age == 24 + assert p._get_changed_fields() == ["age"] + assert p._delta() == ({"age": 24}, {}) p = Person.objects(age=22).get() p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ["age"]) - self.assertEqual(p._delta(), ({"age": 24}, {})) + assert p.age == 24 + assert p._get_changed_fields() == ["age"] + assert p._delta() == ({"age": 24}, {}) p.save() - self.assertEqual(1, Person.objects(age=24).count()) + assert 1 == Person.objects(age=24).count() def test_dynamic_delta(self): class Doc(DynamicDocument): @@ -734,40 +688,40 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) doc.string_field = "hello" - self.assertEqual(doc._get_changed_fields(), ["string_field"]) - self.assertEqual(doc._delta(), ({"string_field": "hello"}, {})) + assert doc._get_changed_fields() == ["string_field"] + assert doc._delta() == ({"string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ["int_field"]) - self.assertEqual(doc._delta(), ({"int_field": 1}, {})) + assert doc._get_changed_fields() == ["int_field"] + assert doc._delta() == ({"int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({"dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({"list_field": list_value}, {})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({"list_field": list_value}, {}) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({}, {"dict_field": 1})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({}, {"dict_field": 1}) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({}, {"list_field": 1})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({}, {"list_field": 1}) def test_delta_with_dbref_true(self): person, organization, employee = self.circular_reference_deltas_2( @@ -775,16 +729,16 @@ class TestDelta(MongoDBTestCase): ) employee.name = "test" - self.assertEqual(organization._get_changed_fields(), []) + assert organization._get_changed_fields() == [] updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertEqual({}, updates) + assert {} == removals + assert {} == updates organization.employees.append(person) updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertIn("employees", updates) + assert {} == removals + assert "employees" in updates def test_delta_with_dbref_false(self): person, organization, employee = self.circular_reference_deltas_2( @@ -792,16 +746,16 @@ class TestDelta(MongoDBTestCase): ) employee.name = "test" - self.assertEqual(organization._get_changed_fields(), []) + assert organization._get_changed_fields() == [] updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertEqual({}, updates) + assert {} == removals + assert {} == updates organization.employees.append(person) updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertIn("employees", updates) + assert {} == removals + assert "employees" in updates def test_nested_nested_fields_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -821,11 +775,11 @@ class TestDelta(MongoDBTestCase): subdoc = mydoc.subs["a"]["b"] subdoc.name = "bar" - self.assertEqual(["name"], subdoc._get_changed_fields()) - self.assertEqual(["subs.a.b.name"], mydoc._get_changed_fields()) + assert ["name"] == subdoc._get_changed_fields() + assert ["subs.a.b.name"] == mydoc._get_changed_fields() mydoc._clear_changed_fields() - self.assertEqual([], mydoc._get_changed_fields()) + assert [] == mydoc._get_changed_fields() def test_lower_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -840,17 +794,17 @@ class TestDelta(MongoDBTestCase): mydoc = MyDoc.objects.first() mydoc.subs["a"] = EmbeddedDoc() - self.assertEqual(["subs.a"], mydoc._get_changed_fields()) + assert ["subs.a"] == mydoc._get_changed_fields() subdoc = mydoc.subs["a"] subdoc.name = "bar" - self.assertEqual(["name"], subdoc._get_changed_fields()) - self.assertEqual(["subs.a"], mydoc._get_changed_fields()) + assert ["name"] == subdoc._get_changed_fields() + assert ["subs.a"] == mydoc._get_changed_fields() mydoc.save() mydoc._clear_changed_fields() - self.assertEqual([], mydoc._get_changed_fields()) + assert [] == mydoc._get_changed_fields() def test_upper_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -867,15 +821,15 @@ class TestDelta(MongoDBTestCase): subdoc = mydoc.subs["a"] subdoc.name = "bar" - self.assertEqual(["name"], subdoc._get_changed_fields()) - self.assertEqual(["subs.a.name"], mydoc._get_changed_fields()) + assert ["name"] == subdoc._get_changed_fields() + assert ["subs.a.name"] == mydoc._get_changed_fields() mydoc.subs["a"] = EmbeddedDoc() - self.assertEqual(["subs.a"], mydoc._get_changed_fields()) + assert ["subs.a"] == mydoc._get_changed_fields() mydoc.save() mydoc._clear_changed_fields() - self.assertEqual([], mydoc._get_changed_fields()) + assert [] == mydoc._get_changed_fields() def test_referenced_object_changed_attributes(self): """Ensures that when you save a new reference to a field, the referenced object isn't altered""" @@ -902,22 +856,22 @@ class TestDelta(MongoDBTestCase): org1.reload() org2.reload() user.reload() - self.assertEqual(org1.name, "Org 1") - self.assertEqual(org2.name, "Org 2") - self.assertEqual(user.name, "Fred") + assert org1.name == "Org 1" + assert org2.name == "Org 2" + assert user.name == "Fred" user.name = "Harold" user.org = org2 org2.name = "New Org 2" - self.assertEqual(org2.name, "New Org 2") + assert org2.name == "New Org 2" user.save() org2.save() - self.assertEqual(org2.name, "New Org 2") + assert org2.name == "New Org 2" org2.reload() - self.assertEqual(org2.name, "New Org 2") + assert org2.name == "New Org 2" def test_delta_for_nested_map_fields(self): class UInfoDocument(Document): @@ -950,12 +904,12 @@ class TestDelta(MongoDBTestCase): d.users["007"]["rolist"].append(EmbeddedRole(type="oops")) d.users["007"]["info"] = uinfo delta = d._delta() - self.assertEqual(True, "users.007.roles.666" in delta[0]) - self.assertEqual(True, "users.007.rolist" in delta[0]) - self.assertEqual(True, "users.007.info" in delta[0]) - self.assertEqual("superadmin", delta[0]["users.007.roles.666"]["type"]) - self.assertEqual("oops", delta[0]["users.007.rolist"][0]["type"]) - self.assertEqual(uinfo.id, delta[0]["users.007.info"]) + assert True == ("users.007.roles.666" in delta[0]) + assert True == ("users.007.rolist" in delta[0]) + assert True == ("users.007.info" in delta[0]) + assert "superadmin" == delta[0]["users.007.roles.666"]["type"] + assert "oops" == delta[0]["users.007.rolist"][0]["type"] + assert uinfo.id == delta[0]["users.007.info"] if __name__ == "__main__": diff --git a/tests/document/test_dynamic.py b/tests/document/test_dynamic.py index 6b517d24..a6f46862 100644 --- a/tests/document/test_dynamic.py +++ b/tests/document/test_dynamic.py @@ -2,6 +2,7 @@ import unittest from mongoengine import * from tests.utils import MongoDBTestCase +import pytest __all__ = ("TestDynamicDocument",) @@ -25,15 +26,15 @@ class TestDynamicDocument(MongoDBTestCase): p.name = "James" p.age = 34 - self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", "age": 34}) - self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"]) + assert p.to_mongo() == {"_cls": "Person", "name": "James", "age": 34} + assert p.to_mongo().keys() == ["_cls", "name", "age"] p.save() - self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"]) + assert p.to_mongo().keys() == ["_id", "_cls", "name", "age"] - self.assertEqual(self.Person.objects.first().age, 34) + assert self.Person.objects.first().age == 34 # Confirm no changes to self.Person - self.assertFalse(hasattr(self.Person, "age")) + assert not hasattr(self.Person, "age") def test_change_scope_of_variable(self): """Test changing the scope of a dynamic field has no adverse effects""" @@ -47,7 +48,7 @@ class TestDynamicDocument(MongoDBTestCase): p.save() p = self.Person.objects.get() - self.assertEqual(p.misc, {"hello": "world"}) + assert p.misc == {"hello": "world"} def test_delete_dynamic_field(self): """Test deleting a dynamic field works""" @@ -62,19 +63,19 @@ class TestDynamicDocument(MongoDBTestCase): p.save() p = self.Person.objects.get() - self.assertEqual(p.misc, {"hello": "world"}) + assert p.misc == {"hello": "world"} collection = self.db[self.Person._get_collection_name()] obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "misc", "name"]) + assert sorted(obj.keys()) == ["_cls", "_id", "misc", "name"] del p.misc p.save() p = self.Person.objects.get() - self.assertFalse(hasattr(p, "misc")) + assert not hasattr(p, "misc") obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "name"]) + assert sorted(obj.keys()) == ["_cls", "_id", "name"] def test_reload_after_unsetting(self): p = self.Person() @@ -88,12 +89,12 @@ class TestDynamicDocument(MongoDBTestCase): p = self.Person.objects.create() p.update(age=1) - self.assertEqual(len(p._data), 3) - self.assertEqual(sorted(p._data.keys()), ["_cls", "id", "name"]) + assert len(p._data) == 3 + assert sorted(p._data.keys()) == ["_cls", "id", "name"] p.reload() - self.assertEqual(len(p._data), 4) - self.assertEqual(sorted(p._data.keys()), ["_cls", "age", "id", "name"]) + assert len(p._data) == 4 + assert sorted(p._data.keys()) == ["_cls", "age", "id", "name"] def test_fields_without_underscore(self): """Ensure we can query dynamic fields""" @@ -103,16 +104,18 @@ class TestDynamicDocument(MongoDBTestCase): p.save() raw_p = Person.objects.as_pymongo().get(id=p.id) - self.assertEqual(raw_p, {"_cls": u"Person", "_id": p.id, "name": u"Dean"}) + assert raw_p == {"_cls": u"Person", "_id": p.id, "name": u"Dean"} p.name = "OldDean" p.newattr = "garbage" p.save() raw_p = Person.objects.as_pymongo().get(id=p.id) - self.assertEqual( - raw_p, - {"_cls": u"Person", "_id": p.id, "name": "OldDean", "newattr": u"garbage"}, - ) + assert raw_p == { + "_cls": u"Person", + "_id": p.id, + "name": "OldDean", + "newattr": u"garbage", + } def test_fields_containing_underscore(self): """Ensure we can query dynamic fields""" @@ -127,14 +130,14 @@ class TestDynamicDocument(MongoDBTestCase): p.save() raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id) - self.assertEqual(raw_p, {"_id": p.id, "_name": u"Dean", "name": u"Dean"}) + assert raw_p == {"_id": p.id, "_name": u"Dean", "name": u"Dean"} p.name = "OldDean" p._name = "NewDean" p._newattr1 = "garbage" # Unknown fields won't be added p.save() raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id) - self.assertEqual(raw_p, {"_id": p.id, "_name": u"NewDean", "name": u"OldDean"}) + assert raw_p == {"_id": p.id, "_name": u"NewDean", "name": u"OldDean"} def test_dynamic_document_queries(self): """Ensure we can query dynamic fields""" @@ -143,10 +146,10 @@ class TestDynamicDocument(MongoDBTestCase): p.age = 22 p.save() - self.assertEqual(1, self.Person.objects(age=22).count()) + assert 1 == self.Person.objects(age=22).count() p = self.Person.objects(age=22) p = p.get() - self.assertEqual(22, p.age) + assert 22 == p.age def test_complex_dynamic_document_queries(self): class Person(DynamicDocument): @@ -166,8 +169,8 @@ class TestDynamicDocument(MongoDBTestCase): p2.age = 10 p2.save() - self.assertEqual(Person.objects(age__icontains="ten").count(), 2) - self.assertEqual(Person.objects(age__gte=10).count(), 1) + assert Person.objects(age__icontains="ten").count() == 2 + assert Person.objects(age__gte=10).count() == 1 def test_complex_data_lookups(self): """Ensure you can query dynamic document dynamic fields""" @@ -175,12 +178,12 @@ class TestDynamicDocument(MongoDBTestCase): p.misc = {"hello": "world"} p.save() - self.assertEqual(1, self.Person.objects(misc__hello="world").count()) + assert 1 == self.Person.objects(misc__hello="world").count() def test_three_level_complex_data_lookups(self): """Ensure you can query three level document dynamic fields""" self.Person.objects.create(misc={"hello": {"hello2": "world"}}) - self.assertEqual(1, self.Person.objects(misc__hello__hello2="world").count()) + assert 1 == self.Person.objects(misc__hello__hello2="world").count() def test_complex_embedded_document_validation(self): """Ensure embedded dynamic documents may be validated""" @@ -198,11 +201,13 @@ class TestDynamicDocument(MongoDBTestCase): embedded_doc_1.validate() embedded_doc_2 = Embedded(content="this is not a url") - self.assertRaises(ValidationError, embedded_doc_2.validate) + with pytest.raises(ValidationError): + embedded_doc_2.validate() doc.embedded_field_1 = embedded_doc_1 doc.embedded_field_2 = embedded_doc_2 - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() def test_inheritance(self): """Ensure that dynamic document plays nice with inheritance""" @@ -212,11 +217,9 @@ class TestDynamicDocument(MongoDBTestCase): Employee.drop_collection() - self.assertIn("name", Employee._fields) - self.assertIn("salary", Employee._fields) - self.assertEqual( - Employee._get_collection_name(), self.Person._get_collection_name() - ) + assert "name" in Employee._fields + assert "salary" in Employee._fields + assert Employee._get_collection_name() == self.Person._get_collection_name() joe_bloggs = Employee() joe_bloggs.name = "Joe Bloggs" @@ -224,11 +227,11 @@ class TestDynamicDocument(MongoDBTestCase): joe_bloggs.age = 20 joe_bloggs.save() - self.assertEqual(1, self.Person.objects(age=20).count()) - self.assertEqual(1, Employee.objects(age=20).count()) + assert 1 == self.Person.objects(age=20).count() + assert 1 == Employee.objects(age=20).count() joe_bloggs = self.Person.objects.first() - self.assertIsInstance(joe_bloggs, Employee) + assert isinstance(joe_bloggs, Employee) def test_embedded_dynamic_document(self): """Test dynamic embedded documents""" @@ -249,26 +252,23 @@ class TestDynamicDocument(MongoDBTestCase): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual( - doc.to_mongo(), - { - "embedded_field": { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ["1", 2, {"hello": "world"}], - } - }, - ) + assert doc.to_mongo() == { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ["1", 2, {"hello": "world"}], + } + } doc.save() doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"}) - self.assertEqual(doc.embedded_field.list_field, ["1", 2, {"hello": "world"}]) + assert doc.embedded_field.__class__ == Embedded + assert doc.embedded_field.string_field == "hello" + assert doc.embedded_field.int_field == 1 + assert doc.embedded_field.dict_field == {"hello": "world"} + assert doc.embedded_field.list_field == ["1", 2, {"hello": "world"}] def test_complex_embedded_documents(self): """Test complex dynamic embedded documents setups""" @@ -296,44 +296,41 @@ class TestDynamicDocument(MongoDBTestCase): embedded_1.list_field = ["1", 2, embedded_2] doc.embedded_field = embedded_1 - self.assertEqual( - doc.to_mongo(), - { - "embedded_field": { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ["1", 2, {"hello": "world"}], - }, - ], - } - }, - ) + assert doc.to_mongo() == { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ["1", 2, {"hello": "world"}], + }, + ], + } + } doc.save() doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"}) - self.assertEqual(doc.embedded_field.list_field[0], "1") - self.assertEqual(doc.embedded_field.list_field[1], 2) + assert doc.embedded_field.__class__ == Embedded + assert doc.embedded_field.string_field == "hello" + assert doc.embedded_field.int_field == 1 + assert doc.embedded_field.dict_field == {"hello": "world"} + assert doc.embedded_field.list_field[0] == "1" + assert doc.embedded_field.list_field[1] == 2 embedded_field = doc.embedded_field.list_field[2] - self.assertEqual(embedded_field.__class__, Embedded) - self.assertEqual(embedded_field.string_field, "hello") - self.assertEqual(embedded_field.int_field, 1) - self.assertEqual(embedded_field.dict_field, {"hello": "world"}) - self.assertEqual(embedded_field.list_field, ["1", 2, {"hello": "world"}]) + assert embedded_field.__class__ == Embedded + assert embedded_field.string_field == "hello" + assert embedded_field.int_field == 1 + assert embedded_field.dict_field == {"hello": "world"} + assert embedded_field.list_field == ["1", 2, {"hello": "world"}] def test_dynamic_and_embedded(self): """Ensure embedded documents play nicely""" @@ -352,18 +349,18 @@ class TestDynamicDocument(MongoDBTestCase): person.address.city = "Lundenne" person.save() - self.assertEqual(Person.objects.first().address.city, "Lundenne") + assert Person.objects.first().address.city == "Lundenne" person = Person.objects.first() person.address = Address(city="Londinium") person.save() - self.assertEqual(Person.objects.first().address.city, "Londinium") + assert Person.objects.first().address.city == "Londinium" person = Person.objects.first() person.age = 35 person.save() - self.assertEqual(Person.objects.first().age, 35) + assert Person.objects.first().age == 35 def test_dynamic_embedded_works_with_only(self): """Ensure custom fieldnames on a dynamic embedded document are found by qs.only()""" @@ -380,10 +377,10 @@ class TestDynamicDocument(MongoDBTestCase): name="Eric", address=Address(city="San Francisco", street_number="1337") ).save() - self.assertEqual(Person.objects.first().address.street_number, "1337") - self.assertEqual( - Person.objects.only("address__street_number").first().address.street_number, - "1337", + assert Person.objects.first().address.street_number == "1337" + assert ( + Person.objects.only("address__street_number").first().address.street_number + == "1337" ) def test_dynamic_and_embedded_dict_access(self): @@ -408,20 +405,20 @@ class TestDynamicDocument(MongoDBTestCase): person["address"]["city"] = "Lundenne" person.save() - self.assertEqual(Person.objects.first().address.city, "Lundenne") + assert Person.objects.first().address.city == "Lundenne" - self.assertEqual(Person.objects.first().phone, "555-1212") + assert Person.objects.first().phone == "555-1212" person = Person.objects.first() person.address = Address(city="Londinium") person.save() - self.assertEqual(Person.objects.first().address.city, "Londinium") + assert Person.objects.first().address.city == "Londinium" person = Person.objects.first() person["age"] = 35 person.save() - self.assertEqual(Person.objects.first().age, 35) + assert Person.objects.first().age == 35 if __name__ == "__main__": diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 1b0304c4..cc1aae52 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -9,6 +9,7 @@ from six import iteritems from mongoengine import * from mongoengine.connection import get_db +import pytest class TestIndexes(unittest.TestCase): @@ -53,15 +54,15 @@ class TestIndexes(unittest.TestCase): {"fields": [("tags", 1)]}, {"fields": [("category", 1), ("addDate", -1)]}, ] - self.assertEqual(expected_specs, BlogPost._meta["index_specs"]) + assert expected_specs == BlogPost._meta["index_specs"] BlogPost.ensure_indexes() info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') - self.assertEqual(len(info), 4) + assert len(info) == 4 info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected["fields"], info) + assert expected["fields"] in info def _index_test_inheritance(self, InheritFrom): class BlogPost(InheritFrom): @@ -78,7 +79,7 @@ class TestIndexes(unittest.TestCase): {"fields": [("_cls", 1), ("tags", 1)]}, {"fields": [("_cls", 1), ("category", 1), ("addDate", -1)]}, ] - self.assertEqual(expected_specs, BlogPost._meta["index_specs"]) + assert expected_specs == BlogPost._meta["index_specs"] BlogPost.ensure_indexes() info = BlogPost.objects._collection.index_information() @@ -86,17 +87,17 @@ class TestIndexes(unittest.TestCase): # NB: there is no index on _cls by itself, since # the indices on -date and tags will both contain # _cls as first element in the key - self.assertEqual(len(info), 4) + assert len(info) == 4 info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected["fields"], info) + assert expected["fields"] in info class ExtendedBlogPost(BlogPost): title = StringField() meta = {"indexes": ["title"]} expected_specs.append({"fields": [("_cls", 1), ("title", 1)]}) - self.assertEqual(expected_specs, ExtendedBlogPost._meta["index_specs"]) + assert expected_specs == ExtendedBlogPost._meta["index_specs"] BlogPost.drop_collection() @@ -104,7 +105,7 @@ class TestIndexes(unittest.TestCase): info = ExtendedBlogPost.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected["fields"], info) + assert expected["fields"] in info def test_indexes_document_inheritance(self): """Ensure that indexes are used when meta[indexes] is specified for @@ -128,10 +129,8 @@ class TestIndexes(unittest.TestCase): class B(A): description = StringField() - self.assertEqual(A._meta["index_specs"], B._meta["index_specs"]) - self.assertEqual( - [{"fields": [("_cls", 1), ("title", 1)]}], A._meta["index_specs"] - ) + assert A._meta["index_specs"] == B._meta["index_specs"] + assert [{"fields": [("_cls", 1), ("title", 1)]}] == A._meta["index_specs"] def test_index_no_cls(self): """Ensure index specs are inhertited correctly""" @@ -144,11 +143,11 @@ class TestIndexes(unittest.TestCase): "index_cls": False, } - self.assertEqual([("title", 1)], A._meta["index_specs"][0]["fields"]) + assert [("title", 1)] == A._meta["index_specs"][0]["fields"] A._get_collection().drop_indexes() A.ensure_indexes() info = A._get_collection().index_information() - self.assertEqual(len(info.keys()), 2) + assert len(info.keys()) == 2 class B(A): c = StringField() @@ -158,8 +157,8 @@ class TestIndexes(unittest.TestCase): "allow_inheritance": True, } - self.assertEqual([("c", 1)], B._meta["index_specs"][1]["fields"]) - self.assertEqual([("_cls", 1), ("d", 1)], B._meta["index_specs"][2]["fields"]) + assert [("c", 1)] == B._meta["index_specs"][1]["fields"] + assert [("_cls", 1), ("d", 1)] == B._meta["index_specs"][2]["fields"] def test_build_index_spec_is_not_destructive(self): class MyDoc(Document): @@ -167,12 +166,12 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["keywords"], "allow_inheritance": False} - self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}]) + assert MyDoc._meta["index_specs"] == [{"fields": [("keywords", 1)]}] # Force index creation MyDoc.ensure_indexes() - self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}]) + assert MyDoc._meta["index_specs"] == [{"fields": [("keywords", 1)]}] def test_embedded_document_index_meta(self): """Ensure that embedded document indexes are created explicitly @@ -187,7 +186,7 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["rank.title"], "allow_inheritance": False} - self.assertEqual([{"fields": [("rank.title", 1)]}], Person._meta["index_specs"]) + assert [{"fields": [("rank.title", 1)]}] == Person._meta["index_specs"] Person.drop_collection() @@ -195,7 +194,7 @@ class TestIndexes(unittest.TestCase): list(Person.objects) info = Person.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("rank.title", 1)], info) + assert [("rank.title", 1)] in info def test_explicit_geo2d_index(self): """Ensure that geo2d indexes work when created via meta[indexes] @@ -205,14 +204,12 @@ class TestIndexes(unittest.TestCase): location = DictField() meta = {"allow_inheritance": True, "indexes": ["*location.point"]} - self.assertEqual( - [{"fields": [("location.point", "2d")]}], Place._meta["index_specs"] - ) + assert [{"fields": [("location.point", "2d")]}] == Place._meta["index_specs"] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "2d")], info) + assert [("location.point", "2d")] in info def test_explicit_geo2d_index_embedded(self): """Ensure that geo2d indexes work when created via meta[indexes] @@ -225,14 +222,14 @@ class TestIndexes(unittest.TestCase): current = DictField(field=EmbeddedDocumentField("EmbeddedLocation")) meta = {"allow_inheritance": True, "indexes": ["*current.location.point"]} - self.assertEqual( - [{"fields": [("current.location.point", "2d")]}], Place._meta["index_specs"] - ) + assert [{"fields": [("current.location.point", "2d")]}] == Place._meta[ + "index_specs" + ] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("current.location.point", "2d")], info) + assert [("current.location.point", "2d")] in info def test_explicit_geosphere_index(self): """Ensure that geosphere indexes work when created via meta[indexes] @@ -242,14 +239,14 @@ class TestIndexes(unittest.TestCase): location = DictField() meta = {"allow_inheritance": True, "indexes": ["(location.point"]} - self.assertEqual( - [{"fields": [("location.point", "2dsphere")]}], Place._meta["index_specs"] - ) + assert [{"fields": [("location.point", "2dsphere")]}] == Place._meta[ + "index_specs" + ] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "2dsphere")], info) + assert [("location.point", "2dsphere")] in info def test_explicit_geohaystack_index(self): """Ensure that geohaystack indexes work when created via meta[indexes] @@ -264,15 +261,14 @@ class TestIndexes(unittest.TestCase): name = StringField() meta = {"indexes": [(")location.point", "name")]} - self.assertEqual( - [{"fields": [("location.point", "geoHaystack"), ("name", 1)]}], - Place._meta["index_specs"], - ) + assert [ + {"fields": [("location.point", "geoHaystack"), ("name", 1)]} + ] == Place._meta["index_specs"] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "geoHaystack")], info) + assert [("location.point", "geoHaystack")] in info def test_create_geohaystack_index(self): """Ensure that geohaystack indexes can be created @@ -285,7 +281,7 @@ class TestIndexes(unittest.TestCase): Place.create_index({"fields": (")location.point", "name")}, bucketSize=10) info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "geoHaystack"), ("name", 1)], info) + assert [("location.point", "geoHaystack"), ("name", 1)] in info def test_dictionary_indexes(self): """Ensure that indexes are used when meta[indexes] contains @@ -298,16 +294,15 @@ class TestIndexes(unittest.TestCase): tags = ListField(StringField()) meta = {"indexes": [{"fields": ["-date"], "unique": True, "sparse": True}]} - self.assertEqual( - [{"fields": [("addDate", -1)], "unique": True, "sparse": True}], - BlogPost._meta["index_specs"], - ) + assert [ + {"fields": [("addDate", -1)], "unique": True, "sparse": True} + ] == BlogPost._meta["index_specs"] BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() # _id, '-date' - self.assertEqual(len(info), 2) + assert len(info) == 2 # Indexes are lazy so use list() to perform query list(BlogPost.objects) @@ -316,7 +311,7 @@ class TestIndexes(unittest.TestCase): (value["key"], value.get("unique", False), value.get("sparse", False)) for key, value in iteritems(info) ] - self.assertIn(([("addDate", -1)], True, True), info) + assert ([("addDate", -1)], True, True) in info BlogPost.drop_collection() @@ -338,11 +333,9 @@ class TestIndexes(unittest.TestCase): Person(name="test", user_guid="123").save() - self.assertEqual(1, Person.objects.count()) + assert 1 == Person.objects.count() info = Person.objects._collection.index_information() - self.assertEqual( - sorted(info.keys()), ["_cls_1_name_1", "_cls_1_user_guid_1", "_id_"] - ) + assert sorted(info.keys()) == ["_cls_1_name_1", "_cls_1_user_guid_1", "_id_"] def test_disable_index_creation(self): """Tests setting auto_create_index to False on the connection will @@ -365,13 +358,13 @@ class TestIndexes(unittest.TestCase): User(user_guid="123").save() MongoUser(user_guid="123").save() - self.assertEqual(2, User.objects.count()) + assert 2 == User.objects.count() info = User.objects._collection.index_information() - self.assertEqual(list(info.keys()), ["_id_"]) + assert list(info.keys()) == ["_id_"] User.ensure_indexes() info = User.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ["_cls_1_user_guid_1", "_id_"]) + assert sorted(info.keys()) == ["_cls_1_user_guid_1", "_id_"] def test_embedded_document_index(self): """Tests settings an index on an embedded document @@ -389,7 +382,7 @@ class TestIndexes(unittest.TestCase): BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ["_id_", "date.yr_-1"]) + assert sorted(info.keys()) == ["_id_", "date.yr_-1"] def test_list_embedded_document_index(self): """Ensure list embedded documents can be indexed @@ -408,7 +401,7 @@ class TestIndexes(unittest.TestCase): info = BlogPost.objects._collection.index_information() # we don't use _cls in with list fields by default - self.assertEqual(sorted(info.keys()), ["_id_", "tags.tag_1"]) + assert sorted(info.keys()) == ["_id_", "tags.tag_1"] post1 = BlogPost( title="Embedded Indexes tests in place", @@ -426,7 +419,7 @@ class TestIndexes(unittest.TestCase): RecursiveDocument.ensure_indexes() info = RecursiveDocument._get_collection().index_information() - self.assertEqual(sorted(info.keys()), ["_cls_1", "_id_"]) + assert sorted(info.keys()) == ["_cls_1", "_id_"] def test_covered_index(self): """Ensure that covered indexes can be used @@ -446,46 +439,45 @@ class TestIndexes(unittest.TestCase): # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude("a").explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IDHACK", + .get("stage") + == "IDHACK" ) query_plan = Test.objects(id=obj.id).only("id").explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IDHACK", + .get("stage") + == "IDHACK" ) query_plan = Test.objects(a=1).only("a").exclude("id").explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IXSCAN", + .get("stage") + == "IXSCAN" ) - self.assertEqual( - query_plan.get("queryPlanner").get("winningPlan").get("stage"), "PROJECTION" + assert ( + query_plan.get("queryPlanner").get("winningPlan").get("stage") + == "PROJECTION" ) query_plan = Test.objects(a=1).explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IXSCAN", - ) - self.assertEqual( - query_plan.get("queryPlanner").get("winningPlan").get("stage"), "FETCH" + .get("stage") + == "IXSCAN" ) + assert query_plan.get("queryPlanner").get("winningPlan").get("stage") == "FETCH" def test_index_on_id(self): class BlogPost(Document): @@ -498,9 +490,7 @@ class TestIndexes(unittest.TestCase): BlogPost.drop_collection() indexes = BlogPost.objects._collection.index_information() - self.assertEqual( - indexes["categories_1__id_1"]["key"], [("categories", 1), ("_id", 1)] - ) + assert indexes["categories_1__id_1"]["key"] == [("categories", 1), ("_id", 1)] def test_hint(self): TAGS_INDEX_NAME = "tags_1" @@ -516,25 +506,25 @@ class TestIndexes(unittest.TestCase): BlogPost(tags=tags).save() # Hinting by shape should work. - self.assertEqual(BlogPost.objects.hint([("tags", 1)]).count(), 10) + assert BlogPost.objects.hint([("tags", 1)]).count() == 10 # Hinting by index name should work. - self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) + assert BlogPost.objects.hint(TAGS_INDEX_NAME).count() == 10 # Clearing the hint should work fine. - self.assertEqual(BlogPost.objects.hint().count(), 10) - self.assertEqual(BlogPost.objects.hint([("ZZ", 1)]).hint().count(), 10) + assert BlogPost.objects.hint().count() == 10 + assert BlogPost.objects.hint([("ZZ", 1)]).hint().count() == 10 # Hinting on a non-existent index shape should fail. - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): BlogPost.objects.hint([("ZZ", 1)]).count() # Hinting on a non-existent index name should fail. - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): BlogPost.objects.hint("Bad Name").count() # Invalid shape argument (missing list brackets) should fail. - with self.assertRaises(ValueError): + with pytest.raises(ValueError): BlogPost.objects.hint(("tags", 1)).count() def test_collation(self): @@ -588,11 +578,14 @@ class TestIndexes(unittest.TestCase): # Two posts with the same slug is not allowed post2 = BlogPost(title="test2", slug="test") - self.assertRaises(NotUniqueError, post2.save) - self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2) + with pytest.raises(NotUniqueError): + post2.save() + with pytest.raises(NotUniqueError): + BlogPost.objects.insert(post2) # Ensure backwards compatibility for errors - self.assertRaises(OperationError, post2.save) + with pytest.raises(OperationError): + post2.save() def test_primary_key_unique_not_working(self): """Relates to #1445""" @@ -602,23 +595,21 @@ class TestIndexes(unittest.TestCase): Blog.drop_collection() - with self.assertRaises(OperationFailure) as ctx_err: + with pytest.raises(OperationFailure) as ctx_err: Blog(id="garbage").save() # One of the errors below should happen. Which one depends on the # PyMongo version and dict order. err_msg = str(ctx_err.exception) - self.assertTrue( - any( - [ - "The field 'unique' is not valid for an _id index specification" - in err_msg, - "The field 'background' is not valid for an _id index specification" - in err_msg, - "The field 'sparse' is not valid for an _id index specification" - in err_msg, - ] - ) + assert any( + [ + "The field 'unique' is not valid for an _id index specification" + in err_msg, + "The field 'background' is not valid for an _id index specification" + in err_msg, + "The field 'sparse' is not valid for an _id index specification" + in err_msg, + ] ) def test_unique_with(self): @@ -644,7 +635,8 @@ class TestIndexes(unittest.TestCase): # Now there will be two docs with the same slug and the same day: fail post3 = BlogPost(title="test3", date=Date(year=2010), slug="test") - self.assertRaises(OperationError, post3.save) + with pytest.raises(OperationError): + post3.save() def test_unique_embedded_document(self): """Ensure that uniqueness constraints are applied to fields on embedded documents. @@ -669,7 +661,8 @@ class TestIndexes(unittest.TestCase): # Now there will be two docs with the same sub.slug post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test")) - self.assertRaises(NotUniqueError, post3.save) + with pytest.raises(NotUniqueError): + post3.save() def test_unique_embedded_document_in_list(self): """ @@ -699,7 +692,8 @@ class TestIndexes(unittest.TestCase): post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() def test_unique_embedded_document_in_sorted_list(self): """ @@ -729,12 +723,13 @@ class TestIndexes(unittest.TestCase): # confirm that the unique index is created indexes = BlogPost._get_collection().index_information() - self.assertIn("subs.slug_1", indexes) - self.assertTrue(indexes["subs.slug_1"]["unique"]) + assert "subs.slug_1" in indexes + assert indexes["subs.slug_1"]["unique"] post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() def test_unique_embedded_document_in_embedded_document_list(self): """ @@ -764,12 +759,13 @@ class TestIndexes(unittest.TestCase): # confirm that the unique index is created indexes = BlogPost._get_collection().index_information() - self.assertIn("subs.slug_1", indexes) - self.assertTrue(indexes["subs.slug_1"]["unique"]) + assert "subs.slug_1" in indexes + assert indexes["subs.slug_1"]["unique"] post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() def test_unique_with_embedded_document_and_embedded_unique(self): """Ensure that uniqueness constraints are applied to fields on @@ -795,11 +791,13 @@ class TestIndexes(unittest.TestCase): # Now there will be two docs with the same sub.slug post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test")) - self.assertRaises(NotUniqueError, post3.save) + with pytest.raises(NotUniqueError): + post3.save() # Now there will be two docs with the same title and year post3 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test-1")) - self.assertRaises(NotUniqueError, post3.save) + with pytest.raises(NotUniqueError): + post3.save() def test_ttl_indexes(self): class Log(Document): @@ -811,7 +809,7 @@ class TestIndexes(unittest.TestCase): # Indexes are lazy so use list() to perform query list(Log.objects) info = Log.objects._collection.index_information() - self.assertEqual(3600, info["created_1"]["expireAfterSeconds"]) + assert 3600 == info["created_1"]["expireAfterSeconds"] def test_index_drop_dups_silently_ignored(self): class Customer(Document): @@ -839,14 +837,14 @@ class TestIndexes(unittest.TestCase): cust.save() cust_dupe = Customer(cust_id=1) - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): cust_dupe.save() cust = Customer(cust_id=2) cust.save() # duplicate key on update - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): cust.cust_id = 1 cust.save() @@ -867,8 +865,8 @@ class TestIndexes(unittest.TestCase): user = User(name="huangz", password="secret2") user.save() - self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, "secret2") + assert User.objects.count() == 1 + assert User.objects.get().password == "secret2" def test_unique_and_primary_create(self): """Create a new record with a duplicate primary key @@ -882,11 +880,11 @@ class TestIndexes(unittest.TestCase): User.drop_collection() User.objects.create(name="huangz", password="secret") - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): User.objects.create(name="huangz", password="secret2") - self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, "secret") + assert User.objects.count() == 1 + assert User.objects.get().password == "secret" def test_index_with_pk(self): """Ensure you can use `pk` as part of a query""" @@ -910,7 +908,7 @@ class TestIndexes(unittest.TestCase): info = BlogPost.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] index_item = [("_id", 1), ("comments.comment_id", 1)] - self.assertIn(index_item, info) + assert index_item in info def test_compound_key_embedded(self): class CompoundKey(EmbeddedDocument): @@ -924,10 +922,8 @@ class TestIndexes(unittest.TestCase): my_key = CompoundKey(name="n", term="ok") report = ReportEmbedded(text="OK", key=my_key).save() - self.assertEqual( - {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo() - ) - self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key)) + assert {"text": "OK", "_id": {"term": "ok", "name": "n"}} == report.to_mongo() + assert report == ReportEmbedded.objects.get(pk=my_key) def test_compound_key_dictfield(self): class ReportDictField(Document): @@ -937,15 +933,13 @@ class TestIndexes(unittest.TestCase): my_key = {"name": "n", "term": "ok"} report = ReportDictField(text="OK", key=my_key).save() - self.assertEqual( - {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo() - ) + assert {"text": "OK", "_id": {"term": "ok", "name": "n"}} == report.to_mongo() # We can't directly call ReportDictField.objects.get(pk=my_key), # because dicts are unordered, and if the order in MongoDB is # different than the one in `my_key`, this test will fail. - self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key["name"])) - self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key["term"])) + assert report == ReportDictField.objects.get(pk__name=my_key["name"]) + assert report == ReportDictField.objects.get(pk__term=my_key["term"]) def test_string_indexes(self): class MyDoc(Document): @@ -954,8 +948,8 @@ class TestIndexes(unittest.TestCase): info = MyDoc.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("provider_ids.foo", 1)], info) - self.assertIn([("provider_ids.bar", 1)], info) + assert [("provider_ids.foo", 1)] in info + assert [("provider_ids.bar", 1)] in info def test_sparse_compound_indexes(self): class MyDoc(Document): @@ -967,11 +961,10 @@ class TestIndexes(unittest.TestCase): } info = MyDoc.objects._collection.index_information() - self.assertEqual( - [("provider_ids.foo", 1), ("provider_ids.bar", 1)], - info["provider_ids.foo_1_provider_ids.bar_1"]["key"], - ) - self.assertTrue(info["provider_ids.foo_1_provider_ids.bar_1"]["sparse"]) + assert [("provider_ids.foo", 1), ("provider_ids.bar", 1)] == info[ + "provider_ids.foo_1_provider_ids.bar_1" + ]["key"] + assert info["provider_ids.foo_1_provider_ids.bar_1"]["sparse"] def test_text_indexes(self): class Book(Document): @@ -979,9 +972,9 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["$title"]} indexes = Book.objects._collection.index_information() - self.assertIn("title_text", indexes) + assert "title_text" in indexes key = indexes["title_text"]["key"] - self.assertIn(("_fts", "text"), key) + assert ("_fts", "text") in key def test_hashed_indexes(self): class Book(Document): @@ -989,8 +982,8 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["#ref_id"]} indexes = Book.objects._collection.index_information() - self.assertIn("ref_id_hashed", indexes) - self.assertIn(("ref_id", "hashed"), indexes["ref_id_hashed"]["key"]) + assert "ref_id_hashed" in indexes + assert ("ref_id", "hashed") in indexes["ref_id_hashed"]["key"] def test_indexes_after_database_drop(self): """ @@ -1027,7 +1020,8 @@ class TestIndexes(unittest.TestCase): # Create Post #2 post2 = BlogPost(title="test2", slug="test") - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() finally: # Drop the temporary database at the end connection.drop_database("tempdatabase") @@ -1074,15 +1068,12 @@ class TestIndexes(unittest.TestCase): "dropDups" ] # drop the index dropDups - it is deprecated in MongoDB 3+ - self.assertEqual( - index_info, - { - "txt_1": {"key": [("txt", 1)], "background": False}, - "_id_": {"key": [("_id", 1)]}, - "txt2_1": {"key": [("txt2", 1)], "background": False}, - "_cls_1": {"key": [("_cls", 1)], "background": False}, - }, - ) + assert index_info == { + "txt_1": {"key": [("txt", 1)], "background": False}, + "_id_": {"key": [("_id", 1)]}, + "txt2_1": {"key": [("txt2", 1)], "background": False}, + "_cls_1": {"key": [("_cls", 1)], "background": False}, + } def test_compound_index_underscore_cls_not_overwritten(self): """ @@ -1105,7 +1096,7 @@ class TestIndexes(unittest.TestCase): TestDoc.ensure_indexes() index_info = TestDoc._get_collection().index_information() - self.assertIn("shard_1_1__cls_1_txt_1_1", index_info) + assert "shard_1_1__cls_1_txt_1_1" in index_info if __name__ == "__main__": diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index 4bb46e58..6a913b3e 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -17,6 +17,7 @@ from mongoengine import ( from mongoengine.pymongo_support import list_collection_names from tests.fixtures import Base from tests.utils import MongoDBTestCase +import pytest class TestInheritance(MongoDBTestCase): @@ -37,12 +38,12 @@ class TestInheritance(MongoDBTestCase): meta = {"allow_inheritance": True} test_doc = DataDoc(name="test", embed=EmbedData(data="data")) - self.assertEqual(test_doc._cls, "DataDoc") - self.assertEqual(test_doc.embed._cls, "EmbedData") + assert test_doc._cls == "DataDoc" + assert test_doc.embed._cls == "EmbedData" test_doc.save() saved_doc = DataDoc.objects.with_id(test_doc.id) - self.assertEqual(test_doc._cls, saved_doc._cls) - self.assertEqual(test_doc.embed._cls, saved_doc.embed._cls) + assert test_doc._cls == saved_doc._cls + assert test_doc.embed._cls == saved_doc.embed._cls test_doc.delete() def test_superclasses(self): @@ -67,12 +68,12 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Fish._superclasses, ("Animal",)) - self.assertEqual(Guppy._superclasses, ("Animal", "Animal.Fish")) - self.assertEqual(Mammal._superclasses, ("Animal",)) - self.assertEqual(Dog._superclasses, ("Animal", "Animal.Mammal")) - self.assertEqual(Human._superclasses, ("Animal", "Animal.Mammal")) + assert Animal._superclasses == () + assert Fish._superclasses == ("Animal",) + assert Guppy._superclasses == ("Animal", "Animal.Fish") + assert Mammal._superclasses == ("Animal",) + assert Dog._superclasses == ("Animal", "Animal.Mammal") + assert Human._superclasses == ("Animal", "Animal.Mammal") def test_external_superclasses(self): """Ensure that the correct list of super classes is assembled when @@ -97,18 +98,12 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual(Animal._superclasses, ("Base",)) - self.assertEqual(Fish._superclasses, ("Base", "Base.Animal")) - self.assertEqual( - Guppy._superclasses, ("Base", "Base.Animal", "Base.Animal.Fish") - ) - self.assertEqual(Mammal._superclasses, ("Base", "Base.Animal")) - self.assertEqual( - Dog._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal") - ) - self.assertEqual( - Human._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal") - ) + assert Animal._superclasses == ("Base",) + assert Fish._superclasses == ("Base", "Base.Animal") + assert Guppy._superclasses == ("Base", "Base.Animal", "Base.Animal.Fish") + assert Mammal._superclasses == ("Base", "Base.Animal") + assert Dog._superclasses == ("Base", "Base.Animal", "Base.Animal.Mammal") + assert Human._superclasses == ("Base", "Base.Animal", "Base.Animal.Mammal") def test_subclasses(self): """Ensure that the correct list of _subclasses (subclasses) is @@ -133,24 +128,22 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual( - Animal._subclasses, - ( - "Animal", - "Animal.Fish", - "Animal.Fish.Guppy", - "Animal.Mammal", - "Animal.Mammal.Dog", - "Animal.Mammal.Human", - ), + assert Animal._subclasses == ( + "Animal", + "Animal.Fish", + "Animal.Fish.Guppy", + "Animal.Mammal", + "Animal.Mammal.Dog", + "Animal.Mammal.Human", ) - self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Guppy")) - self.assertEqual(Guppy._subclasses, ("Animal.Fish.Guppy",)) - self.assertEqual( - Mammal._subclasses, - ("Animal.Mammal", "Animal.Mammal.Dog", "Animal.Mammal.Human"), + assert Fish._subclasses == ("Animal.Fish", "Animal.Fish.Guppy") + assert Guppy._subclasses == ("Animal.Fish.Guppy",) + assert Mammal._subclasses == ( + "Animal.Mammal", + "Animal.Mammal.Dog", + "Animal.Mammal.Human", ) - self.assertEqual(Human._subclasses, ("Animal.Mammal.Human",)) + assert Human._subclasses == ("Animal.Mammal.Human",) def test_external_subclasses(self): """Ensure that the correct list of _subclasses (subclasses) is @@ -175,30 +168,22 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual( - Animal._subclasses, - ( - "Base.Animal", - "Base.Animal.Fish", - "Base.Animal.Fish.Guppy", - "Base.Animal.Mammal", - "Base.Animal.Mammal.Dog", - "Base.Animal.Mammal.Human", - ), + assert Animal._subclasses == ( + "Base.Animal", + "Base.Animal.Fish", + "Base.Animal.Fish.Guppy", + "Base.Animal.Mammal", + "Base.Animal.Mammal.Dog", + "Base.Animal.Mammal.Human", ) - self.assertEqual( - Fish._subclasses, ("Base.Animal.Fish", "Base.Animal.Fish.Guppy") + assert Fish._subclasses == ("Base.Animal.Fish", "Base.Animal.Fish.Guppy") + assert Guppy._subclasses == ("Base.Animal.Fish.Guppy",) + assert Mammal._subclasses == ( + "Base.Animal.Mammal", + "Base.Animal.Mammal.Dog", + "Base.Animal.Mammal.Human", ) - self.assertEqual(Guppy._subclasses, ("Base.Animal.Fish.Guppy",)) - self.assertEqual( - Mammal._subclasses, - ( - "Base.Animal.Mammal", - "Base.Animal.Mammal.Dog", - "Base.Animal.Mammal.Human", - ), - ) - self.assertEqual(Human._subclasses, ("Base.Animal.Mammal.Human",)) + assert Human._subclasses == ("Base.Animal.Mammal.Human",) def test_dynamic_declarations(self): """Test that declaring an extra class updates meta data""" @@ -206,33 +191,31 @@ class TestInheritance(MongoDBTestCase): class Animal(Document): meta = {"allow_inheritance": True} - self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ("Animal",)) + assert Animal._superclasses == () + assert Animal._subclasses == ("Animal",) # Test dynamically adding a class changes the meta data class Fish(Animal): pass - self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ("Animal", "Animal.Fish")) + assert Animal._superclasses == () + assert Animal._subclasses == ("Animal", "Animal.Fish") - self.assertEqual(Fish._superclasses, ("Animal",)) - self.assertEqual(Fish._subclasses, ("Animal.Fish",)) + assert Fish._superclasses == ("Animal",) + assert Fish._subclasses == ("Animal.Fish",) # Test dynamically adding an inherited class changes the meta data class Pike(Fish): pass - self.assertEqual(Animal._superclasses, ()) - self.assertEqual( - Animal._subclasses, ("Animal", "Animal.Fish", "Animal.Fish.Pike") - ) + assert Animal._superclasses == () + assert Animal._subclasses == ("Animal", "Animal.Fish", "Animal.Fish.Pike") - self.assertEqual(Fish._superclasses, ("Animal",)) - self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Pike")) + assert Fish._superclasses == ("Animal",) + assert Fish._subclasses == ("Animal.Fish", "Animal.Fish.Pike") - self.assertEqual(Pike._superclasses, ("Animal", "Animal.Fish")) - self.assertEqual(Pike._subclasses, ("Animal.Fish.Pike",)) + assert Pike._superclasses == ("Animal", "Animal.Fish") + assert Pike._subclasses == ("Animal.Fish.Pike",) def test_inheritance_meta_data(self): """Ensure that document may inherit fields from a superclass document. @@ -247,10 +230,10 @@ class TestInheritance(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual( - ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys()) + assert ["_cls", "age", "id", "name", "salary"] == sorted( + Employee._fields.keys() ) - self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) + assert Employee._get_collection_name() == Person._get_collection_name() def test_inheritance_to_mongo_keys(self): """Ensure that document may inherit fields from a superclass document. @@ -265,17 +248,17 @@ class TestInheritance(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual( - ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys()) + assert ["_cls", "age", "id", "name", "salary"] == sorted( + Employee._fields.keys() ) - self.assertEqual( - Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"] - ) - self.assertEqual( - Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ["_cls", "name", "age", "salary"], - ) - self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) + assert Person(name="Bob", age=35).to_mongo().keys() == ["_cls", "name", "age"] + assert Employee(name="Bob", age=35, salary=0).to_mongo().keys() == [ + "_cls", + "name", + "age", + "salary", + ] + assert Employee._get_collection_name() == Person._get_collection_name() def test_indexes_and_multiple_inheritance(self): """ Ensure that all of the indexes are created for a document with @@ -301,13 +284,10 @@ class TestInheritance(MongoDBTestCase): C.ensure_indexes() - self.assertEqual( - sorted( - [idx["key"] for idx in C._get_collection().index_information().values()] - ), - sorted( - [[(u"_cls", 1), (u"b", 1)], [(u"_id", 1)], [(u"_cls", 1), (u"a", 1)]] - ), + assert sorted( + [idx["key"] for idx in C._get_collection().index_information().values()] + ) == sorted( + [[(u"_cls", 1), (u"b", 1)], [(u"_id", 1)], [(u"_cls", 1), (u"a", 1)]] ) def test_polymorphic_queries(self): @@ -338,13 +318,13 @@ class TestInheritance(MongoDBTestCase): Human().save() classes = [obj.__class__ for obj in Animal.objects] - self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + assert classes == [Animal, Fish, Mammal, Dog, Human] classes = [obj.__class__ for obj in Mammal.objects] - self.assertEqual(classes, [Mammal, Dog, Human]) + assert classes == [Mammal, Dog, Human] classes = [obj.__class__ for obj in Human.objects] - self.assertEqual(classes, [Human]) + assert classes == [Human] def test_allow_inheritance(self): """Ensure that inheritance is disabled by default on simple @@ -355,20 +335,20 @@ class TestInheritance(MongoDBTestCase): name = StringField() # can't inherit because Animal didn't explicitly allow inheritance - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: class Dog(Animal): pass - self.assertIn("Document Animal may not be subclassed", str(cm.exception)) + assert "Document Animal may not be subclassed" in str(cm.exception) # Check that _cls etc aren't present on simple documents dog = Animal(name="dog").save() - self.assertEqual(dog.to_mongo().keys(), ["_id", "name"]) + assert dog.to_mongo().keys() == ["_id", "name"] collection = self.db[Animal._get_collection_name()] obj = collection.find_one() - self.assertNotIn("_cls", obj) + assert "_cls" not in obj def test_cant_turn_off_inheritance_on_subclass(self): """Ensure if inheritance is on in a subclass you cant turn it off. @@ -378,14 +358,14 @@ class TestInheritance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: class Mammal(Animal): meta = {"allow_inheritance": False} - self.assertEqual( - str(cm.exception), - 'Only direct subclasses of Document may set "allow_inheritance" to False', + assert ( + str(cm.exception) + == 'Only direct subclasses of Document may set "allow_inheritance" to False' ) def test_allow_inheritance_abstract_document(self): @@ -399,14 +379,14 @@ class TestInheritance(MongoDBTestCase): class Animal(FinalDocument): name = StringField() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class Mammal(Animal): pass # Check that _cls isn't present in simple documents doc = Animal(name="dog") - self.assertNotIn("_cls", doc.to_mongo()) + assert "_cls" not in doc.to_mongo() def test_using_abstract_class_in_reference_field(self): # Ensures no regression of #1920 @@ -452,10 +432,10 @@ class TestInheritance(MongoDBTestCase): name = StringField() berlin = EuropeanCity(name="Berlin", continent="Europe") - self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._fields_ordered), 3) - self.assertEqual(berlin._fields_ordered[0], "id") + assert len(berlin._db_field_map) == len(berlin._fields_ordered) + assert len(berlin._reverse_db_field_map) == len(berlin._fields_ordered) + assert len(berlin._fields_ordered) == 3 + assert berlin._fields_ordered[0] == "id" def test_auto_id_not_set_if_specific_in_parent_class(self): class City(Document): @@ -467,10 +447,10 @@ class TestInheritance(MongoDBTestCase): name = StringField() berlin = EuropeanCity(name="Berlin", continent="Europe") - self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._fields_ordered), 3) - self.assertEqual(berlin._fields_ordered[0], "city_id") + assert len(berlin._db_field_map) == len(berlin._fields_ordered) + assert len(berlin._reverse_db_field_map) == len(berlin._fields_ordered) + assert len(berlin._fields_ordered) == 3 + assert berlin._fields_ordered[0] == "city_id" def test_auto_id_vs_non_pk_id_field(self): class City(Document): @@ -482,12 +462,12 @@ class TestInheritance(MongoDBTestCase): name = StringField() berlin = EuropeanCity(name="Berlin", continent="Europe") - self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._fields_ordered), 4) - self.assertEqual(berlin._fields_ordered[0], "auto_id_0") + assert len(berlin._db_field_map) == len(berlin._fields_ordered) + assert len(berlin._reverse_db_field_map) == len(berlin._fields_ordered) + assert len(berlin._fields_ordered) == 4 + assert berlin._fields_ordered[0] == "auto_id_0" berlin.save() - self.assertEqual(berlin.pk, berlin.auto_id_0) + assert berlin.pk == berlin.auto_id_0 def test_abstract_document_creation_does_not_fail(self): class City(Document): @@ -495,9 +475,9 @@ class TestInheritance(MongoDBTestCase): meta = {"abstract": True, "allow_inheritance": False} city = City(continent="asia") - self.assertEqual(None, city.pk) + assert None == city.pk # TODO: expected error? Shouldn't we create a new error type? - with self.assertRaises(KeyError): + with pytest.raises(KeyError): setattr(city, "pk", 1) def test_allow_inheritance_embedded_document(self): @@ -506,20 +486,20 @@ class TestInheritance(MongoDBTestCase): class Comment(EmbeddedDocument): content = StringField() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class SpecialComment(Comment): pass doc = Comment(content="test") - self.assertNotIn("_cls", doc.to_mongo()) + assert "_cls" not in doc.to_mongo() class Comment(EmbeddedDocument): content = StringField() meta = {"allow_inheritance": True} doc = Comment(content="test") - self.assertIn("_cls", doc.to_mongo()) + assert "_cls" in doc.to_mongo() def test_document_inheritance(self): """Ensure mutliple inheritance of abstract documents @@ -537,7 +517,7 @@ class TestInheritance(MongoDBTestCase): pass except Exception: - self.assertTrue(False, "Couldn't create MyDocument class") + assert False, "Couldn't create MyDocument class" def test_abstract_documents(self): """Ensure that a document superclass can be marked as abstract @@ -574,20 +554,20 @@ class TestInheritance(MongoDBTestCase): for k, v in iteritems(defaults): for cls in [Animal, Fish, Guppy]: - self.assertEqual(cls._meta[k], v) + assert cls._meta[k] == v - self.assertNotIn("collection", Animal._meta) - self.assertNotIn("collection", Mammal._meta) + assert "collection" not in Animal._meta + assert "collection" not in Mammal._meta - self.assertEqual(Animal._get_collection_name(), None) - self.assertEqual(Mammal._get_collection_name(), None) + assert Animal._get_collection_name() == None + assert Mammal._get_collection_name() == None - self.assertEqual(Fish._get_collection_name(), "fish") - self.assertEqual(Guppy._get_collection_name(), "fish") - self.assertEqual(Human._get_collection_name(), "human") + assert Fish._get_collection_name() == "fish" + assert Guppy._get_collection_name() == "fish" + assert Human._get_collection_name() == "human" # ensure that a subclass of a non-abstract class can't be abstract - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class EvilHuman(Human): evil = BooleanField(default=True) @@ -601,7 +581,7 @@ class TestInheritance(MongoDBTestCase): class B(A): pass - self.assertFalse(B._meta["abstract"]) + assert not B._meta["abstract"] def test_inherited_collections(self): """Ensure that subclassed documents don't override parents' @@ -647,8 +627,8 @@ class TestInheritance(MongoDBTestCase): real_person = Drinker(drink=beer) real_person.save() - self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) - self.assertEqual(Drinker.objects[1].drink.name, beer.name) + assert Drinker.objects[0].drink.name == red_bull.name + assert Drinker.objects[1].drink.name == beer.name if __name__ == "__main__": diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 203e2cce..01dc492b 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -36,6 +36,7 @@ from tests.fixtures import ( PickleTest, ) from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") @@ -63,18 +64,17 @@ class TestInstance(MongoDBTestCase): self.db.drop_collection(collection) def assertDbEqual(self, docs): - self.assertEqual( - list(self.Person._get_collection().find().sort("id")), - sorted(docs, key=lambda doc: doc["_id"]), + assert list(self.Person._get_collection().find().sort("id")) == sorted( + docs, key=lambda doc: doc["_id"] ) def assertHasInstance(self, field, instance): - self.assertTrue(hasattr(field, "_instance")) - self.assertTrue(field._instance is not None) + assert hasattr(field, "_instance") + assert field._instance is not None if isinstance(field._instance, weakref.ProxyType): - self.assertTrue(field._instance.__eq__(instance)) + assert field._instance.__eq__(instance) else: - self.assertEqual(field._instance, instance) + assert field._instance == instance def test_capped_collection(self): """Ensure that capped collections work properly.""" @@ -89,16 +89,16 @@ class TestInstance(MongoDBTestCase): for _ in range(10): Log().save() - self.assertEqual(Log.objects.count(), 10) + assert Log.objects.count() == 10 # Check that extra documents don't increase the size Log().save() - self.assertEqual(Log.objects.count(), 10) + assert Log.objects.count() == 10 options = Log.objects._collection.options() - self.assertEqual(options["capped"], True) - self.assertEqual(options["max"], 10) - self.assertEqual(options["size"], 4096) + assert options["capped"] == True + assert options["max"] == 10 + assert options["size"] == 4096 # Check that the document cannot be redefined with different options class Log(Document): @@ -106,7 +106,7 @@ class TestInstance(MongoDBTestCase): meta = {"max_documents": 11} # Accessing Document.objects creates the collection - with self.assertRaises(InvalidCollectionError): + with pytest.raises(InvalidCollectionError): Log.objects def test_capped_collection_default(self): @@ -122,9 +122,9 @@ class TestInstance(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - self.assertEqual(options["capped"], True) - self.assertEqual(options["max"], 10) - self.assertEqual(options["size"], 10 * 2 ** 20) + assert options["capped"] == True + assert options["max"] == 10 + assert options["size"] == 10 * 2 ** 20 # Check that the document with default value can be recreated class Log(Document): @@ -150,8 +150,8 @@ class TestInstance(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - self.assertEqual(options["capped"], True) - self.assertTrue(options["size"] >= 10000) + assert options["capped"] == True + assert options["size"] >= 10000 # Check that the document with odd max_size value can be recreated class Log(Document): @@ -173,7 +173,7 @@ class TestInstance(MongoDBTestCase): doc = Article(title=u"привет мир") - self.assertEqual("", repr(doc)) + assert "" == repr(doc) def test_repr_none(self): """Ensure None values are handled correctly.""" @@ -185,11 +185,11 @@ class TestInstance(MongoDBTestCase): return None doc = Article(title=u"привет мир") - self.assertEqual("", repr(doc)) + assert "" == repr(doc) def test_queryset_resurrects_dropped_collection(self): self.Person.drop_collection() - self.assertEqual([], list(self.Person.objects())) + assert [] == list(self.Person.objects()) # Ensure works correctly with inhertited classes class Actor(self.Person): @@ -197,7 +197,7 @@ class TestInstance(MongoDBTestCase): Actor.objects() self.Person.drop_collection() - self.assertEqual([], list(Actor.objects())) + assert [] == list(Actor.objects()) def test_polymorphic_references(self): """Ensure that the correct subclasses are returned from a query @@ -237,7 +237,7 @@ class TestInstance(MongoDBTestCase): zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + assert classes == [Animal, Fish, Mammal, Dog, Human] Zoo.drop_collection() @@ -250,7 +250,7 @@ class TestInstance(MongoDBTestCase): zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + assert classes == [Animal, Fish, Mammal, Dog, Human] def test_reference_inheritance(self): class Stats(Document): @@ -275,7 +275,7 @@ class TestInstance(MongoDBTestCase): cmp_stats = CompareStats(stats=list_stats) cmp_stats.save() - self.assertEqual(list_stats, CompareStats.objects.first().stats) + assert list_stats == CompareStats.objects.first().stats def test_db_field_load(self): """Ensure we load data correctly from the right db field.""" @@ -294,8 +294,8 @@ class TestInstance(MongoDBTestCase): Person(name="Fred").save() - self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") - self.assertEqual(Person.objects.get(name="Fred").rank, "Private") + assert Person.objects.get(name="Jack").rank == "Corporal" + assert Person.objects.get(name="Fred").rank == "Private" def test_db_embedded_doc_field_load(self): """Ensure we load embedded document data correctly.""" @@ -318,8 +318,8 @@ class TestInstance(MongoDBTestCase): Person(name="Jack", rank_=Rank(title="Corporal")).save() Person(name="Fred").save() - self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") - self.assertEqual(Person.objects.get(name="Fred").rank, "Private") + assert Person.objects.get(name="Jack").rank == "Corporal" + assert Person.objects.get(name="Fred").rank == "Private" def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys.""" @@ -332,15 +332,15 @@ class TestInstance(MongoDBTestCase): User.drop_collection() - self.assertEqual(User._fields["username"].db_field, "_id") - self.assertEqual(User._meta["id_field"], "username") + assert User._fields["username"].db_field == "_id" + assert User._meta["id_field"] == "username" User.objects.create(username="test", name="test user") user = User.objects.first() - self.assertEqual(user.id, "test") - self.assertEqual(user.pk, "test") + assert user.id == "test" + assert user.pk == "test" user_dict = User.objects._collection.find_one() - self.assertEqual(user_dict["_id"], "test") + assert user_dict["_id"] == "test" def test_change_custom_id_field_in_subclass(self): """Subclasses cannot override which field is the primary key.""" @@ -350,13 +350,13 @@ class TestInstance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError) as e: class EmailUser(User): email = StringField(primary_key=True) exc = e.exception - self.assertEqual(str(exc), "Cannot override primary key field") + assert str(exc) == "Cannot override primary key field" def test_custom_id_field_is_required(self): """Ensure the custom primary key field is required.""" @@ -365,10 +365,10 @@ class TestInstance(MongoDBTestCase): username = StringField(primary_key=True) name = StringField() - with self.assertRaises(ValidationError) as e: + with pytest.raises(ValidationError) as e: User(name="test").save() exc = e.exception - self.assertTrue("Field is required: ['username']" in str(exc)) + assert "Field is required: ['username']" in str(exc) def test_document_not_registered(self): class Place(Document): @@ -388,7 +388,7 @@ class TestInstance(MongoDBTestCase): # and the NicePlace model not being imported in at query time. del _document_registry["Place.NicePlace"] - with self.assertRaises(NotRegistered): + with pytest.raises(NotRegistered): list(Place.objects.all()) def test_document_registry_regressions(self): @@ -401,26 +401,27 @@ class TestInstance(MongoDBTestCase): Location.drop_collection() - self.assertEqual(Area, get_document("Area")) - self.assertEqual(Area, get_document("Location.Area")) + assert Area == get_document("Area") + assert Area == get_document("Location.Area") def test_creation(self): """Ensure that document may be created using keyword arguments.""" person = self.Person(name="Test User", age=30) - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 30) + assert person.name == "Test User" + assert person.age == 30 def test_to_dbref(self): """Ensure that you can get a dbref of a document.""" person = self.Person(name="Test User", age=30) - self.assertRaises(OperationError, person.to_dbref) + with pytest.raises(OperationError): + person.to_dbref() person.save() person.to_dbref() def test_key_like_attribute_access(self): person = self.Person(age=30) - self.assertEqual(person["age"], 30) - with self.assertRaises(KeyError): + assert person["age"] == 30 + with pytest.raises(KeyError): person["unknown_attr"] def test_save_abstract_document(self): @@ -430,7 +431,7 @@ class TestInstance(MongoDBTestCase): name = StringField() meta = {"abstract": True} - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): Doc(name="aaa").save() def test_reload(self): @@ -443,20 +444,20 @@ class TestInstance(MongoDBTestCase): person_obj.age = 21 person_obj.save() - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 20) + assert person.name == "Test User" + assert person.age == 20 person.reload("age") - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 21) + assert person.name == "Test User" + assert person.age == 21 person.reload() - self.assertEqual(person.name, "Mr Test User") - self.assertEqual(person.age, 21) + assert person.name == "Mr Test User" + assert person.age == 21 person.reload() - self.assertEqual(person.name, "Mr Test User") - self.assertEqual(person.age, 21) + assert person.name == "Mr Test User" + assert person.age == 21 def test_reload_sharded(self): class Animal(Document): @@ -471,9 +472,10 @@ class TestInstance(MongoDBTestCase): with query_counter() as q: doc.reload() query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] - self.assertEqual( - set(query_op[CMD_QUERY_KEY]["filter"].keys()), {"_id", "superphylum"} - ) + assert set(query_op[CMD_QUERY_KEY]["filter"].keys()) == { + "_id", + "superphylum", + } def test_reload_sharded_with_db_field(self): class Person(Document): @@ -488,9 +490,7 @@ class TestInstance(MongoDBTestCase): with query_counter() as q: doc.reload() query_op = q.db.system.profile.find({"ns": "mongoenginetest.person"})[0] - self.assertEqual( - set(query_op[CMD_QUERY_KEY]["filter"].keys()), {"_id", "country"} - ) + assert set(query_op[CMD_QUERY_KEY]["filter"].keys()) == {"_id", "country"} def test_reload_sharded_nested(self): class SuperPhylum(EmbeddedDocument): @@ -526,15 +526,11 @@ class TestInstance(MongoDBTestCase): doc.name = "Cat" doc.save() query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] - self.assertEqual(query_op["op"], "update") + assert query_op["op"] == "update" if mongo_db <= MONGODB_34: - self.assertEqual( - set(query_op["query"].keys()), set(["_id", "is_mammal"]) - ) + assert set(query_op["query"].keys()) == set(["_id", "is_mammal"]) else: - self.assertEqual( - set(query_op["command"]["q"].keys()), set(["_id", "is_mammal"]) - ) + assert set(query_op["command"]["q"].keys()) == set(["_id", "is_mammal"]) Animal.drop_collection() @@ -551,12 +547,12 @@ class TestInstance(MongoDBTestCase): user.name = "John" user.number = 2 - self.assertEqual(user._get_changed_fields(), ["name", "number"]) + assert user._get_changed_fields() == ["name", "number"] user.reload("number") - self.assertEqual(user._get_changed_fields(), ["name"]) + assert user._get_changed_fields() == ["name"] user.save() user.reload() - self.assertEqual(user.name, "John") + assert user.name == "John" def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly.""" @@ -587,47 +583,44 @@ class TestInstance(MongoDBTestCase): doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field["woot"] = "woot" - self.assertEqual( - doc._get_changed_fields(), - [ - "list_field", - "dict_field.woot", - "embedded_field.list_field", - "embedded_field.dict_field.woot", - ], - ) + assert doc._get_changed_fields() == [ + "list_field", + "dict_field.woot", + "embedded_field.list_field", + "embedded_field.dict_field.woot", + ] doc.save() - self.assertEqual(len(doc.list_field), 4) + assert len(doc.list_field) == 4 doc = doc.reload(10) - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(len(doc.list_field), 4) - self.assertEqual(len(doc.dict_field), 2) - self.assertEqual(len(doc.embedded_field.list_field), 4) - self.assertEqual(len(doc.embedded_field.dict_field), 2) + assert doc._get_changed_fields() == [] + assert len(doc.list_field) == 4 + assert len(doc.dict_field) == 2 + assert len(doc.embedded_field.list_field) == 4 + assert len(doc.embedded_field.dict_field) == 2 doc.list_field.append(1) doc.save() doc.dict_field["extra"] = 1 doc = doc.reload(10, "list_field") - self.assertEqual(doc._get_changed_fields(), ["dict_field.extra"]) - self.assertEqual(len(doc.list_field), 5) - self.assertEqual(len(doc.dict_field), 3) - self.assertEqual(len(doc.embedded_field.list_field), 4) - self.assertEqual(len(doc.embedded_field.dict_field), 2) + assert doc._get_changed_fields() == ["dict_field.extra"] + assert len(doc.list_field) == 5 + assert len(doc.dict_field) == 3 + assert len(doc.embedded_field.list_field) == 4 + assert len(doc.embedded_field.dict_field) == 2 def test_reload_doesnt_exist(self): class Foo(Document): pass f = Foo() - with self.assertRaises(Foo.DoesNotExist): + with pytest.raises(Foo.DoesNotExist): f.reload() f.save() f.delete() - with self.assertRaises(Foo.DoesNotExist): + with pytest.raises(Foo.DoesNotExist): f.reload() def test_reload_of_non_strict_with_special_field_name(self): @@ -646,27 +639,29 @@ class TestInstance(MongoDBTestCase): post = Post.objects.first() post.reload() - self.assertEqual(post.title, "Items eclipse") - self.assertEqual(post.items, ["more lorem", "even more ipsum"]) + assert post.title == "Items eclipse" + assert post.items == ["more lorem", "even more ipsum"] def test_dictionary_access(self): """Ensure that dictionary-style field access works properly.""" person = self.Person(name="Test User", age=30, job=self.Job()) - self.assertEqual(person["name"], "Test User") + assert person["name"] == "Test User" - self.assertRaises(KeyError, person.__getitem__, "salary") - self.assertRaises(KeyError, person.__setitem__, "salary", 50) + with pytest.raises(KeyError): + person.__getitem__("salary") + with pytest.raises(KeyError): + person.__setitem__("salary", 50) person["name"] = "Another User" - self.assertEqual(person["name"], "Another User") + assert person["name"] == "Another User" # Length = length(assigned fields + id) - self.assertEqual(len(person), 5) + assert len(person) == 5 - self.assertIn("age", person) + assert "age" in person person.age = None - self.assertNotIn("age", person) - self.assertNotIn("nationality", person) + assert "age" not in person + assert "nationality" not in person def test_embedded_document_to_mongo(self): class Person(EmbeddedDocument): @@ -678,20 +673,20 @@ class TestInstance(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual( - Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"] - ) - self.assertEqual( - Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ["_cls", "name", "age", "salary"], - ) + assert Person(name="Bob", age=35).to_mongo().keys() == ["_cls", "name", "age"] + assert Employee(name="Bob", age=35, salary=0).to_mongo().keys() == [ + "_cls", + "name", + "age", + "salary", + ] def test_embedded_document_to_mongo_id(self): class SubDoc(EmbeddedDocument): id = StringField(required=True) sub_doc = SubDoc(id="abc") - self.assertEqual(sub_doc.to_mongo().keys(), ["id"]) + assert sub_doc.to_mongo().keys() == ["id"] def test_embedded_document(self): """Ensure that embedded documents are set up correctly.""" @@ -699,8 +694,8 @@ class TestInstance(MongoDBTestCase): class Comment(EmbeddedDocument): content = StringField() - self.assertIn("content", Comment._fields) - self.assertNotIn("id", Comment._fields) + assert "content" in Comment._fields + assert "id" not in Comment._fields def test_embedded_document_instance(self): """Ensure that embedded documents can reference parent instance.""" @@ -753,7 +748,7 @@ class TestInstance(MongoDBTestCase): .to_mongo(use_db_field=False) .to_dict() ) - self.assertEqual(d["embedded_field"], [{"string": "Hi"}]) + assert d["embedded_field"] == [{"string": "Hi"}] def test_instance_is_set_on_setattr(self): class Email(EmbeddedDocument): @@ -796,7 +791,7 @@ class TestInstance(MongoDBTestCase): def clean(self): raise CustomError() - with self.assertRaises(CustomError): + with pytest.raises(CustomError): TestDocument().save() TestDocument().save(clean=False) @@ -816,10 +811,10 @@ class TestInstance(MongoDBTestCase): BlogPost.drop_collection() post = BlogPost(content="unchecked").save() - self.assertEqual(post.content, "checked") + assert post.content == "checked" # Make sure pre_save_post_validation changes makes it to the db raw_doc = get_as_pymongo(post) - self.assertEqual(raw_doc, {"content": "checked", "_id": post.id}) + assert raw_doc == {"content": "checked", "_id": post.id} # Important to disconnect as it could cause some assertions in test_signals # to fail (due to the garbage collection timing of this signal) @@ -840,17 +835,17 @@ class TestInstance(MongoDBTestCase): # Ensure clean=False prevent call to clean t = TestDocument(status="published") t.save(clean=False) - self.assertEqual(t.status, "published") - self.assertEqual(t.cleaned, False) + assert t.status == "published" + assert t.cleaned == False t = TestDocument(status="published") - self.assertEqual(t.cleaned, False) + assert t.cleaned == False t.save(clean=True) - self.assertEqual(t.status, "published") - self.assertEqual(t.cleaned, True) + assert t.status == "published" + assert t.cleaned == True raw_doc = get_as_pymongo(t) # Make sure clean changes makes it to the db - self.assertEqual(raw_doc, {"status": "published", "cleaned": True, "_id": t.id}) + assert raw_doc == {"status": "published", "cleaned": True, "_id": t.id} def test_document_embedded_clean(self): class TestEmbeddedDocument(EmbeddedDocument): @@ -875,15 +870,15 @@ class TestInstance(MongoDBTestCase): t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) - with self.assertRaises(ValidationError) as cm: + with pytest.raises(ValidationError) as cm: t.save() expected_msg = "Value of z != x + y" - self.assertIn(expected_msg, cm.exception.message) - self.assertEqual(cm.exception.to_dict(), {"doc": {"__all__": expected_msg}}) + assert expected_msg in cm.exception.message + assert cm.exception.to_dict() == {"doc": {"__all__": expected_msg}} t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() - self.assertEqual(t.doc.z, 35) + assert t.doc.z == 35 # Asserts not raises t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) @@ -892,7 +887,7 @@ class TestInstance(MongoDBTestCase): def test_modify_empty(self): doc = self.Person(name="bob", age=10).save() - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): self.Person().modify(set__age=10) self.assertDbEqual([dict(doc.to_mongo())]) @@ -902,7 +897,7 @@ class TestInstance(MongoDBTestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): doc1.modify({"id": doc2.id}, set__value=20) self.assertDbEqual(docs) @@ -913,7 +908,7 @@ class TestInstance(MongoDBTestCase): docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] n_modified = doc1.modify({"name": doc2.name}, set__age=100) - self.assertEqual(n_modified, 0) + assert n_modified == 0 self.assertDbEqual(docs) @@ -923,7 +918,7 @@ class TestInstance(MongoDBTestCase): docs = [dict(doc1.to_mongo())] n_modified = doc2.modify({"name": doc2.name}, set__age=100) - self.assertEqual(n_modified, 0) + assert n_modified == 0 self.assertDbEqual(docs) @@ -943,13 +938,13 @@ class TestInstance(MongoDBTestCase): n_modified = doc.modify( set__age=21, set__job__name="MongoDB", unset__job__years=True ) - self.assertEqual(n_modified, 1) + assert n_modified == 1 doc_copy.age = 21 doc_copy.job.name = "MongoDB" del doc_copy.job.years - self.assertEqual(doc.to_json(), doc_copy.to_json()) - self.assertEqual(doc._get_changed_fields(), []) + assert doc.to_json() == doc_copy.to_json() + assert doc._get_changed_fields() == [] self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) @@ -965,27 +960,25 @@ class TestInstance(MongoDBTestCase): tags=["python"], content=Content(keywords=["ipsum"]) ) - self.assertEqual(post.tags, ["python"]) + assert post.tags == ["python"] post.modify(push__tags__0=["code", "mongo"]) - self.assertEqual(post.tags, ["code", "mongo", "python"]) + assert post.tags == ["code", "mongo", "python"] # Assert same order of the list items is maintained in the db - self.assertEqual( - BlogPost._get_collection().find_one({"_id": post.pk})["tags"], - ["code", "mongo", "python"], - ) + assert BlogPost._get_collection().find_one({"_id": post.pk})["tags"] == [ + "code", + "mongo", + "python", + ] - self.assertEqual(post.content.keywords, ["ipsum"]) + assert post.content.keywords == ["ipsum"] post.modify(push__content__keywords__0=["lorem"]) - self.assertEqual(post.content.keywords, ["lorem", "ipsum"]) + assert post.content.keywords == ["lorem", "ipsum"] # Assert same order of the list items is maintained in the db - self.assertEqual( - BlogPost._get_collection().find_one({"_id": post.pk})["content"][ - "keywords" - ], - ["lorem", "ipsum"], - ) + assert BlogPost._get_collection().find_one({"_id": post.pk})["content"][ + "keywords" + ] == ["lorem", "ipsum"] def test_save(self): """Ensure that a document may be saved in the database.""" @@ -996,28 +989,30 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database raw_doc = get_as_pymongo(person) - self.assertEqual( - raw_doc, - {"_cls": "Person", "name": "Test User", "age": 30, "_id": person.id}, - ) + assert raw_doc == { + "_cls": "Person", + "name": "Test User", + "age": 30, + "_id": person.id, + } def test_save_skip_validation(self): class Recipient(Document): email = EmailField(required=True) recipient = Recipient(email="not-an-email") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): recipient.save() recipient.save(validate=False) raw_doc = get_as_pymongo(recipient) - self.assertEqual(raw_doc, {"email": "not-an-email", "_id": recipient.id}) + assert raw_doc == {"email": "not-an-email", "_id": recipient.id} def test_save_with_bad_id(self): class Clown(Document): id = IntField(primary_key=True) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Clown(id="not_an_int").save() def test_save_to_a_value_that_equates_to_false(self): @@ -1037,7 +1032,7 @@ class TestInstance(MongoDBTestCase): user.save() user.reload() - self.assertEqual(user.thing.count, 0) + assert user.thing.count == 0 def test_save_max_recursion_not_hit(self): class Person(Document): @@ -1085,7 +1080,7 @@ class TestInstance(MongoDBTestCase): b.name = "world" b.save() - self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) + assert b.picture == b.bar.picture, b.bar.bar.picture def test_save_cascades(self): class Person(Document): @@ -1107,7 +1102,7 @@ class TestInstance(MongoDBTestCase): p.save(cascade=True) p1.reload() - self.assertEqual(p1.name, p.parent.name) + assert p1.name == p.parent.name def test_save_cascade_kwargs(self): class Person(Document): @@ -1127,7 +1122,7 @@ class TestInstance(MongoDBTestCase): p1.reload() p2.reload() - self.assertEqual(p1.name, p2.parent.name) + assert p1.name == p2.parent.name def test_save_cascade_meta_false(self): class Person(Document): @@ -1151,11 +1146,11 @@ class TestInstance(MongoDBTestCase): p.save() p1.reload() - self.assertNotEqual(p1.name, p.parent.name) + assert p1.name != p.parent.name p.save(cascade=True) p1.reload() - self.assertEqual(p1.name, p.parent.name) + assert p1.name == p.parent.name def test_save_cascade_meta_true(self): class Person(Document): @@ -1179,7 +1174,7 @@ class TestInstance(MongoDBTestCase): p.save() p1.reload() - self.assertNotEqual(p1.name, p.parent.name) + assert p1.name != p.parent.name def test_save_cascades_generically(self): class Person(Document): @@ -1200,11 +1195,11 @@ class TestInstance(MongoDBTestCase): p.save() p1.reload() - self.assertNotEqual(p1.name, p.parent.name) + assert p1.name != p.parent.name p.save(cascade=True) p1.reload() - self.assertEqual(p1.name, p.parent.name) + assert p1.name == p.parent.name def test_save_atomicity_condition(self): class Widget(Document): @@ -1226,64 +1221,61 @@ class TestInstance(MongoDBTestCase): # ignore save_condition on new record creation w1.save(save_condition={"save_id": UUID(42)}) w1.reload() - self.assertFalse(w1.toggle) - self.assertEqual(w1.save_id, UUID(1)) - self.assertEqual(w1.count, 0) + assert not w1.toggle + assert w1.save_id == UUID(1) + assert w1.count == 0 # mismatch in save_condition prevents save and raise exception flip(w1) - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 1) - self.assertRaises( - SaveConditionError, w1.save, save_condition={"save_id": UUID(42)} - ) + assert w1.toggle + assert w1.count == 1 + with pytest.raises(SaveConditionError): + w1.save(save_condition={"save_id": UUID(42)}) w1.reload() - self.assertFalse(w1.toggle) - self.assertEqual(w1.count, 0) + assert not w1.toggle + assert w1.count == 0 # matched save_condition allows save flip(w1) - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 1) + assert w1.toggle + assert w1.count == 1 w1.save(save_condition={"save_id": UUID(1)}) w1.reload() - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 1) + assert w1.toggle + assert w1.count == 1 # save_condition can be used to ensure atomic read & updates # i.e., prevent interleaved reads and writes from separate contexts w2 = Widget.objects.get() - self.assertEqual(w1, w2) + assert w1 == w2 old_id = w1.save_id flip(w1) w1.save_id = UUID(2) w1.save(save_condition={"save_id": old_id}) w1.reload() - self.assertFalse(w1.toggle) - self.assertEqual(w1.count, 2) + assert not w1.toggle + assert w1.count == 2 flip(w2) flip(w2) - self.assertRaises( - SaveConditionError, w2.save, save_condition={"save_id": old_id} - ) + with pytest.raises(SaveConditionError): + w2.save(save_condition={"save_id": old_id}) w2.reload() - self.assertFalse(w2.toggle) - self.assertEqual(w2.count, 2) + assert not w2.toggle + assert w2.count == 2 # save_condition uses mongoengine-style operator syntax flip(w1) w1.save(save_condition={"count__lt": w1.count}) w1.reload() - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 3) + assert w1.toggle + assert w1.count == 3 flip(w1) - self.assertRaises( - SaveConditionError, w1.save, save_condition={"count__gte": w1.count} - ) + with pytest.raises(SaveConditionError): + w1.save(save_condition={"count__gte": w1.count}) w1.reload() - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 3) + assert w1.toggle + assert w1.count == 3 def test_save_update_selectively(self): class WildBoy(Document): @@ -1303,8 +1295,8 @@ class TestInstance(MongoDBTestCase): boy2.save() fresh_boy = WildBoy.objects().first() - self.assertEqual(fresh_boy.age, 99) - self.assertEqual(fresh_boy.name, "Bob") + assert fresh_boy.age == 99 + assert fresh_boy.name == "Bob" def test_save_update_selectively_with_custom_pk(self): # Prevents regression of #2082 @@ -1326,8 +1318,8 @@ class TestInstance(MongoDBTestCase): boy2.save() fresh_boy = WildBoy.objects().first() - self.assertEqual(fresh_boy.age, 99) - self.assertEqual(fresh_boy.name, "Bob") + assert fresh_boy.age == 99 + assert fresh_boy.name == "Bob" def test_update(self): """Ensure that an existing document is updated instead of be @@ -1343,20 +1335,20 @@ class TestInstance(MongoDBTestCase): same_person.save() # Confirm only one object - self.assertEqual(self.Person.objects.count(), 1) + assert self.Person.objects.count() == 1 # reload person.reload() same_person.reload() # Confirm the same - self.assertEqual(person, same_person) - self.assertEqual(person.name, same_person.name) - self.assertEqual(person.age, same_person.age) + assert person == same_person + assert person.name == same_person.name + assert person.age == same_person.age # Confirm the saved values - self.assertEqual(person.name, "Test") - self.assertEqual(person.age, 30) + assert person.name == "Test" + assert person.age == 30 # Test only / exclude only updates included fields person = self.Person.objects.only("name").get() @@ -1364,8 +1356,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 30) + assert person.name == "User" + assert person.age == 30 # test exclude only updates set fields person = self.Person.objects.exclude("name").get() @@ -1373,8 +1365,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 21) + assert person.name == "User" + assert person.age == 21 # Test only / exclude can set non excluded / included fields person = self.Person.objects.only("name").get() @@ -1383,8 +1375,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "Test") - self.assertEqual(person.age, 30) + assert person.name == "Test" + assert person.age == 30 # test exclude only updates set fields person = self.Person.objects.exclude("name").get() @@ -1393,8 +1385,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 21) + assert person.name == "User" + assert person.age == 21 # Confirm does remove unrequired fields person = self.Person.objects.exclude("name").get() @@ -1402,8 +1394,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, None) + assert person.name == "User" + assert person.age == None person = self.Person.objects.get() person.name = None @@ -1411,20 +1403,20 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, None) - self.assertEqual(person.age, None) + assert person.name == None + assert person.age == None def test_update_rename_operator(self): """Test the $rename operator.""" coll = self.Person._get_collection() doc = self.Person(name="John").save() raw_doc = coll.find_one({"_id": doc.pk}) - self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "name"])) + assert set(raw_doc.keys()) == set(["_id", "_cls", "name"]) doc.update(rename__name="first_name") raw_doc = coll.find_one({"_id": doc.pk}) - self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "first_name"])) - self.assertEqual(raw_doc["first_name"], "John") + assert set(raw_doc.keys()) == set(["_id", "_cls", "first_name"]) + assert raw_doc["first_name"] == "John" def test_inserts_if_you_set_the_pk(self): p1 = self.Person(name="p1", id=bson.ObjectId()).save() @@ -1432,7 +1424,7 @@ class TestInstance(MongoDBTestCase): p2.id = bson.ObjectId() p2.save() - self.assertEqual(2, self.Person.objects.count()) + assert 2 == self.Person.objects.count() def test_can_save_if_not_included(self): class EmbeddedDoc(EmbeddedDocument): @@ -1480,13 +1472,13 @@ class TestInstance(MongoDBTestCase): my_doc.save() my_doc = Doc.objects.get(string_field="string") - self.assertEqual(my_doc.string_field, "string") - self.assertEqual(my_doc.int_field, 1) + assert my_doc.string_field == "string" + assert my_doc.int_field == 1 def test_document_update(self): # try updating a non-saved document - with self.assertRaises(OperationError): + with pytest.raises(OperationError): person = self.Person(name="dcrosta") person.update(set__name="Dan Crosta") @@ -1497,10 +1489,10 @@ class TestInstance(MongoDBTestCase): author.reload() p1 = self.Person.objects.first() - self.assertEqual(p1.name, author.name) + assert p1.name == author.name # try sending an empty update - with self.assertRaises(OperationError): + with pytest.raises(OperationError): person = self.Person.objects.first() person.update() @@ -1509,7 +1501,7 @@ class TestInstance(MongoDBTestCase): person = self.Person.objects.first() person.update(name="Dan") person.reload() - self.assertEqual("Dan", person.name) + assert "Dan" == person.name def test_update_unique_field(self): class Doc(Document): @@ -1518,7 +1510,7 @@ class TestInstance(MongoDBTestCase): doc1 = Doc(name="first").save() doc2 = Doc(name="second").save() - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): doc2.update(set__name=doc1.name) def test_embedded_update(self): @@ -1540,7 +1532,7 @@ class TestInstance(MongoDBTestCase): site.save() site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") + assert site.page.log_message == "Error: Dummy message" def test_update_list_field(self): """Test update on `ListField` with $pull + $in. @@ -1558,7 +1550,7 @@ class TestInstance(MongoDBTestCase): doc.update(pull__foo__in=["a", "c"]) doc = Doc.objects.first() - self.assertEqual(doc.foo, ["b"]) + assert doc.foo == ["b"] def test_embedded_update_db_field(self): """Test update on `EmbeddedDocumentField` fields when db_field @@ -1584,7 +1576,7 @@ class TestInstance(MongoDBTestCase): site.save() site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") + assert site.page.log_message == "Error: Dummy message" def test_save_only_changed_fields(self): """Ensure save only sets / unsets changed fields.""" @@ -1610,9 +1602,9 @@ class TestInstance(MongoDBTestCase): same_person.save() person = self.Person.objects.get() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 21) - self.assertEqual(person.active, False) + assert person.name == "User" + assert person.age == 21 + assert person.active == False def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( self, @@ -1626,7 +1618,7 @@ class TestInstance(MongoDBTestCase): emb = EmbeddedChildModel(id={"1": [1]}) changed_fields = ParentModel(child=emb)._get_changed_fields() - self.assertEqual(changed_fields, []) + assert changed_fields == [] def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc( self, @@ -1647,8 +1639,8 @@ class TestInstance(MongoDBTestCase): message = Message(id=1, author=user).save() message.author.name = "tutu" - self.assertEqual(message._get_changed_fields(), []) - self.assertEqual(user._get_changed_fields(), ["name"]) + assert message._get_changed_fields() == [] + assert user._get_changed_fields() == ["name"] def test__get_changed_fields_same_ids_embedded(self): # Refers to Issue #1768 @@ -1667,11 +1659,11 @@ class TestInstance(MongoDBTestCase): message = Message(id=1, author=user).save() message.author.name = "tutu" - self.assertEqual(message._get_changed_fields(), ["author.name"]) + assert message._get_changed_fields() == ["author.name"] message.save() message_fetched = Message.objects.with_id(message.id) - self.assertEqual(message_fetched.author.name, "tutu") + assert message_fetched.author.name == "tutu" def test_query_count_when_saving(self): """Ensure references don't cause extra fetches when saving""" @@ -1707,65 +1699,65 @@ class TestInstance(MongoDBTestCase): user = User.objects.first() # Even if stored as ObjectId's internally mongoengine uses DBRefs # As ObjectId's aren't automatically derefenced - self.assertIsInstance(user._data["orgs"][0], DBRef) - self.assertIsInstance(user.orgs[0], Organization) - self.assertIsInstance(user._data["orgs"][0], Organization) + assert isinstance(user._data["orgs"][0], DBRef) + assert isinstance(user.orgs[0], Organization) + assert isinstance(user._data["orgs"][0], Organization) # Changing a value with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription.objects.first() - self.assertEqual(q, 1) + assert q == 1 sub.name = "Test Sub" sub.save() - self.assertEqual(q, 2) + assert q == 2 # Changing a value that will cascade with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription.objects.first() - self.assertEqual(q, 1) + assert q == 1 sub.user.name = "Test" - self.assertEqual(q, 2) + assert q == 2 sub.save(cascade=True) - self.assertEqual(q, 3) + assert q == 3 # Changing a value and one that will cascade with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription.objects.first() sub.name = "Test Sub 2" - self.assertEqual(q, 1) + assert q == 1 sub.user.name = "Test 2" - self.assertEqual(q, 2) + assert q == 2 sub.save(cascade=True) - self.assertEqual(q, 4) # One for the UserSub and one for the User + assert q == 4 # One for the UserSub and one for the User # Saving with just the refs with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription(user=u1.pk, feed=f1.pk) - self.assertEqual(q, 0) + assert q == 0 sub.save() - self.assertEqual(q, 1) + assert q == 1 # Saving with just the refs on a ListField with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 User(name="Bob", orgs=[o1.pk, o2.pk]).save() - self.assertEqual(q, 1) + assert q == 1 # Saving new objects with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 user = User.objects.first() - self.assertEqual(q, 1) + assert q == 1 feed = Feed.objects.first() - self.assertEqual(q, 2) + assert q == 2 sub = UserSubscription(user=user, feed=feed) - self.assertEqual(q, 2) # Check no change + assert q == 2 # Check no change sub.save() - self.assertEqual(q, 3) + assert q == 3 def test_set_unset_one_operation(self): """Ensure that $set and $unset actions are performed in the @@ -1781,14 +1773,14 @@ class TestInstance(MongoDBTestCase): # write an entity with a single prop foo = FooBar(foo="foo").save() - self.assertEqual(foo.foo, "foo") + assert foo.foo == "foo" del foo.foo foo.bar = "bar" with query_counter() as q: - self.assertEqual(0, q) + assert 0 == q foo.save() - self.assertEqual(1, q) + assert 1 == q def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields.""" @@ -1810,34 +1802,34 @@ class TestInstance(MongoDBTestCase): person.reload() person = self.Person.objects.get() - self.assertTrue(person.comments[0].published) + assert person.comments[0].published person.comments[0].published = False person.save() person = self.Person.objects.get() - self.assertFalse(person.comments[0].published) + assert not person.comments[0].published # Simple dict w person.comments_dict["first_post"] = Comment() person.save() person = self.Person.objects.get() - self.assertTrue(person.comments_dict["first_post"].published) + assert person.comments_dict["first_post"].published person.comments_dict["first_post"].published = False person.save() person = self.Person.objects.get() - self.assertFalse(person.comments_dict["first_post"].published) + assert not person.comments_dict["first_post"].published def test_delete(self): """Ensure that document may be deleted using the delete method.""" person = self.Person(name="Test User", age=30) person.save() - self.assertEqual(self.Person.objects.count(), 1) + assert self.Person.objects.count() == 1 person.delete() - self.assertEqual(self.Person.objects.count(), 0) + assert self.Person.objects.count() == 0 def test_save_custom_id(self): """Ensure that a document may be saved with a custom _id.""" @@ -1849,7 +1841,7 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({"name": "Test User"}) - self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4") + assert str(person_obj["_id"]) == "497ce96f395f2f052a494fd4" def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using @@ -1862,7 +1854,7 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({"name": "Test User"}) - self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4") + assert str(person_obj["_id"]) == "497ce96f395f2f052a494fd4" def test_save_list(self): """Ensure that a list field may be properly saved.""" @@ -1885,9 +1877,9 @@ class TestInstance(MongoDBTestCase): collection = self.db[BlogPost._get_collection_name()] post_obj = collection.find_one() - self.assertEqual(post_obj["tags"], tags) + assert post_obj["tags"] == tags for comment_obj, comment in zip(post_obj["comments"], comments): - self.assertEqual(comment_obj["content"], comment["content"]) + assert comment_obj["content"] == comment["content"] def test_list_search_by_embedded(self): class User(Document): @@ -1944,9 +1936,9 @@ class TestInstance(MongoDBTestCase): p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")]) p4.save() - self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) - self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) - self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) + assert [p1, p2] == list(Page.objects.filter(comments__user=u1)) + assert [p1, p2, p4] == list(Page.objects.filter(comments__user=u2)) + assert [p1, p3] == list(Page.objects.filter(comments__user=u3)) def test_save_embedded_document(self): """Ensure that a document with an embedded document field may @@ -1968,11 +1960,11 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database collection = self.db[self.Person._get_collection_name()] employee_obj = collection.find_one({"name": "Test Employee"}) - self.assertEqual(employee_obj["name"], "Test Employee") - self.assertEqual(employee_obj["age"], 50) + assert employee_obj["name"] == "Test Employee" + assert employee_obj["age"] == 50 # Ensure that the 'details' embedded object saved correctly - self.assertEqual(employee_obj["details"]["position"], "Developer") + assert employee_obj["details"]["position"] == "Developer" def test_embedded_update_after_save(self): """Test update of `EmbeddedDocumentField` attached to a newly @@ -1994,7 +1986,7 @@ class TestInstance(MongoDBTestCase): site.save() site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") + assert site.page.log_message == "Error: Dummy message" def test_updating_an_embedded_document(self): """Ensure that a document with an embedded document field may @@ -2019,18 +2011,18 @@ class TestInstance(MongoDBTestCase): promoted_employee.save() promoted_employee.reload() - self.assertEqual(promoted_employee.name, "Test Employee") - self.assertEqual(promoted_employee.age, 50) + assert promoted_employee.name == "Test Employee" + assert promoted_employee.age == 50 # Ensure that the 'details' embedded object saved correctly - self.assertEqual(promoted_employee.details.position, "Senior Developer") + assert promoted_employee.details.position == "Senior Developer" # Test removal promoted_employee.details = None promoted_employee.save() promoted_employee.reload() - self.assertEqual(promoted_employee.details, None) + assert promoted_employee.details == None def test_object_mixins(self): class NameMixin(object): @@ -2039,12 +2031,12 @@ class TestInstance(MongoDBTestCase): class Foo(EmbeddedDocument, NameMixin): quantity = IntField() - self.assertEqual(["name", "quantity"], sorted(Foo._fields.keys())) + assert ["name", "quantity"] == sorted(Foo._fields.keys()) class Bar(Document, NameMixin): widgets = StringField() - self.assertEqual(["id", "name", "widgets"], sorted(Bar._fields.keys())) + assert ["id", "name", "widgets"] == sorted(Bar._fields.keys()) def test_mixin_inheritance(self): class BaseMixIn(object): @@ -2064,10 +2056,10 @@ class TestInstance(MongoDBTestCase): t = TestDoc.objects.first() - self.assertEqual(t.age, 19) - self.assertEqual(t.comment, "great!") - self.assertEqual(t.data, "test") - self.assertEqual(t.count, 12) + assert t.age == 19 + assert t.comment == "great!" + assert t.data == "test" + assert t.count == 12 def test_save_reference(self): """Ensure that a document reference field may be saved in the @@ -2092,22 +2084,22 @@ class TestInstance(MongoDBTestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertIsInstance(post_obj._data["author"], bson.DBRef) - self.assertIsInstance(post_obj.author, self.Person) - self.assertEqual(post_obj.author.name, "Test User") + assert isinstance(post_obj._data["author"], bson.DBRef) + assert isinstance(post_obj.author, self.Person) + assert post_obj.author.name == "Test User" # Ensure that the dereferenced object may be changed and saved post_obj.author.age = 25 post_obj.author.save() author = list(self.Person.objects(name="Test User"))[-1] - self.assertEqual(author.age, 25) + assert author.age == 25 def test_duplicate_db_fields_raise_invalid_document_error(self): """Ensure a InvalidDocumentError is thrown if duplicate fields declare the same db_field. """ - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Foo(Document): name = StringField() @@ -2125,7 +2117,7 @@ class TestInstance(MongoDBTestCase): forms = ListField(StringField(), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): Word._from_son( { "stem": [1, 2, 3], @@ -2136,7 +2128,7 @@ class TestInstance(MongoDBTestCase): ) # Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Word._from_son("this is not a valid SON dict") def test_reverse_delete_rule_cascade_and_nullify(self): @@ -2165,12 +2157,12 @@ class TestInstance(MongoDBTestCase): reviewer.delete() # No effect on the BlogPost - self.assertEqual(BlogPost.objects.count(), 1) - self.assertEqual(BlogPost.objects.get().reviewer, None) + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.get().reviewer == None # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(BlogPost.objects.count(), 0) + assert BlogPost.objects.count() == 0 def test_reverse_delete_rule_pull(self): """Ensure that a referenced document is also deleted with @@ -2189,7 +2181,7 @@ class TestInstance(MongoDBTestCase): parent_record.save() child_record.delete() - self.assertEqual(Record.objects(name="parent").get().children, []) + assert Record.objects(name="parent").get().children == [] def test_reverse_delete_rule_with_custom_id_field(self): """Ensure that a referenced document with custom primary key @@ -2211,11 +2203,11 @@ class TestInstance(MongoDBTestCase): book = Book(author=user, reviewer=reviewer).save() reviewer.delete() - self.assertEqual(Book.objects.count(), 1) - self.assertEqual(Book.objects.get().reviewer, None) + assert Book.objects.count() == 1 + assert Book.objects.get().reviewer == None user.delete() - self.assertEqual(Book.objects.count(), 0) + assert Book.objects.count() == 0 def test_reverse_delete_rule_with_shared_id_among_collections(self): """Ensure that cascade delete rule doesn't mix id among @@ -2239,16 +2231,16 @@ class TestInstance(MongoDBTestCase): user_2.delete() # Deleting user_2 should also delete book_1 but not book_2 - self.assertEqual(Book.objects.count(), 1) - self.assertEqual(Book.objects.get(), book_2) + assert Book.objects.count() == 1 + assert Book.objects.get() == book_2 user_3 = User(id=3).save() book_3 = Book(id=3, author=user_3).save() user_3.delete() # Deleting user_3 should also delete book_3 - self.assertEqual(Book.objects.count(), 1) - self.assertEqual(Book.objects.get(), book_2) + assert Book.objects.count() == 1 + assert Book.objects.get() == book_2 def test_reverse_delete_rule_with_document_inheritance(self): """Ensure that a referenced document is also deleted upon @@ -2278,12 +2270,12 @@ class TestInstance(MongoDBTestCase): post.save() reviewer.delete() - self.assertEqual(BlogPost.objects.count(), 1) - self.assertEqual(BlogPost.objects.get().reviewer, None) + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.get().reviewer == None # Delete the Writer should lead to deletion of the BlogPost author.delete() - self.assertEqual(BlogPost.objects.count(), 0) + assert BlogPost.objects.count() == 0 def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): """Ensure that a referenced document is also deleted upon @@ -2315,12 +2307,12 @@ class TestInstance(MongoDBTestCase): # Deleting the reviewer should have no effect on the BlogPost reviewer.delete() - self.assertEqual(BlogPost.objects.count(), 1) - self.assertEqual(BlogPost.objects.get().reviewers, []) + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.get().reviewers == [] # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(BlogPost.objects.count(), 0) + assert BlogPost.objects.count() == 0 def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): """Ensure the pre_delete signal is triggered upon a cascading @@ -2357,7 +2349,7 @@ class TestInstance(MongoDBTestCase): # the pre-delete signal should have decremented the editor's queue editor = Editor.objects(name="Max P.").get() - self.assertEqual(editor.review_queue, 0) + assert editor.review_queue == 0 def test_two_way_reverse_delete_rule(self): """Ensure that Bi-Directional relationships work with @@ -2389,11 +2381,11 @@ class TestInstance(MongoDBTestCase): f.delete() - self.assertEqual(Bar.objects.count(), 1) # No effect on the BlogPost - self.assertEqual(Bar.objects.get().foo, None) + assert Bar.objects.count() == 1 # No effect on the BlogPost + assert Bar.objects.get().foo == None def test_invalid_reverse_delete_rule_raise_errors(self): - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Blog(Document): content = StringField() @@ -2404,7 +2396,7 @@ class TestInstance(MongoDBTestCase): field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY) ) - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Parents(EmbeddedDocument): father = ReferenceField("Person", reverse_delete_rule=DENY) @@ -2441,7 +2433,7 @@ class TestInstance(MongoDBTestCase): # Delete the Person, which should lead to deletion of the BlogPost, # and, recursively to the Comment, too author.delete() - self.assertEqual(Comment.objects.count(), 0) + assert Comment.objects.count() == 0 def test_reverse_delete_rule_deny(self): """Ensure that a document cannot be referenced if there are @@ -2463,19 +2455,18 @@ class TestInstance(MongoDBTestCase): post.save() # Delete the Person should be denied - self.assertRaises(OperationError, author.delete) # Should raise denied error - self.assertEqual( - BlogPost.objects.count(), 1 - ) # No objects may have been deleted - self.assertEqual(self.Person.objects.count(), 1) + with pytest.raises(OperationError): + author.delete() # Should raise denied error + assert BlogPost.objects.count() == 1 # No objects may have been deleted + assert self.Person.objects.count() == 1 # Other users, that don't have BlogPosts must be removable, like normal author = self.Person(name="Another User") author.save() - self.assertEqual(self.Person.objects.count(), 2) + assert self.Person.objects.count() == 2 author.delete() - self.assertEqual(self.Person.objects.count(), 1) + assert self.Person.objects.count() == 1 def subclasses_and_unique_keys_works(self): class A(Document): @@ -2491,8 +2482,8 @@ class TestInstance(MongoDBTestCase): A().save() B(foo=True).save() - self.assertEqual(A.objects.count(), 2) - self.assertEqual(B.objects.count(), 1) + assert A.objects.count() == 2 + assert B.objects.count() == 1 def test_document_hash(self): """Test document in list, dict, set.""" @@ -2518,12 +2509,12 @@ class TestInstance(MongoDBTestCase): # Make sure docs are properly identified in a list (__eq__ is used # for the comparison). all_user_list = list(User.objects.all()) - self.assertIn(u1, all_user_list) - self.assertIn(u2, all_user_list) - self.assertIn(u3, all_user_list) - self.assertNotIn(u4, all_user_list) # New object - self.assertNotIn(b1, all_user_list) # Other object - self.assertNotIn(b2, all_user_list) # Other object + assert u1 in all_user_list + assert u2 in all_user_list + assert u3 in all_user_list + assert u4 not in all_user_list # New object + assert b1 not in all_user_list # Other object + assert b2 not in all_user_list # Other object # Make sure docs can be used as keys in a dict (__hash__ is used # for hashing the docs). @@ -2531,27 +2522,27 @@ class TestInstance(MongoDBTestCase): for u in User.objects.all(): all_user_dic[u] = "OK" - self.assertEqual(all_user_dic.get(u1, False), "OK") - self.assertEqual(all_user_dic.get(u2, False), "OK") - self.assertEqual(all_user_dic.get(u3, False), "OK") - self.assertEqual(all_user_dic.get(u4, False), False) # New object - self.assertEqual(all_user_dic.get(b1, False), False) # Other object - self.assertEqual(all_user_dic.get(b2, False), False) # Other object + assert all_user_dic.get(u1, False) == "OK" + assert all_user_dic.get(u2, False) == "OK" + assert all_user_dic.get(u3, False) == "OK" + assert all_user_dic.get(u4, False) == False # New object + assert all_user_dic.get(b1, False) == False # Other object + assert all_user_dic.get(b2, False) == False # Other object # Make sure docs are properly identified in a set (__hash__ is used # for hashing the docs). all_user_set = set(User.objects.all()) - self.assertIn(u1, all_user_set) - self.assertNotIn(u4, all_user_set) - self.assertNotIn(b1, all_user_list) - self.assertNotIn(b2, all_user_list) + assert u1 in all_user_set + assert u4 not in all_user_set + assert b1 not in all_user_list + assert b2 not in all_user_list # Make sure duplicate docs aren't accepted in the set - self.assertEqual(len(all_user_set), 3) + assert len(all_user_set) == 3 all_user_set.add(u1) all_user_set.add(u2) all_user_set.add(u3) - self.assertEqual(len(all_user_set), 3) + assert len(all_user_set) == 3 def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"]) @@ -2564,21 +2555,21 @@ class TestInstance(MongoDBTestCase): pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected, pickle_doc) + assert resurrected == pickle_doc # Test pickling changed data pickle_doc.lists.append("3") pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected, pickle_doc) + assert resurrected == pickle_doc resurrected.string = "Two" resurrected.save() pickle_doc = PickleTest.objects.first() - self.assertEqual(resurrected, pickle_doc) - self.assertEqual(pickle_doc.string, "Two") - self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) + assert resurrected == pickle_doc + assert pickle_doc.string == "Two" + assert pickle_doc.lists == ["1", "2", "3"] def test_regular_document_pickle(self): pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"]) @@ -2594,11 +2585,12 @@ class TestInstance(MongoDBTestCase): fixtures.PickleTest = fixtures.NewDocumentPickleTest resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected.__class__, fixtures.NewDocumentPickleTest) - self.assertEqual( - resurrected._fields_ordered, fixtures.NewDocumentPickleTest._fields_ordered + assert resurrected.__class__ == fixtures.NewDocumentPickleTest + assert ( + resurrected._fields_ordered + == fixtures.NewDocumentPickleTest._fields_ordered ) - self.assertNotEqual(resurrected._fields_ordered, pickle_doc._fields_ordered) + assert resurrected._fields_ordered != pickle_doc._fields_ordered # The local PickleTest is still a ref to the original fixtures.PickleTest = PickleTest @@ -2617,19 +2609,17 @@ class TestInstance(MongoDBTestCase): pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected, pickle_doc) - self.assertEqual(resurrected._fields_ordered, pickle_doc._fields_ordered) - self.assertEqual( - resurrected._dynamic_fields.keys(), pickle_doc._dynamic_fields.keys() - ) + assert resurrected == pickle_doc + assert resurrected._fields_ordered == pickle_doc._fields_ordered + assert resurrected._dynamic_fields.keys() == pickle_doc._dynamic_fields.keys() - self.assertEqual(resurrected.embedded, pickle_doc.embedded) - self.assertEqual( - resurrected.embedded._fields_ordered, pickle_doc.embedded._fields_ordered + assert resurrected.embedded == pickle_doc.embedded + assert ( + resurrected.embedded._fields_ordered == pickle_doc.embedded._fields_ordered ) - self.assertEqual( - resurrected.embedded._dynamic_fields.keys(), - pickle_doc.embedded._dynamic_fields.keys(), + assert ( + resurrected.embedded._dynamic_fields.keys() + == pickle_doc.embedded._dynamic_fields.keys() ) def test_picklable_on_signals(self): @@ -2642,7 +2632,7 @@ class TestInstance(MongoDBTestCase): """Test creating a field with a field name that would override the "validate" method. """ - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Blog(Document): validate = DictField() @@ -2659,7 +2649,7 @@ class TestInstance(MongoDBTestCase): a = A() a.save() a.reload() - self.assertEqual(a.b.field1, "field1") + assert a.b.field1 == "field1" class C(EmbeddedDocument): c_field = StringField(default="cfield") @@ -2676,7 +2666,7 @@ class TestInstance(MongoDBTestCase): a.save() a.reload() - self.assertEqual(a.b.field2.c_field, "new value") + assert a.b.field2.c_field == "new value" def test_can_save_false_values(self): """Ensures you can save False values on save.""" @@ -2692,7 +2682,7 @@ class TestInstance(MongoDBTestCase): d.archived = False d.save() - self.assertEqual(Doc.objects(archived=False).count(), 1) + assert Doc.objects(archived=False).count() == 1 def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs.""" @@ -2707,7 +2697,7 @@ class TestInstance(MongoDBTestCase): d.archived = False d.save() - self.assertEqual(Doc.objects(archived=False).count(), 1) + assert Doc.objects(archived=False).count() == 1 def test_do_not_save_unchanged_references(self): """Ensures cascading saves dont auto update""" @@ -2768,8 +2758,8 @@ class TestInstance(MongoDBTestCase): hp = Book.objects.create(name="Harry Potter") # Selects - self.assertEqual(User.objects.first(), bob) - self.assertEqual(Book.objects.first(), hp) + assert User.objects.first() == bob + assert Book.objects.first() == hp # DeReference class AuthorBooks(Document): @@ -2783,27 +2773,23 @@ class TestInstance(MongoDBTestCase): ab = AuthorBooks.objects.create(author=bob, book=hp) # select - self.assertEqual(AuthorBooks.objects.first(), ab) - self.assertEqual(AuthorBooks.objects.first().book, hp) - self.assertEqual(AuthorBooks.objects.first().author, bob) - self.assertEqual(AuthorBooks.objects.filter(author=bob).first(), ab) - self.assertEqual(AuthorBooks.objects.filter(book=hp).first(), ab) + assert AuthorBooks.objects.first() == ab + assert AuthorBooks.objects.first().book == hp + assert AuthorBooks.objects.first().author == bob + assert AuthorBooks.objects.filter(author=bob).first() == ab + assert AuthorBooks.objects.filter(book=hp).first() == ab # DB Alias - self.assertEqual(User._get_db(), get_db("testdb-1")) - self.assertEqual(Book._get_db(), get_db("testdb-2")) - self.assertEqual(AuthorBooks._get_db(), get_db("testdb-3")) + assert User._get_db() == get_db("testdb-1") + assert Book._get_db() == get_db("testdb-2") + assert AuthorBooks._get_db() == get_db("testdb-3") # Collections - self.assertEqual( - User._get_collection(), get_db("testdb-1")[User._get_collection_name()] - ) - self.assertEqual( - Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()] - ) - self.assertEqual( - AuthorBooks._get_collection(), - get_db("testdb-3")[AuthorBooks._get_collection_name()], + assert User._get_collection() == get_db("testdb-1")[User._get_collection_name()] + assert Book._get_collection() == get_db("testdb-2")[Book._get_collection_name()] + assert ( + AuthorBooks._get_collection() + == get_db("testdb-3")[AuthorBooks._get_collection_name()] ) def test_db_alias_overrides(self): @@ -2826,9 +2812,9 @@ class TestInstance(MongoDBTestCase): A.objects.all() - self.assertEqual("testdb-2", B._meta.get("db_alias")) - self.assertEqual("mongoenginetest", A._get_collection().database.name) - self.assertEqual("mongoenginetest2", B._get_collection().database.name) + assert "testdb-2" == B._meta.get("db_alias") + assert "mongoenginetest" == A._get_collection().database.name + assert "mongoenginetest2" == B._get_collection().database.name def test_db_alias_propagates(self): """db_alias propagates?""" @@ -2841,7 +2827,7 @@ class TestInstance(MongoDBTestCase): class B(A): pass - self.assertEqual("testdb-1", B._meta.get("db_alias")) + assert "testdb-1" == B._meta.get("db_alias") def test_db_ref_usage(self): """DB Ref usage in dict_fields.""" @@ -2898,11 +2884,9 @@ class TestInstance(MongoDBTestCase): Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) # Checks - self.assertEqual( - ",".join([str(b) for b in Book.objects.all()]), "1,2,3,4,5,6,7,8,9" - ) + assert ",".join([str(b) for b in Book.objects.all()]) == "1,2,3,4,5,6,7,8,9" # bob related books - self.assertEqual( + assert ( ",".join( [ str(b) @@ -2910,12 +2894,12 @@ class TestInstance(MongoDBTestCase): Q(extra__a=bob) | Q(author=bob) | Q(extra__b=bob) ) ] - ), - "1,2,3,4", + ) + == "1,2,3,4" ) # Susan & Karl related books - self.assertEqual( + assert ( ",".join( [ str(b) @@ -2925,12 +2909,12 @@ class TestInstance(MongoDBTestCase): | Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()]) ) ] - ), - "1", + ) + == "1" ) # $Where - self.assertEqual( + assert ( u",".join( [ str(b) @@ -2943,8 +2927,8 @@ class TestInstance(MongoDBTestCase): } ) ] - ), - "1,2", + ) + == "1,2" ) def test_switch_db_instance(self): @@ -2958,7 +2942,7 @@ class TestInstance(MongoDBTestCase): Group.drop_collection() Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() group = Group.objects.first() group.switch_db("testdb-1") @@ -2967,10 +2951,10 @@ class TestInstance(MongoDBTestCase): with switch_db(Group, "testdb-1") as Group: group = Group.objects.first() - self.assertEqual("hello - testdb!", group.name) + assert "hello - testdb!" == group.name group = Group.objects.first() - self.assertEqual("hello - default", group.name) + assert "hello - default" == group.name # Slightly contrived now - perform an update # Only works as they have the same object_id @@ -2979,12 +2963,12 @@ class TestInstance(MongoDBTestCase): with switch_db(Group, "testdb-1") as Group: group = Group.objects.first() - self.assertEqual("hello - update", group.name) + assert "hello - update" == group.name Group.drop_collection() - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() group = Group.objects.first() - self.assertEqual("hello - default", group.name) + assert "hello - default" == group.name # Totally contrived now - perform a delete # Only works as they have the same object_id @@ -2992,10 +2976,10 @@ class TestInstance(MongoDBTestCase): group.delete() with switch_db(Group, "testdb-1") as Group: - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() group = Group.objects.first() - self.assertEqual("hello - default", group.name) + assert "hello - default" == group.name def test_load_undefined_fields(self): class User(Document): @@ -3007,7 +2991,8 @@ class TestInstance(MongoDBTestCase): {"name": "John", "foo": "Bar", "data": [1, 2, 3]} ) - self.assertRaises(FieldDoesNotExist, User.objects.first) + with pytest.raises(FieldDoesNotExist): + User.objects.first() def test_load_undefined_fields_with_strict_false(self): class User(Document): @@ -3022,11 +3007,11 @@ class TestInstance(MongoDBTestCase): ) user = User.objects.first() - self.assertEqual(user.name, "John") - self.assertFalse(hasattr(user, "foo")) - self.assertEqual(user._data["foo"], "Bar") - self.assertFalse(hasattr(user, "data")) - self.assertEqual(user._data["data"], [1, 2, 3]) + assert user.name == "John" + assert not hasattr(user, "foo") + assert user._data["foo"] == "Bar" + assert not hasattr(user, "data") + assert user._data["data"] == [1, 2, 3] def test_load_undefined_fields_on_embedded_document(self): class Thing(EmbeddedDocument): @@ -3045,7 +3030,8 @@ class TestInstance(MongoDBTestCase): } ) - self.assertRaises(FieldDoesNotExist, User.objects.first) + with pytest.raises(FieldDoesNotExist): + User.objects.first() def test_load_undefined_fields_on_embedded_document_with_strict_false_on_doc(self): class Thing(EmbeddedDocument): @@ -3066,7 +3052,8 @@ class TestInstance(MongoDBTestCase): } ) - self.assertRaises(FieldDoesNotExist, User.objects.first) + with pytest.raises(FieldDoesNotExist): + User.objects.first() def test_load_undefined_fields_on_embedded_document_with_strict_false(self): class Thing(EmbeddedDocument): @@ -3088,12 +3075,12 @@ class TestInstance(MongoDBTestCase): ) user = User.objects.first() - self.assertEqual(user.name, "John") - self.assertEqual(user.thing.name, "My thing") - self.assertFalse(hasattr(user.thing, "foo")) - self.assertEqual(user.thing._data["foo"], "Bar") - self.assertFalse(hasattr(user.thing, "data")) - self.assertEqual(user.thing._data["data"], [1, 2, 3]) + assert user.name == "John" + assert user.thing.name == "My thing" + assert not hasattr(user.thing, "foo") + assert user.thing._data["foo"] == "Bar" + assert not hasattr(user.thing, "data") + assert user.thing._data["data"] == [1, 2, 3] def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -3108,7 +3095,7 @@ class TestInstance(MongoDBTestCase): doc.save() one = Doc.objects.filter(**{"hello world": 1}).count() - self.assertEqual(1, one) + assert 1 == one def test_shard_key(self): class LogEntry(Document): @@ -3123,13 +3110,13 @@ class TestInstance(MongoDBTestCase): log.machine = "Localhost" log.save() - self.assertTrue(log.id is not None) + assert log.id is not None log.log = "Saving" log.save() # try to change the shard key - with self.assertRaises(OperationError): + with pytest.raises(OperationError): log.machine = "127.0.0.1" def test_shard_key_in_embedded_document(self): @@ -3145,13 +3132,13 @@ class TestInstance(MongoDBTestCase): bar_doc = Bar(foo=foo_doc, bar="world") bar_doc.save() - self.assertTrue(bar_doc.id is not None) + assert bar_doc.id is not None bar_doc.bar = "baz" bar_doc.save() # try to change the shard key - with self.assertRaises(OperationError): + with pytest.raises(OperationError): bar_doc.foo.foo = "something" bar_doc.save() @@ -3168,13 +3155,13 @@ class TestInstance(MongoDBTestCase): log.machine = "Localhost" log.save() - self.assertTrue(log.id is not None) + assert log.id is not None log.log = "Saving" log.save() # try to change the shard key - with self.assertRaises(OperationError): + with pytest.raises(OperationError): log.machine = "127.0.0.1" def test_kwargs_simple(self): @@ -3191,8 +3178,8 @@ class TestInstance(MongoDBTestCase): classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) dict_doc = Doc(**{"doc_name": "my doc", "doc": {"name": "embedded doc"}}) - self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + assert classic_doc == dict_doc + assert classic_doc._data == dict_doc._data def test_kwargs_complex(self): class Embedded(EmbeddedDocument): @@ -3216,48 +3203,48 @@ class TestInstance(MongoDBTestCase): } ) - self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + assert classic_doc == dict_doc + assert classic_doc._data == dict_doc._data def test_positional_creation(self): """Document cannot be instantiated using positional arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: person = self.Person("Test User", 42) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_mixed_creation(self): """Document cannot be instantiated using mixed arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: person = self.Person("Test User", age=42) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_positional_creation_embedded(self): """Embedded document cannot be created using positional arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: job = self.Job("Test Job", 4) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_mixed_creation_embedded(self): """Embedded document cannot be created using mixed arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: job = self.Job("Test Job", years=4) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id'.""" @@ -3269,8 +3256,8 @@ class TestInstance(MongoDBTestCase): Person(name="Harry Potter").save() person = Person.objects.first() - self.assertIn("id", person._data.keys()) - self.assertEqual(person._data.get("id"), person.id) + assert "id" in person._data.keys() + assert person._data.get("id") == person.id def test_complex_nesting_document_and_embedded_document(self): class Macro(EmbeddedDocument): @@ -3310,8 +3297,8 @@ class TestInstance(MongoDBTestCase): system.save() system = NodesSystem.objects.first() - self.assertEqual( - "UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value + assert ( + "UNDEFINED" == system.nodes["node"].parameters["param"].macros["test"].value ) def test_embedded_document_equality(self): @@ -3328,9 +3315,9 @@ class TestInstance(MongoDBTestCase): f1 = Embedded._from_son(e.to_mongo()) f2 = Embedded._from_son(e.to_mongo()) - self.assertEqual(f1, f2) + assert f1 == f2 f1.ref # Dereferences lazily - self.assertEqual(f1, f2) + assert f1 == f2 def test_dbref_equality(self): class Test2(Document): @@ -3361,36 +3348,36 @@ class TestInstance(MongoDBTestCase): dbref2 = f._data["test2"] obj2 = f.test2 - self.assertIsInstance(dbref2, DBRef) - self.assertIsInstance(obj2, Test2) - self.assertEqual(obj2.id, dbref2.id) - self.assertEqual(obj2, dbref2) - self.assertEqual(dbref2, obj2) + assert isinstance(dbref2, DBRef) + assert isinstance(obj2, Test2) + assert obj2.id == dbref2.id + assert obj2 == dbref2 + assert dbref2 == obj2 dbref3 = f._data["test3"] obj3 = f.test3 - self.assertIsInstance(dbref3, DBRef) - self.assertIsInstance(obj3, Test3) - self.assertEqual(obj3.id, dbref3.id) - self.assertEqual(obj3, dbref3) - self.assertEqual(dbref3, obj3) + assert isinstance(dbref3, DBRef) + assert isinstance(obj3, Test3) + assert obj3.id == dbref3.id + assert obj3 == dbref3 + assert dbref3 == obj3 - self.assertEqual(obj2.id, obj3.id) - self.assertEqual(dbref2.id, dbref3.id) - self.assertNotEqual(dbref2, dbref3) - self.assertNotEqual(dbref3, dbref2) - self.assertNotEqual(dbref2, dbref3) - self.assertNotEqual(dbref3, dbref2) + assert obj2.id == obj3.id + assert dbref2.id == dbref3.id + assert dbref2 != dbref3 + assert dbref3 != dbref2 + assert dbref2 != dbref3 + assert dbref3 != dbref2 - self.assertNotEqual(obj2, dbref3) - self.assertNotEqual(dbref3, obj2) - self.assertNotEqual(obj2, dbref3) - self.assertNotEqual(dbref3, obj2) + assert obj2 != dbref3 + assert dbref3 != obj2 + assert obj2 != dbref3 + assert dbref3 != obj2 - self.assertNotEqual(obj3, dbref2) - self.assertNotEqual(dbref2, obj3) - self.assertNotEqual(obj3, dbref2) - self.assertNotEqual(dbref2, obj3) + assert obj3 != dbref2 + assert dbref2 != obj3 + assert obj3 != dbref2 + assert dbref2 != obj3 def test_default_values(self): class Person(Document): @@ -3405,7 +3392,7 @@ class TestInstance(MongoDBTestCase): p2.name = "alon2" p2.save() p3 = Person.objects().only("created_on")[0] - self.assertEqual(orig_created_on, p3.created_on) + assert orig_created_on == p3.created_on class Person(Document): created_on = DateTimeField(default=lambda: datetime.utcnow()) @@ -3414,10 +3401,10 @@ class TestInstance(MongoDBTestCase): p4 = Person.objects()[0] p4.save() - self.assertEqual(p4.height, 189) + assert p4.height == 189 # However the default will not be fixed in DB - self.assertEqual(Person.objects(height=189).count(), 0) + assert Person.objects(height=189).count() == 0 # alter DB for the new default coll = Person._get_collection() @@ -3425,7 +3412,7 @@ class TestInstance(MongoDBTestCase): if "height" not in person: coll.update_one({"_id": person["_id"]}, {"$set": {"height": 189}}) - self.assertEqual(Person.objects(height=189).count(), 1) + assert Person.objects(height=189).count() == 1 def test_shard_key_mutability_after_from_json(self): """Ensure that a document ID can be modified after from_json. @@ -3445,11 +3432,11 @@ class TestInstance(MongoDBTestCase): meta = {"shard_key": ("id", "name")} p = Person.from_json('{"name": "name", "age": 27}', created=True) - self.assertEqual(p._created, True) + assert p._created == True p.name = "new name" p.id = "12345" - self.assertEqual(p.name, "new name") - self.assertEqual(p.id, "12345") + assert p.name == "new name" + assert p.id == "12345" def test_shard_key_mutability_after_from_son(self): """Ensure that a document ID can be modified after _from_son. @@ -3463,11 +3450,11 @@ class TestInstance(MongoDBTestCase): meta = {"shard_key": ("id", "name")} p = Person._from_son({"name": "name", "age": 27}, created=True) - self.assertEqual(p._created, True) + assert p._created == True p.name = "new name" p.id = "12345" - self.assertEqual(p.name, "new name") - self.assertEqual(p.id, "12345") + assert p.name == "new name" + assert p.id == "12345" def test_from_json_created_false_without_an_id(self): class Person(Document): @@ -3476,14 +3463,14 @@ class TestInstance(MongoDBTestCase): Person.objects.delete() p = Person.from_json('{"name": "name"}', created=False) - self.assertEqual(p._created, False) - self.assertEqual(p.id, None) + assert p._created == False + assert p.id == None # Make sure the document is subsequently persisted correctly. p.save() - self.assertTrue(p.id is not None) + assert p.id is not None saved_p = Person.objects.get(id=p.id) - self.assertEqual(saved_p.name, "name") + assert saved_p.name == "name" def test_from_json_created_false_with_an_id(self): """See https://github.com/mongoengine/mongoengine/issues/1854""" @@ -3496,13 +3483,13 @@ class TestInstance(MongoDBTestCase): p = Person.from_json( '{"_id": "5b85a8b04ec5dc2da388296e", "name": "name"}', created=False ) - self.assertEqual(p._created, False) - self.assertEqual(p._changed_fields, []) - self.assertEqual(p.name, "name") - self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e")) + assert p._created == False + assert p._changed_fields == [] + assert p.name == "name" + assert p.id == ObjectId("5b85a8b04ec5dc2da388296e") p.save() - with self.assertRaises(DoesNotExist): + with pytest.raises(DoesNotExist): # Since the object is considered as already persisted (thanks to # `created=False` and an existing ID), and we haven't changed any # fields (i.e. `_changed_fields` is empty), the document is @@ -3510,12 +3497,12 @@ class TestInstance(MongoDBTestCase): # nothing. Person.objects.get(id=p.id) - self.assertFalse(p._created) + assert not p._created p.name = "a new name" - self.assertEqual(p._changed_fields, ["name"]) + assert p._changed_fields == ["name"] p.save() saved_p = Person.objects.get(id=p.id) - self.assertEqual(saved_p.name, p.name) + assert saved_p.name == p.name def test_from_json_created_true_with_an_id(self): class Person(Document): @@ -3526,15 +3513,15 @@ class TestInstance(MongoDBTestCase): p = Person.from_json( '{"_id": "5b85a8b04ec5dc2da388296e", "name": "name"}', created=True ) - self.assertTrue(p._created) - self.assertEqual(p._changed_fields, []) - self.assertEqual(p.name, "name") - self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e")) + assert p._created + assert p._changed_fields == [] + assert p.name == "name" + assert p.id == ObjectId("5b85a8b04ec5dc2da388296e") p.save() saved_p = Person.objects.get(id=p.id) - self.assertEqual(saved_p, p) - self.assertEqual(saved_p.name, "name") + assert saved_p == p + assert saved_p.name == "name" def test_null_field(self): # 734 @@ -3553,13 +3540,13 @@ class TestInstance(MongoDBTestCase): u_from_db = User.objects.get(name="user") u_from_db.height = None u_from_db.save() - self.assertEqual(u_from_db.height, None) + assert u_from_db.height == None # 864 - self.assertEqual(u_from_db.str_fld, None) - self.assertEqual(u_from_db.int_fld, None) - self.assertEqual(u_from_db.flt_fld, None) - self.assertEqual(u_from_db.dt_fld, None) - self.assertEqual(u_from_db.cdt_fld, None) + assert u_from_db.str_fld == None + assert u_from_db.int_fld == None + assert u_from_db.flt_fld == None + assert u_from_db.dt_fld == None + assert u_from_db.cdt_fld == None # 735 User.objects.delete() @@ -3567,7 +3554,7 @@ class TestInstance(MongoDBTestCase): u.save() User.objects(name="user").update_one(set__height=None, upsert=True) u_from_db = User.objects.get(name="user") - self.assertEqual(u_from_db.height, None) + assert u_from_db.height == None def test_not_saved_eq(self): """Ensure we can compare documents not saved. @@ -3578,8 +3565,8 @@ class TestInstance(MongoDBTestCase): p = Person() p1 = Person() - self.assertNotEqual(p, p1) - self.assertEqual(p, p) + assert p != p1 + assert p == p def test_list_iter(self): # 914 @@ -3592,10 +3579,10 @@ class TestInstance(MongoDBTestCase): A.objects.delete() A(l=[B(v="1"), B(v="2"), B(v="3")]).save() a = A.objects.get() - self.assertEqual(a.l._instance, a) + assert a.l._instance == a for idx, b in enumerate(a.l): - self.assertEqual(b._instance, a) - self.assertEqual(idx, 2) + assert b._instance == a + assert idx == 2 def test_falsey_pk(self): """Ensure that we can create and update a document with Falsey PK.""" @@ -3625,7 +3612,7 @@ class TestInstance(MongoDBTestCase): blog.update(push__tags__0=["mongodb", "code"]) blog.reload() - self.assertEqual(blog.tags, ["mongodb", "code", "python"]) + assert blog.tags == ["mongodb", "code", "python"] def test_push_nested_list(self): """Ensure that push update works in nested list""" @@ -3637,7 +3624,7 @@ class TestInstance(MongoDBTestCase): blog = BlogPost(slug="test").save() blog.update(push__tags=["value1", 123]) blog.reload() - self.assertEqual(blog.tags, [["value1", 123]]) + assert blog.tags == [["value1", 123]] def test_accessing_objects_with_indexes_error(self): insert_result = self.db.company.insert_many( @@ -3653,7 +3640,7 @@ class TestInstance(MongoDBTestCase): company = ReferenceField(Company) # Ensure index creation exception aren't swallowed (#1688) - with self.assertRaises(DuplicateKeyError): + with pytest.raises(DuplicateKeyError): User.objects().select_related() @@ -3663,10 +3650,10 @@ class ObjectKeyTestCase(MongoDBTestCase): title = StringField() book = Book(title="Whatever") - self.assertEqual(book._object_key, {"pk": None}) + assert book._object_key == {"pk": None} book.pk = ObjectId() - self.assertEqual(book._object_key, {"pk": book.pk}) + assert book._object_key == {"pk": book.pk} def test_object_key_with_custom_primary_key(self): class Book(Document): @@ -3674,10 +3661,10 @@ class ObjectKeyTestCase(MongoDBTestCase): title = StringField() book = Book(title="Sapiens") - self.assertEqual(book._object_key, {"pk": None}) + assert book._object_key == {"pk": None} book = Book(pk="0062316117") - self.assertEqual(book._object_key, {"pk": "0062316117"}) + assert book._object_key == {"pk": "0062316117"} def test_object_key_in_a_sharded_collection(self): class Book(Document): @@ -3685,9 +3672,9 @@ class ObjectKeyTestCase(MongoDBTestCase): meta = {"shard_key": ("pk", "title")} book = Book() - self.assertEqual(book._object_key, {"pk": None, "title": None}) + assert book._object_key == {"pk": None, "title": None} book = Book(pk=ObjectId(), title="Sapiens") - self.assertEqual(book._object_key, {"pk": book.pk, "title": "Sapiens"}) + assert book._object_key == {"pk": book.pk, "title": "Sapiens"} def test_object_key_with_custom_db_field(self): class Book(Document): @@ -3695,7 +3682,7 @@ class ObjectKeyTestCase(MongoDBTestCase): meta = {"shard_key": ("pk", "author")} book = Book(pk=ObjectId(), author="Author") - self.assertEqual(book._object_key, {"pk": book.pk, "author": "Author"}) + assert book._object_key == {"pk": book.pk, "author": "Author"} def test_object_key_with_nested_shard_key(self): class Author(EmbeddedDocument): @@ -3706,7 +3693,7 @@ class ObjectKeyTestCase(MongoDBTestCase): meta = {"shard_key": ("pk", "author.name")} book = Book(pk=ObjectId(), author=Author(name="Author")) - self.assertEqual(book._object_key, {"pk": book.pk, "author__name": "Author"}) + assert book._object_key == {"pk": book.pk, "author__name": "Author"} if __name__ == "__main__": diff --git a/tests/document/test_json_serialisation.py b/tests/document/test_json_serialisation.py index 26a4a6c1..593d34f8 100644 --- a/tests/document/test_json_serialisation.py +++ b/tests/document/test_json_serialisation.py @@ -32,7 +32,7 @@ class TestJson(MongoDBTestCase): expected_json = """{"embedded":{"string":"Inner Hello"},"string":"Hello"}""" - self.assertEqual(doc_json, expected_json) + assert doc_json == expected_json def test_json_simple(self): class Embedded(EmbeddedDocument): @@ -52,9 +52,9 @@ class TestJson(MongoDBTestCase): doc_json = doc.to_json(sort_keys=True, separators=(",", ":")) expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}""" - self.assertEqual(doc_json, expected_json) + assert doc_json == expected_json - self.assertEqual(doc, Doc.from_json(doc.to_json())) + assert doc == Doc.from_json(doc.to_json()) def test_json_complex(self): class EmbeddedDoc(EmbeddedDocument): @@ -99,7 +99,7 @@ class TestJson(MongoDBTestCase): return json.loads(self.to_json()) == json.loads(other.to_json()) doc = Doc() - self.assertEqual(doc, Doc.from_json(doc.to_json())) + assert doc == Doc.from_json(doc.to_json()) if __name__ == "__main__": diff --git a/tests/document/test_validation.py b/tests/document/test_validation.py index 7449dd33..80601994 100644 --- a/tests/document/test_validation.py +++ b/tests/document/test_validation.py @@ -4,6 +4,7 @@ from datetime import datetime from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestValidatorError(MongoDBTestCase): @@ -11,12 +12,12 @@ class TestValidatorError(MongoDBTestCase): """Ensure a ValidationError handles error to_dict correctly. """ error = ValidationError("root") - self.assertEqual(error.to_dict(), {}) + assert error.to_dict() == {} # 1st level error schema error.errors = {"1st": ValidationError("bad 1st")} - self.assertIn("1st", error.to_dict()) - self.assertEqual(error.to_dict()["1st"], "bad 1st") + assert "1st" in error.to_dict() + assert error.to_dict()["1st"] == "bad 1st" # 2nd level error schema error.errors = { @@ -24,10 +25,10 @@ class TestValidatorError(MongoDBTestCase): "bad 1st", errors={"2nd": ValidationError("bad 2nd")} ) } - self.assertIn("1st", error.to_dict()) - self.assertIsInstance(error.to_dict()["1st"], dict) - self.assertIn("2nd", error.to_dict()["1st"]) - self.assertEqual(error.to_dict()["1st"]["2nd"], "bad 2nd") + assert "1st" in error.to_dict() + assert isinstance(error.to_dict()["1st"], dict) + assert "2nd" in error.to_dict()["1st"] + assert error.to_dict()["1st"]["2nd"] == "bad 2nd" # moar levels error.errors = { @@ -45,13 +46,13 @@ class TestValidatorError(MongoDBTestCase): }, ) } - self.assertIn("1st", error.to_dict()) - self.assertIn("2nd", error.to_dict()["1st"]) - self.assertIn("3rd", error.to_dict()["1st"]["2nd"]) - self.assertIn("4th", error.to_dict()["1st"]["2nd"]["3rd"]) - self.assertEqual(error.to_dict()["1st"]["2nd"]["3rd"]["4th"], "Inception") + assert "1st" in error.to_dict() + assert "2nd" in error.to_dict()["1st"] + assert "3rd" in error.to_dict()["1st"]["2nd"] + assert "4th" in error.to_dict()["1st"]["2nd"]["3rd"] + assert error.to_dict()["1st"]["2nd"]["3rd"]["4th"] == "Inception" - self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") + assert error.message == "root(2nd.3rd.4th.Inception: ['1st'])" def test_model_validation(self): class User(Document): @@ -61,19 +62,19 @@ class TestValidatorError(MongoDBTestCase): try: User().validate() except ValidationError as e: - self.assertIn("User:None", e.message) - self.assertEqual( - e.to_dict(), - {"username": "Field is required", "name": "Field is required"}, - ) + assert "User:None" in e.message + assert e.to_dict() == { + "username": "Field is required", + "name": "Field is required", + } user = User(username="RossC0", name="Ross").save() user.name = None try: user.save() except ValidationError as e: - self.assertIn("User:RossC0", e.message) - self.assertEqual(e.to_dict(), {"name": "Field is required"}) + assert "User:RossC0" in e.message + assert e.to_dict() == {"name": "Field is required"} def test_fields_rewrite(self): class BasePerson(Document): @@ -85,7 +86,8 @@ class TestValidatorError(MongoDBTestCase): name = StringField(required=True) p = Person(age=15) - self.assertRaises(ValidationError, p.validate) + with pytest.raises(ValidationError): + p.validate() def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. @@ -96,17 +98,19 @@ class TestValidatorError(MongoDBTestCase): content = StringField(required=True) comment = Comment() - self.assertRaises(ValidationError, comment.validate) + with pytest.raises(ValidationError): + comment.validate() comment.content = "test" comment.validate() comment.date = 4 - self.assertRaises(ValidationError, comment.validate) + with pytest.raises(ValidationError): + comment.validate() comment.date = datetime.now() comment.validate() - self.assertEqual(comment._instance, None) + assert comment._instance == None def test_embedded_db_field_validate(self): class SubDoc(EmbeddedDocument): @@ -119,10 +123,8 @@ class TestValidatorError(MongoDBTestCase): try: Doc(id="bad").validate() except ValidationError as e: - self.assertIn("SubDoc:None", e.message) - self.assertEqual( - e.to_dict(), {"e": {"val": "OK could not be converted to int"}} - ) + assert "SubDoc:None" in e.message + assert e.to_dict() == {"e": {"val": "OK could not be converted to int"}} Doc.drop_collection() @@ -130,18 +132,16 @@ class TestValidatorError(MongoDBTestCase): doc = Doc.objects.first() keys = doc._data.keys() - self.assertEqual(2, len(keys)) - self.assertIn("e", keys) - self.assertIn("id", keys) + assert 2 == len(keys) + assert "e" in keys + assert "id" in keys doc.e.val = "OK" try: doc.save() except ValidationError as e: - self.assertIn("Doc:test", e.message) - self.assertEqual( - e.to_dict(), {"e": {"val": "OK could not be converted to int"}} - ) + assert "Doc:test" in e.message + assert e.to_dict() == {"e": {"val": "OK could not be converted to int"}} def test_embedded_weakref(self): class SubDoc(EmbeddedDocument): @@ -157,14 +157,16 @@ class TestValidatorError(MongoDBTestCase): s = SubDoc() - self.assertRaises(ValidationError, s.validate) + with pytest.raises(ValidationError): + s.validate() d1.e = s d2.e = s del d1 - self.assertRaises(ValidationError, d2.validate) + with pytest.raises(ValidationError): + d2.validate() def test_parent_reference_in_child_document(self): """ diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py index 719df922..86ee2654 100644 --- a/tests/fields/test_binary_field.py +++ b/tests/fields/test_binary_field.py @@ -7,6 +7,7 @@ import six from mongoengine import * from tests.utils import MongoDBTestCase +import pytest BIN_VALUE = six.b( "\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5" @@ -31,8 +32,8 @@ class TestBinaryField(MongoDBTestCase): attachment.save() attachment_1 = Attachment.objects().first() - self.assertEqual(MIME_TYPE, attachment_1.content_type) - self.assertEqual(BLOB, six.binary_type(attachment_1.blob)) + assert MIME_TYPE == attachment_1.content_type + assert BLOB == six.binary_type(attachment_1.blob) def test_validation_succeeds(self): """Ensure that valid values can be assigned to binary fields. @@ -45,13 +46,15 @@ class TestBinaryField(MongoDBTestCase): blob = BinaryField(max_bytes=4) attachment_required = AttachmentRequired() - self.assertRaises(ValidationError, attachment_required.validate) + with pytest.raises(ValidationError): + attachment_required.validate() attachment_required.blob = Binary(six.b("\xe6\x00\xc4\xff\x07")) attachment_required.validate() _5_BYTES = six.b("\xe6\x00\xc4\xff\x07") _4_BYTES = six.b("\xe6\x00\xc4\xff") - self.assertRaises(ValidationError, AttachmentSizeLimit(blob=_5_BYTES).validate) + with pytest.raises(ValidationError): + AttachmentSizeLimit(blob=_5_BYTES).validate() AttachmentSizeLimit(blob=_4_BYTES).validate() def test_validation_fails(self): @@ -61,7 +64,8 @@ class TestBinaryField(MongoDBTestCase): blob = BinaryField() for invalid_data in (2, u"Im_a_unicode", ["some_str"]): - self.assertRaises(ValidationError, Attachment(blob=invalid_data).validate) + with pytest.raises(ValidationError): + Attachment(blob=invalid_data).validate() def test__primary(self): class Attachment(Document): @@ -70,10 +74,10 @@ class TestBinaryField(MongoDBTestCase): Attachment.drop_collection() binary_id = uuid.uuid4().bytes att = Attachment(id=binary_id).save() - self.assertEqual(1, Attachment.objects.count()) - self.assertEqual(1, Attachment.objects.filter(id=att.id).count()) + assert 1 == Attachment.objects.count() + assert 1 == Attachment.objects.filter(id=att.id).count() att.delete() - self.assertEqual(0, Attachment.objects.count()) + assert 0 == Attachment.objects.count() def test_primary_filter_by_binary_pk_as_str(self): class Attachment(Document): @@ -82,9 +86,9 @@ class TestBinaryField(MongoDBTestCase): Attachment.drop_collection() binary_id = uuid.uuid4().bytes att = Attachment(id=binary_id).save() - self.assertEqual(1, Attachment.objects.filter(id=binary_id).count()) + assert 1 == Attachment.objects.filter(id=binary_id).count() att.delete() - self.assertEqual(0, Attachment.objects.count()) + assert 0 == Attachment.objects.count() def test_match_querying_with_bytes(self): class MyDocument(Document): @@ -94,7 +98,7 @@ class TestBinaryField(MongoDBTestCase): doc = MyDocument(bin_field=BIN_VALUE).save() matched_doc = MyDocument.objects(bin_field=BIN_VALUE).first() - self.assertEqual(matched_doc.id, doc.id) + assert matched_doc.id == doc.id def test_match_querying_with_binary(self): class MyDocument(Document): @@ -105,7 +109,7 @@ class TestBinaryField(MongoDBTestCase): doc = MyDocument(bin_field=BIN_VALUE).save() matched_doc = MyDocument.objects(bin_field=Binary(BIN_VALUE)).first() - self.assertEqual(matched_doc.id, doc.id) + assert matched_doc.id == doc.id def test_modify_operation__set(self): """Ensures no regression of bug #1127""" @@ -119,11 +123,11 @@ class TestBinaryField(MongoDBTestCase): doc = MyDocument.objects(some_field="test").modify( upsert=True, new=True, set__bin_field=BIN_VALUE ) - self.assertEqual(doc.some_field, "test") + assert doc.some_field == "test" if six.PY3: - self.assertEqual(doc.bin_field, BIN_VALUE) + assert doc.bin_field == BIN_VALUE else: - self.assertEqual(doc.bin_field, Binary(BIN_VALUE)) + assert doc.bin_field == Binary(BIN_VALUE) def test_update_one(self): """Ensures no regression of bug #1127""" @@ -139,9 +143,9 @@ class TestBinaryField(MongoDBTestCase): n_updated = MyDocument.objects(bin_field=bin_data).update_one( bin_field=BIN_VALUE ) - self.assertEqual(n_updated, 1) + assert n_updated == 1 fetched = MyDocument.objects.with_id(doc.id) if six.PY3: - self.assertEqual(fetched.bin_field, BIN_VALUE) + assert fetched.bin_field == BIN_VALUE else: - self.assertEqual(fetched.bin_field, Binary(BIN_VALUE)) + assert fetched.bin_field == Binary(BIN_VALUE) diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py index 22ebb6f7..b38b5ea4 100644 --- a/tests/fields/test_boolean_field.py +++ b/tests/fields/test_boolean_field.py @@ -2,6 +2,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest class TestBooleanField(MongoDBTestCase): @@ -11,7 +12,7 @@ class TestBooleanField(MongoDBTestCase): person = Person(admin=True) person.save() - self.assertEqual(get_as_pymongo(person), {"_id": person.id, "admin": True}) + assert get_as_pymongo(person) == {"_id": person.id, "admin": True} def test_validation(self): """Ensure that invalid values cannot be assigned to boolean @@ -26,11 +27,14 @@ class TestBooleanField(MongoDBTestCase): person.validate() person.admin = 2 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.admin = "Yes" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.admin = "False" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_weirdness_constructor(self): """When attribute is set in contructor, it gets cast into a bool @@ -42,7 +46,7 @@ class TestBooleanField(MongoDBTestCase): admin = BooleanField() new_person = Person(admin="False") - self.assertTrue(new_person.admin) + assert new_person.admin new_person = Person(admin="0") - self.assertTrue(new_person.admin) + assert new_person.admin diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index 4e467587..e404aae0 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -4,6 +4,7 @@ from decimal import Decimal from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestCachedReferenceField(MongoDBTestCase): @@ -46,29 +47,29 @@ class TestCachedReferenceField(MongoDBTestCase): a = Animal(name="Leopard", tag="heavy") a.save() - self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) + assert Animal._cached_reference_fields == [Ocorrence.animal] o = Ocorrence(person="teste", animal=a) o.save() p = Ocorrence(person="Wilson") p.save() - self.assertEqual(Ocorrence.objects(animal=None).count(), 1) + assert Ocorrence.objects(animal=None).count() == 1 - self.assertEqual(a.to_mongo(fields=["tag"]), {"tag": "heavy", "_id": a.pk}) + assert a.to_mongo(fields=["tag"]) == {"tag": "heavy", "_id": a.pk} - self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") + assert o.to_mongo()["animal"]["tag"] == "heavy" # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() count = Ocorrence.objects(animal__tag="heavy").count() - self.assertEqual(count, 1) + assert count == 1 ocorrence = Ocorrence.objects(animal__tag="heavy").first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) + assert ocorrence.person == "teste" + assert isinstance(ocorrence.animal, Animal) def test_with_decimal(self): class PersonAuto(Document): @@ -88,10 +89,11 @@ class TestCachedReferenceField(MongoDBTestCase): s = SocialTest(group="dev", person=p) s.save() - self.assertEqual( - SocialTest.objects._collection.find_one({"person.salary": 7000.00}), - {"_id": s.pk, "group": s.group, "person": {"_id": p.pk, "salary": 7000.00}}, - ) + assert SocialTest.objects._collection.find_one({"person.salary": 7000.00}) == { + "_id": s.pk, + "group": s.group, + "person": {"_id": p.pk, "salary": 7000.00}, + } def test_cached_reference_field_reference(self): class Group(Document): @@ -131,18 +133,15 @@ class TestCachedReferenceField(MongoDBTestCase): s2 = SocialData(obs="testing 321", person=p3, tags=["tag3", "tag4"]) s2.save() - self.assertEqual( - SocialData.objects._collection.find_one({"tags": "tag2"}), - { - "_id": s1.pk, - "obs": "testing 123", - "tags": ["tag1", "tag2"], - "person": {"_id": p1.pk, "group": g1.pk}, - }, - ) + assert SocialData.objects._collection.find_one({"tags": "tag2"}) == { + "_id": s1.pk, + "obs": "testing 123", + "tags": ["tag1", "tag2"], + "person": {"_id": p1.pk, "group": g1.pk}, + } - self.assertEqual(SocialData.objects(person__group=g2).count(), 1) - self.assertEqual(SocialData.objects(person__group=g2).first(), s2) + assert SocialData.objects(person__group=g2).count() == 1 + assert SocialData.objects(person__group=g2).first() == s2 def test_cached_reference_field_push_with_fields(self): class Product(Document): @@ -157,26 +156,20 @@ class TestCachedReferenceField(MongoDBTestCase): product1 = Product(name="abc").save() product2 = Product(name="def").save() basket = Basket(products=[product1]).save() - self.assertEqual( - Basket.objects._collection.find_one(), - { - "_id": basket.pk, - "products": [{"_id": product1.pk, "name": product1.name}], - }, - ) + assert Basket.objects._collection.find_one() == { + "_id": basket.pk, + "products": [{"_id": product1.pk, "name": product1.name}], + } # push to list basket.update(push__products=product2) basket.reload() - self.assertEqual( - Basket.objects._collection.find_one(), - { - "_id": basket.pk, - "products": [ - {"_id": product1.pk, "name": product1.name}, - {"_id": product2.pk, "name": product2.name}, - ], - }, - ) + assert Basket.objects._collection.find_one() == { + "_id": basket.pk, + "products": [ + {"_id": product1.pk, "name": product1.name}, + {"_id": product2.pk, "name": product2.name}, + ], + } def test_cached_reference_field_update_all(self): class Person(Document): @@ -194,37 +187,31 @@ class TestCachedReferenceField(MongoDBTestCase): a2.save() a2 = Person.objects.with_id(a2.id) - self.assertEqual(a2.father.tp, a1.tp) + assert a2.father.tp == a1.tp - self.assertEqual( - dict(a2.to_mongo()), - { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": {"_id": a1.pk, "tp": u"pj"}, - }, - ) + assert dict(a2.to_mongo()) == { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": {"_id": a1.pk, "tp": u"pj"}, + } - self.assertEqual(Person.objects(father=a1)._query, {"father._id": a1.pk}) - self.assertEqual(Person.objects(father=a1).count(), 1) + assert Person.objects(father=a1)._query == {"father._id": a1.pk} + assert Person.objects(father=a1).count() == 1 Person.objects.update(set__tp="pf") Person.father.sync_all() a2.reload() - self.assertEqual( - dict(a2.to_mongo()), - { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": {"_id": a1.pk, "tp": u"pf"}, - }, - ) + assert dict(a2.to_mongo()) == { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": {"_id": a1.pk, "tp": u"pf"}, + } def test_cached_reference_fields_on_embedded_documents(self): - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Test(Document): name = StringField() @@ -255,15 +242,12 @@ class TestCachedReferenceField(MongoDBTestCase): a1.save() a2.reload() - self.assertEqual( - dict(a2.to_mongo()), - { - "_id": a2.pk, - "name": "Wilson Junior", - "tp": "pf", - "father": {"_id": a1.pk, "tp": "pf"}, - }, - ) + assert dict(a2.to_mongo()) == { + "_id": a2.pk, + "name": "Wilson Junior", + "tp": "pf", + "father": {"_id": a1.pk, "tp": "pf"}, + } def test_cached_reference_auto_sync_disabled(self): class Persone(Document): @@ -284,15 +268,12 @@ class TestCachedReferenceField(MongoDBTestCase): a1.tp = "pf" a1.save() - self.assertEqual( - Persone.objects._collection.find_one({"_id": a2.pk}), - { - "_id": a2.pk, - "name": "Wilson Junior", - "tp": "pf", - "father": {"_id": a1.pk, "tp": "pj"}, - }, - ) + assert Persone.objects._collection.find_one({"_id": a2.pk}) == { + "_id": a2.pk, + "name": "Wilson Junior", + "tp": "pf", + "father": {"_id": a1.pk, "tp": "pj"}, + } def test_cached_reference_embedded_fields(self): class Owner(EmbeddedDocument): @@ -320,28 +301,29 @@ class TestCachedReferenceField(MongoDBTestCase): o = Ocorrence(person="teste", animal=a) o.save() - self.assertEqual( - dict(a.to_mongo(fields=["tag", "owner.tp"])), - {"_id": a.pk, "tag": "heavy", "owner": {"t": "u"}}, - ) - self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") - self.assertEqual(o.to_mongo()["animal"]["owner"]["t"], "u") + assert dict(a.to_mongo(fields=["tag", "owner.tp"])) == { + "_id": a.pk, + "tag": "heavy", + "owner": {"t": "u"}, + } + assert o.to_mongo()["animal"]["tag"] == "heavy" + assert o.to_mongo()["animal"]["owner"]["t"] == "u" # Check to_mongo with fields - self.assertNotIn("animal", o.to_mongo(fields=["person"])) + assert "animal" not in o.to_mongo(fields=["person"]) # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() count = Ocorrence.objects(animal__tag="heavy", animal__owner__tp="u").count() - self.assertEqual(count, 1) + assert count == 1 ocorrence = Ocorrence.objects( animal__tag="heavy", animal__owner__tp="u" ).first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) + assert ocorrence.person == "teste" + assert isinstance(ocorrence.animal, Animal) def test_cached_reference_embedded_list_fields(self): class Owner(EmbeddedDocument): @@ -370,13 +352,14 @@ class TestCachedReferenceField(MongoDBTestCase): o = Ocorrence(person="teste 2", animal=a) o.save() - self.assertEqual( - dict(a.to_mongo(fields=["tag", "owner.tags"])), - {"_id": a.pk, "tag": "heavy", "owner": {"tags": ["cool", "funny"]}}, - ) + assert dict(a.to_mongo(fields=["tag", "owner.tags"])) == { + "_id": a.pk, + "tag": "heavy", + "owner": {"tags": ["cool", "funny"]}, + } - self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") - self.assertEqual(o.to_mongo()["animal"]["owner"]["tags"], ["cool", "funny"]) + assert o.to_mongo()["animal"]["tag"] == "heavy" + assert o.to_mongo()["animal"]["owner"]["tags"] == ["cool", "funny"] # counts Ocorrence(person="teste 2").save() @@ -385,10 +368,10 @@ class TestCachedReferenceField(MongoDBTestCase): query = Ocorrence.objects( animal__tag="heavy", animal__owner__tags="cool" )._query - self.assertEqual(query, {"animal.owner.tags": "cool", "animal.tag": "heavy"}) + assert query == {"animal.owner.tags": "cool", "animal.tag": "heavy"} ocorrence = Ocorrence.objects( animal__tag="heavy", animal__owner__tags="cool" ).first() - self.assertEqual(ocorrence.person, "teste 2") - self.assertIsInstance(ocorrence.animal, Animal) + assert ocorrence.person == "teste 2" + assert isinstance(ocorrence.animal, Animal) diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py index 611c0ff8..f0a6b96e 100644 --- a/tests/fields/test_complex_datetime_field.py +++ b/tests/fields/test_complex_datetime_field.py @@ -28,7 +28,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 # Post UTC - microseconds are rounded (down) nearest millisecond - with # default datetimefields @@ -36,7 +36,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 # Pre UTC dates microseconds below 1000 are dropped - with default # datetimefields @@ -44,7 +44,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 # Pre UTC microseconds above 1000 is wonky - with default datetimefields # log.date has an invalid microsecond value so I can't construct @@ -54,9 +54,9 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) + assert log == log1 # Test string padding microsecond = map(int, [math.pow(10, x) for x in range(6)]) @@ -64,7 +64,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): stored = LogEntry(date=datetime.datetime(*values)).to_mongo()["date"] - self.assertTrue( + assert ( re.match("^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$", stored) is not None ) @@ -73,7 +73,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()[ "date_with_dots" ] - self.assertTrue( + assert ( re.match("^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$", stored) is not None ) @@ -93,40 +93,40 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.save() log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) + assert log == log1 # create extra 59 log entries for a total of 60 for i in range(1951, 2010): d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) LogEntry(date=d).save() - self.assertEqual(LogEntry.objects.count(), 60) + assert LogEntry.objects.count() == 60 # Test ordering logs = LogEntry.objects.order_by("date") i = 0 while i < 59: - self.assertTrue(logs[i].date <= logs[i + 1].date) + assert logs[i].date <= logs[i + 1].date i += 1 logs = LogEntry.objects.order_by("-date") i = 0 while i < 59: - self.assertTrue(logs[i].date >= logs[i + 1].date) + assert logs[i].date >= logs[i + 1].date i += 1 # Test searching logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) + assert logs.count() == 30 logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) + assert logs.count() == 30 logs = LogEntry.objects.filter( date__lte=datetime.datetime(2011, 1, 1), date__gte=datetime.datetime(2000, 1, 1), ) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 LogEntry.drop_collection() @@ -137,17 +137,17 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): logs = list(LogEntry.objects.order_by("date")) for next_idx, log in enumerate(logs[:-1], start=1): next_log = logs[next_idx] - self.assertTrue(log.date < next_log.date) + assert log.date < next_log.date logs = list(LogEntry.objects.order_by("-date")) for next_idx, log in enumerate(logs[:-1], start=1): next_log = logs[next_idx] - self.assertTrue(log.date > next_log.date) + assert log.date > next_log.date logs = LogEntry.objects.filter( date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000) ) - self.assertEqual(logs.count(), 4) + assert logs.count() == 4 def test_no_default_value(self): class Log(Document): @@ -156,11 +156,11 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): Log.drop_collection() log = Log() - self.assertIsNone(log.timestamp) + assert log.timestamp is None log.save() fetched_log = Log.objects.with_id(log.id) - self.assertIsNone(fetched_log.timestamp) + assert fetched_log.timestamp is None def test_default_static_value(self): NOW = datetime.datetime.utcnow() @@ -171,11 +171,11 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): Log.drop_collection() log = Log() - self.assertEqual(log.timestamp, NOW) + assert log.timestamp == NOW log.save() fetched_log = Log.objects.with_id(log.id) - self.assertEqual(fetched_log.timestamp, NOW) + assert fetched_log.timestamp == NOW def test_default_callable(self): NOW = datetime.datetime.utcnow() @@ -186,8 +186,8 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): Log.drop_collection() log = Log() - self.assertGreaterEqual(log.timestamp, NOW) + assert log.timestamp >= NOW log.save() fetched_log = Log.objects.with_id(log.id) - self.assertGreaterEqual(fetched_log.timestamp, NOW) + assert fetched_log.timestamp >= NOW diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index da572134..46fa4f0f 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -10,6 +10,7 @@ except ImportError: from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestDateField(MongoDBTestCase): @@ -23,7 +24,8 @@ class TestDateField(MongoDBTestCase): dt = DateField() md = MyDoc(dt="") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_date_from_whitespace_string(self): """ @@ -35,7 +37,8 @@ class TestDateField(MongoDBTestCase): dt = DateField() md = MyDoc(dt=" ") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_default_values_today(self): """Ensure that default field values are used when creating @@ -47,9 +50,9 @@ class TestDateField(MongoDBTestCase): person = Person() person.validate() - self.assertEqual(person.day, person.day) - self.assertEqual(person.day, datetime.date.today()) - self.assertEqual(person._data["day"], person.day) + assert person.day == person.day + assert person.day == datetime.date.today() + assert person._data["day"] == person.day def test_date(self): """Tests showing pymongo date fields @@ -67,7 +70,7 @@ class TestDateField(MongoDBTestCase): log.date = datetime.date.today() log.save() log.reload() - self.assertEqual(log.date, datetime.date.today()) + assert log.date == datetime.date.today() d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) @@ -75,16 +78,16 @@ class TestDateField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) + assert log.date == d1.date() + assert log.date == d2.date() d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) + assert log.date == d1.date() + assert log.date == d2.date() if not six.PY3: # Pre UTC dates microseconds below 1000 are dropped @@ -94,8 +97,8 @@ class TestDateField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) + assert log.date == d1.date() + assert log.date == d2.date() def test_regular_usage(self): """Tests for regular datetime fields""" @@ -113,35 +116,35 @@ class TestDateField(MongoDBTestCase): for query in (d1, d1.isoformat(" ")): log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) + assert log == log1 if dateutil: log1 = LogEntry.objects.get(date=d1.isoformat("T")) - self.assertEqual(log, log1) + assert log == log1 # create additional 19 log entries for a total of 20 for i in range(1971, 1990): d = datetime.datetime(i, 1, 1, 0, 0, 1) LogEntry(date=d).save() - self.assertEqual(LogEntry.objects.count(), 20) + assert LogEntry.objects.count() == 20 # Test ordering logs = LogEntry.objects.order_by("date") i = 0 while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) + assert logs[i].date <= logs[i + 1].date i += 1 logs = LogEntry.objects.order_by("-date") i = 0 while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) + assert logs[i].date >= logs[i + 1].date i += 1 # Test searching logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 def test_validation(self): """Ensure that invalid values cannot be assigned to datetime @@ -166,6 +169,8 @@ class TestDateField(MongoDBTestCase): log.validate() log.time = -1 - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "ABC" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index c911390a..8db491c6 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -11,6 +11,7 @@ from mongoengine import * from mongoengine import connection from tests.utils import MongoDBTestCase +import pytest class TestDateTimeField(MongoDBTestCase): @@ -24,7 +25,8 @@ class TestDateTimeField(MongoDBTestCase): dt = DateTimeField() md = MyDoc(dt="") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_datetime_from_whitespace_string(self): """ @@ -36,7 +38,8 @@ class TestDateTimeField(MongoDBTestCase): dt = DateTimeField() md = MyDoc(dt=" ") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_default_value_utcnow(self): """Ensure that default field values are used when creating @@ -50,11 +53,9 @@ class TestDateTimeField(MongoDBTestCase): person = Person() person.validate() person_created_t0 = person.created - self.assertLess(person.created - utcnow, dt.timedelta(seconds=1)) - self.assertEqual( - person_created_t0, person.created - ) # make sure it does not change - self.assertEqual(person._data["created"], person.created) + assert person.created - utcnow < dt.timedelta(seconds=1) + assert person_created_t0 == person.created # make sure it does not change + assert person._data["created"] == person.created def test_handling_microseconds(self): """Tests showing pymongo datetime fields handling of microseconds. @@ -74,7 +75,7 @@ class TestDateTimeField(MongoDBTestCase): log.date = dt.date.today() log.save() log.reload() - self.assertEqual(log.date.date(), dt.date.today()) + assert log.date.date() == dt.date.today() # Post UTC - microseconds are rounded (down) nearest millisecond and # dropped @@ -84,8 +85,8 @@ class TestDateTimeField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) + assert log.date != d1 + assert log.date == d2 # Post UTC - microseconds are rounded (down) nearest millisecond d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 9999) @@ -93,8 +94,8 @@ class TestDateTimeField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) + assert log.date != d1 + assert log.date == d2 if not six.PY3: # Pre UTC dates microseconds below 1000 are dropped @@ -104,8 +105,8 @@ class TestDateTimeField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) + assert log.date != d1 + assert log.date == d2 def test_regular_usage(self): """Tests for regular datetime fields""" @@ -123,43 +124,43 @@ class TestDateTimeField(MongoDBTestCase): for query in (d1, d1.isoformat(" ")): log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) + assert log == log1 if dateutil: log1 = LogEntry.objects.get(date=d1.isoformat("T")) - self.assertEqual(log, log1) + assert log == log1 # create additional 19 log entries for a total of 20 for i in range(1971, 1990): d = dt.datetime(i, 1, 1, 0, 0, 1) LogEntry(date=d).save() - self.assertEqual(LogEntry.objects.count(), 20) + assert LogEntry.objects.count() == 20 # Test ordering logs = LogEntry.objects.order_by("date") i = 0 while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) + assert logs[i].date <= logs[i + 1].date i += 1 logs = LogEntry.objects.order_by("-date") i = 0 while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) + assert logs[i].date >= logs[i + 1].date i += 1 # Test searching logs = LogEntry.objects.filter(date__gte=dt.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 logs = LogEntry.objects.filter(date__lte=dt.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 logs = LogEntry.objects.filter( date__lte=dt.datetime(1980, 1, 1), date__gte=dt.datetime(1975, 1, 1) ) - self.assertEqual(logs.count(), 5) + assert logs.count() == 5 def test_datetime_validation(self): """Ensure that invalid values cannot be assigned to datetime @@ -187,15 +188,20 @@ class TestDateTimeField(MongoDBTestCase): log.validate() log.time = -1 - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "ABC" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "2019-05-16 21:GARBAGE:12" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "2019-05-16 21:42:57.GARBAGE" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "2019-05-16 21:42:57.123.456" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() def test_parse_datetime_as_str(self): class DTDoc(Document): @@ -206,15 +212,16 @@ class TestDateTimeField(MongoDBTestCase): # make sure that passing a parsable datetime works dtd = DTDoc() dtd.date = date_str - self.assertIsInstance(dtd.date, six.string_types) + assert isinstance(dtd.date, six.string_types) dtd.save() dtd.reload() - self.assertIsInstance(dtd.date, dt.datetime) - self.assertEqual(str(dtd.date), date_str) + assert isinstance(dtd.date, dt.datetime) + assert str(dtd.date) == date_str dtd.date = "January 1st, 9999999999" - self.assertRaises(ValidationError, dtd.validate) + with pytest.raises(ValidationError): + dtd.validate() class TestDateTimeTzAware(MongoDBTestCase): @@ -235,4 +242,4 @@ class TestDateTimeTzAware(MongoDBTestCase): log = LogEntry.objects.first() log.time = dt.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(["time"], log._changed_fields) + assert ["time"] == log._changed_fields diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py index 30b7e5ea..b5b95363 100644 --- a/tests/fields/test_decimal_field.py +++ b/tests/fields/test_decimal_field.py @@ -4,6 +4,7 @@ from decimal import Decimal from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestDecimalField(MongoDBTestCase): @@ -18,21 +19,26 @@ class TestDecimalField(MongoDBTestCase): Person(height=Decimal("1.89")).save() person = Person.objects.first() - self.assertEqual(person.height, Decimal("1.89")) + assert person.height == Decimal("1.89") person.height = "2.0" person.save() person.height = 0.01 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = Decimal("0.01") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = Decimal("4.0") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = "something invalid" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person_2 = Person(height="something invalid") - self.assertRaises(ValidationError, person_2.validate) + with pytest.raises(ValidationError): + person_2.validate() def test_comparison(self): class Person(Document): @@ -45,11 +51,11 @@ class TestDecimalField(MongoDBTestCase): Person(money=8).save() Person(money=10).save() - self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) - self.assertEqual(2, Person.objects(money__gt=7).count()) - self.assertEqual(2, Person.objects(money__gt="7").count()) + assert 2 == Person.objects(money__gt=Decimal("7")).count() + assert 2 == Person.objects(money__gt=7).count() + assert 2 == Person.objects(money__gt="7").count() - self.assertEqual(3, Person.objects(money__gte="7").count()) + assert 3 == Person.objects(money__gte="7").count() def test_storage(self): class Person(Document): @@ -87,7 +93,7 @@ class TestDecimalField(MongoDBTestCase): ] expected.extend(expected) actual = list(Person.objects.exclude("id").as_pymongo()) - self.assertEqual(expected, actual) + assert expected == actual # How it comes out locally expected = [ @@ -101,4 +107,4 @@ class TestDecimalField(MongoDBTestCase): expected.extend(expected) for field_name in ["float_value", "string_value"]: actual = list(Person.objects().scalar(field_name)) - self.assertEqual(expected, actual) + assert expected == actual diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 07bab85b..56df682f 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -3,6 +3,7 @@ from mongoengine import * from mongoengine.base import BaseDict from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest class TestDictField(MongoDBTestCase): @@ -14,7 +15,7 @@ class TestDictField(MongoDBTestCase): info = {"testkey": "testvalue"} post = BlogPost(info=info).save() - self.assertEqual(get_as_pymongo(post), {"_id": post.id, "info": info}) + assert get_as_pymongo(post) == {"_id": post.id, "info": info} def test_general_things(self): """Ensure that dict types work as expected.""" @@ -26,25 +27,32 @@ class TestDictField(MongoDBTestCase): post = BlogPost() post.info = "my post" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = ["test", "test"] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"$title": "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"nested": {"$title": "test"}} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"the.title": "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"nested": {"the.title": "test"}} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {1: "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"title": "test"} post.save() @@ -61,33 +69,27 @@ class TestDictField(MongoDBTestCase): post.info = {"details": {"test": 3}} post.save() - self.assertEqual(BlogPost.objects.count(), 4) - self.assertEqual(BlogPost.objects.filter(info__title__exact="test").count(), 1) - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact="test").count(), 1 - ) + assert BlogPost.objects.count() == 4 + assert BlogPost.objects.filter(info__title__exact="test").count() == 1 + assert BlogPost.objects.filter(info__details__test__exact="test").count() == 1 post = BlogPost.objects.filter(info__title__exact="dollar_sign").first() - self.assertIn("te$t", post["info"]["details"]) + assert "te$t" in post["info"]["details"] # Confirm handles non strings or non existing keys - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact=5).count(), 0 - ) - self.assertEqual( - BlogPost.objects.filter(info__made_up__test__exact="test").count(), 0 - ) + assert BlogPost.objects.filter(info__details__test__exact=5).count() == 0 + assert BlogPost.objects.filter(info__made_up__test__exact="test").count() == 0 post = BlogPost.objects.create(info={"title": "original"}) post.info.update({"title": "updated"}) post.save() post.reload() - self.assertEqual("updated", post.info["title"]) + assert "updated" == post.info["title"] post.info.setdefault("authors", []) post.save() post.reload() - self.assertEqual([], post.info["authors"]) + assert [] == post.info["authors"] def test_dictfield_dump_document(self): """Ensure a DictField can handle another document's dump.""" @@ -114,10 +116,8 @@ class TestDictField(MongoDBTestCase): ).save() doc = Doc(field=to_embed.to_mongo().to_dict()) doc.save() - self.assertIsInstance(doc.field, dict) - self.assertEqual( - doc.field, {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} - ) + assert isinstance(doc.field, dict) + assert doc.field == {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild( @@ -125,7 +125,7 @@ class TestDictField(MongoDBTestCase): ).save() doc = Doc(field=to_embed_child.to_mongo().to_dict()) doc.save() - self.assertIsInstance(doc.field, dict) + assert isinstance(doc.field, dict) expected = { "_id": 2, "_cls": "ToEmbedParent.ToEmbedChild", @@ -135,7 +135,7 @@ class TestDictField(MongoDBTestCase): "recursive": {}, }, } - self.assertEqual(doc.field, expected) + assert doc.field == expected def test_dictfield_strict(self): """Ensure that dict field handles validation if provided a strict field type.""" @@ -150,7 +150,7 @@ class TestDictField(MongoDBTestCase): e.save() # try creating an invalid mapping - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping["somestring"] = "abc" e.save() @@ -184,22 +184,21 @@ class TestDictField(MongoDBTestCase): e.save() e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping["somestring"], StringSetting) - self.assertIsInstance(e2.mapping["someint"], IntegerSetting) + assert isinstance(e2.mapping["somestring"], StringSetting) + assert isinstance(e2.mapping["someint"], IntegerSetting) # Test querying - self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__number=1).count(), 1 + assert Simple.objects.filter(mapping__someint__value=42).count() == 1 + assert Simple.objects.filter(mapping__nested_dict__number=1).count() == 1 + assert ( + Simple.objects.filter(mapping__nested_dict__complex__value=42).count() == 1 ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1 + assert ( + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count() == 1 ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1 - ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 1 + assert ( + Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count() + == 1 ) # Confirm can update @@ -207,11 +206,13 @@ class TestDictField(MongoDBTestCase): Simple.objects().update( set__mapping__nested_dict__list__1=StringSetting(value="Boo") ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 0 + assert ( + Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count() + == 0 ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value="Boo").count(), 1 + assert ( + Simple.objects.filter(mapping__nested_dict__list__1__value="Boo").count() + == 1 ) def test_push_dict(self): @@ -221,12 +222,12 @@ class TestDictField(MongoDBTestCase): doc = MyModel(events=[{"a": 1}]).save() raw_doc = get_as_pymongo(doc) expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}]} - self.assertEqual(raw_doc, expected_raw_doc) + assert raw_doc == expected_raw_doc MyModel.objects(id=doc.id).update(push__events={}) raw_doc = get_as_pymongo(doc) expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}, {}]} - self.assertEqual(raw_doc, expected_raw_doc) + assert raw_doc == expected_raw_doc def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" @@ -239,8 +240,8 @@ class TestDictField(MongoDBTestCase): d1.data["foo"] = "bar" d1.data2["foo"] = "bar" d2 = D() - self.assertEqual(d2.data, {}) - self.assertEqual(d2.data2, {}) + assert d2.data == {} + assert d2.data2 == {} def test_dict_field_invalid_dict_value(self): class DictFieldTest(Document): @@ -250,11 +251,13 @@ class TestDictField(MongoDBTestCase): test = DictFieldTest(dictionary=None) test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) + with pytest.raises(ValidationError): + test.validate() test = DictFieldTest(dictionary=False) test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) + with pytest.raises(ValidationError): + test.validate() def test_dict_field_raises_validation_error_if_wrongly_assign_embedded_doc(self): class DictFieldTest(Document): @@ -267,12 +270,10 @@ class TestDictField(MongoDBTestCase): embed = Embedded(name="garbage") doc = DictFieldTest(dictionary=embed) - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: doc.validate() - self.assertIn("'dictionary'", str(ctx_err.exception)) - self.assertIn( - "Only dictionaries may be used in a DictField", str(ctx_err.exception) - ) + assert "'dictionary'" in str(ctx_err.exception) + assert "Only dictionaries may be used in a DictField" in str(ctx_err.exception) def test_atomic_update_dict_field(self): """Ensure that the entire DictField can be atomically updated.""" @@ -287,11 +288,11 @@ class TestDictField(MongoDBTestCase): e.save() e.update(set__mapping={"ints": [3, 4]}) e.reload() - self.assertEqual(BaseDict, type(e.mapping)) - self.assertEqual({"ints": [3, 4]}, e.mapping) + assert BaseDict == type(e.mapping) + assert {"ints": [3, 4]} == e.mapping # try creating an invalid mapping - with self.assertRaises(ValueError): + with pytest.raises(ValueError): e.update(set__mapping={"somestrings": ["foo", "bar"]}) def test_dictfield_with_referencefield_complex_nesting_cases(self): @@ -329,13 +330,13 @@ class TestDictField(MongoDBTestCase): e.save() s = Simple.objects.first() - self.assertIsInstance(s.mapping0["someint"], Doc) - self.assertIsInstance(s.mapping1["someint"], Doc) - self.assertIsInstance(s.mapping2["someint"][0], Doc) - self.assertIsInstance(s.mapping3["someint"][0], Doc) - self.assertIsInstance(s.mapping4["someint"]["d"], Doc) - self.assertIsInstance(s.mapping5["someint"]["d"], Doc) - self.assertIsInstance(s.mapping6["someint"][0]["d"], Doc) - self.assertIsInstance(s.mapping7["someint"][0]["d"], Doc) - self.assertIsInstance(s.mapping8["someint"][0]["d"][0], Doc) - self.assertIsInstance(s.mapping9["someint"][0]["d"][0], Doc) + assert isinstance(s.mapping0["someint"], Doc) + assert isinstance(s.mapping1["someint"], Doc) + assert isinstance(s.mapping2["someint"][0], Doc) + assert isinstance(s.mapping3["someint"][0], Doc) + assert isinstance(s.mapping4["someint"]["d"], Doc) + assert isinstance(s.mapping5["someint"]["d"], Doc) + assert isinstance(s.mapping6["someint"][0]["d"], Doc) + assert isinstance(s.mapping7["someint"][0]["d"], Doc) + assert isinstance(s.mapping8["someint"][0]["d"][0], Doc) + assert isinstance(s.mapping9["someint"][0]["d"][0], Doc) diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index 06ec5151..b8d3d169 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -5,6 +5,7 @@ from unittest import SkipTest from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestEmailField(MongoDBTestCase): @@ -27,7 +28,8 @@ class TestEmailField(MongoDBTestCase): user.validate() user = User(email="ross@example.com.") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # unicode domain user = User(email=u"user@пример.рф") @@ -35,11 +37,13 @@ class TestEmailField(MongoDBTestCase): # invalid unicode domain user = User(email=u"user@пример") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # invalid data type user = User(email=123) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() def test_email_field_unicode_user(self): # Don't run this test on pypy3, which doesn't support unicode regex: @@ -52,7 +56,8 @@ class TestEmailField(MongoDBTestCase): # unicode user shouldn't validate by default... user = User(email=u"Dörte@Sörensen.example.com") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # ...but it should be fine with allow_utf8_user set to True class User(Document): @@ -67,7 +72,8 @@ class TestEmailField(MongoDBTestCase): # localhost domain shouldn't validate by default... user = User(email="me@localhost") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # ...but it should be fine if it's whitelisted class User(Document): @@ -82,9 +88,9 @@ class TestEmailField(MongoDBTestCase): invalid_idn = ".google.com" user = User(email="me@%s" % invalid_idn) - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: user.validate() - self.assertIn("domain failed IDN encoding", str(ctx_err.exception)) + assert "domain failed IDN encoding" in str(ctx_err.exception) def test_email_field_ip_domain(self): class User(Document): @@ -96,13 +102,16 @@ class TestEmailField(MongoDBTestCase): # IP address as a domain shouldn't validate by default... user = User(email=valid_ipv4) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() user = User(email=valid_ipv6) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # ...but it should be fine with allow_ip_domain set to True class User(Document): @@ -116,7 +125,8 @@ class TestEmailField(MongoDBTestCase): # invalid IP should still fail validation user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() def test_email_field_honors_regex(self): class User(Document): @@ -124,8 +134,9 @@ class TestEmailField(MongoDBTestCase): # Fails regex validation user = User(email="me@foo.com") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # Passes regex validation user = User(email="me@example.com") - self.assertIsNone(user.validate()) + assert user.validate() is None diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index 8db8c180..4fcf6bf1 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -13,6 +13,7 @@ from mongoengine import ( ) from tests.utils import MongoDBTestCase +import pytest class TestEmbeddedDocumentField(MongoDBTestCase): @@ -21,13 +22,13 @@ class TestEmbeddedDocumentField(MongoDBTestCase): name = StringField() field = EmbeddedDocumentField(MyDoc) - self.assertEqual(field.document_type_obj, MyDoc) + assert field.document_type_obj == MyDoc field2 = EmbeddedDocumentField("MyDoc") - self.assertEqual(field2.document_type_obj, "MyDoc") + assert field2.document_type_obj == "MyDoc" def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): EmbeddedDocumentField(dict) def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): @@ -35,11 +36,11 @@ class TestEmbeddedDocumentField(MongoDBTestCase): name = StringField() emb = EmbeddedDocumentField("MyDoc") - with self.assertRaises(ValidationError) as ctx: + with pytest.raises(ValidationError) as ctx: emb.document_type - self.assertIn( - "Invalid embedded document class provided to an EmbeddedDocumentField", - str(ctx.exception), + assert ( + "Invalid embedded document class provided to an EmbeddedDocumentField" + in str(ctx.exception) ) def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): @@ -47,12 +48,12 @@ class TestEmbeddedDocumentField(MongoDBTestCase): class MyDoc(Document): name = StringField() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): class MyFailingDoc(Document): emb = EmbeddedDocumentField(MyDoc) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): class MyFailingdoc2(Document): emb = EmbeddedDocumentField("MyDoc") @@ -71,24 +72,24 @@ class TestEmbeddedDocumentField(MongoDBTestCase): p = Person(settings=AdminSettings(foo1="bar1", foo2="bar2"), name="John").save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as ctx_err: Person.objects(settings__notexist="bar").first() - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' - with self.assertRaises(LookUpError): + with pytest.raises(LookUpError): Person.objects.only("settings.notexist") # Test existing attribute - self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p.id) + assert Person.objects(settings__foo1="bar1").first().id == p.id only_p = Person.objects.only("settings.foo1").first() - self.assertEqual(only_p.settings.foo1, p.settings.foo1) - self.assertIsNone(only_p.settings.foo2) - self.assertIsNone(only_p.name) + assert only_p.settings.foo1 == p.settings.foo1 + assert only_p.settings.foo2 is None + assert only_p.name is None exclude_p = Person.objects.exclude("settings.foo1").first() - self.assertIsNone(exclude_p.settings.foo1) - self.assertEqual(exclude_p.settings.foo2, p.settings.foo2) - self.assertEqual(exclude_p.name, p.name) + assert exclude_p.settings.foo1 is None + assert exclude_p.settings.foo2 == p.settings.foo2 + assert exclude_p.name == p.name def test_query_embedded_document_attribute_with_inheritance(self): class BaseSettings(EmbeddedDocument): @@ -107,17 +108,17 @@ class TestEmbeddedDocumentField(MongoDBTestCase): p.save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: - self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id) - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + with pytest.raises(InvalidQueryError) as ctx_err: + assert Person.objects(settings__notexist="bar").first().id == p.id + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' # Test existing attribute - self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id) - self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id) + assert Person.objects(settings__base_foo="basefoo").first().id == p.id + assert Person.objects(settings__sub_foo="subfoo").first().id == p.id only_p = Person.objects.only("settings.base_foo", "settings._cls").first() - self.assertEqual(only_p.settings.base_foo, "basefoo") - self.assertIsNone(only_p.settings.sub_foo) + assert only_p.settings.base_foo == "basefoo" + assert only_p.settings.sub_foo is None def test_query_list_embedded_document_with_inheritance(self): class Post(EmbeddedDocument): @@ -137,14 +138,14 @@ class TestEmbeddedDocumentField(MongoDBTestCase): record_text = Record(posts=[TextPost(content="a", title="foo")]).save() records = list(Record.objects(posts__author=record_movie.posts[0].author)) - self.assertEqual(len(records), 1) - self.assertEqual(records[0].id, record_movie.id) + assert len(records) == 1 + assert records[0].id == record_movie.id records = list(Record.objects(posts__content=record_text.posts[0].content)) - self.assertEqual(len(records), 1) - self.assertEqual(records[0].id, record_text.id) + assert len(records) == 1 + assert records[0].id == record_text.id - self.assertEqual(Record.objects(posts__title="foo").count(), 2) + assert Record.objects(posts__title="foo").count() == 2 class TestGenericEmbeddedDocumentField(MongoDBTestCase): @@ -167,13 +168,13 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): person.save() person = Person.objects.first() - self.assertIsInstance(person.like, Car) + assert isinstance(person.like, Car) person.like = Dish(food="arroz", number=15) person.save() person = Person.objects.first() - self.assertIsInstance(person.like, Dish) + assert isinstance(person.like, Dish) def test_generic_embedded_document_choices(self): """Ensure you can limit GenericEmbeddedDocument choices.""" @@ -193,13 +194,14 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): person = Person(name="Test User") person.like = Car(name="Fiat") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.like = Dish(food="arroz", number=15) person.save() person = Person.objects.first() - self.assertIsInstance(person.like, Dish) + assert isinstance(person.like, Dish) def test_generic_list_embedded_document_choices(self): """Ensure you can limit GenericEmbeddedDocument choices inside @@ -221,13 +223,14 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): person = Person(name="Test User") person.likes = [Car(name="Fiat")] - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.likes = [Dish(food="arroz", number=15)] person.save() person = Person.objects.first() - self.assertIsInstance(person.likes[0], Dish) + assert isinstance(person.likes[0], Dish) def test_choices_validation_documents(self): """ @@ -263,7 +266,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): # Single Entry Failure post = BlogPost(comments=[ModeratorComments(author="mod1", message="message1")]) - self.assertRaises(ValidationError, post.save) + with pytest.raises(ValidationError): + post.save() # Mixed Entry Failure post = BlogPost( @@ -272,7 +276,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): UserComments(author="user2", message="message2"), ] ) - self.assertRaises(ValidationError, post.save) + with pytest.raises(ValidationError): + post.save() def test_choices_validation_documents_inheritance(self): """ @@ -311,16 +316,16 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): p2 = Person(settings=NonAdminSettings(foo2="bar2")).save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as ctx_err: Person.objects(settings__notexist="bar").first() - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' - with self.assertRaises(LookUpError): + with pytest.raises(LookUpError): Person.objects.only("settings.notexist") # Test existing attribute - self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p1.id) - self.assertEqual(Person.objects(settings__foo2="bar2").first().id, p2.id) + assert Person.objects(settings__foo1="bar1").first().id == p1.id + assert Person.objects(settings__foo2="bar2").first().id == p2.id def test_query_generic_embedded_document_attribute_with_inheritance(self): class BaseSettings(EmbeddedDocument): @@ -339,10 +344,10 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): p.save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: - self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id) - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + with pytest.raises(InvalidQueryError) as ctx_err: + assert Person.objects(settings__notexist="bar").first().id == p.id + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' # Test existing attribute - self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id) - self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id) + assert Person.objects(settings__base_foo="basefoo").first().id == p.id + assert Person.objects(settings__sub_foo="subfoo").first().id == p.id diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index bd2149e6..b27d95d2 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -39,6 +39,7 @@ from mongoengine.base import BaseField, EmbeddedDocumentList, _document_registry from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase +import pytest class TestField(MongoDBTestCase): @@ -58,25 +59,25 @@ class TestField(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) + assert data_to_be_saved == ["age", "created", "day", "name", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, person.name) - self.assertEqual(person.age, person.age) - self.assertEqual(person.userid, person.userid) - self.assertEqual(person.created, person.created) - self.assertEqual(person.day, person.day) + assert person.name == person.name + assert person.age == person.age + assert person.userid == person.userid + assert person.created == person.created + assert person.day == person.day - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) - self.assertEqual(person._data["day"], person.day) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created + assert person._data["day"] == person.day # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) + assert data_to_be_saved == ["age", "created", "day", "name", "userid"] def test_custom_field_validation_raise_deprecated_error_when_validation_return_something( self, @@ -95,13 +96,13 @@ class TestField(MongoDBTestCase): "it should raise a ValidationError if validation fails" ) - with self.assertRaises(DeprecatedError) as ctx_err: + with pytest.raises(DeprecatedError) as ctx_err: Person(name="").validate() - self.assertEqual(str(ctx_err.exception), error) + assert str(ctx_err.exception) == error - with self.assertRaises(DeprecatedError) as ctx_err: + with pytest.raises(DeprecatedError) as ctx_err: Person(name="").save() - self.assertEqual(str(ctx_err.exception), error) + assert str(ctx_err.exception) == error def test_custom_field_validation_raise_validation_error(self): def _not_empty(z): @@ -113,18 +114,16 @@ class TestField(MongoDBTestCase): Person.drop_collection() - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: Person(name="").validate() - self.assertEqual( - "ValidationError (Person:None) (cantbeempty: ['name'])", - str(ctx_err.exception), + assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( + ctx_err.exception ) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Person(name="").save() - self.assertEqual( - "ValidationError (Person:None) (cantbeempty: ['name'])", - str(ctx_err.exception), + assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( + ctx_err.exception ) Person(name="garbage").validate() @@ -146,23 +145,23 @@ class TestField(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, person.name) - self.assertEqual(person.age, person.age) - self.assertEqual(person.userid, person.userid) - self.assertEqual(person.created, person.created) + assert person.name == person.name + assert person.age == person.age + assert person.userid == person.userid + assert person.created == person.created - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] def test_default_values_when_setting_to_None(self): """Ensure that default field values are used when creating @@ -183,23 +182,23 @@ class TestField(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, None) - self.assertEqual(person.age, 30) - self.assertEqual(person.userid, "test") - self.assertIsInstance(person.created, datetime.datetime) + assert person.name == None + assert person.age == 30 + assert person.userid == "test" + assert isinstance(person.created, datetime.datetime) - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc( self, @@ -213,7 +212,7 @@ class TestField(MongoDBTestCase): doc.x = [] doc.save() reloaded = Doc.objects.get(id=doc.id) - self.assertEqual(reloaded.x, []) + assert reloaded.x == [] def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc( self, @@ -228,7 +227,7 @@ class TestField(MongoDBTestCase): doc.y = 2 # Was triggering the bug doc.save() reloaded = Doc.objects.get(id=doc.id) - self.assertEqual(reloaded.x, []) + assert reloaded.x == [] def test_default_values_when_deleting_value(self): """Ensure that default field values are used after non-default @@ -253,24 +252,24 @@ class TestField(MongoDBTestCase): del person.created data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, None) - self.assertEqual(person.age, 30) - self.assertEqual(person.userid, "test") - self.assertIsInstance(person.created, datetime.datetime) - self.assertNotEqual(person.created, datetime.datetime(2014, 6, 12)) + assert person.name == None + assert person.age == 30 + assert person.userid == "test" + assert isinstance(person.created, datetime.datetime) + assert person.created != datetime.datetime(2014, 6, 12) - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] def test_required_values(self): """Ensure that required field constraints are enforced.""" @@ -281,9 +280,11 @@ class TestField(MongoDBTestCase): userid = StringField() person = Person(name="Test User") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person = Person(age=30) - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_not_required_handles_none_in_update(self): """Ensure that every fields should accept None if required is @@ -311,15 +312,15 @@ class TestField(MongoDBTestCase): set__flt_fld=None, set__comp_dt_fld=None, ) - self.assertEqual(res, 1) + assert res == 1 # Retrive data from db and verify it. ret = HandleNoneFields.objects.all()[0] - self.assertIsNone(ret.str_fld) - self.assertIsNone(ret.int_fld) - self.assertIsNone(ret.flt_fld) + assert ret.str_fld is None + assert ret.int_fld is None + assert ret.flt_fld is None - self.assertIsNone(ret.comp_dt_fld) + assert ret.comp_dt_fld is None def test_not_required_handles_none_from_database(self): """Ensure that every field can handle null values from the @@ -349,14 +350,15 @@ class TestField(MongoDBTestCase): # Retrive data from db and verify it. ret = HandleNoneFields.objects.first() - self.assertIsNone(ret.str_fld) - self.assertIsNone(ret.int_fld) - self.assertIsNone(ret.flt_fld) - self.assertIsNone(ret.comp_dt_fld) + assert ret.str_fld is None + assert ret.int_fld is None + assert ret.flt_fld is None + assert ret.comp_dt_fld is None # Retrieved object shouldn't pass validation when a re-save is # attempted. - self.assertRaises(ValidationError, ret.validate) + with pytest.raises(ValidationError): + ret.validate() def test_default_id_validation_as_objectid(self): """Ensure that invalid values cannot be assigned to an @@ -367,13 +369,15 @@ class TestField(MongoDBTestCase): name = StringField() person = Person(name="Test User") - self.assertEqual(person.id, None) + assert person.id == None person.id = 47 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.id = "abc" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.id = str(ObjectId()) person.validate() @@ -386,19 +390,22 @@ class TestField(MongoDBTestCase): userid = StringField(r"[0-9a-z_]+$") person = Person(name=34) - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() # Test regex validation on userid person = Person(userid="test.User") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.userid = "test_user" - self.assertEqual(person.userid, "test_user") + assert person.userid == "test_user" person.validate() # Test max length validation on name person = Person(name="Name that is more than twenty characters") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.name = "Shorter name" person.validate() @@ -407,19 +414,19 @@ class TestField(MongoDBTestCase): """Ensure that db_field doesn't accept invalid values.""" # dot in the name - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class User(Document): name = StringField(db_field="user.name") # name starting with $ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class UserX1(Document): name = StringField(db_field="$name") # name containing a null character - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class UserX2(Document): name = StringField(db_field="name\0") @@ -455,9 +462,11 @@ class TestField(MongoDBTestCase): post.validate() post.tags = "fun" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.tags = [1, 2] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.tags = ["fun", "leisure"] post.validate() @@ -465,30 +474,36 @@ class TestField(MongoDBTestCase): post.validate() post.access_list = "a,b" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.access_list = ["c", "d"] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.access_list = ["a", "b"] post.validate() - self.assertEqual(post.get_access_list_display(), u"Administration, Manager") + assert post.get_access_list_display() == u"Administration, Manager" post.comments = ["a"] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.comments = "yay" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() comments = [Comment(content="Good for you"), Comment(content="Yay.")] post.comments = comments post.validate() post.authors = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.authors = [User()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() user = User() user.save() @@ -496,34 +511,42 @@ class TestField(MongoDBTestCase): post.validate() post.authors_as_lazy = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.authors_as_lazy = [User()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.authors_as_lazy = [user] post.validate() post.generic = [1, 2] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic = [User(), Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic = [user] post.validate() post.generic_as_lazy = [1, 2] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic_as_lazy = [User(), Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic_as_lazy = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic_as_lazy = [user] post.validate() @@ -549,7 +572,7 @@ class TestField(MongoDBTestCase): post.tags = ["leisure", "fun"] post.save() post.reload() - self.assertEqual(post.tags, ["fun", "leisure"]) + assert post.tags == ["fun", "leisure"] comment1 = Comment(content="Good for you", order=1) comment2 = Comment(content="Yay.", order=0) @@ -557,15 +580,15 @@ class TestField(MongoDBTestCase): post.comments = comments post.save() post.reload() - self.assertEqual(post.comments[0].content, comment2.content) - self.assertEqual(post.comments[1].content, comment1.content) + assert post.comments[0].content == comment2.content + assert post.comments[1].content == comment1.content post.comments[0].order = 2 post.save() post.reload() - self.assertEqual(post.comments[0].content, comment1.content) - self.assertEqual(post.comments[1].content, comment2.content) + assert post.comments[0].content == comment1.content + assert post.comments[1].content == comment2.content def test_reverse_list_sorting(self): """Ensure that a reverse sorted list field properly sorts values""" @@ -590,9 +613,9 @@ class TestField(MongoDBTestCase): catlist.save() catlist.reload() - self.assertEqual(catlist.categories[0].name, cat2.name) - self.assertEqual(catlist.categories[1].name, cat3.name) - self.assertEqual(catlist.categories[2].name, cat1.name) + assert catlist.categories[0].name == cat2.name + assert catlist.categories[1].name == cat3.name + assert catlist.categories[2].name == cat1.name def test_list_field(self): """Ensure that list types work as expected.""" @@ -604,10 +627,12 @@ class TestField(MongoDBTestCase): post = BlogPost() post.info = "my post" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"title": "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = ["test"] post.save() @@ -620,15 +645,13 @@ class TestField(MongoDBTestCase): post.info = [{"test": 3}] post.save() - self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual(BlogPost.objects.filter(info__exact="test").count(), 1) - self.assertEqual(BlogPost.objects.filter(info__0__test="test").count(), 1) + assert BlogPost.objects.count() == 3 + assert BlogPost.objects.filter(info__exact="test").count() == 1 + assert BlogPost.objects.filter(info__0__test="test").count() == 1 # Confirm handles non strings or non existing keys - self.assertEqual(BlogPost.objects.filter(info__0__test__exact="5").count(), 0) - self.assertEqual( - BlogPost.objects.filter(info__100__test__exact="test").count(), 0 - ) + assert BlogPost.objects.filter(info__0__test__exact="5").count() == 0 + assert BlogPost.objects.filter(info__100__test__exact="test").count() == 0 # test queries by list post = BlogPost() @@ -637,12 +660,12 @@ class TestField(MongoDBTestCase): post = BlogPost.objects(info=["1", "2"]).get() post.info += ["3", "4"] post.save() - self.assertEqual(BlogPost.objects(info=["1", "2", "3", "4"]).count(), 1) + assert BlogPost.objects(info=["1", "2", "3", "4"]).count() == 1 post = BlogPost.objects(info=["1", "2", "3", "4"]).get() post.info *= 2 post.save() - self.assertEqual( - BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count(), 1 + assert ( + BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count() == 1 ) def test_list_field_manipulative_operators(self): @@ -670,165 +693,149 @@ class TestField(MongoDBTestCase): reset_post() temp = ["a", "b"] post.info = post.info + temp - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] # '__delitem__(index)' # aka 'del list[index]' # aka 'operator.delitem(list, index)' reset_post() del post.info[2] # del from middle ('2') - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] # '__delitem__(slice(i, j))' # aka 'del list[i:j]' # aka 'operator.delitem(list, slice(i,j))' reset_post() del post.info[1:3] # removes '1', '2' - self.assertEqual(post.info, ["0", "3", "4", "5"]) + assert post.info == ["0", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "3", "4", "5"]) + assert post.info == ["0", "3", "4", "5"] # '__iadd__' # aka 'list += list' reset_post() temp = ["a", "b"] post.info += temp - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] # '__imul__' # aka 'list *= number' reset_post() post.info *= 2 - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__mul__' # aka 'listA*listB' reset_post() post.info = post.info * 2 - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__rmul__' # aka 'listB*listA' reset_post() post.info = 2 * post.info - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__setitem__(index, value)' # aka 'list[index]=value' # aka 'setitem(list, value)' reset_post() post.info[4] = "a" - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] # __setitem__(index, value) with a negative index reset_post() post.info[-2] = "a" - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] # '__setitem__(slice(i, j), listB)' # aka 'listA[i:j] = listB' # aka 'setitem(listA, slice(i, j), listB)' reset_post() post.info[1:3] = ["h", "e", "l", "l", "o"] - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] # '__setitem__(slice(i, j), listB)' with negative i and j reset_post() post.info[-5:-3] = ["h", "e", "l", "l", "o"] - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] # negative # 'append' reset_post() post.info.append("h") - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "h"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "h"] # 'extend' reset_post() post.info.extend(["h", "e", "l", "l", "o"]) - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] # 'insert' # 'pop' reset_post() x = post.info.pop(2) y = post.info.pop() - self.assertEqual(post.info, ["0", "1", "3", "4"]) - self.assertEqual(x, "2") - self.assertEqual(y, "5") + assert post.info == ["0", "1", "3", "4"] + assert x == "2" + assert y == "5" post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "3", "4"]) + assert post.info == ["0", "1", "3", "4"] # 'remove' reset_post() post.info.remove("2") - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] # 'reverse' reset_post() post.info.reverse() - self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"]) + assert post.info == ["5", "4", "3", "2", "1", "0"] post.save() post.reload() - self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"]) + assert post.info == ["5", "4", "3", "2", "1", "0"] # 'sort': though this operator method does manipulate the list, it is # tested in the 'test_list_field_lexicograpic_operators' function @@ -844,7 +851,8 @@ class TestField(MongoDBTestCase): # '__hash__' # aka 'hash(list)' - self.assertRaises(TypeError, lambda: hash(post.info)) + with pytest.raises(TypeError): + hash(post.info) def test_list_field_lexicographic_operators(self): """Ensure that ListField works with standard list operators that @@ -883,32 +891,32 @@ class TestField(MongoDBTestCase): blogLargeB.reload() # '__eq__' aka '==' - self.assertEqual(blogLargeA.text_info, blogLargeB.text_info) - self.assertEqual(blogLargeA.bool_info, blogLargeB.bool_info) + assert blogLargeA.text_info == blogLargeB.text_info + assert blogLargeA.bool_info == blogLargeB.bool_info # '__ge__' aka '>=' - self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info) - self.assertGreaterEqual(blogLargeA.text_info, blogLargeB.text_info) - self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info) - self.assertGreaterEqual(blogLargeA.bool_info, blogLargeB.bool_info) + assert blogLargeA.text_info >= blogSmall.text_info + assert blogLargeA.text_info >= blogLargeB.text_info + assert blogLargeA.bool_info >= blogSmall.bool_info + assert blogLargeA.bool_info >= blogLargeB.bool_info # '__gt__' aka '>' - self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info) - self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info) + assert blogLargeA.text_info >= blogSmall.text_info + assert blogLargeA.bool_info >= blogSmall.bool_info # '__le__' aka '<=' - self.assertLessEqual(blogSmall.text_info, blogLargeB.text_info) - self.assertLessEqual(blogLargeA.text_info, blogLargeB.text_info) - self.assertLessEqual(blogSmall.bool_info, blogLargeB.bool_info) - self.assertLessEqual(blogLargeA.bool_info, blogLargeB.bool_info) + assert blogSmall.text_info <= blogLargeB.text_info + assert blogLargeA.text_info <= blogLargeB.text_info + assert blogSmall.bool_info <= blogLargeB.bool_info + assert blogLargeA.bool_info <= blogLargeB.bool_info # '__lt__' aka '<' - self.assertLess(blogSmall.text_info, blogLargeB.text_info) - self.assertLess(blogSmall.bool_info, blogLargeB.bool_info) + assert blogSmall.text_info < blogLargeB.text_info + assert blogSmall.bool_info < blogLargeB.bool_info # '__ne__' aka '!=' - self.assertNotEqual(blogSmall.text_info, blogLargeB.text_info) - self.assertNotEqual(blogSmall.bool_info, blogLargeB.bool_info) + assert blogSmall.text_info != blogLargeB.text_info + assert blogSmall.bool_info != blogLargeB.bool_info # 'sort' blogLargeB.bool_info = [True, False, True, False] @@ -920,14 +928,14 @@ class TestField(MongoDBTestCase): ObjectId("54495ad94c934721ede76d23"), ObjectId("54495ad94c934721ede76f90"), ] - self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) - self.assertEqual(blogLargeB.oid_info, sorted_target_list) - self.assertEqual(blogLargeB.bool_info, [False, False, True, True]) + assert blogLargeB.text_info == ["a", "j", "z"] + assert blogLargeB.oid_info == sorted_target_list + assert blogLargeB.bool_info == [False, False, True, True] blogLargeB.save() blogLargeB.reload() - self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) - self.assertEqual(blogLargeB.oid_info, sorted_target_list) - self.assertEqual(blogLargeB.bool_info, [False, False, True, True]) + assert blogLargeB.text_info == ["a", "j", "z"] + assert blogLargeB.oid_info == sorted_target_list + assert blogLargeB.bool_info == [False, False, True, True] def test_list_assignment(self): """Ensure that list field element assignment and slicing work.""" @@ -944,37 +952,37 @@ class TestField(MongoDBTestCase): post.info[0] = 1 post.save() post.reload() - self.assertEqual(post.info[0], 1) + assert post.info[0] == 1 post.info[1:3] = ["n2", "n3"] post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", "4", 5]) + assert post.info == [1, "n2", "n3", "4", 5] post.info[-1] = "n5" post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", "4", "n5"]) + assert post.info == [1, "n2", "n3", "4", "n5"] post.info[-2] = 4 post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"]) + assert post.info == [1, "n2", "n3", 4, "n5"] post.info[1:-1] = [2] post.save() post.reload() - self.assertEqual(post.info, [1, 2, "n5"]) + assert post.info == [1, 2, "n5"] post.info[:-1] = [1, "n2", "n3", 4] post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"]) + assert post.info == [1, "n2", "n3", 4, "n5"] post.info[-4:3] = [2, 3] post.save() post.reload() - self.assertEqual(post.info, [1, 2, 3, 4, "n5"]) + assert post.info == [1, 2, 3, 4, "n5"] def test_list_field_passed_in_value(self): class Foo(Document): @@ -988,7 +996,7 @@ class TestField(MongoDBTestCase): foo = Foo(bars=[]) foo.bars.append(bar) - self.assertEqual(repr(foo.bars), "[]") + assert repr(foo.bars) == "[]" def test_list_field_strict(self): """Ensure that list field handles validation if provided @@ -1005,7 +1013,7 @@ class TestField(MongoDBTestCase): e.save() # try creating an invalid mapping - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping = ["abc"] e.save() @@ -1021,9 +1029,9 @@ class TestField(MongoDBTestCase): if i < 6: foo.save() else: - with self.assertRaises(ValidationError) as cm: + with pytest.raises(ValidationError) as cm: foo.save() - self.assertIn("List is too long", str(cm.exception)) + assert "List is too long" in str(cm.exception) def test_list_field_max_length_set_operator(self): """Ensure ListField's max_length is respected for a "set" operator.""" @@ -1032,9 +1040,9 @@ class TestField(MongoDBTestCase): items = ListField(IntField(), max_length=3) foo = Foo.objects.create(items=[1, 2, 3]) - with self.assertRaises(ValidationError) as cm: + with pytest.raises(ValidationError) as cm: foo.modify(set__items=[1, 2, 3, 4]) - self.assertIn("List is too long", str(cm.exception)) + assert "List is too long" in str(cm.exception) def test_list_field_rejects_strings(self): """Strings aren't valid list field data types.""" @@ -1046,7 +1054,8 @@ class TestField(MongoDBTestCase): e = Simple() e.mapping = "hello world" - self.assertRaises(ValidationError, e.save) + with pytest.raises(ValidationError): + e.save() def test_complex_field_required(self): """Ensure required cant be None / Empty.""" @@ -1058,7 +1067,8 @@ class TestField(MongoDBTestCase): e = Simple() e.mapping = [] - self.assertRaises(ValidationError, e.save) + with pytest.raises(ValidationError): + e.save() class Simple(Document): mapping = DictField(required=True) @@ -1066,7 +1076,8 @@ class TestField(MongoDBTestCase): Simple.drop_collection() e = Simple() e.mapping = {} - self.assertRaises(ValidationError, e.save) + with pytest.raises(ValidationError): + e.save() def test_complex_field_same_value_not_changed(self): """If a complex field is set to the same value, it should not @@ -1080,7 +1091,7 @@ class TestField(MongoDBTestCase): e = Simple().save() e.mapping = [] - self.assertEqual([], e._changed_fields) + assert [] == e._changed_fields class Simple(Document): mapping = DictField() @@ -1089,7 +1100,7 @@ class TestField(MongoDBTestCase): e = Simple().save() e.mapping = {} - self.assertEqual([], e._changed_fields) + assert [] == e._changed_fields def test_slice_marks_field_as_changed(self): class Simple(Document): @@ -1097,11 +1108,11 @@ class TestField(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[:3] = [] - self.assertEqual(["widgets"], simple._changed_fields) + assert ["widgets"] == simple._changed_fields simple.save() simple = simple.reload() - self.assertEqual(simple.widgets, [4]) + assert simple.widgets == [4] def test_del_slice_marks_field_as_changed(self): class Simple(Document): @@ -1109,11 +1120,11 @@ class TestField(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() del simple.widgets[:3] - self.assertEqual(["widgets"], simple._changed_fields) + assert ["widgets"] == simple._changed_fields simple.save() simple = simple.reload() - self.assertEqual(simple.widgets, [4]) + assert simple.widgets == [4] def test_list_field_with_negative_indices(self): class Simple(Document): @@ -1121,11 +1132,11 @@ class TestField(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[-1] = 5 - self.assertEqual(["widgets.3"], simple._changed_fields) + assert ["widgets.3"] == simple._changed_fields simple.save() simple = simple.reload() - self.assertEqual(simple.widgets, [1, 2, 3, 5]) + assert simple.widgets == [1, 2, 3, 5] def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -1159,33 +1170,23 @@ class TestField(MongoDBTestCase): e.save() e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping[0], StringSetting) - self.assertIsInstance(e2.mapping[1], IntegerSetting) + assert isinstance(e2.mapping[0], StringSetting) + assert isinstance(e2.mapping[1], IntegerSetting) # Test querying - self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__2__complex__value=42).count(), 1 - ) - self.assertEqual( - Simple.objects.filter(mapping__2__list__0__value=42).count(), 1 - ) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value="foo").count(), 1 - ) + assert Simple.objects.filter(mapping__1__value=42).count() == 1 + assert Simple.objects.filter(mapping__2__number=1).count() == 1 + assert Simple.objects.filter(mapping__2__complex__value=42).count() == 1 + assert Simple.objects.filter(mapping__2__list__0__value=42).count() == 1 + assert Simple.objects.filter(mapping__2__list__1__value="foo").count() == 1 # Confirm can update Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) - self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1) + assert Simple.objects.filter(mapping__1__value=10).count() == 1 Simple.objects().update(set__mapping__2__list__1=StringSetting(value="Boo")) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value="foo").count(), 0 - ) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value="Boo").count(), 1 - ) + assert Simple.objects.filter(mapping__2__list__1__value="foo").count() == 0 + assert Simple.objects.filter(mapping__2__list__1__value="Boo").count() == 1 def test_embedded_db_field(self): class Embedded(EmbeddedDocument): @@ -1203,9 +1204,9 @@ class TestField(MongoDBTestCase): Test.objects.update_one(inc__embedded__number=1) test = Test.objects.get() - self.assertEqual(test.embedded.number, 2) + assert test.embedded.number == 2 doc = self.db.test.find_one() - self.assertEqual(doc["x"]["i"], 2) + assert doc["x"]["i"] == 2 def test_double_embedded_db_field(self): """Make sure multiple layers of embedded docs resolve db fields @@ -1242,7 +1243,7 @@ class TestField(MongoDBTestCase): b = EmbeddedDocumentField(B, db_field="fb") a = A._from_son(SON([("fb", SON([("fc", SON([("txt", "hi")]))]))])) - self.assertEqual(a.b.c.txt, "hi") + assert a.b.c.txt == "hi" def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet( self, @@ -1277,18 +1278,21 @@ class TestField(MongoDBTestCase): person = Person(name="Test User") person.preferences = "My Preferences" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() # Check that only the right embedded doc works person.preferences = Comment(content="Nice blog post...") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() # Check that the embedded doc is valid person.preferences = PersonPreferences() - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.preferences = PersonPreferences(food="Cheese", number=47) - self.assertEqual(person.preferences.food, "Cheese") + assert person.preferences.food == "Cheese" person.validate() def test_embedded_document_inheritance(self): @@ -1314,7 +1318,7 @@ class TestField(MongoDBTestCase): post.author = PowerUser(name="Test User", power=47) post.save() - self.assertEqual(47, BlogPost.objects.first().author.power) + assert 47 == BlogPost.objects.first().author.power def test_embedded_document_inheritance_with_list(self): """Ensure that nested list of subclassed embedded documents is @@ -1339,7 +1343,7 @@ class TestField(MongoDBTestCase): foobar = User(groups=[group]) foobar.save() - self.assertEqual(content, User.objects.first().groups[0].content) + assert content == User.objects.first().groups[0].content def test_reference_miss(self): """Ensure an exception is raised when dereferencing an unknown @@ -1362,16 +1366,18 @@ class TestField(MongoDBTestCase): # Reference is no longer valid foo.delete() bar = Bar.objects.get() - self.assertRaises(DoesNotExist, getattr, bar, "ref") - self.assertRaises(DoesNotExist, getattr, bar, "generic_ref") + with pytest.raises(DoesNotExist): + getattr(bar, "ref") + with pytest.raises(DoesNotExist): + getattr(bar, "generic_ref") # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() expected = foo.to_dbref() bar._fields["ref"]._auto_dereference = False - self.assertEqual(bar.ref, expected) + assert bar.ref == expected bar._fields["generic_ref"]._auto_dereference = False - self.assertEqual(bar.generic_ref, {"_ref": expected, "_cls": "Foo"}) + assert bar.generic_ref == {"_ref": expected, "_cls": "Foo"} def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -1396,8 +1402,8 @@ class TestField(MongoDBTestCase): group_obj = Group.objects.first() - self.assertEqual(group_obj.members[0].name, user1.name) - self.assertEqual(group_obj.members[1].name, user2.name) + assert group_obj.members[0].name == user1.name + assert group_obj.members[1].name == user2.name def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. @@ -1424,8 +1430,8 @@ class TestField(MongoDBTestCase): peter.save() peter = Employee.objects.with_id(peter.id) - self.assertEqual(peter.boss, bill) - self.assertEqual(peter.friends, friends) + assert peter.boss == bill + assert peter.friends == friends def test_recursive_embedding(self): """Ensure that EmbeddedDocumentFields can contain their own documents. @@ -1450,18 +1456,18 @@ class TestField(MongoDBTestCase): tree.save() tree = Tree.objects.first() - self.assertEqual(len(tree.children), 1) + assert len(tree.children) == 1 - self.assertEqual(len(tree.children[0].children), 1) + assert len(tree.children[0].children) == 1 third_child = TreeNode(name="Child 3") tree.children[0].children.append(third_child) tree.save() - self.assertEqual(len(tree.children), 1) - self.assertEqual(tree.children[0].name, first_child.name) - self.assertEqual(tree.children[0].children[0].name, second_child.name) - self.assertEqual(tree.children[0].children[1].name, third_child.name) + assert len(tree.children) == 1 + assert tree.children[0].name == first_child.name + assert tree.children[0].children[0].name == second_child.name + assert tree.children[0].children[1].name == third_child.name # Test updating tree.children[0].name = "I am Child 1" @@ -1469,28 +1475,28 @@ class TestField(MongoDBTestCase): tree.children[0].children[1].name = "I am Child 3" tree.save() - self.assertEqual(tree.children[0].name, "I am Child 1") - self.assertEqual(tree.children[0].children[0].name, "I am Child 2") - self.assertEqual(tree.children[0].children[1].name, "I am Child 3") + assert tree.children[0].name == "I am Child 1" + assert tree.children[0].children[0].name == "I am Child 2" + assert tree.children[0].children[1].name == "I am Child 3" # Test removal - self.assertEqual(len(tree.children[0].children), 2) + assert len(tree.children[0].children) == 2 del tree.children[0].children[1] tree.save() - self.assertEqual(len(tree.children[0].children), 1) + assert len(tree.children[0].children) == 1 tree.children[0].children.pop(0) tree.save() - self.assertEqual(len(tree.children[0].children), 0) - self.assertEqual(tree.children[0].children, []) + assert len(tree.children[0].children) == 0 + assert tree.children[0].children == [] tree.children[0].children.insert(0, third_child) tree.children[0].children.insert(0, second_child) tree.save() - self.assertEqual(len(tree.children[0].children), 2) - self.assertEqual(tree.children[0].children[0].name, second_child.name) - self.assertEqual(tree.children[0].children[1].name, third_child.name) + assert len(tree.children[0].children) == 2 + assert tree.children[0].children[0].name == second_child.name + assert tree.children[0].children[1].name == third_child.name def test_drop_abstract_document(self): """Ensure that an abstract document cannot be dropped given it @@ -1501,7 +1507,8 @@ class TestField(MongoDBTestCase): name = StringField() meta = {"abstract": True} - self.assertRaises(OperationError, AbstractDoc.drop_collection) + with pytest.raises(OperationError): + AbstractDoc.drop_collection() def test_reference_class_with_abstract_parent(self): """Ensure that a class with an abstract parent can be referenced. @@ -1525,7 +1532,7 @@ class TestField(MongoDBTestCase): brother = Brother(name="Bob", sibling=sister) brother.save() - self.assertEqual(Brother.objects[0].sibling.name, sister.name) + assert Brother.objects[0].sibling.name == sister.name def test_reference_abstract_class(self): """Ensure that an abstract class instance cannot be used in the @@ -1547,7 +1554,8 @@ class TestField(MongoDBTestCase): sister = Sibling(name="Alice") brother = Brother(name="Bob", sibling=sister) - self.assertRaises(ValidationError, brother.save) + with pytest.raises(ValidationError): + brother.save() def test_abstract_reference_base_type(self): """Ensure that an an abstract reference fails validation when given a @@ -1570,7 +1578,8 @@ class TestField(MongoDBTestCase): mother = Mother(name="Carol") mother.save() brother = Brother(name="Bob", sibling=mother) - self.assertRaises(ValidationError, brother.save) + with pytest.raises(ValidationError): + brother.save() def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. @@ -1601,16 +1610,16 @@ class TestField(MongoDBTestCase): bm = Bookmark.objects(bookmark_object=post_1).first() - self.assertEqual(bm.bookmark_object, post_1) - self.assertIsInstance(bm.bookmark_object, Post) + assert bm.bookmark_object == post_1 + assert isinstance(bm.bookmark_object, Post) bm.bookmark_object = link_1 bm.save() bm = Bookmark.objects(bookmark_object=link_1).first() - self.assertEqual(bm.bookmark_object, link_1) - self.assertIsInstance(bm.bookmark_object, Link) + assert bm.bookmark_object == link_1 + assert isinstance(bm.bookmark_object, Link) def test_generic_reference_list(self): """Ensure that a ListField properly dereferences generic references. @@ -1640,8 +1649,8 @@ class TestField(MongoDBTestCase): user = User.objects(bookmarks__all=[post_1, link_1]).first() - self.assertEqual(user.bookmarks[0], post_1) - self.assertEqual(user.bookmarks[1], link_1) + assert user.bookmarks[0] == post_1 + assert user.bookmarks[1] == link_1 def test_generic_reference_document_not_registered(self): """Ensure dereferencing out of the document registry throws a @@ -1682,7 +1691,7 @@ class TestField(MongoDBTestCase): Person.drop_collection() Person(name="Wilson Jr").save() - self.assertEqual(repr(Person.objects(city=None)), "[]") + assert repr(Person.objects(city=None)) == "[]" def test_generic_reference_choices(self): """Ensure that a GenericReferenceField can handle choices.""" @@ -1707,13 +1716,14 @@ class TestField(MongoDBTestCase): post_1.save() bm = Bookmark(bookmark_object=link_1) - self.assertRaises(ValidationError, bm.validate) + with pytest.raises(ValidationError): + bm.validate() bm = Bookmark(bookmark_object=post_1) bm.save() bm = Bookmark.objects.first() - self.assertEqual(bm.bookmark_object, post_1) + assert bm.bookmark_object == post_1 def test_generic_reference_string_choices(self): """Ensure that a GenericReferenceField can handle choices as strings @@ -1745,7 +1755,8 @@ class TestField(MongoDBTestCase): bm.save() bm = Bookmark(bookmark_object=bm) - self.assertRaises(ValidationError, bm.validate) + with pytest.raises(ValidationError): + bm.validate() def test_generic_reference_choices_no_dereference(self): """Ensure that a GenericReferenceField can handle choices on @@ -1798,13 +1809,14 @@ class TestField(MongoDBTestCase): post_1.save() user = User(bookmarks=[link_1]) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() user = User(bookmarks=[post_1]) user.save() user = User.objects.first() - self.assertEqual(user.bookmarks, [post_1]) + assert user.bookmarks == [post_1] def test_generic_reference_list_item_modification(self): """Ensure that modifications of related documents (through generic reference) don't influence on querying @@ -1832,8 +1844,8 @@ class TestField(MongoDBTestCase): user = User.objects(bookmarks__all=[post_1]).first() - self.assertNotEqual(user, None) - self.assertEqual(user.bookmarks[0], post_1) + assert user != None + assert user.bookmarks[0] == post_1 def test_generic_reference_filter_by_dbref(self): """Ensure we can search for a specific generic reference by @@ -1849,7 +1861,7 @@ class TestField(MongoDBTestCase): doc2 = Doc.objects.create(ref=doc1) doc = Doc.objects.get(ref=DBRef("doc", doc1.pk)) - self.assertEqual(doc, doc2) + assert doc == doc2 def test_generic_reference_is_not_tracked_in_parent_doc(self): """Ensure that modifications of related documents (through generic reference) don't influence @@ -1871,11 +1883,11 @@ class TestField(MongoDBTestCase): doc2 = Doc2(ref=doc1, refs=[doc11]).save() doc2.ref.name = "garbage2" - self.assertEqual(doc2._get_changed_fields(), []) + assert doc2._get_changed_fields() == [] doc2.refs[0].name = "garbage3" - self.assertEqual(doc2._get_changed_fields(), []) - self.assertEqual(doc2._delta(), ({}, {})) + assert doc2._get_changed_fields() == [] + assert doc2._delta() == ({}, {}) def test_generic_reference_field(self): """Ensure we can search for a specific generic reference by @@ -1890,10 +1902,10 @@ class TestField(MongoDBTestCase): doc1 = Doc.objects.create() doc2 = Doc.objects.create(ref=doc1) - self.assertIsInstance(doc1.pk, ObjectId) + assert isinstance(doc1.pk, ObjectId) doc = Doc.objects.get(ref=doc1.pk) - self.assertEqual(doc, doc2) + assert doc == doc2 def test_choices_allow_using_sets_as_choices(self): """Ensure that sets can be used when setting choices @@ -1933,7 +1945,7 @@ class TestField(MongoDBTestCase): size = StringField(choices=("S", "M")) shirt = Shirt(size="XS") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): shirt.validate() def test_choices_get_field_display(self): @@ -1964,24 +1976,25 @@ class TestField(MongoDBTestCase): shirt2 = Shirt() # Make sure get__display returns the default value (or None) - self.assertEqual(shirt1.get_size_display(), None) - self.assertEqual(shirt1.get_style_display(), "Wide") + assert shirt1.get_size_display() == None + assert shirt1.get_style_display() == "Wide" shirt1.size = "XXL" shirt1.style = "B" shirt2.size = "M" shirt2.style = "S" - self.assertEqual(shirt1.get_size_display(), "Extra Extra Large") - self.assertEqual(shirt1.get_style_display(), "Baggy") - self.assertEqual(shirt2.get_size_display(), "Medium") - self.assertEqual(shirt2.get_style_display(), "Small") + assert shirt1.get_size_display() == "Extra Extra Large" + assert shirt1.get_style_display() == "Baggy" + assert shirt2.get_size_display() == "Medium" + assert shirt2.get_style_display() == "Small" # Set as Z - an invalid choice shirt1.size = "Z" shirt1.style = "Z" - self.assertEqual(shirt1.get_size_display(), "Z") - self.assertEqual(shirt1.get_style_display(), "Z") - self.assertRaises(ValidationError, shirt1.validate) + assert shirt1.get_size_display() == "Z" + assert shirt1.get_style_display() == "Z" + with pytest.raises(ValidationError): + shirt1.validate() def test_simple_choices_validation(self): """Ensure that value is in a container of allowed values. @@ -1999,7 +2012,8 @@ class TestField(MongoDBTestCase): shirt.validate() shirt.size = "XS" - self.assertRaises(ValidationError, shirt.validate) + with pytest.raises(ValidationError): + shirt.validate() def test_simple_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices @@ -2016,20 +2030,21 @@ class TestField(MongoDBTestCase): shirt = Shirt() - self.assertEqual(shirt.get_size_display(), None) - self.assertEqual(shirt.get_style_display(), "Small") + assert shirt.get_size_display() == None + assert shirt.get_style_display() == "Small" shirt.size = "XXL" shirt.style = "Baggy" - self.assertEqual(shirt.get_size_display(), "XXL") - self.assertEqual(shirt.get_style_display(), "Baggy") + assert shirt.get_size_display() == "XXL" + assert shirt.get_style_display() == "Baggy" # Set as Z - an invalid choice shirt.size = "Z" shirt.style = "Z" - self.assertEqual(shirt.get_size_display(), "Z") - self.assertEqual(shirt.get_style_display(), "Z") - self.assertRaises(ValidationError, shirt.validate) + assert shirt.get_size_display() == "Z" + assert shirt.get_style_display() == "Z" + with pytest.raises(ValidationError): + shirt.validate() def test_simple_choices_validation_invalid_value(self): """Ensure that error messages are correct. @@ -2060,8 +2075,8 @@ class TestField(MongoDBTestCase): except ValidationError as error: # get the validation rules error_dict = error.to_dict() - self.assertEqual(error_dict["size"], SIZE_MESSAGE) - self.assertEqual(error_dict["color"], COLOR_MESSAGE) + assert error_dict["size"] == SIZE_MESSAGE + assert error_dict["color"] == COLOR_MESSAGE def test_recursive_validation(self): """Ensure that a validation result to_dict is available.""" @@ -2082,26 +2097,25 @@ class TestField(MongoDBTestCase): post.comments.append(Comment(content="hello", author=bob)) post.comments.append(Comment(author=bob)) - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() try: post.validate() except ValidationError as error: # ValidationError.errors property - self.assertTrue(hasattr(error, "errors")) - self.assertIsInstance(error.errors, dict) - self.assertIn("comments", error.errors) - self.assertIn(1, error.errors["comments"]) - self.assertIsInstance( - error.errors["comments"][1]["content"], ValidationError - ) + assert hasattr(error, "errors") + assert isinstance(error.errors, dict) + assert "comments" in error.errors + assert 1 in error.errors["comments"] + assert isinstance(error.errors["comments"][1]["content"], ValidationError) # ValidationError.schema property error_dict = error.to_dict() - self.assertIsInstance(error_dict, dict) - self.assertIn("comments", error_dict) - self.assertIn(1, error_dict["comments"]) - self.assertIn("content", error_dict["comments"][1]) - self.assertEqual(error_dict["comments"][1]["content"], u"Field is required") + assert isinstance(error_dict, dict) + assert "comments" in error_dict + assert 1 in error_dict["comments"] + assert "content" in error_dict["comments"][1] + assert error_dict["comments"][1]["content"] == u"Field is required" post.comments[1].content = "here we go" post.validate() @@ -2131,10 +2145,10 @@ class TestField(MongoDBTestCase): doc.items = tuples doc.save() x = TestDoc.objects().get() - self.assertIsNotNone(x) - self.assertEqual(len(x.items), 1) - self.assertIn(tuple(x.items[0]), tuples) - self.assertIn(x.items[0], tuples) + assert x is not None + assert len(x.items) == 1 + assert tuple(x.items[0]) in tuples + assert x.items[0] in tuples def test_dynamic_fields_class(self): class Doc2(Document): @@ -2150,13 +2164,14 @@ class TestField(MongoDBTestCase): doc2 = Doc2(field_1="hello") doc = Doc(my_id=1, embed_me=doc2, field_x="x") - self.assertRaises(OperationError, doc.save) + with pytest.raises(OperationError): + doc.save() doc2.save() doc.save() doc = Doc.objects.get() - self.assertEqual(doc.embed_me.field_1, "hello") + assert doc.embed_me.field_1 == "hello" def test_dynamic_fields_embedded_class(self): class Embed(EmbeddedDocument): @@ -2172,7 +2187,7 @@ class TestField(MongoDBTestCase): Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save() doc = Doc.objects.get() - self.assertEqual(doc.embed_me.field_1, "hello") + assert doc.embed_me.field_1 == "hello" def test_dynamicfield_dump_document(self): """Ensure a DynamicField can handle another document's dump.""" @@ -2197,15 +2212,15 @@ class TestField(MongoDBTestCase): to_embed = ToEmbed(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed) doc.save() - self.assertIsInstance(doc.field, ToEmbed) - self.assertEqual(doc.field, to_embed) + assert isinstance(doc.field, ToEmbed) + assert doc.field == to_embed # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed_child) doc.save() - self.assertIsInstance(doc.field, ToEmbedChild) - self.assertEqual(doc.field, to_embed_child) + assert isinstance(doc.field, ToEmbedChild) + assert doc.field == to_embed_child def test_cls_field(self): class Animal(Document): @@ -2227,10 +2242,10 @@ class TestField(MongoDBTestCase): Dog().save() Fish().save() Human().save() - self.assertEqual( - Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2 + assert ( + Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count() == 2 ) - self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) + assert Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count() == 0 def test_sparse_field(self): class Doc(Document): @@ -2249,7 +2264,7 @@ class TestField(MongoDBTestCase): class Doc(Document): foo = StringField() - with self.assertRaises(FieldDoesNotExist): + with pytest.raises(FieldDoesNotExist): Doc(bar="test") def test_undefined_field_exception_with_strict(self): @@ -2262,7 +2277,7 @@ class TestField(MongoDBTestCase): foo = StringField() meta = {"strict": False} - with self.assertRaises(FieldDoesNotExist): + with pytest.raises(FieldDoesNotExist): Doc(bar="test") @@ -2310,20 +2325,20 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): # Test with an embeddedDocument instead of a list(embeddedDocument) # It's an edge case but it used to fail with a vague error, making it difficult to troubleshoot it post = self.BlogPost(comments=comment) - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: post.validate() - self.assertIn("'comments'", str(ctx_err.exception)) - self.assertIn( - "Only lists and tuples may be used in a list field", str(ctx_err.exception) + assert "'comments'" in str(ctx_err.exception) + assert "Only lists and tuples may be used in a list field" in str( + ctx_err.exception ) # Test with a Document post = self.BlogPost(comments=Title(content="garbage")) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): post.validate() - self.assertIn("'comments'", str(ctx_err.exception)) - self.assertIn( - "Only lists and tuples may be used in a list field", str(ctx_err.exception) + assert "'comments'" in str(ctx_err.exception) + assert "Only lists and tuples may be used in a list field" in str( + ctx_err.exception ) def test_no_keyword_filter(self): @@ -2334,7 +2349,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post1.comments.filter() # Ensure nothing was changed - self.assertListEqual(filtered, self.post1.comments) + assert filtered == self.post1.comments def test_single_keyword_filter(self): """ @@ -2344,10 +2359,10 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post1.comments.filter(author="user1") # Ensure only 1 entry was returned. - self.assertEqual(len(filtered), 1) + assert len(filtered) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, "user1") + assert filtered[0].author == "user1" def test_multi_keyword_filter(self): """ @@ -2357,11 +2372,11 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post2.comments.filter(author="user2", message="message2") # Ensure only 1 entry was returned. - self.assertEqual(len(filtered), 1) + assert len(filtered) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, "user2") - self.assertEqual(filtered[0].message, "message2") + assert filtered[0].author == "user2" + assert filtered[0].message == "message2" def test_chained_filter(self): """ @@ -2370,18 +2385,18 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post2.comments.filter(author="user2").filter(message="message2") # Ensure only 1 entry was returned. - self.assertEqual(len(filtered), 1) + assert len(filtered) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, "user2") - self.assertEqual(filtered[0].message, "message2") + assert filtered[0].author == "user2" + assert filtered[0].message == "message2" def test_unknown_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents when the keyword is not a known keyword. """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.post2.comments.filter(year=2) def test_no_keyword_exclude(self): @@ -2392,7 +2407,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post1.comments.exclude() # Ensure everything was removed - self.assertListEqual(filtered, []) + assert filtered == [] def test_single_keyword_exclude(self): """ @@ -2402,10 +2417,10 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): excluded = self.post1.comments.exclude(author="user1") # Ensure only 1 entry was returned. - self.assertEqual(len(excluded), 1) + assert len(excluded) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(excluded[0].author, "user2") + assert excluded[0].author == "user2" def test_multi_keyword_exclude(self): """ @@ -2415,11 +2430,11 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): excluded = self.post2.comments.exclude(author="user3", message="message1") # Ensure only 2 entries were returned. - self.assertEqual(len(excluded), 2) + assert len(excluded) == 2 # Ensure the entries returned are the correct entries. - self.assertEqual(excluded[0].author, "user2") - self.assertEqual(excluded[1].author, "user2") + assert excluded[0].author == "user2" + assert excluded[1].author == "user2" def test_non_matching_exclude(self): """ @@ -2429,14 +2444,14 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): excluded = self.post2.comments.exclude(author="user4") # Ensure the 3 entries still exist. - self.assertEqual(len(excluded), 3) + assert len(excluded) == 3 def test_unknown_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents when the keyword is not a known keyword. """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.post2.comments.exclude(year=2) def test_chained_filter_exclude(self): @@ -2449,25 +2464,25 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): ) # Ensure only 1 entry was returned. - self.assertEqual(len(excluded), 1) + assert len(excluded) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(excluded[0].author, "user2") - self.assertEqual(excluded[0].message, "message3") + assert excluded[0].author == "user2" + assert excluded[0].message == "message3" def test_count(self): """ Tests the count method of a List of Embedded Documents. """ - self.assertEqual(self.post1.comments.count(), 2) - self.assertEqual(self.post1.comments.count(), len(self.post1.comments)) + assert self.post1.comments.count() == 2 + assert self.post1.comments.count() == len(self.post1.comments) def test_filtered_count(self): """ Tests the filter + count method of a List of Embedded Documents. """ count = self.post1.comments.filter(author="user1").count() - self.assertEqual(count, 1) + assert count == 1 def test_single_keyword_get(self): """ @@ -2475,8 +2490,8 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): single keyword. """ comment = self.post1.comments.get(author="user1") - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user1") + assert isinstance(comment, self.Comments) + assert comment.author == "user1" def test_multi_keyword_get(self): """ @@ -2484,16 +2499,16 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): multiple keywords. """ comment = self.post2.comments.get(author="user2", message="message2") - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user2") - self.assertEqual(comment.message, "message2") + assert isinstance(comment, self.Comments) + assert comment.author == "user2" + assert comment.message == "message2" def test_no_keyword_multiple_return_get(self): """ Tests the get method of a List of Embedded Documents without a keyword to return multiple documents. """ - with self.assertRaises(MultipleObjectsReturned): + with pytest.raises(MultipleObjectsReturned): self.post1.comments.get() def test_keyword_multiple_return_get(self): @@ -2501,7 +2516,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): Tests the get method of a List of Embedded Documents with a keyword to return multiple documents. """ - with self.assertRaises(MultipleObjectsReturned): + with pytest.raises(MultipleObjectsReturned): self.post2.comments.get(author="user2") def test_unknown_keyword_get(self): @@ -2509,7 +2524,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): Tests the get method of a List of Embedded Documents with an unknown keyword. """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.post2.comments.get(year=2020) def test_no_result_get(self): @@ -2517,7 +2532,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): Tests the get method of a List of Embedded Documents where get returns no results. """ - with self.assertRaises(DoesNotExist): + with pytest.raises(DoesNotExist): self.post1.comments.get(author="user3") def test_first(self): @@ -2528,8 +2543,8 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): comment = self.post1.comments.first() # Ensure a Comment object was returned. - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment, self.post1.comments[0]) + assert isinstance(comment, self.Comments) + assert comment == self.post1.comments[0] def test_create(self): """ @@ -2539,14 +2554,12 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure the returned value is the comment object. - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user4") - self.assertEqual(comment.message, "message1") + assert isinstance(comment, self.Comments) + assert comment.author == "user4" + assert comment.message == "message1" # Ensure the new comment was actually saved to the database. - self.assertIn( - comment, self.BlogPost.objects(comments__author="user4")[0].comments - ) + assert comment in self.BlogPost.objects(comments__author="user4")[0].comments def test_filtered_create(self): """ @@ -2560,14 +2573,12 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure the returned value is the comment object. - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user4") - self.assertEqual(comment.message, "message1") + assert isinstance(comment, self.Comments) + assert comment.author == "user4" + assert comment.message == "message1" # Ensure the new comment was actually saved to the database. - self.assertIn( - comment, self.BlogPost.objects(comments__author="user4")[0].comments - ) + assert comment in self.BlogPost.objects(comments__author="user4")[0].comments def test_no_keyword_update(self): """ @@ -2579,13 +2590,13 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure that nothing was altered. - self.assertIn(original[0], self.BlogPost.objects(id=self.post1.id)[0].comments) + assert original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments - self.assertIn(original[1], self.BlogPost.objects(id=self.post1.id)[0].comments) + assert original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure the method returned 0 as the number of entries # modified - self.assertEqual(number, 0) + assert number == 0 def test_single_keyword_update(self): """ @@ -2598,12 +2609,12 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): comments = self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure that the database was updated properly. - self.assertEqual(comments[0].author, "user4") - self.assertEqual(comments[1].author, "user4") + assert comments[0].author == "user4" + assert comments[1].author == "user4" # Ensure the method returned 2 as the number of entries # modified - self.assertEqual(number, 2) + assert number == 2 def test_unicode(self): """ @@ -2615,7 +2626,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.Comments(author="user2", message=u"хабарлама"), ] ).save() - self.assertEqual(post.comments.get(message=u"сообщение").author, "user1") + assert post.comments.get(message=u"сообщение").author == "user1" def test_save(self): """ @@ -2627,7 +2638,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): comments.save() # Ensure that the new comment has been added to the database. - self.assertIn(new_comment, self.BlogPost.objects(id=self.post1.id)[0].comments) + assert new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments def test_delete(self): """ @@ -2638,17 +2649,17 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): # Ensure that all the comments under post1 were deleted in the # database. - self.assertListEqual(self.BlogPost.objects(id=self.post1.id)[0].comments, []) + assert self.BlogPost.objects(id=self.post1.id)[0].comments == [] # Ensure that post1 comments were deleted from the list. - self.assertListEqual(self.post1.comments, []) + assert self.post1.comments == [] # Ensure that comments still returned a EmbeddedDocumentList object. - self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) + assert isinstance(self.post1.comments, EmbeddedDocumentList) # Ensure that the delete method returned 2 as the number of entries # deleted from the database - self.assertEqual(number, 2) + assert number == 2 def test_empty_list_embedded_documents_with_unique_field(self): """ @@ -2664,7 +2675,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): my_list = ListField(EmbeddedDocumentField(EmbeddedWithUnique)) A(my_list=[]).save() - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): A(my_list=[]).save() class EmbeddedWithSparseUnique(EmbeddedDocument): @@ -2689,16 +2700,16 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure that only the user2 comment was deleted. - self.assertNotIn(comment, self.BlogPost.objects(id=self.post1.id)[0].comments) - self.assertEqual(len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1) + assert comment not in self.BlogPost.objects(id=self.post1.id)[0].comments + assert len(self.BlogPost.objects(id=self.post1.id)[0].comments) == 1 # Ensure that the user2 comment no longer exists in the list. - self.assertNotIn(comment, self.post1.comments) - self.assertEqual(len(self.post1.comments), 1) + assert comment not in self.post1.comments + assert len(self.post1.comments) == 1 # Ensure that the delete method returned 1 as the number of entries # deleted from the database - self.assertEqual(number, 1) + assert number == 1 def test_custom_data(self): """ @@ -2714,10 +2725,10 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): CustomData.drop_collection() a1 = CustomData(a_field=1, c_field=2).save() - self.assertEqual(2, a1.c_field) - self.assertFalse(hasattr(a1.c_field, "custom_data")) - self.assertTrue(hasattr(CustomData.c_field, "custom_data")) - self.assertEqual(custom_data["a"], CustomData.c_field.custom_data["a"]) + assert 2 == a1.c_field + assert not hasattr(a1.c_field, "custom_data") + assert hasattr(CustomData.c_field, "custom_data") + assert custom_data["a"] == CustomData.c_field.custom_data["a"] if __name__ == "__main__": diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index 49eb5bc2..0746db33 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -64,13 +64,13 @@ class TestFileField(MongoDBTestCase): putfile.save() result = PutFile.objects.first() - self.assertEqual(putfile, result) - self.assertEqual( - "%s" % result.the_file, - "" % result.the_file.grid_id, + assert putfile == result + assert ( + "%s" % result.the_file + == "" % result.the_file.grid_id ) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) + assert result.the_file.read() == text + assert result.the_file.content_type == content_type result.the_file.delete() # Remove file from GridFS PutFile.objects.delete() @@ -85,9 +85,9 @@ class TestFileField(MongoDBTestCase): putfile.save() result = PutFile.objects.first() - self.assertEqual(putfile, result) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) + assert putfile == result + assert result.the_file.read() == text + assert result.the_file.content_type == content_type result.the_file.delete() def test_file_fields_stream(self): @@ -111,19 +111,19 @@ class TestFileField(MongoDBTestCase): streamfile.save() result = StreamFile.objects.first() - self.assertEqual(streamfile, result) - self.assertEqual(result.the_file.read(), text + more_text) - self.assertEqual(result.the_file.content_type, content_type) + assert streamfile == result + assert result.the_file.read() == text + more_text + assert result.the_file.content_type == content_type result.the_file.seek(0) - self.assertEqual(result.the_file.tell(), 0) - self.assertEqual(result.the_file.read(len(text)), text) - self.assertEqual(result.the_file.tell(), len(text)) - self.assertEqual(result.the_file.read(len(more_text)), more_text) - self.assertEqual(result.the_file.tell(), len(text + more_text)) + assert result.the_file.tell() == 0 + assert result.the_file.read(len(text)) == text + assert result.the_file.tell() == len(text) + assert result.the_file.read(len(more_text)) == more_text + assert result.the_file.tell() == len(text + more_text) result.the_file.delete() # Ensure deleted file returns None - self.assertTrue(result.the_file.read() is None) + assert result.the_file.read() is None def test_file_fields_stream_after_none(self): """Ensure that a file field can be written to after it has been saved as @@ -148,19 +148,19 @@ class TestFileField(MongoDBTestCase): streamfile.save() result = StreamFile.objects.first() - self.assertEqual(streamfile, result) - self.assertEqual(result.the_file.read(), text + more_text) + assert streamfile == result + assert result.the_file.read() == text + more_text # self.assertEqual(result.the_file.content_type, content_type) result.the_file.seek(0) - self.assertEqual(result.the_file.tell(), 0) - self.assertEqual(result.the_file.read(len(text)), text) - self.assertEqual(result.the_file.tell(), len(text)) - self.assertEqual(result.the_file.read(len(more_text)), more_text) - self.assertEqual(result.the_file.tell(), len(text + more_text)) + assert result.the_file.tell() == 0 + assert result.the_file.read(len(text)) == text + assert result.the_file.tell() == len(text) + assert result.the_file.read(len(more_text)) == more_text + assert result.the_file.tell() == len(text + more_text) result.the_file.delete() # Ensure deleted file returns None - self.assertTrue(result.the_file.read() is None) + assert result.the_file.read() is None def test_file_fields_set(self): class SetFile(Document): @@ -176,16 +176,16 @@ class TestFileField(MongoDBTestCase): setfile.save() result = SetFile.objects.first() - self.assertEqual(setfile, result) - self.assertEqual(result.the_file.read(), text) + assert setfile == result + assert result.the_file.read() == text # Try replacing file with new one result.the_file.replace(more_text) result.save() result = SetFile.objects.first() - self.assertEqual(setfile, result) - self.assertEqual(result.the_file.read(), more_text) + assert setfile == result + assert result.the_file.read() == more_text result.the_file.delete() def test_file_field_no_default(self): @@ -205,28 +205,28 @@ class TestFileField(MongoDBTestCase): doc_b = GridDocument.objects.with_id(doc_a.id) doc_b.the_file.replace(f, filename="doc_b") doc_b.save() - self.assertNotEqual(doc_b.the_file.grid_id, None) + assert doc_b.the_file.grid_id != None # Test it matches doc_c = GridDocument.objects.with_id(doc_b.id) - self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) + assert doc_b.the_file.grid_id == doc_c.the_file.grid_id # Test with default doc_d = GridDocument(the_file=six.b("")) doc_d.save() doc_e = GridDocument.objects.with_id(doc_d.id) - self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id) + assert doc_d.the_file.grid_id == doc_e.the_file.grid_id doc_e.the_file.replace(f, filename="doc_e") doc_e.save() doc_f = GridDocument.objects.with_id(doc_e.id) - self.assertEqual(doc_e.the_file.grid_id, doc_f.the_file.grid_id) + assert doc_e.the_file.grid_id == doc_f.the_file.grid_id db = GridDocument._get_db() grid_fs = gridfs.GridFS(db) - self.assertEqual(["doc_b", "doc_e"], grid_fs.list()) + assert ["doc_b", "doc_e"] == grid_fs.list() def test_file_uniqueness(self): """Ensure that each instance of a FileField is unique @@ -246,8 +246,8 @@ class TestFileField(MongoDBTestCase): test_file_dupe = TestFile() data = test_file_dupe.the_file.read() # Should be None - self.assertNotEqual(test_file.name, test_file_dupe.name) - self.assertNotEqual(test_file.the_file.read(), data) + assert test_file.name != test_file_dupe.name + assert test_file.the_file.read() != data TestFile.drop_collection() @@ -268,8 +268,8 @@ class TestFileField(MongoDBTestCase): marmot.save() marmot = Animal.objects.get() - self.assertEqual(marmot.photo.content_type, "image/jpeg") - self.assertEqual(marmot.photo.foo, "bar") + assert marmot.photo.content_type == "image/jpeg" + assert marmot.photo.foo == "bar" def test_file_reassigning(self): class TestFile(Document): @@ -278,12 +278,12 @@ class TestFileField(MongoDBTestCase): TestFile.drop_collection() test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() - self.assertEqual(test_file.the_file.get().length, 8313) + assert test_file.the_file.get().length == 8313 test_file = TestFile.objects.first() test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() - self.assertEqual(test_file.the_file.get().length, 4971) + assert test_file.the_file.get().length == 4971 def test_file_boolean(self): """Ensure that a boolean test of a FileField indicates its presence @@ -295,13 +295,13 @@ class TestFileField(MongoDBTestCase): TestFile.drop_collection() test_file = TestFile() - self.assertFalse(bool(test_file.the_file)) + assert not bool(test_file.the_file) test_file.the_file.put(six.b("Hello, World!"), content_type="text/plain") test_file.save() - self.assertTrue(bool(test_file.the_file)) + assert bool(test_file.the_file) test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.content_type, "text/plain") + assert test_file.the_file.content_type == "text/plain" def test_file_cmp(self): """Test comparing against other types""" @@ -310,7 +310,7 @@ class TestFileField(MongoDBTestCase): the_file = FileField() test_file = TestFile() - self.assertNotIn(test_file.the_file, [{"test": 1}]) + assert test_file.the_file not in [{"test": 1}] def test_file_disk_space(self): """ Test disk space usage when we delete/replace a file """ @@ -330,16 +330,16 @@ class TestFileField(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 1) - self.assertEqual(len(list(chunks)), 1) + assert len(list(files)) == 1 + assert len(list(chunks)) == 1 # Deleting the docoument should delete the files testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 # Test case where we don't store a file in the first place testfile = TestFile() @@ -347,15 +347,15 @@ class TestFileField(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 # Test case where we overwrite the file testfile = TestFile() @@ -368,15 +368,15 @@ class TestFileField(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 1) - self.assertEqual(len(list(chunks)), 1) + assert len(list(files)) == 1 + assert len(list(chunks)) == 1 testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 def test_image_field(self): if not HAS_PIL: @@ -396,9 +396,7 @@ class TestFileField(MongoDBTestCase): t.image.put(f) self.fail("Should have raised an invalidation error") except ValidationError as e: - self.assertEqual( - "%s" % e, "Invalid image: cannot identify image file %s" % f - ) + assert "%s" % e == "Invalid image: cannot identify image file %s" % f t = TestImage() t.image.put(get_file(TEST_IMAGE_PATH)) @@ -406,11 +404,11 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, "PNG") + assert t.image.format == "PNG" w, h = t.image.size - self.assertEqual(w, 371) - self.assertEqual(h, 76) + assert w == 371 + assert h == 76 t.image.delete() @@ -424,12 +422,12 @@ class TestFileField(MongoDBTestCase): TestFile.drop_collection() test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() - self.assertEqual(test_file.the_file.size, (371, 76)) + assert test_file.the_file.size == (371, 76) test_file = TestFile.objects.first() test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() - self.assertEqual(test_file.the_file.size, (45, 101)) + assert test_file.the_file.size == (45, 101) def test_image_field_resize(self): if not HAS_PIL: @@ -446,11 +444,11 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, "PNG") + assert t.image.format == "PNG" w, h = t.image.size - self.assertEqual(w, 185) - self.assertEqual(h, 37) + assert w == 185 + assert h == 37 t.image.delete() @@ -469,11 +467,11 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, "PNG") + assert t.image.format == "PNG" w, h = t.image.size - self.assertEqual(w, 185) - self.assertEqual(h, 37) + assert w == 185 + assert h == 37 t.image.delete() @@ -492,9 +490,9 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.thumbnail.format, "PNG") - self.assertEqual(t.image.thumbnail.width, 92) - self.assertEqual(t.image.thumbnail.height, 18) + assert t.image.thumbnail.format == "PNG" + assert t.image.thumbnail.width == 92 + assert t.image.thumbnail.height == 18 t.image.delete() @@ -518,17 +516,17 @@ class TestFileField(MongoDBTestCase): test_file.save() data = get_db("test_files").macumba.files.find_one() - self.assertEqual(data.get("name"), "hello.txt") + assert data.get("name") == "hello.txt" test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), six.b("Hello, World!")) + assert test_file.the_file.read() == six.b("Hello, World!") test_file = TestFile.objects.first() test_file.the_file = six.b("HELLO, WORLD!") test_file.save() test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), six.b("HELLO, WORLD!")) + assert test_file.the_file.read() == six.b("HELLO, WORLD!") def test_copyable(self): class PutFile(Document): @@ -546,8 +544,8 @@ class TestFileField(MongoDBTestCase): class TestFile(Document): name = StringField() - self.assertEqual(putfile, copy.copy(putfile)) - self.assertEqual(putfile, copy.deepcopy(putfile)) + assert putfile == copy.copy(putfile) + assert putfile == copy.deepcopy(putfile) def test_get_image_by_grid_id(self): @@ -569,9 +567,7 @@ class TestFileField(MongoDBTestCase): test = TestImage.objects.first() grid_id = test.image1.grid_id - self.assertEqual( - 1, TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count() - ) + assert 1 == TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count() def test_complex_field_filefield(self): """Ensure you can add meta data to file""" @@ -593,9 +589,9 @@ class TestFileField(MongoDBTestCase): marmot.save() marmot = Animal.objects.get() - self.assertEqual(marmot.photos[0].content_type, "image/jpeg") - self.assertEqual(marmot.photos[0].foo, "bar") - self.assertEqual(marmot.photos[0].get().length, 8313) + assert marmot.photos[0].content_type == "image/jpeg" + assert marmot.photos[0].foo == "bar" + assert marmot.photos[0].get().length == 8313 if __name__ == "__main__": diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py index 9f357ce5..d755fb4e 100644 --- a/tests/fields/test_float_field.py +++ b/tests/fields/test_float_field.py @@ -4,6 +4,7 @@ import six from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestFloatField(MongoDBTestCase): @@ -16,8 +17,8 @@ class TestFloatField(MongoDBTestCase): TestDocument(float_fld=None).save() TestDocument(float_fld=1).save() - self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) - self.assertEqual(1, TestDocument.objects(float_fld__ne=1).count()) + assert 1 == TestDocument.objects(float_fld__ne=None).count() + assert 1 == TestDocument.objects(float_fld__ne=1).count() def test_validation(self): """Ensure that invalid values cannot be assigned to float fields. @@ -34,16 +35,20 @@ class TestFloatField(MongoDBTestCase): person.validate() person.height = "2.0" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = 0.01 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = 4.0 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person_2 = Person(height="something invalid") - self.assertRaises(ValidationError, person_2.validate) + with pytest.raises(ValidationError): + person_2.validate() big_person = BigPerson() @@ -55,4 +60,5 @@ class TestFloatField(MongoDBTestCase): big_person.validate() big_person.height = 2 ** 100000 # Too big for a float value - self.assertRaises(ValidationError, big_person.validate) + with pytest.raises(ValidationError): + big_person.validate() diff --git a/tests/fields/test_geo_fields.py b/tests/fields/test_geo_fields.py index ff4cbc83..1b912a4b 100644 --- a/tests/fields/test_geo_fields.py +++ b/tests/fields/test_geo_fields.py @@ -11,7 +11,7 @@ class TestGeoField(MongoDBTestCase): Cls(loc=loc).validate() self.fail("Should not validate the location {0}".format(loc)) except ValidationError as e: - self.assertEqual(expected, e.to_dict()["loc"]) + assert expected == e.to_dict()["loc"] def test_geopoint_validation(self): class Location(Document): @@ -299,7 +299,7 @@ class TestGeoField(MongoDBTestCase): location = GeoPointField() geo_indicies = Event._geo_indices() - self.assertEqual(geo_indicies, [{"fields": [("location", "2d")]}]) + assert geo_indicies == [{"fields": [("location", "2d")]}] def test_geopoint_embedded_indexes(self): """Ensure that indexes are created automatically for GeoPointFields on @@ -315,7 +315,7 @@ class TestGeoField(MongoDBTestCase): venue = EmbeddedDocumentField(Venue) geo_indicies = Event._geo_indices() - self.assertEqual(geo_indicies, [{"fields": [("venue.location", "2d")]}]) + assert geo_indicies == [{"fields": [("venue.location", "2d")]}] def test_indexes_2dsphere(self): """Ensure that indexes are created automatically for GeoPointFields. @@ -328,9 +328,9 @@ class TestGeoField(MongoDBTestCase): polygon = PolygonField() geo_indicies = Event._geo_indices() - self.assertIn({"fields": [("line", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("polygon", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("point", "2dsphere")]}, geo_indicies) + assert {"fields": [("line", "2dsphere")]} in geo_indicies + assert {"fields": [("polygon", "2dsphere")]} in geo_indicies + assert {"fields": [("point", "2dsphere")]} in geo_indicies def test_indexes_2dsphere_embedded(self): """Ensure that indexes are created automatically for GeoPointFields. @@ -347,9 +347,9 @@ class TestGeoField(MongoDBTestCase): venue = EmbeddedDocumentField(Venue) geo_indicies = Event._geo_indices() - self.assertIn({"fields": [("venue.line", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("venue.polygon", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("venue.point", "2dsphere")]}, geo_indicies) + assert {"fields": [("venue.line", "2dsphere")]} in geo_indicies + assert {"fields": [("venue.polygon", "2dsphere")]} in geo_indicies + assert {"fields": [("venue.point", "2dsphere")]} in geo_indicies def test_geo_indexes_recursion(self): class Location(Document): @@ -365,12 +365,12 @@ class TestGeoField(MongoDBTestCase): Parent(name="Berlin").save() info = Parent._get_collection().index_information() - self.assertNotIn("location_2d", info) + assert "location_2d" not in info info = Location._get_collection().index_information() - self.assertIn("location_2d", info) + assert "location_2d" in info - self.assertEqual(len(Parent._geo_indices()), 0) - self.assertEqual(len(Location._geo_indices()), 1) + assert len(Parent._geo_indices()) == 0 + assert len(Location._geo_indices()) == 1 def test_geo_indexes_auto_index(self): @@ -381,16 +381,16 @@ class TestGeoField(MongoDBTestCase): meta = {"indexes": [[("location", "2dsphere"), ("datetime", 1)]]} - self.assertEqual([], Log._geo_indices()) + assert [] == Log._geo_indices() Log.drop_collection() Log.ensure_indexes() info = Log._get_collection().index_information() - self.assertEqual( - info["location_2dsphere_datetime_1"]["key"], - [("location", "2dsphere"), ("datetime", 1)], - ) + assert info["location_2dsphere_datetime_1"]["key"] == [ + ("location", "2dsphere"), + ("datetime", 1), + ] # Test listing explicitly class Log(Document): @@ -401,16 +401,16 @@ class TestGeoField(MongoDBTestCase): "indexes": [{"fields": [("location", "2dsphere"), ("datetime", 1)]}] } - self.assertEqual([], Log._geo_indices()) + assert [] == Log._geo_indices() Log.drop_collection() Log.ensure_indexes() info = Log._get_collection().index_information() - self.assertEqual( - info["location_2dsphere_datetime_1"]["key"], - [("location", "2dsphere"), ("datetime", 1)], - ) + assert info["location_2dsphere_datetime_1"]["key"] == [ + ("location", "2dsphere"), + ("datetime", 1), + ] if __name__ == "__main__": diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py index b7db0416..65a5fbad 100644 --- a/tests/fields/test_int_field.py +++ b/tests/fields/test_int_field.py @@ -2,6 +2,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestIntField(MongoDBTestCase): @@ -23,11 +24,14 @@ class TestIntField(MongoDBTestCase): person.validate() person.age = -1 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.age = 120 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.age = "ten" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_ne_operator(self): class TestDocument(Document): @@ -38,5 +42,5 @@ class TestIntField(MongoDBTestCase): TestDocument(int_fld=None).save() TestDocument(int_fld=1).save() - self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) - self.assertEqual(1, TestDocument.objects(int_fld__ne=1).count()) + assert 1 == TestDocument.objects(int_fld__ne=None).count() + assert 1 == TestDocument.objects(int_fld__ne=1).count() diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index 2a686d7f..8150574d 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -5,13 +5,15 @@ from mongoengine import * from mongoengine.base import LazyReference from tests.utils import MongoDBTestCase +import pytest class TestLazyReferenceField(MongoDBTestCase): def test_lazy_reference_config(self): # Make sure ReferenceField only accepts a document class or a string # with a document class name. - self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) + with pytest.raises(ValidationError): + LazyReferenceField(EmbeddedDocument) def test___repr__(self): class Animal(Document): @@ -25,7 +27,7 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal() oc = Ocurrence(animal=animal) - self.assertIn("LazyReference", repr(oc.animal)) + assert "LazyReference" in repr(oc.animal) def test___getattr___unknown_attr_raises_attribute_error(self): class Animal(Document): @@ -39,7 +41,7 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal().save() oc = Ocurrence(animal=animal) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): oc.animal.not_exist def test_lazy_reference_simple(self): @@ -57,19 +59,19 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() Ocurrence(person="test", animal=animal).save() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) + assert isinstance(p.animal, LazyReference) fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) + assert fetched_animal == animal # `fetch` keep cache on referenced document by default... animal.tag = "not so heavy" animal.save() double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") + assert fetched_animal is double_fetch + assert double_fetch.tag == "heavy" # ...unless specified otherwise fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") + assert fetch_force is not fetched_animal + assert fetch_force.tag == "not so heavy" def test_lazy_reference_fetch_invalid_ref(self): class Animal(Document): @@ -87,8 +89,8 @@ class TestLazyReferenceField(MongoDBTestCase): Ocurrence(person="test", animal=animal).save() animal.delete() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(DoesNotExist): + assert isinstance(p.animal, LazyReference) + with pytest.raises(DoesNotExist): p.animal.fetch() def test_lazy_reference_set(self): @@ -122,7 +124,7 @@ class TestLazyReferenceField(MongoDBTestCase): ): p = Ocurrence(person="test", animal=ref).save() p.reload() - self.assertIsInstance(p.animal, LazyReference) + assert isinstance(p.animal, LazyReference) p.animal.fetch() def test_lazy_reference_bad_set(self): @@ -149,7 +151,7 @@ class TestLazyReferenceField(MongoDBTestCase): DBRef(baddoc._get_collection_name(), animal.pk), LazyReference(BadDoc, animal.pk), ): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): p = Ocurrence(person="test", animal=bad).save() def test_lazy_reference_query_conversion(self): @@ -179,14 +181,14 @@ class TestLazyReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id # Same thing by passing a LazyReference instance post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_lazy_reference_query_conversion_dbref(self): """Ensure that LazyReferenceFields can be queried using objects and values @@ -215,14 +217,14 @@ class TestLazyReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id # Same thing by passing a LazyReference instance post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_lazy_reference_passthrough(self): class Animal(Document): @@ -239,20 +241,20 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() Ocurrence(animal=animal, animal_passthrough=animal).save() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(KeyError): + assert isinstance(p.animal, LazyReference) + with pytest.raises(KeyError): p.animal["name"] - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): p.animal.name - self.assertEqual(p.animal.pk, animal.pk) + assert p.animal.pk == animal.pk - self.assertEqual(p.animal_passthrough.name, "Leopard") - self.assertEqual(p.animal_passthrough["name"], "Leopard") + assert p.animal_passthrough.name == "Leopard" + assert p.animal_passthrough["name"] == "Leopard" # Should not be able to access referenced document's methods - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): p.animal.save - with self.assertRaises(KeyError): + with pytest.raises(KeyError): p.animal["save"] def test_lazy_reference_not_set(self): @@ -269,7 +271,7 @@ class TestLazyReferenceField(MongoDBTestCase): Ocurrence(person="foo").save() p = Ocurrence.objects.get() - self.assertIs(p.animal, None) + assert p.animal is None def test_lazy_reference_equality(self): class Animal(Document): @@ -280,12 +282,12 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() animalref = LazyReference(Animal, animal.pk) - self.assertEqual(animal, animalref) - self.assertEqual(animalref, animal) + assert animal == animalref + assert animalref == animal other_animalref = LazyReference(Animal, ObjectId("54495ad94c934721ede76f90")) - self.assertNotEqual(animal, other_animalref) - self.assertNotEqual(other_animalref, animal) + assert animal != other_animalref + assert other_animalref != animal def test_lazy_reference_embedded(self): class Animal(Document): @@ -308,12 +310,12 @@ class TestLazyReferenceField(MongoDBTestCase): animal2 = Animal(name="cheeta").save() def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) + assert isinstance(occ.direct, LazyReference) for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) + assert isinstance(elem, LazyReference) + assert isinstance(occ.in_embedded.direct, LazyReference) for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) + assert isinstance(elem, LazyReference) occ = Ocurrence( in_list=[animal1, animal2], @@ -346,19 +348,19 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() Ocurrence(person="test", animal=animal).save() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) + assert isinstance(p.animal, LazyReference) fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) + assert fetched_animal == animal # `fetch` keep cache on referenced document by default... animal.tag = "not so heavy" animal.save() double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") + assert fetched_animal is double_fetch + assert double_fetch.tag == "heavy" # ...unless specified otherwise fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") + assert fetch_force is not fetched_animal + assert fetch_force.tag == "not so heavy" def test_generic_lazy_reference_choices(self): class Animal(Document): @@ -385,13 +387,13 @@ class TestGenericLazyReferenceField(MongoDBTestCase): occ_animal = Ocurrence(living_thing=animal, thing=animal).save() occ_vegetal = Ocurrence(living_thing=vegetal, thing=vegetal).save() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ocurrence(living_thing=mineral).save() occ = Ocurrence.objects.get(living_thing=animal) - self.assertEqual(occ, occ_animal) - self.assertIsInstance(occ.thing, LazyReference) - self.assertIsInstance(occ.living_thing, LazyReference) + assert occ == occ_animal + assert isinstance(occ.thing, LazyReference) + assert isinstance(occ.living_thing, LazyReference) occ.thing = vegetal occ.living_thing = vegetal @@ -399,7 +401,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): occ.thing = mineral occ.living_thing = mineral - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): occ.save() def test_generic_lazy_reference_set(self): @@ -434,7 +436,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): ): p = Ocurrence(person="test", animal=ref).save() p.reload() - self.assertIsInstance(p.animal, (LazyReference, Document)) + assert isinstance(p.animal, (LazyReference, Document)) p.animal.fetch() def test_generic_lazy_reference_bad_set(self): @@ -455,7 +457,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() baddoc = BadDoc().save() for bad in (42, "foo", baddoc, LazyReference(BadDoc, animal.pk)): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): p = Ocurrence(person="test", animal=bad).save() def test_generic_lazy_reference_query_conversion(self): @@ -481,14 +483,14 @@ class TestGenericLazyReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id # Same thing by passing a LazyReference instance post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_generic_lazy_reference_not_set(self): class Animal(Document): @@ -504,7 +506,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): Ocurrence(person="foo").save() p = Ocurrence.objects.get() - self.assertIs(p.animal, None) + assert p.animal is None def test_generic_lazy_reference_accepts_string_instead_of_class(self): class Animal(Document): @@ -521,7 +523,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal().save() Ocurrence(animal=animal).save() p = Ocurrence.objects.get() - self.assertEqual(p.animal, animal) + assert p.animal == animal def test_generic_lazy_reference_embedded(self): class Animal(Document): @@ -544,12 +546,12 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal2 = Animal(name="cheeta").save() def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) + assert isinstance(occ.direct, LazyReference) for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) + assert isinstance(elem, LazyReference) + assert isinstance(occ.in_embedded.direct, LazyReference) for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) + assert isinstance(elem, LazyReference) occ = Ocurrence( in_list=[animal1, animal2], diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py index ab86eccd..51f8e255 100644 --- a/tests/fields/test_long_field.py +++ b/tests/fields/test_long_field.py @@ -10,6 +10,7 @@ from mongoengine import * from mongoengine.connection import get_db from tests.utils import MongoDBTestCase +import pytest class TestLongField(MongoDBTestCase): @@ -24,10 +25,10 @@ class TestLongField(MongoDBTestCase): doc = TestLongFieldConsideredAsInt64(some_long=42).save() db = get_db() - self.assertIsInstance( + assert isinstance( db.test_long_field_considered_as_int64.find()[0]["some_long"], Int64 ) - self.assertIsInstance(doc.some_long, six.integer_types) + assert isinstance(doc.some_long, six.integer_types) def test_long_validation(self): """Ensure that invalid values cannot be assigned to long fields. @@ -41,11 +42,14 @@ class TestLongField(MongoDBTestCase): doc.validate() doc.value = -1 - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() doc.value = 120 - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() doc.value = "ten" - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() def test_long_ne_operator(self): class TestDocument(Document): @@ -56,4 +60,4 @@ class TestLongField(MongoDBTestCase): TestDocument(long_fld=None).save() TestDocument(long_fld=1).save() - self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) + assert 1 == TestDocument.objects(long_fld__ne=None).count() diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py index 54f70aa1..fd56ddd0 100644 --- a/tests/fields/test_map_field.py +++ b/tests/fields/test_map_field.py @@ -4,6 +4,7 @@ import datetime from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestMapField(MongoDBTestCase): @@ -19,11 +20,11 @@ class TestMapField(MongoDBTestCase): e.mapping["someint"] = 1 e.save() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping["somestring"] = "abc" e.save() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): class NoDeclaredType(Document): mapping = MapField() @@ -51,10 +52,10 @@ class TestMapField(MongoDBTestCase): e.save() e2 = Extensible.objects.get(id=e.id) - self.assertIsInstance(e2.mapping["somestring"], StringSetting) - self.assertIsInstance(e2.mapping["someint"], IntegerSetting) + assert isinstance(e2.mapping["somestring"], StringSetting) + assert isinstance(e2.mapping["someint"], IntegerSetting) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping["someint"] = 123 e.save() @@ -74,9 +75,9 @@ class TestMapField(MongoDBTestCase): Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) test = Test.objects.get() - self.assertEqual(test.my_map["DICTIONARY_KEY"].number, 2) + assert test.my_map["DICTIONARY_KEY"].number == 2 doc = self.db.test.find_one() - self.assertEqual(doc["x"]["DICTIONARY_KEY"]["i"], 2) + assert doc["x"]["DICTIONARY_KEY"]["i"] == 2 def test_mapfield_numerical_index(self): """Ensure that MapField accept numeric strings as indexes.""" @@ -116,13 +117,13 @@ class TestMapField(MongoDBTestCase): actions={"friends": Action(operation="drink", object="beer")}, ).save() - self.assertEqual(1, Log.objects(visited__friends__exists=True).count()) + assert 1 == Log.objects(visited__friends__exists=True).count() - self.assertEqual( - 1, - Log.objects( + assert ( + 1 + == Log.objects( actions__friends__operation="drink", actions__friends__object="beer" - ).count(), + ).count() ) def test_map_field_unicode(self): @@ -139,7 +140,7 @@ class TestMapField(MongoDBTestCase): tree.save() - self.assertEqual( - BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, - u"VALUE: éééé", + assert ( + BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description + == u"VALUE: éééé" ) diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index 783a46da..783d1315 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -4,6 +4,7 @@ from bson import DBRef, SON from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestReferenceField(MongoDBTestCase): @@ -24,19 +25,22 @@ class TestReferenceField(MongoDBTestCase): # Make sure ReferenceField only accepts a document class or a string # with a document class name. - self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) + with pytest.raises(ValidationError): + ReferenceField(EmbeddedDocument) user = User(name="Test User") # Ensure that the referenced object must have been saved post1 = BlogPost(content="Chips and gravy taste good.") post1.author = user - self.assertRaises(ValidationError, post1.save) + with pytest.raises(ValidationError): + post1.save() # Check that an invalid object type cannot be used post2 = BlogPost(content="Chips and chilli taste good.") post1.author = post2 - self.assertRaises(ValidationError, post1.validate) + with pytest.raises(ValidationError): + post1.validate() # Ensure ObjectID's are accepted as references user_object_id = user.pk @@ -52,7 +56,8 @@ class TestReferenceField(MongoDBTestCase): # Make sure referencing a saved document of the *wrong* type fails post2.save() post1.author = post2 - self.assertRaises(ValidationError, post1.validate) + with pytest.raises(ValidationError): + post1.validate() def test_objectid_reference_fields(self): """Make sure storing Object ID references works.""" @@ -67,7 +72,7 @@ class TestReferenceField(MongoDBTestCase): Person(name="Ross", parent=p1.pk).save() p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) + assert p.parent == p1 def test_dbref_reference_fields(self): """Make sure storing references as bson.dbref.DBRef works.""" @@ -81,13 +86,12 @@ class TestReferenceField(MongoDBTestCase): p1 = Person(name="John").save() Person(name="Ross", parent=p1).save() - self.assertEqual( - Person._get_collection().find_one({"name": "Ross"})["parent"], - DBRef("person", p1.pk), + assert Person._get_collection().find_one({"name": "Ross"})["parent"] == DBRef( + "person", p1.pk ) p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) + assert p.parent == p1 def test_dbref_to_mongo(self): """Make sure that calling to_mongo on a ReferenceField which @@ -100,9 +104,7 @@ class TestReferenceField(MongoDBTestCase): parent = ReferenceField("self", dbref=False) p = Person(name="Steve", parent=DBRef("person", "abcdefghijklmnop")) - self.assertEqual( - p.to_mongo(), SON([("name", u"Steve"), ("parent", "abcdefghijklmnop")]) - ) + assert p.to_mongo() == SON([("name", u"Steve"), ("parent", "abcdefghijklmnop")]) def test_objectid_reference_fields(self): class Person(Document): @@ -116,10 +118,10 @@ class TestReferenceField(MongoDBTestCase): col = Person._get_collection() data = col.find_one({"name": "Ross"}) - self.assertEqual(data["parent"], p1.pk) + assert data["parent"] == p1.pk p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) + assert p.parent == p1 def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. @@ -144,14 +146,14 @@ class TestReferenceField(MongoDBTestCase): me.save() obj = Product.objects(company=ten_gen).first() - self.assertEqual(obj, mongodb) - self.assertEqual(obj.company, ten_gen) + assert obj == mongodb + assert obj.company == ten_gen obj = Product.objects(company=None).first() - self.assertEqual(obj, me) + assert obj == me obj = Product.objects.get(company=None) - self.assertEqual(obj, me) + assert obj == me def test_reference_query_conversion(self): """Ensure that ReferenceFields can be queried using objects and values @@ -180,10 +182,10 @@ class TestReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_reference_query_conversion_dbref(self): """Ensure that ReferenceFields can be queried using objects and values @@ -212,7 +214,7 @@ class TestReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py index f2c8388b..aa83f710 100644 --- a/tests/fields/test_sequence_field.py +++ b/tests/fields/test_sequence_field.py @@ -18,17 +18,17 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 Person.id.set_next_value(1000) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 1000) + assert c["next"] == 1000 def test_sequence_field_get_next_value(self): class Person(Document): @@ -41,10 +41,10 @@ class TestSequenceField(MongoDBTestCase): for x in range(10): Person(name="Person %s" % x).save() - self.assertEqual(Person.id.get_next_value(), 11) + assert Person.id.get_next_value() == 11 self.db["mongoengine.counters"].drop() - self.assertEqual(Person.id.get_next_value(), 1) + assert Person.id.get_next_value() == 1 class Person(Document): id = SequenceField(primary_key=True, value_decorator=str) @@ -56,10 +56,10 @@ class TestSequenceField(MongoDBTestCase): for x in range(10): Person(name="Person %s" % x).save() - self.assertEqual(Person.id.get_next_value(), "11") + assert Person.id.get_next_value() == "11" self.db["mongoengine.counters"].drop() - self.assertEqual(Person.id.get_next_value(), "1") + assert Person.id.get_next_value() == "1" def test_sequence_field_sequence_name(self): class Person(Document): @@ -73,17 +73,17 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 Person.id.set_next_value(1000) c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) - self.assertEqual(c["next"], 1000) + assert c["next"] == 1000 def test_multiple_sequence_fields(self): class Person(Document): @@ -98,24 +98,24 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) counters = [i.counter for i in Person.objects] - self.assertEqual(counters, range(1, 11)) + assert counters == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 Person.id.set_next_value(1000) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 1000) + assert c["next"] == 1000 Person.counter.set_next_value(999) c = self.db["mongoengine.counters"].find_one({"_id": "person.counter"}) - self.assertEqual(c["next"], 999) + assert c["next"] == 999 def test_sequence_fields_reload(self): class Animal(Document): @@ -127,20 +127,20 @@ class TestSequenceField(MongoDBTestCase): a = Animal(name="Boi").save() - self.assertEqual(a.counter, 1) + assert a.counter == 1 a.reload() - self.assertEqual(a.counter, 1) + assert a.counter == 1 a.counter = None - self.assertEqual(a.counter, 2) + assert a.counter == 2 a.save() - self.assertEqual(a.counter, 2) + assert a.counter == 2 a = Animal.objects.first() - self.assertEqual(a.counter, 2) + assert a.counter == 2 a.reload() - self.assertEqual(a.counter, 2) + assert a.counter == 2 def test_multiple_sequence_fields_on_docs(self): class Animal(Document): @@ -160,22 +160,22 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) id = [i.id for i in Animal.objects] - self.assertEqual(id, range(1, 11)) + assert id == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 def test_sequence_field_value_decorator(self): class Person(Document): @@ -190,13 +190,13 @@ class TestSequenceField(MongoDBTestCase): p.save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, map(str, range(1, 11))) + assert ids == map(str, range(1, 11)) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): @@ -218,10 +218,10 @@ class TestSequenceField(MongoDBTestCase): ], ).save() c = self.db["mongoengine.counters"].find_one({"_id": "comment.id"}) - self.assertEqual(c["next"], 2) + assert c["next"] == 2 post = Post.objects.first() - self.assertEqual(1, post.comments[0].id) - self.assertEqual(2, post.comments[1].id) + assert 1 == post.comments[0].id + assert 2 == post.comments[1].id def test_inherited_sequencefield(self): class Base(Document): @@ -241,16 +241,14 @@ class TestSequenceField(MongoDBTestCase): foo = Foo(name="Foo") foo.save() - self.assertTrue( - "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") - ) - self.assertFalse( + assert "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") + assert not ( ("foo.counter" or "bar.counter") in self.db["mongoengine.counters"].find().distinct("_id") ) - self.assertNotEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields["counter"].owner_document, Base) - self.assertEqual(bar._fields["counter"].owner_document, Base) + assert foo.counter != bar.counter + assert foo._fields["counter"].owner_document == Base + assert bar._fields["counter"].owner_document == Base def test_no_inherited_sequencefield(self): class Base(Document): @@ -269,13 +267,12 @@ class TestSequenceField(MongoDBTestCase): foo = Foo(name="Foo") foo.save() - self.assertFalse( + assert not ( "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") ) - self.assertTrue( - ("foo.counter" and "bar.counter") - in self.db["mongoengine.counters"].find().distinct("_id") - ) - self.assertEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields["counter"].owner_document, Foo) - self.assertEqual(bar._fields["counter"].owner_document, Bar) + assert ("foo.counter" and "bar.counter") in self.db[ + "mongoengine.counters" + ].find().distinct("_id") + assert foo.counter == bar.counter + assert foo._fields["counter"].owner_document == Foo + assert bar._fields["counter"].owner_document == Bar diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index 81baf8d0..e7df0e08 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -2,6 +2,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestURLField(MongoDBTestCase): @@ -13,7 +14,8 @@ class TestURLField(MongoDBTestCase): link = Link() link.url = "google" - self.assertRaises(ValidationError, link.validate) + with pytest.raises(ValidationError): + link.validate() link.url = "http://www.google.com:8080" link.validate() @@ -29,11 +31,11 @@ class TestURLField(MongoDBTestCase): # TODO fix URL validation - this *IS* a valid URL # For now we just want to make sure that the error message is correct - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: link.validate() - self.assertEqual( - unicode(ctx_err.exception), - u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])", + assert ( + unicode(ctx_err.exception) + == u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])" ) def test_url_scheme_validation(self): @@ -48,7 +50,8 @@ class TestURLField(MongoDBTestCase): link = Link() link.url = "ws://google.com" - self.assertRaises(ValidationError, link.validate) + with pytest.raises(ValidationError): + link.validate() scheme_link = SchemeLink() scheme_link.url = "ws://google.com" diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py index 647dceaf..b1413f95 100644 --- a/tests/fields/test_uuid_field.py +++ b/tests/fields/test_uuid_field.py @@ -4,6 +4,7 @@ import uuid from mongoengine import * from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest class Person(Document): @@ -14,9 +15,7 @@ class TestUUIDField(MongoDBTestCase): def test_storage(self): uid = uuid.uuid4() person = Person(api_key=uid).save() - self.assertEqual( - get_as_pymongo(person), {"_id": person.id, "api_key": str(uid)} - ) + assert get_as_pymongo(person) == {"_id": person.id, "api_key": str(uid)} def test_field_string(self): """Test UUID fields storing as String @@ -25,8 +24,8 @@ class TestUUIDField(MongoDBTestCase): uu = uuid.uuid4() Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) + assert 1 == Person.objects(api_key=uu).count() + assert uu == Person.objects.first().api_key person = Person() valid = (uuid.uuid4(), uuid.uuid1()) @@ -40,7 +39,8 @@ class TestUUIDField(MongoDBTestCase): ) for api_key in invalid: person.api_key = api_key - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_field_binary(self): """Test UUID fields storing as Binary object.""" @@ -48,8 +48,8 @@ class TestUUIDField(MongoDBTestCase): uu = uuid.uuid4() Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) + assert 1 == Person.objects(api_key=uu).count() + assert uu == Person.objects.first().api_key person = Person() valid = (uuid.uuid4(), uuid.uuid1()) @@ -63,4 +63,5 @@ class TestUUIDField(MongoDBTestCase): ) for api_key in invalid: person.api_key = api_key - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() diff --git a/tests/queryset/test_field_list.py b/tests/queryset/test_field_list.py index 703c2031..d33c4c86 100644 --- a/tests/queryset/test_field_list.py +++ b/tests/queryset/test_field_list.py @@ -2,66 +2,67 @@ import unittest from mongoengine import * from mongoengine.queryset import QueryFieldList +import pytest class TestQueryFieldList(unittest.TestCase): def test_empty(self): q = QueryFieldList() - self.assertFalse(q) + assert not q q = QueryFieldList(always_include=["_cls"]) - self.assertFalse(q) + assert not q def test_include_include(self): q = QueryFieldList() q += QueryFieldList( fields=["a", "b"], value=QueryFieldList.ONLY, _only_called=True ) - self.assertEqual(q.as_dict(), {"a": 1, "b": 1}) + assert q.as_dict() == {"a": 1, "b": 1} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"a": 1, "b": 1, "c": 1}) + assert q.as_dict() == {"a": 1, "b": 1, "c": 1} def test_include_exclude(self): q = QueryFieldList() q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"a": 1, "b": 1}) + assert q.as_dict() == {"a": 1, "b": 1} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 1}) + assert q.as_dict() == {"a": 1} def test_exclude_exclude(self): q = QueryFieldList() q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 0, "b": 0}) + assert q.as_dict() == {"a": 0, "b": 0} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 0, "b": 0, "c": 0}) + assert q.as_dict() == {"a": 0, "b": 0, "c": 0} def test_exclude_include(self): q = QueryFieldList() q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 0, "b": 0}) + assert q.as_dict() == {"a": 0, "b": 0} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"c": 1}) + assert q.as_dict() == {"c": 1} def test_always_include(self): q = QueryFieldList(always_include=["x", "y"]) q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1}) + assert q.as_dict() == {"x": 1, "y": 1, "c": 1} def test_reset(self): q = QueryFieldList(always_include=["x", "y"]) q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1}) + assert q.as_dict() == {"x": 1, "y": 1, "c": 1} q.reset() - self.assertFalse(q) + assert not q q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "b": 1, "c": 1}) + assert q.as_dict() == {"x": 1, "y": 1, "b": 1, "c": 1} def test_using_a_slice(self): q = QueryFieldList() q += QueryFieldList(fields=["a"], value={"$slice": 5}) - self.assertEqual(q.as_dict(), {"a": {"$slice": 5}}) + assert q.as_dict() == {"a": {"$slice": 5}} class TestOnlyExcludeAll(unittest.TestCase): @@ -90,25 +91,23 @@ class TestOnlyExcludeAll(unittest.TestCase): only = ["b", "c"] qs = MyDoc.objects.fields(**{i: 1 for i in include}) - self.assertEqual( - qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1} - ) + assert qs._loaded_fields.as_dict() == {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1} qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} qs = qs.exclude(*exclude) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} qs = MyDoc.objects.fields(**{i: 1 for i in include}) qs = qs.exclude(*exclude) - self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"a": 1, "b": 1, "c": 1} qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} qs = MyDoc.objects.exclude(*exclude) qs = qs.fields(**{i: 1 for i in include}) - self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"a": 1, "b": 1, "c": 1} qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} def test_slicing(self): class MyDoc(Document): @@ -127,15 +126,16 @@ class TestOnlyExcludeAll(unittest.TestCase): qs = qs.exclude(*exclude) qs = qs.only(*only) qs = qs.fields(slice__b=5) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": {"$slice": 5}, "c": 1} qs = qs.fields(slice__c=[5, 1]) - self.assertEqual( - qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": {"$slice": [5, 1]}} - ) + assert qs._loaded_fields.as_dict() == { + "b": {"$slice": 5}, + "c": {"$slice": [5, 1]}, + } qs = qs.exclude("c") - self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}}) + assert qs._loaded_fields.as_dict() == {"b": {"$slice": 5}} def test_mix_slice_with_other_fields(self): class MyDoc(Document): @@ -144,7 +144,7 @@ class TestOnlyExcludeAll(unittest.TestCase): c = ListField() qs = MyDoc.objects.fields(a=1, b=0, slice__c=2) - self.assertEqual(qs._loaded_fields.as_dict(), {"c": {"$slice": 2}, "a": 1}) + assert qs._loaded_fields.as_dict() == {"c": {"$slice": 2}, "a": 1} def test_only(self): """Ensure that QuerySet.only only returns the requested fields. @@ -153,20 +153,20 @@ class TestOnlyExcludeAll(unittest.TestCase): person.save() obj = self.Person.objects.only("name").get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, None) + assert obj.name == person.name + assert obj.age == None obj = self.Person.objects.only("age").get() - self.assertEqual(obj.name, None) - self.assertEqual(obj.age, person.age) + assert obj.name == None + assert obj.age == person.age obj = self.Person.objects.only("name", "age").get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, person.age) + assert obj.name == person.name + assert obj.age == person.age obj = self.Person.objects.only(*("id", "name")).get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, None) + assert obj.name == person.name + assert obj.age == None # Check polymorphism still works class Employee(self.Person): @@ -176,12 +176,12 @@ class TestOnlyExcludeAll(unittest.TestCase): employee.save() obj = self.Person.objects(id=employee.id).only("age").get() - self.assertIsInstance(obj, Employee) + assert isinstance(obj, Employee) # Check field names are looked up properly obj = Employee.objects(id=employee.id).only("salary").get() - self.assertEqual(obj.salary, employee.salary) - self.assertEqual(obj.name, None) + assert obj.salary == employee.salary + assert obj.name == None def test_only_with_subfields(self): class User(EmbeddedDocument): @@ -215,29 +215,29 @@ class TestOnlyExcludeAll(unittest.TestCase): post.save() obj = BlogPost.objects.only("author.name").get() - self.assertEqual(obj.content, None) - self.assertEqual(obj.author.email, None) - self.assertEqual(obj.author.name, "Test User") - self.assertEqual(obj.comments, []) + assert obj.content == None + assert obj.author.email == None + assert obj.author.name == "Test User" + assert obj.comments == [] obj = BlogPost.objects.only("various.test_dynamic.some").get() - self.assertEqual(obj.various["test_dynamic"].some, True) + assert obj.various["test_dynamic"].some == True obj = BlogPost.objects.only("content", "comments.title").get() - self.assertEqual(obj.content, "Had a good coffee today...") - self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, "I aggree") - self.assertEqual(obj.comments[1].title, "Coffee") - self.assertEqual(obj.comments[0].text, None) - self.assertEqual(obj.comments[1].text, None) + assert obj.content == "Had a good coffee today..." + assert obj.author == None + assert obj.comments[0].title == "I aggree" + assert obj.comments[1].title == "Coffee" + assert obj.comments[0].text == None + assert obj.comments[1].text == None obj = BlogPost.objects.only("comments").get() - self.assertEqual(obj.content, None) - self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, "I aggree") - self.assertEqual(obj.comments[1].title, "Coffee") - self.assertEqual(obj.comments[0].text, "Great post!") - self.assertEqual(obj.comments[1].text, "I hate coffee") + assert obj.content == None + assert obj.author == None + assert obj.comments[0].title == "I aggree" + assert obj.comments[1].title == "Coffee" + assert obj.comments[0].text == "Great post!" + assert obj.comments[1].text == "I hate coffee" BlogPost.drop_collection() @@ -266,10 +266,10 @@ class TestOnlyExcludeAll(unittest.TestCase): post.save() obj = BlogPost.objects.exclude("author", "comments.text").get() - self.assertEqual(obj.author, None) - self.assertEqual(obj.content, "Had a good coffee today...") - self.assertEqual(obj.comments[0].title, "I aggree") - self.assertEqual(obj.comments[0].text, None) + assert obj.author == None + assert obj.content == "Had a good coffee today..." + assert obj.comments[0].title == "I aggree" + assert obj.comments[0].text == None BlogPost.drop_collection() @@ -301,18 +301,18 @@ class TestOnlyExcludeAll(unittest.TestCase): email.save() obj = Email.objects.exclude("content_type").exclude("body").get() - self.assertEqual(obj.sender, "me") - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, "From Russia with Love") - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) + assert obj.sender == "me" + assert obj.to == "you" + assert obj.subject == "From Russia with Love" + assert obj.body == None + assert obj.content_type == None obj = Email.objects.only("sender", "to").exclude("body", "sender").get() - self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, None) - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) + assert obj.sender == None + assert obj.to == "you" + assert obj.subject == None + assert obj.body == None + assert obj.content_type == None obj = ( Email.objects.exclude("attachments.content") @@ -320,13 +320,13 @@ class TestOnlyExcludeAll(unittest.TestCase): .only("to", "attachments.name") .get() ) - self.assertEqual(obj.attachments[0].name, "file1.doc") - self.assertEqual(obj.attachments[0].content, None) - self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, None) - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) + assert obj.attachments[0].name == "file1.doc" + assert obj.attachments[0].content == None + assert obj.sender == None + assert obj.to == "you" + assert obj.subject == None + assert obj.body == None + assert obj.content_type == None Email.drop_collection() @@ -355,11 +355,11 @@ class TestOnlyExcludeAll(unittest.TestCase): .all_fields() .get() ) - self.assertEqual(obj.sender, "me") - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, "From Russia with Love") - self.assertEqual(obj.body, "Hello!") - self.assertEqual(obj.content_type, "text/plain") + assert obj.sender == "me" + assert obj.to == "you" + assert obj.subject == "From Russia with Love" + assert obj.body == "Hello!" + assert obj.content_type == "text/plain" Email.drop_collection() @@ -377,27 +377,27 @@ class TestOnlyExcludeAll(unittest.TestCase): # first three numbers = Numbers.objects.fields(slice__n=3).get() - self.assertEqual(numbers.n, [0, 1, 2]) + assert numbers.n == [0, 1, 2] # last three numbers = Numbers.objects.fields(slice__n=-3).get() - self.assertEqual(numbers.n, [-3, -2, -1]) + assert numbers.n == [-3, -2, -1] # skip 2, limit 3 numbers = Numbers.objects.fields(slice__n=[2, 3]).get() - self.assertEqual(numbers.n, [2, 3, 4]) + assert numbers.n == [2, 3, 4] # skip to fifth from last, limit 4 numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2]) + assert numbers.n == [-5, -4, -3, -2] # skip to fifth from last, limit 10 numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) + assert numbers.n == [-5, -4, -3, -2, -1] # skip to fifth from last, limit 10 dict method numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) + assert numbers.n == [-5, -4, -3, -2, -1] def test_slicing_nested_fields(self): """Ensure that query slicing an embedded array works. @@ -417,27 +417,27 @@ class TestOnlyExcludeAll(unittest.TestCase): # first three numbers = Numbers.objects.fields(slice__embedded__n=3).get() - self.assertEqual(numbers.embedded.n, [0, 1, 2]) + assert numbers.embedded.n == [0, 1, 2] # last three numbers = Numbers.objects.fields(slice__embedded__n=-3).get() - self.assertEqual(numbers.embedded.n, [-3, -2, -1]) + assert numbers.embedded.n == [-3, -2, -1] # skip 2, limit 3 numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() - self.assertEqual(numbers.embedded.n, [2, 3, 4]) + assert numbers.embedded.n == [2, 3, 4] # skip to fifth from last, limit 4 numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2]) + assert numbers.embedded.n == [-5, -4, -3, -2] # skip to fifth from last, limit 10 numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + assert numbers.embedded.n == [-5, -4, -3, -2, -1] # skip to fifth from last, limit 10 dict method numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + assert numbers.embedded.n == [-5, -4, -3, -2, -1] def test_exclude_from_subclasses_docs(self): class Base(Document): @@ -456,9 +456,10 @@ class TestOnlyExcludeAll(unittest.TestCase): User(username="mongodb", password="secret").save() user = Base.objects().exclude("password", "wibble").first() - self.assertEqual(user.password, None) + assert user.password == None - self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + with pytest.raises(LookUpError): + Base.objects.exclude("made_up") if __name__ == "__main__": diff --git a/tests/queryset/test_geo.py b/tests/queryset/test_geo.py index 343f864b..a546fdb6 100644 --- a/tests/queryset/test_geo.py +++ b/tests/queryset/test_geo.py @@ -48,14 +48,14 @@ class TestGeoQueries(MongoDBTestCase): # note that "near" will show the san francisco event, too, # although it sorts to last. events = self.Event.objects(location__near=[-87.67892, 41.9120459]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) + assert events.count() == 3 + assert list(events) == [event1, event3, event2] # ensure ordering is respected by "near" events = self.Event.objects(location__near=[-87.67892, 41.9120459]) events = events.order_by("-date") - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event3, event1, event2]) + assert events.count() == 3 + assert list(events) == [event3, event1, event2] def test_near_and_max_distance(self): """Ensure the "max_distance" operator works alongside the "near" @@ -66,8 +66,8 @@ class TestGeoQueries(MongoDBTestCase): # find events within 10 degrees of san francisco point = [-122.415579, 37.7566023] events = self.Event.objects(location__near=point, location__max_distance=10) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 def test_near_and_min_distance(self): """Ensure the "min_distance" operator works alongside the "near" @@ -78,7 +78,7 @@ class TestGeoQueries(MongoDBTestCase): # find events at least 10 degrees away of san francisco point = [-122.415579, 37.7566023] events = self.Event.objects(location__near=point, location__min_distance=10) - self.assertEqual(events.count(), 2) + assert events.count() == 2 def test_within_distance(self): """Make sure the "within_distance" operator works.""" @@ -87,29 +87,29 @@ class TestGeoQueries(MongoDBTestCase): # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[-87.67892, 41.9120459], 5] events = self.Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 2) + assert events.count() == 2 events = list(events) - self.assertNotIn(event2, events) - self.assertIn(event1, events) - self.assertIn(event3, events) + assert event2 not in events + assert event1 in events + assert event3 in events # find events within 10 degrees of san francisco point_and_distance = [[-122.415579, 37.7566023], 10] events = self.Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 # find events within 1 degree of greenpoint, broolyn, nyc, ny point_and_distance = [[-73.9509714, 40.7237134], 1] events = self.Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 0) + assert events.count() == 0 # ensure ordering is respected by "within_distance" point_and_distance = [[-87.67892, 41.9120459], 10] events = self.Event.objects(location__within_distance=point_and_distance) events = events.order_by("-date") - self.assertEqual(events.count(), 2) - self.assertEqual(events[0], event3) + assert events.count() == 2 + assert events[0] == event3 def test_within_box(self): """Ensure the "within_box" operator works.""" @@ -118,8 +118,8 @@ class TestGeoQueries(MongoDBTestCase): # check that within_box works box = [(-125.0, 35.0), (-100.0, 40.0)] events = self.Event.objects(location__within_box=box) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event2.id) + assert events.count() == 1 + assert events[0].id == event2.id def test_within_polygon(self): """Ensure the "within_polygon" operator works.""" @@ -133,8 +133,8 @@ class TestGeoQueries(MongoDBTestCase): (-87.656164, 41.898061), ] events = self.Event.objects(location__within_polygon=polygon) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event1.id) + assert events.count() == 1 + assert events[0].id == event1.id polygon2 = [ (-1.742249, 54.033586), @@ -142,7 +142,7 @@ class TestGeoQueries(MongoDBTestCase): (-4.40094, 53.389881), ] events = self.Event.objects(location__within_polygon=polygon2) - self.assertEqual(events.count(), 0) + assert events.count() == 0 def test_2dsphere_near(self): """Make sure the "near" operator works with a PointField, which @@ -154,14 +154,14 @@ class TestGeoQueries(MongoDBTestCase): # note that "near" will show the san francisco event, too, # although it sorts to last. events = self.Event.objects(location__near=[-87.67892, 41.9120459]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) + assert events.count() == 3 + assert list(events) == [event1, event3, event2] # ensure ordering is respected by "near" events = self.Event.objects(location__near=[-87.67892, 41.9120459]) events = events.order_by("-date") - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event3, event1, event2]) + assert events.count() == 3 + assert list(events) == [event3, event1, event2] def test_2dsphere_near_and_max_distance(self): """Ensure the "max_distance" operator works alongside the "near" @@ -172,21 +172,21 @@ class TestGeoQueries(MongoDBTestCase): # find events within 10km of san francisco point = [-122.415579, 37.7566023] events = self.Event.objects(location__near=point, location__max_distance=10000) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 # find events within 1km of greenpoint, broolyn, nyc, ny events = self.Event.objects( location__near=[-73.9509714, 40.7237134], location__max_distance=1000 ) - self.assertEqual(events.count(), 0) + assert events.count() == 0 # ensure ordering is respected by "near" events = self.Event.objects( location__near=[-87.67892, 41.9120459], location__max_distance=10000 ).order_by("-date") - self.assertEqual(events.count(), 2) - self.assertEqual(events[0], event3) + assert events.count() == 2 + assert events[0] == event3 def test_2dsphere_geo_within_box(self): """Ensure the "geo_within_box" operator works with a 2dsphere @@ -197,8 +197,8 @@ class TestGeoQueries(MongoDBTestCase): # check that within_box works box = [(-125.0, 35.0), (-100.0, 40.0)] events = self.Event.objects(location__geo_within_box=box) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event2.id) + assert events.count() == 1 + assert events[0].id == event2.id def test_2dsphere_geo_within_polygon(self): """Ensure the "geo_within_polygon" operator works with a @@ -214,8 +214,8 @@ class TestGeoQueries(MongoDBTestCase): (-87.656164, 41.898061), ] events = self.Event.objects(location__geo_within_polygon=polygon) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event1.id) + assert events.count() == 1 + assert events[0].id == event1.id polygon2 = [ (-1.742249, 54.033586), @@ -223,7 +223,7 @@ class TestGeoQueries(MongoDBTestCase): (-4.40094, 53.389881), ] events = self.Event.objects(location__geo_within_polygon=polygon2) - self.assertEqual(events.count(), 0) + assert events.count() == 0 def test_2dsphere_near_and_min_max_distance(self): """Ensure "min_distace" and "max_distance" operators work well @@ -237,15 +237,15 @@ class TestGeoQueries(MongoDBTestCase): location__min_distance=1000, location__max_distance=10000, ).order_by("-date") - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event3) + assert events.count() == 1 + assert events[0] == event3 # ensure ordering is respected by "near" with "min_distance" events = self.Event.objects( location__near=[-87.67892, 41.9120459], location__min_distance=10000 ).order_by("-date") - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 def test_2dsphere_geo_within_center(self): """Make sure the "geo_within_center" operator works with a @@ -256,11 +256,11 @@ class TestGeoQueries(MongoDBTestCase): # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[-87.67892, 41.9120459], 2] events = self.Event.objects(location__geo_within_center=point_and_distance) - self.assertEqual(events.count(), 2) + assert events.count() == 2 events = list(events) - self.assertNotIn(event2, events) - self.assertIn(event1, events) - self.assertIn(event3, events) + assert event2 not in events + assert event1 in events + assert event3 in events def _test_embedded(self, point_field_class): """Helper test method ensuring given point field class works @@ -290,8 +290,8 @@ class TestGeoQueries(MongoDBTestCase): # note that "near" will show the san francisco event, too, # although it sorts to last. events = Event.objects(venue__location__near=[-87.67892, 41.9120459]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) + assert events.count() == 3 + assert list(events) == [event1, event3, event2] def test_geo_spatial_embedded(self): """Make sure GeoPointField works properly in an embedded document.""" @@ -319,55 +319,55 @@ class TestGeoQueries(MongoDBTestCase): # Finds both points because they are within 60 km of the reference # point equidistant between them. points = Point.objects(location__near_sphere=[-122, 37.5]) - self.assertEqual(points.count(), 2) + assert points.count() == 2 # Same behavior for _within_spherical_distance points = Point.objects( location__within_spherical_distance=[[-122, 37.5], 60 / earth_radius] ) - self.assertEqual(points.count(), 2) + assert points.count() == 2 points = Point.objects( location__near_sphere=[-122, 37.5], location__max_distance=60 / earth_radius ) - self.assertEqual(points.count(), 2) + assert points.count() == 2 # Test query works with max_distance, being farer from one point points = Point.objects( location__near_sphere=[-122, 37.8], location__max_distance=60 / earth_radius ) close_point = points.first() - self.assertEqual(points.count(), 1) + assert points.count() == 1 # Test query works with min_distance, being farer from one point points = Point.objects( location__near_sphere=[-122, 37.8], location__min_distance=60 / earth_radius ) - self.assertEqual(points.count(), 1) + assert points.count() == 1 far_point = points.first() - self.assertNotEqual(close_point, far_point) + assert close_point != far_point # Finds both points, but orders the north point first because it's # closer to the reference point to the north. points = Point.objects(location__near_sphere=[-122, 38.5]) - self.assertEqual(points.count(), 2) - self.assertEqual(points[0].id, north_point.id) - self.assertEqual(points[1].id, south_point.id) + assert points.count() == 2 + assert points[0].id == north_point.id + assert points[1].id == south_point.id # Finds both points, but orders the south point first because it's # closer to the reference point to the south. points = Point.objects(location__near_sphere=[-122, 36.5]) - self.assertEqual(points.count(), 2) - self.assertEqual(points[0].id, south_point.id) - self.assertEqual(points[1].id, north_point.id) + assert points.count() == 2 + assert points[0].id == south_point.id + assert points[1].id == north_point.id # Finds only one point because only the first point is within 60km of # the reference point to the south. points = Point.objects( location__within_spherical_distance=[[-122, 36.5], 60 / earth_radius] ) - self.assertEqual(points.count(), 1) - self.assertEqual(points[0].id, south_point.id) + assert points.count() == 1 + assert points[0].id == south_point.id def test_linestring(self): class Road(Document): @@ -381,13 +381,13 @@ class TestGeoQueries(MongoDBTestCase): # near point = {"type": "Point", "coordinates": [40, 5]} roads = Road.objects.filter(line__near=point["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__near=point).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__near={"$geometry": point}).count() - self.assertEqual(1, roads) + assert 1 == roads # Within polygon = { @@ -395,37 +395,37 @@ class TestGeoQueries(MongoDBTestCase): "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_within=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_within={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads # Intersects line = {"type": "LineString", "coordinates": [[40, 5], [40, 6]]} roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects=line).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count() - self.assertEqual(1, roads) + assert 1 == roads polygon = { "type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads def test_polygon(self): class Road(Document): @@ -439,13 +439,13 @@ class TestGeoQueries(MongoDBTestCase): # near point = {"type": "Point", "coordinates": [40, 5]} roads = Road.objects.filter(poly__near=point["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__near=point).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__near={"$geometry": point}).count() - self.assertEqual(1, roads) + assert 1 == roads # Within polygon = { @@ -453,37 +453,37 @@ class TestGeoQueries(MongoDBTestCase): "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_within=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_within={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads # Intersects line = {"type": "LineString", "coordinates": [[40, 5], [41, 6]]} roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects=line).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count() - self.assertEqual(1, roads) + assert 1 == roads polygon = { "type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads def test_aspymongo_with_only(self): """Ensure as_pymongo works with only""" @@ -495,13 +495,10 @@ class TestGeoQueries(MongoDBTestCase): p = Place(location=[24.946861267089844, 60.16311983618494]) p.save() qs = Place.objects().only("location") - self.assertDictEqual( - qs.as_pymongo()[0]["location"], - { - u"type": u"Point", - u"coordinates": [24.946861267089844, 60.16311983618494], - }, - ) + assert qs.as_pymongo()[0]["location"] == { + u"type": u"Point", + u"coordinates": [24.946861267089844, 60.16311983618494], + } def test_2dsphere_point_sets_correctly(self): class Location(Document): @@ -511,11 +508,11 @@ class TestGeoQueries(MongoDBTestCase): Location(loc=[1, 2]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [1, 2]}) + assert loc["loc"] == {"type": "Point", "coordinates": [1, 2]} Location.objects.update(set__loc=[2, 1]) loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [2, 1]}) + assert loc["loc"] == {"type": "Point", "coordinates": [2, 1]} def test_2dsphere_linestring_sets_correctly(self): class Location(Document): @@ -525,15 +522,11 @@ class TestGeoQueries(MongoDBTestCase): Location(line=[[1, 2], [2, 2]]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["line"], {"type": "LineString", "coordinates": [[1, 2], [2, 2]]} - ) + assert loc["line"] == {"type": "LineString", "coordinates": [[1, 2], [2, 2]]} Location.objects.update(set__line=[[2, 1], [1, 2]]) loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["line"], {"type": "LineString", "coordinates": [[2, 1], [1, 2]]} - ) + assert loc["line"] == {"type": "LineString", "coordinates": [[2, 1], [1, 2]]} def test_geojson_PolygonField(self): class Location(Document): @@ -543,17 +536,17 @@ class TestGeoQueries(MongoDBTestCase): Location(poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["poly"], - {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}, - ) + assert loc["poly"] == { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], + } Location.objects.update(set__poly=[[[40, 4], [40, 6], [41, 6], [40, 4]]]) loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["poly"], - {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]}, - ) + assert loc["poly"] == { + "type": "Polygon", + "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]], + } if __name__ == "__main__": diff --git a/tests/queryset/test_modify.py b/tests/queryset/test_modify.py index 60f4884c..293a463e 100644 --- a/tests/queryset/test_modify.py +++ b/tests/queryset/test_modify.py @@ -14,14 +14,14 @@ class TestFindAndModify(unittest.TestCase): Doc.drop_collection() def assertDbEqual(self, docs): - self.assertEqual(list(Doc._collection.find().sort("id")), docs) + assert list(Doc._collection.find().sort("id")) == docs def test_modify(self): Doc(id=0, value=0).save() doc = Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).modify(set__value=-1) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_with_new(self): @@ -30,18 +30,18 @@ class TestFindAndModify(unittest.TestCase): new_doc = Doc.objects(id=1).modify(set__value=-1, new=True) doc.value = -1 - self.assertEqual(new_doc.to_json(), doc.to_json()) + assert new_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_not_existing(self): Doc(id=0, value=0).save() - self.assertEqual(Doc.objects(id=1).modify(set__value=-1), None) + assert Doc.objects(id=1).modify(set__value=-1) == None self.assertDbEqual([{"_id": 0, "value": 0}]) def test_modify_with_upsert(self): Doc(id=0, value=0).save() old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) - self.assertEqual(old_doc, None) + assert old_doc == None self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) def test_modify_with_upsert_existing(self): @@ -49,13 +49,13 @@ class TestFindAndModify(unittest.TestCase): doc = Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_with_upsert_with_new(self): Doc(id=0, value=0).save() new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1) - self.assertEqual(new_doc.to_mongo(), {"_id": 1, "value": 1}) + assert new_doc.to_mongo() == {"_id": 1, "value": 1} self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) def test_modify_with_remove(self): @@ -63,12 +63,12 @@ class TestFindAndModify(unittest.TestCase): doc = Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).modify(remove=True) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}]) def test_find_and_modify_with_remove_not_existing(self): Doc(id=0, value=0).save() - self.assertEqual(Doc.objects(id=1).modify(remove=True), None) + assert Doc.objects(id=1).modify(remove=True) == None self.assertDbEqual([{"_id": 0, "value": 0}]) def test_modify_with_order_by(self): @@ -78,7 +78,7 @@ class TestFindAndModify(unittest.TestCase): doc = Doc(id=3, value=0).save() old_doc = Doc.objects().order_by("-id").modify(set__value=-1) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual( [ {"_id": 0, "value": 3}, @@ -93,7 +93,7 @@ class TestFindAndModify(unittest.TestCase): Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).only("id").modify(set__value=-1) - self.assertEqual(old_doc.to_mongo(), {"_id": 1}) + assert old_doc.to_mongo() == {"_id": 1} self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_with_push(self): @@ -106,23 +106,23 @@ class TestFindAndModify(unittest.TestCase): # Push a new tag via modify with new=False (default). BlogPost(id=blog.id).modify(push__tags="code") - self.assertEqual(blog.tags, []) + assert blog.tags == [] blog.reload() - self.assertEqual(blog.tags, ["code"]) + assert blog.tags == ["code"] # Push a new tag via modify with new=True. blog = BlogPost.objects(id=blog.id).modify(push__tags="java", new=True) - self.assertEqual(blog.tags, ["code", "java"]) + assert blog.tags == ["code", "java"] # Push a new tag with a positional argument. blog = BlogPost.objects(id=blog.id).modify(push__tags__0="python", new=True) - self.assertEqual(blog.tags, ["python", "code", "java"]) + assert blog.tags == ["python", "code", "java"] # Push multiple new tags with a positional argument. blog = BlogPost.objects(id=blog.id).modify( push__tags__1=["go", "rust"], new=True ) - self.assertEqual(blog.tags, ["python", "go", "rust", "code", "java"]) + assert blog.tags == ["python", "go", "rust", "code", "java"] if __name__ == "__main__": diff --git a/tests/queryset/test_pickable.py b/tests/queryset/test_pickable.py index 8c4e3426..d41f56df 100644 --- a/tests/queryset/test_pickable.py +++ b/tests/queryset/test_pickable.py @@ -37,13 +37,13 @@ class TestQuerysetPickable(MongoDBTestCase): loadedQs = self._get_loaded(qs) - self.assertEqual(qs.count(), loadedQs.count()) + assert qs.count() == loadedQs.count() # can update loadedQs loadedQs.update(age=23) # check - self.assertEqual(Person.objects.first().age, 23) + assert Person.objects.first().age == 23 def test_pickle_support_filtration(self): Person.objects.create(name="Alice", age=22) @@ -51,9 +51,9 @@ class TestQuerysetPickable(MongoDBTestCase): Person.objects.create(name="Bob", age=23) qs = Person.objects.filter(age__gte=22) - self.assertEqual(qs.count(), 2) + assert qs.count() == 2 loaded = self._get_loaded(qs) - self.assertEqual(loaded.count(), 2) - self.assertEqual(loaded.filter(name="Bob").first().age, 23) + assert loaded.count() == 2 + assert loaded.filter(name="Bob").first().age == 23 diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 16213254..d154de8d 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -24,6 +24,7 @@ from mongoengine.queryset import ( QuerySetManager, queryset_manager, ) +import pytest class db_ops_tracker(query_counter): @@ -64,11 +65,11 @@ class TestQueryset(unittest.TestCase): def test_initialisation(self): """Ensure that a QuerySet is correctly initialised by QuerySetManager. """ - self.assertIsInstance(self.Person.objects, QuerySet) - self.assertEqual( - self.Person.objects._collection.name, self.Person._get_collection_name() + assert isinstance(self.Person.objects, QuerySet) + assert ( + self.Person.objects._collection.name == self.Person._get_collection_name() ) - self.assertIsInstance( + assert isinstance( self.Person.objects._collection, pymongo.collection.Collection ) @@ -78,11 +79,11 @@ class TestQueryset(unittest.TestCase): author2 = GenericReferenceField() # test addressing a field from a reference - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): list(BlogPost.objects(author__name="test")) # should fail for a generic reference as well - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): list(BlogPost.objects(author2__name="test")) def test_find(self): @@ -92,27 +93,27 @@ class TestQueryset(unittest.TestCase): # Find all people in the collection people = self.Person.objects - self.assertEqual(people.count(), 2) + assert people.count() == 2 results = list(people) - self.assertIsInstance(results[0], self.Person) - self.assertIsInstance(results[0].id, ObjectId) + assert isinstance(results[0], self.Person) + assert isinstance(results[0].id, ObjectId) - self.assertEqual(results[0], user_a) - self.assertEqual(results[0].name, "User A") - self.assertEqual(results[0].age, 20) + assert results[0] == user_a + assert results[0].name == "User A" + assert results[0].age == 20 - self.assertEqual(results[1], user_b) - self.assertEqual(results[1].name, "User B") - self.assertEqual(results[1].age, 30) + assert results[1] == user_b + assert results[1].name == "User B" + assert results[1].age == 30 # Filter people by age people = self.Person.objects(age=20) - self.assertEqual(people.count(), 1) + assert people.count() == 1 person = people.next() - self.assertEqual(person, user_a) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, 20) + assert person == user_a + assert person.name == "User A" + assert person.age == 20 def test_limit(self): """Ensure that QuerySet.limit works as expected.""" @@ -121,27 +122,27 @@ class TestQueryset(unittest.TestCase): # Test limit on a new queryset people = list(self.Person.objects.limit(1)) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], user_a) + assert len(people) == 1 + assert people[0] == user_a # Test limit on an existing queryset people = self.Person.objects - self.assertEqual(len(people), 2) + assert len(people) == 2 people2 = people.limit(1) - self.assertEqual(len(people), 2) - self.assertEqual(len(people2), 1) - self.assertEqual(people2[0], user_a) + assert len(people) == 2 + assert len(people2) == 1 + assert people2[0] == user_a # Test limit with 0 as parameter people = self.Person.objects.limit(0) - self.assertEqual(people.count(with_limit_and_skip=True), 2) - self.assertEqual(len(people), 2) + assert people.count(with_limit_and_skip=True) == 2 + assert len(people) == 2 # Test chaining of only after limit person = self.Person.objects().limit(1).only("name").first() - self.assertEqual(person, user_a) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, None) + assert person == user_a + assert person.name == "User A" + assert person.age == None def test_skip(self): """Ensure that QuerySet.skip works as expected.""" @@ -150,26 +151,26 @@ class TestQueryset(unittest.TestCase): # Test skip on a new queryset people = list(self.Person.objects.skip(1)) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], user_b) + assert len(people) == 1 + assert people[0] == user_b # Test skip on an existing queryset people = self.Person.objects - self.assertEqual(len(people), 2) + assert len(people) == 2 people2 = people.skip(1) - self.assertEqual(len(people), 2) - self.assertEqual(len(people2), 1) - self.assertEqual(people2[0], user_b) + assert len(people) == 2 + assert len(people2) == 1 + assert people2[0] == user_b # Test chaining of only after skip person = self.Person.objects().skip(1).only("name").first() - self.assertEqual(person, user_b) - self.assertEqual(person.name, "User B") - self.assertEqual(person.age, None) + assert person == user_b + assert person.name == "User B" + assert person.age == None def test___getitem___invalid_index(self): """Ensure slicing a queryset works as expected.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.Person.objects()["a"] def test_slice(self): @@ -180,27 +181,27 @@ class TestQueryset(unittest.TestCase): # Test slice limit people = list(self.Person.objects[:2]) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], user_a) - self.assertEqual(people[1], user_b) + assert len(people) == 2 + assert people[0] == user_a + assert people[1] == user_b # Test slice skip people = list(self.Person.objects[1:]) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], user_b) - self.assertEqual(people[1], user_c) + assert len(people) == 2 + assert people[0] == user_b + assert people[1] == user_c # Test slice limit and skip people = list(self.Person.objects[1:2]) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], user_b) + assert len(people) == 1 + assert people[0] == user_b # Test slice limit and skip on an existing queryset people = self.Person.objects - self.assertEqual(len(people), 3) + assert len(people) == 3 people2 = people[1:2] - self.assertEqual(len(people2), 1) - self.assertEqual(people2[0], user_b) + assert len(people2) == 1 + assert people2[0] == user_b # Test slice limit and skip cursor reset qs = self.Person.objects[1:2] @@ -208,31 +209,31 @@ class TestQueryset(unittest.TestCase): qs._cursor qs._cursor_obj = None people = list(qs) - self.assertEqual(len(people), 1) - self.assertEqual(people[0].name, "User B") + assert len(people) == 1 + assert people[0].name == "User B" # Test empty slice people = list(self.Person.objects[1:1]) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test slice out of range people = list(self.Person.objects[80000:80001]) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test larger slice __repr__ self.Person.objects.delete() for i in range(55): self.Person(name="A%s" % i, age=i).save() - self.assertEqual(self.Person.objects.count(), 55) - self.assertEqual("Person object", "%s" % self.Person.objects[0]) - self.assertEqual( - "[, ]", - "%s" % self.Person.objects[1:3], + assert self.Person.objects.count() == 55 + assert "Person object" == "%s" % self.Person.objects[0] + assert ( + "[, ]" + == "%s" % self.Person.objects[1:3] ) - self.assertEqual( - "[, ]", - "%s" % self.Person.objects[51:53], + assert ( + "[, ]" + == "%s" % self.Person.objects[51:53] ) def test_find_one(self): @@ -245,40 +246,42 @@ class TestQueryset(unittest.TestCase): # Retrieve the first person from the database person = self.Person.objects.first() - self.assertIsInstance(person, self.Person) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, 20) + assert isinstance(person, self.Person) + assert person.name == "User A" + assert person.age == 20 # Use a query to filter the people found to just person2 person = self.Person.objects(age=30).first() - self.assertEqual(person.name, "User B") + assert person.name == "User B" person = self.Person.objects(age__lt=30).first() - self.assertEqual(person.name, "User A") + assert person.name == "User A" # Use array syntax person = self.Person.objects[0] - self.assertEqual(person.name, "User A") + assert person.name == "User A" person = self.Person.objects[1] - self.assertEqual(person.name, "User B") + assert person.name == "User B" - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.Person.objects[2] # Find a document using just the object id person = self.Person.objects.with_id(person1.id) - self.assertEqual(person.name, "User A") + assert person.name == "User A" - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): self.Person.objects(name="User A").with_id(person1.id) def test_find_only_one(self): """Ensure that a query using ``get`` returns at most one result. """ # Try retrieving when no objects exists - self.assertRaises(DoesNotExist, self.Person.objects.get) - self.assertRaises(self.Person.DoesNotExist, self.Person.objects.get) + with pytest.raises(DoesNotExist): + self.Person.objects.get() + with pytest.raises(self.Person.DoesNotExist): + self.Person.objects.get() person1 = self.Person(name="User A", age=20) person1.save() @@ -286,15 +289,17 @@ class TestQueryset(unittest.TestCase): person2.save() # Retrieve the first person from the database - self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) - self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get) + with pytest.raises(MultipleObjectsReturned): + self.Person.objects.get() + with pytest.raises(self.Person.MultipleObjectsReturned): + self.Person.objects.get() # Use a query to filter the people found to just person2 person = self.Person.objects.get(age=30) - self.assertEqual(person.name, "User B") + assert person.name == "User B" person = self.Person.objects.get(age__lt=30) - self.assertEqual(person.name, "User A") + assert person.name == "User A" def test_find_array_position(self): """Ensure that query by array position works. @@ -313,10 +318,10 @@ class TestQueryset(unittest.TestCase): Blog.drop_collection() Blog.objects.create(tags=["a", "b"]) - self.assertEqual(Blog.objects(tags__0="a").count(), 1) - self.assertEqual(Blog.objects(tags__0="b").count(), 0) - self.assertEqual(Blog.objects(tags__1="a").count(), 0) - self.assertEqual(Blog.objects(tags__1="b").count(), 1) + assert Blog.objects(tags__0="a").count() == 1 + assert Blog.objects(tags__0="b").count() == 0 + assert Blog.objects(tags__1="a").count() == 0 + assert Blog.objects(tags__1="b").count() == 1 Blog.drop_collection() @@ -328,19 +333,19 @@ class TestQueryset(unittest.TestCase): blog2 = Blog.objects.create(posts=[post2, post1]) blog = Blog.objects(posts__0__comments__0__name="testa").get() - self.assertEqual(blog, blog1) + assert blog == blog1 blog = Blog.objects(posts__0__comments__0__name="testb").get() - self.assertEqual(blog, blog2) + assert blog == blog2 query = Blog.objects(posts__1__comments__1__name="testb") - self.assertEqual(query.count(), 2) + assert query.count() == 2 query = Blog.objects(posts__1__comments__1__name="testa") - self.assertEqual(query.count(), 0) + assert query.count() == 0 query = Blog.objects(posts__0__comments__1__name="testa") - self.assertEqual(query.count(), 0) + assert query.count() == 0 Blog.drop_collection() @@ -351,8 +356,8 @@ class TestQueryset(unittest.TestCase): A.drop_collection() A().save() - self.assertEqual(list(A.objects.none()), []) - self.assertEqual(list(A.objects.none().all()), []) + assert list(A.objects.none()) == [] + assert list(A.objects.none().all()) == [] def test_chaining(self): class A(Document): @@ -376,12 +381,12 @@ class TestQueryset(unittest.TestCase): # Doesn't work q2 = B.objects.filter(ref__in=[a1, a2]) q2 = q2.filter(ref=a1)._query - self.assertEqual(q1, q2) + assert q1 == q2 a_objects = A.objects(s="test1") query = B.objects(ref__in=a_objects) query = query.filter(boolfield=True) - self.assertEqual(query.count(), 1) + assert query.count() == 1 def test_batch_size(self): """Ensure that batch_size works.""" @@ -398,7 +403,7 @@ class TestQueryset(unittest.TestCase): cnt = 0 for a in A.objects.batch_size(10): cnt += 1 - self.assertEqual(cnt, 100) + assert cnt == 100 # test chaining qs = A.objects.all() @@ -406,11 +411,11 @@ class TestQueryset(unittest.TestCase): cnt = 0 for a in qs: cnt += 1 - self.assertEqual(cnt, 9) + assert cnt == 9 # test invalid batch size qs = A.objects.batch_size(-1) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): list(qs) def test_batch_size_cloned(self): @@ -419,9 +424,9 @@ class TestQueryset(unittest.TestCase): # test that batch size gets cloned qs = A.objects.batch_size(5) - self.assertEqual(qs._batch_size, 5) + assert qs._batch_size == 5 qs_clone = qs.clone() - self.assertEqual(qs_clone._batch_size, 5) + assert qs_clone._batch_size == 5 def test_update_write_concern(self): """Test that passing write_concern works""" @@ -437,18 +442,18 @@ class TestQueryset(unittest.TestCase): result = self.Person.objects.update(set__name="Ross", write_concern={"w": 1}) - self.assertEqual(result, 2) + assert result == 2 result = self.Person.objects.update(set__name="Ross", write_concern={"w": 0}) - self.assertEqual(result, None) + assert result == None result = self.Person.objects.update_one( set__name="Test User", write_concern={"w": 1} ) - self.assertEqual(result, 1) + assert result == 1 result = self.Person.objects.update_one( set__name="Test User", write_concern={"w": 0} ) - self.assertEqual(result, None) + assert result == None def test_update_update_has_a_value(self): """Test to ensure that update is passed a value to update to""" @@ -456,10 +461,10 @@ class TestQueryset(unittest.TestCase): author = self.Person.objects.create(name="Test User") - with self.assertRaises(OperationError): + with pytest.raises(OperationError): self.Person.objects(pk=author.pk).update({}) - with self.assertRaises(OperationError): + with pytest.raises(OperationError): self.Person.objects(pk=author.pk).update_one({}) def test_update_array_position(self): @@ -492,7 +497,7 @@ class TestQueryset(unittest.TestCase): # Update all of the first comments of second posts of all blogs Blog.objects().update(set__posts__1__comments__0__name="testc") testc_blogs = Blog.objects(posts__1__comments__0__name="testc") - self.assertEqual(testc_blogs.count(), 2) + assert testc_blogs.count() == 2 Blog.drop_collection() Blog.objects.create(posts=[post1, post2]) @@ -501,10 +506,10 @@ class TestQueryset(unittest.TestCase): # Update only the first blog returned by the query Blog.objects().update_one(set__posts__1__comments__1__name="testc") testc_blogs = Blog.objects(posts__1__comments__1__name="testc") - self.assertEqual(testc_blogs.count(), 1) + assert testc_blogs.count() == 1 # Check that using this indexing syntax on a non-list fails - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Blog.objects().update(set__posts__1__comments__0__name__1="asdf") Blog.drop_collection() @@ -531,8 +536,8 @@ class TestQueryset(unittest.TestCase): BlogPost.objects(comments__by="jane").update(inc__comments__S__votes=1) post = BlogPost.objects.first() - self.assertEqual(post.comments[1].by, "jane") - self.assertEqual(post.comments[1].votes, 8) + assert post.comments[1].by == "jane" + assert post.comments[1].votes == 8 def test_update_using_positional_operator_matches_first(self): @@ -547,7 +552,7 @@ class TestQueryset(unittest.TestCase): Simple.objects(x=2).update(inc__x__S=1) simple = Simple.objects.first() - self.assertEqual(simple.x, [1, 3, 3, 2]) + assert simple.x == [1, 3, 3, 2] Simple.drop_collection() # You can set multiples @@ -559,10 +564,10 @@ class TestQueryset(unittest.TestCase): Simple.objects(x=3).update(set__x__S=0) s = Simple.objects() - self.assertEqual(s[0].x, [1, 2, 0, 4]) - self.assertEqual(s[1].x, [2, 0, 4, 5]) - self.assertEqual(s[2].x, [0, 4, 5, 6]) - self.assertEqual(s[3].x, [4, 5, 6, 7]) + assert s[0].x == [1, 2, 0, 4] + assert s[1].x == [2, 0, 4, 5] + assert s[2].x == [0, 4, 5, 6] + assert s[3].x == [4, 5, 6, 7] # Using "$unset" with an expression like this "array.$" will result in # the array item becoming None, not being removed. @@ -570,14 +575,14 @@ class TestQueryset(unittest.TestCase): Simple(x=[1, 2, 3, 4, 3, 2, 3, 4]).save() Simple.objects(x=3).update(unset__x__S=1) simple = Simple.objects.first() - self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) + assert simple.x == [1, 2, None, 4, 3, 2, 3, 4] # Nested updates arent supported yet.. - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Simple.drop_collection() Simple(x=[{"test": [1, 2, 3, 4]}]).save() Simple.objects(x__test=2).update(set__x__S__test__S=3) - self.assertEqual(simple.x, [1, 2, 3, 4]) + assert simple.x == [1, 2, 3, 4] def test_update_using_positional_operator_embedded_document(self): """Ensure that the embedded documents can be updated using the positional @@ -606,8 +611,8 @@ class TestQueryset(unittest.TestCase): ) post = BlogPost.objects.first() - self.assertEqual(post.comments[0].by, "joe") - self.assertEqual(post.comments[0].votes.score, 4) + assert post.comments[0].by == "joe" + assert post.comments[0].votes.score == 4 def test_update_min_max(self): class Scores(Document): @@ -617,14 +622,14 @@ class TestQueryset(unittest.TestCase): scores = Scores.objects.create(high_score=800, low_score=200) Scores.objects(id=scores.id).update(min__low_score=150) - self.assertEqual(Scores.objects.get(id=scores.id).low_score, 150) + assert Scores.objects.get(id=scores.id).low_score == 150 Scores.objects(id=scores.id).update(min__low_score=250) - self.assertEqual(Scores.objects.get(id=scores.id).low_score, 150) + assert Scores.objects.get(id=scores.id).low_score == 150 Scores.objects(id=scores.id).update(max__high_score=1000) - self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) + assert Scores.objects.get(id=scores.id).high_score == 1000 Scores.objects(id=scores.id).update(max__high_score=500) - self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) + assert Scores.objects.get(id=scores.id).high_score == 1000 def test_update_multiple(self): class Product(Document): @@ -634,10 +639,10 @@ class TestQueryset(unittest.TestCase): product = Product.objects.create(item="ABC", price=10.99) product = Product.objects.create(item="ABC", price=10.99) Product.objects(id=product.id).update(mul__price=1.25) - self.assertEqual(Product.objects.get(id=product.id).price, 13.7375) + assert Product.objects.get(id=product.id).price == 13.7375 unknown_product = Product.objects.create(item="Unknown") Product.objects(id=unknown_product.id).update(mul__price=100) - self.assertEqual(Product.objects.get(id=unknown_product.id).price, 0) + assert Product.objects.get(id=unknown_product.id).price == 0 def test_updates_can_have_match_operators(self): class Comment(EmbeddedDocument): @@ -663,7 +668,7 @@ class TestQueryset(unittest.TestCase): Post.objects().update_one(pull__comments__vote__lt=1) - self.assertEqual(1, len(Post.objects.first().comments)) + assert 1 == len(Post.objects.first().comments) def test_mapfield_update(self): """Ensure that the MapField can be updated.""" @@ -684,8 +689,8 @@ class TestQueryset(unittest.TestCase): Club.objects().update(set__members={"John": Member(gender="F", age=14)}) club = Club.objects().first() - self.assertEqual(club.members["John"].gender, "F") - self.assertEqual(club.members["John"].age, 14) + assert club.members["John"].gender == "F" + assert club.members["John"].age == 14 def test_dictfield_update(self): """Ensure that the DictField can be updated.""" @@ -700,25 +705,25 @@ class TestQueryset(unittest.TestCase): Club.objects().update(set__members={"John": {"gender": "F", "age": 14}}) club = Club.objects().first() - self.assertEqual(club.members["John"]["gender"], "F") - self.assertEqual(club.members["John"]["age"], 14) + assert club.members["John"]["gender"] == "F" + assert club.members["John"]["age"] == 14 def test_update_results(self): self.Person.drop_collection() result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True) - self.assertIsInstance(result, UpdateResult) - self.assertIn("upserted", result.raw_result) - self.assertFalse(result.raw_result["updatedExisting"]) + assert isinstance(result, UpdateResult) + assert "upserted" in result.raw_result + assert not result.raw_result["updatedExisting"] bob = self.Person.objects.first() result = bob.update(set__age=30, full_result=True) - self.assertIsInstance(result, UpdateResult) - self.assertTrue(result.raw_result["updatedExisting"]) + assert isinstance(result, UpdateResult) + assert result.raw_result["updatedExisting"] self.Person(name="Bob", age=20).save() result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True) - self.assertEqual(result, 2) + assert result == 2 def test_update_validate(self): class EmDoc(EmbeddedDocument): @@ -730,13 +735,12 @@ class TestQueryset(unittest.TestCase): cdt_f = ComplexDateTimeField() ed_f = EmbeddedDocumentField(EmDoc) - self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True) - self.assertRaises( - ValidationError, Doc.objects().update, dt_f="datetime", upsert=True - ) - self.assertRaises( - ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True - ) + with pytest.raises(ValidationError): + Doc.objects().update(str_f=1, upsert=True) + with pytest.raises(ValidationError): + Doc.objects().update(dt_f="datetime", upsert=True) + with pytest.raises(ValidationError): + Doc.objects().update(ed_f__str_f=1, upsert=True) def test_update_related_models(self): class TestPerson(Document): @@ -757,20 +761,20 @@ class TestQueryset(unittest.TestCase): o.owner = p p.name = "p2" - self.assertEqual(o._get_changed_fields(), ["owner"]) - self.assertEqual(p._get_changed_fields(), ["name"]) + assert o._get_changed_fields() == ["owner"] + assert p._get_changed_fields() == ["name"] o.save() - self.assertEqual(o._get_changed_fields(), []) - self.assertEqual(p._get_changed_fields(), ["name"]) # Fails; it's empty + assert o._get_changed_fields() == [] + assert p._get_changed_fields() == ["name"] # Fails; it's empty # This will do NOTHING at all, even though we changed the name p.save() p.reload() - self.assertEqual(p.name, "p2") # Fails; it's still `p1` + assert p.name == "p2" # Fails; it's still `p1` def test_upsert(self): self.Person.drop_collection() @@ -778,25 +782,25 @@ class TestQueryset(unittest.TestCase): self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True) bob = self.Person.objects.first() - self.assertEqual("Bob", bob.name) - self.assertEqual(30, bob.age) + assert "Bob" == bob.name + assert 30 == bob.age def test_upsert_one(self): self.Person.drop_collection() bob = self.Person.objects(name="Bob", age=30).upsert_one() - self.assertEqual("Bob", bob.name) - self.assertEqual(30, bob.age) + assert "Bob" == bob.name + assert 30 == bob.age bob.name = "Bobby" bob.save() bobby = self.Person.objects(name="Bobby", age=30).upsert_one() - self.assertEqual("Bobby", bobby.name) - self.assertEqual(30, bobby.age) - self.assertEqual(bob.id, bobby.id) + assert "Bobby" == bobby.name + assert 30 == bobby.age + assert bob.id == bobby.id def test_set_on_insert(self): self.Person.drop_collection() @@ -806,8 +810,8 @@ class TestQueryset(unittest.TestCase): ) bob = self.Person.objects.first() - self.assertEqual("Bob", bob.name) - self.assertEqual(30, bob.age) + assert "Bob" == bob.name + assert 30 == bob.age def test_save_and_only_on_fields_with_default(self): class Embed(EmbeddedDocument): @@ -832,9 +836,9 @@ class TestQueryset(unittest.TestCase): # Checking it was saved correctly record.reload() - self.assertEqual(record.field, 2) - self.assertEqual(record.embed_no_default.field, 2) - self.assertEqual(record.embed.field, 2) + assert record.field == 2 + assert record.embed_no_default.field == 2 + assert record.embed.field == 2 # Request only the _id field and save clone = B.objects().only("id").first() @@ -842,9 +846,9 @@ class TestQueryset(unittest.TestCase): # Reload the record and see that the embed data is not lost record.reload() - self.assertEqual(record.field, 2) - self.assertEqual(record.embed_no_default.field, 2) - self.assertEqual(record.embed.field, 2) + assert record.field == 2 + assert record.embed_no_default.field == 2 + assert record.embed.field == 2 def test_bulk_insert(self): """Ensure that bulk insert works""" @@ -863,7 +867,7 @@ class TestQueryset(unittest.TestCase): Blog.drop_collection() # Recreates the collection - self.assertEqual(0, Blog.objects.count()) + assert 0 == Blog.objects.count() comment1 = Comment(name="testa") comment2 = Comment(name="testb") @@ -873,11 +877,11 @@ class TestQueryset(unittest.TestCase): # Check bulk insert using load_bulk=False blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)] with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 1) # 1 entry containing the list of inserts + assert q == 1 # 1 entry containing the list of inserts - self.assertEqual(Blog.objects.count(), len(blogs)) + assert Blog.objects.count() == len(blogs) Blog.drop_collection() Blog.ensure_indexes() @@ -885,9 +889,9 @@ class TestQueryset(unittest.TestCase): # Check bulk insert using load_bulk=True blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)] with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 Blog.objects.insert(blogs) - self.assertEqual(q, 2) # 1 for insert 1 for fetch + assert q == 2 # 1 for insert 1 for fetch Blog.drop_collection() @@ -898,25 +902,27 @@ class TestQueryset(unittest.TestCase): blog1 = Blog(title="code", posts=[post1, post2]) blog2 = Blog(title="mongodb", posts=[post2, post1]) blog1, blog2 = Blog.objects.insert([blog1, blog2]) - self.assertEqual(blog1.title, "code") - self.assertEqual(blog2.title, "mongodb") + assert blog1.title == "code" + assert blog2.title == "mongodb" - self.assertEqual(Blog.objects.count(), 2) + assert Blog.objects.count() == 2 # test inserting an existing document (shouldn't be allowed) - with self.assertRaises(OperationError) as cm: + with pytest.raises(OperationError) as cm: blog = Blog.objects.first() Blog.objects.insert(blog) - self.assertEqual( - str(cm.exception), "Some documents have ObjectIds, use doc.update() instead" + assert ( + str(cm.exception) + == "Some documents have ObjectIds, use doc.update() instead" ) # test inserting a query set - with self.assertRaises(OperationError) as cm: + with pytest.raises(OperationError) as cm: blogs_qs = Blog.objects Blog.objects.insert(blogs_qs) - self.assertEqual( - str(cm.exception), "Some documents have ObjectIds, use doc.update() instead" + assert ( + str(cm.exception) + == "Some documents have ObjectIds, use doc.update() instead" ) # insert 1 new doc @@ -927,13 +933,13 @@ class TestQueryset(unittest.TestCase): blog1 = Blog(title="code", posts=[post1, post2]) blog1 = Blog.objects.insert(blog1) - self.assertEqual(blog1.title, "code") - self.assertEqual(Blog.objects.count(), 1) + assert blog1.title == "code" + assert Blog.objects.count() == 1 Blog.drop_collection() blog1 = Blog(title="code", posts=[post1, post2]) obj_id = Blog.objects.insert(blog1, load_bulk=False) - self.assertIsInstance(obj_id, ObjectId) + assert isinstance(obj_id, ObjectId) Blog.drop_collection() post3 = Post(comments=[comment1, comment1]) @@ -941,10 +947,10 @@ class TestQueryset(unittest.TestCase): blog2 = Blog(title="bar", posts=[post2, post3]) Blog.objects.insert([blog1, blog2]) - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): Blog.objects.insert(Blog(title=blog2.title)) - self.assertEqual(Blog.objects.count(), 2) + assert Blog.objects.count() == 2 def test_bulk_insert_different_class_fails(self): class Blog(Document): @@ -954,7 +960,7 @@ class TestQueryset(unittest.TestCase): pass # try inserting a different document class - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Blog.objects.insert(Author()) def test_bulk_insert_with_wrong_type(self): @@ -964,10 +970,10 @@ class TestQueryset(unittest.TestCase): Blog.drop_collection() Blog(name="test").save() - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Blog.objects.insert("HELLO WORLD") - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Blog.objects.insert({"name": "garbage"}) def test_bulk_insert_update_input_document_ids(self): @@ -979,23 +985,23 @@ class TestQueryset(unittest.TestCase): # Test with bulk comments = [Comment(idx=idx) for idx in range(20)] for com in comments: - self.assertIsNone(com.id) + assert com.id is None returned_comments = Comment.objects.insert(comments, load_bulk=True) for com in comments: - self.assertIsInstance(com.id, ObjectId) + assert isinstance(com.id, ObjectId) input_mapping = {com.id: com.idx for com in comments} saved_mapping = {com.id: com.idx for com in returned_comments} - self.assertEqual(input_mapping, saved_mapping) + assert input_mapping == saved_mapping Comment.drop_collection() # Test with just one comment = Comment(idx=0) inserted_comment_id = Comment.objects.insert(comment, load_bulk=False) - self.assertEqual(comment.id, inserted_comment_id) + assert comment.id == inserted_comment_id def test_bulk_insert_accepts_doc_with_ids(self): class Comment(Document): @@ -1017,7 +1023,7 @@ class TestQueryset(unittest.TestCase): Comment.objects.insert(com1) - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): Comment.objects.insert(com1) def test_get_changed_fields_query_count(self): @@ -1050,28 +1056,28 @@ class TestQueryset(unittest.TestCase): o1 = Organization(name="o1", employees=[p1]).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 # Fetching a document should result in a query. org = Organization.objects.get(id=o1.id) - self.assertEqual(q, 1) + assert q == 1 # Checking changed fields of a newly fetched document should not # result in a query. org._get_changed_fields() - self.assertEqual(q, 1) + assert q == 1 # Saving a doc without changing any of its fields should not result # in a query (with or without cascade=False). org = Organization.objects.get(id=o1.id) with query_counter() as q: org.save() - self.assertEqual(q, 0) + assert q == 0 org = Organization.objects.get(id=o1.id) with query_counter() as q: org.save(cascade=False) - self.assertEqual(q, 0) + assert q == 0 # Saving a doc after you append a reference to it should result in # two db operations (a query for the reference and an update). @@ -1080,7 +1086,7 @@ class TestQueryset(unittest.TestCase): with query_counter() as q: org.employees.append(p2) # dereferences p2 org.save() # saves the org - self.assertEqual(q, 2) + assert q == 2 def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. @@ -1097,8 +1103,8 @@ class TestQueryset(unittest.TestCase): break people3 = [person for person in queryset] - self.assertEqual(people1, people2) - self.assertEqual(people1, people3) + assert people1 == people2 + assert people1 == people3 def test_repr(self): """Test repr behavior isnt destructive""" @@ -1116,21 +1122,21 @@ class TestQueryset(unittest.TestCase): docs = Doc.objects.order_by("number") - self.assertEqual(docs.count(), 1000) + assert docs.count() == 1000 docs_string = "%s" % docs - self.assertIn("Doc: 0", docs_string) + assert "Doc: 0" in docs_string - self.assertEqual(docs.count(), 1000) - self.assertIn("(remaining elements truncated)", "%s" % docs) + assert docs.count() == 1000 + assert "(remaining elements truncated)" in "%s" % docs # Limit and skip docs = docs[1:4] - self.assertEqual("[, , ]", "%s" % docs) + assert "[, , ]" == "%s" % docs - self.assertEqual(docs.count(with_limit_and_skip=True), 3) + assert docs.count(with_limit_and_skip=True) == 3 for doc in docs: - self.assertEqual(".. queryset mid-iteration ..", repr(docs)) + assert ".. queryset mid-iteration .." == repr(docs) def test_regex_query_shortcuts(self): """Ensure that contains, startswith, endswith, etc work. @@ -1140,54 +1146,54 @@ class TestQueryset(unittest.TestCase): # Test contains obj = self.Person.objects(name__contains="van").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__contains="Van").first() - self.assertEqual(obj, None) + assert obj == None # Test icontains obj = self.Person.objects(name__icontains="Van").first() - self.assertEqual(obj, person) + assert obj == person # Test startswith obj = self.Person.objects(name__startswith="Guido").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__startswith="guido").first() - self.assertEqual(obj, None) + assert obj == None # Test istartswith obj = self.Person.objects(name__istartswith="guido").first() - self.assertEqual(obj, person) + assert obj == person # Test endswith obj = self.Person.objects(name__endswith="Rossum").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__endswith="rossuM").first() - self.assertEqual(obj, None) + assert obj == None # Test iendswith obj = self.Person.objects(name__iendswith="rossuM").first() - self.assertEqual(obj, person) + assert obj == person # Test exact obj = self.Person.objects(name__exact="Guido van Rossum").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__exact="Guido van rossum").first() - self.assertEqual(obj, None) + assert obj == None obj = self.Person.objects(name__exact="Guido van Rossu").first() - self.assertEqual(obj, None) + assert obj == None # Test iexact obj = self.Person.objects(name__iexact="gUIDO VAN rOSSUM").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__iexact="gUIDO VAN rOSSU").first() - self.assertEqual(obj, None) + assert obj == None # Test unsafe expressions person = self.Person(name="Guido van Rossum [.'Geek']") person.save() obj = self.Person.objects(name__icontains="[.'Geek").first() - self.assertEqual(obj, person) + assert obj == person def test_not(self): """Ensure that the __not operator works as expected. @@ -1196,10 +1202,10 @@ class TestQueryset(unittest.TestCase): alice.save() obj = self.Person.objects(name__iexact="alice").first() - self.assertEqual(obj, alice) + assert obj == alice obj = self.Person.objects(name__not__iexact="alice").first() - self.assertEqual(obj, None) + assert obj == None def test_filter_chaining(self): """Ensure filters can be chained together. @@ -1253,12 +1259,12 @@ class TestQueryset(unittest.TestCase): published_posts = published_posts.filter( published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0) ) - self.assertEqual(published_posts.count(), 2) + assert published_posts.count() == 2 blog_posts = BlogPost.objects blog_posts = blog_posts.filter(blog__in=[blog_1, blog_2]) blog_posts = blog_posts.filter(blog=blog_3) - self.assertEqual(blog_posts.count(), 0) + assert blog_posts.count() == 0 BlogPost.drop_collection() Blog.drop_collection() @@ -1269,14 +1275,14 @@ class TestQueryset(unittest.TestCase): people = self.Person.objects people = people.filter(name__startswith="Gui").filter(name__not__endswith="tum") - self.assertEqual(people.count(), 1) + assert people.count() == 1 def assertSequence(self, qs, expected): qs = list(qs) expected = list(expected) - self.assertEqual(len(qs), len(expected)) + assert len(qs) == len(expected) for i in range(len(qs)): - self.assertEqual(qs[i], expected[i]) + assert qs[i] == expected[i] def test_ordering(self): """Ensure default ordering is applied and can be overridden. @@ -1327,31 +1333,27 @@ class TestQueryset(unittest.TestCase): # default ordering should be used by default with db_ops_tracker() as q: BlogPost.objects.filter(title="whatever").first() - self.assertEqual(len(q.get_ops()), 1) - self.assertEqual( - q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": -1} - ) + assert len(q.get_ops()) == 1 + assert q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY] == {"published_date": -1} # calling order_by() should clear the default ordering with db_ops_tracker() as q: BlogPost.objects.filter(title="whatever").order_by().first() - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] # calling an explicit order_by should use a specified sort with db_ops_tracker() as q: BlogPost.objects.filter(title="whatever").order_by("published_date").first() - self.assertEqual(len(q.get_ops()), 1) - self.assertEqual( - q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": 1} - ) + assert len(q.get_ops()) == 1 + assert q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY] == {"published_date": 1} # calling order_by() after an explicit sort should clear it with db_ops_tracker() as q: qs = BlogPost.objects.filter(title="whatever").order_by("published_date") qs.order_by().first() - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] def test_no_ordering_for_get(self): """ Ensure that Doc.objects.get doesn't use any ordering. @@ -1370,14 +1372,14 @@ class TestQueryset(unittest.TestCase): with db_ops_tracker() as q: BlogPost.objects.get(title="whatever") - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] # Ordering should be ignored for .get even if we set it explicitly with db_ops_tracker() as q: BlogPost.objects.order_by("-title").get(title="whatever") - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] def test_find_embedded(self): """Ensure that an embedded document is properly returned from @@ -1397,20 +1399,20 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.create(author=user, content="Had a good coffee today...") result = BlogPost.objects.first() - self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, "Test User") + assert isinstance(result.author, User) + assert result.author.name == "Test User" result = BlogPost.objects.get(author__name=user.name) - self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, "Test User") + assert isinstance(result.author, User) + assert result.author.name == "Test User" result = BlogPost.objects.get(author={"name": user.name}) - self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, "Test User") + assert isinstance(result.author, User) + assert result.author.name == "Test User" # Fails, since the string is not a type that is able to represent the # author's document structure (should be dict) - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): BlogPost.objects.get(author=user.name) def test_find_empty_embedded(self): @@ -1428,7 +1430,7 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.create(content="Anonymous post...") result = BlogPost.objects.get(author=None) - self.assertEqual(result.author, None) + assert result.author == None def test_find_dict_item(self): """Ensure that DictField items may be found. @@ -1443,7 +1445,7 @@ class TestQueryset(unittest.TestCase): post.save() post_obj = BlogPost.objects(info__title="test").first() - self.assertEqual(post_obj.id, post.id) + assert post_obj.id == post.id BlogPost.drop_collection() @@ -1478,10 +1480,10 @@ class TestQueryset(unittest.TestCase): # Ensure that normal queries work c = BlogPost.objects(published=True).exec_js(js_func, "hits") - self.assertEqual(c, 2) + assert c == 2 c = BlogPost.objects(published=False).exec_js(js_func, "hits") - self.assertEqual(c, 1) + assert c == 1 BlogPost.drop_collection() @@ -1525,7 +1527,7 @@ class TestQueryset(unittest.TestCase): sub_code = BlogPost.objects._sub_js_fields(code) code_chunks = ['doc["cmnts"];', 'doc["doc-name"],', 'doc["cmnts"][i]["body"]'] for chunk in code_chunks: - self.assertIn(chunk, sub_code) + assert chunk in sub_code results = BlogPost.objects.exec_js(code) expected_results = [ @@ -1533,12 +1535,12 @@ class TestQueryset(unittest.TestCase): {u"comment": u"yay", u"document": u"post1"}, {u"comment": u"nice stuff", u"document": u"post2"}, ] - self.assertEqual(results, expected_results) + assert results == expected_results # Test template style code = "{{~comments.content}}" sub_code = BlogPost.objects._sub_js_fields(code) - self.assertEqual("cmnts.body", sub_code) + assert "cmnts.body" == sub_code BlogPost.drop_collection() @@ -1549,13 +1551,13 @@ class TestQueryset(unittest.TestCase): self.Person(name="User B", age=30).save() self.Person(name="User C", age=40).save() - self.assertEqual(self.Person.objects.count(), 3) + assert self.Person.objects.count() == 3 self.Person.objects(age__lt=30).delete() - self.assertEqual(self.Person.objects.count(), 2) + assert self.Person.objects.count() == 2 self.Person.objects.delete() - self.assertEqual(self.Person.objects.count(), 0) + assert self.Person.objects.count() == 0 def test_reverse_delete_rule_cascade(self): """Ensure cascading deletion of referring documents from the database. @@ -1576,9 +1578,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Chilling out", author=me).save() BlogPost(content="Pro Testing", author=someoneelse).save() - self.assertEqual(3, BlogPost.objects.count()) + assert 3 == BlogPost.objects.count() self.Person.objects(name="Test User").delete() - self.assertEqual(1, BlogPost.objects.count()) + assert 1 == BlogPost.objects.count() def test_reverse_delete_rule_cascade_on_abstract_document(self): """Ensure cascading deletion of referring documents from the database @@ -1603,9 +1605,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Chilling out", author=me).save() BlogPost(content="Pro Testing", author=someoneelse).save() - self.assertEqual(3, BlogPost.objects.count()) + assert 3 == BlogPost.objects.count() self.Person.objects(name="Test User").delete() - self.assertEqual(1, BlogPost.objects.count()) + assert 1 == BlogPost.objects.count() def test_reverse_delete_rule_cascade_cycle(self): """Ensure reference cascading doesn't loop if reference graph isn't @@ -1622,8 +1624,10 @@ class TestQueryset(unittest.TestCase): base.delete() - self.assertRaises(DoesNotExist, base.reload) - self.assertRaises(DoesNotExist, other.reload) + with pytest.raises(DoesNotExist): + base.reload() + with pytest.raises(DoesNotExist): + other.reload() def test_reverse_delete_rule_cascade_complex_cycle(self): """Ensure reference cascading doesn't loop if reference graph isn't @@ -1646,9 +1650,12 @@ class TestQueryset(unittest.TestCase): cat.delete() - self.assertRaises(DoesNotExist, base.reload) - self.assertRaises(DoesNotExist, other.reload) - self.assertRaises(DoesNotExist, other2.reload) + with pytest.raises(DoesNotExist): + base.reload() + with pytest.raises(DoesNotExist): + other.reload() + with pytest.raises(DoesNotExist): + other2.reload() def test_reverse_delete_rule_cascade_self_referencing(self): """Ensure self-referencing CASCADE deletes do not result in infinite @@ -1677,13 +1684,13 @@ class TestQueryset(unittest.TestCase): child_child.save() tree_size = 1 + num_children + (num_children * num_children) - self.assertEqual(tree_size, Category.objects.count()) - self.assertEqual(num_children, Category.objects(parent=base).count()) + assert tree_size == Category.objects.count() + assert num_children == Category.objects(parent=base).count() # The delete should effectively wipe out the Category collection # without resulting in infinite parent-child cascade recursion base.delete() - self.assertEqual(0, Category.objects.count()) + assert 0 == Category.objects.count() def test_reverse_delete_rule_nullify(self): """Ensure nullification of references to deleted documents. @@ -1705,11 +1712,11 @@ class TestQueryset(unittest.TestCase): post = BlogPost(content="Watching TV", category=lameness) post.save() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual("Lameness", BlogPost.objects.first().category.name) + assert 1 == BlogPost.objects.count() + assert "Lameness" == BlogPost.objects.first().category.name Category.objects.delete() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual(None, BlogPost.objects.first().category) + assert 1 == BlogPost.objects.count() + assert None == BlogPost.objects.first().category def test_reverse_delete_rule_nullify_on_abstract_document(self): """Ensure nullification of references to deleted documents when @@ -1732,11 +1739,11 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Watching TV", author=me).save() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual(me, BlogPost.objects.first().author) + assert 1 == BlogPost.objects.count() + assert me == BlogPost.objects.first().author self.Person.objects(name="Test User").delete() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual(None, BlogPost.objects.first().author) + assert 1 == BlogPost.objects.count() + assert None == BlogPost.objects.first().author def test_reverse_delete_rule_deny(self): """Ensure deletion gets denied on documents that still have references @@ -1756,7 +1763,8 @@ class TestQueryset(unittest.TestCase): post = BlogPost(content="Watching TV", author=me) post.save() - self.assertRaises(OperationError, self.Person.objects.delete) + with pytest.raises(OperationError): + self.Person.objects.delete() def test_reverse_delete_rule_deny_on_abstract_document(self): """Ensure deletion gets denied on documents that still have references @@ -1777,8 +1785,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Watching TV", author=me).save() - self.assertEqual(1, BlogPost.objects.count()) - self.assertRaises(OperationError, self.Person.objects.delete) + assert 1 == BlogPost.objects.count() + with pytest.raises(OperationError): + self.Person.objects.delete() def test_reverse_delete_rule_pull(self): """Ensure pulling of references to deleted documents. @@ -1807,8 +1816,8 @@ class TestQueryset(unittest.TestCase): post.reload() another.reload() - self.assertEqual(post.authors, [me]) - self.assertEqual(another.authors, []) + assert post.authors == [me] + assert another.authors == [] def test_reverse_delete_rule_pull_on_abstract_documents(self): """Ensure pulling of references to deleted documents when reference @@ -1841,8 +1850,8 @@ class TestQueryset(unittest.TestCase): post.reload() another.reload() - self.assertEqual(post.authors, [me]) - self.assertEqual(another.authors, []) + assert post.authors == [me] + assert another.authors == [] def test_delete_with_limits(self): class Log(Document): @@ -1854,7 +1863,7 @@ class TestQueryset(unittest.TestCase): Log().save() Log.objects()[3:5].delete() - self.assertEqual(8, Log.objects.count()) + assert 8 == Log.objects.count() def test_delete_with_limit_handles_delete_rules(self): """Ensure cascading deletion of referring documents from the database. @@ -1875,9 +1884,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Chilling out", author=me).save() BlogPost(content="Pro Testing", author=someoneelse).save() - self.assertEqual(3, BlogPost.objects.count()) + assert 3 == BlogPost.objects.count() self.Person.objects()[:1].delete() - self.assertEqual(1, BlogPost.objects.count()) + assert 1 == BlogPost.objects.count() def test_delete_edge_case_with_write_concern_0_return_None(self): """Return None if the delete operation is unacknowledged. @@ -1887,7 +1896,7 @@ class TestQueryset(unittest.TestCase): """ p1 = self.Person(name="User Z", age=20).save() del_result = p1.delete(w=0) - self.assertEqual(None, del_result) + assert None == del_result def test_reference_field_find(self): """Ensure cascading deletion of referring documents from the database. @@ -1903,13 +1912,13 @@ class TestQueryset(unittest.TestCase): me = self.Person(name="Test User").save() BlogPost(content="test 123", author=me).save() - self.assertEqual(1, BlogPost.objects(author=me).count()) - self.assertEqual(1, BlogPost.objects(author=me.pk).count()) - self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + assert 1 == BlogPost.objects(author=me).count() + assert 1 == BlogPost.objects(author=me.pk).count() + assert 1 == BlogPost.objects(author="%s" % me.pk).count() - self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) - self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) - self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + assert 1 == BlogPost.objects(author__in=[me]).count() + assert 1 == BlogPost.objects(author__in=[me.pk]).count() + assert 1 == BlogPost.objects(author__in=["%s" % me.pk]).count() def test_reference_field_find_dbref(self): """Ensure cascading deletion of referring documents from the database. @@ -1925,13 +1934,13 @@ class TestQueryset(unittest.TestCase): me = self.Person(name="Test User").save() BlogPost(content="test 123", author=me).save() - self.assertEqual(1, BlogPost.objects(author=me).count()) - self.assertEqual(1, BlogPost.objects(author=me.pk).count()) - self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + assert 1 == BlogPost.objects(author=me).count() + assert 1 == BlogPost.objects(author=me.pk).count() + assert 1 == BlogPost.objects(author="%s" % me.pk).count() - self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) - self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) - self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + assert 1 == BlogPost.objects(author__in=[me]).count() + assert 1 == BlogPost.objects(author__in=[me.pk]).count() + assert 1 == BlogPost.objects(author__in=["%s" % me.pk]).count() def test_update_intfield_operator(self): class BlogPost(Document): @@ -1944,20 +1953,20 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.update_one(set__hits=10) post.reload() - self.assertEqual(post.hits, 10) + assert post.hits == 10 BlogPost.objects.update_one(inc__hits=1) post.reload() - self.assertEqual(post.hits, 11) + assert post.hits == 11 BlogPost.objects.update_one(dec__hits=1) post.reload() - self.assertEqual(post.hits, 10) + assert post.hits == 10 # Negative dec operator is equal to a positive inc operator BlogPost.objects.update_one(dec__hits=-1) post.reload() - self.assertEqual(post.hits, 11) + assert post.hits == 11 def test_update_decimalfield_operator(self): class BlogPost(Document): @@ -1970,19 +1979,19 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.update_one(inc__review=0.1) # test with floats post.reload() - self.assertEqual(float(post.review), 3.6) + assert float(post.review) == 3.6 BlogPost.objects.update_one(dec__review=0.1) post.reload() - self.assertEqual(float(post.review), 3.5) + assert float(post.review) == 3.5 BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal post.reload() - self.assertEqual(float(post.review), 3.62) + assert float(post.review) == 3.62 BlogPost.objects.update_one(dec__review=Decimal(0.12)) post.reload() - self.assertEqual(float(post.review), 3.5) + assert float(post.review) == 3.5 def test_update_decimalfield_operator_not_working_with_force_string(self): class BlogPost(Document): @@ -1993,7 +2002,7 @@ class TestQueryset(unittest.TestCase): post = BlogPost(review=3.5) post.save() - with self.assertRaises(OperationError): + with pytest.raises(OperationError): BlogPost.objects.update_one(inc__review=0.1) # test with floats def test_update_listfield_operator(self): @@ -2011,22 +2020,22 @@ class TestQueryset(unittest.TestCase): # ListField operator BlogPost.objects.update(push__tags="mongo") post.reload() - self.assertIn("mongo", post.tags) + assert "mongo" in post.tags BlogPost.objects.update_one(push_all__tags=["db", "nosql"]) post.reload() - self.assertIn("db", post.tags) - self.assertIn("nosql", post.tags) + assert "db" in post.tags + assert "nosql" in post.tags tags = post.tags[:-1] BlogPost.objects.update(pop__tags=1) post.reload() - self.assertEqual(post.tags, tags) + assert post.tags == tags BlogPost.objects.update_one(add_to_set__tags="unique") BlogPost.objects.update_one(add_to_set__tags="unique") post.reload() - self.assertEqual(post.tags.count("unique"), 1) + assert post.tags.count("unique") == 1 BlogPost.drop_collection() @@ -2038,12 +2047,12 @@ class TestQueryset(unittest.TestCase): post = BlogPost(title="garbage").save() - self.assertNotEqual(post.title, None) + assert post.title != None BlogPost.objects.update_one(unset__title=1) post.reload() - self.assertEqual(post.title, None) + assert post.title == None pymongo_doc = BlogPost.objects.as_pymongo().first() - self.assertNotIn("title", pymongo_doc) + assert "title" not in pymongo_doc def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly. @@ -2060,16 +2069,16 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.filter(id=post.id).update(push__tags="code") BlogPost.objects.filter(id=post.id).update(push__tags__0=["mongodb", "python"]) post.reload() - self.assertEqual(post.tags, ["mongodb", "python", "code"]) + assert post.tags == ["mongodb", "python", "code"] BlogPost.objects.filter(id=post.id).update(set__tags__2="java") post.reload() - self.assertEqual(post.tags, ["mongodb", "python", "java"]) + assert post.tags == ["mongodb", "python", "java"] # test push with singular value BlogPost.objects.filter(id=post.id).update(push__tags__0="scala") post.reload() - self.assertEqual(post.tags, ["scala", "mongodb", "python", "java"]) + assert post.tags == ["scala", "mongodb", "python", "java"] def test_update_push_list_of_list(self): """Ensure that the 'push' update operation works in the list of list @@ -2085,7 +2094,7 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.filter(slug="test").update(push__tags=["value1", 123]) post.reload() - self.assertEqual(post.tags, [["value1", 123]]) + assert post.tags == [["value1", 123]] def test_update_push_and_pull_add_to_set(self): """Ensure that the 'pull' update operation works correctly. @@ -2102,25 +2111,25 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.filter(id=post.id).update(push__tags="code") post.reload() - self.assertEqual(post.tags, ["code"]) + assert post.tags == ["code"] BlogPost.objects.filter(id=post.id).update(push_all__tags=["mongodb", "code"]) post.reload() - self.assertEqual(post.tags, ["code", "mongodb", "code"]) + assert post.tags == ["code", "mongodb", "code"] BlogPost.objects(slug="test").update(pull__tags="code") post.reload() - self.assertEqual(post.tags, ["mongodb"]) + assert post.tags == ["mongodb"] BlogPost.objects(slug="test").update(pull_all__tags=["mongodb", "code"]) post.reload() - self.assertEqual(post.tags, []) + assert post.tags == [] BlogPost.objects(slug="test").update( __raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}} ) post.reload() - self.assertEqual(post.tags, ["code", "mongodb"]) + assert post.tags == ["code", "mongodb"] def test_add_to_set_each(self): class Item(Document): @@ -2137,7 +2146,7 @@ class TestQueryset(unittest.TestCase): item.update(add_to_set__parents=[parent_1, parent_2, parent_1]) item.reload() - self.assertEqual([parent_1, parent_2], item.parents) + assert [parent_1, parent_2] == item.parents def test_pull_nested(self): class Collaborator(EmbeddedDocument): @@ -2156,9 +2165,9 @@ class TestQueryset(unittest.TestCase): s = Site(name="test", collaborators=[c]).save() Site.objects(id=s.id).update_one(pull__collaborators__user="Esteban") - self.assertEqual(Site.objects.first().collaborators, []) + assert Site.objects.first().collaborators == [] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Site.objects(id=s.id).update_one(pull_all__collaborators__user=["Ross"]) def test_pull_from_nested_embedded(self): @@ -2185,14 +2194,14 @@ class TestQueryset(unittest.TestCase): ).save() Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) - self.assertEqual(Site.objects.first().collaborators["helpful"], []) + assert Site.objects.first().collaborators["helpful"] == [] Site.objects(id=s.id).update_one( pull__collaborators__unhelpful={"name": "Frank"} ) - self.assertEqual(Site.objects.first().collaborators["unhelpful"], []) + assert Site.objects.first().collaborators["unhelpful"] == [] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=["Ross"] ) @@ -2229,12 +2238,12 @@ class TestQueryset(unittest.TestCase): Site.objects(id=s.id).update_one( pull__collaborators__helpful__name__in=["Esteban"] ) # Pull a - self.assertEqual(Site.objects.first().collaborators["helpful"], [b]) + assert Site.objects.first().collaborators["helpful"] == [b] Site.objects(id=s.id).update_one( pull__collaborators__unhelpful__name__nin=["John"] ) # Pull x - self.assertEqual(Site.objects.first().collaborators["unhelpful"], [y]) + assert Site.objects.first().collaborators["unhelpful"] == [y] def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): @@ -2255,14 +2264,14 @@ class TestQueryset(unittest.TestCase): s.save() Site.objects(id=s.id).update_one(pull__collaborators__helpful__user="Esteban") - self.assertEqual(Site.objects.first().collaborators["helpful"], []) + assert Site.objects.first().collaborators["helpful"] == [] Site.objects(id=s.id).update_one( pull__collaborators__unhelpful={"user": "Frank"} ) - self.assertEqual(Site.objects.first().collaborators["unhelpful"], []) + assert Site.objects.first().collaborators["unhelpful"] == [] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__user=["Ross"] ) @@ -2280,7 +2289,7 @@ class TestQueryset(unittest.TestCase): bar = Bar(foos=[foo]).save() Bar.objects(id=bar.id).update(pull__foos=foo) bar.reload() - self.assertEqual(len(bar.foos), 0) + assert len(bar.foos) == 0 def test_update_one_check_return_with_full_result(self): class BlogTag(Document): @@ -2290,10 +2299,10 @@ class TestQueryset(unittest.TestCase): BlogTag(name="garbage").save() default_update = BlogTag.objects.update_one(name="new") - self.assertEqual(default_update, 1) + assert default_update == 1 full_result_update = BlogTag.objects.update_one(name="new", full_result=True) - self.assertIsInstance(full_result_update, UpdateResult) + assert isinstance(full_result_update, UpdateResult) def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2316,12 +2325,12 @@ class TestQueryset(unittest.TestCase): post = BlogPost(slug="test-2", tags=[tag_1, tag_2]) post.save() - self.assertEqual(len(post.tags), 2) + assert len(post.tags) == 2 BlogPost.objects(slug="test-2").update_one(pop__tags=-1) post.reload() - self.assertEqual(len(post.tags), 1) + assert len(post.tags) == 1 BlogPost.drop_collection() BlogTag.drop_collection() @@ -2344,15 +2353,15 @@ class TestQueryset(unittest.TestCase): post = BlogPost(slug="test-2", tags=[tag_1, tag_2]) post.save() - self.assertEqual(len(post.tags), 2) + assert len(post.tags) == 2 BlogPost.objects(slug="test-2").update_one(set__tags__0__name="python") post.reload() - self.assertEqual(post.tags[0].name, "python") + assert post.tags[0].name == "python" BlogPost.objects(slug="test-2").update_one(pop__tags=-1) post.reload() - self.assertEqual(len(post.tags), 1) + assert len(post.tags) == 1 BlogPost.drop_collection() @@ -2374,7 +2383,7 @@ class TestQueryset(unittest.TestCase): ) message = message.reload() - self.assertEqual(message.authors[0].name, "Ross") + assert message.authors[0].name == "Ross" Message.objects(authors__name="Ross").update_one( set__authors=[ @@ -2385,9 +2394,9 @@ class TestQueryset(unittest.TestCase): ) message = message.reload() - self.assertEqual(message.authors[0].name, "Harry") - self.assertEqual(message.authors[1].name, "Ross") - self.assertEqual(message.authors[2].name, "Adam") + assert message.authors[0].name == "Harry" + assert message.authors[1].name == "Ross" + assert message.authors[2].name == "Adam" def test_set_generic_embedded_documents(self): class Bar(EmbeddedDocument): @@ -2403,7 +2412,7 @@ class TestQueryset(unittest.TestCase): User.objects(username="abc").update(set__bar=Bar(name="test"), upsert=True) user = User.objects(username="abc").first() - self.assertEqual(user.bar.name, "test") + assert user.bar.name == "test" def test_reload_embedded_docs_instance(self): class SubDoc(EmbeddedDocument): @@ -2415,7 +2424,7 @@ class TestQueryset(unittest.TestCase): doc = Doc(embedded=SubDoc(val=0)).save() doc.reload() - self.assertEqual(doc.pk, doc.embedded._instance.pk) + assert doc.pk == doc.embedded._instance.pk def test_reload_list_embedded_docs_instance(self): class SubDoc(EmbeddedDocument): @@ -2427,7 +2436,7 @@ class TestQueryset(unittest.TestCase): doc = Doc(embedded=[SubDoc(val=0)]).save() doc.reload() - self.assertEqual(doc.pk, doc.embedded[0]._instance.pk) + assert doc.pk == doc.embedded[0]._instance.pk def test_order_by(self): """Ensure that QuerySets may be ordered. @@ -2437,16 +2446,16 @@ class TestQueryset(unittest.TestCase): self.Person(name="User C", age=30).save() names = [p.name for p in self.Person.objects.order_by("-age")] - self.assertEqual(names, ["User B", "User C", "User A"]) + assert names == ["User B", "User C", "User A"] names = [p.name for p in self.Person.objects.order_by("+age")] - self.assertEqual(names, ["User A", "User C", "User B"]) + assert names == ["User A", "User C", "User B"] names = [p.name for p in self.Person.objects.order_by("age")] - self.assertEqual(names, ["User A", "User C", "User B"]) + assert names == ["User A", "User C", "User B"] ages = [p.age for p in self.Person.objects.order_by("-name")] - self.assertEqual(ages, [30, 40, 20]) + assert ages == [30, 40, 20] def test_order_by_optional(self): class BlogPost(Document): @@ -2511,24 +2520,24 @@ class TestQueryset(unittest.TestCase): ages = [p.age for p in only_age] # The .only('age') clause should mean that all names are None - self.assertEqual(names, [None, None, None]) - self.assertEqual(ages, [40, 30, 20]) + assert names == [None, None, None] + assert ages == [40, 30, 20] qs = self.Person.objects.all().order_by("-age") qs = qs.limit(10) ages = [p.age for p in qs] - self.assertEqual(ages, [40, 30, 20]) + assert ages == [40, 30, 20] qs = self.Person.objects.all().limit(10) qs = qs.order_by("-age") ages = [p.age for p in qs] - self.assertEqual(ages, [40, 30, 20]) + assert ages == [40, 30, 20] qs = self.Person.objects.all().skip(0) qs = qs.order_by("-age") ages = [p.age for p in qs] - self.assertEqual(ages, [40, 30, 20]) + assert ages == [40, 30, 20] def test_confirm_order_by_reference_wont_work(self): """Ordering by reference is not possible. Use map / reduce.. or @@ -2551,7 +2560,7 @@ class TestQueryset(unittest.TestCase): Author(author=person_c).save() names = [a.author.name for a in Author.objects.order_by("-author__age")] - self.assertEqual(names, ["User A", "User B", "User C"]) + assert names == ["User A", "User B", "User C"] def test_comment(self): """Make sure adding a comment to the query gets added to the query""" @@ -2573,10 +2582,10 @@ class TestQueryset(unittest.TestCase): ) ops = q.get_ops() - self.assertEqual(len(ops), 2) + assert len(ops) == 2 for op in ops: - self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {"age": {"$gte": 18}}) - self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], "looking for an adult") + assert op[CMD_QUERY_KEY][QUERY_KEY] == {"age": {"$gte": 18}} + assert op[CMD_QUERY_KEY][COMMENT_KEY] == "looking for an adult" def test_map_reduce(self): """Ensure map/reduce is both mapping and reducing. @@ -2613,13 +2622,13 @@ class TestQueryset(unittest.TestCase): # run a map/reduce operation spanning all posts results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - self.assertEqual(len(results), 4) + assert len(results) == 4 music = list(filter(lambda r: r.key == "music", results))[0] - self.assertEqual(music.value, 2) + assert music.value == 2 film = list(filter(lambda r: r.key == "film", results))[0] - self.assertEqual(film.value, 3) + assert film.value == 3 BlogPost.drop_collection() @@ -2640,8 +2649,8 @@ class TestQueryset(unittest.TestCase): post2.save() post3.save() - self.assertEqual(BlogPost._fields["title"].db_field, "_id") - self.assertEqual(BlogPost._meta["id_field"], "title") + assert BlogPost._fields["title"].db_field == "_id" + assert BlogPost._meta["id_field"] == "title" map_f = """ function() { @@ -2663,9 +2672,9 @@ class TestQueryset(unittest.TestCase): results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - self.assertEqual(results[0].object, post1) - self.assertEqual(results[1].object, post2) - self.assertEqual(results[2].object, post3) + assert results[0].object == post1 + assert results[1].object == post2 + assert results[2].object == post3 BlogPost.drop_collection() @@ -2770,50 +2779,41 @@ class TestQueryset(unittest.TestCase): results = list(results) collection = get_db("test2").family_map - self.assertEqual( - collection.find_one({"_id": 1}), - { - "_id": 1, - "value": { - "persons": [ - {"age": 21, "name": u"Wilson Jr"}, - {"age": 45, "name": u"Wilson Father"}, - {"age": 40, "name": u"Eliana Costa"}, - {"age": 17, "name": u"Tayza Mariana"}, - ], - "totalAge": 123, - }, + assert collection.find_one({"_id": 1}) == { + "_id": 1, + "value": { + "persons": [ + {"age": 21, "name": u"Wilson Jr"}, + {"age": 45, "name": u"Wilson Father"}, + {"age": 40, "name": u"Eliana Costa"}, + {"age": 17, "name": u"Tayza Mariana"}, + ], + "totalAge": 123, }, - ) + } - self.assertEqual( - collection.find_one({"_id": 2}), - { - "_id": 2, - "value": { - "persons": [ - {"age": 16, "name": u"Isabella Luanna"}, - {"age": 36, "name": u"Sandra Mara"}, - {"age": 10, "name": u"Igor Gabriel"}, - ], - "totalAge": 62, - }, + assert collection.find_one({"_id": 2}) == { + "_id": 2, + "value": { + "persons": [ + {"age": 16, "name": u"Isabella Luanna"}, + {"age": 36, "name": u"Sandra Mara"}, + {"age": 10, "name": u"Igor Gabriel"}, + ], + "totalAge": 62, }, - ) + } - self.assertEqual( - collection.find_one({"_id": 3}), - { - "_id": 3, - "value": { - "persons": [ - {"age": 30, "name": u"Arthur WA"}, - {"age": 25, "name": u"Paula Leonel"}, - ], - "totalAge": 55, - }, + assert collection.find_one({"_id": 3}) == { + "_id": 3, + "value": { + "persons": [ + {"age": 30, "name": u"Arthur WA"}, + {"age": 25, "name": u"Paula Leonel"}, + ], + "totalAge": 55, }, - ) + } def test_map_reduce_finalize(self): """Ensure that map, reduce, and finalize run and introduce "scope" @@ -2933,10 +2933,10 @@ class TestQueryset(unittest.TestCase): results = list(results) # assert troublesome Buzz article is ranked 1st - self.assertTrue(results[0].object.title.startswith("Google Buzz")) + assert results[0].object.title.startswith("Google Buzz") # assert laser vision is ranked last - self.assertTrue(results[-1].object.title.startswith("How to see")) + assert results[-1].object.title.startswith("How to see") Link.drop_collection() @@ -2956,11 +2956,11 @@ class TestQueryset(unittest.TestCase): def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["music", "film", "actors", "watch"]), set(f.keys())) - self.assertEqual(f["music"], 3) - self.assertEqual(f["actors"], 2) - self.assertEqual(f["watch"], 2) - self.assertEqual(f["film"], 1) + assert set(["music", "film", "actors", "watch"]) == set(f.keys()) + assert f["music"] == 3 + assert f["actors"] == 2 + assert f["watch"] == 2 + assert f["film"] == 1 exec_js = BlogPost.objects.item_frequencies("tags") map_reduce = BlogPost.objects.item_frequencies("tags", map_reduce=True) @@ -2970,10 +2970,10 @@ class TestQueryset(unittest.TestCase): # Ensure query is taken into account def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["music", "actors", "watch"]), set(f.keys())) - self.assertEqual(f["music"], 2) - self.assertEqual(f["actors"], 1) - self.assertEqual(f["watch"], 1) + assert set(["music", "actors", "watch"]) == set(f.keys()) + assert f["music"] == 2 + assert f["actors"] == 1 + assert f["watch"] == 1 exec_js = BlogPost.objects(hits__gt=1).item_frequencies("tags") map_reduce = BlogPost.objects(hits__gt=1).item_frequencies( @@ -2984,10 +2984,10 @@ class TestQueryset(unittest.TestCase): # Check that normalization works def test_assertions(f): - self.assertAlmostEqual(f["music"], 3.0 / 8.0) - self.assertAlmostEqual(f["actors"], 2.0 / 8.0) - self.assertAlmostEqual(f["watch"], 2.0 / 8.0) - self.assertAlmostEqual(f["film"], 1.0 / 8.0) + assert round(abs(f["music"] - 3.0 / 8.0), 7) == 0 + assert round(abs(f["actors"] - 2.0 / 8.0), 7) == 0 + assert round(abs(f["watch"] - 2.0 / 8.0), 7) == 0 + assert round(abs(f["film"] - 1.0 / 8.0), 7) == 0 exec_js = BlogPost.objects.item_frequencies("tags", normalize=True) map_reduce = BlogPost.objects.item_frequencies( @@ -2998,9 +2998,9 @@ class TestQueryset(unittest.TestCase): # Check item_frequencies works for non-list fields def test_assertions(f): - self.assertEqual(set([1, 2]), set(f.keys())) - self.assertEqual(f[1], 1) - self.assertEqual(f[2], 2) + assert set([1, 2]) == set(f.keys()) + assert f[1] == 1 + assert f[2] == 2 exec_js = BlogPost.objects.item_frequencies("hits") map_reduce = BlogPost.objects.item_frequencies("hits", map_reduce=True) @@ -3036,9 +3036,9 @@ class TestQueryset(unittest.TestCase): def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["62-3331-1656", "62-3332-1656"]), set(f.keys())) - self.assertEqual(f["62-3331-1656"], 2) - self.assertEqual(f["62-3332-1656"], 1) + assert set(["62-3331-1656", "62-3332-1656"]) == set(f.keys()) + assert f["62-3331-1656"] == 2 + assert f["62-3332-1656"] == 1 exec_js = Person.objects.item_frequencies("phone.number") map_reduce = Person.objects.item_frequencies("phone.number", map_reduce=True) @@ -3048,8 +3048,8 @@ class TestQueryset(unittest.TestCase): # Ensure query is taken into account def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["62-3331-1656"]), set(f.keys())) - self.assertEqual(f["62-3331-1656"], 2) + assert set(["62-3331-1656"]) == set(f.keys()) + assert f["62-3331-1656"] == 2 exec_js = Person.objects(phone__number="62-3331-1656").item_frequencies( "phone.number" @@ -3062,8 +3062,8 @@ class TestQueryset(unittest.TestCase): # Check that normalization works def test_assertions(f): - self.assertEqual(f["62-3331-1656"], 2.0 / 3.0) - self.assertEqual(f["62-3332-1656"], 1.0 / 3.0) + assert f["62-3331-1656"] == 2.0 / 3.0 + assert f["62-3332-1656"] == 1.0 / 3.0 exec_js = Person.objects.item_frequencies("phone.number", normalize=True) map_reduce = Person.objects.item_frequencies( @@ -3083,14 +3083,14 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson Jr").save() freq = Person.objects.item_frequencies("city") - self.assertEqual(freq, {"CRB": 1.0, None: 1.0}) + assert freq == {"CRB": 1.0, None: 1.0} freq = Person.objects.item_frequencies("city", normalize=True) - self.assertEqual(freq, {"CRB": 0.5, None: 0.5}) + assert freq == {"CRB": 0.5, None: 0.5} freq = Person.objects.item_frequencies("city", map_reduce=True) - self.assertEqual(freq, {"CRB": 1.0, None: 1.0}) + assert freq == {"CRB": 1.0, None: 1.0} freq = Person.objects.item_frequencies("city", normalize=True, map_reduce=True) - self.assertEqual(freq, {"CRB": 0.5, None: 0.5}) + assert freq == {"CRB": 0.5, None: 0.5} def test_item_frequencies_with_null_embedded(self): class Data(EmbeddedDocument): @@ -3115,10 +3115,10 @@ class TestQueryset(unittest.TestCase): p.save() ot = Person.objects.item_frequencies("extra.tag", map_reduce=False) - self.assertEqual(ot, {None: 1.0, u"friend": 1.0}) + assert ot == {None: 1.0, u"friend": 1.0} ot = Person.objects.item_frequencies("extra.tag", map_reduce=True) - self.assertEqual(ot, {None: 1.0, u"friend": 1.0}) + assert ot == {None: 1.0, u"friend": 1.0} def test_item_frequencies_with_0_values(self): class Test(Document): @@ -3130,9 +3130,9 @@ class TestQueryset(unittest.TestCase): t.save() ot = Test.objects.item_frequencies("val", map_reduce=True) - self.assertEqual(ot, {0: 1}) + assert ot == {0: 1} ot = Test.objects.item_frequencies("val", map_reduce=False) - self.assertEqual(ot, {0: 1}) + assert ot == {0: 1} def test_item_frequencies_with_False_values(self): class Test(Document): @@ -3144,9 +3144,9 @@ class TestQueryset(unittest.TestCase): t.save() ot = Test.objects.item_frequencies("val", map_reduce=True) - self.assertEqual(ot, {False: 1}) + assert ot == {False: 1} ot = Test.objects.item_frequencies("val", map_reduce=False) - self.assertEqual(ot, {False: 1}) + assert ot == {False: 1} def test_item_frequencies_normalize(self): class Test(Document): @@ -3161,31 +3161,32 @@ class TestQueryset(unittest.TestCase): Test(val=2).save() freqs = Test.objects.item_frequencies("val", map_reduce=False, normalize=True) - self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) + assert freqs == {1: 50.0 / 70, 2: 20.0 / 70} freqs = Test.objects.item_frequencies("val", map_reduce=True, normalize=True) - self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) + assert freqs == {1: 50.0 / 70, 2: 20.0 / 70} def test_average(self): """Ensure that field can be averaged correctly. """ self.Person(name="person", age=0).save() - self.assertEqual(int(self.Person.objects.average("age")), 0) + assert int(self.Person.objects.average("age")) == 0 ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): self.Person(name="test%s" % i, age=age).save() avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 - self.assertAlmostEqual(int(self.Person.objects.average("age")), avg) + assert round(abs(int(self.Person.objects.average("age")) - avg), 7) == 0 self.Person(name="ageless person").save() - self.assertEqual(int(self.Person.objects.average("age")), avg) + assert int(self.Person.objects.average("age")) == avg # dot notation self.Person(name="person meta", person_meta=self.PersonMeta(weight=0)).save() - self.assertAlmostEqual( - int(self.Person.objects.average("person_meta.weight")), 0 + assert ( + round(abs(int(self.Person.objects.average("person_meta.weight")) - 0), 7) + == 0 ) for i, weight in enumerate(ages): @@ -3193,17 +3194,18 @@ class TestQueryset(unittest.TestCase): name="test meta%i", person_meta=self.PersonMeta(weight=weight) ).save() - self.assertAlmostEqual( - int(self.Person.objects.average("person_meta.weight")), avg + assert ( + round(abs(int(self.Person.objects.average("person_meta.weight")) - avg), 7) + == 0 ) self.Person(name="test meta none").save() - self.assertEqual(int(self.Person.objects.average("person_meta.weight")), avg) + assert int(self.Person.objects.average("person_meta.weight")) == avg # test summing over a filtered queryset over_50 = [a for a in ages if a >= 50] avg = float(sum(over_50)) / len(over_50) - self.assertEqual(self.Person.objects.filter(age__gte=50).average("age"), avg) + assert self.Person.objects.filter(age__gte=50).average("age") == avg def test_sum(self): """Ensure that field can be summed over correctly. @@ -3212,25 +3214,24 @@ class TestQueryset(unittest.TestCase): for i, age in enumerate(ages): self.Person(name="test%s" % i, age=age).save() - self.assertEqual(self.Person.objects.sum("age"), sum(ages)) + assert self.Person.objects.sum("age") == sum(ages) self.Person(name="ageless person").save() - self.assertEqual(self.Person.objects.sum("age"), sum(ages)) + assert self.Person.objects.sum("age") == sum(ages) for i, age in enumerate(ages): self.Person( name="test meta%s" % i, person_meta=self.PersonMeta(weight=age) ).save() - self.assertEqual(self.Person.objects.sum("person_meta.weight"), sum(ages)) + assert self.Person.objects.sum("person_meta.weight") == sum(ages) self.Person(name="weightless person").save() - self.assertEqual(self.Person.objects.sum("age"), sum(ages)) + assert self.Person.objects.sum("age") == sum(ages) # test summing over a filtered queryset - self.assertEqual( - self.Person.objects.filter(age__gte=50).sum("age"), - sum([a for a in ages if a >= 50]), + assert self.Person.objects.filter(age__gte=50).sum("age") == sum( + [a for a in ages if a >= 50] ) def test_sum_over_db_field(self): @@ -3246,7 +3247,7 @@ class TestQueryset(unittest.TestCase): UserVisit.objects.create(num_visits=10) UserVisit.objects.create(num_visits=5) - self.assertEqual(UserVisit.objects.sum("num_visits"), 15) + assert UserVisit.objects.sum("num_visits") == 15 def test_average_over_db_field(self): """Ensure that a field mapped to a db field with a different name @@ -3261,7 +3262,7 @@ class TestQueryset(unittest.TestCase): UserVisit.objects.create(num_visits=20) UserVisit.objects.create(num_visits=10) - self.assertEqual(UserVisit.objects.average("num_visits"), 15) + assert UserVisit.objects.average("num_visits") == 15 def test_embedded_average(self): class Pay(EmbeddedDocument): @@ -3278,7 +3279,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(value=165)).save() Doc(name="Eliana Costa", pay=Pay(value=115)).save() - self.assertEqual(Doc.objects.average("pay.value"), 240) + assert Doc.objects.average("pay.value") == 240 def test_embedded_array_average(self): class Pay(EmbeddedDocument): @@ -3295,7 +3296,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save() Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save() - self.assertEqual(Doc.objects.average("pay.values"), 170) + assert Doc.objects.average("pay.values") == 170 def test_array_average(self): class Doc(Document): @@ -3308,7 +3309,7 @@ class TestQueryset(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual(Doc.objects.average("values"), 170) + assert Doc.objects.average("values") == 170 def test_embedded_sum(self): class Pay(EmbeddedDocument): @@ -3325,7 +3326,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(value=165)).save() Doc(name="Eliana Costa", pay=Pay(value=115)).save() - self.assertEqual(Doc.objects.sum("pay.value"), 960) + assert Doc.objects.sum("pay.value") == 960 def test_embedded_array_sum(self): class Pay(EmbeddedDocument): @@ -3342,7 +3343,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save() Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save() - self.assertEqual(Doc.objects.sum("pay.values"), 1360) + assert Doc.objects.sum("pay.values") == 1360 def test_array_sum(self): class Doc(Document): @@ -3355,7 +3356,7 @@ class TestQueryset(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual(Doc.objects.sum("values"), 1360) + assert Doc.objects.sum("values") == 1360 def test_distinct(self): """Ensure that the QuerySet.distinct method works. @@ -3364,14 +3365,12 @@ class TestQueryset(unittest.TestCase): self.Person(name="Mr White", age=20).save() self.Person(name="Mr Orange", age=30).save() self.Person(name="Mr Pink", age=30).save() - self.assertEqual( - set(self.Person.objects.distinct("name")), - set(["Mr Orange", "Mr White", "Mr Pink"]), + assert set(self.Person.objects.distinct("name")) == set( + ["Mr Orange", "Mr White", "Mr Pink"] ) - self.assertEqual(set(self.Person.objects.distinct("age")), set([20, 30])) - self.assertEqual( - set(self.Person.objects(age=30).distinct("name")), - set(["Mr Orange", "Mr Pink"]), + assert set(self.Person.objects.distinct("age")) == set([20, 30]) + assert set(self.Person.objects(age=30).distinct("name")) == set( + ["Mr Orange", "Mr Pink"] ) def test_distinct_handles_references(self): @@ -3390,7 +3389,7 @@ class TestQueryset(unittest.TestCase): foo = Foo(bar=bar) foo.save() - self.assertEqual(Foo.objects.distinct("bar"), [bar]) + assert Foo.objects.distinct("bar") == [bar] def test_text_indexes(self): class News(Document): @@ -3410,8 +3409,8 @@ class TestQueryset(unittest.TestCase): News.drop_collection() info = News.objects._collection.index_information() - self.assertIn("title_text_content_text", info) - self.assertIn("textIndexVersion", info["title_text_content_text"]) + assert "title_text_content_text" in info + assert "textIndexVersion" in info["title_text_content_text"] News( title="Neymar quebrou a vertebra", @@ -3426,11 +3425,11 @@ class TestQueryset(unittest.TestCase): count = News.objects.search_text("neymar", language="portuguese").count() - self.assertEqual(count, 1) + assert count == 1 count = News.objects.search_text("brasil -neymar").count() - self.assertEqual(count, 1) + assert count == 1 News( title=u"As eleições no Brasil já estão em planejamento", @@ -3442,41 +3441,41 @@ class TestQueryset(unittest.TestCase): query = News.objects(is_active=False).search_text("dilma", language="pt")._query - self.assertEqual( - query, - {"$text": {"$search": "dilma", "$language": "pt"}, "is_active": False}, - ) + assert query == { + "$text": {"$search": "dilma", "$language": "pt"}, + "is_active": False, + } - self.assertFalse(new.is_active) - self.assertIn("dilma", new.content) - self.assertIn("planejamento", new.title) + assert not new.is_active + assert "dilma" in new.content + assert "planejamento" in new.title query = News.objects.search_text("candidata") - self.assertEqual(query._search_text, "candidata") + assert query._search_text == "candidata" new = query.first() - self.assertIsInstance(new.get_text_score(), float) + assert isinstance(new.get_text_score(), float) # count query = News.objects.search_text("brasil").order_by("$text_score") - self.assertEqual(query._search_text, "brasil") + assert query._search_text == "brasil" - self.assertEqual(query.count(), 3) - self.assertEqual(query._query, {"$text": {"$search": "brasil"}}) + assert query.count() == 3 + assert query._query == {"$text": {"$search": "brasil"}} cursor_args = query._cursor_args cursor_args_fields = cursor_args["projection"] - self.assertEqual(cursor_args_fields, {"_text_score": {"$meta": "textScore"}}) + assert cursor_args_fields == {"_text_score": {"$meta": "textScore"}} text_scores = [i.get_text_score() for i in query] - self.assertEqual(len(text_scores), 3) + assert len(text_scores) == 3 - self.assertTrue(text_scores[0] > text_scores[1]) - self.assertTrue(text_scores[1] > text_scores[2]) + assert text_scores[0] > text_scores[1] + assert text_scores[1] > text_scores[2] max_text_score = text_scores[0] # get item item = News.objects.search_text("brasil").order_by("$text_score").first() - self.assertEqual(item.get_text_score(), max_text_score) + assert item.get_text_score() == max_text_score def test_distinct_handles_references_to_alias(self): register_connection("testdb", "mongoenginetest2") @@ -3498,7 +3497,7 @@ class TestQueryset(unittest.TestCase): foo = Foo(bar=bar) foo.save() - self.assertEqual(Foo.objects.distinct("bar"), [bar]) + assert Foo.objects.distinct("bar") == [bar] def test_distinct_handles_db_field(self): """Ensure that distinct resolves field name to db_field as expected. @@ -3513,8 +3512,8 @@ class TestQueryset(unittest.TestCase): Product(product_id=2).save() Product(product_id=1).save() - self.assertEqual(set(Product.objects.distinct("product_id")), set([1, 2])) - self.assertEqual(set(Product.objects.distinct("pid")), set([1, 2])) + assert set(Product.objects.distinct("product_id")) == set([1, 2]) + assert set(Product.objects.distinct("pid")) == set([1, 2]) Product.drop_collection() @@ -3536,7 +3535,7 @@ class TestQueryset(unittest.TestCase): Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) authors = Book.objects.distinct("authors") - self.assertEqual(authors, [mark_twain, john_tolkien]) + assert authors == [mark_twain, john_tolkien] def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): class Continent(EmbeddedDocument): @@ -3570,10 +3569,10 @@ class TestQueryset(unittest.TestCase): Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) country_list = Book.objects.distinct("authors.country") - self.assertEqual(country_list, [scotland, tibet]) + assert country_list == [scotland, tibet] continent_list = Book.objects.distinct("authors.country.continent") - self.assertEqual(continent_list, [europe, asia]) + assert continent_list == [europe, asia] def test_distinct_ListField_ReferenceField(self): class Bar(Document): @@ -3595,7 +3594,7 @@ class TestQueryset(unittest.TestCase): foo = Foo(bar=bar_1, bar_lst=[bar_1, bar_2]) foo.save() - self.assertEqual(Foo.objects.distinct("bar_lst"), [bar_1, bar_2]) + assert Foo.objects.distinct("bar_lst") == [bar_1, bar_2] def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. @@ -3627,15 +3626,15 @@ class TestQueryset(unittest.TestCase): post3 = BlogPost(tags=["film", "actors"]).save() post4 = BlogPost(tags=["film", "actors", "music"], deleted=True).save() - self.assertEqual( - [p.id for p in BlogPost.objects()], [post1.id, post2.id, post3.id] - ) - self.assertEqual( - [p.id for p in BlogPost.objects_1_arg()], [post1.id, post2.id, post3.id] - ) - self.assertEqual([p.id for p in BlogPost.music_posts()], [post1.id, post2.id]) + assert [p.id for p in BlogPost.objects()] == [post1.id, post2.id, post3.id] + assert [p.id for p in BlogPost.objects_1_arg()] == [ + post1.id, + post2.id, + post3.id, + ] + assert [p.id for p in BlogPost.music_posts()] == [post1.id, post2.id] - self.assertEqual([p.id for p in BlogPost.music_posts(True)], [post4.id]) + assert [p.id for p in BlogPost.music_posts(True)] == [post4.id] BlogPost.drop_collection() @@ -3657,12 +3656,12 @@ class TestQueryset(unittest.TestCase): Foo(active=True).save() Foo(active=False).save() - self.assertEqual(1, Foo.objects.count()) - self.assertEqual(1, Foo.with_inactive.count()) + assert 1 == Foo.objects.count() + assert 1 == Foo.with_inactive.count() Foo.with_inactive.first().delete() - self.assertEqual(0, Foo.with_inactive.count()) - self.assertEqual(1, Foo.objects.count()) + assert 0 == Foo.with_inactive.count() + assert 1 == Foo.objects.count() def test_inherit_objects(self): class Foo(Document): @@ -3678,7 +3677,7 @@ class TestQueryset(unittest.TestCase): Bar.drop_collection() Bar.objects.create(active=False) - self.assertEqual(0, Bar.objects.count()) + assert 0 == Bar.objects.count() def test_inherit_objects_override(self): class Foo(Document): @@ -3696,8 +3695,8 @@ class TestQueryset(unittest.TestCase): Bar.drop_collection() Bar.objects.create(active=False) - self.assertEqual(0, Foo.objects.count()) - self.assertEqual(1, Bar.objects.count()) + assert 0 == Foo.objects.count() + assert 1 == Bar.objects.count() def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. @@ -3718,11 +3717,11 @@ class TestQueryset(unittest.TestCase): # while using a ReferenceField's name - the document should be # converted to an DBRef, which is legal, unlike a Document object post_obj = BlogPost.objects(author=person).first() - self.assertEqual(post.id, post_obj.id) + assert post.id == post_obj.id # Test that lists of values work when using the 'in', 'nin' and 'all' post_obj = BlogPost.objects(author__in=[person]).first() - self.assertEqual(post.id, post_obj.id) + assert post.id == post_obj.id BlogPost.drop_collection() @@ -3746,9 +3745,9 @@ class TestQueryset(unittest.TestCase): Group.objects(id=group.id).update(set__members=[user1, user2]) group.reload() - self.assertEqual(len(group.members), 2) - self.assertEqual(group.members[0].name, user1.name) - self.assertEqual(group.members[1].name, user2.name) + assert len(group.members) == 2 + assert group.members[0].name == user1.name + assert group.members[1].name == user2.name Group.drop_collection() @@ -3776,15 +3775,15 @@ class TestQueryset(unittest.TestCase): ids = [post_1.id, post_2.id, post_5.id] objects = BlogPost.objects.in_bulk(ids) - self.assertEqual(len(objects), 3) + assert len(objects) == 3 - self.assertIn(post_1.id, objects) - self.assertIn(post_2.id, objects) - self.assertIn(post_5.id, objects) + assert post_1.id in objects + assert post_2.id in objects + assert post_5.id in objects - self.assertEqual(objects[post_1.id].title, post_1.title) - self.assertEqual(objects[post_2.id].title, post_2.title) - self.assertEqual(objects[post_5.id].title, post_5.title) + assert objects[post_1.id].title == post_1.title + assert objects[post_2.id].title == post_2.title + assert objects[post_5.id].title == post_5.title BlogPost.drop_collection() @@ -3804,11 +3803,11 @@ class TestQueryset(unittest.TestCase): Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3828,11 +3827,11 @@ class TestQueryset(unittest.TestCase): Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3853,8 +3852,8 @@ class TestQueryset(unittest.TestCase): Post().save() Post(is_published=True).save() - self.assertEqual(Post.objects.count(), 2) - self.assertEqual(Post.published.count(), 1) + assert Post.objects.count() == 2 + assert Post.published.count() == 1 Post.drop_collection() @@ -3873,11 +3872,11 @@ class TestQueryset(unittest.TestCase): pass Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3900,11 +3899,11 @@ class TestQueryset(unittest.TestCase): pass Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3917,13 +3916,9 @@ class TestQueryset(unittest.TestCase): for i in range(10): Post(title="Post %s" % i).save() - self.assertEqual( - 5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True) - ) + assert 5 == Post.objects.limit(5).skip(5).count(with_limit_and_skip=True) - self.assertEqual( - 10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False) - ) + assert 10 == Post.objects.limit(5).skip(5).count(with_limit_and_skip=False) def test_count_and_none(self): """Test count works with None()""" @@ -3935,8 +3930,8 @@ class TestQueryset(unittest.TestCase): for i in range(0, 10): MyDoc().save() - self.assertEqual(MyDoc.objects.count(), 10) - self.assertEqual(MyDoc.objects.none().count(), 0) + assert MyDoc.objects.count() == 10 + assert MyDoc.objects.none().count() == 0 def test_count_list_embedded(self): class B(EmbeddedDocument): @@ -3945,7 +3940,7 @@ class TestQueryset(unittest.TestCase): class A(Document): b = ListField(EmbeddedDocumentField(B)) - self.assertEqual(A.objects(b=[{"c": "c"}]).count(), 0) + assert A.objects(b=[{"c": "c"}]).count() == 0 def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works @@ -3960,7 +3955,7 @@ class TestQueryset(unittest.TestCase): Post(title="Post 2").save() posts = Post.objects.all()[0:1] - self.assertEqual(len(list(posts())), 1) + assert len(list(posts())) == 1 Post.drop_collection() @@ -3976,9 +3971,9 @@ class TestQueryset(unittest.TestCase): n2 = Number.objects.create(n=2) n1 = Number.objects.create(n=1) - self.assertEqual(list(Number.objects), [n2, n1]) - self.assertEqual(list(Number.objects.order_by("n")), [n1, n2]) - self.assertEqual(list(Number.objects.order_by("n").filter()), [n1, n2]) + assert list(Number.objects) == [n2, n1] + assert list(Number.objects.order_by("n")) == [n1, n2] + assert list(Number.objects.order_by("n").filter()) == [n1, n2] Number.drop_collection() @@ -3997,18 +3992,18 @@ class TestQueryset(unittest.TestCase): test = Number.objects test2 = test.clone() - self.assertNotEqual(test, test2) - self.assertEqual(test.count(), test2.count()) + assert test != test2 + assert test.count() == test2.count() test = test.filter(n__gt=11) test2 = test.clone() - self.assertNotEqual(test, test2) - self.assertEqual(test.count(), test2.count()) + assert test != test2 + assert test.count() == test2.count() test = test.limit(10) test2 = test.clone() - self.assertNotEqual(test, test2) - self.assertEqual(test.count(), test2.count()) + assert test != test2 + assert test.count() == test2.count() Number.drop_collection() @@ -4028,7 +4023,7 @@ class TestQueryset(unittest.TestCase): t.switch_db("test2") t.save() - self.assertEqual(len(Number2.objects.using("test2")), 9) + assert len(Number2.objects.using("test2")) == 9 def test_unset_reference(self): class Comment(Document): @@ -4043,10 +4038,10 @@ class TestQueryset(unittest.TestCase): comment = Comment.objects.create(text="test") post = Post.objects.create(comment=comment) - self.assertEqual(post.comment, comment) + assert post.comment == comment Post.objects.update(unset__comment=1) post.reload() - self.assertEqual(post.comment, None) + assert post.comment == None Comment.drop_collection() Post.drop_collection() @@ -4060,8 +4055,8 @@ class TestQueryset(unittest.TestCase): n2 = Number.objects.create(n=2) n1 = Number.objects.create(n=1) - self.assertEqual(list(Number.objects), [n2, n1]) - self.assertEqual(list(Number.objects.order_by("n")), [n1, n2]) + assert list(Number.objects) == [n2, n1] + assert list(Number.objects.order_by("n")) == [n1, n2] Number.drop_collection() @@ -4079,10 +4074,10 @@ class TestQueryset(unittest.TestCase): Number(n=3).save() numbers = [n.n for n in Number.objects.order_by("-n")] - self.assertEqual([3, 2, 1], numbers) + assert [3, 2, 1] == numbers numbers = [n.n for n in Number.objects.order_by("+n")] - self.assertEqual([1, 2, 3], numbers) + assert [1, 2, 3] == numbers Number.drop_collection() def test_ensure_index(self): @@ -4100,7 +4095,7 @@ class TestQueryset(unittest.TestCase): (value["key"], value.get("unique", False), value.get("sparse", False)) for key, value in iteritems(info) ] - self.assertIn(([("_cls", 1), ("message", 1)], False, False), info) + assert ([("_cls", 1), ("message", 1)], False, False) in info def test_where(self): """Ensure that where clauses work. @@ -4120,30 +4115,30 @@ class TestQueryset(unittest.TestCase): c.save() query = IntPair.objects.where("this[~fielda] >= this[~fieldb]") - self.assertEqual('this["fielda"] >= this["fieldb"]', query._where_clause) + assert 'this["fielda"] >= this["fieldb"]' == query._where_clause results = list(query) - self.assertEqual(2, len(results)) - self.assertIn(a, results) - self.assertIn(c, results) + assert 2 == len(results) + assert a in results + assert c in results query = IntPair.objects.where("this[~fielda] == this[~fieldb]") results = list(query) - self.assertEqual(1, len(results)) - self.assertIn(a, results) + assert 1 == len(results) + assert a in results query = IntPair.objects.where( "function() { return this[~fielda] >= this[~fieldb] }" ) - self.assertEqual( - 'function() { return this["fielda"] >= this["fieldb"] }', - query._where_clause, + assert ( + 'function() { return this["fielda"] >= this["fieldb"] }' + == query._where_clause ) results = list(query) - self.assertEqual(2, len(results)) - self.assertIn(a, results) - self.assertIn(c, results) + assert 2 == len(results) + assert a in results + assert c in results - with self.assertRaises(TypeError): + with pytest.raises(TypeError): list(IntPair.objects.where(fielda__gte=3)) def test_scalar(self): @@ -4165,13 +4160,13 @@ class TestQueryset(unittest.TestCase): # set of users (Pretend this has additional filtering.) user_orgs = set(User.objects.scalar("organization")) orgs = Organization.objects(id__in=user_orgs).scalar("name") - self.assertEqual(list(orgs), ["White House"]) + assert list(orgs) == ["White House"] # Efficient for generating listings, too. orgs = Organization.objects.scalar("name").in_bulk(list(user_orgs)) user_map = User.objects.scalar("name", "organization") user_listing = [(user, orgs[org]) for user, org in user_map] - self.assertEqual([("Bob Dole", "White House")], user_listing) + assert [("Bob Dole", "White House")] == user_listing def test_scalar_simple(self): class TestDoc(Document): @@ -4186,10 +4181,10 @@ class TestQueryset(unittest.TestCase): plist = list(TestDoc.objects.scalar("x", "y")) - self.assertEqual(len(plist), 3) - self.assertEqual(plist[0], (10, True)) - self.assertEqual(plist[1], (20, False)) - self.assertEqual(plist[2], (30, True)) + assert len(plist) == 3 + assert plist[0] == (10, True) + assert plist[1] == (20, False) + assert plist[2] == (30, True) class UserDoc(Document): name = StringField() @@ -4204,14 +4199,16 @@ class TestQueryset(unittest.TestCase): ulist = list(UserDoc.objects.scalar("name", "age")) - self.assertEqual( - ulist, - [(u"Wilson Jr", 19), (u"Wilson", 43), (u"Eliana", 37), (u"Tayza", 15)], - ) + assert ulist == [ + (u"Wilson Jr", 19), + (u"Wilson", 43), + (u"Eliana", 37), + (u"Tayza", 15), + ] ulist = list(UserDoc.objects.scalar("name").order_by("age")) - self.assertEqual(ulist, [(u"Tayza"), (u"Wilson Jr"), (u"Eliana"), (u"Wilson")]) + assert ulist == [(u"Tayza"), (u"Wilson Jr"), (u"Eliana"), (u"Wilson")] def test_scalar_embedded(self): class Profile(EmbeddedDocument): @@ -4248,25 +4245,21 @@ class TestQueryset(unittest.TestCase): locale=Locale(city="Brasilia", country="Brazil"), ).save() - self.assertEqual( - list(Person.objects.order_by("profile__age").scalar("profile__name")), - [u"Wilson Jr", u"Gabriel Falcao", u"Lincoln de souza", u"Walter cruz"], - ) + assert list( + Person.objects.order_by("profile__age").scalar("profile__name") + ) == [u"Wilson Jr", u"Gabriel Falcao", u"Lincoln de souza", u"Walter cruz"] ulist = list( Person.objects.order_by("locale.city").scalar( "profile__name", "profile__age", "locale__city" ) ) - self.assertEqual( - ulist, - [ - (u"Lincoln de souza", 28, u"Belo Horizonte"), - (u"Walter cruz", 30, u"Brasilia"), - (u"Wilson Jr", 19, u"Corumba-GO"), - (u"Gabriel Falcao", 23, u"New York"), - ], - ) + assert ulist == [ + (u"Lincoln de souza", 28, u"Belo Horizonte"), + (u"Walter cruz", 30, u"Brasilia"), + (u"Wilson Jr", 19, u"Corumba-GO"), + (u"Gabriel Falcao", 23, u"New York"), + ] def test_scalar_decimal(self): from decimal import Decimal @@ -4279,7 +4272,7 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson Jr", rating=Decimal("1.0")).save() ulist = list(Person.objects.scalar("name", "rating")) - self.assertEqual(ulist, [(u"Wilson Jr", Decimal("1.0"))]) + assert ulist == [(u"Wilson Jr", Decimal("1.0"))] def test_scalar_reference_field(self): class State(Document): @@ -4298,7 +4291,7 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson JR", state=s1).save() plist = list(Person.objects.scalar("name", "state")) - self.assertEqual(plist, [(u"Wilson JR", s1)]) + assert plist == [(u"Wilson JR", s1)] def test_scalar_generic_reference_field(self): class State(Document): @@ -4317,7 +4310,7 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson JR", state=s1).save() plist = list(Person.objects.scalar("name", "state")) - self.assertEqual(plist, [(u"Wilson JR", s1)]) + assert plist == [(u"Wilson JR", s1)] def test_generic_reference_field_with_only_and_as_pymongo(self): class TestPerson(Document): @@ -4342,18 +4335,18 @@ class TestQueryset(unittest.TestCase): .no_dereference() .first() ) - self.assertEqual(activity[0], a1.pk) - self.assertEqual(activity[1]["_ref"], DBRef("test_person", person.pk)) + assert activity[0] == a1.pk + assert activity[1]["_ref"] == DBRef("test_person", person.pk) activity = TestActivity.objects(owner=person).only("id", "owner")[0] - self.assertEqual(activity.pk, a1.pk) - self.assertEqual(activity.owner, person) + assert activity.pk == a1.pk + assert activity.owner == person activity = ( TestActivity.objects(owner=person).only("id", "owner").as_pymongo().first() ) - self.assertEqual(activity["_id"], a1.pk) - self.assertTrue(activity["owner"]["_ref"], DBRef("test_person", person.pk)) + assert activity["_id"] == a1.pk + assert activity["owner"]["_ref"], DBRef("test_person", person.pk) def test_scalar_db_field(self): class TestDoc(Document): @@ -4367,10 +4360,10 @@ class TestQueryset(unittest.TestCase): TestDoc(x=30, y=True).save() plist = list(TestDoc.objects.scalar("x", "y")) - self.assertEqual(len(plist), 3) - self.assertEqual(plist[0], (10, True)) - self.assertEqual(plist[1], (20, False)) - self.assertEqual(plist[2], (30, True)) + assert len(plist) == 3 + assert plist[0] == (10, True) + assert plist[1] == (20, False) + assert plist[2] == (30, True) def test_scalar_primary_key(self): class SettingValue(Document): @@ -4382,7 +4375,7 @@ class TestQueryset(unittest.TestCase): s.save() val = SettingValue.objects.scalar("key", "value") - self.assertEqual(list(val), [("test", "test value")]) + assert list(val) == [("test", "test value")] def test_scalar_cursor_behaviour(self): """Ensure that a query returns a valid set of results. @@ -4394,90 +4387,86 @@ class TestQueryset(unittest.TestCase): # Find all people in the collection people = self.Person.objects.scalar("name") - self.assertEqual(people.count(), 2) + assert people.count() == 2 results = list(people) - self.assertEqual(results[0], "User A") - self.assertEqual(results[1], "User B") + assert results[0] == "User A" + assert results[1] == "User B" # Use a query to filter the people found to just person1 people = self.Person.objects(age=20).scalar("name") - self.assertEqual(people.count(), 1) + assert people.count() == 1 person = people.next() - self.assertEqual(person, "User A") + assert person == "User A" # Test limit people = list(self.Person.objects.limit(1).scalar("name")) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], "User A") + assert len(people) == 1 + assert people[0] == "User A" # Test skip people = list(self.Person.objects.skip(1).scalar("name")) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], "User B") + assert len(people) == 1 + assert people[0] == "User B" person3 = self.Person(name="User C", age=40) person3.save() # Test slice limit people = list(self.Person.objects[:2].scalar("name")) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], "User A") - self.assertEqual(people[1], "User B") + assert len(people) == 2 + assert people[0] == "User A" + assert people[1] == "User B" # Test slice skip people = list(self.Person.objects[1:].scalar("name")) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], "User B") - self.assertEqual(people[1], "User C") + assert len(people) == 2 + assert people[0] == "User B" + assert people[1] == "User C" # Test slice limit and skip people = list(self.Person.objects[1:2].scalar("name")) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], "User B") + assert len(people) == 1 + assert people[0] == "User B" people = list(self.Person.objects[1:1].scalar("name")) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test slice out of range people = list(self.Person.objects.scalar("name")[80000:80001]) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test larger slice __repr__ self.Person.objects.delete() for i in range(55): self.Person(name="A%s" % i, age=i).save() - self.assertEqual(self.Person.objects.scalar("name").count(), 55) - self.assertEqual( - "A0", "%s" % self.Person.objects.order_by("name").scalar("name").first() - ) - self.assertEqual( - "A0", "%s" % self.Person.objects.scalar("name").order_by("name")[0] + assert self.Person.objects.scalar("name").count() == 55 + assert ( + "A0" == "%s" % self.Person.objects.order_by("name").scalar("name").first() ) + assert "A0" == "%s" % self.Person.objects.scalar("name").order_by("name")[0] if six.PY3: - self.assertEqual( - "['A1', 'A2']", - "%s" % self.Person.objects.order_by("age").scalar("name")[1:3], + assert ( + "['A1', 'A2']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[1:3] ) - self.assertEqual( - "['A51', 'A52']", - "%s" % self.Person.objects.order_by("age").scalar("name")[51:53], + assert ( + "['A51', 'A52']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[51:53] ) else: - self.assertEqual( - "[u'A1', u'A2']", - "%s" % self.Person.objects.order_by("age").scalar("name")[1:3], + assert ( + "[u'A1', u'A2']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[1:3] ) - self.assertEqual( - "[u'A51', u'A52']", - "%s" % self.Person.objects.order_by("age").scalar("name")[51:53], + assert ( + "[u'A51', u'A52']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[51:53] ) # with_id and in_bulk person = self.Person.objects.order_by("name").first() - self.assertEqual( - "A0", "%s" % self.Person.objects.scalar("name").with_id(person.id) - ) + assert "A0" == "%s" % self.Person.objects.scalar("name").with_id(person.id) pks = self.Person.objects.order_by("age").scalar("pk")[1:3] names = self.Person.objects.scalar("name").in_bulk(list(pks)).values() @@ -4485,7 +4474,7 @@ class TestQueryset(unittest.TestCase): expected = "['A1', 'A2']" else: expected = "[u'A1', u'A2']" - self.assertEqual(expected, "%s" % sorted(names)) + assert expected == "%s" % sorted(names) def test_elem_match(self): class Foo(EmbeddedDocument): @@ -4525,29 +4514,29 @@ class TestQueryset(unittest.TestCase): b3.save() ak = list(Bar.objects(foo__match={"shape": "square", "color": "purple"})) - self.assertEqual([b1], ak) + assert [b1] == ak ak = list(Bar.objects(foo__elemMatch={"shape": "square", "color": "purple"})) - self.assertEqual([b1], ak) + assert [b1] == ak ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) - self.assertEqual([b1], ak) + assert [b1] == ak ak = list( Bar.objects(foo__elemMatch={"shape": "square", "color__exists": True}) ) - self.assertEqual([b1, b2], ak) + assert [b1, b2] == ak ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": True})) - self.assertEqual([b1, b2], ak) + assert [b1, b2] == ak ak = list( Bar.objects(foo__elemMatch={"shape": "square", "color__exists": False}) ) - self.assertEqual([b3], ak) + assert [b3] == ak ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": False})) - self.assertEqual([b3], ak) + assert [b3] == ak def test_upsert_includes_cls(self): """Upserts should include _cls information for inheritable classes @@ -4558,7 +4547,7 @@ class TestQueryset(unittest.TestCase): Test.drop_collection() Test.objects(test="foo").update_one(upsert=True, set__test="foo") - self.assertNotIn("_cls", Test._collection.find_one()) + assert "_cls" not in Test._collection.find_one() class Test(Document): meta = {"allow_inheritance": True} @@ -4567,15 +4556,15 @@ class TestQueryset(unittest.TestCase): Test.drop_collection() Test.objects(test="foo").update_one(upsert=True, set__test="foo") - self.assertIn("_cls", Test._collection.find_one()) + assert "_cls" in Test._collection.find_one() def test_update_upsert_looks_like_a_digit(self): class MyDoc(DynamicDocument): pass MyDoc.drop_collection() - self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1)) - self.assertEqual(MyDoc.objects.get()["47"], 1) + assert 1 == MyDoc.objects.update_one(upsert=True, inc__47=1) + assert MyDoc.objects.get()["47"] == 1 def test_dictfield_key_looks_like_a_digit(self): """Only should work with DictField even if they have numeric keys.""" @@ -4586,7 +4575,7 @@ class TestQueryset(unittest.TestCase): MyDoc.drop_collection() doc = MyDoc(test={"47": 1}) doc.save() - self.assertEqual(MyDoc.objects.only("test__47").get().test["47"], 1) + assert MyDoc.objects.only("test__47").get().test["47"] == 1 def test_clear_cls_query(self): class Parent(Document): @@ -4599,32 +4588,28 @@ class TestQueryset(unittest.TestCase): Parent.drop_collection() # Default query includes the "_cls" check. - self.assertEqual( - Parent.objects._query, {"_cls": {"$in": ("Parent", "Parent.Child")}} - ) + assert Parent.objects._query == {"_cls": {"$in": ("Parent", "Parent.Child")}} # Clearing the "_cls" query should work. - self.assertEqual(Parent.objects.clear_cls_query()._query, {}) + assert Parent.objects.clear_cls_query()._query == {} # Clearing the "_cls" query should not persist across queryset instances. - self.assertEqual( - Parent.objects._query, {"_cls": {"$in": ("Parent", "Parent.Child")}} - ) + assert Parent.objects._query == {"_cls": {"$in": ("Parent", "Parent.Child")}} # The rest of the query should not be cleared. - self.assertEqual( - Parent.objects.filter(name="xyz").clear_cls_query()._query, {"name": "xyz"} - ) + assert Parent.objects.filter(name="xyz").clear_cls_query()._query == { + "name": "xyz" + } Parent.objects.create(name="foo") Child.objects.create(name="bar", age=1) - self.assertEqual(Parent.objects.clear_cls_query().count(), 2) - self.assertEqual(Parent.objects.count(), 2) - self.assertEqual(Child.objects().count(), 1) + assert Parent.objects.clear_cls_query().count() == 2 + assert Parent.objects.count() == 2 + assert Child.objects().count() == 1 # XXX This isn't really how you'd want to use `clear_cls_query()`, but # it's a decent test to validate its behavior nonetheless. - self.assertEqual(Child.objects.clear_cls_query().count(), 2) + assert Child.objects.clear_cls_query().count() == 2 def test_read_preference(self): class Bar(Document): @@ -4636,20 +4621,21 @@ class TestQueryset(unittest.TestCase): bar = Bar.objects.create(txt="xyz") bars = list(Bar.objects.read_preference(ReadPreference.PRIMARY)) - self.assertEqual(bars, [bar]) + assert bars == [bar] bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) - self.assertEqual( - bars._cursor.collection.read_preference, ReadPreference.SECONDARY_PREFERRED + assert bars._read_preference == ReadPreference.SECONDARY_PREFERRED + assert ( + bars._cursor.collection.read_preference == ReadPreference.SECONDARY_PREFERRED ) # Make sure that `.read_preference(...)` does accept string values. - self.assertRaises(TypeError, Bar.objects.read_preference, "Primary") + with pytest.raises(TypeError): + Bar.objects.read_preference("Primary") def assert_read_pref(qs, expected_read_pref): - self.assertEqual(qs._read_preference, expected_read_pref) - self.assertEqual(qs._cursor.collection.read_preference, expected_read_pref) + assert qs._read_preference == expected_read_pref + assert qs._cursor.collection.read_preference == expected_read_pref # Make sure read preference is respected after a `.skip(...)`. bars = Bar.objects.skip(1).read_preference(ReadPreference.SECONDARY_PREFERRED) @@ -4681,9 +4667,9 @@ class TestQueryset(unittest.TestCase): bars = Bar.objects.read_preference( ReadPreference.SECONDARY_PREFERRED ).aggregate() - self.assertEqual( - bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED, + assert ( + bars._CommandCursor__collection.read_preference + == ReadPreference.SECONDARY_PREFERRED ) def test_json_simple(self): @@ -4702,7 +4688,7 @@ class TestQueryset(unittest.TestCase): json_data = Doc.objects.to_json(sort_keys=True, separators=(",", ":")) doc_objects = list(Doc.objects) - self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) + assert doc_objects == Doc.objects.from_json(json_data) def test_json_complex(self): class EmbeddedDoc(EmbeddedDocument): @@ -4748,7 +4734,7 @@ class TestQueryset(unittest.TestCase): json_data = Doc.objects.to_json() doc_objects = list(Doc.objects) - self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) + assert doc_objects == Doc.objects.from_json(json_data) def test_as_pymongo(self): class LastLogin(EmbeddedDocument): @@ -4774,36 +4760,33 @@ class TestQueryset(unittest.TestCase): ) results = User.objects.as_pymongo() - self.assertEqual(set(results[0].keys()), set(["_id", "name", "age", "price"])) - self.assertEqual( - set(results[1].keys()), set(["_id", "name", "age", "price", "last_login"]) + assert set(results[0].keys()) == set(["_id", "name", "age", "price"]) + assert set(results[1].keys()) == set( + ["_id", "name", "age", "price", "last_login"] ) results = User.objects.only("id", "name").as_pymongo() - self.assertEqual(set(results[0].keys()), set(["_id", "name"])) + assert set(results[0].keys()) == set(["_id", "name"]) users = User.objects.only("name", "price").as_pymongo() results = list(users) - self.assertIsInstance(results[0], dict) - self.assertIsInstance(results[1], dict) - self.assertEqual(results[0]["name"], "Bob Dole") - self.assertEqual(results[0]["price"], 1.11) - self.assertEqual(results[1]["name"], "Barak Obama") - self.assertEqual(results[1]["price"], 2.22) + assert isinstance(results[0], dict) + assert isinstance(results[1], dict) + assert results[0]["name"] == "Bob Dole" + assert results[0]["price"] == 1.11 + assert results[1]["name"] == "Barak Obama" + assert results[1]["price"] == 2.22 users = User.objects.only("name", "last_login").as_pymongo() results = list(users) - self.assertIsInstance(results[0], dict) - self.assertIsInstance(results[1], dict) - self.assertEqual(results[0], {"_id": "Bob", "name": "Bob Dole"}) - self.assertEqual( - results[1], - { - "_id": "Barak", - "name": "Barak Obama", - "last_login": {"location": "White House", "ip": "104.107.108.116"}, - }, - ) + assert isinstance(results[0], dict) + assert isinstance(results[1], dict) + assert results[0] == {"_id": "Bob", "name": "Bob Dole"} + assert results[1] == { + "_id": "Barak", + "name": "Barak Obama", + "last_login": {"location": "White House", "ip": "104.107.108.116"}, + } def test_as_pymongo_returns_cls_attribute_when_using_inheritance(self): class User(Document): @@ -4814,7 +4797,7 @@ class TestQueryset(unittest.TestCase): user = User(name="Bob Dole").save() result = User.objects.as_pymongo().first() - self.assertEqual(result, {"_cls": "User", "_id": user.id, "name": "Bob Dole"}) + assert result == {"_cls": "User", "_id": user.id, "name": "Bob Dole"} def test_as_pymongo_json_limit_fields(self): class User(Document): @@ -4830,30 +4813,30 @@ class TestQueryset(unittest.TestCase): serialized_user = User.objects.exclude( "password_salt", "password_hash" ).as_pymongo()[0] - self.assertEqual({"_id", "email"}, set(serialized_user.keys())) + assert {"_id", "email"} == set(serialized_user.keys()) serialized_user = User.objects.exclude( "id", "password_salt", "password_hash" ).to_json() - self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + assert '[{"email": "ross@example.com"}]' == serialized_user serialized_user = User.objects.only("email").as_pymongo()[0] - self.assertEqual({"_id", "email"}, set(serialized_user.keys())) + assert {"_id", "email"} == set(serialized_user.keys()) serialized_user = ( User.objects.exclude("password_salt").only("email").as_pymongo()[0] ) - self.assertEqual({"_id", "email"}, set(serialized_user.keys())) + assert {"_id", "email"} == set(serialized_user.keys()) serialized_user = ( User.objects.exclude("password_salt", "id").only("email").as_pymongo()[0] ) - self.assertEqual({"email"}, set(serialized_user.keys())) + assert {"email"} == set(serialized_user.keys()) serialized_user = ( User.objects.exclude("password_salt", "id").only("email").to_json() ) - self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + assert '[{"email": "ross@example.com"}]' == serialized_user def test_only_after_count(self): """Test that only() works after count()""" @@ -4869,13 +4852,13 @@ class TestQueryset(unittest.TestCase): user_queryset = User.objects(age=50) result = user_queryset.only("name", "age").as_pymongo().first() - self.assertEqual(result, {"_id": user.id, "name": "User", "age": 50}) + assert result == {"_id": user.id, "name": "User", "age": 50} result = user_queryset.count() - self.assertEqual(result, 1) + assert result == 1 result = user_queryset.only("name", "age").as_pymongo().first() - self.assertEqual(result, {"_id": user.id, "name": "User", "age": 50}) + assert result == {"_id": user.id, "name": "User", "age": 50} def test_no_dereference(self): class Organization(Document): @@ -4894,12 +4877,12 @@ class TestQueryset(unittest.TestCase): qs = User.objects() qs_user = qs.first() - self.assertIsInstance(qs.first().organization, Organization) + assert isinstance(qs.first().organization, Organization) - self.assertIsInstance(qs.no_dereference().first().organization, DBRef) + assert isinstance(qs.no_dereference().first().organization, DBRef) - self.assertIsInstance(qs_user.organization, Organization) - self.assertIsInstance(qs.first().organization, Organization) + assert isinstance(qs_user.organization, Organization) + assert isinstance(qs.first().organization, Organization) def test_no_dereference_internals(self): # Test the internals on which queryset.no_dereference relies on @@ -4913,24 +4896,24 @@ class TestQueryset(unittest.TestCase): Organization.drop_collection() cls_organization_field = User.organization - self.assertTrue(cls_organization_field._auto_dereference, True) # default + assert cls_organization_field._auto_dereference, True # default org = Organization(name="whatever").save() User(organization=org).save() qs_no_deref = User.objects().no_dereference() user_no_deref = qs_no_deref.first() - self.assertFalse(qs_no_deref._auto_dereference) + assert not qs_no_deref._auto_dereference # Make sure the instance field is different from the class field instance_org_field = user_no_deref._fields["organization"] - self.assertIsNot(instance_org_field, cls_organization_field) - self.assertFalse(instance_org_field._auto_dereference) + assert instance_org_field is not cls_organization_field + assert not instance_org_field._auto_dereference - self.assertIsInstance(user_no_deref.organization, DBRef) - self.assertTrue( - cls_organization_field._auto_dereference, True - ) # Make sure the class Field wasn't altered + assert isinstance(user_no_deref.organization, DBRef) + assert ( + cls_organization_field._auto_dereference + ), True # Make sure the class Field wasn't altered def test_no_dereference_no_side_effect_on_existing_instance(self): # Relates to issue #1677 - ensures no regression of the bug @@ -4956,13 +4939,13 @@ class TestQueryset(unittest.TestCase): # ReferenceField no_derf_org = user_no_deref.organization # was triggering the bug - self.assertIsInstance(no_derf_org, DBRef) - self.assertIsInstance(user.organization, Organization) + assert isinstance(no_derf_org, DBRef) + assert isinstance(user.organization, Organization) # GenericReferenceField no_derf_org_gen = user_no_deref.organization_gen - self.assertIsInstance(no_derf_org_gen, dict) - self.assertIsInstance(user.organization_gen, Organization) + assert isinstance(no_derf_org_gen, dict) + assert isinstance(user.organization_gen, Organization) def test_no_dereference_embedded_doc(self): class User(Document): @@ -4994,13 +4977,13 @@ class TestQueryset(unittest.TestCase): org = Organization.objects().no_dereference().first() - self.assertNotEqual(id(org._fields["admins"]), id(Organization.admins)) - self.assertFalse(org._fields["admins"]._auto_dereference) + assert id(org._fields["admins"]) != id(Organization.admins) + assert not org._fields["admins"]._auto_dereference admin = org.admins[0] - self.assertIsInstance(admin, DBRef) - self.assertIsInstance(org.member.user, DBRef) - self.assertIsInstance(org.members[0].user, DBRef) + assert isinstance(admin, DBRef) + assert isinstance(org.member.user, DBRef) + assert isinstance(org.members[0].user, DBRef) def test_cached_queryset(self): class Person(Document): @@ -5011,11 +4994,11 @@ class TestQueryset(unittest.TestCase): Person(name="No: %s" % i).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 people = Person.objects [x for x in people] - self.assertEqual(100, len(people._result_cache)) + assert 100 == len(people._result_cache) import platform @@ -5023,15 +5006,15 @@ class TestQueryset(unittest.TestCase): # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. # This may be a bug in PyPy (PyPy/#1802) but it does not affect # the behavior of MongoEngine. - self.assertEqual(None, people._len) - self.assertEqual(q, 1) + assert None == people._len + assert q == 1 list(people) - self.assertEqual(100, people._len) # Caused by list calling len - self.assertEqual(q, 1) + assert 100 == people._len # Caused by list calling len + assert q == 1 people.count(with_limit_and_skip=True) # count is cached - self.assertEqual(q, 1) + assert q == 1 def test_no_cached_queryset(self): class Person(Document): @@ -5042,17 +5025,17 @@ class TestQueryset(unittest.TestCase): Person(name="No: %s" % i).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 people = Person.objects.no_cache() [x for x in people] - self.assertEqual(q, 1) + assert q == 1 list(people) - self.assertEqual(q, 2) + assert q == 2 people.count() - self.assertEqual(q, 3) + assert q == 3 def test_no_cached_queryset__repr__(self): class Person(Document): @@ -5060,7 +5043,7 @@ class TestQueryset(unittest.TestCase): Person.drop_collection() qs = Person.objects.no_cache() - self.assertEqual(repr(qs), "[]") + assert repr(qs) == "[]" def test_no_cached_on_a_cached_queryset_raise_error(self): class Person(Document): @@ -5070,9 +5053,9 @@ class TestQueryset(unittest.TestCase): Person(name="a").save() qs = Person.objects() _ = list(qs) - with self.assertRaises(OperationError) as ctx_err: + with pytest.raises(OperationError) as ctx_err: qs.no_cache() - self.assertEqual("QuerySet already cached", str(ctx_err.exception)) + assert "QuerySet already cached" == str(ctx_err.exception) def test_no_cached_queryset_no_cache_back_to_cache(self): class Person(Document): @@ -5080,11 +5063,11 @@ class TestQueryset(unittest.TestCase): Person.drop_collection() qs = Person.objects() - self.assertIsInstance(qs, QuerySet) + assert isinstance(qs, QuerySet) qs = qs.no_cache() - self.assertIsInstance(qs, QuerySetNoCache) + assert isinstance(qs, QuerySetNoCache) qs = qs.cache() - self.assertIsInstance(qs, QuerySet) + assert isinstance(qs, QuerySet) def test_cache_not_cloned(self): class User(Document): @@ -5099,12 +5082,12 @@ class TestQueryset(unittest.TestCase): User(name="Bob").save() users = User.objects.all().order_by("name") - self.assertEqual("%s" % users, "[, ]") - self.assertEqual(2, len(users._result_cache)) + assert "%s" % users == "[, ]" + assert 2 == len(users._result_cache) users = users.filter(name="Bob") - self.assertEqual("%s" % users, "[]") - self.assertEqual(1, len(users._result_cache)) + assert "%s" % users == "[]" + assert 1 == len(users._result_cache) def test_no_cache(self): """Ensure you can add meta data to file""" @@ -5122,23 +5105,23 @@ class TestQueryset(unittest.TestCase): docs = Noddy.objects.no_cache() counter = len([1 for i in docs]) - self.assertEqual(counter, 100) + assert counter == 100 - self.assertEqual(len(list(docs)), 100) + assert len(list(docs)) == 100 # Can't directly get a length of a no-cache queryset. - with self.assertRaises(TypeError): + with pytest.raises(TypeError): len(docs) # Another iteration over the queryset should result in another db op. with query_counter() as q: list(docs) - self.assertEqual(q, 1) + assert q == 1 # ... and another one to double-check. with query_counter() as q: list(docs) - self.assertEqual(q, 1) + assert q == 1 def test_nested_queryset_iterator(self): # Try iterating the same queryset twice, nested. @@ -5161,32 +5144,32 @@ class TestQueryset(unittest.TestCase): inner_total_count = 0 with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 - self.assertEqual(users.count(with_limit_and_skip=True), 7) + assert users.count(with_limit_and_skip=True) == 7 for i, outer_user in enumerate(users): - self.assertEqual(outer_user.name, names[i]) + assert outer_user.name == names[i] outer_count += 1 inner_count = 0 # Calling len might disrupt the inner loop if there are bugs - self.assertEqual(users.count(with_limit_and_skip=True), 7) + assert users.count(with_limit_and_skip=True) == 7 for j, inner_user in enumerate(users): - self.assertEqual(inner_user.name, names[j]) + assert inner_user.name == names[j] inner_count += 1 inner_total_count += 1 # inner loop should always be executed seven times - self.assertEqual(inner_count, 7) + assert inner_count == 7 # outer loop should be executed seven times total - self.assertEqual(outer_count, 7) + assert outer_count == 7 # inner loop should be executed fourtynine times total - self.assertEqual(inner_total_count, 7 * 7) + assert inner_total_count == 7 * 7 - self.assertEqual(q, 2) + assert q == 2 def test_no_sub_classes(self): class A(Document): @@ -5209,23 +5192,23 @@ class TestQueryset(unittest.TestCase): B(x=30, y=50).save() C(x=40, y=60).save() - self.assertEqual(A.objects.no_sub_classes().count(), 2) - self.assertEqual(A.objects.count(), 5) + assert A.objects.no_sub_classes().count() == 2 + assert A.objects.count() == 5 - self.assertEqual(B.objects.no_sub_classes().count(), 2) - self.assertEqual(B.objects.count(), 3) + assert B.objects.no_sub_classes().count() == 2 + assert B.objects.count() == 3 - self.assertEqual(C.objects.no_sub_classes().count(), 1) - self.assertEqual(C.objects.count(), 1) + assert C.objects.no_sub_classes().count() == 1 + assert C.objects.count() == 1 for obj in A.objects.no_sub_classes(): - self.assertEqual(obj.__class__, A) + assert obj.__class__ == A for obj in B.objects.no_sub_classes(): - self.assertEqual(obj.__class__, B) + assert obj.__class__ == B for obj in C.objects.no_sub_classes(): - self.assertEqual(obj.__class__, C) + assert obj.__class__ == C def test_query_generic_embedded_document(self): """Ensure that querying sub field on generic_embedded_field works @@ -5245,10 +5228,10 @@ class TestQueryset(unittest.TestCase): Doc(document=B(b_name="B doc")).save() # Using raw in filter working fine - self.assertEqual(Doc.objects(__raw__={"document.a_name": "A doc"}).count(), 1) - self.assertEqual(Doc.objects(__raw__={"document.b_name": "B doc"}).count(), 1) - self.assertEqual(Doc.objects(document__a_name="A doc").count(), 1) - self.assertEqual(Doc.objects(document__b_name="B doc").count(), 1) + assert Doc.objects(__raw__={"document.a_name": "A doc"}).count() == 1 + assert Doc.objects(__raw__={"document.b_name": "B doc"}).count() == 1 + assert Doc.objects(document__a_name="A doc").count() == 1 + assert Doc.objects(document__b_name="B doc").count() == 1 def test_query_reference_to_custom_pk_doc(self): class A(Document): @@ -5263,9 +5246,9 @@ class TestQueryset(unittest.TestCase): a = A.objects.create(id="custom_id") B.objects.create(a=a) - self.assertEqual(B.objects.count(), 1) - self.assertEqual(B.objects.get(a=a).a, a) - self.assertEqual(B.objects.get(a=a.id).a, a) + assert B.objects.count() == 1 + assert B.objects.get(a=a).a == a + assert B.objects.get(a=a.id).a == a def test_cls_query_in_subclassed_docs(self): class Animal(Document): @@ -5279,21 +5262,18 @@ class TestQueryset(unittest.TestCase): class Cat(Animal): pass - self.assertEqual( - Animal.objects(name="Charlie")._query, - { - "name": "Charlie", - "_cls": {"$in": ("Animal", "Animal.Dog", "Animal.Cat")}, - }, - ) - self.assertEqual( - Dog.objects(name="Charlie")._query, - {"name": "Charlie", "_cls": "Animal.Dog"}, - ) - self.assertEqual( - Cat.objects(name="Charlie")._query, - {"name": "Charlie", "_cls": "Animal.Cat"}, - ) + assert Animal.objects(name="Charlie")._query == { + "name": "Charlie", + "_cls": {"$in": ("Animal", "Animal.Dog", "Animal.Cat")}, + } + assert Dog.objects(name="Charlie")._query == { + "name": "Charlie", + "_cls": "Animal.Dog", + } + assert Cat.objects(name="Charlie")._query == { + "name": "Charlie", + "_cls": "Animal.Cat", + } def test_can_have_field_same_name_as_query_operator(self): class Size(Document): @@ -5308,8 +5288,8 @@ class TestQueryset(unittest.TestCase): instance_size = Size(name="Large").save() Example(size=instance_size).save() - self.assertEqual(Example.objects(size=instance_size).count(), 1) - self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) + assert Example.objects(size=instance_size).count() == 1 + assert Example.objects(size__in=[instance_size]).count() == 1 def test_cursor_in_an_if_stmt(self): class Test(Document): @@ -5347,12 +5327,12 @@ class TestQueryset(unittest.TestCase): if Person.objects: pass - self.assertEqual(q, 1) + assert q == 1 op = q.db.system.profile.find( {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertEqual(op["nreturned"], 1) + assert op["nreturned"] == 1 def test_bool_with_ordering(self): ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) @@ -5375,7 +5355,7 @@ class TestQueryset(unittest.TestCase): {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) + assert ORDER_BY_KEY not in op[CMD_QUERY_KEY] # Check that normal query uses orderby qs2 = Person.objects.order_by("name") @@ -5388,7 +5368,7 @@ class TestQueryset(unittest.TestCase): {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) + assert ORDER_BY_KEY in op[CMD_QUERY_KEY] def test_bool_with_ordering_from_meta_dict(self): ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) @@ -5412,16 +5392,12 @@ class TestQueryset(unittest.TestCase): {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertNotIn( - "$orderby", - op[CMD_QUERY_KEY], - "BaseQuerySet must remove orderby from meta in boolen test", - ) + assert ( + "$orderby" not in op[CMD_QUERY_KEY] + ), "BaseQuerySet must remove orderby from meta in boolen test" - self.assertEqual(Person.objects.first().name, "A") - self.assertTrue( - Person.objects._has_data(), "Cursor has data and returned False" - ) + assert Person.objects.first().name == "A" + assert Person.objects._has_data(), "Cursor has data and returned False" def test_queryset_aggregation_framework(self): class Person(Document): @@ -5439,13 +5415,10 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual( - list(data), - [ - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - ], - ) + assert list(data) == [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + ] data = ( Person.objects(age__lte=22) @@ -5453,13 +5426,10 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual( - list(data), - [ - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - ], - ) + assert list(data) == [ + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + ] data = ( Person.objects(age__gte=17, age__lte=40) @@ -5468,12 +5438,10 @@ class TestQueryset(unittest.TestCase): {"$group": {"_id": None, "total": {"$sum": 1}, "avg": {"$avg": "$age"}}} ) ) - self.assertEqual(list(data), [{"_id": None, "avg": 29, "total": 2}]) + assert list(data) == [{"_id": None, "avg": 29, "total": 2}] data = Person.objects().aggregate({"$match": {"name": "Isabella Luanna"}}) - self.assertEqual( - list(data), [{u"_id": p1.pk, u"age": 16, u"name": u"Isabella Luanna"}] - ) + assert list(data) == [{u"_id": p1.pk, u"age": 16, u"name": u"Isabella Luanna"}] def test_queryset_aggregation_with_skip(self): class Person(Document): @@ -5491,13 +5459,10 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual( - list(data), - [ - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - {"_id": p3.pk, "name": "SANDRA MARA"}, - ], - ) + assert list(data) == [ + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + ] def test_queryset_aggregation_with_limit(self): class Person(Document): @@ -5515,7 +5480,7 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}]) + assert list(data) == [{"_id": p1.pk, "name": "ISABELLA LUANNA"}] def test_queryset_aggregation_with_sort(self): class Person(Document): @@ -5533,14 +5498,11 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual( - list(data), - [ - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - {"_id": p3.pk, "name": "SANDRA MARA"}, - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - ], - ) + assert list(data) == [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + ] def test_queryset_aggregation_with_skip_with_limit(self): class Person(Document): @@ -5560,7 +5522,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}]) + assert list(data) == [{"_id": p2.pk, "name": "WILSON JUNIOR"}] # Make sure limit/skip chaining order has no impact data2 = ( @@ -5569,7 +5531,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(data, list(data2)) + assert data == list(data2) def test_queryset_aggregation_with_sort_with_limit(self): class Person(Document): @@ -5589,13 +5551,10 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual( - list(data), - [ - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - {"_id": p3.pk, "name": "SANDRA MARA"}, - ], - ) + assert list(data) == [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + ] # Verify adding limit/skip steps works as expected data = ( @@ -5604,7 +5563,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}, {"$limit": 1}) ) - self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}]) + assert list(data) == [{"_id": p1.pk, "name": "ISABELLA LUANNA"}] data = ( Person.objects.order_by("name") @@ -5616,7 +5575,7 @@ class TestQueryset(unittest.TestCase): ) ) - self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}]) + assert list(data) == [{"_id": p3.pk, "name": "SANDRA MARA"}] def test_queryset_aggregation_with_sort_with_skip(self): class Person(Document): @@ -5636,7 +5595,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}]) + assert list(data) == [{"_id": p2.pk, "name": "WILSON JUNIOR"}] def test_queryset_aggregation_with_sort_with_skip_with_limit(self): class Person(Document): @@ -5657,30 +5616,29 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}]) + assert list(data) == [{"_id": p3.pk, "name": "SANDRA MARA"}] def test_delete_count(self): [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] - self.assertEqual( - self.Person.objects().delete(), 3 + assert ( + self.Person.objects().delete() == 3 ) # test ordinary QuerySey delete count [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] - self.assertEqual( - self.Person.objects().skip(1).delete(), 2 + assert ( + self.Person.objects().skip(1).delete() == 2 ) # test Document delete with existing documents self.Person.objects().delete() - self.assertEqual( - self.Person.objects().skip(1).delete(), 0 + assert ( + self.Person.objects().skip(1).delete() == 0 ) # test Document delete without existing documents def test_max_time_ms(self): # 778: max_time_ms can get only int or None as input - self.assertRaises( - TypeError, self.Person.objects(name="name").max_time_ms, "not a number" - ) + with pytest.raises(TypeError): + self.Person.objects(name="name").max_time_ms("not a number") def test_subclass_field_query(self): class Animal(Document): @@ -5698,8 +5656,8 @@ class TestQueryset(unittest.TestCase): Animal(is_mamal=False).save() Cat(is_mamal=True, whiskers_length=5.1).save() ScottishCat(is_mamal=True, folded_ears=True).save() - self.assertEqual(Animal.objects(folded_ears=True).count(), 1) - self.assertEqual(Animal.objects(whiskers_length=5.1).count(), 1) + assert Animal.objects(folded_ears=True).count() == 1 + assert Animal.objects(whiskers_length=5.1).count() == 1 def test_loop_over_invalid_id_does_not_crash(self): class Person(Document): @@ -5709,7 +5667,7 @@ class TestQueryset(unittest.TestCase): Person._get_collection().insert_one({"name": "a", "id": ""}) for p in Person.objects(): - self.assertEqual(p.name, "a") + assert p.name == "a" def test_len_during_iteration(self): """Tests that calling len on a queyset during iteration doesn't @@ -5733,7 +5691,7 @@ class TestQueryset(unittest.TestCase): for i, r in enumerate(records): if i == 58: len(records) - self.assertEqual(i, 249) + assert i == 249 # Assert the same behavior is true even if we didn't pre-populate the # result cache. @@ -5741,7 +5699,7 @@ class TestQueryset(unittest.TestCase): for i, r in enumerate(records): if i == 58: len(records) - self.assertEqual(i, 249) + assert i == 249 def test_iteration_within_iteration(self): """You should be able to reliably iterate over all the documents @@ -5760,8 +5718,8 @@ class TestQueryset(unittest.TestCase): for j, doc2 in enumerate(qs): pass - self.assertEqual(i, 249) - self.assertEqual(j, 249) + assert i == 249 + assert j == 249 def test_in_operator_on_non_iterable(self): """Ensure that using the `__in` operator on a non-iterable raises an @@ -5785,24 +5743,26 @@ class TestQueryset(unittest.TestCase): # Make sure using `__in` with a list works blog_posts = BlogPost.objects(authors__in=[author]) - self.assertEqual(list(blog_posts), [post]) + assert list(blog_posts) == [post] # Using `__in` with a non-iterable should raise a TypeError - self.assertRaises(TypeError, BlogPost.objects(authors__in=author.pk).count) + with pytest.raises(TypeError): + BlogPost.objects(authors__in=author.pk).count() # Using `__in` with a `Document` (which is seemingly iterable but not # in a way we'd expect) should raise a TypeError, too - self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count) + with pytest.raises(TypeError): + BlogPost.objects(authors__in=author).count() def test_create_count(self): self.Person.drop_collection() self.Person.objects.create(name="Foo") self.Person.objects.create(name="Bar") self.Person.objects.create(name="Baz") - self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 3) + assert self.Person.objects.count(with_limit_and_skip=True) == 3 - self.Person.objects.create(name="Foo_1") - self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 4) + newPerson = self.Person.objects.create(name="Foo_1") + assert self.Person.objects.count(with_limit_and_skip=True) == 4 def test_no_cursor_timeout(self): qs = self.Person.objects() diff --git a/tests/queryset/test_transform.py b/tests/queryset/test_transform.py index 8207351d..be28c3b8 100644 --- a/tests/queryset/test_transform.py +++ b/tests/queryset/test_transform.py @@ -4,6 +4,7 @@ from bson.son import SON from mongoengine import * from mongoengine.queryset import Q, transform +import pytest class TestTransform(unittest.TestCase): @@ -13,23 +14,16 @@ class TestTransform(unittest.TestCase): def test_transform_query(self): """Ensure that the _transform_query function operates correctly. """ - self.assertEqual( - transform.query(name="test", age=30), {"name": "test", "age": 30} - ) - self.assertEqual(transform.query(age__lt=30), {"age": {"$lt": 30}}) - self.assertEqual( - transform.query(age__gt=20, age__lt=50), {"age": {"$gt": 20, "$lt": 50}} - ) - self.assertEqual( - transform.query(age=20, age__gt=50), - {"$and": [{"age": {"$gt": 50}}, {"age": 20}]}, - ) - self.assertEqual( - transform.query(friend__age__gte=30), {"friend.age": {"$gte": 30}} - ) - self.assertEqual( - transform.query(name__exists=True), {"name": {"$exists": True}} - ) + assert transform.query(name="test", age=30) == {"name": "test", "age": 30} + assert transform.query(age__lt=30) == {"age": {"$lt": 30}} + assert transform.query(age__gt=20, age__lt=50) == { + "age": {"$gt": 20, "$lt": 50} + } + assert transform.query(age=20, age__gt=50) == { + "$and": [{"age": {"$gt": 50}}, {"age": 20}] + } + assert transform.query(friend__age__gte=30) == {"friend.age": {"$gte": 30}} + assert transform.query(name__exists=True) == {"name": {"$exists": True}} def test_transform_update(self): class LisDoc(Document): @@ -54,17 +48,17 @@ class TestTransform(unittest.TestCase): ("push", "$push"), ): update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) - self.assertIsInstance(update[v]["dictField.test"], dict) + assert isinstance(update[v]["dictField.test"], dict) # Update special cases update = transform.update(DicDoc, unset__dictField__test=doc) - self.assertEqual(update["$unset"]["dictField.test"], 1) + assert update["$unset"]["dictField.test"] == 1 update = transform.update(DicDoc, pull__dictField__test=doc) - self.assertIsInstance(update["$pull"]["dictField"]["test"], dict) + assert isinstance(update["$pull"]["dictField"]["test"], dict) update = transform.update(LisDoc, pull__foo__in=["a"]) - self.assertEqual(update, {"$pull": {"foo": {"$in": ["a"]}}}) + assert update == {"$pull": {"foo": {"$in": ["a"]}}} def test_transform_update_push(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" @@ -73,10 +67,10 @@ class TestTransform(unittest.TestCase): tags = ListField(StringField()) update = transform.update(BlogPost, push__tags=["mongo", "db"]) - self.assertEqual(update, {"$push": {"tags": ["mongo", "db"]}}) + assert update == {"$push": {"tags": ["mongo", "db"]}} update = transform.update(BlogPost, push_all__tags=["mongo", "db"]) - self.assertEqual(update, {"$push": {"tags": {"$each": ["mongo", "db"]}}}) + assert update == {"$push": {"tags": {"$each": ["mongo", "db"]}}} def test_transform_update_no_operator_default_to_set(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" @@ -85,7 +79,7 @@ class TestTransform(unittest.TestCase): tags = ListField(StringField()) update = transform.update(BlogPost, tags=["mongo", "db"]) - self.assertEqual(update, {"$set": {"tags": ["mongo", "db"]}}) + assert update == {"$set": {"tags": ["mongo", "db"]}} def test_query_field_name(self): """Ensure that the correct field name is used when querying. @@ -106,18 +100,18 @@ class TestTransform(unittest.TestCase): post = BlogPost(**data) post.save() - self.assertIn("postTitle", BlogPost.objects(title=data["title"])._query) - self.assertFalse("title" in BlogPost.objects(title=data["title"])._query) - self.assertEqual(BlogPost.objects(title=data["title"]).count(), 1) + assert "postTitle" in BlogPost.objects(title=data["title"])._query + assert not ("title" in BlogPost.objects(title=data["title"])._query) + assert BlogPost.objects(title=data["title"]).count() == 1 - self.assertIn("_id", BlogPost.objects(pk=post.id)._query) - self.assertEqual(BlogPost.objects(pk=post.id).count(), 1) + assert "_id" in BlogPost.objects(pk=post.id)._query + assert BlogPost.objects(pk=post.id).count() == 1 - self.assertIn( - "postComments.commentContent", - BlogPost.objects(comments__content="test")._query, + assert ( + "postComments.commentContent" + in BlogPost.objects(comments__content="test")._query ) - self.assertEqual(BlogPost.objects(comments__content="test").count(), 1) + assert BlogPost.objects(comments__content="test").count() == 1 BlogPost.drop_collection() @@ -135,9 +129,9 @@ class TestTransform(unittest.TestCase): post = BlogPost(**data) post.save() - self.assertIn("_id", BlogPost.objects(pk=data["title"])._query) - self.assertIn("_id", BlogPost.objects(title=data["title"])._query) - self.assertEqual(BlogPost.objects(pk=data["title"]).count(), 1) + assert "_id" in BlogPost.objects(pk=data["title"])._query + assert "_id" in BlogPost.objects(title=data["title"])._query + assert BlogPost.objects(pk=data["title"]).count() == 1 BlogPost.drop_collection() @@ -163,7 +157,7 @@ class TestTransform(unittest.TestCase): q2 = B.objects.filter(a__in=[a1, a2]) q2 = q2.filter(a=a1)._query - self.assertEqual(q1, q2) + assert q1 == q2 def test_raw_query_and_Q_objects(self): """ @@ -179,11 +173,11 @@ class TestTransform(unittest.TestCase): meta = {"allow_inheritance": False} query = Foo.objects(__raw__={"$nor": [{"name": "bar"}]})._query - self.assertEqual(query, {"$nor": [{"name": "bar"}]}) + assert query == {"$nor": [{"name": "bar"}]} q1 = {"$or": [{"a": 1}, {"b": 1}]} query = Foo.objects(Q(__raw__=q1) & Q(c=1))._query - self.assertEqual(query, {"$or": [{"a": 1}, {"b": 1}], "c": 1}) + assert query == {"$or": [{"a": 1}, {"b": 1}], "c": 1} def test_raw_and_merging(self): class Doc(Document): @@ -200,51 +194,39 @@ class TestTransform(unittest.TestCase): } )._query - self.assertEqual( - raw_query, - { - "deleted": False, - "scraped": "yes", - "$nor": [ - {"views.extracted": "no"}, - {"attachments.views.extracted": "no"}, - ], - }, - ) + assert raw_query == { + "deleted": False, + "scraped": "yes", + "$nor": [{"views.extracted": "no"}, {"attachments.views.extracted": "no"}], + } def test_geojson_PointField(self): class Location(Document): loc = PointField() update = transform.update(Location, set__loc=[1, 2]) - self.assertEqual( - update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} - ) + assert update == {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} update = transform.update( Location, set__loc={"type": "Point", "coordinates": [1, 2]} ) - self.assertEqual( - update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} - ) + assert update == {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} def test_geojson_LineStringField(self): class Location(Document): line = LineStringField() update = transform.update(Location, set__line=[[1, 2], [2, 2]]) - self.assertEqual( - update, - {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}, - ) + assert update == { + "$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}} + } update = transform.update( Location, set__line={"type": "LineString", "coordinates": [[1, 2], [2, 2]]} ) - self.assertEqual( - update, - {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}, - ) + assert update == { + "$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}} + } def test_geojson_PolygonField(self): class Location(Document): @@ -253,17 +235,14 @@ class TestTransform(unittest.TestCase): update = transform.update( Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]] ) - self.assertEqual( - update, - { - "$set": { - "poly": { - "type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], - } + assert update == { + "$set": { + "poly": { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], } - }, - ) + } + } update = transform.update( Location, @@ -272,17 +251,14 @@ class TestTransform(unittest.TestCase): "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], }, ) - self.assertEqual( - update, - { - "$set": { - "poly": { - "type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], - } + assert update == { + "$set": { + "poly": { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], } - }, - ) + } + } def test_type(self): class Doc(Document): @@ -291,10 +267,10 @@ class TestTransform(unittest.TestCase): Doc(df=True).save() Doc(df=7).save() Doc(df="df").save() - self.assertEqual(Doc.objects(df__type=1).count(), 0) # double - self.assertEqual(Doc.objects(df__type=8).count(), 1) # bool - self.assertEqual(Doc.objects(df__type=2).count(), 1) # str - self.assertEqual(Doc.objects(df__type=16).count(), 1) # int + assert Doc.objects(df__type=1).count() == 0 # double + assert Doc.objects(df__type=8).count() == 1 # bool + assert Doc.objects(df__type=2).count() == 1 # str + assert Doc.objects(df__type=16).count() == 1 # int def test_last_field_name_like_operator(self): class EmbeddedItem(EmbeddedDocument): @@ -309,12 +285,12 @@ class TestTransform(unittest.TestCase): doc = Doc(item=EmbeddedItem(type="axe", name="Heroic axe")) doc.save() - self.assertEqual(1, Doc.objects(item__type__="axe").count()) - self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) + assert 1 == Doc.objects(item__type__="axe").count() + assert 1 == Doc.objects(item__name__="Heroic axe").count() Doc.objects(id=doc.id).update(set__item__type__="sword") - self.assertEqual(1, Doc.objects(item__type__="sword").count()) - self.assertEqual(0, Doc.objects(item__type__="axe").count()) + assert 1 == Doc.objects(item__type__="sword").count() + assert 0 == Doc.objects(item__type__="axe").count() def test_understandable_error_raised(self): class Event(Document): @@ -324,7 +300,7 @@ class TestTransform(unittest.TestCase): box = [(35.0, -125.0), (40.0, -100.0)] # I *meant* to execute location__within_box=box events = Event.objects(location__within=box) - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): events.count() def test_update_pull_for_list_fields(self): @@ -347,24 +323,20 @@ class TestTransform(unittest.TestCase): word = Word(word="abc", index=1) update = transform.update(MainDoc, pull__content__text=word) - self.assertEqual( - update, {"$pull": {"content.text": SON([("word", u"abc"), ("index", 1)])}} - ) + assert update == { + "$pull": {"content.text": SON([("word", u"abc"), ("index", 1)])} + } update = transform.update(MainDoc, pull__content__heading="xyz") - self.assertEqual(update, {"$pull": {"content.heading": "xyz"}}) + assert update == {"$pull": {"content.heading": "xyz"}} update = transform.update(MainDoc, pull__content__text__word__in=["foo", "bar"]) - self.assertEqual( - update, {"$pull": {"content.text": {"word": {"$in": ["foo", "bar"]}}}} - ) + assert update == {"$pull": {"content.text": {"word": {"$in": ["foo", "bar"]}}}} update = transform.update( MainDoc, pull__content__text__word__nin=["foo", "bar"] ) - self.assertEqual( - update, {"$pull": {"content.text": {"word": {"$nin": ["foo", "bar"]}}}} - ) + assert update == {"$pull": {"content.text": {"word": {"$nin": ["foo", "bar"]}}}} if __name__ == "__main__": diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index acadabd4..a41f9278 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -7,6 +7,7 @@ from bson import ObjectId from mongoengine import * from mongoengine.errors import InvalidQueryError from mongoengine.queryset import Q +import pytest class TestQ(unittest.TestCase): @@ -35,10 +36,10 @@ class TestQ(unittest.TestCase): age = IntField() query = {"$or": [{"age": {"$gte": 18}}, {"name": "test"}]} - self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) + assert (q1 | q2 | q3 | q4 | q5).to_query(Person) == query query = {"age": {"$gte": 18}, "name": "test"} - self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) + assert (q1 & q2 & q3 & q4 & q5).to_query(Person) == query def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" @@ -53,8 +54,8 @@ class TestQ(unittest.TestCase): user = User.objects.create() Post.objects.create(created_user=user) - self.assertEqual(Post.objects.filter(created_user=user).count(), 1) - self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + assert Post.objects.filter(created_user=user).count() == 1 + assert Post.objects.filter(Q(created_user=user)).count() == 1 def test_and_combination(self): """Ensure that Q-objects correctly AND together. @@ -65,12 +66,10 @@ class TestQ(unittest.TestCase): y = StringField() query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) - self.assertEqual(query, {"$and": [{"x": {"$lt": 7}}, {"x": {"$lt": 3}}]}) + assert query == {"$and": [{"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} query = (Q(y="a") & Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) - self.assertEqual( - query, {"$and": [{"y": "a"}, {"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} - ) + assert query == {"$and": [{"y": "a"}, {"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} # Check normal cases work without an error query = Q(x__lt=7) & Q(x__gt=3) @@ -78,7 +77,7 @@ class TestQ(unittest.TestCase): q1 = Q(x__lt=7) q2 = Q(x__gt=3) query = (q1 & q2).to_query(TestDoc) - self.assertEqual(query, {"x": {"$lt": 7, "$gt": 3}}) + assert query == {"x": {"$lt": 7, "$gt": 3}} # More complex nested example query = Q(x__lt=100) & Q(y__ne="NotMyString") @@ -87,7 +86,7 @@ class TestQ(unittest.TestCase): "x": {"$lt": 100, "$gt": -100}, "y": {"$ne": "NotMyString", "$in": ["a", "b", "c"]}, } - self.assertEqual(query.to_query(TestDoc), mongo_query) + assert query.to_query(TestDoc) == mongo_query def test_or_combination(self): """Ensure that Q-objects correctly OR together. @@ -99,7 +98,7 @@ class TestQ(unittest.TestCase): q1 = Q(x__lt=3) q2 = Q(x__gt=7) query = (q1 | q2).to_query(TestDoc) - self.assertEqual(query, {"$or": [{"x": {"$lt": 3}}, {"x": {"$gt": 7}}]}) + assert query == {"$or": [{"x": {"$lt": 3}}, {"x": {"$gt": 7}}]} def test_and_or_combination(self): """Ensure that Q-objects handle ANDing ORed components. @@ -113,15 +112,12 @@ class TestQ(unittest.TestCase): query = Q(x__gt=0) | Q(x__exists=False) query &= Q(x__lt=100) - self.assertEqual( - query.to_query(TestDoc), - { - "$and": [ - {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, - {"x": {"$lt": 100}}, - ] - }, - ) + assert query.to_query(TestDoc) == { + "$and": [ + {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, + {"x": {"$lt": 100}}, + ] + } q1 = Q(x__gt=0) | Q(x__exists=False) q2 = Q(x__lt=100) | Q(y=True) @@ -131,16 +127,13 @@ class TestQ(unittest.TestCase): TestDoc(x=10).save() TestDoc(y=True).save() - self.assertEqual( - query, - { - "$and": [ - {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, - {"$or": [{"x": {"$lt": 100}}, {"y": True}]}, - ] - }, - ) - self.assertEqual(2, TestDoc.objects(q1 & q2).count()) + assert query == { + "$and": [ + {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, + {"$or": [{"x": {"$lt": 100}}, {"y": True}]}, + ] + } + assert 2 == TestDoc.objects(q1 & q2).count() def test_or_and_or_combination(self): """Ensure that Q-objects handle ORing ANDed ORed components. :) @@ -160,26 +153,23 @@ class TestQ(unittest.TestCase): q2 = Q(x__lt=100) & (Q(y=False) | Q(y__exists=False)) query = (q1 | q2).to_query(TestDoc) - self.assertEqual( - query, - { - "$or": [ - { - "$and": [ - {"x": {"$gt": 0}}, - {"$or": [{"y": True}, {"y": {"$exists": False}}]}, - ] - }, - { - "$and": [ - {"x": {"$lt": 100}}, - {"$or": [{"y": False}, {"y": {"$exists": False}}]}, - ] - }, - ] - }, - ) - self.assertEqual(2, TestDoc.objects(q1 | q2).count()) + assert query == { + "$or": [ + { + "$and": [ + {"x": {"$gt": 0}}, + {"$or": [{"y": True}, {"y": {"$exists": False}}]}, + ] + }, + { + "$and": [ + {"x": {"$lt": 100}}, + {"$or": [{"y": False}, {"y": {"$exists": False}}]}, + ] + }, + ] + } + assert 2 == TestDoc.objects(q1 | q2).count() def test_multiple_occurence_in_field(self): class Test(Document): @@ -192,8 +182,8 @@ class TestQ(unittest.TestCase): q3 = q1 & q2 query = q3.to_query(Test) - self.assertEqual(query["$and"][0], q1.to_query(Test)) - self.assertEqual(query["$and"][1], q2.to_query(Test)) + assert query["$and"][0] == q1.to_query(Test) + assert query["$and"][1] == q2.to_query(Test) def test_q_clone(self): class TestDoc(Document): @@ -207,15 +197,15 @@ class TestQ(unittest.TestCase): # Check normal cases work without an error test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) - self.assertEqual(test.count(), 3) + assert test.count() == 3 test2 = test.clone() - self.assertEqual(test2.count(), 3) - self.assertNotEqual(test2, test) + assert test2.count() == 3 + assert test2 != test test3 = test2.filter(x=6) - self.assertEqual(test3.count(), 1) - self.assertEqual(test.count(), 3) + assert test3.count() == 1 + assert test.count() == 3 def test_q(self): """Ensure that Q objects may be used to query for documents. @@ -252,19 +242,19 @@ class TestQ(unittest.TestCase): # Check ObjectId lookup works obj = BlogPost.objects(id=post1.id).first() - self.assertEqual(obj, post1) + assert obj == post1 # Check Q object combination with one does not exist q = BlogPost.objects(Q(title="Test 5") | Q(published=True)) posts = [post.id for post in q] published_posts = (post2, post3) - self.assertTrue(all(obj.id in posts for obj in published_posts)) + assert all(obj.id in posts for obj in published_posts) q = BlogPost.objects(Q(title="Test 1") | Q(published=True)) posts = [post.id for post in q] published_posts = (post1, post2, post3, post5, post6) - self.assertTrue(all(obj.id in posts for obj in published_posts)) + assert all(obj.id in posts for obj in published_posts) # Check Q object combination date = datetime.datetime(2010, 1, 10) @@ -272,9 +262,9 @@ class TestQ(unittest.TestCase): posts = [post.id for post in q] published_posts = (post1, post2, post3, post4) - self.assertTrue(all(obj.id in posts for obj in published_posts)) + assert all(obj.id in posts for obj in published_posts) - self.assertFalse(any(obj.id in posts for obj in [post5, post6])) + assert not any(obj.id in posts for obj in [post5, post6]) BlogPost.drop_collection() @@ -284,15 +274,15 @@ class TestQ(unittest.TestCase): self.Person(name="user3", age=30).save() self.Person(name="user4", age=40).save() - self.assertEqual(self.Person.objects(Q(age__in=[20])).count(), 2) - self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) + assert self.Person.objects(Q(age__in=[20])).count() == 2 + assert self.Person.objects(Q(age__in=[20, 30])).count() == 3 # Test invalid query objs - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): self.Person.objects("user1") # filter should fail, too - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): self.Person.objects.filter("user1") def test_q_regex(self): @@ -302,31 +292,31 @@ class TestQ(unittest.TestCase): person.save() obj = self.Person.objects(Q(name=re.compile("^Gui"))).first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(Q(name=re.compile("^gui"))).first() - self.assertEqual(obj, None) + assert obj == None obj = self.Person.objects(Q(name=re.compile("^gui", re.I))).first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(Q(name__not=re.compile("^bob"))).first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(Q(name__not=re.compile("^Gui"))).first() - self.assertEqual(obj, None) + assert obj == None def test_q_repr(self): - self.assertEqual(repr(Q()), "Q(**{})") - self.assertEqual(repr(Q(name="test")), "Q(**{'name': 'test'})") + assert repr(Q()) == "Q(**{})" + assert repr(Q(name="test")) == "Q(**{'name': 'test'})" - self.assertEqual( - repr(Q(name="test") & Q(age__gte=18)), - "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))", + assert ( + repr(Q(name="test") & Q(age__gte=18)) + == "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))" ) - self.assertEqual( - repr(Q(name="test") | Q(age__gte=18)), - "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))", + assert ( + repr(Q(name="test") | Q(age__gte=18)) + == "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))" ) def test_q_lists(self): @@ -341,8 +331,8 @@ class TestQ(unittest.TestCase): BlogPost(tags=["python", "mongo"]).save() BlogPost(tags=["python"]).save() - self.assertEqual(BlogPost.objects(Q(tags="mongo")).count(), 1) - self.assertEqual(BlogPost.objects(Q(tags="python")).count(), 2) + assert BlogPost.objects(Q(tags="mongo")).count() == 1 + assert BlogPost.objects(Q(tags="python")).count() == 2 BlogPost.drop_collection() @@ -355,12 +345,12 @@ class TestQ(unittest.TestCase): pk = ObjectId() User(email="example@example.com", pk=pk).save() - self.assertEqual( - 1, - User.objects.filter(Q(email="example@example.com") | Q(name="John Doe")) + assert ( + 1 + == User.objects.filter(Q(email="example@example.com") | Q(name="John Doe")) .limit(2) .filter(pk=pk) - .count(), + .count() ) def test_chained_q_or_filtering(self): @@ -376,14 +366,12 @@ class TestQ(unittest.TestCase): Item(postables=[Post(name="a"), Post(name="c")]).save() Item(postables=[Post(name="a"), Post(name="b"), Post(name="c")]).save() - self.assertEqual( - Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2 + assert ( + Item.objects(Q(postables__name="a") & Q(postables__name="b")).count() == 2 ) - self.assertEqual( - Item.objects.filter(postables__name="a") - .filter(postables__name="b") - .count(), - 2, + assert ( + Item.objects.filter(postables__name="a").filter(postables__name="b").count() + == 2 ) diff --git a/tests/test_common.py b/tests/test_common.py index 28f0b992..6b6f18de 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,7 @@ import unittest +import pytest + from mongoengine import Document from mongoengine.common import _import_class @@ -7,8 +9,8 @@ from mongoengine.common import _import_class class TestCommon(unittest.TestCase): def test__import_class(self): doc_cls = _import_class("Document") - self.assertIs(doc_cls, Document) + assert doc_cls is Document def test__import_class_raise_if_not_known(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _import_class("UnknownClass") diff --git a/tests/test_connection.py b/tests/test_connection.py index 1519a835..c73b67d1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,6 +29,7 @@ from mongoengine.connection import ( get_connection, get_db, ) +import pytest def get_tz_awareness(connection): @@ -54,15 +55,15 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" connect("mongoenginetest2", alias="testdb") conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) def test_connect_disconnect_works_properly(self): class History1(Document): @@ -82,31 +83,27 @@ class ConnectionTest(unittest.TestCase): h = History1(name="default").save() h1 = History2(name="db1").save() - self.assertEqual( - list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] - ) - self.assertEqual( - list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] - ) + assert list(History1.objects().as_pymongo()) == [ + {"_id": h.id, "name": "default"} + ] + assert list(History2.objects().as_pymongo()) == [{"_id": h1.id, "name": "db1"}] disconnect("db1") disconnect("db2") - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): list(History1.objects().as_pymongo()) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): list(History2.objects().as_pymongo()) connect("db1", alias="db1") connect("db2", alias="db2") - self.assertEqual( - list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] - ) - self.assertEqual( - list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] - ) + assert list(History1.objects().as_pymongo()) == [ + {"_id": h.id, "name": "default"} + ] + assert list(History2.objects().as_pymongo()) == [{"_id": h1.id, "name": "db1"}] def test_connect_different_documents_to_different_database(self): class History(Document): @@ -132,39 +129,35 @@ class ConnectionTest(unittest.TestCase): h1 = History1(name="db1").save() h2 = History2(name="db2").save() - self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) - self.assertEqual(History1._collection.database.name, "db1") - self.assertEqual(History2._collection.database.name, "db2") + assert History._collection.database.name == DEFAULT_DATABASE_NAME + assert History1._collection.database.name == "db1" + assert History2._collection.database.name == "db2" - self.assertEqual( - list(History.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] - ) - self.assertEqual( - list(History1.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] - ) - self.assertEqual( - list(History2.objects().as_pymongo()), [{"_id": h2.id, "name": "db2"}] - ) + assert list(History.objects().as_pymongo()) == [ + {"_id": h.id, "name": "default"} + ] + assert list(History1.objects().as_pymongo()) == [{"_id": h1.id, "name": "db1"}] + assert list(History2.objects().as_pymongo()) == [{"_id": h2.id, "name": "db2"}] def test_connect_fails_if_connect_2_times_with_default_alias(self): connect("mongoenginetest") - with self.assertRaises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as ctx_err: connect("mongoenginetest2") - self.assertEqual( - "A different connection with alias `default` was already registered. Use disconnect() first", - str(ctx_err.exception), + assert ( + "A different connection with alias `default` was already registered. Use disconnect() first" + == str(ctx_err.exception) ) def test_connect_fails_if_connect_2_times_with_custom_alias(self): connect("mongoenginetest", alias="alias1") - with self.assertRaises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as ctx_err: connect("mongoenginetest2", alias="alias1") - self.assertEqual( - "A different connection with alias `alias1` was already registered. Use disconnect() first", - str(ctx_err.exception), + assert ( + "A different connection with alias `alias1` was already registered. Use disconnect() first" + == str(ctx_err.exception) ) def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( @@ -175,25 +168,25 @@ class ConnectionTest(unittest.TestCase): db_alias = "alias1" connect(db=db_name, alias=db_alias, host="localhost", port=27017) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): connect(host="mongodb://localhost:27017/%s" % db_name, alias=db_alias) def test_connect_passes_silently_connect_multiple_times_with_same_config(self): # test default connection to `test` connect() connect() - self.assertEqual(len(mongoengine.connection._connections), 1) + assert len(mongoengine.connection._connections) == 1 connect("test01", alias="test01") connect("test01", alias="test01") - self.assertEqual(len(mongoengine.connection._connections), 2) + assert len(mongoengine.connection._connections) == 2 connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02") connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02") - self.assertEqual(len(mongoengine.connection._connections), 3) + assert len(mongoengine.connection._connections) == 3 def test_connect_with_invalid_db_name(self): """Ensure that connect() method fails fast if db name is invalid """ - with self.assertRaises(InvalidName): + with pytest.raises(InvalidName): connect("mongomock://localhost") def test_connect_with_db_name_external(self): @@ -203,20 +196,20 @@ class ConnectionTest(unittest.TestCase): connect("$external") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "$external") + assert isinstance(db, pymongo.database.Database) + assert db.name == "$external" connect("$external", alias="testdb") conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) def test_connect_with_invalid_db_name_type(self): """Ensure that connect() method fails fast if db name has invalid type """ - with self.assertRaises(TypeError): + with pytest.raises(TypeError): non_string_db_name = ["e. g. list instead of a string"] connect(non_string_db_name) @@ -230,11 +223,11 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest", host="mongomock://localhost") conn = get_connection() - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect("mongoenginetest2", host="mongomock://localhost", alias="testdb2") conn = get_connection("testdb2") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( "mongoenginetest3", @@ -243,11 +236,11 @@ class ConnectionTest(unittest.TestCase): alias="testdb3", ) conn = get_connection("testdb3") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect("mongoenginetest4", is_mock=True, alias="testdb4") conn = get_connection("testdb4") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host="mongodb://localhost:27017/mongoenginetest5", @@ -255,11 +248,11 @@ class ConnectionTest(unittest.TestCase): alias="testdb5", ) conn = get_connection("testdb5") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect(host="mongomock://localhost:27017/mongoenginetest6", alias="testdb6") conn = get_connection("testdb6") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host="mongomock://localhost:27017/mongoenginetest7", @@ -267,7 +260,7 @@ class ConnectionTest(unittest.TestCase): alias="testdb7", ) conn = get_connection("testdb7") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) def test_default_database_with_mocking(self): """Ensure that the default database is correctly set when using mongomock. @@ -286,8 +279,8 @@ class ConnectionTest(unittest.TestCase): some_document = SomeDocument() # database won't exist until we save a document some_document.save() - self.assertEqual(conn.get_default_database().name, "mongoenginetest") - self.assertEqual(conn.list_database_names()[0], "mongoenginetest") + assert conn.get_default_database().name == "mongoenginetest" + assert conn.database_names()[0] == "mongoenginetest" def test_connect_with_host_list(self): """Ensure that the connect() method works when host is a list @@ -301,22 +294,22 @@ class ConnectionTest(unittest.TestCase): connect(host=["mongomock://localhost"]) conn = get_connection() - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect(host=["mongodb://localhost"], is_mock=True, alias="testdb2") conn = get_connection("testdb2") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect(host=["localhost"], is_mock=True, alias="testdb3") conn = get_connection("testdb3") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host=["mongomock://localhost:27017", "mongomock://localhost:27018"], alias="testdb4", ) conn = get_connection("testdb4") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host=["mongodb://localhost:27017", "mongodb://localhost:27018"], @@ -324,13 +317,13 @@ class ConnectionTest(unittest.TestCase): alias="testdb5", ) conn = get_connection("testdb5") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host=["localhost:27017", "localhost:27018"], is_mock=True, alias="testdb6" ) conn = get_connection("testdb6") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) def test_disconnect_cleans_globals(self): """Ensure that the disconnect() method cleans the globals objects""" @@ -340,20 +333,20 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest") - self.assertEqual(len(connections), 1) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 1) + assert len(connections) == 1 + assert len(dbs) == 0 + assert len(connection_settings) == 1 class TestDoc(Document): pass TestDoc.drop_collection() # triggers the db - self.assertEqual(len(dbs), 1) + assert len(dbs) == 1 disconnect() - self.assertEqual(len(connections), 0) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 0) + assert len(connections) == 0 + assert len(dbs) == 0 + assert len(connection_settings) == 0 def test_disconnect_cleans_cached_collection_attribute_in_document(self): """Ensure that the disconnect() method works properly""" @@ -362,22 +355,20 @@ class ConnectionTest(unittest.TestCase): class History(Document): pass - self.assertIsNone(History._collection) + assert History._collection is None History.drop_collection() History.objects.first() # will trigger the caching of _collection attribute - self.assertIsNotNone(History._collection) + assert History._collection is not None disconnect() - self.assertIsNone(History._collection) + assert History._collection is None - with self.assertRaises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as ctx_err: History.objects.first() - self.assertEqual( - "You have not defined a default connection", str(ctx_err.exception) - ) + assert "You have not defined a default connection" == str(ctx_err.exception) def test_connect_disconnect_works_on_same_document(self): """Ensure that the connect/disconnect works properly with a single Document""" @@ -399,7 +390,7 @@ class ConnectionTest(unittest.TestCase): disconnect() # Make sure save doesnt work at this stage - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): User(name="Wont work").save() # Save in db2 @@ -408,13 +399,13 @@ class ConnectionTest(unittest.TestCase): disconnect() db1_users = list(client[db1].user.find()) - self.assertEqual(db1_users, [{"_id": user1.id, "name": "John is in db1"}]) + assert db1_users == [{"_id": user1.id, "name": "John is in db1"}] db2_users = list(client[db2].user.find()) - self.assertEqual(db2_users, [{"_id": user2.id, "name": "Bob is in db2"}]) + assert db2_users == [{"_id": user2.id, "name": "Bob is in db2"}] def test_disconnect_silently_pass_if_alias_does_not_exist(self): connections = mongoengine.connection._connections - self.assertEqual(len(connections), 0) + assert len(connections) == 0 disconnect(alias="not_exist") def test_disconnect_all(self): @@ -437,26 +428,26 @@ class ConnectionTest(unittest.TestCase): History1.drop_collection() History1.objects.first() - self.assertIsNotNone(History._collection) - self.assertIsNotNone(History1._collection) + assert History._collection is not None + assert History1._collection is not None - self.assertEqual(len(connections), 2) - self.assertEqual(len(dbs), 2) - self.assertEqual(len(connection_settings), 2) + assert len(connections) == 2 + assert len(dbs) == 2 + assert len(connection_settings) == 2 disconnect_all() - self.assertIsNone(History._collection) - self.assertIsNone(History1._collection) + assert History._collection is None + assert History1._collection is None - self.assertEqual(len(connections), 0) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 0) + assert len(connections) == 0 + assert len(dbs) == 0 + assert len(connection_settings) == 0 - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): History.objects.first() - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): History1.objects.first() def test_disconnect_all_silently_pass_if_no_connection_exist(self): @@ -473,7 +464,7 @@ class ConnectionTest(unittest.TestCase): expected_connection.server_info() - self.assertEqual(expected_connection, actual_connection) + assert expected_connection == actual_connection def test_connect_uri(self): """Ensure that the connect() method works properly with URIs.""" @@ -490,11 +481,11 @@ class ConnectionTest(unittest.TestCase): ) conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" c.admin.system.users.delete_many({}) c.mongoenginetest.system.users.delete_many({}) @@ -506,11 +497,11 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest", host="mongodb://localhost/") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" def test_connect_uri_default_db(self): """Ensure connect() defaults to the right database name if @@ -519,11 +510,11 @@ class ConnectionTest(unittest.TestCase): connect(host="mongodb://localhost/") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "test") + assert isinstance(db, pymongo.database.Database) + assert db.name == "test" def test_uri_without_credentials_doesnt_override_conn_settings(self): """Ensure connect() uses the username & password params if the URI @@ -536,7 +527,8 @@ class ConnectionTest(unittest.TestCase): # OperationFailure means that mongoengine attempted authentication # w/ the provided username/password and failed - that's the desired # behavior. If the MongoDB URI would override the credentials - self.assertRaises(OperationFailure, get_db) + with pytest.raises(OperationFailure): + get_db() def test_connect_uri_with_authsource(self): """Ensure that the connect() method works well with `authSource` @@ -554,7 +546,8 @@ class ConnectionTest(unittest.TestCase): alias="test1", host="mongodb://username2:password@localhost/mongoenginetest", ) - self.assertRaises(OperationFailure, test_conn.server_info) + with pytest.raises(OperationFailure): + test_conn.server_info() # Authentication succeeds with "authSource" authd_conn = connect( @@ -566,8 +559,8 @@ class ConnectionTest(unittest.TestCase): ), ) db = get_db("test2") - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" # Clear all users authd_conn.admin.system.users.delete_many({}) @@ -577,13 +570,14 @@ class ConnectionTest(unittest.TestCase): """ register_connection("testdb", "mongoenginetest2") - self.assertRaises(ConnectionFailure, get_connection) + with pytest.raises(ConnectionFailure): + get_connection() conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db("testdb") - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest2") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest2" def test_register_connection_defaults(self): """Ensure that defaults are used when the host and port are None. @@ -591,18 +585,18 @@ class ConnectionTest(unittest.TestCase): register_connection("testdb", "mongoenginetest", host=None, port=None) conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) def test_connection_kwargs(self): """Ensure that connection kwargs get passed to pymongo.""" connect("mongoenginetest", alias="t1", tz_aware=True) conn = get_connection("t1") - self.assertTrue(get_tz_awareness(conn)) + assert get_tz_awareness(conn) connect("mongoenginetest2", alias="t2") conn = get_connection("t2") - self.assertFalse(get_tz_awareness(conn)) + assert not get_tz_awareness(conn) def test_connection_pool_via_kwarg(self): """Ensure we can specify a max connection pool size using @@ -613,7 +607,7 @@ class ConnectionTest(unittest.TestCase): conn = connect( "mongoenginetest", alias="max_pool_size_via_kwarg", **pool_size_kwargs ) - self.assertEqual(conn.max_pool_size, 100) + assert conn.max_pool_size == 100 def test_connection_pool_via_uri(self): """Ensure we can specify a max connection pool size using @@ -623,7 +617,7 @@ class ConnectionTest(unittest.TestCase): host="mongodb://localhost/test?maxpoolsize=100", alias="max_pool_size_via_uri", ) - self.assertEqual(conn.max_pool_size, 100) + assert conn.max_pool_size == 100 def test_write_concern(self): """Ensure write concern can be specified in connect() via @@ -642,18 +636,18 @@ class ConnectionTest(unittest.TestCase): """ c = connect(host="mongodb://localhost/test?replicaSet=local-rs") db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "test") + assert isinstance(db, pymongo.database.Database) + assert db.name == "test" def test_connect_with_replicaset_via_kwargs(self): """Ensure connect() works when specifying a replicaSet via the connection kwargs """ c = connect(replicaset="local-rs") - self.assertEqual(c._MongoClient__options.replica_set_name, "local-rs") + assert c._MongoClient__options.replica_set_name == "local-rs" db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "test") + assert isinstance(db, pymongo.database.Database) + assert db.name == "test" def test_connect_tz_aware(self): connect("mongoenginetest", tz_aware=True) @@ -666,13 +660,13 @@ class ConnectionTest(unittest.TestCase): DateDoc(the_date=d).save() date_doc = DateDoc.objects.first() - self.assertEqual(d, date_doc.the_date) + assert d == date_doc.the_date def test_read_preference_from_parse(self): conn = connect( host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred" ) - self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) + assert conn.read_preference == ReadPreference.SECONDARY_PREFERRED def test_multiple_connection_settings(self): connect("mongoenginetest", alias="t1", host="localhost") @@ -680,27 +674,27 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest2", alias="t2", host="127.0.0.1") mongo_connections = mongoengine.connection._connections - self.assertEqual(len(mongo_connections.items()), 2) - self.assertIn("t1", mongo_connections.keys()) - self.assertIn("t2", mongo_connections.keys()) + assert len(mongo_connections.items()) == 2 + assert "t1" in mongo_connections.keys() + assert "t2" in mongo_connections.keys() # Handle PyMongo 3+ Async Connection # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. # Purposely not catching exception to fail test if thrown. mongo_connections["t1"].server_info() mongo_connections["t2"].server_info() - self.assertEqual(mongo_connections["t1"].address[0], "localhost") - self.assertEqual(mongo_connections["t2"].address[0], "127.0.0.1") + assert mongo_connections["t1"].address[0] == "localhost" + assert mongo_connections["t2"].address[0] == "127.0.0.1" def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): c1 = connect(alias="testdb1", db="testdb1") c2 = connect(alias="testdb2", db="testdb2") - self.assertIs(c1, c2) + assert c1 is c2 def test_connect_2_databases_uses_different_client_if_different_parameters(self): c1 = connect(alias="testdb1", db="testdb1", username="u1") c2 = connect(alias="testdb2", db="testdb2", username="u2") - self.assertIsNot(c1, c2) + assert c1 is not c2 if __name__ == "__main__": diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 32e48a70..cf4dd100 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -10,6 +10,7 @@ from mongoengine.context_managers import ( switch_db, ) from mongoengine.pymongo_support import count_documents +import pytest class ContextManagersTest(unittest.TestCase): @@ -23,20 +24,20 @@ class ContextManagersTest(unittest.TestCase): Group.drop_collection() Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() with switch_db(Group, "testdb-1") as Group: - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() Group(name="hello").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() Group.drop_collection() - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() def test_switch_collection_context_manager(self): connect("mongoenginetest") @@ -51,20 +52,20 @@ class ContextManagersTest(unittest.TestCase): Group.drop_collection() # drops in group1 Group(name="hello - group").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() with switch_collection(Group, "group1") as Group: - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() Group(name="hello - group1").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() Group.drop_collection() - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced. @@ -89,20 +90,20 @@ class ContextManagersTest(unittest.TestCase): Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields["members"]._auto_dereference) - self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference) + assert Group._fields["members"]._auto_dereference + assert not NoDeRefGroup._fields["members"]._auto_dereference with no_dereference(Group) as Group: group = Group.objects.first() for m in group.members: - self.assertNotIsInstance(m, User) - self.assertNotIsInstance(group.ref, User) - self.assertNotIsInstance(group.generic, User) + assert not isinstance(m, User) + assert not isinstance(group.ref, User) + assert not isinstance(group.generic, User) for m in group.members: - self.assertIsInstance(m, User) - self.assertIsInstance(group.ref, User) - self.assertIsInstance(group.generic, User) + assert isinstance(m, User) + assert isinstance(group.ref, User) + assert isinstance(group.generic, User) def test_no_dereference_context_manager_dbref(self): """Ensure that DBRef items in ListFields aren't dereferenced. @@ -127,18 +128,18 @@ class ContextManagersTest(unittest.TestCase): Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields["members"]._auto_dereference) - self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference) + assert Group._fields["members"]._auto_dereference + assert not NoDeRefGroup._fields["members"]._auto_dereference with no_dereference(Group) as Group: group = Group.objects.first() - self.assertTrue(all([not isinstance(m, User) for m in group.members])) - self.assertNotIsInstance(group.ref, User) - self.assertNotIsInstance(group.generic, User) + assert all([not isinstance(m, User) for m in group.members]) + assert not isinstance(group.ref, User) + assert not isinstance(group.generic, User) - self.assertTrue(all([isinstance(m, User) for m in group.members])) - self.assertIsInstance(group.ref, User) - self.assertIsInstance(group.generic, User) + assert all([isinstance(m, User) for m in group.members]) + assert isinstance(group.ref, User) + assert isinstance(group.generic, User) def test_no_sub_classes(self): class A(Document): @@ -159,32 +160,32 @@ class ContextManagersTest(unittest.TestCase): B(x=30).save() C(x=40).save() - self.assertEqual(A.objects.count(), 5) - self.assertEqual(B.objects.count(), 3) - self.assertEqual(C.objects.count(), 1) + assert A.objects.count() == 5 + assert B.objects.count() == 3 + assert C.objects.count() == 1 with no_sub_classes(A): - self.assertEqual(A.objects.count(), 2) + assert A.objects.count() == 2 for obj in A.objects: - self.assertEqual(obj.__class__, A) + assert obj.__class__ == A with no_sub_classes(B): - self.assertEqual(B.objects.count(), 2) + assert B.objects.count() == 2 for obj in B.objects: - self.assertEqual(obj.__class__, B) + assert obj.__class__ == B with no_sub_classes(C): - self.assertEqual(C.objects.count(), 1) + assert C.objects.count() == 1 for obj in C.objects: - self.assertEqual(obj.__class__, C) + assert obj.__class__ == C # Confirm context manager exit correctly - self.assertEqual(A.objects.count(), 5) - self.assertEqual(B.objects.count(), 3) - self.assertEqual(C.objects.count(), 1) + assert A.objects.count() == 5 + assert B.objects.count() == 3 + assert C.objects.count() == 1 def test_no_sub_classes_modification_to_document_class_are_temporary(self): class A(Document): @@ -194,27 +195,27 @@ class ContextManagersTest(unittest.TestCase): class B(A): z = IntField() - self.assertEqual(A._subclasses, ("A", "A.B")) + assert A._subclasses == ("A", "A.B") with no_sub_classes(A): - self.assertEqual(A._subclasses, ("A",)) - self.assertEqual(A._subclasses, ("A", "A.B")) + assert A._subclasses == ("A",) + assert A._subclasses == ("A", "A.B") - self.assertEqual(B._subclasses, ("A.B",)) + assert B._subclasses == ("A.B",) with no_sub_classes(B): - self.assertEqual(B._subclasses, ("A.B",)) - self.assertEqual(B._subclasses, ("A.B",)) + assert B._subclasses == ("A.B",) + assert B._subclasses == ("A.B",) def test_no_subclass_context_manager_does_not_swallow_exception(self): class User(Document): name = StringField() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): with no_sub_classes(User): raise TypeError() def test_query_counter_does_not_swallow_exception(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): with query_counter() as q: raise TypeError() @@ -227,10 +228,10 @@ class ContextManagersTest(unittest.TestCase): try: NEW_LEVEL = 1 db.set_profiling_level(NEW_LEVEL) - self.assertEqual(db.profiling_level(), NEW_LEVEL) + assert db.profiling_level() == NEW_LEVEL with query_counter() as q: - self.assertEqual(db.profiling_level(), 2) - self.assertEqual(db.profiling_level(), NEW_LEVEL) + assert db.profiling_level() == 2 + assert db.profiling_level() == NEW_LEVEL except Exception: db.set_profiling_level( initial_profiling_level @@ -255,33 +256,31 @@ class ContextManagersTest(unittest.TestCase): counter = 0 with query_counter() as q: - self.assertEqual(q, counter) - self.assertEqual( - q, counter - ) # Ensures previous count query did not get counted + assert q == counter + assert q == counter # Ensures previous count query did not get counted for _ in range(10): issue_1_insert_query() counter += 1 - self.assertEqual(q, counter) + assert q == counter for _ in range(4): issue_1_find_query() counter += 1 - self.assertEqual(q, counter) + assert q == counter for _ in range(3): issue_1_count_query() counter += 1 - self.assertEqual(q, counter) + assert q == counter - self.assertEqual(int(q), counter) # test __int__ - self.assertEqual(repr(q), str(int(q))) # test __repr__ - self.assertGreater(q, -1) # test __gt__ - self.assertGreaterEqual(q, int(q)) # test __gte__ - self.assertNotEqual(q, -1) - self.assertLess(q, 1000) - self.assertLessEqual(q, int(q)) + assert int(q) == counter # test __int__ + assert repr(q) == str(int(q)) # test __repr__ + assert q > -1 # test __gt__ + assert q >= int(q) # test __gte__ + assert q != -1 + assert q < 1000 + assert q <= int(q) def test_query_counter_counts_getmore_queries(self): connect("mongoenginetest") @@ -296,9 +295,9 @@ class ContextManagersTest(unittest.TestCase): ) # first batch of documents contains 101 documents with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 list(collection.find()) - self.assertEqual(q, 2) # 1st select + 1 getmore + assert q == 2 # 1st select + 1 getmore def test_query_counter_ignores_particular_queries(self): connect("mongoenginetest") @@ -308,18 +307,18 @@ class ContextManagersTest(unittest.TestCase): collection.insert_many([{"test": "garbage %s" % i} for i in range(10)]) with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 cursor = collection.find() - self.assertEqual(q, 0) # cursor wasn't opened yet + assert q == 0 # cursor wasn't opened yet _ = next(cursor) # opens the cursor and fires the find query - self.assertEqual(q, 1) + assert q == 1 cursor.close() # issues a `killcursors` query that is ignored by the context - self.assertEqual(q, 1) + assert q == 1 _ = ( db.system.indexes.find_one() ) # queries on db.system.indexes are ignored as well - self.assertEqual(q, 1) + assert q == 1 if __name__ == "__main__": diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index ff7598be..3a6029c1 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,6 +1,8 @@ import unittest from six import iterkeys +import pytest + from mongoengine import Document from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict @@ -31,48 +33,48 @@ class TestBaseDict(unittest.TestCase): dict_items = {"k": "v"} doc = MyDoc() base_dict = BaseDict(dict_items, instance=doc, name="my_name") - self.assertIsInstance(base_dict._instance, Document) - self.assertEqual(base_dict._name, "my_name") - self.assertEqual(base_dict, dict_items) + assert isinstance(base_dict._instance, Document) + assert base_dict._name == "my_name" + assert base_dict == dict_items def test_setdefault_calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict.setdefault("k", "v") - self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) + assert base_dict._instance._changed_fields == [base_dict._name] def test_popitems_calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) - self.assertEqual(base_dict.popitem(), ("k", "v")) - self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) - self.assertFalse(base_dict) + assert base_dict.popitem() == ("k", "v") + assert base_dict._instance._changed_fields == [base_dict._name] + assert not base_dict def test_pop_calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) - self.assertEqual(base_dict.pop("k"), "v") - self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) - self.assertFalse(base_dict) + assert base_dict.pop("k") == "v" + assert base_dict._instance._changed_fields == [base_dict._name] + assert not base_dict def test_pop_calls_does_not_mark_as_changed_when_it_fails(self): base_dict = self._get_basedict({"k": "v"}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): base_dict.pop("X") - self.assertFalse(base_dict._instance._changed_fields) + assert not base_dict._instance._changed_fields def test_clear_calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) base_dict.clear() - self.assertEqual(base_dict._instance._changed_fields, ["my_name"]) - self.assertEqual(base_dict, {}) + assert base_dict._instance._changed_fields == ["my_name"] + assert base_dict == {} def test___delitem___calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) del base_dict["k"] - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"]) - self.assertEqual(base_dict, {}) + assert base_dict._instance._changed_fields == ["my_name.k"] + assert base_dict == {} def test___getitem____KeyError(self): base_dict = self._get_basedict({}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): base_dict["new"] def test___getitem____simple_value(self): @@ -82,62 +84,62 @@ class TestBaseDict(unittest.TestCase): def test___getitem____sublist_gets_converted_to_BaseList(self): base_dict = self._get_basedict({"k": [0, 1, 2]}) sub_list = base_dict["k"] - self.assertEqual(sub_list, [0, 1, 2]) - self.assertIsInstance(sub_list, BaseList) - self.assertIs(sub_list._instance, base_dict._instance) - self.assertEqual(sub_list._name, "my_name.k") - self.assertEqual(base_dict._instance._changed_fields, []) + assert sub_list == [0, 1, 2] + assert isinstance(sub_list, BaseList) + assert sub_list._instance is base_dict._instance + assert sub_list._name == "my_name.k" + assert base_dict._instance._changed_fields == [] # Challenge mark_as_changed from sublist sub_list[1] = None - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.1"]) + assert base_dict._instance._changed_fields == ["my_name.k.1"] def test___getitem____subdict_gets_converted_to_BaseDict(self): base_dict = self._get_basedict({"k": {"subk": "subv"}}) sub_dict = base_dict["k"] - self.assertEqual(sub_dict, {"subk": "subv"}) - self.assertIsInstance(sub_dict, BaseDict) - self.assertIs(sub_dict._instance, base_dict._instance) - self.assertEqual(sub_dict._name, "my_name.k") - self.assertEqual(base_dict._instance._changed_fields, []) + assert sub_dict == {"subk": "subv"} + assert isinstance(sub_dict, BaseDict) + assert sub_dict._instance is base_dict._instance + assert sub_dict._name == "my_name.k" + assert base_dict._instance._changed_fields == [] # Challenge mark_as_changed from subdict sub_dict["subk"] = None - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.subk"]) + assert base_dict._instance._changed_fields == ["my_name.k.subk"] def test_get_sublist_gets_converted_to_BaseList_just_like__getitem__(self): base_dict = self._get_basedict({"k": [0, 1, 2]}) sub_list = base_dict.get("k") - self.assertEqual(sub_list, [0, 1, 2]) - self.assertIsInstance(sub_list, BaseList) + assert sub_list == [0, 1, 2] + assert isinstance(sub_list, BaseList) def test_get_returns_the_same_as___getitem__(self): base_dict = self._get_basedict({"k": [0, 1, 2]}) get_ = base_dict.get("k") getitem_ = base_dict["k"] - self.assertEqual(get_, getitem_) + assert get_ == getitem_ def test_get_default(self): base_dict = self._get_basedict({}) sentinel = object() - self.assertEqual(base_dict.get("new"), None) - self.assertIs(base_dict.get("new", sentinel), sentinel) + assert base_dict.get("new") == None + assert base_dict.get("new", sentinel) is sentinel def test___setitem___calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict["k"] = "v" - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"]) - self.assertEqual(base_dict, {"k": "v"}) + assert base_dict._instance._changed_fields == ["my_name.k"] + assert base_dict == {"k": "v"} def test_update_calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict.update({"k": "v"}) - self.assertEqual(base_dict._instance._changed_fields, ["my_name"]) + assert base_dict._instance._changed_fields == ["my_name"] def test___setattr____not_tracked_by_changes(self): base_dict = self._get_basedict({}) base_dict.a_new_attr = "test" - self.assertEqual(base_dict._instance._changed_fields, []) + assert base_dict._instance._changed_fields == [] def test___delattr____tracked_by_changes(self): # This is probably a bug as __setattr__ is not tracked @@ -146,7 +148,7 @@ class TestBaseDict(unittest.TestCase): base_dict = self._get_basedict({}) base_dict.a_new_attr = "test" del base_dict.a_new_attr - self.assertEqual(base_dict._instance._changed_fields, ["my_name.a_new_attr"]) + assert base_dict._instance._changed_fields == ["my_name.a_new_attr"] class TestBaseList(unittest.TestCase): @@ -167,14 +169,14 @@ class TestBaseList(unittest.TestCase): list_items = [True] doc = MyDoc() base_list = BaseList(list_items, instance=doc, name="my_name") - self.assertIsInstance(base_list._instance, Document) - self.assertEqual(base_list._name, "my_name") - self.assertEqual(base_list, list_items) + assert isinstance(base_list._instance, Document) + assert base_list._name == "my_name" + assert base_list == list_items def test___iter__(self): values = [True, False, True, False] base_list = BaseList(values, instance=None, name="my_name") - self.assertEqual(values, list(base_list)) + assert values == list(base_list) def test___iter___allow_modification_while_iterating_withou_error(self): # regular list allows for this, thus this subclass must comply to that @@ -185,9 +187,9 @@ class TestBaseList(unittest.TestCase): def test_append_calls_mark_as_changed(self): base_list = self._get_baselist([]) - self.assertFalse(base_list._instance._changed_fields) + assert not base_list._instance._changed_fields base_list.append(True) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_subclass_append(self): # Due to the way mark_as_changed_wrapper is implemented @@ -200,85 +202,85 @@ class TestBaseList(unittest.TestCase): def test___getitem__using_simple_index(self): base_list = self._get_baselist([0, 1, 2]) - self.assertEqual(base_list[0], 0) - self.assertEqual(base_list[1], 1) - self.assertEqual(base_list[-1], 2) + assert base_list[0] == 0 + assert base_list[1] == 1 + assert base_list[-1] == 2 def test___getitem__using_slice(self): base_list = self._get_baselist([0, 1, 2]) - self.assertEqual(base_list[1:3], [1, 2]) - self.assertEqual(base_list[0:3:2], [0, 2]) + assert base_list[1:3] == [1, 2] + assert base_list[0:3:2] == [0, 2] def test___getitem___using_slice_returns_list(self): # Bug: using slice does not properly handles the instance # and mark_as_changed behaviour. base_list = self._get_baselist([0, 1, 2]) sliced = base_list[1:3] - self.assertEqual(sliced, [1, 2]) - self.assertIsInstance(sliced, list) - self.assertEqual(base_list._instance._changed_fields, []) + assert sliced == [1, 2] + assert isinstance(sliced, list) + assert base_list._instance._changed_fields == [] def test___getitem__sublist_returns_BaseList_bound_to_instance(self): base_list = self._get_baselist([[1, 2], [3, 4]]) sub_list = base_list[0] - self.assertEqual(sub_list, [1, 2]) - self.assertIsInstance(sub_list, BaseList) - self.assertIs(sub_list._instance, base_list._instance) - self.assertEqual(sub_list._name, "my_name.0") - self.assertEqual(base_list._instance._changed_fields, []) + assert sub_list == [1, 2] + assert isinstance(sub_list, BaseList) + assert sub_list._instance is base_list._instance + assert sub_list._name == "my_name.0" + assert base_list._instance._changed_fields == [] # Challenge mark_as_changed from sublist sub_list[1] = None - self.assertEqual(base_list._instance._changed_fields, ["my_name.0.1"]) + assert base_list._instance._changed_fields == ["my_name.0.1"] def test___getitem__subdict_returns_BaseList_bound_to_instance(self): base_list = self._get_baselist([{"subk": "subv"}]) sub_dict = base_list[0] - self.assertEqual(sub_dict, {"subk": "subv"}) - self.assertIsInstance(sub_dict, BaseDict) - self.assertIs(sub_dict._instance, base_list._instance) - self.assertEqual(sub_dict._name, "my_name.0") - self.assertEqual(base_list._instance._changed_fields, []) + assert sub_dict == {"subk": "subv"} + assert isinstance(sub_dict, BaseDict) + assert sub_dict._instance is base_list._instance + assert sub_dict._name == "my_name.0" + assert base_list._instance._changed_fields == [] # Challenge mark_as_changed from subdict sub_dict["subk"] = None - self.assertEqual(base_list._instance._changed_fields, ["my_name.0.subk"]) + assert base_list._instance._changed_fields == ["my_name.0.subk"] def test_extend_calls_mark_as_changed(self): base_list = self._get_baselist([]) base_list.extend([True]) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_insert_calls_mark_as_changed(self): base_list = self._get_baselist([]) base_list.insert(0, True) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_remove_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list.remove(True) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_remove_not_mark_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): base_list.remove(False) - self.assertFalse(base_list._instance._changed_fields) + assert not base_list._instance._changed_fields def test_pop_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list.pop() - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_reverse_calls_mark_as_changed(self): base_list = self._get_baselist([True, False]) base_list.reverse() - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test___delitem___calls_mark_as_changed(self): base_list = self._get_baselist([True]) del base_list[0] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test___setitem___calls_with_full_slice_mark_as_changed(self): base_list = self._get_baselist([]) @@ -286,8 +288,8 @@ class TestBaseList(unittest.TestCase): 0, 1, ] # Will use __setslice__ under py2 and __setitem__ under py3 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [0, 1]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [0, 1] def test___setitem___calls_with_partial_slice_mark_as_changed(self): base_list = self._get_baselist([0, 1, 2]) @@ -295,66 +297,66 @@ class TestBaseList(unittest.TestCase): 1, 0, ] # Will use __setslice__ under py2 and __setitem__ under py3 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [1, 0, 2]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [1, 0, 2] def test___setitem___calls_with_step_slice_mark_as_changed(self): base_list = self._get_baselist([0, 1, 2]) base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [-1, 1, -2]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [-1, 1, -2] def test___setitem___with_slice(self): base_list = self._get_baselist([0, 1, 2, 3, 4, 5]) base_list[0:6:2] = [None, None, None] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [None, 1, None, 3, None, 5]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [None, 1, None, 3, None, 5] def test___setitem___item_0_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list[0] = False - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [False]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [False] def test___setitem___item_1_calls_mark_as_changed(self): base_list = self._get_baselist([True, True]) base_list[1] = False - self.assertEqual(base_list._instance._changed_fields, ["my_name.1"]) - self.assertEqual(base_list, [True, False]) + assert base_list._instance._changed_fields == ["my_name.1"] + assert base_list == [True, False] def test___delslice___calls_mark_as_changed(self): base_list = self._get_baselist([0, 1]) del base_list[0:1] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [1]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [1] def test___iadd___calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list += [False] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test___imul___calls_mark_as_changed(self): base_list = self._get_baselist([True]) - self.assertEqual(base_list._instance._changed_fields, []) + assert base_list._instance._changed_fields == [] base_list *= 2 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_sort_calls_not_marked_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): base_list.sort(key=1) - self.assertEqual(base_list._instance._changed_fields, []) + assert base_list._instance._changed_fields == [] def test_sort_calls_mark_as_changed(self): base_list = self._get_baselist([True, False]) base_list.sort() - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_sort_calls_with_key(self): base_list = self._get_baselist([1, 2, 11]) base_list.sort(key=lambda i: str(i)) - self.assertEqual(base_list, [1, 11, 2]) + assert base_list == [1, 11, 2] class TestStrictDict(unittest.TestCase): @@ -366,32 +368,32 @@ class TestStrictDict(unittest.TestCase): def test_init(self): d = self.dtype(a=1, b=1, c=1) - self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + assert (d.a, d.b, d.c) == (1, 1, 1) def test_iterkeys(self): d = self.dtype(a=1) - self.assertEqual(list(iterkeys(d)), ["a"]) + assert list(iterkeys(d)) == ["a"] def test_len(self): d = self.dtype(a=1) - self.assertEqual(len(d), 1) + assert len(d) == 1 def test_pop(self): d = self.dtype(a=1) - self.assertIn("a", d) + assert "a" in d d.pop("a") - self.assertNotIn("a", d) + assert "a" not in d def test_repr(self): d = self.dtype(a=1, b=2, c=3) - self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') + assert repr(d) == '{"a": 1, "b": 2, "c": 3}' # make sure quotes are escaped properly d = self.dtype(a='"', b="'", c="") - self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') + assert repr(d) == '{"a": \'"\', "b": "\'", "c": \'\'}' def test_init_fails_on_nonexisting_attrs(self): - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.dtype(a=1, b=2, d=3) def test_eq(self): @@ -403,45 +405,46 @@ class TestStrictDict(unittest.TestCase): h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) - self.assertEqual(d, dd) - self.assertNotEqual(d, e) - self.assertNotEqual(d, f) - self.assertNotEqual(d, g) - self.assertNotEqual(f, d) - self.assertEqual(d, h) - self.assertNotEqual(d, i) + assert d == dd + assert d != e + assert d != f + assert d != g + assert f != d + assert d == h + assert d != i def test_setattr_getattr(self): d = self.dtype() d.a = 1 - self.assertEqual(d.a, 1) - self.assertRaises(AttributeError, getattr, d, "b") + assert d.a == 1 + with pytest.raises(AttributeError): + getattr(d, "b") def test_setattr_raises_on_nonexisting_attr(self): d = self.dtype() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): d.x = 1 def test_setattr_getattr_special(self): d = self.strict_dict_class(["items"]) d.items = 1 - self.assertEqual(d.items, 1) + assert d.items == 1 def test_get(self): d = self.dtype(a=1) - self.assertEqual(d.get("a"), 1) - self.assertEqual(d.get("b", "bla"), "bla") + assert d.get("a") == 1 + assert d.get("b", "bla") == "bla" def test_items(self): d = self.dtype(a=1) - self.assertEqual(d.items(), [("a", 1)]) + assert d.items() == [("a", 1)] d = self.dtype(a=1, b=2) - self.assertEqual(d.items(), [("a", 1), ("b", 2)]) + assert d.items() == [("a", 1), ("b", 2)] def test_mappings_protocol(self): d = self.dtype(a=1, b=2) - self.assertEqual(dict(d), {"a": 1, "b": 2}) - self.assertEqual(dict(**d), {"a": 1, "b": 2}) + assert dict(d) == {"a": 1, "b": 2} + assert dict(**d) == {"a": 1, "b": 2} if __name__ == "__main__": diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 4730e2e3..b9d92883 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -42,37 +42,37 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 len(group_obj._data["members"]) - self.assertEqual(q, 1) + assert q == 1 len(group_obj.members) - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 User.drop_collection() Group.drop_collection() @@ -99,40 +99,40 @@ class FieldTest(unittest.TestCase): group.reload() # Confirm reload works with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced # verifies that no additional queries gets executed # if we re-iterate over the ListField once it is # dereferenced [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 def test_list_item_dereference_orphan_dbref(self): """Ensure that orphan DBRef items in ListFields are dereferenced. @@ -159,21 +159,21 @@ class FieldTest(unittest.TestCase): # Group.members list is an orphan DBRef User.objects[0].delete() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced # verifies that no additional queries gets executed # if we re-iterate over the ListField once it is # dereferenced [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced User.drop_collection() Group.drop_collection() @@ -197,8 +197,8 @@ class FieldTest(unittest.TestCase): Group(members=User.objects).save() group = Group.objects.first() - self.assertEqual(Group._get_collection().find_one()["members"], [1]) - self.assertEqual(group.members, [user]) + assert Group._get_collection().find_one()["members"] == [1] + assert group.members == [user] def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -231,8 +231,8 @@ class FieldTest(unittest.TestCase): group.save() group = Group.objects.first() - self.assertEqual(group.members[0].name, "user 1") - self.assertEqual(group.members[-1].name, "String!") + assert group.members[0].name == "user 1" + assert group.members[-1].name == "String!" def test_migrate_references(self): """Example of migrating ReferenceField storage @@ -253,12 +253,12 @@ class FieldTest(unittest.TestCase): group = Group(author=user, members=[user]).save() raw_data = Group._get_collection().find_one() - self.assertIsInstance(raw_data["author"], DBRef) - self.assertIsInstance(raw_data["members"][0], DBRef) + assert isinstance(raw_data["author"], DBRef) + assert isinstance(raw_data["members"][0], DBRef) group = Group.objects.first() - self.assertEqual(group.author, user) - self.assertEqual(group.members, [user]) + assert group.author == user + assert group.members == [user] # Migrate the model definition class Group(Document): @@ -273,12 +273,12 @@ class FieldTest(unittest.TestCase): g.save() group = Group.objects.first() - self.assertEqual(group.author, user) - self.assertEqual(group.members, [user]) + assert group.author == user + assert group.members == [user] raw_data = Group._get_collection().find_one() - self.assertIsInstance(raw_data["author"], ObjectId) - self.assertIsInstance(raw_data["members"][0], ObjectId) + assert isinstance(raw_data["author"], ObjectId) + assert isinstance(raw_data["members"][0], ObjectId) def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. @@ -309,43 +309,43 @@ class FieldTest(unittest.TestCase): Employee(name="Funky Gibbon", boss=bill, friends=friends).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 peter = Employee.objects.with_id(peter.id) - self.assertEqual(q, 1) + assert q == 1 peter.boss - self.assertEqual(q, 2) + assert q == 2 peter.friends - self.assertEqual(q, 3) + assert q == 3 # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 peter = Employee.objects.with_id(peter.id).select_related() - self.assertEqual(q, 2) + assert q == 2 - self.assertEqual(peter.boss, bill) - self.assertEqual(q, 2) + assert peter.boss == bill + assert q == 2 - self.assertEqual(peter.friends, friends) - self.assertEqual(q, 2) + assert peter.friends == friends + assert q == 2 # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 employees = Employee.objects(boss=bill).select_related() - self.assertEqual(q, 2) + assert q == 2 for employee in employees: - self.assertEqual(employee.boss, bill) - self.assertEqual(q, 2) + assert employee.boss == bill + assert q == 2 - self.assertEqual(employee.friends, friends) - self.assertEqual(q, 2) + assert employee.friends == friends + assert q == 2 def test_list_of_lists_of_references(self): class User(Document): @@ -366,10 +366,10 @@ class FieldTest(unittest.TestCase): u3 = User.objects.create(name="u3") SimpleList.objects.create(users=[u1, u2, u3]) - self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) + assert SimpleList.objects.all()[0].users == [u1, u2, u3] Post.objects.create(user_lists=[[u1, u2], [u3]]) - self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) + assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]] def test_circular_reference(self): """Ensure you can handle circular references @@ -403,9 +403,7 @@ class FieldTest(unittest.TestCase): daughter.relations.append(self_rel) daughter.save() - self.assertEqual( - "[, ]", "%s" % Person.objects() - ) + assert "[, ]" == "%s" % Person.objects() def test_circular_reference_on_self(self): """Ensure you can handle circular references @@ -432,9 +430,7 @@ class FieldTest(unittest.TestCase): daughter.relations.append(daughter) daughter.save() - self.assertEqual( - "[, ]", "%s" % Person.objects() - ) + assert "[, ]" == "%s" % Person.objects() def test_circular_tree_reference(self): """Ensure you can handle circular references with more than one level @@ -473,9 +469,9 @@ class FieldTest(unittest.TestCase): anna.other.name = "Anna's friends" anna.save() - self.assertEqual( - "[, , , ]", - "%s" % Person.objects(), + assert ( + "[, , , ]" + == "%s" % Person.objects() ) def test_generic_reference(self): @@ -516,52 +512,52 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ def test_generic_reference_orphan_dbref(self): """Ensure that generic orphan DBRef items in ListFields are dereferenced. @@ -604,18 +600,18 @@ class FieldTest(unittest.TestCase): # an orphan DBRef in the GenericReference ListField UserA.objects[0].delete() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 4 + assert group_obj._data["members"]._dereferenced [m for m in group_obj.members] - self.assertEqual(q, 4) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 4 + assert group_obj._data["members"]._dereferenced UserA.drop_collection() UserB.drop_collection() @@ -660,52 +656,52 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ UserA.drop_collection() UserB.drop_collection() @@ -735,43 +731,43 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, User) + assert isinstance(m, User) # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, User) + assert isinstance(m, User) # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, User) + assert isinstance(m, User) User.drop_collection() Group.drop_collection() @@ -813,65 +809,65 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ Group.objects.delete() Group().save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 1) - self.assertEqual(group_obj.members, {}) + assert q == 1 + assert group_obj.members == {} UserA.drop_collection() UserB.drop_collection() @@ -903,52 +899,52 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, UserA) + assert isinstance(m, UserA) # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, UserA) + assert isinstance(m, UserA) # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, UserA) + assert isinstance(m, UserA) UserA.drop_collection() Group.drop_collection() @@ -990,64 +986,64 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ Group.objects.delete() Group().save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 1) + assert q == 1 UserA.drop_collection() UserB.drop_collection() @@ -1075,8 +1071,8 @@ class FieldTest(unittest.TestCase): root.save() root = root.reload() - self.assertEqual(root.children, [company]) - self.assertEqual(company.parents, [root]) + assert root.children == [company] + assert company.parents == [root] def test_dict_in_dbref_instance(self): class Person(Document): @@ -1102,8 +1098,8 @@ class FieldTest(unittest.TestCase): room_101.save() room = Room.objects.first().select_related() - self.assertEqual(room.staffs_with_position[0]["staff"], sarah) - self.assertEqual(room.staffs_with_position[1]["staff"], bob) + assert room.staffs_with_position[0]["staff"] == sarah + assert room.staffs_with_position[1]["staff"] == bob def test_document_reload_no_inheritance(self): class Foo(Document): @@ -1133,8 +1129,8 @@ class FieldTest(unittest.TestCase): foo.save() foo.reload() - self.assertEqual(type(foo.bar), Bar) - self.assertEqual(type(foo.baz), Baz) + assert type(foo.bar) == Bar + assert type(foo.baz) == Baz def test_document_reload_reference_integrity(self): """ @@ -1166,13 +1162,13 @@ class FieldTest(unittest.TestCase): concurrent_change_user = User.objects.get(id=1) concurrent_change_user.name = "new-name" concurrent_change_user.save() - self.assertNotEqual(user.name, "new-name") + assert user.name != "new-name" msg = Message.objects.get(id=1) msg.reload() - self.assertEqual(msg.topic, topic) - self.assertEqual(msg.author, user) - self.assertEqual(msg.author.name, "new-name") + assert msg.topic == topic + assert msg.author == user + assert msg.author.name == "new-name" def test_list_lookup_not_checked_in_map(self): """Ensure we dereference list data correctly @@ -1194,8 +1190,8 @@ class FieldTest(unittest.TestCase): Message(id=1, comments=[c1, c2]).save() msg = Message.objects.get(id=1) - self.assertEqual(0, msg.comments[0].id) - self.assertEqual(1, msg.comments[1].id) + assert 0 == msg.comments[0].id + assert 1 == msg.comments[1].id def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -1217,15 +1213,15 @@ class FieldTest(unittest.TestCase): Group(name="Test", members=User.objects).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 group_obj.name = "new test" group_obj.save() - self.assertEqual(q, 2) + assert q == 2 def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -1247,15 +1243,15 @@ class FieldTest(unittest.TestCase): Group(name="Test", members=User.objects).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 group_obj.name = "new test" group_obj.save() - self.assertEqual(q, 2) + assert q == 2 def test_generic_reference_save_doesnt_cause_extra_queries(self): class UserA(Document): @@ -1287,15 +1283,15 @@ class FieldTest(unittest.TestCase): Group(name="test", members=members).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 group_obj.name = "new test" group_obj.save() - self.assertEqual(q, 2) + assert q == 2 def test_objectid_reference_across_databases(self): # mongoenginetest - Is default connection alias from setUp() @@ -1319,10 +1315,10 @@ class FieldTest(unittest.TestCase): # Can't use query_counter across databases - so test the _data object book = Book.objects.first() - self.assertNotIsInstance(book._data["author"], User) + assert not isinstance(book._data["author"], User) book.select_related() - self.assertIsInstance(book._data["author"], User) + assert isinstance(book._data["author"], User) def test_non_ascii_pk(self): """ @@ -1346,7 +1342,7 @@ class FieldTest(unittest.TestCase): BrandGroup(title="top_brands", brands=[brand1, brand2]).save() brand_groups = BrandGroup.objects().all() - self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) + assert 2 == len([brand for bg in brand_groups for brand in bg.brands]) def test_dereferencing_embedded_listfield_referencefield(self): class Tag(Document): @@ -1370,7 +1366,7 @@ class FieldTest(unittest.TestCase): Page(tags=[tag], posts=[post]).save() page = Page.objects.first() - self.assertEqual(page.tags[0], page.posts[0].tags[0]) + assert page.tags[0] == page.posts[0].tags[0] def test_select_related_follows_embedded_referencefields(self): class Song(Document): @@ -1390,12 +1386,12 @@ class FieldTest(unittest.TestCase): playlist = Playlist.objects.create(items=items) with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 playlist = Playlist.objects.first().select_related() songs = [item.song for item in playlist.items] - self.assertEqual(q, 2) + assert q == 2 if __name__ == "__main__": diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index e92f3d09..c1ea407c 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -39,7 +39,7 @@ class ConnectionTest(unittest.TestCase): # really??? return - self.assertEqual(conn.read_preference, READ_PREF) + assert conn.read_preference == READ_PREF if __name__ == "__main__": diff --git a/tests/test_signals.py b/tests/test_signals.py index 1d0607d7..b217712b 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -245,7 +245,7 @@ class SignalTests(unittest.TestCase): # Note that there is a chance that the following assert fails in case # some receivers (eventually created in other tests) # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) - self.assertEqual(self.pre_signals, post_signals) + assert self.pre_signals == post_signals def test_model_signals(self): """ Model saves should throw some signals. """ @@ -267,97 +267,76 @@ class SignalTests(unittest.TestCase): self.get_signal_output(lambda: None) # eliminate signal output a1 = self.Author.objects(name="Bill Shakespeare")[0] - self.assertEqual( - self.get_signal_output(create_author), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - ], - ) + assert self.get_signal_output(create_author) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + ] a1 = self.Author(name="Bill Shakespeare") - self.assertEqual( - self.get_signal_output(a1.save), - [ - "pre_save signal, Bill Shakespeare", - {}, - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", - {}, - "post_save signal, Bill Shakespeare", - "post_save dirty keys, ['name']", - "Is created", - {}, - ], - ) + assert self.get_signal_output(a1.save) == [ + "pre_save signal, Bill Shakespeare", + {}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {}, + ] a1.reload() a1.name = "William Shakespeare" - self.assertEqual( - self.get_signal_output(a1.save), - [ - "pre_save signal, William Shakespeare", - {}, - "pre_save_post_validation signal, William Shakespeare", - "Is updated", - {}, - "post_save signal, William Shakespeare", - "post_save dirty keys, ['name']", - "Is updated", - {}, - ], - ) + assert self.get_signal_output(a1.save) == [ + "pre_save signal, William Shakespeare", + {}, + "pre_save_post_validation signal, William Shakespeare", + "Is updated", + {}, + "post_save signal, William Shakespeare", + "post_save dirty keys, ['name']", + "Is updated", + {}, + ] - self.assertEqual( - self.get_signal_output(a1.delete), - [ - "pre_delete signal, William Shakespeare", - {}, - "post_delete signal, William Shakespeare", - {}, - ], - ) + assert self.get_signal_output(a1.delete) == [ + "pre_delete signal, William Shakespeare", + {}, + "post_delete signal, William Shakespeare", + {}, + ] - self.assertEqual( - self.get_signal_output(load_existing_author), - [ - "pre_init signal, Author", - {"id": 2, "name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = False", - ], - ) + assert self.get_signal_output(load_existing_author) == [ + "pre_init signal, Author", + {"id": 2, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + ] - self.assertEqual( - self.get_signal_output(bulk_create_author_with_load), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {}, - "pre_init signal, Author", - {"id": 3, "name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = False", - "post_bulk_insert signal, []", - "Is loaded", - {}, - ], - ) + assert self.get_signal_output(bulk_create_author_with_load) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {}, + "pre_init signal, Author", + {"id": 3, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + "post_bulk_insert signal, []", + "Is loaded", + {}, + ] - self.assertEqual( - self.get_signal_output(bulk_create_author_without_load), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {}, - "post_bulk_insert signal, []", - "Not loaded", - {}, - ], - ) + assert self.get_signal_output(bulk_create_author_without_load) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {}, + "post_bulk_insert signal, []", + "Not loaded", + {}, + ] def test_signal_kwargs(self): """ Make sure signal_kwargs is passed to signals calls. """ @@ -367,83 +346,74 @@ class SignalTests(unittest.TestCase): a.save(signal_kwargs={"live": True, "die": False}) a.delete(signal_kwargs={"live": False, "die": True}) - self.assertEqual( - self.get_signal_output(live_and_let_die), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_save signal, Bill Shakespeare", - {"die": False, "live": True}, - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", - {"die": False, "live": True}, - "post_save signal, Bill Shakespeare", - "post_save dirty keys, ['name']", - "Is created", - {"die": False, "live": True}, - "pre_delete signal, Bill Shakespeare", - {"die": True, "live": False}, - "post_delete signal, Bill Shakespeare", - {"die": True, "live": False}, - ], - ) + assert self.get_signal_output(live_and_let_die) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_save signal, Bill Shakespeare", + {"die": False, "live": True}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {"die": False, "live": True}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {"die": False, "live": True}, + "pre_delete signal, Bill Shakespeare", + {"die": True, "live": False}, + "post_delete signal, Bill Shakespeare", + {"die": True, "live": False}, + ] def bulk_create_author(): a1 = self.Author(name="Bill Shakespeare") self.Author.objects.insert([a1], signal_kwargs={"key": True}) - self.assertEqual( - self.get_signal_output(bulk_create_author), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {"key": True}, - "pre_init signal, Author", - {"id": 2, "name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = False", - "post_bulk_insert signal, []", - "Is loaded", - {"key": True}, - ], - ) + assert self.get_signal_output(bulk_create_author) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {"key": True}, + "pre_init signal, Author", + {"id": 2, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + "post_bulk_insert signal, []", + "Is loaded", + {"key": True}, + ] def test_queryset_delete_signals(self): """ Queryset delete should throw some signals. """ self.Another(name="Bill Shakespeare").save() - self.assertEqual( - self.get_signal_output(self.Another.objects.delete), - [ - "pre_delete signal, Bill Shakespeare", - {}, - "post_delete signal, Bill Shakespeare", - {}, - ], - ) + assert self.get_signal_output(self.Another.objects.delete) == [ + "pre_delete signal, Bill Shakespeare", + {}, + "post_delete signal, Bill Shakespeare", + {}, + ] def test_signals_with_explicit_doc_ids(self): """ Model saves must have a created flag the first time.""" ei = self.ExplicitId(id=123) # post save must received the created flag, even if there's already # an object id present - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] # second time, it must be an update - self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) + assert self.get_signal_output(ei.save) == ["Is updated"] def test_signals_with_switch_collection(self): ei = self.ExplicitId(id=123) ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) + assert self.get_signal_output(ei.save) == ["Is updated"] ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] def test_signals_with_switch_db(self): connect("mongoenginetest") @@ -451,14 +421,14 @@ class SignalTests(unittest.TestCase): ei = self.ExplicitId(id=123) ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) + assert self.get_signal_output(ei.save) == ["Is updated"] ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] def test_signals_bulk_insert(self): def bulk_set_active_post(): @@ -470,16 +440,13 @@ class SignalTests(unittest.TestCase): self.Post.objects.insert(posts) results = self.get_signal_output(bulk_set_active_post) - self.assertEqual( - results, - [ - "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", - {}, - "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", - "Is loaded", - {}, - ], - ) + assert results == [ + "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", + {}, + "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", + "Is loaded", + {}, + ] if __name__ == "__main__": diff --git a/tests/test_utils.py b/tests/test_utils.py index 897c19b2..ccb44aac 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import re import unittest from mongoengine.base.utils import LazyRegexCompiler +import pytest signal_output = [] @@ -12,21 +13,21 @@ class LazyRegexCompilerTest(unittest.TestCase): EMAIL_REGEX = LazyRegexCompiler("@", flags=32) descriptor = UserEmail.__dict__["EMAIL_REGEX"] - self.assertIsNone(descriptor._compiled_regex) + assert descriptor._compiled_regex is None regex = UserEmail.EMAIL_REGEX - self.assertEqual(regex, re.compile("@", flags=32)) - self.assertEqual(regex.search("user@domain.com").group(), "@") + assert regex == re.compile("@", flags=32) + assert regex.search("user@domain.com").group() == "@" user_email = UserEmail() - self.assertIs(user_email.EMAIL_REGEX, UserEmail.EMAIL_REGEX) + assert user_email.EMAIL_REGEX is UserEmail.EMAIL_REGEX def test_lazy_regex_compiler_verify_cannot_set_descriptor_on_instance(self): class UserEmail(object): EMAIL_REGEX = LazyRegexCompiler("@") user_email = UserEmail() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): user_email.EMAIL_REGEX = re.compile("@") def test_lazy_regex_compiler_verify_can_override_class_attr(self): @@ -34,6 +35,4 @@ class LazyRegexCompilerTest(unittest.TestCase): EMAIL_REGEX = LazyRegexCompiler("@") UserEmail.EMAIL_REGEX = re.compile("cookies") - self.assertEqual( - UserEmail.EMAIL_REGEX.search("Cake & cookies").group(), "cookies" - ) + assert UserEmail.EMAIL_REGEX.search("Cake & cookies").group() == "cookies"