From b887ea96236c62a04a6962dabc747dda14eb1057 Mon Sep 17 00:00:00 2001 From: otrofimov Date: Thu, 8 Aug 2019 11:55:45 +0300 Subject: [PATCH 01/59] Implement collation for queryset --- mongoengine/queryset/base.py | 31 +++++++++++++++++++++++++++++++ tests/document/indexes.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ba3ac95a..b0e1bff2 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -80,6 +80,7 @@ class BaseQuerySet(object): self._limit = None self._skip = None self._hint = -1 # Using -1 as None is a valid value for hint + self._collation = None self._batch_size = None self.only_fields = [] self._max_time_ms = None @@ -781,6 +782,7 @@ class BaseQuerySet(object): "_limit", "_skip", "_hint", + "_collation", "_auto_dereference", "_search_text", "only_fields", @@ -863,6 +865,32 @@ class BaseQuerySet(object): return queryset + def collation(self, collation=None): + """ + Collation allows users to specify language-specific rules for string + comparison, such as rules for lettercase and accent marks. + :param collation: `~pymongo.collation.Collation` or dict with + following fields: + { + locale: str, + caseLevel: bool, + caseFirst: str, + strength: int, + numericOrdering: bool, + alternate: str, + maxVariable: str, + backwards: str + } + Collation should be added to indexes like in test example + """ + queryset = self.clone() + queryset._collation = collation + + if queryset._cursor_obj: + queryset._cursor_obj.collation(collation) + + return queryset + def batch_size(self, size): """Limit the number of documents returned in a single batch (each batch requires a round trip to the server). @@ -1636,6 +1664,9 @@ class BaseQuerySet(object): if self._hint != -1: self._cursor_obj.hint(self._hint) + if self._collation is not None: + self._cursor_obj.collation(self._collation) + if self._batch_size is not None: self._cursor_obj.batch_size(self._batch_size) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 570e619e..0bc23d1c 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -539,6 +539,35 @@ class IndexesTest(unittest.TestCase): with self.assertRaises(ValueError): BlogPost.objects.hint(("tags", 1)).count() + def test_collation(self): + base = {'locale': "en", 'strength': 2} + + class BlogPost(Document): + name = StringField() + meta = {"indexes": [ + {"fields": ["name"], "name": 'name_index', + 'collation': base} + ]} + + BlogPost.drop_collection() + + names = tuple("%sag %i" % ('t' if n % 2 == 0 else 'T', n) for n in range(10)) + for name in names: + BlogPost(name=name).save() + + query_result = BlogPost.objects.collation(base).order_by('name') + self.assertEqual([x.name for x in query_result], + sorted(names, key=lambda x: x.lower())) + self.assertEqual(10, query_result.count()) + + incorrect_collation = {'arndom': 'wrdo'} + with self.assertRaises(OperationFailure): + BlogPost.objects.collation(incorrect_collation).count() + + query_result = BlogPost.objects.collation({}).order_by('name') + self.assertEqual([x.name for x in query_result], + sorted(names)) + def test_unique(self): """Ensure that uniqueness constraints are applied to fields. """ From fbb3bf869c9cdea0b6c5060e2f57fabf5b8c5e5d Mon Sep 17 00:00:00 2001 From: otrofimov Date: Thu, 8 Aug 2019 15:56:20 +0300 Subject: [PATCH 02/59] compatibility with black --- tests/document/indexes.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 0bc23d1c..fa3d1706 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -540,33 +540,34 @@ class IndexesTest(unittest.TestCase): BlogPost.objects.hint(("tags", 1)).count() def test_collation(self): - base = {'locale': "en", 'strength': 2} + base = {"locale": "en", "strength": 2} class BlogPost(Document): name = StringField() - meta = {"indexes": [ - {"fields": ["name"], "name": 'name_index', - 'collation': base} - ]} + meta = { + "indexes": [ + {"fields": ["name"], "name": "name_index", "collation": base} + ] + } BlogPost.drop_collection() - names = tuple("%sag %i" % ('t' if n % 2 == 0 else 'T', n) for n in range(10)) + names = tuple("%sag %i" % ("t" if n % 2 == 0 else "T", n) for n in range(10)) for name in names: BlogPost(name=name).save() - query_result = BlogPost.objects.collation(base).order_by('name') - self.assertEqual([x.name for x in query_result], - sorted(names, key=lambda x: x.lower())) + query_result = BlogPost.objects.collation(base).order_by("name") + self.assertEqual( + [x.name for x in query_result], sorted(names, key=lambda x: x.lower()) + ) self.assertEqual(10, query_result.count()) - incorrect_collation = {'arndom': 'wrdo'} + incorrect_collation = {"arndom": "wrdo"} with self.assertRaises(OperationFailure): BlogPost.objects.collation(incorrect_collation).count() - query_result = BlogPost.objects.collation({}).order_by('name') - self.assertEqual([x.name for x in query_result], - sorted(names)) + query_result = BlogPost.objects.collation({}).order_by("name") + self.assertEqual([x.name for x in query_result], sorted(names)) def test_unique(self): """Ensure that uniqueness constraints are applied to fields. From eecbb5ca90192a28354efd817f1383ad674100ba Mon Sep 17 00:00:00 2001 From: Ali Mirlou Date: Tue, 20 Aug 2019 19:53:49 +0430 Subject: [PATCH 03/59] Fix small typo --- mongoengine/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9b9fef6e..f8f527a3 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -2291,7 +2291,7 @@ class LineStringField(GeoJsonBaseField): .. code-block:: js {'type' : 'LineString' , - 'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]} + 'coordinates' : [[x1, y1], [x2, y2] ... [xn, yn]]} You can either pass a dict with the full information or a list of points. From e86cf962e99e15eaa59e14380ff50fd9e25ac6fc Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Wed, 21 Aug 2019 13:08:30 +0200 Subject: [PATCH 04/59] Change misleading error message --- mongoengine/queryset/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ba3ac95a..46b20d78 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -355,8 +355,8 @@ class BaseQuerySet(object): except pymongo.errors.BulkWriteError as err: # inserting documents that already have an _id field will # give huge performance debt or raise - message = u"Document must not have _id value before bulk write (%s)" - raise NotUniqueError(message % six.text_type(err)) + message = u"Bulk write error: (%s)" + raise NotUniqueError(message % six.text_type(err.details)) except pymongo.errors.OperationFailure as err: message = "Could not save document (%s)" if re.match("^E1100[01] duplicate key", six.text_type(err)): From 71a6f3d1a46702bf5d66e704519671e828626cd2 Mon Sep 17 00:00:00 2001 From: otrofimov Date: Wed, 21 Aug 2019 18:26:10 +0300 Subject: [PATCH 05/59] test_collation: Added test with `pymongo.collation.Collation` object Readable list of BlogPost names for test --- tests/document/indexes.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index fa3d1706..dcd3fc6a 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -3,6 +3,7 @@ import unittest from datetime import datetime from nose.plugins.skip import SkipTest +from pymongo.collation import Collation from pymongo.errors import OperationFailure import pymongo from six import iteritems @@ -552,7 +553,7 @@ class IndexesTest(unittest.TestCase): BlogPost.drop_collection() - names = tuple("%sag %i" % ("t" if n % 2 == 0 else "T", n) for n in range(10)) + names = ["tag1", "Tag2", "tag3", "Tag4", "tag5"] for name in names: BlogPost(name=name).save() @@ -560,7 +561,13 @@ class IndexesTest(unittest.TestCase): self.assertEqual( [x.name for x in query_result], sorted(names, key=lambda x: x.lower()) ) - self.assertEqual(10, query_result.count()) + self.assertEqual(5, query_result.count()) + + query_result = BlogPost.objects.collation(Collation(**base)).order_by("name") + self.assertEqual( + [x.name for x in query_result], sorted(names, key=lambda x: x.lower()) + ) + self.assertEqual(5, query_result.count()) incorrect_collation = {"arndom": "wrdo"} with self.assertRaises(OperationFailure): From ddececbfead80b2565c4c086a9cbd52ed07e17c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 28 Aug 2019 16:01:44 +0300 Subject: [PATCH 06/59] rename all test files so that they are prefixed by test_{orginal_filename}.py --- tests/__init__.py | 4 -- tests/all_warnings/__init__.py | 40 ------------- tests/all_warnings/test_warnings.py | 37 ++++++++++++ tests/document/__init__.py | 13 ----- ...class_methods.py => test_class_methods.py} | 9 +-- tests/document/{delta.py => test_delta.py} | 4 +- .../document/{dynamic.py => test_dynamic.py} | 2 +- .../document/{indexes.py => test_indexes.py} | 5 +- .../{inheritance.py => test_inheritance.py} | 8 +-- .../{instance.py => test_instance.py} | 4 +- ...lisation.py => test_json_serialisation.py} | 11 +--- .../{validation.py => test_validation.py} | 8 +-- tests/fields/__init__.py | 3 - tests/fields/test_binary_field.py | 5 +- tests/fields/{fields.py => test_fields.py} | 57 +++++++++---------- .../{file_tests.py => test_file_field.py} | 6 +- tests/fields/{geo.py => test_geo_fields.py} | 10 +--- tests/queryset/__init__.py | 6 -- .../{field_list.py => test_field_list.py} | 6 +- tests/queryset/{geo.py => test_geo.py} | 5 +- .../{queryset.py => test_queryset.py} | 2 +- 21 files changed, 90 insertions(+), 155 deletions(-) create mode 100644 tests/all_warnings/test_warnings.py rename tests/document/{class_methods.py => test_class_methods.py} (99%) rename tests/document/{delta.py => test_delta.py} (99%) rename tests/document/{dynamic.py => test_dynamic.py} (99%) rename tests/document/{indexes.py => test_indexes.py} (99%) rename tests/document/{inheritance.py => test_inheritance.py} (99%) rename tests/document/{instance.py => test_instance.py} (99%) rename tests/document/{json_serialisation.py => test_json_serialisation.py} (95%) rename tests/document/{validation.py => test_validation.py} (97%) rename tests/fields/{fields.py => test_fields.py} (99%) rename tests/fields/{file_tests.py => test_file_field.py} (99%) rename tests/fields/{geo.py => test_geo_fields.py} (98%) rename tests/queryset/{field_list.py => test_field_list.py} (99%) rename tests/queryset/{geo.py => test_geo.py} (99%) rename tests/queryset/{queryset.py => test_queryset.py} (99%) diff --git a/tests/__init__.py b/tests/__init__.py index 08db7186..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +0,0 @@ -from .all_warnings import AllWarnings -from .document import * -from .queryset import * -from .fields import * diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py index a755e7a3..e69de29b 100644 --- a/tests/all_warnings/__init__.py +++ b/tests/all_warnings/__init__.py @@ -1,40 +0,0 @@ -""" -This test has been put into a module. This is because it tests warnings that -only get triggered on first hit. This way we can ensure its imported into the -top level and called first by the test suite. -""" -import unittest -import warnings - -from mongoengine import * - - -__all__ = ("AllWarnings",) - - -class AllWarnings(unittest.TestCase): - def setUp(self): - connect(db="mongoenginetest") - self.warning_list = [] - self.showwarning_default = warnings.showwarning - warnings.showwarning = self.append_to_warning_list - - def append_to_warning_list(self, message, category, *args): - self.warning_list.append({"message": message, "category": category}) - - def tearDown(self): - # restore default handling of warnings - warnings.showwarning = self.showwarning_default - - def test_document_collection_syntax_warning(self): - class NonAbstractBase(Document): - meta = {"allow_inheritance": True} - - class InheritedDocumentFailTest(NonAbstractBase): - meta = {"collection": "fail"} - - warning = self.warning_list[0] - self.assertEqual(SyntaxWarning, warning["category"]) - self.assertEqual( - "non_abstract_base", InheritedDocumentFailTest._get_collection_name() - ) diff --git a/tests/all_warnings/test_warnings.py b/tests/all_warnings/test_warnings.py new file mode 100644 index 00000000..67204617 --- /dev/null +++ b/tests/all_warnings/test_warnings.py @@ -0,0 +1,37 @@ +""" +This test has been put into a module. This is because it tests warnings that +only get triggered on first hit. This way we can ensure its imported into the +top level and called first by the test suite. +""" +import unittest +import warnings + +from mongoengine import * + + +class TestAllWarnings(unittest.TestCase): + def setUp(self): + connect(db="mongoenginetest") + self.warning_list = [] + self.showwarning_default = warnings.showwarning + warnings.showwarning = self.append_to_warning_list + + def append_to_warning_list(self, message, category, *args): + self.warning_list.append({"message": message, "category": category}) + + def tearDown(self): + # restore default handling of warnings + warnings.showwarning = self.showwarning_default + + def test_document_collection_syntax_warning(self): + class NonAbstractBase(Document): + meta = {"allow_inheritance": True} + + class InheritedDocumentFailTest(NonAbstractBase): + meta = {"collection": "fail"} + + warning = self.warning_list[0] + self.assertEqual(SyntaxWarning, warning["category"]) + self.assertEqual( + "non_abstract_base", InheritedDocumentFailTest._get_collection_name() + ) diff --git a/tests/document/__init__.py b/tests/document/__init__.py index f2230c48..e69de29b 100644 --- a/tests/document/__init__.py +++ b/tests/document/__init__.py @@ -1,13 +0,0 @@ -import unittest - -from .class_methods import * -from .delta import * -from .dynamic import * -from .indexes import * -from .inheritance import * -from .instance import * -from .json_serialisation import * -from .validation import * - -if __name__ == "__main__": - unittest.main() diff --git a/tests/document/class_methods.py b/tests/document/test_class_methods.py similarity index 99% rename from tests/document/class_methods.py rename to tests/document/test_class_methods.py index 87f1215b..c5df0843 100644 --- a/tests/document/class_methods.py +++ b/tests/document/test_class_methods.py @@ -2,15 +2,12 @@ import unittest from mongoengine import * -from mongoengine.pymongo_support import list_collection_names - -from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db - -__all__ = ("ClassMethodsTest",) +from mongoengine.pymongo_support import list_collection_names +from mongoengine.queryset import NULLIFY, PULL -class ClassMethodsTest(unittest.TestCase): +class TestClassMethods(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") self.db = get_db() diff --git a/tests/document/delta.py b/tests/document/test_delta.py similarity index 99% rename from tests/document/delta.py rename to tests/document/test_delta.py index 8f1575e6..632d9b3f 100644 --- a/tests/document/delta.py +++ b/tests/document/test_delta.py @@ -7,9 +7,9 @@ from mongoengine.pymongo_support import list_collection_names from tests.utils import MongoDBTestCase -class DeltaTest(MongoDBTestCase): +class TestDelta(MongoDBTestCase): def setUp(self): - super(DeltaTest, self).setUp() + super(TestDelta, self).setUp() class Person(Document): name = StringField() diff --git a/tests/document/dynamic.py b/tests/document/test_dynamic.py similarity index 99% rename from tests/document/dynamic.py rename to tests/document/test_dynamic.py index 414d3352..6b517d24 100644 --- a/tests/document/dynamic.py +++ b/tests/document/test_dynamic.py @@ -179,7 +179,7 @@ class TestDynamicDocument(MongoDBTestCase): def test_three_level_complex_data_lookups(self): """Ensure you can query three level document dynamic fields""" - p = self.Person.objects.create(misc={"hello": {"hello2": "world"}}) + self.Person.objects.create(misc={"hello": {"hello2": "world"}}) self.assertEqual(1, self.Person.objects(misc__hello__hello2="world").count()) def test_complex_embedded_document_validation(self): diff --git a/tests/document/indexes.py b/tests/document/test_indexes.py similarity index 99% rename from tests/document/indexes.py rename to tests/document/test_indexes.py index 570e619e..f94eb359 100644 --- a/tests/document/indexes.py +++ b/tests/document/test_indexes.py @@ -4,16 +4,13 @@ from datetime import datetime from nose.plugins.skip import SkipTest from pymongo.errors import OperationFailure -import pymongo from six import iteritems from mongoengine import * from mongoengine.connection import get_db -__all__ = ("IndexesTest",) - -class IndexesTest(unittest.TestCase): +class TestIndexes(unittest.TestCase): def setUp(self): self.connection = connect(db="mongoenginetest") self.db = get_db() diff --git a/tests/document/inheritance.py b/tests/document/test_inheritance.py similarity index 99% rename from tests/document/inheritance.py rename to tests/document/test_inheritance.py index 4f21d5f4..4bb46e58 100644 --- a/tests/document/inheritance.py +++ b/tests/document/test_inheritance.py @@ -15,13 +15,11 @@ from mongoengine import ( StringField, ) from mongoengine.pymongo_support import list_collection_names -from tests.utils import MongoDBTestCase from tests.fixtures import Base - -__all__ = ("InheritanceTest",) +from tests.utils import MongoDBTestCase -class InheritanceTest(MongoDBTestCase): +class TestInheritance(MongoDBTestCase): def tearDown(self): for collection in list_collection_names(self.db): self.db.drop_collection(collection) @@ -401,7 +399,7 @@ class InheritanceTest(MongoDBTestCase): class Animal(FinalDocument): name = StringField() - with self.assertRaises(ValueError) as cm: + with self.assertRaises(ValueError): class Mammal(Animal): pass diff --git a/tests/document/instance.py b/tests/document/test_instance.py similarity index 99% rename from tests/document/instance.py rename to tests/document/test_instance.py index d8841a40..9b4a16e5 100644 --- a/tests/document/instance.py +++ b/tests/document/test_instance.py @@ -39,10 +39,8 @@ from tests.utils import MongoDBTestCase, get_as_pymongo TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") -__all__ = ("InstanceTest",) - -class InstanceTest(MongoDBTestCase): +class TestInstance(MongoDBTestCase): def setUp(self): class Job(EmbeddedDocument): name = StringField() diff --git a/tests/document/json_serialisation.py b/tests/document/test_json_serialisation.py similarity index 95% rename from tests/document/json_serialisation.py rename to tests/document/test_json_serialisation.py index 33d5a6d9..26a4a6c1 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/test_json_serialisation.py @@ -1,21 +1,14 @@ import unittest import uuid -from nose.plugins.skip import SkipTest from datetime import datetime from bson import ObjectId -import pymongo - from mongoengine import * - -__all__ = ("TestJson",) +from tests.utils import MongoDBTestCase -class TestJson(unittest.TestCase): - def setUp(self): - connect(db="mongoenginetest") - +class TestJson(MongoDBTestCase): def test_json_names(self): """ Going to test reported issue: diff --git a/tests/document/validation.py b/tests/document/test_validation.py similarity index 97% rename from tests/document/validation.py rename to tests/document/test_validation.py index 78199231..7449dd33 100644 --- a/tests/document/validation.py +++ b/tests/document/test_validation.py @@ -3,14 +3,10 @@ import unittest from datetime import datetime from mongoengine import * - -__all__ = ("ValidatorErrorTest",) +from tests.utils import MongoDBTestCase -class ValidatorErrorTest(unittest.TestCase): - def setUp(self): - connect(db="mongoenginetest") - +class TestValidatorError(MongoDBTestCase): def test_to_dict(self): """Ensure a ValidationError handles error to_dict correctly. """ diff --git a/tests/fields/__init__.py b/tests/fields/__init__.py index 4994d0c6..e69de29b 100644 --- a/tests/fields/__init__.py +++ b/tests/fields/__init__.py @@ -1,3 +0,0 @@ -from .fields import * -from .file_tests import * -from .geo import * diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py index df4bf2de..719df922 100644 --- a/tests/fields/test_binary_field.py +++ b/tests/fields/test_binary_field.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- import uuid +from bson import Binary from nose.plugins.skip import SkipTest import six -from bson import Binary - from mongoengine import * from tests.utils import MongoDBTestCase @@ -77,8 +76,6 @@ class TestBinaryField(MongoDBTestCase): self.assertEqual(0, Attachment.objects.count()) def test_primary_filter_by_binary_pk_as_str(self): - raise SkipTest("Querying by id as string is not currently supported") - class Attachment(Document): id = BinaryField(primary_key=True) diff --git a/tests/fields/fields.py b/tests/fields/test_fields.py similarity index 99% rename from tests/fields/fields.py rename to tests/fields/test_fields.py index 49e9508c..d9279c22 100644 --- a/tests/fields/fields.py +++ b/tests/fields/test_fields.py @@ -2,39 +2,38 @@ import datetime import unittest +from bson import DBRef, ObjectId, SON from nose.plugins.skip import SkipTest -from bson import DBRef, ObjectId, SON - from mongoengine import ( - Document, - StringField, - IntField, - DateTimeField, - DateField, - ValidationError, + BooleanField, ComplexDateTimeField, - FloatField, - ListField, - ReferenceField, + DateField, + DateTimeField, DictField, + Document, + DoesNotExist, + DynamicDocument, + DynamicField, EmbeddedDocument, EmbeddedDocumentField, - GenericReferenceField, - DoesNotExist, - NotRegistered, - OperationError, - DynamicField, - FieldDoesNotExist, EmbeddedDocumentListField, - MultipleObjectsReturned, - NotUniqueError, - BooleanField, - ObjectIdField, - SortedListField, + FieldDoesNotExist, + FloatField, GenericLazyReferenceField, + GenericReferenceField, + IntField, LazyReferenceField, - DynamicDocument, + ListField, + MultipleObjectsReturned, + NotRegistered, + NotUniqueError, + ObjectIdField, + OperationError, + ReferenceField, + SortedListField, + StringField, + ValidationError, ) from mongoengine.base import BaseField, EmbeddedDocumentList, _document_registry from mongoengine.errors import DeprecatedError @@ -42,7 +41,7 @@ from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase -class FieldTest(MongoDBTestCase): +class TestField(MongoDBTestCase): def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. @@ -343,7 +342,7 @@ class FieldTest(MongoDBTestCase): doc.save() # Unset all the fields - obj = HandleNoneFields._get_collection().update( + HandleNoneFields._get_collection().update( {"_id": doc.id}, {"$unset": {"str_fld": 1, "int_fld": 1, "flt_fld": 1, "comp_dt_fld": 1}}, ) @@ -416,13 +415,13 @@ class FieldTest(MongoDBTestCase): # name starting with $ with self.assertRaises(ValueError): - class User(Document): + class UserX1(Document): name = StringField(db_field="$name") # name containing a null character with self.assertRaises(ValueError): - class User(Document): + class UserX2(Document): name = StringField(db_field="name\0") def test_list_validation(self): @@ -2267,7 +2266,7 @@ class FieldTest(MongoDBTestCase): Doc(bar="test") -class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): +class TestEmbeddedDocumentListField(MongoDBTestCase): def setUp(self): """ Create two BlogPost entries in the database, each with @@ -2320,7 +2319,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): # Test with a Document post = self.BlogPost(comments=Title(content="garbage")) - with self.assertRaises(ValidationError) as e: + with self.assertRaises(ValidationError): post.validate() self.assertIn("'comments'", str(ctx_err.exception)) self.assertIn( diff --git a/tests/fields/file_tests.py b/tests/fields/test_file_field.py similarity index 99% rename from tests/fields/file_tests.py rename to tests/fields/test_file_field.py index dd2fe609..49eb5bc2 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/test_file_field.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- import copy import os -import unittest import tempfile +import unittest import gridfs +from nose.plugins.skip import SkipTest import six -from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db from mongoengine.python_support import StringIO @@ -35,7 +35,7 @@ def get_file(path): return bytes_io -class FileTest(MongoDBTestCase): +class TestFileField(MongoDBTestCase): def tearDown(self): self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") diff --git a/tests/fields/geo.py b/tests/fields/test_geo_fields.py similarity index 98% rename from tests/fields/geo.py rename to tests/fields/test_geo_fields.py index 446d7171..ff4cbc83 100644 --- a/tests/fields/geo.py +++ b/tests/fields/test_geo_fields.py @@ -2,16 +2,10 @@ import unittest from mongoengine import * -from mongoengine.connection import get_db - -__all__ = ("GeoFieldTest",) +from tests.utils import MongoDBTestCase -class GeoFieldTest(unittest.TestCase): - def setUp(self): - connect(db="mongoenginetest") - self.db = get_db() - +class TestGeoField(MongoDBTestCase): def _test_for_expected_error(self, Cls, loc, expected): try: Cls(loc=loc).validate() diff --git a/tests/queryset/__init__.py b/tests/queryset/__init__.py index 31016966..e69de29b 100644 --- a/tests/queryset/__init__.py +++ b/tests/queryset/__init__.py @@ -1,6 +0,0 @@ -from .transform import * -from .field_list import * -from .queryset import * -from .visitor import * -from .geo import * -from .modify import * diff --git a/tests/queryset/field_list.py b/tests/queryset/test_field_list.py similarity index 99% rename from tests/queryset/field_list.py rename to tests/queryset/test_field_list.py index 9f0fe827..703c2031 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/test_field_list.py @@ -3,10 +3,8 @@ import unittest from mongoengine import * from mongoengine.queryset import QueryFieldList -__all__ = ("QueryFieldListTest", "OnlyExcludeAllTest") - -class QueryFieldListTest(unittest.TestCase): +class TestQueryFieldList(unittest.TestCase): def test_empty(self): q = QueryFieldList() self.assertFalse(q) @@ -66,7 +64,7 @@ class QueryFieldListTest(unittest.TestCase): self.assertEqual(q.as_dict(), {"a": {"$slice": 5}}) -class OnlyExcludeAllTest(unittest.TestCase): +class TestOnlyExcludeAll(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") diff --git a/tests/queryset/geo.py b/tests/queryset/test_geo.py similarity index 99% rename from tests/queryset/geo.py rename to tests/queryset/test_geo.py index 95dc913d..343f864b 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/test_geo.py @@ -6,10 +6,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase -__all__ = ("GeoQueriesTest",) - - -class GeoQueriesTest(MongoDBTestCase): +class TestGeoQueries(MongoDBTestCase): def _create_event_data(self, point_field_class=GeoPointField): """Create some sample data re-used in many of the tests below.""" diff --git a/tests/queryset/queryset.py b/tests/queryset/test_queryset.py similarity index 99% rename from tests/queryset/queryset.py rename to tests/queryset/test_queryset.py index 9dc68f2e..a9ecaef5 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/test_queryset.py @@ -41,7 +41,7 @@ def get_key_compat(mongo_ver): return ORDER_BY_KEY, CMD_QUERY_KEY -class QuerySetTest(unittest.TestCase): +class TestQueryset(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") connect(db="mongoenginetest2", alias="test2") From a06e605e671bfcbe336addd780ae7c9e79069b99 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Thu, 29 Aug 2019 11:11:27 +0200 Subject: [PATCH 07/59] Add BulkWriteError exception --- mongoengine/errors.py | 5 +++++ mongoengine/queryset/base.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 9852f2a1..b76243d3 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -12,6 +12,7 @@ __all__ = ( "InvalidQueryError", "OperationError", "NotUniqueError", + "BulkWriteError", "FieldDoesNotExist", "ValidationError", "SaveConditionError", @@ -51,6 +52,10 @@ class NotUniqueError(OperationError): pass +class BulkWriteError(OperationError): + pass + + class SaveConditionError(OperationError): pass diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 46b20d78..6d3fb41a 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -20,6 +20,7 @@ from mongoengine.common import _import_class from mongoengine.connection import get_db from mongoengine.context_managers import set_write_concern, switch_db from mongoengine.errors import ( + BulkWriteError, InvalidQueryError, LookUpError, NotUniqueError, @@ -356,7 +357,7 @@ class BaseQuerySet(object): # inserting documents that already have an _id field will # give huge performance debt or raise message = u"Bulk write error: (%s)" - raise NotUniqueError(message % six.text_type(err.details)) + raise BulkWriteError(message % six.text_type(err.details)) except pymongo.errors.OperationFailure as err: message = "Could not save document (%s)" if re.match("^E1100[01] duplicate key", six.text_type(err)): From 2267b7e7d740409dd1b9f648d1fe5b9e1cfdb7a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 30 Aug 2019 16:27:56 +0300 Subject: [PATCH 08/59] rename remaining files for pytest migration --- tests/queryset/{modify.py => test_modify.py} | 6 ++---- tests/queryset/{pickable.py => test_pickable.py} | 11 +++-------- tests/queryset/{transform.py => test_transform.py} | 4 +--- tests/queryset/{visitor.py => test_visitor.py} | 4 +--- 4 files changed, 7 insertions(+), 18 deletions(-) rename tests/queryset/{modify.py => test_modify.py} (96%) rename tests/queryset/{pickable.py => test_pickable.py} (87%) rename tests/queryset/{transform.py => test_transform.py} (99%) rename tests/queryset/{visitor.py => test_visitor.py} (99%) diff --git a/tests/queryset/modify.py b/tests/queryset/test_modify.py similarity index 96% rename from tests/queryset/modify.py rename to tests/queryset/test_modify.py index e092d11c..60f4884c 100644 --- a/tests/queryset/modify.py +++ b/tests/queryset/test_modify.py @@ -1,8 +1,6 @@ import unittest -from mongoengine import connect, Document, IntField, StringField, ListField - -__all__ = ("FindAndModifyTest",) +from mongoengine import Document, IntField, ListField, StringField, connect class Doc(Document): @@ -10,7 +8,7 @@ class Doc(Document): value = IntField() -class FindAndModifyTest(unittest.TestCase): +class TestFindAndModify(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") Doc.drop_collection() diff --git a/tests/queryset/pickable.py b/tests/queryset/test_pickable.py similarity index 87% rename from tests/queryset/pickable.py rename to tests/queryset/test_pickable.py index 0945fcbc..fbdd1ff0 100644 --- a/tests/queryset/pickable.py +++ b/tests/queryset/test_pickable.py @@ -1,10 +1,8 @@ import pickle import unittest -from pymongo.mongo_client import MongoClient -from mongoengine import Document, StringField, IntField -from mongoengine.connection import connect -__author__ = "stas" +from mongoengine import Document, IntField, StringField +from mongoengine.connection import connect class Person(Document): @@ -20,11 +18,8 @@ class TestQuerysetPickable(unittest.TestCase): def setUp(self): super(TestQuerysetPickable, self).setUp() - - connection = connect(db="test") # type: pymongo.mongo_client.MongoClient - + connection = connect(db="test") connection.drop_database("test") - self.john = Person.objects.create(name="John", age=21) def test_picke_simple_qs(self): diff --git a/tests/queryset/transform.py b/tests/queryset/test_transform.py similarity index 99% rename from tests/queryset/transform.py rename to tests/queryset/test_transform.py index cfcd8c22..8207351d 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/test_transform.py @@ -5,10 +5,8 @@ from bson.son import SON from mongoengine import * from mongoengine.queryset import Q, transform -__all__ = ("TransformTest",) - -class TransformTest(unittest.TestCase): +class TestTransform(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") diff --git a/tests/queryset/visitor.py b/tests/queryset/test_visitor.py similarity index 99% rename from tests/queryset/visitor.py rename to tests/queryset/test_visitor.py index 0a22416f..acadabd4 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/test_visitor.py @@ -8,10 +8,8 @@ from mongoengine import * from mongoengine.errors import InvalidQueryError from mongoengine.queryset import Q -__all__ = ("QTest",) - -class QTest(unittest.TestCase): +class TestQ(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") From 693195f70be3b675757f700e40c6a66a74aa4b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 31 Aug 2019 22:28:31 +0300 Subject: [PATCH 09/59] fix test_pickable that was brought back to life recently --- tests/queryset/test_pickable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/queryset/test_pickable.py b/tests/queryset/test_pickable.py index fbdd1ff0..8c4e3426 100644 --- a/tests/queryset/test_pickable.py +++ b/tests/queryset/test_pickable.py @@ -3,6 +3,7 @@ import unittest from mongoengine import Document, IntField, StringField from mongoengine.connection import connect +from tests.utils import MongoDBTestCase class Person(Document): @@ -10,7 +11,7 @@ class Person(Document): age = IntField() -class TestQuerysetPickable(unittest.TestCase): +class TestQuerysetPickable(MongoDBTestCase): """ Test for adding pickling support for QuerySet instances See issue https://github.com/MongoEngine/mongoengine/issues/442 @@ -18,8 +19,6 @@ class TestQuerysetPickable(unittest.TestCase): def setUp(self): super(TestQuerysetPickable, self).setUp() - connection = connect(db="test") - connection.drop_database("test") self.john = Person.objects.create(name="John", age=21) def test_picke_simple_qs(self): From 47f8a126ca167cb8fe020e3cc5604b155dfcdebc Mon Sep 17 00:00:00 2001 From: Arto Jantunen Date: Tue, 3 Sep 2019 14:36:06 +0300 Subject: [PATCH 10/59] Only set no_cursor_timeout when requested Previously this was always set for all requests. The parameter is only documented as supported for certain queries, so this was probably wrong. Mongo version 4.2 fails update queries that have this parameter set making mongoengine unusable there. Fixes #2148. --- mongoengine/queryset/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ba3ac95a..ffa099ac 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1576,7 +1576,9 @@ class BaseQuerySet(object): if self._snapshot: msg = "The snapshot option is not anymore available with PyMongo 3+" warnings.warn(msg, DeprecationWarning) - cursor_args = {"no_cursor_timeout": not self._timeout} + cursor_args = {} + if not self._timeout: + cursor_args["no_cursor_timeout"] = True if self._loaded_fields: cursor_args[fields_name] = self._loaded_fields.as_dict() From 1dbe7a3163033703720f4ccd3c6391b5f3f8d490 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Tue, 3 Sep 2019 16:17:09 +0200 Subject: [PATCH 11/59] Add log in changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 55fa4b25..5422f113 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -18,6 +18,7 @@ Development - Fix updating/modifying/deleting/reloading a document that's sharded by a field with ``db_field`` specified. #2125 - ``ListField`` now accepts an optional ``max_length`` parameter. #2110 - The codebase is now formatted using ``black``. #2109 +- In bulk write insert, the detailed error message would raise in exception. Changes in 0.18.2 ================= From 7d94af0e3181751894884cbfcd55fe9383db028b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 11 Sep 2019 21:53:30 +0200 Subject: [PATCH 12/59] add test coverage for no_cursor_timeout to support recent fix --- mongoengine/queryset/base.py | 1 + tests/queryset/test_queryset.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ffa099ac..570ad37f 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1576,6 +1576,7 @@ class BaseQuerySet(object): if self._snapshot: msg = "The snapshot option is not anymore available with PyMongo 3+" warnings.warn(msg, DeprecationWarning) + cursor_args = {} if not self._timeout: cursor_args["no_cursor_timeout"] = True diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index a9ecaef5..e7e59905 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -5809,9 +5809,19 @@ class TestQueryset(unittest.TestCase): self.Person.objects.create(name="Baz") self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 3) - newPerson = self.Person.objects.create(name="Foo_1") + self.Person.objects.create(name="Foo_1") self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 4) + def test_no_cursor_timeout(self): + qs = self.Person.objects() + self.assertEqual(qs._cursor_args, {}) # ensure no regression of #2148 + + qs = self.Person.objects().timeout(True) + self.assertEqual(qs._cursor_args, {}) + + qs = self.Person.objects().timeout(False) + self.assertEqual(qs._cursor_args, {"no_cursor_timeout": True}) + if __name__ == "__main__": unittest.main() From 7ac74b1c1f60967852bea294c318adc0a45e347e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 15 Sep 2019 23:27:34 +0200 Subject: [PATCH 13/59] Document Model.objects.aggregate entrypoint with an example --- docs/guide/querying.rst | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 151855a6..50218aed 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -349,9 +349,9 @@ Just as with limiting and skipping results, there is a method on a You could technically use ``len(User.objects)`` to get the same result, but it would be significantly slower than :meth:`~mongoengine.queryset.QuerySet.count`. When you execute a server-side count query, you let MongoDB do the heavy -lifting and you receive a single integer over the wire. Meanwhile, len() +lifting and you receive a single integer over the wire. Meanwhile, ``len()`` retrieves all the results, places them in a local cache, and finally counts -them. If we compare the performance of the two operations, len() is much slower +them. If we compare the performance of the two operations, ``len()`` is much slower than :meth:`~mongoengine.queryset.QuerySet.count`. Further aggregation @@ -386,6 +386,18 @@ would be generating "tag-clouds":: top_tags = sorted(tag_freqs.items(), key=itemgetter(1), reverse=True)[:10] +MongoDB aggregation API +----------------------- +If you need to run aggregation pipelines, MongoEngine provides an entry point to `pymongo's aggregation framework `_ + through :meth:`~mongoengine.queryset.base.aggregate`. Checkout pymongo's documentation for the syntax and pipeline. +An example of its use would be :: + + class Person(Document): + name = StringField() + + pipeline = [{"$project": {"name": {"$toUpper": "$name"}}}] + data = Person.objects().aggregate(*pipeline) # Would return e.g: [{"_id": ObjectId('5d7eac82aae098e4ed3784c7'), "name": "JOHN DOE"}] + Query efficiency and performance ================================ From be2c4f2b3cdfe13d9abe92312409761d91f0040c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 16 Sep 2019 21:15:35 +0200 Subject: [PATCH 14/59] fix formatting and improve doc based on review --- docs/guide/querying.rst | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 50218aed..d64c169c 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -388,15 +388,22 @@ would be generating "tag-clouds":: MongoDB aggregation API ----------------------- -If you need to run aggregation pipelines, MongoEngine provides an entry point to `pymongo's aggregation framework `_ - through :meth:`~mongoengine.queryset.base.aggregate`. Checkout pymongo's documentation for the syntax and pipeline. -An example of its use would be :: +If you need to run aggregation pipelines, MongoEngine provides an entry point `Pymongo's aggregation framework `_ +through :meth:`~mongoengine.queryset.QuerySet.aggregate`. Check out Pymongo's documentation for the syntax and pipeline. +An example of its use would be:: class Person(Document): name = StringField() - pipeline = [{"$project": {"name": {"$toUpper": "$name"}}}] - data = Person.objects().aggregate(*pipeline) # Would return e.g: [{"_id": ObjectId('5d7eac82aae098e4ed3784c7'), "name": "JOHN DOE"}] + Person(name='John').save() + Person(name='Bob').save() + + pipeline = [ + {"$sort" : {"name" : -1}}, + {"$project": {"_id": 0, "name": {"$toUpper": "$name"}}} + ] + data = Person.objects().aggregate(*pipeline) + assert data == [{'name': 'BOB'}, {'name': 'JOHN'}] Query efficiency and performance ================================ From e3cd553f8211d473b54fb3c91da8bf9fa2ad053d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 4 Oct 2019 21:30:32 +0200 Subject: [PATCH 15/59] add latest pymongo 3.9 as part of the CI --- .travis.yml | 5 ++++- mongoengine/queryset/base.py | 1 + tests/queryset/test_queryset.py | 26 +++++++++----------------- tox.ini | 1 + 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/.travis.yml b/.travis.yml index 54a6befd..af1e2b14 100644 --- a/.travis.yml +++ b/.travis.yml @@ -32,10 +32,11 @@ env: global: - MONGODB_3_4=3.4.17 - MONGODB_3_6=3.6.12 + - PYMONGO_3_9=3.9 - PYMONGO_3_6=3.6 - PYMONGO_3_4=3.4 matrix: - - MONGODB=${MONGODB_3_4} PYMONGO=${PYMONGO_3_6} + - MONGODB=${MONGODB_3_4} PYMONGO=${PYMONGO_3_9} matrix: @@ -47,6 +48,8 @@ matrix: env: MONGODB=${MONGODB_3_4} PYMONGO=${PYMONGO_3_4} - python: 3.7 env: MONGODB=${MONGODB_3_6} PYMONGO=${PYMONGO_3_6} + - python: 3.7 + env: MONGODB=${MONGODB_3_6} PYMONGO=${PYMONGO_3_9} install: diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index cde06d54..a09cbf99 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1639,6 +1639,7 @@ class BaseQuerySet(object): ).find(self._query, **self._cursor_args) else: self._cursor_obj = self._collection.find(self._query, **self._cursor_args) + # Apply "where" clauses to cursor if self._where_clause: where_clause = self._sub_js_fields(self._where_clause) diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index e7e59905..16213254 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -4641,43 +4641,35 @@ class TestQueryset(unittest.TestCase): bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) self.assertEqual( - bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + bars._cursor.collection.read_preference, ReadPreference.SECONDARY_PREFERRED ) # Make sure that `.read_preference(...)` does accept string values. self.assertRaises(TypeError, Bar.objects.read_preference, "Primary") + def assert_read_pref(qs, expected_read_pref): + self.assertEqual(qs._read_preference, expected_read_pref) + self.assertEqual(qs._cursor.collection.read_preference, expected_read_pref) + # Make sure read preference is respected after a `.skip(...)`. bars = Bar.objects.skip(1).read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) - self.assertEqual( - bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED - ) + assert_read_pref(bars, ReadPreference.SECONDARY_PREFERRED) # Make sure read preference is respected after a `.limit(...)`. bars = Bar.objects.limit(1).read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) - self.assertEqual( - bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED - ) + assert_read_pref(bars, ReadPreference.SECONDARY_PREFERRED) # Make sure read preference is respected after an `.order_by(...)`. bars = Bar.objects.order_by("txt").read_preference( ReadPreference.SECONDARY_PREFERRED ) - self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) - self.assertEqual( - bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED - ) + assert_read_pref(bars, ReadPreference.SECONDARY_PREFERRED) # Make sure read preference is respected after a `.hint(...)`. bars = Bar.objects.hint([("txt", 1)]).read_preference( ReadPreference.SECONDARY_PREFERRED ) - self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) - self.assertEqual( - bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED - ) + assert_read_pref(bars, ReadPreference.SECONDARY_PREFERRED) def test_read_preference_aggregation_framework(self): class Bar(Document): diff --git a/tox.ini b/tox.ini index a1ae8444..b4a57818 100644 --- a/tox.ini +++ b/tox.ini @@ -8,5 +8,6 @@ deps = nose mg34: pymongo>=3.4,<3.5 mg36: pymongo>=3.6,<3.7 + mg39: pymongo>=3.9,<4.0 setenv = PYTHON_EGG_CACHE = {envdir}/python-eggs From 71e8d9a49067f1c790739f43fb3ff35baf01c458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 10 Sep 2019 23:02:32 +0200 Subject: [PATCH 16/59] Added a FAQ to doc and Document the fact that we dont support motor --- docs/faq.rst | 13 +++++++++++++ docs/index.rst | 4 ++++ 2 files changed, 17 insertions(+) create mode 100644 docs/faq.rst diff --git a/docs/faq.rst b/docs/faq.rst new file mode 100644 index 00000000..27cd6937 --- /dev/null +++ b/docs/faq.rst @@ -0,0 +1,13 @@ +========================== +Frequently Asked Questions +========================== + +Does MongoEngine support asynchronous drivers (Motor, TxMongo)? +--------------------------------------------------------------- + +No, MongoEngine is exclusively based on PyMongo and isn't designed to support other driver. +If this is a requirement for your project, check the alternative: `uMongo`_ and `MotorEngine`_. + +.. _uMongo: https://umongo.readthedocs.io/ +.. _MotorEngine: https://motorengine.readthedocs.io/ + diff --git a/docs/index.rst b/docs/index.rst index 2102df02..686ef547 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,6 +23,9 @@ MongoDB. To install it, simply run :doc:`upgrade` How to upgrade MongoEngine. +:doc:`faq` + Frequently Asked Questions + :doc:`django` Using MongoEngine and Django @@ -73,6 +76,7 @@ formats for offline reading. apireference changelog upgrade + faq django Indices and tables From 19f12f3f2f380987c23dcafb8f288eb51177363d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 4 Oct 2019 21:51:12 +0200 Subject: [PATCH 17/59] document pymongo in RTD and make it point to github --- docs/index.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 686ef547..662968d4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,6 +29,12 @@ MongoDB. To install it, simply run :doc:`django` Using MongoEngine and Django +MongoDB and driver support +-------------------------- + +MongoEngine is based on the PyMongo driver and tested against multiple versions of MongoDB. +For further details, please refer to the `readme `_. + Community --------- From 1e17b5ac66148387a18d078f4b21cc406beef4f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 5 Oct 2019 14:24:54 +0200 Subject: [PATCH 18/59] Fix docstring format to improve pycharm inspection --- mongoengine/base/document.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index a962a82b..2be8dd6f 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -62,13 +62,13 @@ class BaseDocument(object): """ Initialise a document or an embedded document. - :param dict values: A dictionary of keys and values for the document. + :param values: A dictionary of keys and values for the document. It may contain additional reserved keywords, e.g. "__auto_convert". - :param bool __auto_convert: If True, supplied values will be converted + :param __auto_convert: If True, supplied values will be converted to Python-type values via each field's `to_python` method. - :param set __only_fields: A set of fields that have been loaded for + :param __only_fields: A set of fields that have been loaded for this document. Empty if all fields have been loaded. - :param bool _created: Indicates whether this is a brand new document + :param _created: Indicates whether this is a brand new document or whether it's already been persisted before. Defaults to true. """ self._initialised = False From 5bcc6791947f0520a171dc5d4b843c9321efe683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 10 Oct 2019 22:55:44 +0200 Subject: [PATCH 19/59] fix 2 pymongo deprecation warnings --- tests/test_connection.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index b7dc9268..071f4207 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,35 +1,33 @@ import datetime +from bson.tz_util import utc +from nose.plugins.skip import SkipTest +import pymongo from pymongo import MongoClient -from pymongo.errors import OperationFailure, InvalidName from pymongo import ReadPreference +from pymongo.errors import InvalidName, OperationFailure -from mongoengine import Document try: import unittest2 as unittest except ImportError: import unittest -from nose.plugins.skip import SkipTest -import pymongo -from bson.tz_util import utc - -from mongoengine import ( - connect, - register_connection, - Document, - DateTimeField, - disconnect_all, - StringField, -) import mongoengine.connection +from mongoengine import ( + DateTimeField, + Document, + StringField, + connect, + disconnect_all, + register_connection, +) from mongoengine.connection import ( ConnectionFailure, - get_db, - get_connection, - disconnect, DEFAULT_DATABASE_NAME, + disconnect, + get_connection, + get_db, ) @@ -289,7 +287,7 @@ class ConnectionTest(unittest.TestCase): # database won't exist until we save a document some_document.save() self.assertEqual(conn.get_default_database().name, "mongoenginetest") - self.assertEqual(conn.database_names()[0], "mongoenginetest") + self.assertEqual(conn.list_database_names()[0], "mongoenginetest") def test_connect_with_host_list(self): """Ensure that the connect() method works when host is a list @@ -631,8 +629,10 @@ class ConnectionTest(unittest.TestCase): """Ensure write concern can be specified in connect() via a kwarg or as part of the connection URI. """ - conn1 = connect(alias="conn1", host="mongodb://localhost/testing?w=1&j=true") - conn2 = connect("testing", alias="conn2", w=1, j=True) + conn1 = connect( + alias="conn1", host="mongodb://localhost/testing?w=1&journal=true" + ) + conn2 = connect("testing", alias="conn2", w=1, journal=True) self.assertEqual(conn1.write_concern.document, {"w": 1, "j": True}) self.assertEqual(conn2.write_concern.document, {"w": 1, "j": True}) From c60ed32f3a795a36f15e419fb09fe40b826947d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 16 Oct 2019 21:25:17 +0200 Subject: [PATCH 20/59] Documented how pymongo.monitoring can be used with MongoEngine --- docs/changelog.rst | 2 + docs/guide/index.rst | 1 + docs/guide/logging-monitoring.rst | 80 +++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 docs/guide/logging-monitoring.rst diff --git a/docs/changelog.rst b/docs/changelog.rst index 5422f113..58d7f272 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,8 @@ Changelog Development =========== - (Fill this out as you fix issues and develop your features). +- Documentation improvements: + - Documented how `pymongo.monitoring` can be used to log all queries issued by MongoEngine to the driver. - BREAKING CHANGE: ``class_check`` and ``read_preference`` keyword arguments are no longer available when filtering a ``QuerySet``. #2112 - Instead of ``Doc.objects(foo=bar, read_preference=...)`` use ``Doc.objects(foo=bar).read_preference(...)``. - Instead of ``Doc.objects(foo=bar, class_check=False)`` use ``Doc.objects(foo=bar).clear_cls_query(...)``. diff --git a/docs/guide/index.rst b/docs/guide/index.rst index 46eb7af2..a0364ec1 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -13,4 +13,5 @@ User Guide gridfs signals text-indexes + logging-monitoring mongomock diff --git a/docs/guide/logging-monitoring.rst b/docs/guide/logging-monitoring.rst new file mode 100644 index 00000000..9f523b79 --- /dev/null +++ b/docs/guide/logging-monitoring.rst @@ -0,0 +1,80 @@ +================== +Logging/Monitoring +================== + +It is possible to use `pymongo.monitoring `_ to monitor +the driver events (e.g: queries, connections, etc). This can be handy if you want to monitor the queries issued by +MongoEngine to the driver. + +To use `pymongo.monitoring` with MongoEngine, you need to make sure that you are registering the listeners +**before** establishing the database connection (i.e calling `connect`): + +The following snippet provides a basic logging of all command events: + +.. code-block:: python + + import logging + from pymongo import monitoring + from mongoengine import * + + log = logging.getLogger() + log.setLevel(logging.DEBUG) + logging.basicConfig(level=logging.DEBUG) + + + class CommandLogger(monitoring.CommandListener): + + def started(self, event): + log.debug("Command {0.command_name} with request id " + "{0.request_id} started on server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + log.debug("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "succeeded in {0.duration_micros} " + "microseconds".format(event)) + + def failed(self, event): + log.debug("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "failed in {0.duration_micros} " + "microseconds".format(event)) + + monitoring.register(CommandLogger()) + + + class Jedi(Document): + name = StringField() + + + connect() + + + log.info('GO!') + + log.info('Saving an item through MongoEngine...') + Jedi(name='Obi-Wan Kenobii').save() + + log.info('Querying through MongoEngine...') + obiwan = Jedi.objects.first() + + log.info('Updating through MongoEngine...') + obiwan.name = 'Obi-Wan Kenobi' + obiwan.save() + + +Executing this prints the following output:: + + INFO:root:GO! + INFO:root:Saving an item through MongoEngine... + DEBUG:root:Command insert with request id 1681692777 started on server ('localhost', 27017) + DEBUG:root:Command insert with request id 1681692777 on server ('localhost', 27017) succeeded in 562 microseconds + INFO:root:Querying through MongoEngine... + DEBUG:root:Command find with request id 1714636915 started on server ('localhost', 27017) + DEBUG:root:Command find with request id 1714636915 on server ('localhost', 27017) succeeded in 341 microseconds + INFO:root:Updating through MongoEngine... + DEBUG:root:Command update with request id 1957747793 started on server ('localhost', 27017) + DEBUG:root:Command update with request id 1957747793 on server ('localhost', 27017) succeeded in 455 microseconds + +More details can of course be obtained by checking the `event` argument from the `CommandListener`. From 8bf5370b6cdac97e00b314d6cd57016494a25873 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 28 Oct 2019 22:05:13 +0100 Subject: [PATCH 21/59] Improve error message from InvalidDocumentError whenever an embedded document has a bad shape (e.g due to migration) --- docs/changelog.rst | 1 + mongoengine/base/document.py | 9 +++++++-- tests/document/test_instance.py | 23 +++++++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5422f113..a717b837 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -15,6 +15,7 @@ Development - If you catch/use ``MongoEngineConnectionError`` in your code, you'll have to rename it. - BREAKING CHANGE: Positional arguments when instantiating a document are no longer supported. #2103 - From now on keyword arguments (e.g. ``Doc(field_name=value)``) are required. +- Improve error message related to InvalidDocumentError #2180 - Fix updating/modifying/deleting/reloading a document that's sharded by a field with ``db_field`` specified. #2125 - ``ListField`` now accepts an optional ``max_length`` parameter. #2110 - The codebase is now formatted using ``black``. #2109 diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index a962a82b..a967436a 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -732,7 +732,10 @@ class BaseDocument(object): only_fields = [] if son and not isinstance(son, dict): - raise ValueError("The source SON object needs to be of type 'dict'") + raise ValueError( + "The source SON object needs to be of type 'dict' but a '%s' was found" + % type(son) + ) # Get the class name from the document, falling back to the given # class if unavailable @@ -770,7 +773,9 @@ class BaseDocument(object): errors_dict[field_name] = e if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) for k, v in errors_dict.items()]) + errors = "\n".join( + ["Field '%s' - %s" % (k, v) for k, v in errors_dict.items()] + ) msg = "Invalid data to create a `%s` instance.\n%s" % ( cls._class_name, errors, diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 9b4a16e5..60e5313d 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3656,6 +3656,29 @@ class TestInstance(MongoDBTestCase): with self.assertRaises(DuplicateKeyError): User.objects().select_related() + def test_embedded_document_failed_while_loading_instance_when_it_is_not_a_dict( + self + ): + class LightSaber(EmbeddedDocument): + color = StringField() + + class Jedi(Document): + light_saber = EmbeddedDocumentField(LightSaber) + + coll = Jedi._get_collection() + Jedi(light_saber=LightSaber(color="red")).save() + _ = list(Jedi.objects) # Ensure a proper document loads without errors + + # Forces a document with a wrong shape (may occur in case of migration) + coll.insert_one({"light_saber": "I_should_be_a_dict"}) + + with self.assertRaises(InvalidDocumentError) as cm: + list(Jedi.objects) + self.assertEqual( + str(cm.exception), + "Invalid data to create a `Jedi` instance.\nField 'light_saber' - The source SON object needs to be of type 'dict' but a '' was found", + ) + class ObjectKeyTestCase(MongoDBTestCase): def test_object_key_simple_document(self): From 54ca7bf09fa70a4fba3d83d8fc77090cddaaae67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 28 Oct 2019 22:38:21 +0100 Subject: [PATCH 22/59] fix associated test to avoid discrepencies btw py2 and py3 --- tests/document/test_instance.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 60e5313d..7a868d29 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3670,13 +3670,16 @@ class TestInstance(MongoDBTestCase): _ = list(Jedi.objects) # Ensure a proper document loads without errors # Forces a document with a wrong shape (may occur in case of migration) - coll.insert_one({"light_saber": "I_should_be_a_dict"}) + value = u"I_should_be_a_dict" + coll.insert_one({"light_saber": value}) with self.assertRaises(InvalidDocumentError) as cm: list(Jedi.objects) + self.assertEqual( str(cm.exception), - "Invalid data to create a `Jedi` instance.\nField 'light_saber' - The source SON object needs to be of type 'dict' but a '' was found", + "Invalid data to create a `Jedi` instance.\nField 'light_saber' - The source SON object needs to be of type 'dict' but a '%s' was found" + % type(value), ) From bbfa97886188584ffcc7cfb73d084e2206832c42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 25 Aug 2019 15:21:30 +0300 Subject: [PATCH 23/59] switch test runner from nose to pytest --- .travis.yml | 6 ++--- README.rst | 11 ++++---- setup.cfg | 11 ++++---- setup.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++-- tests/test_ci.py | 9 +++++++ tox.ini | 2 +- 6 files changed, 87 insertions(+), 18 deletions(-) create mode 100644 tests/test_ci.py diff --git a/.travis.yml b/.travis.yml index af1e2b14..9d2ba8c1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -63,8 +63,8 @@ install: - pip install flake8 flake8-import-order - pip install tox # tox 3.11.0 has requirement virtualenv>=14.0.0 - pip install virtualenv # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) - # Install the tox venv. - - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test + # Install the tox venv (we make pytest avoid running the test by giving a bad pattern) + - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -a "-k=test_ci_placeholder" # Install black for Python v3.7 only. - if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then pip install black; fi @@ -76,7 +76,7 @@ before_script: - mongo --eval 'db.version();' # Make sure mongo is awake script: - - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage + - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') #-- --with-coverage # For now only submit coveralls for Python v2.7. Python v3.x currently shows # 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible diff --git a/README.rst b/README.rst index 679980f8..853d8fbe 100644 --- a/README.rst +++ b/README.rst @@ -116,7 +116,8 @@ Some simple examples of what MongoEngine code looks like: Tests ===== To run the test suite, ensure you are running a local instance of MongoDB on -the standard port and have ``nose`` installed. Then, run ``python setup.py nosetests``. +the standard port and have ``pytest`` installed. Then, run ``python setup.py test`` +or simply ``pytest``. To run the test suite on every supported Python and PyMongo version, you can use ``tox``. You'll need to make sure you have each supported Python version @@ -129,16 +130,14 @@ installed in your environment and then: # Run the test suites $ tox -If you wish to run a subset of tests, use the nosetests convention: +If you wish to run a subset of tests, use the pytest convention: .. code-block:: shell # Run all the tests in a particular test file - $ python setup.py nosetests --tests tests/fields/fields.py + $ pytest tests/fields/test_fields.py # Run only particular test class in that file - $ python setup.py nosetests --tests tests/fields/fields.py:FieldTest - # Use the -s option if you want to print some debug statements or use pdb - $ python setup.py nosetests --tests tests/fields/fields.py:FieldTest -s + $ pytest tests/fields/test_fields.py::TestField Community ========= diff --git a/setup.cfg b/setup.cfg index 4bded428..ae1b4f7e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,11 +1,10 @@ -[nosetests] -verbosity=2 -detailed-errors=1 -#tests=tests -cover-package=mongoengine - [flake8] ignore=E501,F401,F403,F405,I201,I202,W504, W605, W503 exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests max-complexity=47 application-import-names=mongoengine,tests + +[tool:pytest] +# Limits the discovery to tests directory +# avoids that it runs for instance the benchmark +testpaths = tests diff --git a/setup.py b/setup.py index c73a93ff..81cc9744 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ import os import sys + +from pkg_resources import normalize_path from setuptools import find_packages, setup +from setuptools.command.test import test as TestCommand # Hack to silence atexit traceback in newer python versions try: @@ -24,6 +27,65 @@ def get_version(version_tuple): return ".".join(map(str, version_tuple)) +class PyTest(TestCommand): + """Will force pytest to search for tests inside the build directory + for 2to3 converted code (used by tox), instead of the current directory. + Required as long as we need 2to3 + + Known Limitation: https://tox.readthedocs.io/en/latest/example/pytest.html#known-issues-and-limitations + Source: https://www.hackzine.org/python-testing-with-pytest-and-2to3-plus-tox-and-travis-ci.html + """ + + # https://pytest.readthedocs.io/en/2.7.3/goodpractises.html#integration-with-setuptools-test-commands + # Allows to provide pytest command arguments through the test runner command `python setup.py test` + # e.g: `python setup.py test -a "-k=test"` + user_options = [("pytest-args=", "a", "Arguments to pass to py.test")] + + def initialize_options(self): + TestCommand.initialize_options(self) + self.pytest_args = "" + + def finalize_options(self): + TestCommand.finalize_options(self) + self.test_args = ["tests"] + self.test_suite = True + + def run_tests(self): + # import here, cause outside the eggs aren't loaded + from pkg_resources import _namespace_packages + import pytest + + # Purge modules under test from sys.modules. The test loader will + # re-import them from the build location. Required when 2to3 is used + # with namespace packages. + if sys.version_info >= (3,) and getattr(self.distribution, "use_2to3", False): + print("Hack for 2to3", self.test_args) + module = self.test_args[-1].split(".")[0] + if module in _namespace_packages: + del_modules = [] + if module in sys.modules: + del_modules.append(module) + module += "." + for name in sys.modules: + if name.startswith(module): + del_modules.append(name) + map(sys.modules.__delitem__, del_modules) + + # Run on the build directory for 2to3-built code + # This will prevent the old 2.x code from being found + # by py.test discovery mechanism, that apparently + # ignores sys.path.. + ei_cmd = self.get_finalized_command("egg_info") + self.test_args = [normalize_path(ei_cmd.egg_base)] + + print(self.test_args, self.pytest_args) + cmd_args = self.test_args + ([self.pytest_args] if self.pytest_args else []) + print(cmd_args) + errno = pytest.main(cmd_args) + + sys.exit(errno) + + # Dirty hack to get version number from monogengine/__init__.py - we can't # import it as it depends on PyMongo and PyMongo isn't installed until this # file is read @@ -51,7 +113,7 @@ CLASSIFIERS = [ extra_opts = { "packages": find_packages(exclude=["tests", "tests.*"]), - "tests_require": ["nose", "coverage==4.2", "blinker", "Pillow>=2.0.0"], + "tests_require": ["pytest<5.0", "coverage==4.2", "blinker", "Pillow>=2.0.0"], } if sys.version_info[0] == 3: extra_opts["use_2to3"] = True @@ -79,6 +141,6 @@ setup( platforms=["any"], classifiers=CLASSIFIERS, install_requires=["pymongo>=3.4", "six"], - test_suite="nose.collector", + cmdclass={"test": PyTest}, **extra_opts ) diff --git a/tests/test_ci.py b/tests/test_ci.py new file mode 100644 index 00000000..04a800eb --- /dev/null +++ b/tests/test_ci.py @@ -0,0 +1,9 @@ +def test_ci_placeholder(): + # This empty test is used within the CI to + # setup the tox venv without running the test suite + # if we simply skip all test with pytest -k=wrong_pattern + # pytest command would return with exit_code=5 (i.e "no tests run") + # making travis fail + # this empty test is the recommended way to handle this + # as described in https://github.com/pytest-dev/pytest/issues/2393 + pass diff --git a/tox.ini b/tox.ini index b4a57818..94ccc9cf 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ envlist = {py27,py35,pypy,pypy3}-{mg34,mg36} [testenv] commands = - python setup.py nosetests {posargs} + python setup.py test {posargs} deps = nose mg34: pymongo>=3.4,<3.5 From 5a16dda50d228d7670bfa7467be7e6124369406f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 26 Aug 2019 17:12:56 +0300 Subject: [PATCH 24/59] fix coverage for pytest runner --- .travis.yml | 4 ++-- setup.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9d2ba8c1..2992d416 100644 --- a/.travis.yml +++ b/.travis.yml @@ -76,13 +76,13 @@ before_script: - mongo --eval 'db.version();' # Make sure mongo is awake script: - - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') #-- --with-coverage + - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -a "--cov=mongoengine" # For now only submit coveralls for Python v2.7. Python v3.x currently shows # 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible # code in a separate dir and runs tests on that. after_success: -- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then coveralls --verbose; fi +- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then cat .coverage & coveralls --verbose; fi notifications: irc: irc.freenode.org#mongoengine diff --git a/setup.py b/setup.py index 81cc9744..94f71162 100644 --- a/setup.py +++ b/setup.py @@ -113,7 +113,13 @@ CLASSIFIERS = [ extra_opts = { "packages": find_packages(exclude=["tests", "tests.*"]), - "tests_require": ["pytest<5.0", "coverage==4.2", "blinker", "Pillow>=2.0.0"], + "tests_require": [ + "pytest<5.0", + "pytest-cov", + "coverage", + "blinker", + "Pillow>=2.0.0", + ], } if sys.version_info[0] == 3: extra_opts["use_2to3"] = True From 51ea3e3c6ff562260616127cde2f806008abc85d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 28 Aug 2019 15:07:27 +0300 Subject: [PATCH 25/59] fix for recent coverage/coveralls compatibility issue --- .travis.yml | 4 ++-- setup.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2992d416..cbf34cde 100644 --- a/.travis.yml +++ b/.travis.yml @@ -63,7 +63,7 @@ install: - pip install flake8 flake8-import-order - pip install tox # tox 3.11.0 has requirement virtualenv>=14.0.0 - pip install virtualenv # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) - # Install the tox venv (we make pytest avoid running the test by giving a bad pattern) + # tox dryrun to setup the tox venv (we run a mock test). - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -a "-k=test_ci_placeholder" # Install black for Python v3.7 only. - if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then pip install black; fi @@ -82,7 +82,7 @@ script: # 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible # code in a separate dir and runs tests on that. after_success: -- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then cat .coverage & coveralls --verbose; fi +- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then coveralls --verbose; else echo "coveralls only sent for py27"; fi notifications: irc: irc.freenode.org#mongoengine diff --git a/setup.py b/setup.py index 94f71162..2bc1ae1c 100644 --- a/setup.py +++ b/setup.py @@ -37,8 +37,9 @@ class PyTest(TestCommand): """ # https://pytest.readthedocs.io/en/2.7.3/goodpractises.html#integration-with-setuptools-test-commands - # Allows to provide pytest command arguments through the test runner command `python setup.py test` + # Allows to provide pytest command argument through the test runner command `python setup.py test` # e.g: `python setup.py test -a "-k=test"` + # This only works for 1 argument though user_options = [("pytest-args=", "a", "Arguments to pass to py.test")] def initialize_options(self): @@ -59,7 +60,6 @@ class PyTest(TestCommand): # re-import them from the build location. Required when 2to3 is used # with namespace packages. if sys.version_info >= (3,) and getattr(self.distribution, "use_2to3", False): - print("Hack for 2to3", self.test_args) module = self.test_args[-1].split(".")[0] if module in _namespace_packages: del_modules = [] @@ -78,11 +78,8 @@ class PyTest(TestCommand): ei_cmd = self.get_finalized_command("egg_info") self.test_args = [normalize_path(ei_cmd.egg_base)] - print(self.test_args, self.pytest_args) cmd_args = self.test_args + ([self.pytest_args] if self.pytest_args else []) - print(cmd_args) errno = pytest.main(cmd_args) - sys.exit(errno) @@ -116,7 +113,7 @@ extra_opts = { "tests_require": [ "pytest<5.0", "pytest-cov", - "coverage", + "coverage<5.0", # recent coverage switched to sqlite format for the .coverage file which isn't handled properly by coveralls "blinker", "Pillow>=2.0.0", ], From 6040b4b494f93efea415ad6e05a1d33e5834e6c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 31 Oct 2019 21:33:19 +0100 Subject: [PATCH 26/59] update changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 58d7f272..249d99b1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,6 +19,7 @@ Development - From now on keyword arguments (e.g. ``Doc(field_name=value)``) are required. - Fix updating/modifying/deleting/reloading a document that's sharded by a field with ``db_field`` specified. #2125 - ``ListField`` now accepts an optional ``max_length`` parameter. #2110 +- Switch from nosetest to pytest as test runner #2114 - The codebase is now formatted using ``black``. #2109 - In bulk write insert, the detailed error message would raise in exception. From 37ca79e9c58e04ce0fd3a6775d804eb4dad6d8c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 31 Oct 2019 22:39:53 +0100 Subject: [PATCH 27/59] fix black formatting --- docs/conf.py | 3 ++- mongoengine/context_managers.py | 4 ++-- mongoengine/queryset/base.py | 4 +--- tests/document/test_instance.py | 4 ++-- tests/fields/test_complex_datetime_field.py | 2 +- tests/fields/test_embedded_document_field.py | 12 +++++------- tests/fields/test_fields.py | 8 ++++---- tests/fields/test_reference_field.py | 2 +- tests/test_common.py | 2 +- tests/test_connection.py | 2 +- tests/test_context_managers.py | 6 +++--- tests/test_datastructures.py | 10 +++++----- tests/test_replicaset_connection.py | 2 +- tests/test_utils.py | 2 +- tests/utils.py | 2 +- 15 files changed, 31 insertions(+), 34 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 0d642e0c..48c8e859 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,7 +11,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys, os +import os +import sys import sphinx_rtd_theme diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 3424a5d5..d8dfeaac 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -247,8 +247,8 @@ class query_counter(object): - self._ctx_query_counter ) self._ctx_query_counter += ( - 1 - ) # Account for the query we just issued to gather the information + 1 # Account for the query we just issued to gather the information + ) return count diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index a09cbf99..a648391e 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1193,9 +1193,7 @@ class BaseQuerySet(object): validate_read_preference("read_preference", read_preference) queryset = self.clone() queryset._read_preference = read_preference - queryset._cursor_obj = ( - None - ) # we need to re-create the cursor object whenever we apply read_preference + queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_preference return queryset def scalar(self, *fields): diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 9b4a16e5..203e2cce 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -1615,7 +1615,7 @@ class TestInstance(MongoDBTestCase): self.assertEqual(person.active, False) def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( - self + self, ): # Refers to Issue #1685 class EmbeddedChildModel(EmbeddedDocument): @@ -1629,7 +1629,7 @@ class TestInstance(MongoDBTestCase): self.assertEqual(changed_fields, []) def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc( - self + self, ): # Refers to Issue #1685 class User(Document): diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py index 4eea5bdc..611c0ff8 100644 --- a/tests/fields/test_complex_datetime_field.py +++ b/tests/fields/test_complex_datetime_field.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import datetime -import math import itertools +import math import re from mongoengine import * diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index 6b420781..8db8c180 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- from mongoengine import ( Document, - StringField, - ValidationError, EmbeddedDocument, EmbeddedDocumentField, - InvalidQueryError, - LookUpError, - IntField, GenericEmbeddedDocumentField, + IntField, + InvalidQueryError, ListField, - EmbeddedDocumentListField, - ReferenceField, + LookUpError, + StringField, + ValidationError, ) from tests.utils import MongoDBTestCase diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index d9279c22..bd2149e6 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -79,7 +79,7 @@ class TestField(MongoDBTestCase): self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) def test_custom_field_validation_raise_deprecated_error_when_validation_return_something( - self + self, ): # Covers introduction of a breaking change in the validation parameter (0.18) def _not_empty(z): @@ -202,7 +202,7 @@ class TestField(MongoDBTestCase): self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc( - self + self, ): """List field with default can be set to the empty list (strict)""" # Issue #1733 @@ -216,7 +216,7 @@ class TestField(MongoDBTestCase): self.assertEqual(reloaded.x, []) def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc( - self + self, ): """List field with default can be set to the empty list (dynamic)""" # Issue #1733 @@ -1245,7 +1245,7 @@ class TestField(MongoDBTestCase): self.assertEqual(a.b.c.txt, "hi") def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet( - self + self, ): raise SkipTest( "Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet" diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index 5fd053fe..783a46da 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from bson import SON, DBRef +from bson import DBRef, SON from mongoengine import * diff --git a/tests/test_common.py b/tests/test_common.py index 5d702668..28f0b992 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,7 +1,7 @@ import unittest -from mongoengine.common import _import_class from mongoengine import Document +from mongoengine.common import _import_class class TestCommon(unittest.TestCase): diff --git a/tests/test_connection.py b/tests/test_connection.py index 071f4207..1519a835 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -168,7 +168,7 @@ class ConnectionTest(unittest.TestCase): ) def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( - self + self, ): """Intended to keep the detecton function simple but robust""" db_name = "mongoenginetest" diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index dc9b9bf3..32e48a70 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -3,11 +3,11 @@ import unittest from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import ( - switch_db, - switch_collection, - no_sub_classes, no_dereference, + no_sub_classes, query_counter, + switch_collection, + switch_db, ) from mongoengine.pymongo_support import count_documents diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 7def2ac7..ff7598be 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -2,7 +2,7 @@ import unittest from six import iterkeys from mongoengine import Document -from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict +from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict class DocumentStub(object): @@ -20,8 +20,8 @@ class TestBaseDict(unittest.TestCase): fake_doc = DocumentStub() base_list = BaseDict(dict_items, instance=None, name="my_name") base_list._instance = ( - fake_doc - ) # hack to inject the mock, it does not work in the constructor + fake_doc # hack to inject the mock, it does not work in the constructor + ) return base_list def test___init___(self): @@ -156,8 +156,8 @@ class TestBaseList(unittest.TestCase): fake_doc = DocumentStub() base_list = BaseList(list_items, instance=None, name="my_name") base_list._instance = ( - fake_doc - ) # hack to inject the mock, it does not work in the constructor + fake_doc # hack to inject the mock, it does not work in the constructor + ) return base_list def test___init___(self): diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index 6dfab407..e92f3d09 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,7 +1,7 @@ import unittest -from pymongo import ReadPreference from pymongo import MongoClient +from pymongo import ReadPreference import mongoengine from mongoengine.connection import ConnectionFailure diff --git a/tests/test_utils.py b/tests/test_utils.py index 2d1e8b00..897c19b2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,5 @@ -import unittest import re +import unittest from mongoengine.base.utils import LazyRegexCompiler diff --git a/tests/utils.py b/tests/utils.py index eb3f016f..0719d6ef 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import unittest from nose.plugins.skip import SkipTest from mongoengine import connect -from mongoengine.connection import get_db, disconnect_all +from mongoengine.connection import disconnect_all, get_db from mongoengine.mongodb_support import get_mongodb_version From ac25f4b98bd8c4b6daad46faf1e8a163928d7bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 30 Aug 2019 16:13:30 +0300 Subject: [PATCH 28/59] ran unittest2pytest --- tests/all_warnings/test_warnings.py | 6 +- tests/document/test_class_methods.py | 133 +- tests/document/test_delta.py | 638 +++---- tests/document/test_dynamic.py | 193 +- tests/document/test_indexes.py | 285 ++- tests/document/test_inheritance.py | 256 ++- tests/document/test_instance.py | 1047 +++++----- tests/document/test_json_serialisation.py | 8 +- tests/document/test_validation.py | 76 +- tests/fields/test_binary_field.py | 40 +- tests/fields/test_boolean_field.py | 16 +- tests/fields/test_cached_reference_field.py | 179 +- tests/fields/test_complex_datetime_field.py | 46 +- tests/fields/test_date_field.py | 45 +- tests/fields/test_datetime_field.py | 71 +- tests/fields/test_decimal_field.py | 30 +- tests/fields/test_dict_field.py | 139 +- tests/fields/test_email_field.py | 37 +- tests/fields/test_embedded_document_field.py | 103 +- tests/fields/test_fields.py | 851 ++++----- tests/fields/test_file_field.py | 166 +- tests/fields/test_float_field.py | 20 +- tests/fields/test_geo_fields.py | 46 +- tests/fields/test_int_field.py | 14 +- tests/fields/test_lazy_reference_field.py | 118 +- tests/fields/test_long_field.py | 16 +- tests/fields/test_map_field.py | 31 +- tests/fields/test_reference_field.py | 46 +- tests/fields/test_sequence_field.py | 99 +- tests/fields/test_url_field.py | 15 +- tests/fields/test_uuid_field.py | 19 +- tests/queryset/test_field_list.py | 197 +- tests/queryset/test_geo.py | 205 +- tests/queryset/test_modify.py | 32 +- tests/queryset/test_pickable.py | 10 +- tests/queryset/test_queryset.py | 1784 +++++++++--------- tests/queryset/test_transform.py | 178 +- tests/queryset/test_visitor.py | 172 +- tests/test_common.py | 6 +- tests/test_connection.py | 254 ++- tests/test_context_managers.py | 139 +- tests/test_datastructures.py | 241 +-- tests/test_dereference.py | 386 ++-- tests/test_replicaset_connection.py | 2 +- tests/test_signals.py | 265 ++- tests/test_utils.py | 15 +- 46 files changed, 4247 insertions(+), 4428 deletions(-) diff --git a/tests/all_warnings/test_warnings.py b/tests/all_warnings/test_warnings.py index 67204617..a9910121 100644 --- a/tests/all_warnings/test_warnings.py +++ b/tests/all_warnings/test_warnings.py @@ -31,7 +31,5 @@ class TestAllWarnings(unittest.TestCase): meta = {"collection": "fail"} warning = self.warning_list[0] - self.assertEqual(SyntaxWarning, warning["category"]) - self.assertEqual( - "non_abstract_base", InheritedDocumentFailTest._get_collection_name() - ) + assert SyntaxWarning == warning["category"] + assert "non_abstract_base" == InheritedDocumentFailTest._get_collection_name() diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index c5df0843..98909d2f 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -29,43 +29,40 @@ class TestClassMethods(unittest.TestCase): def test_definition(self): """Ensure that document may be defined using fields. """ - self.assertEqual( - ["_cls", "age", "id", "name"], sorted(self.Person._fields.keys()) - ) - self.assertEqual( - ["IntField", "ObjectIdField", "StringField", "StringField"], - sorted([x.__class__.__name__ for x in self.Person._fields.values()]), + assert ["_cls", "age", "id", "name"] == sorted(self.Person._fields.keys()) + assert ["IntField", "ObjectIdField", "StringField", "StringField"] == sorted( + [x.__class__.__name__ for x in self.Person._fields.values()] ) def test_get_db(self): """Ensure that get_db returns the expected db. """ db = self.Person._get_db() - self.assertEqual(self.db, db) + assert self.db == db def test_get_collection_name(self): """Ensure that get_collection_name returns the expected collection name. """ collection_name = "person" - self.assertEqual(collection_name, self.Person._get_collection_name()) + assert collection_name == self.Person._get_collection_name() def test_get_collection(self): """Ensure that get_collection returns the expected collection. """ collection_name = "person" collection = self.Person._get_collection() - self.assertEqual(self.db[collection_name], collection) + assert self.db[collection_name] == collection def test_drop_collection(self): """Ensure that the collection may be dropped from the database. """ collection_name = "person" self.Person(name="Test").save() - self.assertIn(collection_name, list_collection_names(self.db)) + assert collection_name in list_collection_names(self.db) self.Person.drop_collection() - self.assertNotIn(collection_name, list_collection_names(self.db)) + assert collection_name not in list_collection_names(self.db) def test_register_delete_rule(self): """Ensure that register delete rule adds a delete rule to the document @@ -75,12 +72,10 @@ class TestClassMethods(unittest.TestCase): class Job(Document): employee = ReferenceField(self.Person) - self.assertEqual(self.Person._meta.get("delete_rules"), None) + assert self.Person._meta.get("delete_rules") == None self.Person.register_delete_rule(Job, "employee", NULLIFY) - self.assertEqual( - self.Person._meta["delete_rules"], {(Job, "employee"): NULLIFY} - ) + assert self.Person._meta["delete_rules"] == {(Job, "employee"): NULLIFY} def test_compare_indexes(self): """ Ensure that the indexes are properly created and that @@ -98,22 +93,22 @@ class TestClassMethods(unittest.TestCase): BlogPost.drop_collection() BlogPost.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPost.ensure_index(["author", "description"]) - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [], "extra": [[("author", 1), ("description", 1)]]}, - ) + assert BlogPost.compare_indexes() == { + "missing": [], + "extra": [[("author", 1), ("description", 1)]], + } BlogPost._get_collection().drop_index("author_1_description_1") - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPost._get_collection().drop_index("author_1_title_1") - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [[("author", 1), ("title", 1)]], "extra": []}, - ) + assert BlogPost.compare_indexes() == { + "missing": [[("author", 1), ("title", 1)]], + "extra": [], + } def test_compare_indexes_inheritance(self): """ Ensure that the indexes are properly created and that @@ -138,22 +133,22 @@ class TestClassMethods(unittest.TestCase): BlogPost.ensure_indexes() BlogPostWithTags.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPostWithTags.ensure_index(["author", "tag_list"]) - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [], "extra": [[("_cls", 1), ("author", 1), ("tag_list", 1)]]}, - ) + assert BlogPost.compare_indexes() == { + "missing": [], + "extra": [[("_cls", 1), ("author", 1), ("tag_list", 1)]], + } BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tag_list_1") - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tags_1") - self.assertEqual( - BlogPost.compare_indexes(), - {"missing": [[("_cls", 1), ("author", 1), ("tags", 1)]], "extra": []}, - ) + assert BlogPost.compare_indexes() == { + "missing": [[("_cls", 1), ("author", 1), ("tags", 1)]], + "extra": [], + } def test_compare_indexes_multiple_subclasses(self): """ Ensure that compare_indexes behaves correctly if called from a @@ -182,13 +177,9 @@ class TestClassMethods(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithCustomField.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) - self.assertEqual( - BlogPostWithTags.compare_indexes(), {"missing": [], "extra": []} - ) - self.assertEqual( - BlogPostWithCustomField.compare_indexes(), {"missing": [], "extra": []} - ) + assert BlogPost.compare_indexes() == {"missing": [], "extra": []} + assert BlogPostWithTags.compare_indexes() == {"missing": [], "extra": []} + assert BlogPostWithCustomField.compare_indexes() == {"missing": [], "extra": []} def test_compare_indexes_for_text_indexes(self): """ Ensure that compare_indexes behaves correctly for text indexes """ @@ -210,7 +201,7 @@ class TestClassMethods(unittest.TestCase): Doc.ensure_indexes() actual = Doc.compare_indexes() expected = {"missing": [], "extra": []} - self.assertEqual(actual, expected) + assert actual == expected def test_list_indexes_inheritance(self): """ ensure that all of the indexes are listed regardless of the super- @@ -240,19 +231,14 @@ class TestClassMethods(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithTagsAndExtraText.ensure_indexes() - self.assertEqual(BlogPost.list_indexes(), BlogPostWithTags.list_indexes()) - self.assertEqual( - BlogPost.list_indexes(), BlogPostWithTagsAndExtraText.list_indexes() - ) - self.assertEqual( - BlogPost.list_indexes(), - [ - [("_cls", 1), ("author", 1), ("tags", 1)], - [("_cls", 1), ("author", 1), ("tags", 1), ("extra_text", 1)], - [(u"_id", 1)], - [("_cls", 1)], - ], - ) + assert BlogPost.list_indexes() == BlogPostWithTags.list_indexes() + assert BlogPost.list_indexes() == BlogPostWithTagsAndExtraText.list_indexes() + assert BlogPost.list_indexes() == [ + [("_cls", 1), ("author", 1), ("tags", 1)], + [("_cls", 1), ("author", 1), ("tags", 1), ("extra_text", 1)], + [(u"_id", 1)], + [("_cls", 1)], + ] def test_register_delete_rule_inherited(self): class Vaccine(Document): @@ -271,8 +257,8 @@ class TestClassMethods(unittest.TestCase): class Cat(Animal): name = StringField(required=True) - self.assertEqual(Vaccine._meta["delete_rules"][(Animal, "vaccine_made")], PULL) - self.assertEqual(Vaccine._meta["delete_rules"][(Cat, "vaccine_made")], PULL) + assert Vaccine._meta["delete_rules"][(Animal, "vaccine_made")] == PULL + assert Vaccine._meta["delete_rules"][(Cat, "vaccine_made")] == PULL def test_collection_naming(self): """Ensure that a collection with a specified name may be used. @@ -281,19 +267,17 @@ class TestClassMethods(unittest.TestCase): class DefaultNamingTest(Document): pass - self.assertEqual( - "default_naming_test", DefaultNamingTest._get_collection_name() - ) + assert "default_naming_test" == DefaultNamingTest._get_collection_name() class CustomNamingTest(Document): meta = {"collection": "pimp_my_collection"} - self.assertEqual("pimp_my_collection", CustomNamingTest._get_collection_name()) + assert "pimp_my_collection" == CustomNamingTest._get_collection_name() class DynamicNamingTest(Document): meta = {"collection": lambda c: "DYNAMO"} - self.assertEqual("DYNAMO", DynamicNamingTest._get_collection_name()) + assert "DYNAMO" == DynamicNamingTest._get_collection_name() # Use Abstract class to handle backwards compatibility class BaseDocument(Document): @@ -302,14 +286,12 @@ class TestClassMethods(unittest.TestCase): class OldNamingConvention(BaseDocument): pass - self.assertEqual( - "oldnamingconvention", OldNamingConvention._get_collection_name() - ) + assert "oldnamingconvention" == OldNamingConvention._get_collection_name() class InheritedAbstractNamingTest(BaseDocument): meta = {"collection": "wibble"} - self.assertEqual("wibble", InheritedAbstractNamingTest._get_collection_name()) + assert "wibble" == InheritedAbstractNamingTest._get_collection_name() # Mixin tests class BaseMixin(object): @@ -318,8 +300,9 @@ class TestClassMethods(unittest.TestCase): class OldMixinNamingConvention(Document, BaseMixin): pass - self.assertEqual( - "oldmixinnamingconvention", OldMixinNamingConvention._get_collection_name() + assert ( + "oldmixinnamingconvention" + == OldMixinNamingConvention._get_collection_name() ) class BaseMixin(object): @@ -331,7 +314,7 @@ class TestClassMethods(unittest.TestCase): class MyDocument(BaseDocument): pass - self.assertEqual("basedocument", MyDocument._get_collection_name()) + assert "basedocument" == MyDocument._get_collection_name() def test_custom_collection_name_operations(self): """Ensure that a collection with a specified name is used as expected. @@ -343,16 +326,16 @@ class TestClassMethods(unittest.TestCase): meta = {"collection": collection_name} Person(name="Test User").save() - self.assertIn(collection_name, list_collection_names(self.db)) + assert collection_name in list_collection_names(self.db) user_obj = self.db[collection_name].find_one() - self.assertEqual(user_obj["name"], "Test User") + assert user_obj["name"] == "Test User" user_obj = Person.objects[0] - self.assertEqual(user_obj.name, "Test User") + assert user_obj.name == "Test User" Person.drop_collection() - self.assertNotIn(collection_name, list_collection_names(self.db)) + assert collection_name not in list_collection_names(self.db) def test_collection_name_and_primary(self): """Ensure that a collection with a specified name may be used. @@ -365,7 +348,7 @@ class TestClassMethods(unittest.TestCase): Person(name="Test User").save() user_obj = Person.objects.first() - self.assertEqual(user_obj.name, "Test User") + assert user_obj.name == "Test User" Person.drop_collection() diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 632d9b3f..2324211b 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -41,40 +41,40 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) doc.string_field = "hello" - self.assertEqual(doc._get_changed_fields(), ["string_field"]) - self.assertEqual(doc._delta(), ({"string_field": "hello"}, {})) + assert doc._get_changed_fields() == ["string_field"] + assert doc._delta() == ({"string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ["int_field"]) - self.assertEqual(doc._delta(), ({"int_field": 1}, {})) + assert doc._get_changed_fields() == ["int_field"] + assert doc._delta() == ({"int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({"dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({"list_field": list_value}, {})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({"list_field": list_value}, {}) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({}, {"dict_field": 1})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({}, {"dict_field": 1}) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({}, {"list_field": 1})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({}, {"list_field": 1}) def test_delta_recursive(self): self.delta_recursive(Document, EmbeddedDocument) @@ -102,8 +102,8 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) embedded_1 = Embedded() embedded_1.id = "010101" @@ -113,7 +113,7 @@ class TestDelta(MongoDBTestCase): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ["embedded_field"]) + assert doc._get_changed_fields() == ["embedded_field"] embedded_delta = { "id": "010101", @@ -122,27 +122,27 @@ class TestDelta(MongoDBTestCase): "dict_field": {"hello": "world"}, "list_field": ["1", 2, {"hello": "world"}], } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - self.assertEqual(doc._delta(), ({"embedded_field": embedded_delta}, {})) + assert doc.embedded_field._delta() == (embedded_delta, {}) + assert doc._delta() == ({"embedded_field": embedded_delta}, {}) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["embedded_field.dict_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"dict_field": 1})) - self.assertEqual(doc._delta(), ({}, {"embedded_field.dict_field": 1})) + assert doc._get_changed_fields() == ["embedded_field.dict_field"] + assert doc.embedded_field._delta() == ({}, {"dict_field": 1}) + assert doc._delta() == ({}, {"embedded_field.dict_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) + assert doc.embedded_field.dict_field == {} doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"list_field": 1})) - self.assertEqual(doc._delta(), ({}, {"embedded_field.list_field": 1})) + assert doc._get_changed_fields() == ["embedded_field.list_field"] + assert doc.embedded_field._delta() == ({}, {"list_field": 1}) + assert doc._delta() == ({}, {"embedded_field.list_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) + assert doc.embedded_field.list_field == [] embedded_2 = Embedded() embedded_2.string_field = "hello" @@ -151,148 +151,128 @@ class TestDelta(MongoDBTestCase): embedded_2.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field.list_field = ["1", 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"]) + assert doc._get_changed_fields() == ["embedded_field.list_field"] - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "string_field": "hello", - "dict_field": {"hello": "world"}, - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc.embedded_field._delta() == ( + { + "list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "dict_field": {"hello": "world"}, + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) - self.assertEqual( - doc._delta(), - ( - { - "embedded_field.list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "string_field": "hello", - "dict_field": {"hello": "world"}, - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "embedded_field.list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "dict_field": {"hello": "world"}, + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[0], "1") - self.assertEqual(doc.embedded_field.list_field[1], 2) + assert doc.embedded_field.list_field[0] == "1" + assert doc.embedded_field.list_field[1] == 2 for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) + assert doc.embedded_field.list_field[2][k] == embedded_2[k] doc.embedded_field.list_field[2].string_field = "world" - self.assertEqual( - doc._get_changed_fields(), ["embedded_field.list_field.2.string_field"] + assert doc._get_changed_fields() == ["embedded_field.list_field.2.string_field"] + assert doc.embedded_field._delta() == ( + {"list_field.2.string_field": "world"}, + {}, ) - self.assertEqual( - doc.embedded_field._delta(), ({"list_field.2.string_field": "world"}, {}) - ) - self.assertEqual( - doc._delta(), ({"embedded_field.list_field.2.string_field": "world"}, {}) + assert doc._delta() == ( + {"embedded_field.list_field.2.string_field": "world"}, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "world") + assert doc.embedded_field.list_field[2].string_field == "world" # Test multiple assignments doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field.2"]) - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "list_field.2": { - "_cls": "Embedded", - "string_field": "hello world", - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - "dict_field": {"hello": "world"}, - } - }, - {}, - ), + assert doc._get_changed_fields() == ["embedded_field.list_field.2"] + assert doc.embedded_field._delta() == ( + { + "list_field.2": { + "_cls": "Embedded", + "string_field": "hello world", + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + "dict_field": {"hello": "world"}, + } + }, + {}, ) - self.assertEqual( - doc._delta(), - ( - { - "embedded_field.list_field.2": { - "_cls": "Embedded", - "string_field": "hello world", - "int_field": 1, - "list_field": ["1", 2, {"hello": "world"}], - "dict_field": {"hello": "world"}, - } - }, - {}, - ), + assert doc._delta() == ( + { + "embedded_field.list_field.2": { + "_cls": "Embedded", + "string_field": "hello world", + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + "dict_field": {"hello": "world"}, + } + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world") + assert doc.embedded_field.list_field[2].string_field == "hello world" # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual( - doc._delta(), - ({"embedded_field.list_field.2.list_field": [2, {"hello": "world"}]}, {}), + assert doc._delta() == ( + {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}]}, + {}, ) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual( - doc._delta(), - ( - {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}, 1]}, - {}, - ), + assert doc._delta() == ( + {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}, 1]}, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1] - ) + assert doc.embedded_field.list_field[2].list_field == [2, {"hello": "world"}, 1] doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}] - ) + assert doc.embedded_field.list_field[2].list_field == [1, 2, {"hello": "world"}] del doc.embedded_field.list_field[2].list_field[2]["hello"] - self.assertEqual( - doc._delta(), ({}, {"embedded_field.list_field.2.list_field.2.hello": 1}) + assert doc._delta() == ( + {}, + {"embedded_field.list_field.2.list_field.2.hello": 1}, ) doc.save() doc = doc.reload(10) del doc.embedded_field.list_field[2].list_field - self.assertEqual( - doc._delta(), ({}, {"embedded_field.list_field.2.list_field": 1}) - ) + assert doc._delta() == ({}, {"embedded_field.list_field.2.list_field": 1}) doc.save() doc = doc.reload(10) @@ -302,12 +282,8 @@ class TestDelta(MongoDBTestCase): doc = doc.reload(10) doc.dict_field["Embedded"].string_field = "Hello World" - self.assertEqual( - doc._get_changed_fields(), ["dict_field.Embedded.string_field"] - ) - self.assertEqual( - doc._delta(), ({"dict_field.Embedded.string_field": "Hello World"}, {}) - ) + assert doc._get_changed_fields() == ["dict_field.Embedded.string_field"] + assert doc._delta() == ({"dict_field.Embedded.string_field": "Hello World"}, {}) def test_circular_reference_deltas(self): self.circular_reference_deltas(Document, Document) @@ -338,8 +314,8 @@ class TestDelta(MongoDBTestCase): p = Person.objects[0].select_related() o = Organization.objects.first() - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) + assert p.owns[0] == o + assert o.owner == p def test_circular_reference_deltas_2(self): self.circular_reference_deltas_2(Document, Document) @@ -379,9 +355,9 @@ class TestDelta(MongoDBTestCase): e = Person.objects.get(name="employee") o = Organization.objects.first() - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) - self.assertEqual(e.employer, o) + assert p.owns[0] == o + assert o.owner == p + assert e.employer == o return person, organization, employee @@ -401,40 +377,40 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) doc.string_field = "hello" - self.assertEqual(doc._get_changed_fields(), ["db_string_field"]) - self.assertEqual(doc._delta(), ({"db_string_field": "hello"}, {})) + assert doc._get_changed_fields() == ["db_string_field"] + assert doc._delta() == ({"db_string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ["db_int_field"]) - self.assertEqual(doc._delta(), ({"db_int_field": 1}, {})) + assert doc._get_changed_fields() == ["db_int_field"] + assert doc._delta() == ({"db_int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ["db_dict_field"]) - self.assertEqual(doc._delta(), ({"db_dict_field": dict_value}, {})) + assert doc._get_changed_fields() == ["db_dict_field"] + assert doc._delta() == ({"db_dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ["db_list_field"]) - self.assertEqual(doc._delta(), ({"db_list_field": list_value}, {})) + assert doc._get_changed_fields() == ["db_list_field"] + assert doc._delta() == ({"db_list_field": list_value}, {}) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["db_dict_field"]) - self.assertEqual(doc._delta(), ({}, {"db_dict_field": 1})) + assert doc._get_changed_fields() == ["db_dict_field"] + assert doc._delta() == ({}, {"db_dict_field": 1}) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["db_list_field"]) - self.assertEqual(doc._delta(), ({}, {"db_list_field": 1})) + assert doc._get_changed_fields() == ["db_list_field"] + assert doc._delta() == ({}, {"db_list_field": 1}) # Test it saves that data doc = Doc() @@ -447,10 +423,10 @@ class TestDelta(MongoDBTestCase): doc.save() doc = doc.reload(10) - self.assertEqual(doc.string_field, "hello") - self.assertEqual(doc.int_field, 1) - self.assertEqual(doc.dict_field, {"hello": "world"}) - self.assertEqual(doc.list_field, ["1", 2, {"hello": "world"}]) + assert doc.string_field == "hello" + assert doc.int_field == 1 + assert doc.dict_field == {"hello": "world"} + assert doc.list_field == ["1", 2, {"hello": "world"}] def test_delta_recursive_db_field(self): self.delta_recursive_db_field(Document, EmbeddedDocument) @@ -479,8 +455,8 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) embedded_1 = Embedded() embedded_1.string_field = "hello" @@ -489,7 +465,7 @@ class TestDelta(MongoDBTestCase): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field"]) + assert doc._get_changed_fields() == ["db_embedded_field"] embedded_delta = { "db_string_field": "hello", @@ -497,27 +473,27 @@ class TestDelta(MongoDBTestCase): "db_dict_field": {"hello": "world"}, "db_list_field": ["1", 2, {"hello": "world"}], } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - self.assertEqual(doc._delta(), ({"db_embedded_field": embedded_delta}, {})) + assert doc.embedded_field._delta() == (embedded_delta, {}) + assert doc._delta() == ({"db_embedded_field": embedded_delta}, {}) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_dict_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"db_dict_field": 1})) - self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_dict_field": 1})) + assert doc._get_changed_fields() == ["db_embedded_field.db_dict_field"] + assert doc.embedded_field._delta() == ({}, {"db_dict_field": 1}) + assert doc._delta() == ({}, {"db_embedded_field.db_dict_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) + assert doc.embedded_field.dict_field == {} doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"]) - self.assertEqual(doc.embedded_field._delta(), ({}, {"db_list_field": 1})) - self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_list_field": 1})) + assert doc._get_changed_fields() == ["db_embedded_field.db_list_field"] + assert doc.embedded_field._delta() == ({}, {"db_list_field": 1}) + assert doc._delta() == ({}, {"db_embedded_field.db_list_field": 1}) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) + assert doc.embedded_field.list_field == [] embedded_2 = Embedded() embedded_2.string_field = "hello" @@ -526,166 +502,142 @@ class TestDelta(MongoDBTestCase): embedded_2.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field.list_field = ["1", 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"]) - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "db_list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "db_string_field": "hello", - "db_dict_field": {"hello": "world"}, - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc._get_changed_fields() == ["db_embedded_field.db_list_field"] + assert doc.embedded_field._delta() == ( + { + "db_list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "db_string_field": "hello", + "db_dict_field": {"hello": "world"}, + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "db_string_field": "hello", - "db_dict_field": {"hello": "world"}, - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - }, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "db_string_field": "hello", + "db_dict_field": {"hello": "world"}, + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[0], "1") - self.assertEqual(doc.embedded_field.list_field[1], 2) + assert doc.embedded_field.list_field[0] == "1" + assert doc.embedded_field.list_field[1] == 2 for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) + assert doc.embedded_field.list_field[2][k] == embedded_2[k] doc.embedded_field.list_field[2].string_field = "world" - self.assertEqual( - doc._get_changed_fields(), - ["db_embedded_field.db_list_field.2.db_string_field"], + assert doc._get_changed_fields() == [ + "db_embedded_field.db_list_field.2.db_string_field" + ] + assert doc.embedded_field._delta() == ( + {"db_list_field.2.db_string_field": "world"}, + {}, ) - self.assertEqual( - doc.embedded_field._delta(), - ({"db_list_field.2.db_string_field": "world"}, {}), - ) - self.assertEqual( - doc._delta(), - ({"db_embedded_field.db_list_field.2.db_string_field": "world"}, {}), + assert doc._delta() == ( + {"db_embedded_field.db_list_field.2.db_string_field": "world"}, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "world") + assert doc.embedded_field.list_field[2].string_field == "world" # Test multiple assignments doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual( - doc._get_changed_fields(), ["db_embedded_field.db_list_field.2"] + assert doc._get_changed_fields() == ["db_embedded_field.db_list_field.2"] + assert doc.embedded_field._delta() == ( + { + "db_list_field.2": { + "_cls": "Embedded", + "db_string_field": "hello world", + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + "db_dict_field": {"hello": "world"}, + } + }, + {}, ) - self.assertEqual( - doc.embedded_field._delta(), - ( - { - "db_list_field.2": { - "_cls": "Embedded", - "db_string_field": "hello world", - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - "db_dict_field": {"hello": "world"}, - } - }, - {}, - ), - ) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field.2": { - "_cls": "Embedded", - "db_string_field": "hello world", - "db_int_field": 1, - "db_list_field": ["1", 2, {"hello": "world"}], - "db_dict_field": {"hello": "world"}, - } - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field.2": { + "_cls": "Embedded", + "db_string_field": "hello world", + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + "db_dict_field": {"hello": "world"}, + } + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world") + assert doc.embedded_field.list_field[2].string_field == "hello world" # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field.2.db_list_field": [ - 2, - {"hello": "world"}, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field.2.db_list_field": [ + 2, + {"hello": "world"}, + ] + }, + {}, ) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual( - doc._delta(), - ( - { - "db_embedded_field.db_list_field.2.db_list_field": [ - 2, - {"hello": "world"}, - 1, - ] - }, - {}, - ), + assert doc._delta() == ( + { + "db_embedded_field.db_list_field.2.db_list_field": [ + 2, + {"hello": "world"}, + 1, + ] + }, + {}, ) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1] - ) + assert doc.embedded_field.list_field[2].list_field == [2, {"hello": "world"}, 1] doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) - self.assertEqual( - doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}] - ) + assert doc.embedded_field.list_field[2].list_field == [1, 2, {"hello": "world"}] del doc.embedded_field.list_field[2].list_field[2]["hello"] - self.assertEqual( - doc._delta(), - ({}, {"db_embedded_field.db_list_field.2.db_list_field.2.hello": 1}), + assert doc._delta() == ( + {}, + {"db_embedded_field.db_list_field.2.db_list_field.2.hello": 1}, ) doc.save() doc = doc.reload(10) del doc.embedded_field.list_field[2].list_field - self.assertEqual( - doc._delta(), ({}, {"db_embedded_field.db_list_field.2.db_list_field": 1}) + assert doc._delta() == ( + {}, + {"db_embedded_field.db_list_field.2.db_list_field": 1}, ) def test_delta_for_dynamic_documents(self): @@ -696,14 +648,16 @@ class TestDelta(MongoDBTestCase): Person.drop_collection() p = Person(name="James", age=34) - self.assertEqual( - p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {}) + assert p._delta() == ( + SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), + {}, ) p.doc = 123 del p.doc - self.assertEqual( - p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {}) + assert p._delta() == ( + SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), + {}, ) p = Person() @@ -712,18 +666,18 @@ class TestDelta(MongoDBTestCase): p.save() p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ["age"]) - self.assertEqual(p._delta(), ({"age": 24}, {})) + assert p.age == 24 + assert p._get_changed_fields() == ["age"] + assert p._delta() == ({"age": 24}, {}) p = Person.objects(age=22).get() p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ["age"]) - self.assertEqual(p._delta(), ({"age": 24}, {})) + assert p.age == 24 + assert p._get_changed_fields() == ["age"] + assert p._delta() == ({"age": 24}, {}) p.save() - self.assertEqual(1, Person.objects(age=24).count()) + assert 1 == Person.objects(age=24).count() def test_dynamic_delta(self): class Doc(DynamicDocument): @@ -734,40 +688,40 @@ class TestDelta(MongoDBTestCase): doc.save() doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) + assert doc._get_changed_fields() == [] + assert doc._delta() == ({}, {}) doc.string_field = "hello" - self.assertEqual(doc._get_changed_fields(), ["string_field"]) - self.assertEqual(doc._delta(), ({"string_field": "hello"}, {})) + assert doc._get_changed_fields() == ["string_field"] + assert doc._delta() == ({"string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ["int_field"]) - self.assertEqual(doc._delta(), ({"int_field": 1}, {})) + assert doc._get_changed_fields() == ["int_field"] + assert doc._delta() == ({"int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({"dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({"list_field": list_value}, {})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({"list_field": list_value}, {}) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ["dict_field"]) - self.assertEqual(doc._delta(), ({}, {"dict_field": 1})) + assert doc._get_changed_fields() == ["dict_field"] + assert doc._delta() == ({}, {"dict_field": 1}) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ["list_field"]) - self.assertEqual(doc._delta(), ({}, {"list_field": 1})) + assert doc._get_changed_fields() == ["list_field"] + assert doc._delta() == ({}, {"list_field": 1}) def test_delta_with_dbref_true(self): person, organization, employee = self.circular_reference_deltas_2( @@ -775,16 +729,16 @@ class TestDelta(MongoDBTestCase): ) employee.name = "test" - self.assertEqual(organization._get_changed_fields(), []) + assert organization._get_changed_fields() == [] updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertEqual({}, updates) + assert {} == removals + assert {} == updates organization.employees.append(person) updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertIn("employees", updates) + assert {} == removals + assert "employees" in updates def test_delta_with_dbref_false(self): person, organization, employee = self.circular_reference_deltas_2( @@ -792,16 +746,16 @@ class TestDelta(MongoDBTestCase): ) employee.name = "test" - self.assertEqual(organization._get_changed_fields(), []) + assert organization._get_changed_fields() == [] updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertEqual({}, updates) + assert {} == removals + assert {} == updates organization.employees.append(person) updates, removals = organization._delta() - self.assertEqual({}, removals) - self.assertIn("employees", updates) + assert {} == removals + assert "employees" in updates def test_nested_nested_fields_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -821,11 +775,11 @@ class TestDelta(MongoDBTestCase): subdoc = mydoc.subs["a"]["b"] subdoc.name = "bar" - self.assertEqual(["name"], subdoc._get_changed_fields()) - self.assertEqual(["subs.a.b.name"], mydoc._get_changed_fields()) + assert ["name"] == subdoc._get_changed_fields() + assert ["subs.a.b.name"] == mydoc._get_changed_fields() mydoc._clear_changed_fields() - self.assertEqual([], mydoc._get_changed_fields()) + assert [] == mydoc._get_changed_fields() def test_lower_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -840,17 +794,17 @@ class TestDelta(MongoDBTestCase): mydoc = MyDoc.objects.first() mydoc.subs["a"] = EmbeddedDoc() - self.assertEqual(["subs.a"], mydoc._get_changed_fields()) + assert ["subs.a"] == mydoc._get_changed_fields() subdoc = mydoc.subs["a"] subdoc.name = "bar" - self.assertEqual(["name"], subdoc._get_changed_fields()) - self.assertEqual(["subs.a"], mydoc._get_changed_fields()) + assert ["name"] == subdoc._get_changed_fields() + assert ["subs.a"] == mydoc._get_changed_fields() mydoc.save() mydoc._clear_changed_fields() - self.assertEqual([], mydoc._get_changed_fields()) + assert [] == mydoc._get_changed_fields() def test_upper_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -867,15 +821,15 @@ class TestDelta(MongoDBTestCase): subdoc = mydoc.subs["a"] subdoc.name = "bar" - self.assertEqual(["name"], subdoc._get_changed_fields()) - self.assertEqual(["subs.a.name"], mydoc._get_changed_fields()) + assert ["name"] == subdoc._get_changed_fields() + assert ["subs.a.name"] == mydoc._get_changed_fields() mydoc.subs["a"] = EmbeddedDoc() - self.assertEqual(["subs.a"], mydoc._get_changed_fields()) + assert ["subs.a"] == mydoc._get_changed_fields() mydoc.save() mydoc._clear_changed_fields() - self.assertEqual([], mydoc._get_changed_fields()) + assert [] == mydoc._get_changed_fields() def test_referenced_object_changed_attributes(self): """Ensures that when you save a new reference to a field, the referenced object isn't altered""" @@ -902,22 +856,22 @@ class TestDelta(MongoDBTestCase): org1.reload() org2.reload() user.reload() - self.assertEqual(org1.name, "Org 1") - self.assertEqual(org2.name, "Org 2") - self.assertEqual(user.name, "Fred") + assert org1.name == "Org 1" + assert org2.name == "Org 2" + assert user.name == "Fred" user.name = "Harold" user.org = org2 org2.name = "New Org 2" - self.assertEqual(org2.name, "New Org 2") + assert org2.name == "New Org 2" user.save() org2.save() - self.assertEqual(org2.name, "New Org 2") + assert org2.name == "New Org 2" org2.reload() - self.assertEqual(org2.name, "New Org 2") + assert org2.name == "New Org 2" def test_delta_for_nested_map_fields(self): class UInfoDocument(Document): @@ -950,12 +904,12 @@ class TestDelta(MongoDBTestCase): d.users["007"]["rolist"].append(EmbeddedRole(type="oops")) d.users["007"]["info"] = uinfo delta = d._delta() - self.assertEqual(True, "users.007.roles.666" in delta[0]) - self.assertEqual(True, "users.007.rolist" in delta[0]) - self.assertEqual(True, "users.007.info" in delta[0]) - self.assertEqual("superadmin", delta[0]["users.007.roles.666"]["type"]) - self.assertEqual("oops", delta[0]["users.007.rolist"][0]["type"]) - self.assertEqual(uinfo.id, delta[0]["users.007.info"]) + assert True == ("users.007.roles.666" in delta[0]) + assert True == ("users.007.rolist" in delta[0]) + assert True == ("users.007.info" in delta[0]) + assert "superadmin" == delta[0]["users.007.roles.666"]["type"] + assert "oops" == delta[0]["users.007.rolist"][0]["type"] + assert uinfo.id == delta[0]["users.007.info"] if __name__ == "__main__": diff --git a/tests/document/test_dynamic.py b/tests/document/test_dynamic.py index 6b517d24..a6f46862 100644 --- a/tests/document/test_dynamic.py +++ b/tests/document/test_dynamic.py @@ -2,6 +2,7 @@ import unittest from mongoengine import * from tests.utils import MongoDBTestCase +import pytest __all__ = ("TestDynamicDocument",) @@ -25,15 +26,15 @@ class TestDynamicDocument(MongoDBTestCase): p.name = "James" p.age = 34 - self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", "age": 34}) - self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"]) + assert p.to_mongo() == {"_cls": "Person", "name": "James", "age": 34} + assert p.to_mongo().keys() == ["_cls", "name", "age"] p.save() - self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"]) + assert p.to_mongo().keys() == ["_id", "_cls", "name", "age"] - self.assertEqual(self.Person.objects.first().age, 34) + assert self.Person.objects.first().age == 34 # Confirm no changes to self.Person - self.assertFalse(hasattr(self.Person, "age")) + assert not hasattr(self.Person, "age") def test_change_scope_of_variable(self): """Test changing the scope of a dynamic field has no adverse effects""" @@ -47,7 +48,7 @@ class TestDynamicDocument(MongoDBTestCase): p.save() p = self.Person.objects.get() - self.assertEqual(p.misc, {"hello": "world"}) + assert p.misc == {"hello": "world"} def test_delete_dynamic_field(self): """Test deleting a dynamic field works""" @@ -62,19 +63,19 @@ class TestDynamicDocument(MongoDBTestCase): p.save() p = self.Person.objects.get() - self.assertEqual(p.misc, {"hello": "world"}) + assert p.misc == {"hello": "world"} collection = self.db[self.Person._get_collection_name()] obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "misc", "name"]) + assert sorted(obj.keys()) == ["_cls", "_id", "misc", "name"] del p.misc p.save() p = self.Person.objects.get() - self.assertFalse(hasattr(p, "misc")) + assert not hasattr(p, "misc") obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "name"]) + assert sorted(obj.keys()) == ["_cls", "_id", "name"] def test_reload_after_unsetting(self): p = self.Person() @@ -88,12 +89,12 @@ class TestDynamicDocument(MongoDBTestCase): p = self.Person.objects.create() p.update(age=1) - self.assertEqual(len(p._data), 3) - self.assertEqual(sorted(p._data.keys()), ["_cls", "id", "name"]) + assert len(p._data) == 3 + assert sorted(p._data.keys()) == ["_cls", "id", "name"] p.reload() - self.assertEqual(len(p._data), 4) - self.assertEqual(sorted(p._data.keys()), ["_cls", "age", "id", "name"]) + assert len(p._data) == 4 + assert sorted(p._data.keys()) == ["_cls", "age", "id", "name"] def test_fields_without_underscore(self): """Ensure we can query dynamic fields""" @@ -103,16 +104,18 @@ class TestDynamicDocument(MongoDBTestCase): p.save() raw_p = Person.objects.as_pymongo().get(id=p.id) - self.assertEqual(raw_p, {"_cls": u"Person", "_id": p.id, "name": u"Dean"}) + assert raw_p == {"_cls": u"Person", "_id": p.id, "name": u"Dean"} p.name = "OldDean" p.newattr = "garbage" p.save() raw_p = Person.objects.as_pymongo().get(id=p.id) - self.assertEqual( - raw_p, - {"_cls": u"Person", "_id": p.id, "name": "OldDean", "newattr": u"garbage"}, - ) + assert raw_p == { + "_cls": u"Person", + "_id": p.id, + "name": "OldDean", + "newattr": u"garbage", + } def test_fields_containing_underscore(self): """Ensure we can query dynamic fields""" @@ -127,14 +130,14 @@ class TestDynamicDocument(MongoDBTestCase): p.save() raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id) - self.assertEqual(raw_p, {"_id": p.id, "_name": u"Dean", "name": u"Dean"}) + assert raw_p == {"_id": p.id, "_name": u"Dean", "name": u"Dean"} p.name = "OldDean" p._name = "NewDean" p._newattr1 = "garbage" # Unknown fields won't be added p.save() raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id) - self.assertEqual(raw_p, {"_id": p.id, "_name": u"NewDean", "name": u"OldDean"}) + assert raw_p == {"_id": p.id, "_name": u"NewDean", "name": u"OldDean"} def test_dynamic_document_queries(self): """Ensure we can query dynamic fields""" @@ -143,10 +146,10 @@ class TestDynamicDocument(MongoDBTestCase): p.age = 22 p.save() - self.assertEqual(1, self.Person.objects(age=22).count()) + assert 1 == self.Person.objects(age=22).count() p = self.Person.objects(age=22) p = p.get() - self.assertEqual(22, p.age) + assert 22 == p.age def test_complex_dynamic_document_queries(self): class Person(DynamicDocument): @@ -166,8 +169,8 @@ class TestDynamicDocument(MongoDBTestCase): p2.age = 10 p2.save() - self.assertEqual(Person.objects(age__icontains="ten").count(), 2) - self.assertEqual(Person.objects(age__gte=10).count(), 1) + assert Person.objects(age__icontains="ten").count() == 2 + assert Person.objects(age__gte=10).count() == 1 def test_complex_data_lookups(self): """Ensure you can query dynamic document dynamic fields""" @@ -175,12 +178,12 @@ class TestDynamicDocument(MongoDBTestCase): p.misc = {"hello": "world"} p.save() - self.assertEqual(1, self.Person.objects(misc__hello="world").count()) + assert 1 == self.Person.objects(misc__hello="world").count() def test_three_level_complex_data_lookups(self): """Ensure you can query three level document dynamic fields""" self.Person.objects.create(misc={"hello": {"hello2": "world"}}) - self.assertEqual(1, self.Person.objects(misc__hello__hello2="world").count()) + assert 1 == self.Person.objects(misc__hello__hello2="world").count() def test_complex_embedded_document_validation(self): """Ensure embedded dynamic documents may be validated""" @@ -198,11 +201,13 @@ class TestDynamicDocument(MongoDBTestCase): embedded_doc_1.validate() embedded_doc_2 = Embedded(content="this is not a url") - self.assertRaises(ValidationError, embedded_doc_2.validate) + with pytest.raises(ValidationError): + embedded_doc_2.validate() doc.embedded_field_1 = embedded_doc_1 doc.embedded_field_2 = embedded_doc_2 - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() def test_inheritance(self): """Ensure that dynamic document plays nice with inheritance""" @@ -212,11 +217,9 @@ class TestDynamicDocument(MongoDBTestCase): Employee.drop_collection() - self.assertIn("name", Employee._fields) - self.assertIn("salary", Employee._fields) - self.assertEqual( - Employee._get_collection_name(), self.Person._get_collection_name() - ) + assert "name" in Employee._fields + assert "salary" in Employee._fields + assert Employee._get_collection_name() == self.Person._get_collection_name() joe_bloggs = Employee() joe_bloggs.name = "Joe Bloggs" @@ -224,11 +227,11 @@ class TestDynamicDocument(MongoDBTestCase): joe_bloggs.age = 20 joe_bloggs.save() - self.assertEqual(1, self.Person.objects(age=20).count()) - self.assertEqual(1, Employee.objects(age=20).count()) + assert 1 == self.Person.objects(age=20).count() + assert 1 == Employee.objects(age=20).count() joe_bloggs = self.Person.objects.first() - self.assertIsInstance(joe_bloggs, Employee) + assert isinstance(joe_bloggs, Employee) def test_embedded_dynamic_document(self): """Test dynamic embedded documents""" @@ -249,26 +252,23 @@ class TestDynamicDocument(MongoDBTestCase): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual( - doc.to_mongo(), - { - "embedded_field": { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ["1", 2, {"hello": "world"}], - } - }, - ) + assert doc.to_mongo() == { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ["1", 2, {"hello": "world"}], + } + } doc.save() doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"}) - self.assertEqual(doc.embedded_field.list_field, ["1", 2, {"hello": "world"}]) + assert doc.embedded_field.__class__ == Embedded + assert doc.embedded_field.string_field == "hello" + assert doc.embedded_field.int_field == 1 + assert doc.embedded_field.dict_field == {"hello": "world"} + assert doc.embedded_field.list_field == ["1", 2, {"hello": "world"}] def test_complex_embedded_documents(self): """Test complex dynamic embedded documents setups""" @@ -296,44 +296,41 @@ class TestDynamicDocument(MongoDBTestCase): embedded_1.list_field = ["1", 2, embedded_2] doc.embedded_field = embedded_1 - self.assertEqual( - doc.to_mongo(), - { - "embedded_field": { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": [ - "1", - 2, - { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ["1", 2, {"hello": "world"}], - }, - ], - } - }, - ) + assert doc.to_mongo() == { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ["1", 2, {"hello": "world"}], + }, + ], + } + } doc.save() doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"}) - self.assertEqual(doc.embedded_field.list_field[0], "1") - self.assertEqual(doc.embedded_field.list_field[1], 2) + assert doc.embedded_field.__class__ == Embedded + assert doc.embedded_field.string_field == "hello" + assert doc.embedded_field.int_field == 1 + assert doc.embedded_field.dict_field == {"hello": "world"} + assert doc.embedded_field.list_field[0] == "1" + assert doc.embedded_field.list_field[1] == 2 embedded_field = doc.embedded_field.list_field[2] - self.assertEqual(embedded_field.__class__, Embedded) - self.assertEqual(embedded_field.string_field, "hello") - self.assertEqual(embedded_field.int_field, 1) - self.assertEqual(embedded_field.dict_field, {"hello": "world"}) - self.assertEqual(embedded_field.list_field, ["1", 2, {"hello": "world"}]) + assert embedded_field.__class__ == Embedded + assert embedded_field.string_field == "hello" + assert embedded_field.int_field == 1 + assert embedded_field.dict_field == {"hello": "world"} + assert embedded_field.list_field == ["1", 2, {"hello": "world"}] def test_dynamic_and_embedded(self): """Ensure embedded documents play nicely""" @@ -352,18 +349,18 @@ class TestDynamicDocument(MongoDBTestCase): person.address.city = "Lundenne" person.save() - self.assertEqual(Person.objects.first().address.city, "Lundenne") + assert Person.objects.first().address.city == "Lundenne" person = Person.objects.first() person.address = Address(city="Londinium") person.save() - self.assertEqual(Person.objects.first().address.city, "Londinium") + assert Person.objects.first().address.city == "Londinium" person = Person.objects.first() person.age = 35 person.save() - self.assertEqual(Person.objects.first().age, 35) + assert Person.objects.first().age == 35 def test_dynamic_embedded_works_with_only(self): """Ensure custom fieldnames on a dynamic embedded document are found by qs.only()""" @@ -380,10 +377,10 @@ class TestDynamicDocument(MongoDBTestCase): name="Eric", address=Address(city="San Francisco", street_number="1337") ).save() - self.assertEqual(Person.objects.first().address.street_number, "1337") - self.assertEqual( - Person.objects.only("address__street_number").first().address.street_number, - "1337", + assert Person.objects.first().address.street_number == "1337" + assert ( + Person.objects.only("address__street_number").first().address.street_number + == "1337" ) def test_dynamic_and_embedded_dict_access(self): @@ -408,20 +405,20 @@ class TestDynamicDocument(MongoDBTestCase): person["address"]["city"] = "Lundenne" person.save() - self.assertEqual(Person.objects.first().address.city, "Lundenne") + assert Person.objects.first().address.city == "Lundenne" - self.assertEqual(Person.objects.first().phone, "555-1212") + assert Person.objects.first().phone == "555-1212" person = Person.objects.first() person.address = Address(city="Londinium") person.save() - self.assertEqual(Person.objects.first().address.city, "Londinium") + assert Person.objects.first().address.city == "Londinium" person = Person.objects.first() person["age"] = 35 person.save() - self.assertEqual(Person.objects.first().age, 35) + assert Person.objects.first().age == 35 if __name__ == "__main__": diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 1b0304c4..cc1aae52 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -9,6 +9,7 @@ from six import iteritems from mongoengine import * from mongoengine.connection import get_db +import pytest class TestIndexes(unittest.TestCase): @@ -53,15 +54,15 @@ class TestIndexes(unittest.TestCase): {"fields": [("tags", 1)]}, {"fields": [("category", 1), ("addDate", -1)]}, ] - self.assertEqual(expected_specs, BlogPost._meta["index_specs"]) + assert expected_specs == BlogPost._meta["index_specs"] BlogPost.ensure_indexes() info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') - self.assertEqual(len(info), 4) + assert len(info) == 4 info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected["fields"], info) + assert expected["fields"] in info def _index_test_inheritance(self, InheritFrom): class BlogPost(InheritFrom): @@ -78,7 +79,7 @@ class TestIndexes(unittest.TestCase): {"fields": [("_cls", 1), ("tags", 1)]}, {"fields": [("_cls", 1), ("category", 1), ("addDate", -1)]}, ] - self.assertEqual(expected_specs, BlogPost._meta["index_specs"]) + assert expected_specs == BlogPost._meta["index_specs"] BlogPost.ensure_indexes() info = BlogPost.objects._collection.index_information() @@ -86,17 +87,17 @@ class TestIndexes(unittest.TestCase): # NB: there is no index on _cls by itself, since # the indices on -date and tags will both contain # _cls as first element in the key - self.assertEqual(len(info), 4) + assert len(info) == 4 info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected["fields"], info) + assert expected["fields"] in info class ExtendedBlogPost(BlogPost): title = StringField() meta = {"indexes": ["title"]} expected_specs.append({"fields": [("_cls", 1), ("title", 1)]}) - self.assertEqual(expected_specs, ExtendedBlogPost._meta["index_specs"]) + assert expected_specs == ExtendedBlogPost._meta["index_specs"] BlogPost.drop_collection() @@ -104,7 +105,7 @@ class TestIndexes(unittest.TestCase): info = ExtendedBlogPost.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected["fields"], info) + assert expected["fields"] in info def test_indexes_document_inheritance(self): """Ensure that indexes are used when meta[indexes] is specified for @@ -128,10 +129,8 @@ class TestIndexes(unittest.TestCase): class B(A): description = StringField() - self.assertEqual(A._meta["index_specs"], B._meta["index_specs"]) - self.assertEqual( - [{"fields": [("_cls", 1), ("title", 1)]}], A._meta["index_specs"] - ) + assert A._meta["index_specs"] == B._meta["index_specs"] + assert [{"fields": [("_cls", 1), ("title", 1)]}] == A._meta["index_specs"] def test_index_no_cls(self): """Ensure index specs are inhertited correctly""" @@ -144,11 +143,11 @@ class TestIndexes(unittest.TestCase): "index_cls": False, } - self.assertEqual([("title", 1)], A._meta["index_specs"][0]["fields"]) + assert [("title", 1)] == A._meta["index_specs"][0]["fields"] A._get_collection().drop_indexes() A.ensure_indexes() info = A._get_collection().index_information() - self.assertEqual(len(info.keys()), 2) + assert len(info.keys()) == 2 class B(A): c = StringField() @@ -158,8 +157,8 @@ class TestIndexes(unittest.TestCase): "allow_inheritance": True, } - self.assertEqual([("c", 1)], B._meta["index_specs"][1]["fields"]) - self.assertEqual([("_cls", 1), ("d", 1)], B._meta["index_specs"][2]["fields"]) + assert [("c", 1)] == B._meta["index_specs"][1]["fields"] + assert [("_cls", 1), ("d", 1)] == B._meta["index_specs"][2]["fields"] def test_build_index_spec_is_not_destructive(self): class MyDoc(Document): @@ -167,12 +166,12 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["keywords"], "allow_inheritance": False} - self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}]) + assert MyDoc._meta["index_specs"] == [{"fields": [("keywords", 1)]}] # Force index creation MyDoc.ensure_indexes() - self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}]) + assert MyDoc._meta["index_specs"] == [{"fields": [("keywords", 1)]}] def test_embedded_document_index_meta(self): """Ensure that embedded document indexes are created explicitly @@ -187,7 +186,7 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["rank.title"], "allow_inheritance": False} - self.assertEqual([{"fields": [("rank.title", 1)]}], Person._meta["index_specs"]) + assert [{"fields": [("rank.title", 1)]}] == Person._meta["index_specs"] Person.drop_collection() @@ -195,7 +194,7 @@ class TestIndexes(unittest.TestCase): list(Person.objects) info = Person.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("rank.title", 1)], info) + assert [("rank.title", 1)] in info def test_explicit_geo2d_index(self): """Ensure that geo2d indexes work when created via meta[indexes] @@ -205,14 +204,12 @@ class TestIndexes(unittest.TestCase): location = DictField() meta = {"allow_inheritance": True, "indexes": ["*location.point"]} - self.assertEqual( - [{"fields": [("location.point", "2d")]}], Place._meta["index_specs"] - ) + assert [{"fields": [("location.point", "2d")]}] == Place._meta["index_specs"] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "2d")], info) + assert [("location.point", "2d")] in info def test_explicit_geo2d_index_embedded(self): """Ensure that geo2d indexes work when created via meta[indexes] @@ -225,14 +222,14 @@ class TestIndexes(unittest.TestCase): current = DictField(field=EmbeddedDocumentField("EmbeddedLocation")) meta = {"allow_inheritance": True, "indexes": ["*current.location.point"]} - self.assertEqual( - [{"fields": [("current.location.point", "2d")]}], Place._meta["index_specs"] - ) + assert [{"fields": [("current.location.point", "2d")]}] == Place._meta[ + "index_specs" + ] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("current.location.point", "2d")], info) + assert [("current.location.point", "2d")] in info def test_explicit_geosphere_index(self): """Ensure that geosphere indexes work when created via meta[indexes] @@ -242,14 +239,14 @@ class TestIndexes(unittest.TestCase): location = DictField() meta = {"allow_inheritance": True, "indexes": ["(location.point"]} - self.assertEqual( - [{"fields": [("location.point", "2dsphere")]}], Place._meta["index_specs"] - ) + assert [{"fields": [("location.point", "2dsphere")]}] == Place._meta[ + "index_specs" + ] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "2dsphere")], info) + assert [("location.point", "2dsphere")] in info def test_explicit_geohaystack_index(self): """Ensure that geohaystack indexes work when created via meta[indexes] @@ -264,15 +261,14 @@ class TestIndexes(unittest.TestCase): name = StringField() meta = {"indexes": [(")location.point", "name")]} - self.assertEqual( - [{"fields": [("location.point", "geoHaystack"), ("name", 1)]}], - Place._meta["index_specs"], - ) + assert [ + {"fields": [("location.point", "geoHaystack"), ("name", 1)]} + ] == Place._meta["index_specs"] Place.ensure_indexes() info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "geoHaystack")], info) + assert [("location.point", "geoHaystack")] in info def test_create_geohaystack_index(self): """Ensure that geohaystack indexes can be created @@ -285,7 +281,7 @@ class TestIndexes(unittest.TestCase): Place.create_index({"fields": (")location.point", "name")}, bucketSize=10) info = Place._get_collection().index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("location.point", "geoHaystack"), ("name", 1)], info) + assert [("location.point", "geoHaystack"), ("name", 1)] in info def test_dictionary_indexes(self): """Ensure that indexes are used when meta[indexes] contains @@ -298,16 +294,15 @@ class TestIndexes(unittest.TestCase): tags = ListField(StringField()) meta = {"indexes": [{"fields": ["-date"], "unique": True, "sparse": True}]} - self.assertEqual( - [{"fields": [("addDate", -1)], "unique": True, "sparse": True}], - BlogPost._meta["index_specs"], - ) + assert [ + {"fields": [("addDate", -1)], "unique": True, "sparse": True} + ] == BlogPost._meta["index_specs"] BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() # _id, '-date' - self.assertEqual(len(info), 2) + assert len(info) == 2 # Indexes are lazy so use list() to perform query list(BlogPost.objects) @@ -316,7 +311,7 @@ class TestIndexes(unittest.TestCase): (value["key"], value.get("unique", False), value.get("sparse", False)) for key, value in iteritems(info) ] - self.assertIn(([("addDate", -1)], True, True), info) + assert ([("addDate", -1)], True, True) in info BlogPost.drop_collection() @@ -338,11 +333,9 @@ class TestIndexes(unittest.TestCase): Person(name="test", user_guid="123").save() - self.assertEqual(1, Person.objects.count()) + assert 1 == Person.objects.count() info = Person.objects._collection.index_information() - self.assertEqual( - sorted(info.keys()), ["_cls_1_name_1", "_cls_1_user_guid_1", "_id_"] - ) + assert sorted(info.keys()) == ["_cls_1_name_1", "_cls_1_user_guid_1", "_id_"] def test_disable_index_creation(self): """Tests setting auto_create_index to False on the connection will @@ -365,13 +358,13 @@ class TestIndexes(unittest.TestCase): User(user_guid="123").save() MongoUser(user_guid="123").save() - self.assertEqual(2, User.objects.count()) + assert 2 == User.objects.count() info = User.objects._collection.index_information() - self.assertEqual(list(info.keys()), ["_id_"]) + assert list(info.keys()) == ["_id_"] User.ensure_indexes() info = User.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ["_cls_1_user_guid_1", "_id_"]) + assert sorted(info.keys()) == ["_cls_1_user_guid_1", "_id_"] def test_embedded_document_index(self): """Tests settings an index on an embedded document @@ -389,7 +382,7 @@ class TestIndexes(unittest.TestCase): BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ["_id_", "date.yr_-1"]) + assert sorted(info.keys()) == ["_id_", "date.yr_-1"] def test_list_embedded_document_index(self): """Ensure list embedded documents can be indexed @@ -408,7 +401,7 @@ class TestIndexes(unittest.TestCase): info = BlogPost.objects._collection.index_information() # we don't use _cls in with list fields by default - self.assertEqual(sorted(info.keys()), ["_id_", "tags.tag_1"]) + assert sorted(info.keys()) == ["_id_", "tags.tag_1"] post1 = BlogPost( title="Embedded Indexes tests in place", @@ -426,7 +419,7 @@ class TestIndexes(unittest.TestCase): RecursiveDocument.ensure_indexes() info = RecursiveDocument._get_collection().index_information() - self.assertEqual(sorted(info.keys()), ["_cls_1", "_id_"]) + assert sorted(info.keys()) == ["_cls_1", "_id_"] def test_covered_index(self): """Ensure that covered indexes can be used @@ -446,46 +439,45 @@ class TestIndexes(unittest.TestCase): # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude("a").explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IDHACK", + .get("stage") + == "IDHACK" ) query_plan = Test.objects(id=obj.id).only("id").explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IDHACK", + .get("stage") + == "IDHACK" ) query_plan = Test.objects(a=1).only("a").exclude("id").explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IXSCAN", + .get("stage") + == "IXSCAN" ) - self.assertEqual( - query_plan.get("queryPlanner").get("winningPlan").get("stage"), "PROJECTION" + assert ( + query_plan.get("queryPlanner").get("winningPlan").get("stage") + == "PROJECTION" ) query_plan = Test.objects(a=1).explain() - self.assertEqual( + assert ( query_plan.get("queryPlanner") .get("winningPlan") .get("inputStage") - .get("stage"), - "IXSCAN", - ) - self.assertEqual( - query_plan.get("queryPlanner").get("winningPlan").get("stage"), "FETCH" + .get("stage") + == "IXSCAN" ) + assert query_plan.get("queryPlanner").get("winningPlan").get("stage") == "FETCH" def test_index_on_id(self): class BlogPost(Document): @@ -498,9 +490,7 @@ class TestIndexes(unittest.TestCase): BlogPost.drop_collection() indexes = BlogPost.objects._collection.index_information() - self.assertEqual( - indexes["categories_1__id_1"]["key"], [("categories", 1), ("_id", 1)] - ) + assert indexes["categories_1__id_1"]["key"] == [("categories", 1), ("_id", 1)] def test_hint(self): TAGS_INDEX_NAME = "tags_1" @@ -516,25 +506,25 @@ class TestIndexes(unittest.TestCase): BlogPost(tags=tags).save() # Hinting by shape should work. - self.assertEqual(BlogPost.objects.hint([("tags", 1)]).count(), 10) + assert BlogPost.objects.hint([("tags", 1)]).count() == 10 # Hinting by index name should work. - self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) + assert BlogPost.objects.hint(TAGS_INDEX_NAME).count() == 10 # Clearing the hint should work fine. - self.assertEqual(BlogPost.objects.hint().count(), 10) - self.assertEqual(BlogPost.objects.hint([("ZZ", 1)]).hint().count(), 10) + assert BlogPost.objects.hint().count() == 10 + assert BlogPost.objects.hint([("ZZ", 1)]).hint().count() == 10 # Hinting on a non-existent index shape should fail. - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): BlogPost.objects.hint([("ZZ", 1)]).count() # Hinting on a non-existent index name should fail. - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): BlogPost.objects.hint("Bad Name").count() # Invalid shape argument (missing list brackets) should fail. - with self.assertRaises(ValueError): + with pytest.raises(ValueError): BlogPost.objects.hint(("tags", 1)).count() def test_collation(self): @@ -588,11 +578,14 @@ class TestIndexes(unittest.TestCase): # Two posts with the same slug is not allowed post2 = BlogPost(title="test2", slug="test") - self.assertRaises(NotUniqueError, post2.save) - self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2) + with pytest.raises(NotUniqueError): + post2.save() + with pytest.raises(NotUniqueError): + BlogPost.objects.insert(post2) # Ensure backwards compatibility for errors - self.assertRaises(OperationError, post2.save) + with pytest.raises(OperationError): + post2.save() def test_primary_key_unique_not_working(self): """Relates to #1445""" @@ -602,23 +595,21 @@ class TestIndexes(unittest.TestCase): Blog.drop_collection() - with self.assertRaises(OperationFailure) as ctx_err: + with pytest.raises(OperationFailure) as ctx_err: Blog(id="garbage").save() # One of the errors below should happen. Which one depends on the # PyMongo version and dict order. err_msg = str(ctx_err.exception) - self.assertTrue( - any( - [ - "The field 'unique' is not valid for an _id index specification" - in err_msg, - "The field 'background' is not valid for an _id index specification" - in err_msg, - "The field 'sparse' is not valid for an _id index specification" - in err_msg, - ] - ) + assert any( + [ + "The field 'unique' is not valid for an _id index specification" + in err_msg, + "The field 'background' is not valid for an _id index specification" + in err_msg, + "The field 'sparse' is not valid for an _id index specification" + in err_msg, + ] ) def test_unique_with(self): @@ -644,7 +635,8 @@ class TestIndexes(unittest.TestCase): # Now there will be two docs with the same slug and the same day: fail post3 = BlogPost(title="test3", date=Date(year=2010), slug="test") - self.assertRaises(OperationError, post3.save) + with pytest.raises(OperationError): + post3.save() def test_unique_embedded_document(self): """Ensure that uniqueness constraints are applied to fields on embedded documents. @@ -669,7 +661,8 @@ class TestIndexes(unittest.TestCase): # Now there will be two docs with the same sub.slug post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test")) - self.assertRaises(NotUniqueError, post3.save) + with pytest.raises(NotUniqueError): + post3.save() def test_unique_embedded_document_in_list(self): """ @@ -699,7 +692,8 @@ class TestIndexes(unittest.TestCase): post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() def test_unique_embedded_document_in_sorted_list(self): """ @@ -729,12 +723,13 @@ class TestIndexes(unittest.TestCase): # confirm that the unique index is created indexes = BlogPost._get_collection().index_information() - self.assertIn("subs.slug_1", indexes) - self.assertTrue(indexes["subs.slug_1"]["unique"]) + assert "subs.slug_1" in indexes + assert indexes["subs.slug_1"]["unique"] post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() def test_unique_embedded_document_in_embedded_document_list(self): """ @@ -764,12 +759,13 @@ class TestIndexes(unittest.TestCase): # confirm that the unique index is created indexes = BlogPost._get_collection().index_information() - self.assertIn("subs.slug_1", indexes) - self.assertTrue(indexes["subs.slug_1"]["unique"]) + assert "subs.slug_1" in indexes + assert indexes["subs.slug_1"]["unique"] post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() def test_unique_with_embedded_document_and_embedded_unique(self): """Ensure that uniqueness constraints are applied to fields on @@ -795,11 +791,13 @@ class TestIndexes(unittest.TestCase): # Now there will be two docs with the same sub.slug post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test")) - self.assertRaises(NotUniqueError, post3.save) + with pytest.raises(NotUniqueError): + post3.save() # Now there will be two docs with the same title and year post3 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test-1")) - self.assertRaises(NotUniqueError, post3.save) + with pytest.raises(NotUniqueError): + post3.save() def test_ttl_indexes(self): class Log(Document): @@ -811,7 +809,7 @@ class TestIndexes(unittest.TestCase): # Indexes are lazy so use list() to perform query list(Log.objects) info = Log.objects._collection.index_information() - self.assertEqual(3600, info["created_1"]["expireAfterSeconds"]) + assert 3600 == info["created_1"]["expireAfterSeconds"] def test_index_drop_dups_silently_ignored(self): class Customer(Document): @@ -839,14 +837,14 @@ class TestIndexes(unittest.TestCase): cust.save() cust_dupe = Customer(cust_id=1) - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): cust_dupe.save() cust = Customer(cust_id=2) cust.save() # duplicate key on update - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): cust.cust_id = 1 cust.save() @@ -867,8 +865,8 @@ class TestIndexes(unittest.TestCase): user = User(name="huangz", password="secret2") user.save() - self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, "secret2") + assert User.objects.count() == 1 + assert User.objects.get().password == "secret2" def test_unique_and_primary_create(self): """Create a new record with a duplicate primary key @@ -882,11 +880,11 @@ class TestIndexes(unittest.TestCase): User.drop_collection() User.objects.create(name="huangz", password="secret") - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): User.objects.create(name="huangz", password="secret2") - self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, "secret") + assert User.objects.count() == 1 + assert User.objects.get().password == "secret" def test_index_with_pk(self): """Ensure you can use `pk` as part of a query""" @@ -910,7 +908,7 @@ class TestIndexes(unittest.TestCase): info = BlogPost.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] index_item = [("_id", 1), ("comments.comment_id", 1)] - self.assertIn(index_item, info) + assert index_item in info def test_compound_key_embedded(self): class CompoundKey(EmbeddedDocument): @@ -924,10 +922,8 @@ class TestIndexes(unittest.TestCase): my_key = CompoundKey(name="n", term="ok") report = ReportEmbedded(text="OK", key=my_key).save() - self.assertEqual( - {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo() - ) - self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key)) + assert {"text": "OK", "_id": {"term": "ok", "name": "n"}} == report.to_mongo() + assert report == ReportEmbedded.objects.get(pk=my_key) def test_compound_key_dictfield(self): class ReportDictField(Document): @@ -937,15 +933,13 @@ class TestIndexes(unittest.TestCase): my_key = {"name": "n", "term": "ok"} report = ReportDictField(text="OK", key=my_key).save() - self.assertEqual( - {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo() - ) + assert {"text": "OK", "_id": {"term": "ok", "name": "n"}} == report.to_mongo() # We can't directly call ReportDictField.objects.get(pk=my_key), # because dicts are unordered, and if the order in MongoDB is # different than the one in `my_key`, this test will fail. - self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key["name"])) - self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key["term"])) + assert report == ReportDictField.objects.get(pk__name=my_key["name"]) + assert report == ReportDictField.objects.get(pk__term=my_key["term"]) def test_string_indexes(self): class MyDoc(Document): @@ -954,8 +948,8 @@ class TestIndexes(unittest.TestCase): info = MyDoc.objects._collection.index_information() info = [value["key"] for key, value in iteritems(info)] - self.assertIn([("provider_ids.foo", 1)], info) - self.assertIn([("provider_ids.bar", 1)], info) + assert [("provider_ids.foo", 1)] in info + assert [("provider_ids.bar", 1)] in info def test_sparse_compound_indexes(self): class MyDoc(Document): @@ -967,11 +961,10 @@ class TestIndexes(unittest.TestCase): } info = MyDoc.objects._collection.index_information() - self.assertEqual( - [("provider_ids.foo", 1), ("provider_ids.bar", 1)], - info["provider_ids.foo_1_provider_ids.bar_1"]["key"], - ) - self.assertTrue(info["provider_ids.foo_1_provider_ids.bar_1"]["sparse"]) + assert [("provider_ids.foo", 1), ("provider_ids.bar", 1)] == info[ + "provider_ids.foo_1_provider_ids.bar_1" + ]["key"] + assert info["provider_ids.foo_1_provider_ids.bar_1"]["sparse"] def test_text_indexes(self): class Book(Document): @@ -979,9 +972,9 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["$title"]} indexes = Book.objects._collection.index_information() - self.assertIn("title_text", indexes) + assert "title_text" in indexes key = indexes["title_text"]["key"] - self.assertIn(("_fts", "text"), key) + assert ("_fts", "text") in key def test_hashed_indexes(self): class Book(Document): @@ -989,8 +982,8 @@ class TestIndexes(unittest.TestCase): meta = {"indexes": ["#ref_id"]} indexes = Book.objects._collection.index_information() - self.assertIn("ref_id_hashed", indexes) - self.assertIn(("ref_id", "hashed"), indexes["ref_id_hashed"]["key"]) + assert "ref_id_hashed" in indexes + assert ("ref_id", "hashed") in indexes["ref_id_hashed"]["key"] def test_indexes_after_database_drop(self): """ @@ -1027,7 +1020,8 @@ class TestIndexes(unittest.TestCase): # Create Post #2 post2 = BlogPost(title="test2", slug="test") - self.assertRaises(NotUniqueError, post2.save) + with pytest.raises(NotUniqueError): + post2.save() finally: # Drop the temporary database at the end connection.drop_database("tempdatabase") @@ -1074,15 +1068,12 @@ class TestIndexes(unittest.TestCase): "dropDups" ] # drop the index dropDups - it is deprecated in MongoDB 3+ - self.assertEqual( - index_info, - { - "txt_1": {"key": [("txt", 1)], "background": False}, - "_id_": {"key": [("_id", 1)]}, - "txt2_1": {"key": [("txt2", 1)], "background": False}, - "_cls_1": {"key": [("_cls", 1)], "background": False}, - }, - ) + assert index_info == { + "txt_1": {"key": [("txt", 1)], "background": False}, + "_id_": {"key": [("_id", 1)]}, + "txt2_1": {"key": [("txt2", 1)], "background": False}, + "_cls_1": {"key": [("_cls", 1)], "background": False}, + } def test_compound_index_underscore_cls_not_overwritten(self): """ @@ -1105,7 +1096,7 @@ class TestIndexes(unittest.TestCase): TestDoc.ensure_indexes() index_info = TestDoc._get_collection().index_information() - self.assertIn("shard_1_1__cls_1_txt_1_1", index_info) + assert "shard_1_1__cls_1_txt_1_1" in index_info if __name__ == "__main__": diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index 4bb46e58..6a913b3e 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -17,6 +17,7 @@ from mongoengine import ( from mongoengine.pymongo_support import list_collection_names from tests.fixtures import Base from tests.utils import MongoDBTestCase +import pytest class TestInheritance(MongoDBTestCase): @@ -37,12 +38,12 @@ class TestInheritance(MongoDBTestCase): meta = {"allow_inheritance": True} test_doc = DataDoc(name="test", embed=EmbedData(data="data")) - self.assertEqual(test_doc._cls, "DataDoc") - self.assertEqual(test_doc.embed._cls, "EmbedData") + assert test_doc._cls == "DataDoc" + assert test_doc.embed._cls == "EmbedData" test_doc.save() saved_doc = DataDoc.objects.with_id(test_doc.id) - self.assertEqual(test_doc._cls, saved_doc._cls) - self.assertEqual(test_doc.embed._cls, saved_doc.embed._cls) + assert test_doc._cls == saved_doc._cls + assert test_doc.embed._cls == saved_doc.embed._cls test_doc.delete() def test_superclasses(self): @@ -67,12 +68,12 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Fish._superclasses, ("Animal",)) - self.assertEqual(Guppy._superclasses, ("Animal", "Animal.Fish")) - self.assertEqual(Mammal._superclasses, ("Animal",)) - self.assertEqual(Dog._superclasses, ("Animal", "Animal.Mammal")) - self.assertEqual(Human._superclasses, ("Animal", "Animal.Mammal")) + assert Animal._superclasses == () + assert Fish._superclasses == ("Animal",) + assert Guppy._superclasses == ("Animal", "Animal.Fish") + assert Mammal._superclasses == ("Animal",) + assert Dog._superclasses == ("Animal", "Animal.Mammal") + assert Human._superclasses == ("Animal", "Animal.Mammal") def test_external_superclasses(self): """Ensure that the correct list of super classes is assembled when @@ -97,18 +98,12 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual(Animal._superclasses, ("Base",)) - self.assertEqual(Fish._superclasses, ("Base", "Base.Animal")) - self.assertEqual( - Guppy._superclasses, ("Base", "Base.Animal", "Base.Animal.Fish") - ) - self.assertEqual(Mammal._superclasses, ("Base", "Base.Animal")) - self.assertEqual( - Dog._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal") - ) - self.assertEqual( - Human._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal") - ) + assert Animal._superclasses == ("Base",) + assert Fish._superclasses == ("Base", "Base.Animal") + assert Guppy._superclasses == ("Base", "Base.Animal", "Base.Animal.Fish") + assert Mammal._superclasses == ("Base", "Base.Animal") + assert Dog._superclasses == ("Base", "Base.Animal", "Base.Animal.Mammal") + assert Human._superclasses == ("Base", "Base.Animal", "Base.Animal.Mammal") def test_subclasses(self): """Ensure that the correct list of _subclasses (subclasses) is @@ -133,24 +128,22 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual( - Animal._subclasses, - ( - "Animal", - "Animal.Fish", - "Animal.Fish.Guppy", - "Animal.Mammal", - "Animal.Mammal.Dog", - "Animal.Mammal.Human", - ), + assert Animal._subclasses == ( + "Animal", + "Animal.Fish", + "Animal.Fish.Guppy", + "Animal.Mammal", + "Animal.Mammal.Dog", + "Animal.Mammal.Human", ) - self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Guppy")) - self.assertEqual(Guppy._subclasses, ("Animal.Fish.Guppy",)) - self.assertEqual( - Mammal._subclasses, - ("Animal.Mammal", "Animal.Mammal.Dog", "Animal.Mammal.Human"), + assert Fish._subclasses == ("Animal.Fish", "Animal.Fish.Guppy") + assert Guppy._subclasses == ("Animal.Fish.Guppy",) + assert Mammal._subclasses == ( + "Animal.Mammal", + "Animal.Mammal.Dog", + "Animal.Mammal.Human", ) - self.assertEqual(Human._subclasses, ("Animal.Mammal.Human",)) + assert Human._subclasses == ("Animal.Mammal.Human",) def test_external_subclasses(self): """Ensure that the correct list of _subclasses (subclasses) is @@ -175,30 +168,22 @@ class TestInheritance(MongoDBTestCase): class Human(Mammal): pass - self.assertEqual( - Animal._subclasses, - ( - "Base.Animal", - "Base.Animal.Fish", - "Base.Animal.Fish.Guppy", - "Base.Animal.Mammal", - "Base.Animal.Mammal.Dog", - "Base.Animal.Mammal.Human", - ), + assert Animal._subclasses == ( + "Base.Animal", + "Base.Animal.Fish", + "Base.Animal.Fish.Guppy", + "Base.Animal.Mammal", + "Base.Animal.Mammal.Dog", + "Base.Animal.Mammal.Human", ) - self.assertEqual( - Fish._subclasses, ("Base.Animal.Fish", "Base.Animal.Fish.Guppy") + assert Fish._subclasses == ("Base.Animal.Fish", "Base.Animal.Fish.Guppy") + assert Guppy._subclasses == ("Base.Animal.Fish.Guppy",) + assert Mammal._subclasses == ( + "Base.Animal.Mammal", + "Base.Animal.Mammal.Dog", + "Base.Animal.Mammal.Human", ) - self.assertEqual(Guppy._subclasses, ("Base.Animal.Fish.Guppy",)) - self.assertEqual( - Mammal._subclasses, - ( - "Base.Animal.Mammal", - "Base.Animal.Mammal.Dog", - "Base.Animal.Mammal.Human", - ), - ) - self.assertEqual(Human._subclasses, ("Base.Animal.Mammal.Human",)) + assert Human._subclasses == ("Base.Animal.Mammal.Human",) def test_dynamic_declarations(self): """Test that declaring an extra class updates meta data""" @@ -206,33 +191,31 @@ class TestInheritance(MongoDBTestCase): class Animal(Document): meta = {"allow_inheritance": True} - self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ("Animal",)) + assert Animal._superclasses == () + assert Animal._subclasses == ("Animal",) # Test dynamically adding a class changes the meta data class Fish(Animal): pass - self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ("Animal", "Animal.Fish")) + assert Animal._superclasses == () + assert Animal._subclasses == ("Animal", "Animal.Fish") - self.assertEqual(Fish._superclasses, ("Animal",)) - self.assertEqual(Fish._subclasses, ("Animal.Fish",)) + assert Fish._superclasses == ("Animal",) + assert Fish._subclasses == ("Animal.Fish",) # Test dynamically adding an inherited class changes the meta data class Pike(Fish): pass - self.assertEqual(Animal._superclasses, ()) - self.assertEqual( - Animal._subclasses, ("Animal", "Animal.Fish", "Animal.Fish.Pike") - ) + assert Animal._superclasses == () + assert Animal._subclasses == ("Animal", "Animal.Fish", "Animal.Fish.Pike") - self.assertEqual(Fish._superclasses, ("Animal",)) - self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Pike")) + assert Fish._superclasses == ("Animal",) + assert Fish._subclasses == ("Animal.Fish", "Animal.Fish.Pike") - self.assertEqual(Pike._superclasses, ("Animal", "Animal.Fish")) - self.assertEqual(Pike._subclasses, ("Animal.Fish.Pike",)) + assert Pike._superclasses == ("Animal", "Animal.Fish") + assert Pike._subclasses == ("Animal.Fish.Pike",) def test_inheritance_meta_data(self): """Ensure that document may inherit fields from a superclass document. @@ -247,10 +230,10 @@ class TestInheritance(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual( - ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys()) + assert ["_cls", "age", "id", "name", "salary"] == sorted( + Employee._fields.keys() ) - self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) + assert Employee._get_collection_name() == Person._get_collection_name() def test_inheritance_to_mongo_keys(self): """Ensure that document may inherit fields from a superclass document. @@ -265,17 +248,17 @@ class TestInheritance(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual( - ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys()) + assert ["_cls", "age", "id", "name", "salary"] == sorted( + Employee._fields.keys() ) - self.assertEqual( - Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"] - ) - self.assertEqual( - Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ["_cls", "name", "age", "salary"], - ) - self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) + assert Person(name="Bob", age=35).to_mongo().keys() == ["_cls", "name", "age"] + assert Employee(name="Bob", age=35, salary=0).to_mongo().keys() == [ + "_cls", + "name", + "age", + "salary", + ] + assert Employee._get_collection_name() == Person._get_collection_name() def test_indexes_and_multiple_inheritance(self): """ Ensure that all of the indexes are created for a document with @@ -301,13 +284,10 @@ class TestInheritance(MongoDBTestCase): C.ensure_indexes() - self.assertEqual( - sorted( - [idx["key"] for idx in C._get_collection().index_information().values()] - ), - sorted( - [[(u"_cls", 1), (u"b", 1)], [(u"_id", 1)], [(u"_cls", 1), (u"a", 1)]] - ), + assert sorted( + [idx["key"] for idx in C._get_collection().index_information().values()] + ) == sorted( + [[(u"_cls", 1), (u"b", 1)], [(u"_id", 1)], [(u"_cls", 1), (u"a", 1)]] ) def test_polymorphic_queries(self): @@ -338,13 +318,13 @@ class TestInheritance(MongoDBTestCase): Human().save() classes = [obj.__class__ for obj in Animal.objects] - self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + assert classes == [Animal, Fish, Mammal, Dog, Human] classes = [obj.__class__ for obj in Mammal.objects] - self.assertEqual(classes, [Mammal, Dog, Human]) + assert classes == [Mammal, Dog, Human] classes = [obj.__class__ for obj in Human.objects] - self.assertEqual(classes, [Human]) + assert classes == [Human] def test_allow_inheritance(self): """Ensure that inheritance is disabled by default on simple @@ -355,20 +335,20 @@ class TestInheritance(MongoDBTestCase): name = StringField() # can't inherit because Animal didn't explicitly allow inheritance - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: class Dog(Animal): pass - self.assertIn("Document Animal may not be subclassed", str(cm.exception)) + assert "Document Animal may not be subclassed" in str(cm.exception) # Check that _cls etc aren't present on simple documents dog = Animal(name="dog").save() - self.assertEqual(dog.to_mongo().keys(), ["_id", "name"]) + assert dog.to_mongo().keys() == ["_id", "name"] collection = self.db[Animal._get_collection_name()] obj = collection.find_one() - self.assertNotIn("_cls", obj) + assert "_cls" not in obj def test_cant_turn_off_inheritance_on_subclass(self): """Ensure if inheritance is on in a subclass you cant turn it off. @@ -378,14 +358,14 @@ class TestInheritance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: class Mammal(Animal): meta = {"allow_inheritance": False} - self.assertEqual( - str(cm.exception), - 'Only direct subclasses of Document may set "allow_inheritance" to False', + assert ( + str(cm.exception) + == 'Only direct subclasses of Document may set "allow_inheritance" to False' ) def test_allow_inheritance_abstract_document(self): @@ -399,14 +379,14 @@ class TestInheritance(MongoDBTestCase): class Animal(FinalDocument): name = StringField() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class Mammal(Animal): pass # Check that _cls isn't present in simple documents doc = Animal(name="dog") - self.assertNotIn("_cls", doc.to_mongo()) + assert "_cls" not in doc.to_mongo() def test_using_abstract_class_in_reference_field(self): # Ensures no regression of #1920 @@ -452,10 +432,10 @@ class TestInheritance(MongoDBTestCase): name = StringField() berlin = EuropeanCity(name="Berlin", continent="Europe") - self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._fields_ordered), 3) - self.assertEqual(berlin._fields_ordered[0], "id") + assert len(berlin._db_field_map) == len(berlin._fields_ordered) + assert len(berlin._reverse_db_field_map) == len(berlin._fields_ordered) + assert len(berlin._fields_ordered) == 3 + assert berlin._fields_ordered[0] == "id" def test_auto_id_not_set_if_specific_in_parent_class(self): class City(Document): @@ -467,10 +447,10 @@ class TestInheritance(MongoDBTestCase): name = StringField() berlin = EuropeanCity(name="Berlin", continent="Europe") - self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._fields_ordered), 3) - self.assertEqual(berlin._fields_ordered[0], "city_id") + assert len(berlin._db_field_map) == len(berlin._fields_ordered) + assert len(berlin._reverse_db_field_map) == len(berlin._fields_ordered) + assert len(berlin._fields_ordered) == 3 + assert berlin._fields_ordered[0] == "city_id" def test_auto_id_vs_non_pk_id_field(self): class City(Document): @@ -482,12 +462,12 @@ class TestInheritance(MongoDBTestCase): name = StringField() berlin = EuropeanCity(name="Berlin", continent="Europe") - self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) - self.assertEqual(len(berlin._fields_ordered), 4) - self.assertEqual(berlin._fields_ordered[0], "auto_id_0") + assert len(berlin._db_field_map) == len(berlin._fields_ordered) + assert len(berlin._reverse_db_field_map) == len(berlin._fields_ordered) + assert len(berlin._fields_ordered) == 4 + assert berlin._fields_ordered[0] == "auto_id_0" berlin.save() - self.assertEqual(berlin.pk, berlin.auto_id_0) + assert berlin.pk == berlin.auto_id_0 def test_abstract_document_creation_does_not_fail(self): class City(Document): @@ -495,9 +475,9 @@ class TestInheritance(MongoDBTestCase): meta = {"abstract": True, "allow_inheritance": False} city = City(continent="asia") - self.assertEqual(None, city.pk) + assert None == city.pk # TODO: expected error? Shouldn't we create a new error type? - with self.assertRaises(KeyError): + with pytest.raises(KeyError): setattr(city, "pk", 1) def test_allow_inheritance_embedded_document(self): @@ -506,20 +486,20 @@ class TestInheritance(MongoDBTestCase): class Comment(EmbeddedDocument): content = StringField() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class SpecialComment(Comment): pass doc = Comment(content="test") - self.assertNotIn("_cls", doc.to_mongo()) + assert "_cls" not in doc.to_mongo() class Comment(EmbeddedDocument): content = StringField() meta = {"allow_inheritance": True} doc = Comment(content="test") - self.assertIn("_cls", doc.to_mongo()) + assert "_cls" in doc.to_mongo() def test_document_inheritance(self): """Ensure mutliple inheritance of abstract documents @@ -537,7 +517,7 @@ class TestInheritance(MongoDBTestCase): pass except Exception: - self.assertTrue(False, "Couldn't create MyDocument class") + assert False, "Couldn't create MyDocument class" def test_abstract_documents(self): """Ensure that a document superclass can be marked as abstract @@ -574,20 +554,20 @@ class TestInheritance(MongoDBTestCase): for k, v in iteritems(defaults): for cls in [Animal, Fish, Guppy]: - self.assertEqual(cls._meta[k], v) + assert cls._meta[k] == v - self.assertNotIn("collection", Animal._meta) - self.assertNotIn("collection", Mammal._meta) + assert "collection" not in Animal._meta + assert "collection" not in Mammal._meta - self.assertEqual(Animal._get_collection_name(), None) - self.assertEqual(Mammal._get_collection_name(), None) + assert Animal._get_collection_name() == None + assert Mammal._get_collection_name() == None - self.assertEqual(Fish._get_collection_name(), "fish") - self.assertEqual(Guppy._get_collection_name(), "fish") - self.assertEqual(Human._get_collection_name(), "human") + assert Fish._get_collection_name() == "fish" + assert Guppy._get_collection_name() == "fish" + assert Human._get_collection_name() == "human" # ensure that a subclass of a non-abstract class can't be abstract - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class EvilHuman(Human): evil = BooleanField(default=True) @@ -601,7 +581,7 @@ class TestInheritance(MongoDBTestCase): class B(A): pass - self.assertFalse(B._meta["abstract"]) + assert not B._meta["abstract"] def test_inherited_collections(self): """Ensure that subclassed documents don't override parents' @@ -647,8 +627,8 @@ class TestInheritance(MongoDBTestCase): real_person = Drinker(drink=beer) real_person.save() - self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) - self.assertEqual(Drinker.objects[1].drink.name, beer.name) + assert Drinker.objects[0].drink.name == red_bull.name + assert Drinker.objects[1].drink.name == beer.name if __name__ == "__main__": diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 203e2cce..01dc492b 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -36,6 +36,7 @@ from tests.fixtures import ( PickleTest, ) from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") @@ -63,18 +64,17 @@ class TestInstance(MongoDBTestCase): self.db.drop_collection(collection) def assertDbEqual(self, docs): - self.assertEqual( - list(self.Person._get_collection().find().sort("id")), - sorted(docs, key=lambda doc: doc["_id"]), + assert list(self.Person._get_collection().find().sort("id")) == sorted( + docs, key=lambda doc: doc["_id"] ) def assertHasInstance(self, field, instance): - self.assertTrue(hasattr(field, "_instance")) - self.assertTrue(field._instance is not None) + assert hasattr(field, "_instance") + assert field._instance is not None if isinstance(field._instance, weakref.ProxyType): - self.assertTrue(field._instance.__eq__(instance)) + assert field._instance.__eq__(instance) else: - self.assertEqual(field._instance, instance) + assert field._instance == instance def test_capped_collection(self): """Ensure that capped collections work properly.""" @@ -89,16 +89,16 @@ class TestInstance(MongoDBTestCase): for _ in range(10): Log().save() - self.assertEqual(Log.objects.count(), 10) + assert Log.objects.count() == 10 # Check that extra documents don't increase the size Log().save() - self.assertEqual(Log.objects.count(), 10) + assert Log.objects.count() == 10 options = Log.objects._collection.options() - self.assertEqual(options["capped"], True) - self.assertEqual(options["max"], 10) - self.assertEqual(options["size"], 4096) + assert options["capped"] == True + assert options["max"] == 10 + assert options["size"] == 4096 # Check that the document cannot be redefined with different options class Log(Document): @@ -106,7 +106,7 @@ class TestInstance(MongoDBTestCase): meta = {"max_documents": 11} # Accessing Document.objects creates the collection - with self.assertRaises(InvalidCollectionError): + with pytest.raises(InvalidCollectionError): Log.objects def test_capped_collection_default(self): @@ -122,9 +122,9 @@ class TestInstance(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - self.assertEqual(options["capped"], True) - self.assertEqual(options["max"], 10) - self.assertEqual(options["size"], 10 * 2 ** 20) + assert options["capped"] == True + assert options["max"] == 10 + assert options["size"] == 10 * 2 ** 20 # Check that the document with default value can be recreated class Log(Document): @@ -150,8 +150,8 @@ class TestInstance(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - self.assertEqual(options["capped"], True) - self.assertTrue(options["size"] >= 10000) + assert options["capped"] == True + assert options["size"] >= 10000 # Check that the document with odd max_size value can be recreated class Log(Document): @@ -173,7 +173,7 @@ class TestInstance(MongoDBTestCase): doc = Article(title=u"привет мир") - self.assertEqual("", repr(doc)) + assert "" == repr(doc) def test_repr_none(self): """Ensure None values are handled correctly.""" @@ -185,11 +185,11 @@ class TestInstance(MongoDBTestCase): return None doc = Article(title=u"привет мир") - self.assertEqual("", repr(doc)) + assert "" == repr(doc) def test_queryset_resurrects_dropped_collection(self): self.Person.drop_collection() - self.assertEqual([], list(self.Person.objects())) + assert [] == list(self.Person.objects()) # Ensure works correctly with inhertited classes class Actor(self.Person): @@ -197,7 +197,7 @@ class TestInstance(MongoDBTestCase): Actor.objects() self.Person.drop_collection() - self.assertEqual([], list(Actor.objects())) + assert [] == list(Actor.objects()) def test_polymorphic_references(self): """Ensure that the correct subclasses are returned from a query @@ -237,7 +237,7 @@ class TestInstance(MongoDBTestCase): zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + assert classes == [Animal, Fish, Mammal, Dog, Human] Zoo.drop_collection() @@ -250,7 +250,7 @@ class TestInstance(MongoDBTestCase): zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + assert classes == [Animal, Fish, Mammal, Dog, Human] def test_reference_inheritance(self): class Stats(Document): @@ -275,7 +275,7 @@ class TestInstance(MongoDBTestCase): cmp_stats = CompareStats(stats=list_stats) cmp_stats.save() - self.assertEqual(list_stats, CompareStats.objects.first().stats) + assert list_stats == CompareStats.objects.first().stats def test_db_field_load(self): """Ensure we load data correctly from the right db field.""" @@ -294,8 +294,8 @@ class TestInstance(MongoDBTestCase): Person(name="Fred").save() - self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") - self.assertEqual(Person.objects.get(name="Fred").rank, "Private") + assert Person.objects.get(name="Jack").rank == "Corporal" + assert Person.objects.get(name="Fred").rank == "Private" def test_db_embedded_doc_field_load(self): """Ensure we load embedded document data correctly.""" @@ -318,8 +318,8 @@ class TestInstance(MongoDBTestCase): Person(name="Jack", rank_=Rank(title="Corporal")).save() Person(name="Fred").save() - self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") - self.assertEqual(Person.objects.get(name="Fred").rank, "Private") + assert Person.objects.get(name="Jack").rank == "Corporal" + assert Person.objects.get(name="Fred").rank == "Private" def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys.""" @@ -332,15 +332,15 @@ class TestInstance(MongoDBTestCase): User.drop_collection() - self.assertEqual(User._fields["username"].db_field, "_id") - self.assertEqual(User._meta["id_field"], "username") + assert User._fields["username"].db_field == "_id" + assert User._meta["id_field"] == "username" User.objects.create(username="test", name="test user") user = User.objects.first() - self.assertEqual(user.id, "test") - self.assertEqual(user.pk, "test") + assert user.id == "test" + assert user.pk == "test" user_dict = User.objects._collection.find_one() - self.assertEqual(user_dict["_id"], "test") + assert user_dict["_id"] == "test" def test_change_custom_id_field_in_subclass(self): """Subclasses cannot override which field is the primary key.""" @@ -350,13 +350,13 @@ class TestInstance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError) as e: class EmailUser(User): email = StringField(primary_key=True) exc = e.exception - self.assertEqual(str(exc), "Cannot override primary key field") + assert str(exc) == "Cannot override primary key field" def test_custom_id_field_is_required(self): """Ensure the custom primary key field is required.""" @@ -365,10 +365,10 @@ class TestInstance(MongoDBTestCase): username = StringField(primary_key=True) name = StringField() - with self.assertRaises(ValidationError) as e: + with pytest.raises(ValidationError) as e: User(name="test").save() exc = e.exception - self.assertTrue("Field is required: ['username']" in str(exc)) + assert "Field is required: ['username']" in str(exc) def test_document_not_registered(self): class Place(Document): @@ -388,7 +388,7 @@ class TestInstance(MongoDBTestCase): # and the NicePlace model not being imported in at query time. del _document_registry["Place.NicePlace"] - with self.assertRaises(NotRegistered): + with pytest.raises(NotRegistered): list(Place.objects.all()) def test_document_registry_regressions(self): @@ -401,26 +401,27 @@ class TestInstance(MongoDBTestCase): Location.drop_collection() - self.assertEqual(Area, get_document("Area")) - self.assertEqual(Area, get_document("Location.Area")) + assert Area == get_document("Area") + assert Area == get_document("Location.Area") def test_creation(self): """Ensure that document may be created using keyword arguments.""" person = self.Person(name="Test User", age=30) - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 30) + assert person.name == "Test User" + assert person.age == 30 def test_to_dbref(self): """Ensure that you can get a dbref of a document.""" person = self.Person(name="Test User", age=30) - self.assertRaises(OperationError, person.to_dbref) + with pytest.raises(OperationError): + person.to_dbref() person.save() person.to_dbref() def test_key_like_attribute_access(self): person = self.Person(age=30) - self.assertEqual(person["age"], 30) - with self.assertRaises(KeyError): + assert person["age"] == 30 + with pytest.raises(KeyError): person["unknown_attr"] def test_save_abstract_document(self): @@ -430,7 +431,7 @@ class TestInstance(MongoDBTestCase): name = StringField() meta = {"abstract": True} - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): Doc(name="aaa").save() def test_reload(self): @@ -443,20 +444,20 @@ class TestInstance(MongoDBTestCase): person_obj.age = 21 person_obj.save() - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 20) + assert person.name == "Test User" + assert person.age == 20 person.reload("age") - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 21) + assert person.name == "Test User" + assert person.age == 21 person.reload() - self.assertEqual(person.name, "Mr Test User") - self.assertEqual(person.age, 21) + assert person.name == "Mr Test User" + assert person.age == 21 person.reload() - self.assertEqual(person.name, "Mr Test User") - self.assertEqual(person.age, 21) + assert person.name == "Mr Test User" + assert person.age == 21 def test_reload_sharded(self): class Animal(Document): @@ -471,9 +472,10 @@ class TestInstance(MongoDBTestCase): with query_counter() as q: doc.reload() query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] - self.assertEqual( - set(query_op[CMD_QUERY_KEY]["filter"].keys()), {"_id", "superphylum"} - ) + assert set(query_op[CMD_QUERY_KEY]["filter"].keys()) == { + "_id", + "superphylum", + } def test_reload_sharded_with_db_field(self): class Person(Document): @@ -488,9 +490,7 @@ class TestInstance(MongoDBTestCase): with query_counter() as q: doc.reload() query_op = q.db.system.profile.find({"ns": "mongoenginetest.person"})[0] - self.assertEqual( - set(query_op[CMD_QUERY_KEY]["filter"].keys()), {"_id", "country"} - ) + assert set(query_op[CMD_QUERY_KEY]["filter"].keys()) == {"_id", "country"} def test_reload_sharded_nested(self): class SuperPhylum(EmbeddedDocument): @@ -526,15 +526,11 @@ class TestInstance(MongoDBTestCase): doc.name = "Cat" doc.save() query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] - self.assertEqual(query_op["op"], "update") + assert query_op["op"] == "update" if mongo_db <= MONGODB_34: - self.assertEqual( - set(query_op["query"].keys()), set(["_id", "is_mammal"]) - ) + assert set(query_op["query"].keys()) == set(["_id", "is_mammal"]) else: - self.assertEqual( - set(query_op["command"]["q"].keys()), set(["_id", "is_mammal"]) - ) + assert set(query_op["command"]["q"].keys()) == set(["_id", "is_mammal"]) Animal.drop_collection() @@ -551,12 +547,12 @@ class TestInstance(MongoDBTestCase): user.name = "John" user.number = 2 - self.assertEqual(user._get_changed_fields(), ["name", "number"]) + assert user._get_changed_fields() == ["name", "number"] user.reload("number") - self.assertEqual(user._get_changed_fields(), ["name"]) + assert user._get_changed_fields() == ["name"] user.save() user.reload() - self.assertEqual(user.name, "John") + assert user.name == "John" def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly.""" @@ -587,47 +583,44 @@ class TestInstance(MongoDBTestCase): doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field["woot"] = "woot" - self.assertEqual( - doc._get_changed_fields(), - [ - "list_field", - "dict_field.woot", - "embedded_field.list_field", - "embedded_field.dict_field.woot", - ], - ) + assert doc._get_changed_fields() == [ + "list_field", + "dict_field.woot", + "embedded_field.list_field", + "embedded_field.dict_field.woot", + ] doc.save() - self.assertEqual(len(doc.list_field), 4) + assert len(doc.list_field) == 4 doc = doc.reload(10) - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(len(doc.list_field), 4) - self.assertEqual(len(doc.dict_field), 2) - self.assertEqual(len(doc.embedded_field.list_field), 4) - self.assertEqual(len(doc.embedded_field.dict_field), 2) + assert doc._get_changed_fields() == [] + assert len(doc.list_field) == 4 + assert len(doc.dict_field) == 2 + assert len(doc.embedded_field.list_field) == 4 + assert len(doc.embedded_field.dict_field) == 2 doc.list_field.append(1) doc.save() doc.dict_field["extra"] = 1 doc = doc.reload(10, "list_field") - self.assertEqual(doc._get_changed_fields(), ["dict_field.extra"]) - self.assertEqual(len(doc.list_field), 5) - self.assertEqual(len(doc.dict_field), 3) - self.assertEqual(len(doc.embedded_field.list_field), 4) - self.assertEqual(len(doc.embedded_field.dict_field), 2) + assert doc._get_changed_fields() == ["dict_field.extra"] + assert len(doc.list_field) == 5 + assert len(doc.dict_field) == 3 + assert len(doc.embedded_field.list_field) == 4 + assert len(doc.embedded_field.dict_field) == 2 def test_reload_doesnt_exist(self): class Foo(Document): pass f = Foo() - with self.assertRaises(Foo.DoesNotExist): + with pytest.raises(Foo.DoesNotExist): f.reload() f.save() f.delete() - with self.assertRaises(Foo.DoesNotExist): + with pytest.raises(Foo.DoesNotExist): f.reload() def test_reload_of_non_strict_with_special_field_name(self): @@ -646,27 +639,29 @@ class TestInstance(MongoDBTestCase): post = Post.objects.first() post.reload() - self.assertEqual(post.title, "Items eclipse") - self.assertEqual(post.items, ["more lorem", "even more ipsum"]) + assert post.title == "Items eclipse" + assert post.items == ["more lorem", "even more ipsum"] def test_dictionary_access(self): """Ensure that dictionary-style field access works properly.""" person = self.Person(name="Test User", age=30, job=self.Job()) - self.assertEqual(person["name"], "Test User") + assert person["name"] == "Test User" - self.assertRaises(KeyError, person.__getitem__, "salary") - self.assertRaises(KeyError, person.__setitem__, "salary", 50) + with pytest.raises(KeyError): + person.__getitem__("salary") + with pytest.raises(KeyError): + person.__setitem__("salary", 50) person["name"] = "Another User" - self.assertEqual(person["name"], "Another User") + assert person["name"] == "Another User" # Length = length(assigned fields + id) - self.assertEqual(len(person), 5) + assert len(person) == 5 - self.assertIn("age", person) + assert "age" in person person.age = None - self.assertNotIn("age", person) - self.assertNotIn("nationality", person) + assert "age" not in person + assert "nationality" not in person def test_embedded_document_to_mongo(self): class Person(EmbeddedDocument): @@ -678,20 +673,20 @@ class TestInstance(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual( - Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"] - ) - self.assertEqual( - Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ["_cls", "name", "age", "salary"], - ) + assert Person(name="Bob", age=35).to_mongo().keys() == ["_cls", "name", "age"] + assert Employee(name="Bob", age=35, salary=0).to_mongo().keys() == [ + "_cls", + "name", + "age", + "salary", + ] def test_embedded_document_to_mongo_id(self): class SubDoc(EmbeddedDocument): id = StringField(required=True) sub_doc = SubDoc(id="abc") - self.assertEqual(sub_doc.to_mongo().keys(), ["id"]) + assert sub_doc.to_mongo().keys() == ["id"] def test_embedded_document(self): """Ensure that embedded documents are set up correctly.""" @@ -699,8 +694,8 @@ class TestInstance(MongoDBTestCase): class Comment(EmbeddedDocument): content = StringField() - self.assertIn("content", Comment._fields) - self.assertNotIn("id", Comment._fields) + assert "content" in Comment._fields + assert "id" not in Comment._fields def test_embedded_document_instance(self): """Ensure that embedded documents can reference parent instance.""" @@ -753,7 +748,7 @@ class TestInstance(MongoDBTestCase): .to_mongo(use_db_field=False) .to_dict() ) - self.assertEqual(d["embedded_field"], [{"string": "Hi"}]) + assert d["embedded_field"] == [{"string": "Hi"}] def test_instance_is_set_on_setattr(self): class Email(EmbeddedDocument): @@ -796,7 +791,7 @@ class TestInstance(MongoDBTestCase): def clean(self): raise CustomError() - with self.assertRaises(CustomError): + with pytest.raises(CustomError): TestDocument().save() TestDocument().save(clean=False) @@ -816,10 +811,10 @@ class TestInstance(MongoDBTestCase): BlogPost.drop_collection() post = BlogPost(content="unchecked").save() - self.assertEqual(post.content, "checked") + assert post.content == "checked" # Make sure pre_save_post_validation changes makes it to the db raw_doc = get_as_pymongo(post) - self.assertEqual(raw_doc, {"content": "checked", "_id": post.id}) + assert raw_doc == {"content": "checked", "_id": post.id} # Important to disconnect as it could cause some assertions in test_signals # to fail (due to the garbage collection timing of this signal) @@ -840,17 +835,17 @@ class TestInstance(MongoDBTestCase): # Ensure clean=False prevent call to clean t = TestDocument(status="published") t.save(clean=False) - self.assertEqual(t.status, "published") - self.assertEqual(t.cleaned, False) + assert t.status == "published" + assert t.cleaned == False t = TestDocument(status="published") - self.assertEqual(t.cleaned, False) + assert t.cleaned == False t.save(clean=True) - self.assertEqual(t.status, "published") - self.assertEqual(t.cleaned, True) + assert t.status == "published" + assert t.cleaned == True raw_doc = get_as_pymongo(t) # Make sure clean changes makes it to the db - self.assertEqual(raw_doc, {"status": "published", "cleaned": True, "_id": t.id}) + assert raw_doc == {"status": "published", "cleaned": True, "_id": t.id} def test_document_embedded_clean(self): class TestEmbeddedDocument(EmbeddedDocument): @@ -875,15 +870,15 @@ class TestInstance(MongoDBTestCase): t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) - with self.assertRaises(ValidationError) as cm: + with pytest.raises(ValidationError) as cm: t.save() expected_msg = "Value of z != x + y" - self.assertIn(expected_msg, cm.exception.message) - self.assertEqual(cm.exception.to_dict(), {"doc": {"__all__": expected_msg}}) + assert expected_msg in cm.exception.message + assert cm.exception.to_dict() == {"doc": {"__all__": expected_msg}} t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() - self.assertEqual(t.doc.z, 35) + assert t.doc.z == 35 # Asserts not raises t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) @@ -892,7 +887,7 @@ class TestInstance(MongoDBTestCase): def test_modify_empty(self): doc = self.Person(name="bob", age=10).save() - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): self.Person().modify(set__age=10) self.assertDbEqual([dict(doc.to_mongo())]) @@ -902,7 +897,7 @@ class TestInstance(MongoDBTestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): doc1.modify({"id": doc2.id}, set__value=20) self.assertDbEqual(docs) @@ -913,7 +908,7 @@ class TestInstance(MongoDBTestCase): docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] n_modified = doc1.modify({"name": doc2.name}, set__age=100) - self.assertEqual(n_modified, 0) + assert n_modified == 0 self.assertDbEqual(docs) @@ -923,7 +918,7 @@ class TestInstance(MongoDBTestCase): docs = [dict(doc1.to_mongo())] n_modified = doc2.modify({"name": doc2.name}, set__age=100) - self.assertEqual(n_modified, 0) + assert n_modified == 0 self.assertDbEqual(docs) @@ -943,13 +938,13 @@ class TestInstance(MongoDBTestCase): n_modified = doc.modify( set__age=21, set__job__name="MongoDB", unset__job__years=True ) - self.assertEqual(n_modified, 1) + assert n_modified == 1 doc_copy.age = 21 doc_copy.job.name = "MongoDB" del doc_copy.job.years - self.assertEqual(doc.to_json(), doc_copy.to_json()) - self.assertEqual(doc._get_changed_fields(), []) + assert doc.to_json() == doc_copy.to_json() + assert doc._get_changed_fields() == [] self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) @@ -965,27 +960,25 @@ class TestInstance(MongoDBTestCase): tags=["python"], content=Content(keywords=["ipsum"]) ) - self.assertEqual(post.tags, ["python"]) + assert post.tags == ["python"] post.modify(push__tags__0=["code", "mongo"]) - self.assertEqual(post.tags, ["code", "mongo", "python"]) + assert post.tags == ["code", "mongo", "python"] # Assert same order of the list items is maintained in the db - self.assertEqual( - BlogPost._get_collection().find_one({"_id": post.pk})["tags"], - ["code", "mongo", "python"], - ) + assert BlogPost._get_collection().find_one({"_id": post.pk})["tags"] == [ + "code", + "mongo", + "python", + ] - self.assertEqual(post.content.keywords, ["ipsum"]) + assert post.content.keywords == ["ipsum"] post.modify(push__content__keywords__0=["lorem"]) - self.assertEqual(post.content.keywords, ["lorem", "ipsum"]) + assert post.content.keywords == ["lorem", "ipsum"] # Assert same order of the list items is maintained in the db - self.assertEqual( - BlogPost._get_collection().find_one({"_id": post.pk})["content"][ - "keywords" - ], - ["lorem", "ipsum"], - ) + assert BlogPost._get_collection().find_one({"_id": post.pk})["content"][ + "keywords" + ] == ["lorem", "ipsum"] def test_save(self): """Ensure that a document may be saved in the database.""" @@ -996,28 +989,30 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database raw_doc = get_as_pymongo(person) - self.assertEqual( - raw_doc, - {"_cls": "Person", "name": "Test User", "age": 30, "_id": person.id}, - ) + assert raw_doc == { + "_cls": "Person", + "name": "Test User", + "age": 30, + "_id": person.id, + } def test_save_skip_validation(self): class Recipient(Document): email = EmailField(required=True) recipient = Recipient(email="not-an-email") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): recipient.save() recipient.save(validate=False) raw_doc = get_as_pymongo(recipient) - self.assertEqual(raw_doc, {"email": "not-an-email", "_id": recipient.id}) + assert raw_doc == {"email": "not-an-email", "_id": recipient.id} def test_save_with_bad_id(self): class Clown(Document): id = IntField(primary_key=True) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Clown(id="not_an_int").save() def test_save_to_a_value_that_equates_to_false(self): @@ -1037,7 +1032,7 @@ class TestInstance(MongoDBTestCase): user.save() user.reload() - self.assertEqual(user.thing.count, 0) + assert user.thing.count == 0 def test_save_max_recursion_not_hit(self): class Person(Document): @@ -1085,7 +1080,7 @@ class TestInstance(MongoDBTestCase): b.name = "world" b.save() - self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) + assert b.picture == b.bar.picture, b.bar.bar.picture def test_save_cascades(self): class Person(Document): @@ -1107,7 +1102,7 @@ class TestInstance(MongoDBTestCase): p.save(cascade=True) p1.reload() - self.assertEqual(p1.name, p.parent.name) + assert p1.name == p.parent.name def test_save_cascade_kwargs(self): class Person(Document): @@ -1127,7 +1122,7 @@ class TestInstance(MongoDBTestCase): p1.reload() p2.reload() - self.assertEqual(p1.name, p2.parent.name) + assert p1.name == p2.parent.name def test_save_cascade_meta_false(self): class Person(Document): @@ -1151,11 +1146,11 @@ class TestInstance(MongoDBTestCase): p.save() p1.reload() - self.assertNotEqual(p1.name, p.parent.name) + assert p1.name != p.parent.name p.save(cascade=True) p1.reload() - self.assertEqual(p1.name, p.parent.name) + assert p1.name == p.parent.name def test_save_cascade_meta_true(self): class Person(Document): @@ -1179,7 +1174,7 @@ class TestInstance(MongoDBTestCase): p.save() p1.reload() - self.assertNotEqual(p1.name, p.parent.name) + assert p1.name != p.parent.name def test_save_cascades_generically(self): class Person(Document): @@ -1200,11 +1195,11 @@ class TestInstance(MongoDBTestCase): p.save() p1.reload() - self.assertNotEqual(p1.name, p.parent.name) + assert p1.name != p.parent.name p.save(cascade=True) p1.reload() - self.assertEqual(p1.name, p.parent.name) + assert p1.name == p.parent.name def test_save_atomicity_condition(self): class Widget(Document): @@ -1226,64 +1221,61 @@ class TestInstance(MongoDBTestCase): # ignore save_condition on new record creation w1.save(save_condition={"save_id": UUID(42)}) w1.reload() - self.assertFalse(w1.toggle) - self.assertEqual(w1.save_id, UUID(1)) - self.assertEqual(w1.count, 0) + assert not w1.toggle + assert w1.save_id == UUID(1) + assert w1.count == 0 # mismatch in save_condition prevents save and raise exception flip(w1) - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 1) - self.assertRaises( - SaveConditionError, w1.save, save_condition={"save_id": UUID(42)} - ) + assert w1.toggle + assert w1.count == 1 + with pytest.raises(SaveConditionError): + w1.save(save_condition={"save_id": UUID(42)}) w1.reload() - self.assertFalse(w1.toggle) - self.assertEqual(w1.count, 0) + assert not w1.toggle + assert w1.count == 0 # matched save_condition allows save flip(w1) - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 1) + assert w1.toggle + assert w1.count == 1 w1.save(save_condition={"save_id": UUID(1)}) w1.reload() - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 1) + assert w1.toggle + assert w1.count == 1 # save_condition can be used to ensure atomic read & updates # i.e., prevent interleaved reads and writes from separate contexts w2 = Widget.objects.get() - self.assertEqual(w1, w2) + assert w1 == w2 old_id = w1.save_id flip(w1) w1.save_id = UUID(2) w1.save(save_condition={"save_id": old_id}) w1.reload() - self.assertFalse(w1.toggle) - self.assertEqual(w1.count, 2) + assert not w1.toggle + assert w1.count == 2 flip(w2) flip(w2) - self.assertRaises( - SaveConditionError, w2.save, save_condition={"save_id": old_id} - ) + with pytest.raises(SaveConditionError): + w2.save(save_condition={"save_id": old_id}) w2.reload() - self.assertFalse(w2.toggle) - self.assertEqual(w2.count, 2) + assert not w2.toggle + assert w2.count == 2 # save_condition uses mongoengine-style operator syntax flip(w1) w1.save(save_condition={"count__lt": w1.count}) w1.reload() - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 3) + assert w1.toggle + assert w1.count == 3 flip(w1) - self.assertRaises( - SaveConditionError, w1.save, save_condition={"count__gte": w1.count} - ) + with pytest.raises(SaveConditionError): + w1.save(save_condition={"count__gte": w1.count}) w1.reload() - self.assertTrue(w1.toggle) - self.assertEqual(w1.count, 3) + assert w1.toggle + assert w1.count == 3 def test_save_update_selectively(self): class WildBoy(Document): @@ -1303,8 +1295,8 @@ class TestInstance(MongoDBTestCase): boy2.save() fresh_boy = WildBoy.objects().first() - self.assertEqual(fresh_boy.age, 99) - self.assertEqual(fresh_boy.name, "Bob") + assert fresh_boy.age == 99 + assert fresh_boy.name == "Bob" def test_save_update_selectively_with_custom_pk(self): # Prevents regression of #2082 @@ -1326,8 +1318,8 @@ class TestInstance(MongoDBTestCase): boy2.save() fresh_boy = WildBoy.objects().first() - self.assertEqual(fresh_boy.age, 99) - self.assertEqual(fresh_boy.name, "Bob") + assert fresh_boy.age == 99 + assert fresh_boy.name == "Bob" def test_update(self): """Ensure that an existing document is updated instead of be @@ -1343,20 +1335,20 @@ class TestInstance(MongoDBTestCase): same_person.save() # Confirm only one object - self.assertEqual(self.Person.objects.count(), 1) + assert self.Person.objects.count() == 1 # reload person.reload() same_person.reload() # Confirm the same - self.assertEqual(person, same_person) - self.assertEqual(person.name, same_person.name) - self.assertEqual(person.age, same_person.age) + assert person == same_person + assert person.name == same_person.name + assert person.age == same_person.age # Confirm the saved values - self.assertEqual(person.name, "Test") - self.assertEqual(person.age, 30) + assert person.name == "Test" + assert person.age == 30 # Test only / exclude only updates included fields person = self.Person.objects.only("name").get() @@ -1364,8 +1356,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 30) + assert person.name == "User" + assert person.age == 30 # test exclude only updates set fields person = self.Person.objects.exclude("name").get() @@ -1373,8 +1365,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 21) + assert person.name == "User" + assert person.age == 21 # Test only / exclude can set non excluded / included fields person = self.Person.objects.only("name").get() @@ -1383,8 +1375,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "Test") - self.assertEqual(person.age, 30) + assert person.name == "Test" + assert person.age == 30 # test exclude only updates set fields person = self.Person.objects.exclude("name").get() @@ -1393,8 +1385,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 21) + assert person.name == "User" + assert person.age == 21 # Confirm does remove unrequired fields person = self.Person.objects.exclude("name").get() @@ -1402,8 +1394,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, None) + assert person.name == "User" + assert person.age == None person = self.Person.objects.get() person.name = None @@ -1411,20 +1403,20 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - self.assertEqual(person.name, None) - self.assertEqual(person.age, None) + assert person.name == None + assert person.age == None def test_update_rename_operator(self): """Test the $rename operator.""" coll = self.Person._get_collection() doc = self.Person(name="John").save() raw_doc = coll.find_one({"_id": doc.pk}) - self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "name"])) + assert set(raw_doc.keys()) == set(["_id", "_cls", "name"]) doc.update(rename__name="first_name") raw_doc = coll.find_one({"_id": doc.pk}) - self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "first_name"])) - self.assertEqual(raw_doc["first_name"], "John") + assert set(raw_doc.keys()) == set(["_id", "_cls", "first_name"]) + assert raw_doc["first_name"] == "John" def test_inserts_if_you_set_the_pk(self): p1 = self.Person(name="p1", id=bson.ObjectId()).save() @@ -1432,7 +1424,7 @@ class TestInstance(MongoDBTestCase): p2.id = bson.ObjectId() p2.save() - self.assertEqual(2, self.Person.objects.count()) + assert 2 == self.Person.objects.count() def test_can_save_if_not_included(self): class EmbeddedDoc(EmbeddedDocument): @@ -1480,13 +1472,13 @@ class TestInstance(MongoDBTestCase): my_doc.save() my_doc = Doc.objects.get(string_field="string") - self.assertEqual(my_doc.string_field, "string") - self.assertEqual(my_doc.int_field, 1) + assert my_doc.string_field == "string" + assert my_doc.int_field == 1 def test_document_update(self): # try updating a non-saved document - with self.assertRaises(OperationError): + with pytest.raises(OperationError): person = self.Person(name="dcrosta") person.update(set__name="Dan Crosta") @@ -1497,10 +1489,10 @@ class TestInstance(MongoDBTestCase): author.reload() p1 = self.Person.objects.first() - self.assertEqual(p1.name, author.name) + assert p1.name == author.name # try sending an empty update - with self.assertRaises(OperationError): + with pytest.raises(OperationError): person = self.Person.objects.first() person.update() @@ -1509,7 +1501,7 @@ class TestInstance(MongoDBTestCase): person = self.Person.objects.first() person.update(name="Dan") person.reload() - self.assertEqual("Dan", person.name) + assert "Dan" == person.name def test_update_unique_field(self): class Doc(Document): @@ -1518,7 +1510,7 @@ class TestInstance(MongoDBTestCase): doc1 = Doc(name="first").save() doc2 = Doc(name="second").save() - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): doc2.update(set__name=doc1.name) def test_embedded_update(self): @@ -1540,7 +1532,7 @@ class TestInstance(MongoDBTestCase): site.save() site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") + assert site.page.log_message == "Error: Dummy message" def test_update_list_field(self): """Test update on `ListField` with $pull + $in. @@ -1558,7 +1550,7 @@ class TestInstance(MongoDBTestCase): doc.update(pull__foo__in=["a", "c"]) doc = Doc.objects.first() - self.assertEqual(doc.foo, ["b"]) + assert doc.foo == ["b"] def test_embedded_update_db_field(self): """Test update on `EmbeddedDocumentField` fields when db_field @@ -1584,7 +1576,7 @@ class TestInstance(MongoDBTestCase): site.save() site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") + assert site.page.log_message == "Error: Dummy message" def test_save_only_changed_fields(self): """Ensure save only sets / unsets changed fields.""" @@ -1610,9 +1602,9 @@ class TestInstance(MongoDBTestCase): same_person.save() person = self.Person.objects.get() - self.assertEqual(person.name, "User") - self.assertEqual(person.age, 21) - self.assertEqual(person.active, False) + assert person.name == "User" + assert person.age == 21 + assert person.active == False def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( self, @@ -1626,7 +1618,7 @@ class TestInstance(MongoDBTestCase): emb = EmbeddedChildModel(id={"1": [1]}) changed_fields = ParentModel(child=emb)._get_changed_fields() - self.assertEqual(changed_fields, []) + assert changed_fields == [] def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc( self, @@ -1647,8 +1639,8 @@ class TestInstance(MongoDBTestCase): message = Message(id=1, author=user).save() message.author.name = "tutu" - self.assertEqual(message._get_changed_fields(), []) - self.assertEqual(user._get_changed_fields(), ["name"]) + assert message._get_changed_fields() == [] + assert user._get_changed_fields() == ["name"] def test__get_changed_fields_same_ids_embedded(self): # Refers to Issue #1768 @@ -1667,11 +1659,11 @@ class TestInstance(MongoDBTestCase): message = Message(id=1, author=user).save() message.author.name = "tutu" - self.assertEqual(message._get_changed_fields(), ["author.name"]) + assert message._get_changed_fields() == ["author.name"] message.save() message_fetched = Message.objects.with_id(message.id) - self.assertEqual(message_fetched.author.name, "tutu") + assert message_fetched.author.name == "tutu" def test_query_count_when_saving(self): """Ensure references don't cause extra fetches when saving""" @@ -1707,65 +1699,65 @@ class TestInstance(MongoDBTestCase): user = User.objects.first() # Even if stored as ObjectId's internally mongoengine uses DBRefs # As ObjectId's aren't automatically derefenced - self.assertIsInstance(user._data["orgs"][0], DBRef) - self.assertIsInstance(user.orgs[0], Organization) - self.assertIsInstance(user._data["orgs"][0], Organization) + assert isinstance(user._data["orgs"][0], DBRef) + assert isinstance(user.orgs[0], Organization) + assert isinstance(user._data["orgs"][0], Organization) # Changing a value with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription.objects.first() - self.assertEqual(q, 1) + assert q == 1 sub.name = "Test Sub" sub.save() - self.assertEqual(q, 2) + assert q == 2 # Changing a value that will cascade with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription.objects.first() - self.assertEqual(q, 1) + assert q == 1 sub.user.name = "Test" - self.assertEqual(q, 2) + assert q == 2 sub.save(cascade=True) - self.assertEqual(q, 3) + assert q == 3 # Changing a value and one that will cascade with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription.objects.first() sub.name = "Test Sub 2" - self.assertEqual(q, 1) + assert q == 1 sub.user.name = "Test 2" - self.assertEqual(q, 2) + assert q == 2 sub.save(cascade=True) - self.assertEqual(q, 4) # One for the UserSub and one for the User + assert q == 4 # One for the UserSub and one for the User # Saving with just the refs with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 sub = UserSubscription(user=u1.pk, feed=f1.pk) - self.assertEqual(q, 0) + assert q == 0 sub.save() - self.assertEqual(q, 1) + assert q == 1 # Saving with just the refs on a ListField with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 User(name="Bob", orgs=[o1.pk, o2.pk]).save() - self.assertEqual(q, 1) + assert q == 1 # Saving new objects with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 user = User.objects.first() - self.assertEqual(q, 1) + assert q == 1 feed = Feed.objects.first() - self.assertEqual(q, 2) + assert q == 2 sub = UserSubscription(user=user, feed=feed) - self.assertEqual(q, 2) # Check no change + assert q == 2 # Check no change sub.save() - self.assertEqual(q, 3) + assert q == 3 def test_set_unset_one_operation(self): """Ensure that $set and $unset actions are performed in the @@ -1781,14 +1773,14 @@ class TestInstance(MongoDBTestCase): # write an entity with a single prop foo = FooBar(foo="foo").save() - self.assertEqual(foo.foo, "foo") + assert foo.foo == "foo" del foo.foo foo.bar = "bar" with query_counter() as q: - self.assertEqual(0, q) + assert 0 == q foo.save() - self.assertEqual(1, q) + assert 1 == q def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields.""" @@ -1810,34 +1802,34 @@ class TestInstance(MongoDBTestCase): person.reload() person = self.Person.objects.get() - self.assertTrue(person.comments[0].published) + assert person.comments[0].published person.comments[0].published = False person.save() person = self.Person.objects.get() - self.assertFalse(person.comments[0].published) + assert not person.comments[0].published # Simple dict w person.comments_dict["first_post"] = Comment() person.save() person = self.Person.objects.get() - self.assertTrue(person.comments_dict["first_post"].published) + assert person.comments_dict["first_post"].published person.comments_dict["first_post"].published = False person.save() person = self.Person.objects.get() - self.assertFalse(person.comments_dict["first_post"].published) + assert not person.comments_dict["first_post"].published def test_delete(self): """Ensure that document may be deleted using the delete method.""" person = self.Person(name="Test User", age=30) person.save() - self.assertEqual(self.Person.objects.count(), 1) + assert self.Person.objects.count() == 1 person.delete() - self.assertEqual(self.Person.objects.count(), 0) + assert self.Person.objects.count() == 0 def test_save_custom_id(self): """Ensure that a document may be saved with a custom _id.""" @@ -1849,7 +1841,7 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({"name": "Test User"}) - self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4") + assert str(person_obj["_id"]) == "497ce96f395f2f052a494fd4" def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using @@ -1862,7 +1854,7 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({"name": "Test User"}) - self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4") + assert str(person_obj["_id"]) == "497ce96f395f2f052a494fd4" def test_save_list(self): """Ensure that a list field may be properly saved.""" @@ -1885,9 +1877,9 @@ class TestInstance(MongoDBTestCase): collection = self.db[BlogPost._get_collection_name()] post_obj = collection.find_one() - self.assertEqual(post_obj["tags"], tags) + assert post_obj["tags"] == tags for comment_obj, comment in zip(post_obj["comments"], comments): - self.assertEqual(comment_obj["content"], comment["content"]) + assert comment_obj["content"] == comment["content"] def test_list_search_by_embedded(self): class User(Document): @@ -1944,9 +1936,9 @@ class TestInstance(MongoDBTestCase): p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")]) p4.save() - self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) - self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) - self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) + assert [p1, p2] == list(Page.objects.filter(comments__user=u1)) + assert [p1, p2, p4] == list(Page.objects.filter(comments__user=u2)) + assert [p1, p3] == list(Page.objects.filter(comments__user=u3)) def test_save_embedded_document(self): """Ensure that a document with an embedded document field may @@ -1968,11 +1960,11 @@ class TestInstance(MongoDBTestCase): # Ensure that the object is in the database collection = self.db[self.Person._get_collection_name()] employee_obj = collection.find_one({"name": "Test Employee"}) - self.assertEqual(employee_obj["name"], "Test Employee") - self.assertEqual(employee_obj["age"], 50) + assert employee_obj["name"] == "Test Employee" + assert employee_obj["age"] == 50 # Ensure that the 'details' embedded object saved correctly - self.assertEqual(employee_obj["details"]["position"], "Developer") + assert employee_obj["details"]["position"] == "Developer" def test_embedded_update_after_save(self): """Test update of `EmbeddedDocumentField` attached to a newly @@ -1994,7 +1986,7 @@ class TestInstance(MongoDBTestCase): site.save() site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") + assert site.page.log_message == "Error: Dummy message" def test_updating_an_embedded_document(self): """Ensure that a document with an embedded document field may @@ -2019,18 +2011,18 @@ class TestInstance(MongoDBTestCase): promoted_employee.save() promoted_employee.reload() - self.assertEqual(promoted_employee.name, "Test Employee") - self.assertEqual(promoted_employee.age, 50) + assert promoted_employee.name == "Test Employee" + assert promoted_employee.age == 50 # Ensure that the 'details' embedded object saved correctly - self.assertEqual(promoted_employee.details.position, "Senior Developer") + assert promoted_employee.details.position == "Senior Developer" # Test removal promoted_employee.details = None promoted_employee.save() promoted_employee.reload() - self.assertEqual(promoted_employee.details, None) + assert promoted_employee.details == None def test_object_mixins(self): class NameMixin(object): @@ -2039,12 +2031,12 @@ class TestInstance(MongoDBTestCase): class Foo(EmbeddedDocument, NameMixin): quantity = IntField() - self.assertEqual(["name", "quantity"], sorted(Foo._fields.keys())) + assert ["name", "quantity"] == sorted(Foo._fields.keys()) class Bar(Document, NameMixin): widgets = StringField() - self.assertEqual(["id", "name", "widgets"], sorted(Bar._fields.keys())) + assert ["id", "name", "widgets"] == sorted(Bar._fields.keys()) def test_mixin_inheritance(self): class BaseMixIn(object): @@ -2064,10 +2056,10 @@ class TestInstance(MongoDBTestCase): t = TestDoc.objects.first() - self.assertEqual(t.age, 19) - self.assertEqual(t.comment, "great!") - self.assertEqual(t.data, "test") - self.assertEqual(t.count, 12) + assert t.age == 19 + assert t.comment == "great!" + assert t.data == "test" + assert t.count == 12 def test_save_reference(self): """Ensure that a document reference field may be saved in the @@ -2092,22 +2084,22 @@ class TestInstance(MongoDBTestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertIsInstance(post_obj._data["author"], bson.DBRef) - self.assertIsInstance(post_obj.author, self.Person) - self.assertEqual(post_obj.author.name, "Test User") + assert isinstance(post_obj._data["author"], bson.DBRef) + assert isinstance(post_obj.author, self.Person) + assert post_obj.author.name == "Test User" # Ensure that the dereferenced object may be changed and saved post_obj.author.age = 25 post_obj.author.save() author = list(self.Person.objects(name="Test User"))[-1] - self.assertEqual(author.age, 25) + assert author.age == 25 def test_duplicate_db_fields_raise_invalid_document_error(self): """Ensure a InvalidDocumentError is thrown if duplicate fields declare the same db_field. """ - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Foo(Document): name = StringField() @@ -2125,7 +2117,7 @@ class TestInstance(MongoDBTestCase): forms = ListField(StringField(), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): Word._from_son( { "stem": [1, 2, 3], @@ -2136,7 +2128,7 @@ class TestInstance(MongoDBTestCase): ) # Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Word._from_son("this is not a valid SON dict") def test_reverse_delete_rule_cascade_and_nullify(self): @@ -2165,12 +2157,12 @@ class TestInstance(MongoDBTestCase): reviewer.delete() # No effect on the BlogPost - self.assertEqual(BlogPost.objects.count(), 1) - self.assertEqual(BlogPost.objects.get().reviewer, None) + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.get().reviewer == None # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(BlogPost.objects.count(), 0) + assert BlogPost.objects.count() == 0 def test_reverse_delete_rule_pull(self): """Ensure that a referenced document is also deleted with @@ -2189,7 +2181,7 @@ class TestInstance(MongoDBTestCase): parent_record.save() child_record.delete() - self.assertEqual(Record.objects(name="parent").get().children, []) + assert Record.objects(name="parent").get().children == [] def test_reverse_delete_rule_with_custom_id_field(self): """Ensure that a referenced document with custom primary key @@ -2211,11 +2203,11 @@ class TestInstance(MongoDBTestCase): book = Book(author=user, reviewer=reviewer).save() reviewer.delete() - self.assertEqual(Book.objects.count(), 1) - self.assertEqual(Book.objects.get().reviewer, None) + assert Book.objects.count() == 1 + assert Book.objects.get().reviewer == None user.delete() - self.assertEqual(Book.objects.count(), 0) + assert Book.objects.count() == 0 def test_reverse_delete_rule_with_shared_id_among_collections(self): """Ensure that cascade delete rule doesn't mix id among @@ -2239,16 +2231,16 @@ class TestInstance(MongoDBTestCase): user_2.delete() # Deleting user_2 should also delete book_1 but not book_2 - self.assertEqual(Book.objects.count(), 1) - self.assertEqual(Book.objects.get(), book_2) + assert Book.objects.count() == 1 + assert Book.objects.get() == book_2 user_3 = User(id=3).save() book_3 = Book(id=3, author=user_3).save() user_3.delete() # Deleting user_3 should also delete book_3 - self.assertEqual(Book.objects.count(), 1) - self.assertEqual(Book.objects.get(), book_2) + assert Book.objects.count() == 1 + assert Book.objects.get() == book_2 def test_reverse_delete_rule_with_document_inheritance(self): """Ensure that a referenced document is also deleted upon @@ -2278,12 +2270,12 @@ class TestInstance(MongoDBTestCase): post.save() reviewer.delete() - self.assertEqual(BlogPost.objects.count(), 1) - self.assertEqual(BlogPost.objects.get().reviewer, None) + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.get().reviewer == None # Delete the Writer should lead to deletion of the BlogPost author.delete() - self.assertEqual(BlogPost.objects.count(), 0) + assert BlogPost.objects.count() == 0 def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): """Ensure that a referenced document is also deleted upon @@ -2315,12 +2307,12 @@ class TestInstance(MongoDBTestCase): # Deleting the reviewer should have no effect on the BlogPost reviewer.delete() - self.assertEqual(BlogPost.objects.count(), 1) - self.assertEqual(BlogPost.objects.get().reviewers, []) + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.get().reviewers == [] # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(BlogPost.objects.count(), 0) + assert BlogPost.objects.count() == 0 def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): """Ensure the pre_delete signal is triggered upon a cascading @@ -2357,7 +2349,7 @@ class TestInstance(MongoDBTestCase): # the pre-delete signal should have decremented the editor's queue editor = Editor.objects(name="Max P.").get() - self.assertEqual(editor.review_queue, 0) + assert editor.review_queue == 0 def test_two_way_reverse_delete_rule(self): """Ensure that Bi-Directional relationships work with @@ -2389,11 +2381,11 @@ class TestInstance(MongoDBTestCase): f.delete() - self.assertEqual(Bar.objects.count(), 1) # No effect on the BlogPost - self.assertEqual(Bar.objects.get().foo, None) + assert Bar.objects.count() == 1 # No effect on the BlogPost + assert Bar.objects.get().foo == None def test_invalid_reverse_delete_rule_raise_errors(self): - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Blog(Document): content = StringField() @@ -2404,7 +2396,7 @@ class TestInstance(MongoDBTestCase): field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY) ) - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Parents(EmbeddedDocument): father = ReferenceField("Person", reverse_delete_rule=DENY) @@ -2441,7 +2433,7 @@ class TestInstance(MongoDBTestCase): # Delete the Person, which should lead to deletion of the BlogPost, # and, recursively to the Comment, too author.delete() - self.assertEqual(Comment.objects.count(), 0) + assert Comment.objects.count() == 0 def test_reverse_delete_rule_deny(self): """Ensure that a document cannot be referenced if there are @@ -2463,19 +2455,18 @@ class TestInstance(MongoDBTestCase): post.save() # Delete the Person should be denied - self.assertRaises(OperationError, author.delete) # Should raise denied error - self.assertEqual( - BlogPost.objects.count(), 1 - ) # No objects may have been deleted - self.assertEqual(self.Person.objects.count(), 1) + with pytest.raises(OperationError): + author.delete() # Should raise denied error + assert BlogPost.objects.count() == 1 # No objects may have been deleted + assert self.Person.objects.count() == 1 # Other users, that don't have BlogPosts must be removable, like normal author = self.Person(name="Another User") author.save() - self.assertEqual(self.Person.objects.count(), 2) + assert self.Person.objects.count() == 2 author.delete() - self.assertEqual(self.Person.objects.count(), 1) + assert self.Person.objects.count() == 1 def subclasses_and_unique_keys_works(self): class A(Document): @@ -2491,8 +2482,8 @@ class TestInstance(MongoDBTestCase): A().save() B(foo=True).save() - self.assertEqual(A.objects.count(), 2) - self.assertEqual(B.objects.count(), 1) + assert A.objects.count() == 2 + assert B.objects.count() == 1 def test_document_hash(self): """Test document in list, dict, set.""" @@ -2518,12 +2509,12 @@ class TestInstance(MongoDBTestCase): # Make sure docs are properly identified in a list (__eq__ is used # for the comparison). all_user_list = list(User.objects.all()) - self.assertIn(u1, all_user_list) - self.assertIn(u2, all_user_list) - self.assertIn(u3, all_user_list) - self.assertNotIn(u4, all_user_list) # New object - self.assertNotIn(b1, all_user_list) # Other object - self.assertNotIn(b2, all_user_list) # Other object + assert u1 in all_user_list + assert u2 in all_user_list + assert u3 in all_user_list + assert u4 not in all_user_list # New object + assert b1 not in all_user_list # Other object + assert b2 not in all_user_list # Other object # Make sure docs can be used as keys in a dict (__hash__ is used # for hashing the docs). @@ -2531,27 +2522,27 @@ class TestInstance(MongoDBTestCase): for u in User.objects.all(): all_user_dic[u] = "OK" - self.assertEqual(all_user_dic.get(u1, False), "OK") - self.assertEqual(all_user_dic.get(u2, False), "OK") - self.assertEqual(all_user_dic.get(u3, False), "OK") - self.assertEqual(all_user_dic.get(u4, False), False) # New object - self.assertEqual(all_user_dic.get(b1, False), False) # Other object - self.assertEqual(all_user_dic.get(b2, False), False) # Other object + assert all_user_dic.get(u1, False) == "OK" + assert all_user_dic.get(u2, False) == "OK" + assert all_user_dic.get(u3, False) == "OK" + assert all_user_dic.get(u4, False) == False # New object + assert all_user_dic.get(b1, False) == False # Other object + assert all_user_dic.get(b2, False) == False # Other object # Make sure docs are properly identified in a set (__hash__ is used # for hashing the docs). all_user_set = set(User.objects.all()) - self.assertIn(u1, all_user_set) - self.assertNotIn(u4, all_user_set) - self.assertNotIn(b1, all_user_list) - self.assertNotIn(b2, all_user_list) + assert u1 in all_user_set + assert u4 not in all_user_set + assert b1 not in all_user_list + assert b2 not in all_user_list # Make sure duplicate docs aren't accepted in the set - self.assertEqual(len(all_user_set), 3) + assert len(all_user_set) == 3 all_user_set.add(u1) all_user_set.add(u2) all_user_set.add(u3) - self.assertEqual(len(all_user_set), 3) + assert len(all_user_set) == 3 def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"]) @@ -2564,21 +2555,21 @@ class TestInstance(MongoDBTestCase): pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected, pickle_doc) + assert resurrected == pickle_doc # Test pickling changed data pickle_doc.lists.append("3") pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected, pickle_doc) + assert resurrected == pickle_doc resurrected.string = "Two" resurrected.save() pickle_doc = PickleTest.objects.first() - self.assertEqual(resurrected, pickle_doc) - self.assertEqual(pickle_doc.string, "Two") - self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) + assert resurrected == pickle_doc + assert pickle_doc.string == "Two" + assert pickle_doc.lists == ["1", "2", "3"] def test_regular_document_pickle(self): pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"]) @@ -2594,11 +2585,12 @@ class TestInstance(MongoDBTestCase): fixtures.PickleTest = fixtures.NewDocumentPickleTest resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected.__class__, fixtures.NewDocumentPickleTest) - self.assertEqual( - resurrected._fields_ordered, fixtures.NewDocumentPickleTest._fields_ordered + assert resurrected.__class__ == fixtures.NewDocumentPickleTest + assert ( + resurrected._fields_ordered + == fixtures.NewDocumentPickleTest._fields_ordered ) - self.assertNotEqual(resurrected._fields_ordered, pickle_doc._fields_ordered) + assert resurrected._fields_ordered != pickle_doc._fields_ordered # The local PickleTest is still a ref to the original fixtures.PickleTest = PickleTest @@ -2617,19 +2609,17 @@ class TestInstance(MongoDBTestCase): pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected, pickle_doc) - self.assertEqual(resurrected._fields_ordered, pickle_doc._fields_ordered) - self.assertEqual( - resurrected._dynamic_fields.keys(), pickle_doc._dynamic_fields.keys() - ) + assert resurrected == pickle_doc + assert resurrected._fields_ordered == pickle_doc._fields_ordered + assert resurrected._dynamic_fields.keys() == pickle_doc._dynamic_fields.keys() - self.assertEqual(resurrected.embedded, pickle_doc.embedded) - self.assertEqual( - resurrected.embedded._fields_ordered, pickle_doc.embedded._fields_ordered + assert resurrected.embedded == pickle_doc.embedded + assert ( + resurrected.embedded._fields_ordered == pickle_doc.embedded._fields_ordered ) - self.assertEqual( - resurrected.embedded._dynamic_fields.keys(), - pickle_doc.embedded._dynamic_fields.keys(), + assert ( + resurrected.embedded._dynamic_fields.keys() + == pickle_doc.embedded._dynamic_fields.keys() ) def test_picklable_on_signals(self): @@ -2642,7 +2632,7 @@ class TestInstance(MongoDBTestCase): """Test creating a field with a field name that would override the "validate" method. """ - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Blog(Document): validate = DictField() @@ -2659,7 +2649,7 @@ class TestInstance(MongoDBTestCase): a = A() a.save() a.reload() - self.assertEqual(a.b.field1, "field1") + assert a.b.field1 == "field1" class C(EmbeddedDocument): c_field = StringField(default="cfield") @@ -2676,7 +2666,7 @@ class TestInstance(MongoDBTestCase): a.save() a.reload() - self.assertEqual(a.b.field2.c_field, "new value") + assert a.b.field2.c_field == "new value" def test_can_save_false_values(self): """Ensures you can save False values on save.""" @@ -2692,7 +2682,7 @@ class TestInstance(MongoDBTestCase): d.archived = False d.save() - self.assertEqual(Doc.objects(archived=False).count(), 1) + assert Doc.objects(archived=False).count() == 1 def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs.""" @@ -2707,7 +2697,7 @@ class TestInstance(MongoDBTestCase): d.archived = False d.save() - self.assertEqual(Doc.objects(archived=False).count(), 1) + assert Doc.objects(archived=False).count() == 1 def test_do_not_save_unchanged_references(self): """Ensures cascading saves dont auto update""" @@ -2768,8 +2758,8 @@ class TestInstance(MongoDBTestCase): hp = Book.objects.create(name="Harry Potter") # Selects - self.assertEqual(User.objects.first(), bob) - self.assertEqual(Book.objects.first(), hp) + assert User.objects.first() == bob + assert Book.objects.first() == hp # DeReference class AuthorBooks(Document): @@ -2783,27 +2773,23 @@ class TestInstance(MongoDBTestCase): ab = AuthorBooks.objects.create(author=bob, book=hp) # select - self.assertEqual(AuthorBooks.objects.first(), ab) - self.assertEqual(AuthorBooks.objects.first().book, hp) - self.assertEqual(AuthorBooks.objects.first().author, bob) - self.assertEqual(AuthorBooks.objects.filter(author=bob).first(), ab) - self.assertEqual(AuthorBooks.objects.filter(book=hp).first(), ab) + assert AuthorBooks.objects.first() == ab + assert AuthorBooks.objects.first().book == hp + assert AuthorBooks.objects.first().author == bob + assert AuthorBooks.objects.filter(author=bob).first() == ab + assert AuthorBooks.objects.filter(book=hp).first() == ab # DB Alias - self.assertEqual(User._get_db(), get_db("testdb-1")) - self.assertEqual(Book._get_db(), get_db("testdb-2")) - self.assertEqual(AuthorBooks._get_db(), get_db("testdb-3")) + assert User._get_db() == get_db("testdb-1") + assert Book._get_db() == get_db("testdb-2") + assert AuthorBooks._get_db() == get_db("testdb-3") # Collections - self.assertEqual( - User._get_collection(), get_db("testdb-1")[User._get_collection_name()] - ) - self.assertEqual( - Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()] - ) - self.assertEqual( - AuthorBooks._get_collection(), - get_db("testdb-3")[AuthorBooks._get_collection_name()], + assert User._get_collection() == get_db("testdb-1")[User._get_collection_name()] + assert Book._get_collection() == get_db("testdb-2")[Book._get_collection_name()] + assert ( + AuthorBooks._get_collection() + == get_db("testdb-3")[AuthorBooks._get_collection_name()] ) def test_db_alias_overrides(self): @@ -2826,9 +2812,9 @@ class TestInstance(MongoDBTestCase): A.objects.all() - self.assertEqual("testdb-2", B._meta.get("db_alias")) - self.assertEqual("mongoenginetest", A._get_collection().database.name) - self.assertEqual("mongoenginetest2", B._get_collection().database.name) + assert "testdb-2" == B._meta.get("db_alias") + assert "mongoenginetest" == A._get_collection().database.name + assert "mongoenginetest2" == B._get_collection().database.name def test_db_alias_propagates(self): """db_alias propagates?""" @@ -2841,7 +2827,7 @@ class TestInstance(MongoDBTestCase): class B(A): pass - self.assertEqual("testdb-1", B._meta.get("db_alias")) + assert "testdb-1" == B._meta.get("db_alias") def test_db_ref_usage(self): """DB Ref usage in dict_fields.""" @@ -2898,11 +2884,9 @@ class TestInstance(MongoDBTestCase): Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) # Checks - self.assertEqual( - ",".join([str(b) for b in Book.objects.all()]), "1,2,3,4,5,6,7,8,9" - ) + assert ",".join([str(b) for b in Book.objects.all()]) == "1,2,3,4,5,6,7,8,9" # bob related books - self.assertEqual( + assert ( ",".join( [ str(b) @@ -2910,12 +2894,12 @@ class TestInstance(MongoDBTestCase): Q(extra__a=bob) | Q(author=bob) | Q(extra__b=bob) ) ] - ), - "1,2,3,4", + ) + == "1,2,3,4" ) # Susan & Karl related books - self.assertEqual( + assert ( ",".join( [ str(b) @@ -2925,12 +2909,12 @@ class TestInstance(MongoDBTestCase): | Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()]) ) ] - ), - "1", + ) + == "1" ) # $Where - self.assertEqual( + assert ( u",".join( [ str(b) @@ -2943,8 +2927,8 @@ class TestInstance(MongoDBTestCase): } ) ] - ), - "1,2", + ) + == "1,2" ) def test_switch_db_instance(self): @@ -2958,7 +2942,7 @@ class TestInstance(MongoDBTestCase): Group.drop_collection() Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() group = Group.objects.first() group.switch_db("testdb-1") @@ -2967,10 +2951,10 @@ class TestInstance(MongoDBTestCase): with switch_db(Group, "testdb-1") as Group: group = Group.objects.first() - self.assertEqual("hello - testdb!", group.name) + assert "hello - testdb!" == group.name group = Group.objects.first() - self.assertEqual("hello - default", group.name) + assert "hello - default" == group.name # Slightly contrived now - perform an update # Only works as they have the same object_id @@ -2979,12 +2963,12 @@ class TestInstance(MongoDBTestCase): with switch_db(Group, "testdb-1") as Group: group = Group.objects.first() - self.assertEqual("hello - update", group.name) + assert "hello - update" == group.name Group.drop_collection() - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() group = Group.objects.first() - self.assertEqual("hello - default", group.name) + assert "hello - default" == group.name # Totally contrived now - perform a delete # Only works as they have the same object_id @@ -2992,10 +2976,10 @@ class TestInstance(MongoDBTestCase): group.delete() with switch_db(Group, "testdb-1") as Group: - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() group = Group.objects.first() - self.assertEqual("hello - default", group.name) + assert "hello - default" == group.name def test_load_undefined_fields(self): class User(Document): @@ -3007,7 +2991,8 @@ class TestInstance(MongoDBTestCase): {"name": "John", "foo": "Bar", "data": [1, 2, 3]} ) - self.assertRaises(FieldDoesNotExist, User.objects.first) + with pytest.raises(FieldDoesNotExist): + User.objects.first() def test_load_undefined_fields_with_strict_false(self): class User(Document): @@ -3022,11 +3007,11 @@ class TestInstance(MongoDBTestCase): ) user = User.objects.first() - self.assertEqual(user.name, "John") - self.assertFalse(hasattr(user, "foo")) - self.assertEqual(user._data["foo"], "Bar") - self.assertFalse(hasattr(user, "data")) - self.assertEqual(user._data["data"], [1, 2, 3]) + assert user.name == "John" + assert not hasattr(user, "foo") + assert user._data["foo"] == "Bar" + assert not hasattr(user, "data") + assert user._data["data"] == [1, 2, 3] def test_load_undefined_fields_on_embedded_document(self): class Thing(EmbeddedDocument): @@ -3045,7 +3030,8 @@ class TestInstance(MongoDBTestCase): } ) - self.assertRaises(FieldDoesNotExist, User.objects.first) + with pytest.raises(FieldDoesNotExist): + User.objects.first() def test_load_undefined_fields_on_embedded_document_with_strict_false_on_doc(self): class Thing(EmbeddedDocument): @@ -3066,7 +3052,8 @@ class TestInstance(MongoDBTestCase): } ) - self.assertRaises(FieldDoesNotExist, User.objects.first) + with pytest.raises(FieldDoesNotExist): + User.objects.first() def test_load_undefined_fields_on_embedded_document_with_strict_false(self): class Thing(EmbeddedDocument): @@ -3088,12 +3075,12 @@ class TestInstance(MongoDBTestCase): ) user = User.objects.first() - self.assertEqual(user.name, "John") - self.assertEqual(user.thing.name, "My thing") - self.assertFalse(hasattr(user.thing, "foo")) - self.assertEqual(user.thing._data["foo"], "Bar") - self.assertFalse(hasattr(user.thing, "data")) - self.assertEqual(user.thing._data["data"], [1, 2, 3]) + assert user.name == "John" + assert user.thing.name == "My thing" + assert not hasattr(user.thing, "foo") + assert user.thing._data["foo"] == "Bar" + assert not hasattr(user.thing, "data") + assert user.thing._data["data"] == [1, 2, 3] def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -3108,7 +3095,7 @@ class TestInstance(MongoDBTestCase): doc.save() one = Doc.objects.filter(**{"hello world": 1}).count() - self.assertEqual(1, one) + assert 1 == one def test_shard_key(self): class LogEntry(Document): @@ -3123,13 +3110,13 @@ class TestInstance(MongoDBTestCase): log.machine = "Localhost" log.save() - self.assertTrue(log.id is not None) + assert log.id is not None log.log = "Saving" log.save() # try to change the shard key - with self.assertRaises(OperationError): + with pytest.raises(OperationError): log.machine = "127.0.0.1" def test_shard_key_in_embedded_document(self): @@ -3145,13 +3132,13 @@ class TestInstance(MongoDBTestCase): bar_doc = Bar(foo=foo_doc, bar="world") bar_doc.save() - self.assertTrue(bar_doc.id is not None) + assert bar_doc.id is not None bar_doc.bar = "baz" bar_doc.save() # try to change the shard key - with self.assertRaises(OperationError): + with pytest.raises(OperationError): bar_doc.foo.foo = "something" bar_doc.save() @@ -3168,13 +3155,13 @@ class TestInstance(MongoDBTestCase): log.machine = "Localhost" log.save() - self.assertTrue(log.id is not None) + assert log.id is not None log.log = "Saving" log.save() # try to change the shard key - with self.assertRaises(OperationError): + with pytest.raises(OperationError): log.machine = "127.0.0.1" def test_kwargs_simple(self): @@ -3191,8 +3178,8 @@ class TestInstance(MongoDBTestCase): classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) dict_doc = Doc(**{"doc_name": "my doc", "doc": {"name": "embedded doc"}}) - self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + assert classic_doc == dict_doc + assert classic_doc._data == dict_doc._data def test_kwargs_complex(self): class Embedded(EmbeddedDocument): @@ -3216,48 +3203,48 @@ class TestInstance(MongoDBTestCase): } ) - self.assertEqual(classic_doc, dict_doc) - self.assertEqual(classic_doc._data, dict_doc._data) + assert classic_doc == dict_doc + assert classic_doc._data == dict_doc._data def test_positional_creation(self): """Document cannot be instantiated using positional arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: person = self.Person("Test User", 42) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_mixed_creation(self): """Document cannot be instantiated using mixed arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: person = self.Person("Test User", age=42) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_positional_creation_embedded(self): """Embedded document cannot be created using positional arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: job = self.Job("Test Job", 4) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_mixed_creation_embedded(self): """Embedded document cannot be created using mixed arguments.""" - with self.assertRaises(TypeError) as e: + with pytest.raises(TypeError) as e: job = self.Job("Test Job", years=4) expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - self.assertEqual(str(e.exception), expected_msg) + assert str(e.exception) == expected_msg def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id'.""" @@ -3269,8 +3256,8 @@ class TestInstance(MongoDBTestCase): Person(name="Harry Potter").save() person = Person.objects.first() - self.assertIn("id", person._data.keys()) - self.assertEqual(person._data.get("id"), person.id) + assert "id" in person._data.keys() + assert person._data.get("id") == person.id def test_complex_nesting_document_and_embedded_document(self): class Macro(EmbeddedDocument): @@ -3310,8 +3297,8 @@ class TestInstance(MongoDBTestCase): system.save() system = NodesSystem.objects.first() - self.assertEqual( - "UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value + assert ( + "UNDEFINED" == system.nodes["node"].parameters["param"].macros["test"].value ) def test_embedded_document_equality(self): @@ -3328,9 +3315,9 @@ class TestInstance(MongoDBTestCase): f1 = Embedded._from_son(e.to_mongo()) f2 = Embedded._from_son(e.to_mongo()) - self.assertEqual(f1, f2) + assert f1 == f2 f1.ref # Dereferences lazily - self.assertEqual(f1, f2) + assert f1 == f2 def test_dbref_equality(self): class Test2(Document): @@ -3361,36 +3348,36 @@ class TestInstance(MongoDBTestCase): dbref2 = f._data["test2"] obj2 = f.test2 - self.assertIsInstance(dbref2, DBRef) - self.assertIsInstance(obj2, Test2) - self.assertEqual(obj2.id, dbref2.id) - self.assertEqual(obj2, dbref2) - self.assertEqual(dbref2, obj2) + assert isinstance(dbref2, DBRef) + assert isinstance(obj2, Test2) + assert obj2.id == dbref2.id + assert obj2 == dbref2 + assert dbref2 == obj2 dbref3 = f._data["test3"] obj3 = f.test3 - self.assertIsInstance(dbref3, DBRef) - self.assertIsInstance(obj3, Test3) - self.assertEqual(obj3.id, dbref3.id) - self.assertEqual(obj3, dbref3) - self.assertEqual(dbref3, obj3) + assert isinstance(dbref3, DBRef) + assert isinstance(obj3, Test3) + assert obj3.id == dbref3.id + assert obj3 == dbref3 + assert dbref3 == obj3 - self.assertEqual(obj2.id, obj3.id) - self.assertEqual(dbref2.id, dbref3.id) - self.assertNotEqual(dbref2, dbref3) - self.assertNotEqual(dbref3, dbref2) - self.assertNotEqual(dbref2, dbref3) - self.assertNotEqual(dbref3, dbref2) + assert obj2.id == obj3.id + assert dbref2.id == dbref3.id + assert dbref2 != dbref3 + assert dbref3 != dbref2 + assert dbref2 != dbref3 + assert dbref3 != dbref2 - self.assertNotEqual(obj2, dbref3) - self.assertNotEqual(dbref3, obj2) - self.assertNotEqual(obj2, dbref3) - self.assertNotEqual(dbref3, obj2) + assert obj2 != dbref3 + assert dbref3 != obj2 + assert obj2 != dbref3 + assert dbref3 != obj2 - self.assertNotEqual(obj3, dbref2) - self.assertNotEqual(dbref2, obj3) - self.assertNotEqual(obj3, dbref2) - self.assertNotEqual(dbref2, obj3) + assert obj3 != dbref2 + assert dbref2 != obj3 + assert obj3 != dbref2 + assert dbref2 != obj3 def test_default_values(self): class Person(Document): @@ -3405,7 +3392,7 @@ class TestInstance(MongoDBTestCase): p2.name = "alon2" p2.save() p3 = Person.objects().only("created_on")[0] - self.assertEqual(orig_created_on, p3.created_on) + assert orig_created_on == p3.created_on class Person(Document): created_on = DateTimeField(default=lambda: datetime.utcnow()) @@ -3414,10 +3401,10 @@ class TestInstance(MongoDBTestCase): p4 = Person.objects()[0] p4.save() - self.assertEqual(p4.height, 189) + assert p4.height == 189 # However the default will not be fixed in DB - self.assertEqual(Person.objects(height=189).count(), 0) + assert Person.objects(height=189).count() == 0 # alter DB for the new default coll = Person._get_collection() @@ -3425,7 +3412,7 @@ class TestInstance(MongoDBTestCase): if "height" not in person: coll.update_one({"_id": person["_id"]}, {"$set": {"height": 189}}) - self.assertEqual(Person.objects(height=189).count(), 1) + assert Person.objects(height=189).count() == 1 def test_shard_key_mutability_after_from_json(self): """Ensure that a document ID can be modified after from_json. @@ -3445,11 +3432,11 @@ class TestInstance(MongoDBTestCase): meta = {"shard_key": ("id", "name")} p = Person.from_json('{"name": "name", "age": 27}', created=True) - self.assertEqual(p._created, True) + assert p._created == True p.name = "new name" p.id = "12345" - self.assertEqual(p.name, "new name") - self.assertEqual(p.id, "12345") + assert p.name == "new name" + assert p.id == "12345" def test_shard_key_mutability_after_from_son(self): """Ensure that a document ID can be modified after _from_son. @@ -3463,11 +3450,11 @@ class TestInstance(MongoDBTestCase): meta = {"shard_key": ("id", "name")} p = Person._from_son({"name": "name", "age": 27}, created=True) - self.assertEqual(p._created, True) + assert p._created == True p.name = "new name" p.id = "12345" - self.assertEqual(p.name, "new name") - self.assertEqual(p.id, "12345") + assert p.name == "new name" + assert p.id == "12345" def test_from_json_created_false_without_an_id(self): class Person(Document): @@ -3476,14 +3463,14 @@ class TestInstance(MongoDBTestCase): Person.objects.delete() p = Person.from_json('{"name": "name"}', created=False) - self.assertEqual(p._created, False) - self.assertEqual(p.id, None) + assert p._created == False + assert p.id == None # Make sure the document is subsequently persisted correctly. p.save() - self.assertTrue(p.id is not None) + assert p.id is not None saved_p = Person.objects.get(id=p.id) - self.assertEqual(saved_p.name, "name") + assert saved_p.name == "name" def test_from_json_created_false_with_an_id(self): """See https://github.com/mongoengine/mongoengine/issues/1854""" @@ -3496,13 +3483,13 @@ class TestInstance(MongoDBTestCase): p = Person.from_json( '{"_id": "5b85a8b04ec5dc2da388296e", "name": "name"}', created=False ) - self.assertEqual(p._created, False) - self.assertEqual(p._changed_fields, []) - self.assertEqual(p.name, "name") - self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e")) + assert p._created == False + assert p._changed_fields == [] + assert p.name == "name" + assert p.id == ObjectId("5b85a8b04ec5dc2da388296e") p.save() - with self.assertRaises(DoesNotExist): + with pytest.raises(DoesNotExist): # Since the object is considered as already persisted (thanks to # `created=False` and an existing ID), and we haven't changed any # fields (i.e. `_changed_fields` is empty), the document is @@ -3510,12 +3497,12 @@ class TestInstance(MongoDBTestCase): # nothing. Person.objects.get(id=p.id) - self.assertFalse(p._created) + assert not p._created p.name = "a new name" - self.assertEqual(p._changed_fields, ["name"]) + assert p._changed_fields == ["name"] p.save() saved_p = Person.objects.get(id=p.id) - self.assertEqual(saved_p.name, p.name) + assert saved_p.name == p.name def test_from_json_created_true_with_an_id(self): class Person(Document): @@ -3526,15 +3513,15 @@ class TestInstance(MongoDBTestCase): p = Person.from_json( '{"_id": "5b85a8b04ec5dc2da388296e", "name": "name"}', created=True ) - self.assertTrue(p._created) - self.assertEqual(p._changed_fields, []) - self.assertEqual(p.name, "name") - self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e")) + assert p._created + assert p._changed_fields == [] + assert p.name == "name" + assert p.id == ObjectId("5b85a8b04ec5dc2da388296e") p.save() saved_p = Person.objects.get(id=p.id) - self.assertEqual(saved_p, p) - self.assertEqual(saved_p.name, "name") + assert saved_p == p + assert saved_p.name == "name" def test_null_field(self): # 734 @@ -3553,13 +3540,13 @@ class TestInstance(MongoDBTestCase): u_from_db = User.objects.get(name="user") u_from_db.height = None u_from_db.save() - self.assertEqual(u_from_db.height, None) + assert u_from_db.height == None # 864 - self.assertEqual(u_from_db.str_fld, None) - self.assertEqual(u_from_db.int_fld, None) - self.assertEqual(u_from_db.flt_fld, None) - self.assertEqual(u_from_db.dt_fld, None) - self.assertEqual(u_from_db.cdt_fld, None) + assert u_from_db.str_fld == None + assert u_from_db.int_fld == None + assert u_from_db.flt_fld == None + assert u_from_db.dt_fld == None + assert u_from_db.cdt_fld == None # 735 User.objects.delete() @@ -3567,7 +3554,7 @@ class TestInstance(MongoDBTestCase): u.save() User.objects(name="user").update_one(set__height=None, upsert=True) u_from_db = User.objects.get(name="user") - self.assertEqual(u_from_db.height, None) + assert u_from_db.height == None def test_not_saved_eq(self): """Ensure we can compare documents not saved. @@ -3578,8 +3565,8 @@ class TestInstance(MongoDBTestCase): p = Person() p1 = Person() - self.assertNotEqual(p, p1) - self.assertEqual(p, p) + assert p != p1 + assert p == p def test_list_iter(self): # 914 @@ -3592,10 +3579,10 @@ class TestInstance(MongoDBTestCase): A.objects.delete() A(l=[B(v="1"), B(v="2"), B(v="3")]).save() a = A.objects.get() - self.assertEqual(a.l._instance, a) + assert a.l._instance == a for idx, b in enumerate(a.l): - self.assertEqual(b._instance, a) - self.assertEqual(idx, 2) + assert b._instance == a + assert idx == 2 def test_falsey_pk(self): """Ensure that we can create and update a document with Falsey PK.""" @@ -3625,7 +3612,7 @@ class TestInstance(MongoDBTestCase): blog.update(push__tags__0=["mongodb", "code"]) blog.reload() - self.assertEqual(blog.tags, ["mongodb", "code", "python"]) + assert blog.tags == ["mongodb", "code", "python"] def test_push_nested_list(self): """Ensure that push update works in nested list""" @@ -3637,7 +3624,7 @@ class TestInstance(MongoDBTestCase): blog = BlogPost(slug="test").save() blog.update(push__tags=["value1", 123]) blog.reload() - self.assertEqual(blog.tags, [["value1", 123]]) + assert blog.tags == [["value1", 123]] def test_accessing_objects_with_indexes_error(self): insert_result = self.db.company.insert_many( @@ -3653,7 +3640,7 @@ class TestInstance(MongoDBTestCase): company = ReferenceField(Company) # Ensure index creation exception aren't swallowed (#1688) - with self.assertRaises(DuplicateKeyError): + with pytest.raises(DuplicateKeyError): User.objects().select_related() @@ -3663,10 +3650,10 @@ class ObjectKeyTestCase(MongoDBTestCase): title = StringField() book = Book(title="Whatever") - self.assertEqual(book._object_key, {"pk": None}) + assert book._object_key == {"pk": None} book.pk = ObjectId() - self.assertEqual(book._object_key, {"pk": book.pk}) + assert book._object_key == {"pk": book.pk} def test_object_key_with_custom_primary_key(self): class Book(Document): @@ -3674,10 +3661,10 @@ class ObjectKeyTestCase(MongoDBTestCase): title = StringField() book = Book(title="Sapiens") - self.assertEqual(book._object_key, {"pk": None}) + assert book._object_key == {"pk": None} book = Book(pk="0062316117") - self.assertEqual(book._object_key, {"pk": "0062316117"}) + assert book._object_key == {"pk": "0062316117"} def test_object_key_in_a_sharded_collection(self): class Book(Document): @@ -3685,9 +3672,9 @@ class ObjectKeyTestCase(MongoDBTestCase): meta = {"shard_key": ("pk", "title")} book = Book() - self.assertEqual(book._object_key, {"pk": None, "title": None}) + assert book._object_key == {"pk": None, "title": None} book = Book(pk=ObjectId(), title="Sapiens") - self.assertEqual(book._object_key, {"pk": book.pk, "title": "Sapiens"}) + assert book._object_key == {"pk": book.pk, "title": "Sapiens"} def test_object_key_with_custom_db_field(self): class Book(Document): @@ -3695,7 +3682,7 @@ class ObjectKeyTestCase(MongoDBTestCase): meta = {"shard_key": ("pk", "author")} book = Book(pk=ObjectId(), author="Author") - self.assertEqual(book._object_key, {"pk": book.pk, "author": "Author"}) + assert book._object_key == {"pk": book.pk, "author": "Author"} def test_object_key_with_nested_shard_key(self): class Author(EmbeddedDocument): @@ -3706,7 +3693,7 @@ class ObjectKeyTestCase(MongoDBTestCase): meta = {"shard_key": ("pk", "author.name")} book = Book(pk=ObjectId(), author=Author(name="Author")) - self.assertEqual(book._object_key, {"pk": book.pk, "author__name": "Author"}) + assert book._object_key == {"pk": book.pk, "author__name": "Author"} if __name__ == "__main__": diff --git a/tests/document/test_json_serialisation.py b/tests/document/test_json_serialisation.py index 26a4a6c1..593d34f8 100644 --- a/tests/document/test_json_serialisation.py +++ b/tests/document/test_json_serialisation.py @@ -32,7 +32,7 @@ class TestJson(MongoDBTestCase): expected_json = """{"embedded":{"string":"Inner Hello"},"string":"Hello"}""" - self.assertEqual(doc_json, expected_json) + assert doc_json == expected_json def test_json_simple(self): class Embedded(EmbeddedDocument): @@ -52,9 +52,9 @@ class TestJson(MongoDBTestCase): doc_json = doc.to_json(sort_keys=True, separators=(",", ":")) expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}""" - self.assertEqual(doc_json, expected_json) + assert doc_json == expected_json - self.assertEqual(doc, Doc.from_json(doc.to_json())) + assert doc == Doc.from_json(doc.to_json()) def test_json_complex(self): class EmbeddedDoc(EmbeddedDocument): @@ -99,7 +99,7 @@ class TestJson(MongoDBTestCase): return json.loads(self.to_json()) == json.loads(other.to_json()) doc = Doc() - self.assertEqual(doc, Doc.from_json(doc.to_json())) + assert doc == Doc.from_json(doc.to_json()) if __name__ == "__main__": diff --git a/tests/document/test_validation.py b/tests/document/test_validation.py index 7449dd33..80601994 100644 --- a/tests/document/test_validation.py +++ b/tests/document/test_validation.py @@ -4,6 +4,7 @@ from datetime import datetime from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestValidatorError(MongoDBTestCase): @@ -11,12 +12,12 @@ class TestValidatorError(MongoDBTestCase): """Ensure a ValidationError handles error to_dict correctly. """ error = ValidationError("root") - self.assertEqual(error.to_dict(), {}) + assert error.to_dict() == {} # 1st level error schema error.errors = {"1st": ValidationError("bad 1st")} - self.assertIn("1st", error.to_dict()) - self.assertEqual(error.to_dict()["1st"], "bad 1st") + assert "1st" in error.to_dict() + assert error.to_dict()["1st"] == "bad 1st" # 2nd level error schema error.errors = { @@ -24,10 +25,10 @@ class TestValidatorError(MongoDBTestCase): "bad 1st", errors={"2nd": ValidationError("bad 2nd")} ) } - self.assertIn("1st", error.to_dict()) - self.assertIsInstance(error.to_dict()["1st"], dict) - self.assertIn("2nd", error.to_dict()["1st"]) - self.assertEqual(error.to_dict()["1st"]["2nd"], "bad 2nd") + assert "1st" in error.to_dict() + assert isinstance(error.to_dict()["1st"], dict) + assert "2nd" in error.to_dict()["1st"] + assert error.to_dict()["1st"]["2nd"] == "bad 2nd" # moar levels error.errors = { @@ -45,13 +46,13 @@ class TestValidatorError(MongoDBTestCase): }, ) } - self.assertIn("1st", error.to_dict()) - self.assertIn("2nd", error.to_dict()["1st"]) - self.assertIn("3rd", error.to_dict()["1st"]["2nd"]) - self.assertIn("4th", error.to_dict()["1st"]["2nd"]["3rd"]) - self.assertEqual(error.to_dict()["1st"]["2nd"]["3rd"]["4th"], "Inception") + assert "1st" in error.to_dict() + assert "2nd" in error.to_dict()["1st"] + assert "3rd" in error.to_dict()["1st"]["2nd"] + assert "4th" in error.to_dict()["1st"]["2nd"]["3rd"] + assert error.to_dict()["1st"]["2nd"]["3rd"]["4th"] == "Inception" - self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") + assert error.message == "root(2nd.3rd.4th.Inception: ['1st'])" def test_model_validation(self): class User(Document): @@ -61,19 +62,19 @@ class TestValidatorError(MongoDBTestCase): try: User().validate() except ValidationError as e: - self.assertIn("User:None", e.message) - self.assertEqual( - e.to_dict(), - {"username": "Field is required", "name": "Field is required"}, - ) + assert "User:None" in e.message + assert e.to_dict() == { + "username": "Field is required", + "name": "Field is required", + } user = User(username="RossC0", name="Ross").save() user.name = None try: user.save() except ValidationError as e: - self.assertIn("User:RossC0", e.message) - self.assertEqual(e.to_dict(), {"name": "Field is required"}) + assert "User:RossC0" in e.message + assert e.to_dict() == {"name": "Field is required"} def test_fields_rewrite(self): class BasePerson(Document): @@ -85,7 +86,8 @@ class TestValidatorError(MongoDBTestCase): name = StringField(required=True) p = Person(age=15) - self.assertRaises(ValidationError, p.validate) + with pytest.raises(ValidationError): + p.validate() def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. @@ -96,17 +98,19 @@ class TestValidatorError(MongoDBTestCase): content = StringField(required=True) comment = Comment() - self.assertRaises(ValidationError, comment.validate) + with pytest.raises(ValidationError): + comment.validate() comment.content = "test" comment.validate() comment.date = 4 - self.assertRaises(ValidationError, comment.validate) + with pytest.raises(ValidationError): + comment.validate() comment.date = datetime.now() comment.validate() - self.assertEqual(comment._instance, None) + assert comment._instance == None def test_embedded_db_field_validate(self): class SubDoc(EmbeddedDocument): @@ -119,10 +123,8 @@ class TestValidatorError(MongoDBTestCase): try: Doc(id="bad").validate() except ValidationError as e: - self.assertIn("SubDoc:None", e.message) - self.assertEqual( - e.to_dict(), {"e": {"val": "OK could not be converted to int"}} - ) + assert "SubDoc:None" in e.message + assert e.to_dict() == {"e": {"val": "OK could not be converted to int"}} Doc.drop_collection() @@ -130,18 +132,16 @@ class TestValidatorError(MongoDBTestCase): doc = Doc.objects.first() keys = doc._data.keys() - self.assertEqual(2, len(keys)) - self.assertIn("e", keys) - self.assertIn("id", keys) + assert 2 == len(keys) + assert "e" in keys + assert "id" in keys doc.e.val = "OK" try: doc.save() except ValidationError as e: - self.assertIn("Doc:test", e.message) - self.assertEqual( - e.to_dict(), {"e": {"val": "OK could not be converted to int"}} - ) + assert "Doc:test" in e.message + assert e.to_dict() == {"e": {"val": "OK could not be converted to int"}} def test_embedded_weakref(self): class SubDoc(EmbeddedDocument): @@ -157,14 +157,16 @@ class TestValidatorError(MongoDBTestCase): s = SubDoc() - self.assertRaises(ValidationError, s.validate) + with pytest.raises(ValidationError): + s.validate() d1.e = s d2.e = s del d1 - self.assertRaises(ValidationError, d2.validate) + with pytest.raises(ValidationError): + d2.validate() def test_parent_reference_in_child_document(self): """ diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py index 719df922..86ee2654 100644 --- a/tests/fields/test_binary_field.py +++ b/tests/fields/test_binary_field.py @@ -7,6 +7,7 @@ import six from mongoengine import * from tests.utils import MongoDBTestCase +import pytest BIN_VALUE = six.b( "\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5" @@ -31,8 +32,8 @@ class TestBinaryField(MongoDBTestCase): attachment.save() attachment_1 = Attachment.objects().first() - self.assertEqual(MIME_TYPE, attachment_1.content_type) - self.assertEqual(BLOB, six.binary_type(attachment_1.blob)) + assert MIME_TYPE == attachment_1.content_type + assert BLOB == six.binary_type(attachment_1.blob) def test_validation_succeeds(self): """Ensure that valid values can be assigned to binary fields. @@ -45,13 +46,15 @@ class TestBinaryField(MongoDBTestCase): blob = BinaryField(max_bytes=4) attachment_required = AttachmentRequired() - self.assertRaises(ValidationError, attachment_required.validate) + with pytest.raises(ValidationError): + attachment_required.validate() attachment_required.blob = Binary(six.b("\xe6\x00\xc4\xff\x07")) attachment_required.validate() _5_BYTES = six.b("\xe6\x00\xc4\xff\x07") _4_BYTES = six.b("\xe6\x00\xc4\xff") - self.assertRaises(ValidationError, AttachmentSizeLimit(blob=_5_BYTES).validate) + with pytest.raises(ValidationError): + AttachmentSizeLimit(blob=_5_BYTES).validate() AttachmentSizeLimit(blob=_4_BYTES).validate() def test_validation_fails(self): @@ -61,7 +64,8 @@ class TestBinaryField(MongoDBTestCase): blob = BinaryField() for invalid_data in (2, u"Im_a_unicode", ["some_str"]): - self.assertRaises(ValidationError, Attachment(blob=invalid_data).validate) + with pytest.raises(ValidationError): + Attachment(blob=invalid_data).validate() def test__primary(self): class Attachment(Document): @@ -70,10 +74,10 @@ class TestBinaryField(MongoDBTestCase): Attachment.drop_collection() binary_id = uuid.uuid4().bytes att = Attachment(id=binary_id).save() - self.assertEqual(1, Attachment.objects.count()) - self.assertEqual(1, Attachment.objects.filter(id=att.id).count()) + assert 1 == Attachment.objects.count() + assert 1 == Attachment.objects.filter(id=att.id).count() att.delete() - self.assertEqual(0, Attachment.objects.count()) + assert 0 == Attachment.objects.count() def test_primary_filter_by_binary_pk_as_str(self): class Attachment(Document): @@ -82,9 +86,9 @@ class TestBinaryField(MongoDBTestCase): Attachment.drop_collection() binary_id = uuid.uuid4().bytes att = Attachment(id=binary_id).save() - self.assertEqual(1, Attachment.objects.filter(id=binary_id).count()) + assert 1 == Attachment.objects.filter(id=binary_id).count() att.delete() - self.assertEqual(0, Attachment.objects.count()) + assert 0 == Attachment.objects.count() def test_match_querying_with_bytes(self): class MyDocument(Document): @@ -94,7 +98,7 @@ class TestBinaryField(MongoDBTestCase): doc = MyDocument(bin_field=BIN_VALUE).save() matched_doc = MyDocument.objects(bin_field=BIN_VALUE).first() - self.assertEqual(matched_doc.id, doc.id) + assert matched_doc.id == doc.id def test_match_querying_with_binary(self): class MyDocument(Document): @@ -105,7 +109,7 @@ class TestBinaryField(MongoDBTestCase): doc = MyDocument(bin_field=BIN_VALUE).save() matched_doc = MyDocument.objects(bin_field=Binary(BIN_VALUE)).first() - self.assertEqual(matched_doc.id, doc.id) + assert matched_doc.id == doc.id def test_modify_operation__set(self): """Ensures no regression of bug #1127""" @@ -119,11 +123,11 @@ class TestBinaryField(MongoDBTestCase): doc = MyDocument.objects(some_field="test").modify( upsert=True, new=True, set__bin_field=BIN_VALUE ) - self.assertEqual(doc.some_field, "test") + assert doc.some_field == "test" if six.PY3: - self.assertEqual(doc.bin_field, BIN_VALUE) + assert doc.bin_field == BIN_VALUE else: - self.assertEqual(doc.bin_field, Binary(BIN_VALUE)) + assert doc.bin_field == Binary(BIN_VALUE) def test_update_one(self): """Ensures no regression of bug #1127""" @@ -139,9 +143,9 @@ class TestBinaryField(MongoDBTestCase): n_updated = MyDocument.objects(bin_field=bin_data).update_one( bin_field=BIN_VALUE ) - self.assertEqual(n_updated, 1) + assert n_updated == 1 fetched = MyDocument.objects.with_id(doc.id) if six.PY3: - self.assertEqual(fetched.bin_field, BIN_VALUE) + assert fetched.bin_field == BIN_VALUE else: - self.assertEqual(fetched.bin_field, Binary(BIN_VALUE)) + assert fetched.bin_field == Binary(BIN_VALUE) diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py index 22ebb6f7..b38b5ea4 100644 --- a/tests/fields/test_boolean_field.py +++ b/tests/fields/test_boolean_field.py @@ -2,6 +2,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest class TestBooleanField(MongoDBTestCase): @@ -11,7 +12,7 @@ class TestBooleanField(MongoDBTestCase): person = Person(admin=True) person.save() - self.assertEqual(get_as_pymongo(person), {"_id": person.id, "admin": True}) + assert get_as_pymongo(person) == {"_id": person.id, "admin": True} def test_validation(self): """Ensure that invalid values cannot be assigned to boolean @@ -26,11 +27,14 @@ class TestBooleanField(MongoDBTestCase): person.validate() person.admin = 2 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.admin = "Yes" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.admin = "False" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_weirdness_constructor(self): """When attribute is set in contructor, it gets cast into a bool @@ -42,7 +46,7 @@ class TestBooleanField(MongoDBTestCase): admin = BooleanField() new_person = Person(admin="False") - self.assertTrue(new_person.admin) + assert new_person.admin new_person = Person(admin="0") - self.assertTrue(new_person.admin) + assert new_person.admin diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index 4e467587..e404aae0 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -4,6 +4,7 @@ from decimal import Decimal from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestCachedReferenceField(MongoDBTestCase): @@ -46,29 +47,29 @@ class TestCachedReferenceField(MongoDBTestCase): a = Animal(name="Leopard", tag="heavy") a.save() - self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) + assert Animal._cached_reference_fields == [Ocorrence.animal] o = Ocorrence(person="teste", animal=a) o.save() p = Ocorrence(person="Wilson") p.save() - self.assertEqual(Ocorrence.objects(animal=None).count(), 1) + assert Ocorrence.objects(animal=None).count() == 1 - self.assertEqual(a.to_mongo(fields=["tag"]), {"tag": "heavy", "_id": a.pk}) + assert a.to_mongo(fields=["tag"]) == {"tag": "heavy", "_id": a.pk} - self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") + assert o.to_mongo()["animal"]["tag"] == "heavy" # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() count = Ocorrence.objects(animal__tag="heavy").count() - self.assertEqual(count, 1) + assert count == 1 ocorrence = Ocorrence.objects(animal__tag="heavy").first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) + assert ocorrence.person == "teste" + assert isinstance(ocorrence.animal, Animal) def test_with_decimal(self): class PersonAuto(Document): @@ -88,10 +89,11 @@ class TestCachedReferenceField(MongoDBTestCase): s = SocialTest(group="dev", person=p) s.save() - self.assertEqual( - SocialTest.objects._collection.find_one({"person.salary": 7000.00}), - {"_id": s.pk, "group": s.group, "person": {"_id": p.pk, "salary": 7000.00}}, - ) + assert SocialTest.objects._collection.find_one({"person.salary": 7000.00}) == { + "_id": s.pk, + "group": s.group, + "person": {"_id": p.pk, "salary": 7000.00}, + } def test_cached_reference_field_reference(self): class Group(Document): @@ -131,18 +133,15 @@ class TestCachedReferenceField(MongoDBTestCase): s2 = SocialData(obs="testing 321", person=p3, tags=["tag3", "tag4"]) s2.save() - self.assertEqual( - SocialData.objects._collection.find_one({"tags": "tag2"}), - { - "_id": s1.pk, - "obs": "testing 123", - "tags": ["tag1", "tag2"], - "person": {"_id": p1.pk, "group": g1.pk}, - }, - ) + assert SocialData.objects._collection.find_one({"tags": "tag2"}) == { + "_id": s1.pk, + "obs": "testing 123", + "tags": ["tag1", "tag2"], + "person": {"_id": p1.pk, "group": g1.pk}, + } - self.assertEqual(SocialData.objects(person__group=g2).count(), 1) - self.assertEqual(SocialData.objects(person__group=g2).first(), s2) + assert SocialData.objects(person__group=g2).count() == 1 + assert SocialData.objects(person__group=g2).first() == s2 def test_cached_reference_field_push_with_fields(self): class Product(Document): @@ -157,26 +156,20 @@ class TestCachedReferenceField(MongoDBTestCase): product1 = Product(name="abc").save() product2 = Product(name="def").save() basket = Basket(products=[product1]).save() - self.assertEqual( - Basket.objects._collection.find_one(), - { - "_id": basket.pk, - "products": [{"_id": product1.pk, "name": product1.name}], - }, - ) + assert Basket.objects._collection.find_one() == { + "_id": basket.pk, + "products": [{"_id": product1.pk, "name": product1.name}], + } # push to list basket.update(push__products=product2) basket.reload() - self.assertEqual( - Basket.objects._collection.find_one(), - { - "_id": basket.pk, - "products": [ - {"_id": product1.pk, "name": product1.name}, - {"_id": product2.pk, "name": product2.name}, - ], - }, - ) + assert Basket.objects._collection.find_one() == { + "_id": basket.pk, + "products": [ + {"_id": product1.pk, "name": product1.name}, + {"_id": product2.pk, "name": product2.name}, + ], + } def test_cached_reference_field_update_all(self): class Person(Document): @@ -194,37 +187,31 @@ class TestCachedReferenceField(MongoDBTestCase): a2.save() a2 = Person.objects.with_id(a2.id) - self.assertEqual(a2.father.tp, a1.tp) + assert a2.father.tp == a1.tp - self.assertEqual( - dict(a2.to_mongo()), - { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": {"_id": a1.pk, "tp": u"pj"}, - }, - ) + assert dict(a2.to_mongo()) == { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": {"_id": a1.pk, "tp": u"pj"}, + } - self.assertEqual(Person.objects(father=a1)._query, {"father._id": a1.pk}) - self.assertEqual(Person.objects(father=a1).count(), 1) + assert Person.objects(father=a1)._query == {"father._id": a1.pk} + assert Person.objects(father=a1).count() == 1 Person.objects.update(set__tp="pf") Person.father.sync_all() a2.reload() - self.assertEqual( - dict(a2.to_mongo()), - { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": {"_id": a1.pk, "tp": u"pf"}, - }, - ) + assert dict(a2.to_mongo()) == { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": {"_id": a1.pk, "tp": u"pf"}, + } def test_cached_reference_fields_on_embedded_documents(self): - with self.assertRaises(InvalidDocumentError): + with pytest.raises(InvalidDocumentError): class Test(Document): name = StringField() @@ -255,15 +242,12 @@ class TestCachedReferenceField(MongoDBTestCase): a1.save() a2.reload() - self.assertEqual( - dict(a2.to_mongo()), - { - "_id": a2.pk, - "name": "Wilson Junior", - "tp": "pf", - "father": {"_id": a1.pk, "tp": "pf"}, - }, - ) + assert dict(a2.to_mongo()) == { + "_id": a2.pk, + "name": "Wilson Junior", + "tp": "pf", + "father": {"_id": a1.pk, "tp": "pf"}, + } def test_cached_reference_auto_sync_disabled(self): class Persone(Document): @@ -284,15 +268,12 @@ class TestCachedReferenceField(MongoDBTestCase): a1.tp = "pf" a1.save() - self.assertEqual( - Persone.objects._collection.find_one({"_id": a2.pk}), - { - "_id": a2.pk, - "name": "Wilson Junior", - "tp": "pf", - "father": {"_id": a1.pk, "tp": "pj"}, - }, - ) + assert Persone.objects._collection.find_one({"_id": a2.pk}) == { + "_id": a2.pk, + "name": "Wilson Junior", + "tp": "pf", + "father": {"_id": a1.pk, "tp": "pj"}, + } def test_cached_reference_embedded_fields(self): class Owner(EmbeddedDocument): @@ -320,28 +301,29 @@ class TestCachedReferenceField(MongoDBTestCase): o = Ocorrence(person="teste", animal=a) o.save() - self.assertEqual( - dict(a.to_mongo(fields=["tag", "owner.tp"])), - {"_id": a.pk, "tag": "heavy", "owner": {"t": "u"}}, - ) - self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") - self.assertEqual(o.to_mongo()["animal"]["owner"]["t"], "u") + assert dict(a.to_mongo(fields=["tag", "owner.tp"])) == { + "_id": a.pk, + "tag": "heavy", + "owner": {"t": "u"}, + } + assert o.to_mongo()["animal"]["tag"] == "heavy" + assert o.to_mongo()["animal"]["owner"]["t"] == "u" # Check to_mongo with fields - self.assertNotIn("animal", o.to_mongo(fields=["person"])) + assert "animal" not in o.to_mongo(fields=["person"]) # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() count = Ocorrence.objects(animal__tag="heavy", animal__owner__tp="u").count() - self.assertEqual(count, 1) + assert count == 1 ocorrence = Ocorrence.objects( animal__tag="heavy", animal__owner__tp="u" ).first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) + assert ocorrence.person == "teste" + assert isinstance(ocorrence.animal, Animal) def test_cached_reference_embedded_list_fields(self): class Owner(EmbeddedDocument): @@ -370,13 +352,14 @@ class TestCachedReferenceField(MongoDBTestCase): o = Ocorrence(person="teste 2", animal=a) o.save() - self.assertEqual( - dict(a.to_mongo(fields=["tag", "owner.tags"])), - {"_id": a.pk, "tag": "heavy", "owner": {"tags": ["cool", "funny"]}}, - ) + assert dict(a.to_mongo(fields=["tag", "owner.tags"])) == { + "_id": a.pk, + "tag": "heavy", + "owner": {"tags": ["cool", "funny"]}, + } - self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") - self.assertEqual(o.to_mongo()["animal"]["owner"]["tags"], ["cool", "funny"]) + assert o.to_mongo()["animal"]["tag"] == "heavy" + assert o.to_mongo()["animal"]["owner"]["tags"] == ["cool", "funny"] # counts Ocorrence(person="teste 2").save() @@ -385,10 +368,10 @@ class TestCachedReferenceField(MongoDBTestCase): query = Ocorrence.objects( animal__tag="heavy", animal__owner__tags="cool" )._query - self.assertEqual(query, {"animal.owner.tags": "cool", "animal.tag": "heavy"}) + assert query == {"animal.owner.tags": "cool", "animal.tag": "heavy"} ocorrence = Ocorrence.objects( animal__tag="heavy", animal__owner__tags="cool" ).first() - self.assertEqual(ocorrence.person, "teste 2") - self.assertIsInstance(ocorrence.animal, Animal) + assert ocorrence.person == "teste 2" + assert isinstance(ocorrence.animal, Animal) diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py index 611c0ff8..f0a6b96e 100644 --- a/tests/fields/test_complex_datetime_field.py +++ b/tests/fields/test_complex_datetime_field.py @@ -28,7 +28,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 # Post UTC - microseconds are rounded (down) nearest millisecond - with # default datetimefields @@ -36,7 +36,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 # Pre UTC dates microseconds below 1000 are dropped - with default # datetimefields @@ -44,7 +44,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 # Pre UTC microseconds above 1000 is wonky - with default datetimefields # log.date has an invalid microsecond value so I can't construct @@ -54,9 +54,9 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1) + assert log.date == d1 log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) + assert log == log1 # Test string padding microsecond = map(int, [math.pow(10, x) for x in range(6)]) @@ -64,7 +64,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): stored = LogEntry(date=datetime.datetime(*values)).to_mongo()["date"] - self.assertTrue( + assert ( re.match("^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$", stored) is not None ) @@ -73,7 +73,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()[ "date_with_dots" ] - self.assertTrue( + assert ( re.match("^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$", stored) is not None ) @@ -93,40 +93,40 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): log.save() log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) + assert log == log1 # create extra 59 log entries for a total of 60 for i in range(1951, 2010): d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) LogEntry(date=d).save() - self.assertEqual(LogEntry.objects.count(), 60) + assert LogEntry.objects.count() == 60 # Test ordering logs = LogEntry.objects.order_by("date") i = 0 while i < 59: - self.assertTrue(logs[i].date <= logs[i + 1].date) + assert logs[i].date <= logs[i + 1].date i += 1 logs = LogEntry.objects.order_by("-date") i = 0 while i < 59: - self.assertTrue(logs[i].date >= logs[i + 1].date) + assert logs[i].date >= logs[i + 1].date i += 1 # Test searching logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) + assert logs.count() == 30 logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) + assert logs.count() == 30 logs = LogEntry.objects.filter( date__lte=datetime.datetime(2011, 1, 1), date__gte=datetime.datetime(2000, 1, 1), ) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 LogEntry.drop_collection() @@ -137,17 +137,17 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): logs = list(LogEntry.objects.order_by("date")) for next_idx, log in enumerate(logs[:-1], start=1): next_log = logs[next_idx] - self.assertTrue(log.date < next_log.date) + assert log.date < next_log.date logs = list(LogEntry.objects.order_by("-date")) for next_idx, log in enumerate(logs[:-1], start=1): next_log = logs[next_idx] - self.assertTrue(log.date > next_log.date) + assert log.date > next_log.date logs = LogEntry.objects.filter( date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000) ) - self.assertEqual(logs.count(), 4) + assert logs.count() == 4 def test_no_default_value(self): class Log(Document): @@ -156,11 +156,11 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): Log.drop_collection() log = Log() - self.assertIsNone(log.timestamp) + assert log.timestamp is None log.save() fetched_log = Log.objects.with_id(log.id) - self.assertIsNone(fetched_log.timestamp) + assert fetched_log.timestamp is None def test_default_static_value(self): NOW = datetime.datetime.utcnow() @@ -171,11 +171,11 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): Log.drop_collection() log = Log() - self.assertEqual(log.timestamp, NOW) + assert log.timestamp == NOW log.save() fetched_log = Log.objects.with_id(log.id) - self.assertEqual(fetched_log.timestamp, NOW) + assert fetched_log.timestamp == NOW def test_default_callable(self): NOW = datetime.datetime.utcnow() @@ -186,8 +186,8 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): Log.drop_collection() log = Log() - self.assertGreaterEqual(log.timestamp, NOW) + assert log.timestamp >= NOW log.save() fetched_log = Log.objects.with_id(log.id) - self.assertGreaterEqual(fetched_log.timestamp, NOW) + assert fetched_log.timestamp >= NOW diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index da572134..46fa4f0f 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -10,6 +10,7 @@ except ImportError: from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestDateField(MongoDBTestCase): @@ -23,7 +24,8 @@ class TestDateField(MongoDBTestCase): dt = DateField() md = MyDoc(dt="") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_date_from_whitespace_string(self): """ @@ -35,7 +37,8 @@ class TestDateField(MongoDBTestCase): dt = DateField() md = MyDoc(dt=" ") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_default_values_today(self): """Ensure that default field values are used when creating @@ -47,9 +50,9 @@ class TestDateField(MongoDBTestCase): person = Person() person.validate() - self.assertEqual(person.day, person.day) - self.assertEqual(person.day, datetime.date.today()) - self.assertEqual(person._data["day"], person.day) + assert person.day == person.day + assert person.day == datetime.date.today() + assert person._data["day"] == person.day def test_date(self): """Tests showing pymongo date fields @@ -67,7 +70,7 @@ class TestDateField(MongoDBTestCase): log.date = datetime.date.today() log.save() log.reload() - self.assertEqual(log.date, datetime.date.today()) + assert log.date == datetime.date.today() d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) @@ -75,16 +78,16 @@ class TestDateField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) + assert log.date == d1.date() + assert log.date == d2.date() d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) + assert log.date == d1.date() + assert log.date == d2.date() if not six.PY3: # Pre UTC dates microseconds below 1000 are dropped @@ -94,8 +97,8 @@ class TestDateField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) + assert log.date == d1.date() + assert log.date == d2.date() def test_regular_usage(self): """Tests for regular datetime fields""" @@ -113,35 +116,35 @@ class TestDateField(MongoDBTestCase): for query in (d1, d1.isoformat(" ")): log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) + assert log == log1 if dateutil: log1 = LogEntry.objects.get(date=d1.isoformat("T")) - self.assertEqual(log, log1) + assert log == log1 # create additional 19 log entries for a total of 20 for i in range(1971, 1990): d = datetime.datetime(i, 1, 1, 0, 0, 1) LogEntry(date=d).save() - self.assertEqual(LogEntry.objects.count(), 20) + assert LogEntry.objects.count() == 20 # Test ordering logs = LogEntry.objects.order_by("date") i = 0 while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) + assert logs[i].date <= logs[i + 1].date i += 1 logs = LogEntry.objects.order_by("-date") i = 0 while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) + assert logs[i].date >= logs[i + 1].date i += 1 # Test searching logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 def test_validation(self): """Ensure that invalid values cannot be assigned to datetime @@ -166,6 +169,8 @@ class TestDateField(MongoDBTestCase): log.validate() log.time = -1 - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "ABC" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index c911390a..8db491c6 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -11,6 +11,7 @@ from mongoengine import * from mongoengine import connection from tests.utils import MongoDBTestCase +import pytest class TestDateTimeField(MongoDBTestCase): @@ -24,7 +25,8 @@ class TestDateTimeField(MongoDBTestCase): dt = DateTimeField() md = MyDoc(dt="") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_datetime_from_whitespace_string(self): """ @@ -36,7 +38,8 @@ class TestDateTimeField(MongoDBTestCase): dt = DateTimeField() md = MyDoc(dt=" ") - self.assertRaises(ValidationError, md.save) + with pytest.raises(ValidationError): + md.save() def test_default_value_utcnow(self): """Ensure that default field values are used when creating @@ -50,11 +53,9 @@ class TestDateTimeField(MongoDBTestCase): person = Person() person.validate() person_created_t0 = person.created - self.assertLess(person.created - utcnow, dt.timedelta(seconds=1)) - self.assertEqual( - person_created_t0, person.created - ) # make sure it does not change - self.assertEqual(person._data["created"], person.created) + assert person.created - utcnow < dt.timedelta(seconds=1) + assert person_created_t0 == person.created # make sure it does not change + assert person._data["created"] == person.created def test_handling_microseconds(self): """Tests showing pymongo datetime fields handling of microseconds. @@ -74,7 +75,7 @@ class TestDateTimeField(MongoDBTestCase): log.date = dt.date.today() log.save() log.reload() - self.assertEqual(log.date.date(), dt.date.today()) + assert log.date.date() == dt.date.today() # Post UTC - microseconds are rounded (down) nearest millisecond and # dropped @@ -84,8 +85,8 @@ class TestDateTimeField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) + assert log.date != d1 + assert log.date == d2 # Post UTC - microseconds are rounded (down) nearest millisecond d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 9999) @@ -93,8 +94,8 @@ class TestDateTimeField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) + assert log.date != d1 + assert log.date == d2 if not six.PY3: # Pre UTC dates microseconds below 1000 are dropped @@ -104,8 +105,8 @@ class TestDateTimeField(MongoDBTestCase): log.date = d1 log.save() log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) + assert log.date != d1 + assert log.date == d2 def test_regular_usage(self): """Tests for regular datetime fields""" @@ -123,43 +124,43 @@ class TestDateTimeField(MongoDBTestCase): for query in (d1, d1.isoformat(" ")): log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) + assert log == log1 if dateutil: log1 = LogEntry.objects.get(date=d1.isoformat("T")) - self.assertEqual(log, log1) + assert log == log1 # create additional 19 log entries for a total of 20 for i in range(1971, 1990): d = dt.datetime(i, 1, 1, 0, 0, 1) LogEntry(date=d).save() - self.assertEqual(LogEntry.objects.count(), 20) + assert LogEntry.objects.count() == 20 # Test ordering logs = LogEntry.objects.order_by("date") i = 0 while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) + assert logs[i].date <= logs[i + 1].date i += 1 logs = LogEntry.objects.order_by("-date") i = 0 while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) + assert logs[i].date >= logs[i + 1].date i += 1 # Test searching logs = LogEntry.objects.filter(date__gte=dt.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 logs = LogEntry.objects.filter(date__lte=dt.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) + assert logs.count() == 10 logs = LogEntry.objects.filter( date__lte=dt.datetime(1980, 1, 1), date__gte=dt.datetime(1975, 1, 1) ) - self.assertEqual(logs.count(), 5) + assert logs.count() == 5 def test_datetime_validation(self): """Ensure that invalid values cannot be assigned to datetime @@ -187,15 +188,20 @@ class TestDateTimeField(MongoDBTestCase): log.validate() log.time = -1 - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "ABC" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "2019-05-16 21:GARBAGE:12" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "2019-05-16 21:42:57.GARBAGE" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() log.time = "2019-05-16 21:42:57.123.456" - self.assertRaises(ValidationError, log.validate) + with pytest.raises(ValidationError): + log.validate() def test_parse_datetime_as_str(self): class DTDoc(Document): @@ -206,15 +212,16 @@ class TestDateTimeField(MongoDBTestCase): # make sure that passing a parsable datetime works dtd = DTDoc() dtd.date = date_str - self.assertIsInstance(dtd.date, six.string_types) + assert isinstance(dtd.date, six.string_types) dtd.save() dtd.reload() - self.assertIsInstance(dtd.date, dt.datetime) - self.assertEqual(str(dtd.date), date_str) + assert isinstance(dtd.date, dt.datetime) + assert str(dtd.date) == date_str dtd.date = "January 1st, 9999999999" - self.assertRaises(ValidationError, dtd.validate) + with pytest.raises(ValidationError): + dtd.validate() class TestDateTimeTzAware(MongoDBTestCase): @@ -235,4 +242,4 @@ class TestDateTimeTzAware(MongoDBTestCase): log = LogEntry.objects.first() log.time = dt.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(["time"], log._changed_fields) + assert ["time"] == log._changed_fields diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py index 30b7e5ea..b5b95363 100644 --- a/tests/fields/test_decimal_field.py +++ b/tests/fields/test_decimal_field.py @@ -4,6 +4,7 @@ from decimal import Decimal from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestDecimalField(MongoDBTestCase): @@ -18,21 +19,26 @@ class TestDecimalField(MongoDBTestCase): Person(height=Decimal("1.89")).save() person = Person.objects.first() - self.assertEqual(person.height, Decimal("1.89")) + assert person.height == Decimal("1.89") person.height = "2.0" person.save() person.height = 0.01 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = Decimal("0.01") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = Decimal("4.0") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = "something invalid" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person_2 = Person(height="something invalid") - self.assertRaises(ValidationError, person_2.validate) + with pytest.raises(ValidationError): + person_2.validate() def test_comparison(self): class Person(Document): @@ -45,11 +51,11 @@ class TestDecimalField(MongoDBTestCase): Person(money=8).save() Person(money=10).save() - self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) - self.assertEqual(2, Person.objects(money__gt=7).count()) - self.assertEqual(2, Person.objects(money__gt="7").count()) + assert 2 == Person.objects(money__gt=Decimal("7")).count() + assert 2 == Person.objects(money__gt=7).count() + assert 2 == Person.objects(money__gt="7").count() - self.assertEqual(3, Person.objects(money__gte="7").count()) + assert 3 == Person.objects(money__gte="7").count() def test_storage(self): class Person(Document): @@ -87,7 +93,7 @@ class TestDecimalField(MongoDBTestCase): ] expected.extend(expected) actual = list(Person.objects.exclude("id").as_pymongo()) - self.assertEqual(expected, actual) + assert expected == actual # How it comes out locally expected = [ @@ -101,4 +107,4 @@ class TestDecimalField(MongoDBTestCase): expected.extend(expected) for field_name in ["float_value", "string_value"]: actual = list(Person.objects().scalar(field_name)) - self.assertEqual(expected, actual) + assert expected == actual diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 07bab85b..56df682f 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -3,6 +3,7 @@ from mongoengine import * from mongoengine.base import BaseDict from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest class TestDictField(MongoDBTestCase): @@ -14,7 +15,7 @@ class TestDictField(MongoDBTestCase): info = {"testkey": "testvalue"} post = BlogPost(info=info).save() - self.assertEqual(get_as_pymongo(post), {"_id": post.id, "info": info}) + assert get_as_pymongo(post) == {"_id": post.id, "info": info} def test_general_things(self): """Ensure that dict types work as expected.""" @@ -26,25 +27,32 @@ class TestDictField(MongoDBTestCase): post = BlogPost() post.info = "my post" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = ["test", "test"] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"$title": "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"nested": {"$title": "test"}} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"the.title": "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"nested": {"the.title": "test"}} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {1: "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"title": "test"} post.save() @@ -61,33 +69,27 @@ class TestDictField(MongoDBTestCase): post.info = {"details": {"test": 3}} post.save() - self.assertEqual(BlogPost.objects.count(), 4) - self.assertEqual(BlogPost.objects.filter(info__title__exact="test").count(), 1) - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact="test").count(), 1 - ) + assert BlogPost.objects.count() == 4 + assert BlogPost.objects.filter(info__title__exact="test").count() == 1 + assert BlogPost.objects.filter(info__details__test__exact="test").count() == 1 post = BlogPost.objects.filter(info__title__exact="dollar_sign").first() - self.assertIn("te$t", post["info"]["details"]) + assert "te$t" in post["info"]["details"] # Confirm handles non strings or non existing keys - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact=5).count(), 0 - ) - self.assertEqual( - BlogPost.objects.filter(info__made_up__test__exact="test").count(), 0 - ) + assert BlogPost.objects.filter(info__details__test__exact=5).count() == 0 + assert BlogPost.objects.filter(info__made_up__test__exact="test").count() == 0 post = BlogPost.objects.create(info={"title": "original"}) post.info.update({"title": "updated"}) post.save() post.reload() - self.assertEqual("updated", post.info["title"]) + assert "updated" == post.info["title"] post.info.setdefault("authors", []) post.save() post.reload() - self.assertEqual([], post.info["authors"]) + assert [] == post.info["authors"] def test_dictfield_dump_document(self): """Ensure a DictField can handle another document's dump.""" @@ -114,10 +116,8 @@ class TestDictField(MongoDBTestCase): ).save() doc = Doc(field=to_embed.to_mongo().to_dict()) doc.save() - self.assertIsInstance(doc.field, dict) - self.assertEqual( - doc.field, {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} - ) + assert isinstance(doc.field, dict) + assert doc.field == {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild( @@ -125,7 +125,7 @@ class TestDictField(MongoDBTestCase): ).save() doc = Doc(field=to_embed_child.to_mongo().to_dict()) doc.save() - self.assertIsInstance(doc.field, dict) + assert isinstance(doc.field, dict) expected = { "_id": 2, "_cls": "ToEmbedParent.ToEmbedChild", @@ -135,7 +135,7 @@ class TestDictField(MongoDBTestCase): "recursive": {}, }, } - self.assertEqual(doc.field, expected) + assert doc.field == expected def test_dictfield_strict(self): """Ensure that dict field handles validation if provided a strict field type.""" @@ -150,7 +150,7 @@ class TestDictField(MongoDBTestCase): e.save() # try creating an invalid mapping - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping["somestring"] = "abc" e.save() @@ -184,22 +184,21 @@ class TestDictField(MongoDBTestCase): e.save() e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping["somestring"], StringSetting) - self.assertIsInstance(e2.mapping["someint"], IntegerSetting) + assert isinstance(e2.mapping["somestring"], StringSetting) + assert isinstance(e2.mapping["someint"], IntegerSetting) # Test querying - self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__number=1).count(), 1 + assert Simple.objects.filter(mapping__someint__value=42).count() == 1 + assert Simple.objects.filter(mapping__nested_dict__number=1).count() == 1 + assert ( + Simple.objects.filter(mapping__nested_dict__complex__value=42).count() == 1 ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1 + assert ( + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count() == 1 ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1 - ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 1 + assert ( + Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count() + == 1 ) # Confirm can update @@ -207,11 +206,13 @@ class TestDictField(MongoDBTestCase): Simple.objects().update( set__mapping__nested_dict__list__1=StringSetting(value="Boo") ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 0 + assert ( + Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count() + == 0 ) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value="Boo").count(), 1 + assert ( + Simple.objects.filter(mapping__nested_dict__list__1__value="Boo").count() + == 1 ) def test_push_dict(self): @@ -221,12 +222,12 @@ class TestDictField(MongoDBTestCase): doc = MyModel(events=[{"a": 1}]).save() raw_doc = get_as_pymongo(doc) expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}]} - self.assertEqual(raw_doc, expected_raw_doc) + assert raw_doc == expected_raw_doc MyModel.objects(id=doc.id).update(push__events={}) raw_doc = get_as_pymongo(doc) expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}, {}]} - self.assertEqual(raw_doc, expected_raw_doc) + assert raw_doc == expected_raw_doc def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" @@ -239,8 +240,8 @@ class TestDictField(MongoDBTestCase): d1.data["foo"] = "bar" d1.data2["foo"] = "bar" d2 = D() - self.assertEqual(d2.data, {}) - self.assertEqual(d2.data2, {}) + assert d2.data == {} + assert d2.data2 == {} def test_dict_field_invalid_dict_value(self): class DictFieldTest(Document): @@ -250,11 +251,13 @@ class TestDictField(MongoDBTestCase): test = DictFieldTest(dictionary=None) test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) + with pytest.raises(ValidationError): + test.validate() test = DictFieldTest(dictionary=False) test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) + with pytest.raises(ValidationError): + test.validate() def test_dict_field_raises_validation_error_if_wrongly_assign_embedded_doc(self): class DictFieldTest(Document): @@ -267,12 +270,10 @@ class TestDictField(MongoDBTestCase): embed = Embedded(name="garbage") doc = DictFieldTest(dictionary=embed) - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: doc.validate() - self.assertIn("'dictionary'", str(ctx_err.exception)) - self.assertIn( - "Only dictionaries may be used in a DictField", str(ctx_err.exception) - ) + assert "'dictionary'" in str(ctx_err.exception) + assert "Only dictionaries may be used in a DictField" in str(ctx_err.exception) def test_atomic_update_dict_field(self): """Ensure that the entire DictField can be atomically updated.""" @@ -287,11 +288,11 @@ class TestDictField(MongoDBTestCase): e.save() e.update(set__mapping={"ints": [3, 4]}) e.reload() - self.assertEqual(BaseDict, type(e.mapping)) - self.assertEqual({"ints": [3, 4]}, e.mapping) + assert BaseDict == type(e.mapping) + assert {"ints": [3, 4]} == e.mapping # try creating an invalid mapping - with self.assertRaises(ValueError): + with pytest.raises(ValueError): e.update(set__mapping={"somestrings": ["foo", "bar"]}) def test_dictfield_with_referencefield_complex_nesting_cases(self): @@ -329,13 +330,13 @@ class TestDictField(MongoDBTestCase): e.save() s = Simple.objects.first() - self.assertIsInstance(s.mapping0["someint"], Doc) - self.assertIsInstance(s.mapping1["someint"], Doc) - self.assertIsInstance(s.mapping2["someint"][0], Doc) - self.assertIsInstance(s.mapping3["someint"][0], Doc) - self.assertIsInstance(s.mapping4["someint"]["d"], Doc) - self.assertIsInstance(s.mapping5["someint"]["d"], Doc) - self.assertIsInstance(s.mapping6["someint"][0]["d"], Doc) - self.assertIsInstance(s.mapping7["someint"][0]["d"], Doc) - self.assertIsInstance(s.mapping8["someint"][0]["d"][0], Doc) - self.assertIsInstance(s.mapping9["someint"][0]["d"][0], Doc) + assert isinstance(s.mapping0["someint"], Doc) + assert isinstance(s.mapping1["someint"], Doc) + assert isinstance(s.mapping2["someint"][0], Doc) + assert isinstance(s.mapping3["someint"][0], Doc) + assert isinstance(s.mapping4["someint"]["d"], Doc) + assert isinstance(s.mapping5["someint"]["d"], Doc) + assert isinstance(s.mapping6["someint"][0]["d"], Doc) + assert isinstance(s.mapping7["someint"][0]["d"], Doc) + assert isinstance(s.mapping8["someint"][0]["d"][0], Doc) + assert isinstance(s.mapping9["someint"][0]["d"][0], Doc) diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index 06ec5151..b8d3d169 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -5,6 +5,7 @@ from unittest import SkipTest from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestEmailField(MongoDBTestCase): @@ -27,7 +28,8 @@ class TestEmailField(MongoDBTestCase): user.validate() user = User(email="ross@example.com.") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # unicode domain user = User(email=u"user@пример.рф") @@ -35,11 +37,13 @@ class TestEmailField(MongoDBTestCase): # invalid unicode domain user = User(email=u"user@пример") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # invalid data type user = User(email=123) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() def test_email_field_unicode_user(self): # Don't run this test on pypy3, which doesn't support unicode regex: @@ -52,7 +56,8 @@ class TestEmailField(MongoDBTestCase): # unicode user shouldn't validate by default... user = User(email=u"Dörte@Sörensen.example.com") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # ...but it should be fine with allow_utf8_user set to True class User(Document): @@ -67,7 +72,8 @@ class TestEmailField(MongoDBTestCase): # localhost domain shouldn't validate by default... user = User(email="me@localhost") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # ...but it should be fine if it's whitelisted class User(Document): @@ -82,9 +88,9 @@ class TestEmailField(MongoDBTestCase): invalid_idn = ".google.com" user = User(email="me@%s" % invalid_idn) - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: user.validate() - self.assertIn("domain failed IDN encoding", str(ctx_err.exception)) + assert "domain failed IDN encoding" in str(ctx_err.exception) def test_email_field_ip_domain(self): class User(Document): @@ -96,13 +102,16 @@ class TestEmailField(MongoDBTestCase): # IP address as a domain shouldn't validate by default... user = User(email=valid_ipv4) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() user = User(email=valid_ipv6) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # ...but it should be fine with allow_ip_domain set to True class User(Document): @@ -116,7 +125,8 @@ class TestEmailField(MongoDBTestCase): # invalid IP should still fail validation user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() def test_email_field_honors_regex(self): class User(Document): @@ -124,8 +134,9 @@ class TestEmailField(MongoDBTestCase): # Fails regex validation user = User(email="me@foo.com") - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() # Passes regex validation user = User(email="me@example.com") - self.assertIsNone(user.validate()) + assert user.validate() is None diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index 8db8c180..4fcf6bf1 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -13,6 +13,7 @@ from mongoengine import ( ) from tests.utils import MongoDBTestCase +import pytest class TestEmbeddedDocumentField(MongoDBTestCase): @@ -21,13 +22,13 @@ class TestEmbeddedDocumentField(MongoDBTestCase): name = StringField() field = EmbeddedDocumentField(MyDoc) - self.assertEqual(field.document_type_obj, MyDoc) + assert field.document_type_obj == MyDoc field2 = EmbeddedDocumentField("MyDoc") - self.assertEqual(field2.document_type_obj, "MyDoc") + assert field2.document_type_obj == "MyDoc" def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): EmbeddedDocumentField(dict) def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): @@ -35,11 +36,11 @@ class TestEmbeddedDocumentField(MongoDBTestCase): name = StringField() emb = EmbeddedDocumentField("MyDoc") - with self.assertRaises(ValidationError) as ctx: + with pytest.raises(ValidationError) as ctx: emb.document_type - self.assertIn( - "Invalid embedded document class provided to an EmbeddedDocumentField", - str(ctx.exception), + assert ( + "Invalid embedded document class provided to an EmbeddedDocumentField" + in str(ctx.exception) ) def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): @@ -47,12 +48,12 @@ class TestEmbeddedDocumentField(MongoDBTestCase): class MyDoc(Document): name = StringField() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): class MyFailingDoc(Document): emb = EmbeddedDocumentField(MyDoc) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): class MyFailingdoc2(Document): emb = EmbeddedDocumentField("MyDoc") @@ -71,24 +72,24 @@ class TestEmbeddedDocumentField(MongoDBTestCase): p = Person(settings=AdminSettings(foo1="bar1", foo2="bar2"), name="John").save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as ctx_err: Person.objects(settings__notexist="bar").first() - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' - with self.assertRaises(LookUpError): + with pytest.raises(LookUpError): Person.objects.only("settings.notexist") # Test existing attribute - self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p.id) + assert Person.objects(settings__foo1="bar1").first().id == p.id only_p = Person.objects.only("settings.foo1").first() - self.assertEqual(only_p.settings.foo1, p.settings.foo1) - self.assertIsNone(only_p.settings.foo2) - self.assertIsNone(only_p.name) + assert only_p.settings.foo1 == p.settings.foo1 + assert only_p.settings.foo2 is None + assert only_p.name is None exclude_p = Person.objects.exclude("settings.foo1").first() - self.assertIsNone(exclude_p.settings.foo1) - self.assertEqual(exclude_p.settings.foo2, p.settings.foo2) - self.assertEqual(exclude_p.name, p.name) + assert exclude_p.settings.foo1 is None + assert exclude_p.settings.foo2 == p.settings.foo2 + assert exclude_p.name == p.name def test_query_embedded_document_attribute_with_inheritance(self): class BaseSettings(EmbeddedDocument): @@ -107,17 +108,17 @@ class TestEmbeddedDocumentField(MongoDBTestCase): p.save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: - self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id) - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + with pytest.raises(InvalidQueryError) as ctx_err: + assert Person.objects(settings__notexist="bar").first().id == p.id + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' # Test existing attribute - self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id) - self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id) + assert Person.objects(settings__base_foo="basefoo").first().id == p.id + assert Person.objects(settings__sub_foo="subfoo").first().id == p.id only_p = Person.objects.only("settings.base_foo", "settings._cls").first() - self.assertEqual(only_p.settings.base_foo, "basefoo") - self.assertIsNone(only_p.settings.sub_foo) + assert only_p.settings.base_foo == "basefoo" + assert only_p.settings.sub_foo is None def test_query_list_embedded_document_with_inheritance(self): class Post(EmbeddedDocument): @@ -137,14 +138,14 @@ class TestEmbeddedDocumentField(MongoDBTestCase): record_text = Record(posts=[TextPost(content="a", title="foo")]).save() records = list(Record.objects(posts__author=record_movie.posts[0].author)) - self.assertEqual(len(records), 1) - self.assertEqual(records[0].id, record_movie.id) + assert len(records) == 1 + assert records[0].id == record_movie.id records = list(Record.objects(posts__content=record_text.posts[0].content)) - self.assertEqual(len(records), 1) - self.assertEqual(records[0].id, record_text.id) + assert len(records) == 1 + assert records[0].id == record_text.id - self.assertEqual(Record.objects(posts__title="foo").count(), 2) + assert Record.objects(posts__title="foo").count() == 2 class TestGenericEmbeddedDocumentField(MongoDBTestCase): @@ -167,13 +168,13 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): person.save() person = Person.objects.first() - self.assertIsInstance(person.like, Car) + assert isinstance(person.like, Car) person.like = Dish(food="arroz", number=15) person.save() person = Person.objects.first() - self.assertIsInstance(person.like, Dish) + assert isinstance(person.like, Dish) def test_generic_embedded_document_choices(self): """Ensure you can limit GenericEmbeddedDocument choices.""" @@ -193,13 +194,14 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): person = Person(name="Test User") person.like = Car(name="Fiat") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.like = Dish(food="arroz", number=15) person.save() person = Person.objects.first() - self.assertIsInstance(person.like, Dish) + assert isinstance(person.like, Dish) def test_generic_list_embedded_document_choices(self): """Ensure you can limit GenericEmbeddedDocument choices inside @@ -221,13 +223,14 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): person = Person(name="Test User") person.likes = [Car(name="Fiat")] - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.likes = [Dish(food="arroz", number=15)] person.save() person = Person.objects.first() - self.assertIsInstance(person.likes[0], Dish) + assert isinstance(person.likes[0], Dish) def test_choices_validation_documents(self): """ @@ -263,7 +266,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): # Single Entry Failure post = BlogPost(comments=[ModeratorComments(author="mod1", message="message1")]) - self.assertRaises(ValidationError, post.save) + with pytest.raises(ValidationError): + post.save() # Mixed Entry Failure post = BlogPost( @@ -272,7 +276,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): UserComments(author="user2", message="message2"), ] ) - self.assertRaises(ValidationError, post.save) + with pytest.raises(ValidationError): + post.save() def test_choices_validation_documents_inheritance(self): """ @@ -311,16 +316,16 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): p2 = Person(settings=NonAdminSettings(foo2="bar2")).save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as ctx_err: Person.objects(settings__notexist="bar").first() - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' - with self.assertRaises(LookUpError): + with pytest.raises(LookUpError): Person.objects.only("settings.notexist") # Test existing attribute - self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p1.id) - self.assertEqual(Person.objects(settings__foo2="bar2").first().id, p2.id) + assert Person.objects(settings__foo1="bar1").first().id == p1.id + assert Person.objects(settings__foo2="bar2").first().id == p2.id def test_query_generic_embedded_document_attribute_with_inheritance(self): class BaseSettings(EmbeddedDocument): @@ -339,10 +344,10 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): p.save() # Test non exiting attribute - with self.assertRaises(InvalidQueryError) as ctx_err: - self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id) - self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + with pytest.raises(InvalidQueryError) as ctx_err: + assert Person.objects(settings__notexist="bar").first().id == p.id + assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' # Test existing attribute - self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id) - self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id) + assert Person.objects(settings__base_foo="basefoo").first().id == p.id + assert Person.objects(settings__sub_foo="subfoo").first().id == p.id diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index bd2149e6..b27d95d2 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -39,6 +39,7 @@ from mongoengine.base import BaseField, EmbeddedDocumentList, _document_registry from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase +import pytest class TestField(MongoDBTestCase): @@ -58,25 +59,25 @@ class TestField(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) + assert data_to_be_saved == ["age", "created", "day", "name", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, person.name) - self.assertEqual(person.age, person.age) - self.assertEqual(person.userid, person.userid) - self.assertEqual(person.created, person.created) - self.assertEqual(person.day, person.day) + assert person.name == person.name + assert person.age == person.age + assert person.userid == person.userid + assert person.created == person.created + assert person.day == person.day - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) - self.assertEqual(person._data["day"], person.day) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created + assert person._data["day"] == person.day # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) + assert data_to_be_saved == ["age", "created", "day", "name", "userid"] def test_custom_field_validation_raise_deprecated_error_when_validation_return_something( self, @@ -95,13 +96,13 @@ class TestField(MongoDBTestCase): "it should raise a ValidationError if validation fails" ) - with self.assertRaises(DeprecatedError) as ctx_err: + with pytest.raises(DeprecatedError) as ctx_err: Person(name="").validate() - self.assertEqual(str(ctx_err.exception), error) + assert str(ctx_err.exception) == error - with self.assertRaises(DeprecatedError) as ctx_err: + with pytest.raises(DeprecatedError) as ctx_err: Person(name="").save() - self.assertEqual(str(ctx_err.exception), error) + assert str(ctx_err.exception) == error def test_custom_field_validation_raise_validation_error(self): def _not_empty(z): @@ -113,18 +114,16 @@ class TestField(MongoDBTestCase): Person.drop_collection() - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: Person(name="").validate() - self.assertEqual( - "ValidationError (Person:None) (cantbeempty: ['name'])", - str(ctx_err.exception), + assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( + ctx_err.exception ) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Person(name="").save() - self.assertEqual( - "ValidationError (Person:None) (cantbeempty: ['name'])", - str(ctx_err.exception), + assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( + ctx_err.exception ) Person(name="garbage").validate() @@ -146,23 +145,23 @@ class TestField(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, person.name) - self.assertEqual(person.age, person.age) - self.assertEqual(person.userid, person.userid) - self.assertEqual(person.created, person.created) + assert person.name == person.name + assert person.age == person.age + assert person.userid == person.userid + assert person.created == person.created - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] def test_default_values_when_setting_to_None(self): """Ensure that default field values are used when creating @@ -183,23 +182,23 @@ class TestField(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, None) - self.assertEqual(person.age, 30) - self.assertEqual(person.userid, "test") - self.assertIsInstance(person.created, datetime.datetime) + assert person.name == None + assert person.age == 30 + assert person.userid == "test" + assert isinstance(person.created, datetime.datetime) - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc( self, @@ -213,7 +212,7 @@ class TestField(MongoDBTestCase): doc.x = [] doc.save() reloaded = Doc.objects.get(id=doc.id) - self.assertEqual(reloaded.x, []) + assert reloaded.x == [] def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc( self, @@ -228,7 +227,7 @@ class TestField(MongoDBTestCase): doc.y = 2 # Was triggering the bug doc.save() reloaded = Doc.objects.get(id=doc.id) - self.assertEqual(reloaded.x, []) + assert reloaded.x == [] def test_default_values_when_deleting_value(self): """Ensure that default field values are used after non-default @@ -253,24 +252,24 @@ class TestField(MongoDBTestCase): del person.created data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] - self.assertTrue(person.validate() is None) + assert person.validate() is None - self.assertEqual(person.name, None) - self.assertEqual(person.age, 30) - self.assertEqual(person.userid, "test") - self.assertIsInstance(person.created, datetime.datetime) - self.assertNotEqual(person.created, datetime.datetime(2014, 6, 12)) + assert person.name == None + assert person.age == 30 + assert person.userid == "test" + assert isinstance(person.created, datetime.datetime) + assert person.created != datetime.datetime(2014, 6, 12) - self.assertEqual(person._data["name"], person.name) - self.assertEqual(person._data["age"], person.age) - self.assertEqual(person._data["userid"], person.userid) - self.assertEqual(person._data["created"], person.created) + assert person._data["name"] == person.name + assert person._data["age"] == person.age + assert person._data["userid"] == person.userid + assert person._data["created"] == person.created # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) + assert data_to_be_saved == ["age", "created", "userid"] def test_required_values(self): """Ensure that required field constraints are enforced.""" @@ -281,9 +280,11 @@ class TestField(MongoDBTestCase): userid = StringField() person = Person(name="Test User") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person = Person(age=30) - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_not_required_handles_none_in_update(self): """Ensure that every fields should accept None if required is @@ -311,15 +312,15 @@ class TestField(MongoDBTestCase): set__flt_fld=None, set__comp_dt_fld=None, ) - self.assertEqual(res, 1) + assert res == 1 # Retrive data from db and verify it. ret = HandleNoneFields.objects.all()[0] - self.assertIsNone(ret.str_fld) - self.assertIsNone(ret.int_fld) - self.assertIsNone(ret.flt_fld) + assert ret.str_fld is None + assert ret.int_fld is None + assert ret.flt_fld is None - self.assertIsNone(ret.comp_dt_fld) + assert ret.comp_dt_fld is None def test_not_required_handles_none_from_database(self): """Ensure that every field can handle null values from the @@ -349,14 +350,15 @@ class TestField(MongoDBTestCase): # Retrive data from db and verify it. ret = HandleNoneFields.objects.first() - self.assertIsNone(ret.str_fld) - self.assertIsNone(ret.int_fld) - self.assertIsNone(ret.flt_fld) - self.assertIsNone(ret.comp_dt_fld) + assert ret.str_fld is None + assert ret.int_fld is None + assert ret.flt_fld is None + assert ret.comp_dt_fld is None # Retrieved object shouldn't pass validation when a re-save is # attempted. - self.assertRaises(ValidationError, ret.validate) + with pytest.raises(ValidationError): + ret.validate() def test_default_id_validation_as_objectid(self): """Ensure that invalid values cannot be assigned to an @@ -367,13 +369,15 @@ class TestField(MongoDBTestCase): name = StringField() person = Person(name="Test User") - self.assertEqual(person.id, None) + assert person.id == None person.id = 47 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.id = "abc" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.id = str(ObjectId()) person.validate() @@ -386,19 +390,22 @@ class TestField(MongoDBTestCase): userid = StringField(r"[0-9a-z_]+$") person = Person(name=34) - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() # Test regex validation on userid person = Person(userid="test.User") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.userid = "test_user" - self.assertEqual(person.userid, "test_user") + assert person.userid == "test_user" person.validate() # Test max length validation on name person = Person(name="Name that is more than twenty characters") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.name = "Shorter name" person.validate() @@ -407,19 +414,19 @@ class TestField(MongoDBTestCase): """Ensure that db_field doesn't accept invalid values.""" # dot in the name - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class User(Document): name = StringField(db_field="user.name") # name starting with $ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class UserX1(Document): name = StringField(db_field="$name") # name containing a null character - with self.assertRaises(ValueError): + with pytest.raises(ValueError): class UserX2(Document): name = StringField(db_field="name\0") @@ -455,9 +462,11 @@ class TestField(MongoDBTestCase): post.validate() post.tags = "fun" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.tags = [1, 2] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.tags = ["fun", "leisure"] post.validate() @@ -465,30 +474,36 @@ class TestField(MongoDBTestCase): post.validate() post.access_list = "a,b" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.access_list = ["c", "d"] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.access_list = ["a", "b"] post.validate() - self.assertEqual(post.get_access_list_display(), u"Administration, Manager") + assert post.get_access_list_display() == u"Administration, Manager" post.comments = ["a"] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.comments = "yay" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() comments = [Comment(content="Good for you"), Comment(content="Yay.")] post.comments = comments post.validate() post.authors = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.authors = [User()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() user = User() user.save() @@ -496,34 +511,42 @@ class TestField(MongoDBTestCase): post.validate() post.authors_as_lazy = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.authors_as_lazy = [User()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.authors_as_lazy = [user] post.validate() post.generic = [1, 2] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic = [User(), Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic = [user] post.validate() post.generic_as_lazy = [1, 2] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic_as_lazy = [User(), Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic_as_lazy = [Comment()] - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.generic_as_lazy = [user] post.validate() @@ -549,7 +572,7 @@ class TestField(MongoDBTestCase): post.tags = ["leisure", "fun"] post.save() post.reload() - self.assertEqual(post.tags, ["fun", "leisure"]) + assert post.tags == ["fun", "leisure"] comment1 = Comment(content="Good for you", order=1) comment2 = Comment(content="Yay.", order=0) @@ -557,15 +580,15 @@ class TestField(MongoDBTestCase): post.comments = comments post.save() post.reload() - self.assertEqual(post.comments[0].content, comment2.content) - self.assertEqual(post.comments[1].content, comment1.content) + assert post.comments[0].content == comment2.content + assert post.comments[1].content == comment1.content post.comments[0].order = 2 post.save() post.reload() - self.assertEqual(post.comments[0].content, comment1.content) - self.assertEqual(post.comments[1].content, comment2.content) + assert post.comments[0].content == comment1.content + assert post.comments[1].content == comment2.content def test_reverse_list_sorting(self): """Ensure that a reverse sorted list field properly sorts values""" @@ -590,9 +613,9 @@ class TestField(MongoDBTestCase): catlist.save() catlist.reload() - self.assertEqual(catlist.categories[0].name, cat2.name) - self.assertEqual(catlist.categories[1].name, cat3.name) - self.assertEqual(catlist.categories[2].name, cat1.name) + assert catlist.categories[0].name == cat2.name + assert catlist.categories[1].name == cat3.name + assert catlist.categories[2].name == cat1.name def test_list_field(self): """Ensure that list types work as expected.""" @@ -604,10 +627,12 @@ class TestField(MongoDBTestCase): post = BlogPost() post.info = "my post" - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = {"title": "test"} - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() post.info = ["test"] post.save() @@ -620,15 +645,13 @@ class TestField(MongoDBTestCase): post.info = [{"test": 3}] post.save() - self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual(BlogPost.objects.filter(info__exact="test").count(), 1) - self.assertEqual(BlogPost.objects.filter(info__0__test="test").count(), 1) + assert BlogPost.objects.count() == 3 + assert BlogPost.objects.filter(info__exact="test").count() == 1 + assert BlogPost.objects.filter(info__0__test="test").count() == 1 # Confirm handles non strings or non existing keys - self.assertEqual(BlogPost.objects.filter(info__0__test__exact="5").count(), 0) - self.assertEqual( - BlogPost.objects.filter(info__100__test__exact="test").count(), 0 - ) + assert BlogPost.objects.filter(info__0__test__exact="5").count() == 0 + assert BlogPost.objects.filter(info__100__test__exact="test").count() == 0 # test queries by list post = BlogPost() @@ -637,12 +660,12 @@ class TestField(MongoDBTestCase): post = BlogPost.objects(info=["1", "2"]).get() post.info += ["3", "4"] post.save() - self.assertEqual(BlogPost.objects(info=["1", "2", "3", "4"]).count(), 1) + assert BlogPost.objects(info=["1", "2", "3", "4"]).count() == 1 post = BlogPost.objects(info=["1", "2", "3", "4"]).get() post.info *= 2 post.save() - self.assertEqual( - BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count(), 1 + assert ( + BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count() == 1 ) def test_list_field_manipulative_operators(self): @@ -670,165 +693,149 @@ class TestField(MongoDBTestCase): reset_post() temp = ["a", "b"] post.info = post.info + temp - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] # '__delitem__(index)' # aka 'del list[index]' # aka 'operator.delitem(list, index)' reset_post() del post.info[2] # del from middle ('2') - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] # '__delitem__(slice(i, j))' # aka 'del list[i:j]' # aka 'operator.delitem(list, slice(i,j))' reset_post() del post.info[1:3] # removes '1', '2' - self.assertEqual(post.info, ["0", "3", "4", "5"]) + assert post.info == ["0", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "3", "4", "5"]) + assert post.info == ["0", "3", "4", "5"] # '__iadd__' # aka 'list += list' reset_post() temp = ["a", "b"] post.info += temp - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "a", "b"] # '__imul__' # aka 'list *= number' reset_post() post.info *= 2 - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__mul__' # aka 'listA*listB' reset_post() post.info = post.info * 2 - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__rmul__' # aka 'listB*listA' reset_post() post.info = 2 * post.info - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] # '__setitem__(index, value)' # aka 'list[index]=value' # aka 'setitem(list, value)' reset_post() post.info[4] = "a" - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] # __setitem__(index, value) with a negative index reset_post() post.info[-2] = "a" - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) + assert post.info == ["0", "1", "2", "3", "a", "5"] # '__setitem__(slice(i, j), listB)' # aka 'listA[i:j] = listB' # aka 'setitem(listA, slice(i, j), listB)' reset_post() post.info[1:3] = ["h", "e", "l", "l", "o"] - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] # '__setitem__(slice(i, j), listB)' with negative i and j reset_post() post.info[-5:-3] = ["h", "e", "l", "l", "o"] - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) + assert post.info == ["0", "h", "e", "l", "l", "o", "3", "4", "5"] # negative # 'append' reset_post() post.info.append("h") - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "h"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"]) + assert post.info == ["0", "1", "2", "3", "4", "5", "h"] # 'extend' reset_post() post.info.extend(["h", "e", "l", "l", "o"]) - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] post.save() post.reload() - self.assertEqual( - post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] - ) + assert post.info == ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] # 'insert' # 'pop' reset_post() x = post.info.pop(2) y = post.info.pop() - self.assertEqual(post.info, ["0", "1", "3", "4"]) - self.assertEqual(x, "2") - self.assertEqual(y, "5") + assert post.info == ["0", "1", "3", "4"] + assert x == "2" + assert y == "5" post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "3", "4"]) + assert post.info == ["0", "1", "3", "4"] # 'remove' reset_post() post.info.remove("2") - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] post.save() post.reload() - self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) + assert post.info == ["0", "1", "3", "4", "5"] # 'reverse' reset_post() post.info.reverse() - self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"]) + assert post.info == ["5", "4", "3", "2", "1", "0"] post.save() post.reload() - self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"]) + assert post.info == ["5", "4", "3", "2", "1", "0"] # 'sort': though this operator method does manipulate the list, it is # tested in the 'test_list_field_lexicograpic_operators' function @@ -844,7 +851,8 @@ class TestField(MongoDBTestCase): # '__hash__' # aka 'hash(list)' - self.assertRaises(TypeError, lambda: hash(post.info)) + with pytest.raises(TypeError): + hash(post.info) def test_list_field_lexicographic_operators(self): """Ensure that ListField works with standard list operators that @@ -883,32 +891,32 @@ class TestField(MongoDBTestCase): blogLargeB.reload() # '__eq__' aka '==' - self.assertEqual(blogLargeA.text_info, blogLargeB.text_info) - self.assertEqual(blogLargeA.bool_info, blogLargeB.bool_info) + assert blogLargeA.text_info == blogLargeB.text_info + assert blogLargeA.bool_info == blogLargeB.bool_info # '__ge__' aka '>=' - self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info) - self.assertGreaterEqual(blogLargeA.text_info, blogLargeB.text_info) - self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info) - self.assertGreaterEqual(blogLargeA.bool_info, blogLargeB.bool_info) + assert blogLargeA.text_info >= blogSmall.text_info + assert blogLargeA.text_info >= blogLargeB.text_info + assert blogLargeA.bool_info >= blogSmall.bool_info + assert blogLargeA.bool_info >= blogLargeB.bool_info # '__gt__' aka '>' - self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info) - self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info) + assert blogLargeA.text_info >= blogSmall.text_info + assert blogLargeA.bool_info >= blogSmall.bool_info # '__le__' aka '<=' - self.assertLessEqual(blogSmall.text_info, blogLargeB.text_info) - self.assertLessEqual(blogLargeA.text_info, blogLargeB.text_info) - self.assertLessEqual(blogSmall.bool_info, blogLargeB.bool_info) - self.assertLessEqual(blogLargeA.bool_info, blogLargeB.bool_info) + assert blogSmall.text_info <= blogLargeB.text_info + assert blogLargeA.text_info <= blogLargeB.text_info + assert blogSmall.bool_info <= blogLargeB.bool_info + assert blogLargeA.bool_info <= blogLargeB.bool_info # '__lt__' aka '<' - self.assertLess(blogSmall.text_info, blogLargeB.text_info) - self.assertLess(blogSmall.bool_info, blogLargeB.bool_info) + assert blogSmall.text_info < blogLargeB.text_info + assert blogSmall.bool_info < blogLargeB.bool_info # '__ne__' aka '!=' - self.assertNotEqual(blogSmall.text_info, blogLargeB.text_info) - self.assertNotEqual(blogSmall.bool_info, blogLargeB.bool_info) + assert blogSmall.text_info != blogLargeB.text_info + assert blogSmall.bool_info != blogLargeB.bool_info # 'sort' blogLargeB.bool_info = [True, False, True, False] @@ -920,14 +928,14 @@ class TestField(MongoDBTestCase): ObjectId("54495ad94c934721ede76d23"), ObjectId("54495ad94c934721ede76f90"), ] - self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) - self.assertEqual(blogLargeB.oid_info, sorted_target_list) - self.assertEqual(blogLargeB.bool_info, [False, False, True, True]) + assert blogLargeB.text_info == ["a", "j", "z"] + assert blogLargeB.oid_info == sorted_target_list + assert blogLargeB.bool_info == [False, False, True, True] blogLargeB.save() blogLargeB.reload() - self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) - self.assertEqual(blogLargeB.oid_info, sorted_target_list) - self.assertEqual(blogLargeB.bool_info, [False, False, True, True]) + assert blogLargeB.text_info == ["a", "j", "z"] + assert blogLargeB.oid_info == sorted_target_list + assert blogLargeB.bool_info == [False, False, True, True] def test_list_assignment(self): """Ensure that list field element assignment and slicing work.""" @@ -944,37 +952,37 @@ class TestField(MongoDBTestCase): post.info[0] = 1 post.save() post.reload() - self.assertEqual(post.info[0], 1) + assert post.info[0] == 1 post.info[1:3] = ["n2", "n3"] post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", "4", 5]) + assert post.info == [1, "n2", "n3", "4", 5] post.info[-1] = "n5" post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", "4", "n5"]) + assert post.info == [1, "n2", "n3", "4", "n5"] post.info[-2] = 4 post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"]) + assert post.info == [1, "n2", "n3", 4, "n5"] post.info[1:-1] = [2] post.save() post.reload() - self.assertEqual(post.info, [1, 2, "n5"]) + assert post.info == [1, 2, "n5"] post.info[:-1] = [1, "n2", "n3", 4] post.save() post.reload() - self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"]) + assert post.info == [1, "n2", "n3", 4, "n5"] post.info[-4:3] = [2, 3] post.save() post.reload() - self.assertEqual(post.info, [1, 2, 3, 4, "n5"]) + assert post.info == [1, 2, 3, 4, "n5"] def test_list_field_passed_in_value(self): class Foo(Document): @@ -988,7 +996,7 @@ class TestField(MongoDBTestCase): foo = Foo(bars=[]) foo.bars.append(bar) - self.assertEqual(repr(foo.bars), "[]") + assert repr(foo.bars) == "[]" def test_list_field_strict(self): """Ensure that list field handles validation if provided @@ -1005,7 +1013,7 @@ class TestField(MongoDBTestCase): e.save() # try creating an invalid mapping - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping = ["abc"] e.save() @@ -1021,9 +1029,9 @@ class TestField(MongoDBTestCase): if i < 6: foo.save() else: - with self.assertRaises(ValidationError) as cm: + with pytest.raises(ValidationError) as cm: foo.save() - self.assertIn("List is too long", str(cm.exception)) + assert "List is too long" in str(cm.exception) def test_list_field_max_length_set_operator(self): """Ensure ListField's max_length is respected for a "set" operator.""" @@ -1032,9 +1040,9 @@ class TestField(MongoDBTestCase): items = ListField(IntField(), max_length=3) foo = Foo.objects.create(items=[1, 2, 3]) - with self.assertRaises(ValidationError) as cm: + with pytest.raises(ValidationError) as cm: foo.modify(set__items=[1, 2, 3, 4]) - self.assertIn("List is too long", str(cm.exception)) + assert "List is too long" in str(cm.exception) def test_list_field_rejects_strings(self): """Strings aren't valid list field data types.""" @@ -1046,7 +1054,8 @@ class TestField(MongoDBTestCase): e = Simple() e.mapping = "hello world" - self.assertRaises(ValidationError, e.save) + with pytest.raises(ValidationError): + e.save() def test_complex_field_required(self): """Ensure required cant be None / Empty.""" @@ -1058,7 +1067,8 @@ class TestField(MongoDBTestCase): e = Simple() e.mapping = [] - self.assertRaises(ValidationError, e.save) + with pytest.raises(ValidationError): + e.save() class Simple(Document): mapping = DictField(required=True) @@ -1066,7 +1076,8 @@ class TestField(MongoDBTestCase): Simple.drop_collection() e = Simple() e.mapping = {} - self.assertRaises(ValidationError, e.save) + with pytest.raises(ValidationError): + e.save() def test_complex_field_same_value_not_changed(self): """If a complex field is set to the same value, it should not @@ -1080,7 +1091,7 @@ class TestField(MongoDBTestCase): e = Simple().save() e.mapping = [] - self.assertEqual([], e._changed_fields) + assert [] == e._changed_fields class Simple(Document): mapping = DictField() @@ -1089,7 +1100,7 @@ class TestField(MongoDBTestCase): e = Simple().save() e.mapping = {} - self.assertEqual([], e._changed_fields) + assert [] == e._changed_fields def test_slice_marks_field_as_changed(self): class Simple(Document): @@ -1097,11 +1108,11 @@ class TestField(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[:3] = [] - self.assertEqual(["widgets"], simple._changed_fields) + assert ["widgets"] == simple._changed_fields simple.save() simple = simple.reload() - self.assertEqual(simple.widgets, [4]) + assert simple.widgets == [4] def test_del_slice_marks_field_as_changed(self): class Simple(Document): @@ -1109,11 +1120,11 @@ class TestField(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() del simple.widgets[:3] - self.assertEqual(["widgets"], simple._changed_fields) + assert ["widgets"] == simple._changed_fields simple.save() simple = simple.reload() - self.assertEqual(simple.widgets, [4]) + assert simple.widgets == [4] def test_list_field_with_negative_indices(self): class Simple(Document): @@ -1121,11 +1132,11 @@ class TestField(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[-1] = 5 - self.assertEqual(["widgets.3"], simple._changed_fields) + assert ["widgets.3"] == simple._changed_fields simple.save() simple = simple.reload() - self.assertEqual(simple.widgets, [1, 2, 3, 5]) + assert simple.widgets == [1, 2, 3, 5] def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -1159,33 +1170,23 @@ class TestField(MongoDBTestCase): e.save() e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping[0], StringSetting) - self.assertIsInstance(e2.mapping[1], IntegerSetting) + assert isinstance(e2.mapping[0], StringSetting) + assert isinstance(e2.mapping[1], IntegerSetting) # Test querying - self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__2__complex__value=42).count(), 1 - ) - self.assertEqual( - Simple.objects.filter(mapping__2__list__0__value=42).count(), 1 - ) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value="foo").count(), 1 - ) + assert Simple.objects.filter(mapping__1__value=42).count() == 1 + assert Simple.objects.filter(mapping__2__number=1).count() == 1 + assert Simple.objects.filter(mapping__2__complex__value=42).count() == 1 + assert Simple.objects.filter(mapping__2__list__0__value=42).count() == 1 + assert Simple.objects.filter(mapping__2__list__1__value="foo").count() == 1 # Confirm can update Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) - self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1) + assert Simple.objects.filter(mapping__1__value=10).count() == 1 Simple.objects().update(set__mapping__2__list__1=StringSetting(value="Boo")) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value="foo").count(), 0 - ) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value="Boo").count(), 1 - ) + assert Simple.objects.filter(mapping__2__list__1__value="foo").count() == 0 + assert Simple.objects.filter(mapping__2__list__1__value="Boo").count() == 1 def test_embedded_db_field(self): class Embedded(EmbeddedDocument): @@ -1203,9 +1204,9 @@ class TestField(MongoDBTestCase): Test.objects.update_one(inc__embedded__number=1) test = Test.objects.get() - self.assertEqual(test.embedded.number, 2) + assert test.embedded.number == 2 doc = self.db.test.find_one() - self.assertEqual(doc["x"]["i"], 2) + assert doc["x"]["i"] == 2 def test_double_embedded_db_field(self): """Make sure multiple layers of embedded docs resolve db fields @@ -1242,7 +1243,7 @@ class TestField(MongoDBTestCase): b = EmbeddedDocumentField(B, db_field="fb") a = A._from_son(SON([("fb", SON([("fc", SON([("txt", "hi")]))]))])) - self.assertEqual(a.b.c.txt, "hi") + assert a.b.c.txt == "hi" def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet( self, @@ -1277,18 +1278,21 @@ class TestField(MongoDBTestCase): person = Person(name="Test User") person.preferences = "My Preferences" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() # Check that only the right embedded doc works person.preferences = Comment(content="Nice blog post...") - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() # Check that the embedded doc is valid person.preferences = PersonPreferences() - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.preferences = PersonPreferences(food="Cheese", number=47) - self.assertEqual(person.preferences.food, "Cheese") + assert person.preferences.food == "Cheese" person.validate() def test_embedded_document_inheritance(self): @@ -1314,7 +1318,7 @@ class TestField(MongoDBTestCase): post.author = PowerUser(name="Test User", power=47) post.save() - self.assertEqual(47, BlogPost.objects.first().author.power) + assert 47 == BlogPost.objects.first().author.power def test_embedded_document_inheritance_with_list(self): """Ensure that nested list of subclassed embedded documents is @@ -1339,7 +1343,7 @@ class TestField(MongoDBTestCase): foobar = User(groups=[group]) foobar.save() - self.assertEqual(content, User.objects.first().groups[0].content) + assert content == User.objects.first().groups[0].content def test_reference_miss(self): """Ensure an exception is raised when dereferencing an unknown @@ -1362,16 +1366,18 @@ class TestField(MongoDBTestCase): # Reference is no longer valid foo.delete() bar = Bar.objects.get() - self.assertRaises(DoesNotExist, getattr, bar, "ref") - self.assertRaises(DoesNotExist, getattr, bar, "generic_ref") + with pytest.raises(DoesNotExist): + getattr(bar, "ref") + with pytest.raises(DoesNotExist): + getattr(bar, "generic_ref") # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() expected = foo.to_dbref() bar._fields["ref"]._auto_dereference = False - self.assertEqual(bar.ref, expected) + assert bar.ref == expected bar._fields["generic_ref"]._auto_dereference = False - self.assertEqual(bar.generic_ref, {"_ref": expected, "_cls": "Foo"}) + assert bar.generic_ref == {"_ref": expected, "_cls": "Foo"} def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -1396,8 +1402,8 @@ class TestField(MongoDBTestCase): group_obj = Group.objects.first() - self.assertEqual(group_obj.members[0].name, user1.name) - self.assertEqual(group_obj.members[1].name, user2.name) + assert group_obj.members[0].name == user1.name + assert group_obj.members[1].name == user2.name def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. @@ -1424,8 +1430,8 @@ class TestField(MongoDBTestCase): peter.save() peter = Employee.objects.with_id(peter.id) - self.assertEqual(peter.boss, bill) - self.assertEqual(peter.friends, friends) + assert peter.boss == bill + assert peter.friends == friends def test_recursive_embedding(self): """Ensure that EmbeddedDocumentFields can contain their own documents. @@ -1450,18 +1456,18 @@ class TestField(MongoDBTestCase): tree.save() tree = Tree.objects.first() - self.assertEqual(len(tree.children), 1) + assert len(tree.children) == 1 - self.assertEqual(len(tree.children[0].children), 1) + assert len(tree.children[0].children) == 1 third_child = TreeNode(name="Child 3") tree.children[0].children.append(third_child) tree.save() - self.assertEqual(len(tree.children), 1) - self.assertEqual(tree.children[0].name, first_child.name) - self.assertEqual(tree.children[0].children[0].name, second_child.name) - self.assertEqual(tree.children[0].children[1].name, third_child.name) + assert len(tree.children) == 1 + assert tree.children[0].name == first_child.name + assert tree.children[0].children[0].name == second_child.name + assert tree.children[0].children[1].name == third_child.name # Test updating tree.children[0].name = "I am Child 1" @@ -1469,28 +1475,28 @@ class TestField(MongoDBTestCase): tree.children[0].children[1].name = "I am Child 3" tree.save() - self.assertEqual(tree.children[0].name, "I am Child 1") - self.assertEqual(tree.children[0].children[0].name, "I am Child 2") - self.assertEqual(tree.children[0].children[1].name, "I am Child 3") + assert tree.children[0].name == "I am Child 1" + assert tree.children[0].children[0].name == "I am Child 2" + assert tree.children[0].children[1].name == "I am Child 3" # Test removal - self.assertEqual(len(tree.children[0].children), 2) + assert len(tree.children[0].children) == 2 del tree.children[0].children[1] tree.save() - self.assertEqual(len(tree.children[0].children), 1) + assert len(tree.children[0].children) == 1 tree.children[0].children.pop(0) tree.save() - self.assertEqual(len(tree.children[0].children), 0) - self.assertEqual(tree.children[0].children, []) + assert len(tree.children[0].children) == 0 + assert tree.children[0].children == [] tree.children[0].children.insert(0, third_child) tree.children[0].children.insert(0, second_child) tree.save() - self.assertEqual(len(tree.children[0].children), 2) - self.assertEqual(tree.children[0].children[0].name, second_child.name) - self.assertEqual(tree.children[0].children[1].name, third_child.name) + assert len(tree.children[0].children) == 2 + assert tree.children[0].children[0].name == second_child.name + assert tree.children[0].children[1].name == third_child.name def test_drop_abstract_document(self): """Ensure that an abstract document cannot be dropped given it @@ -1501,7 +1507,8 @@ class TestField(MongoDBTestCase): name = StringField() meta = {"abstract": True} - self.assertRaises(OperationError, AbstractDoc.drop_collection) + with pytest.raises(OperationError): + AbstractDoc.drop_collection() def test_reference_class_with_abstract_parent(self): """Ensure that a class with an abstract parent can be referenced. @@ -1525,7 +1532,7 @@ class TestField(MongoDBTestCase): brother = Brother(name="Bob", sibling=sister) brother.save() - self.assertEqual(Brother.objects[0].sibling.name, sister.name) + assert Brother.objects[0].sibling.name == sister.name def test_reference_abstract_class(self): """Ensure that an abstract class instance cannot be used in the @@ -1547,7 +1554,8 @@ class TestField(MongoDBTestCase): sister = Sibling(name="Alice") brother = Brother(name="Bob", sibling=sister) - self.assertRaises(ValidationError, brother.save) + with pytest.raises(ValidationError): + brother.save() def test_abstract_reference_base_type(self): """Ensure that an an abstract reference fails validation when given a @@ -1570,7 +1578,8 @@ class TestField(MongoDBTestCase): mother = Mother(name="Carol") mother.save() brother = Brother(name="Bob", sibling=mother) - self.assertRaises(ValidationError, brother.save) + with pytest.raises(ValidationError): + brother.save() def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. @@ -1601,16 +1610,16 @@ class TestField(MongoDBTestCase): bm = Bookmark.objects(bookmark_object=post_1).first() - self.assertEqual(bm.bookmark_object, post_1) - self.assertIsInstance(bm.bookmark_object, Post) + assert bm.bookmark_object == post_1 + assert isinstance(bm.bookmark_object, Post) bm.bookmark_object = link_1 bm.save() bm = Bookmark.objects(bookmark_object=link_1).first() - self.assertEqual(bm.bookmark_object, link_1) - self.assertIsInstance(bm.bookmark_object, Link) + assert bm.bookmark_object == link_1 + assert isinstance(bm.bookmark_object, Link) def test_generic_reference_list(self): """Ensure that a ListField properly dereferences generic references. @@ -1640,8 +1649,8 @@ class TestField(MongoDBTestCase): user = User.objects(bookmarks__all=[post_1, link_1]).first() - self.assertEqual(user.bookmarks[0], post_1) - self.assertEqual(user.bookmarks[1], link_1) + assert user.bookmarks[0] == post_1 + assert user.bookmarks[1] == link_1 def test_generic_reference_document_not_registered(self): """Ensure dereferencing out of the document registry throws a @@ -1682,7 +1691,7 @@ class TestField(MongoDBTestCase): Person.drop_collection() Person(name="Wilson Jr").save() - self.assertEqual(repr(Person.objects(city=None)), "[]") + assert repr(Person.objects(city=None)) == "[]" def test_generic_reference_choices(self): """Ensure that a GenericReferenceField can handle choices.""" @@ -1707,13 +1716,14 @@ class TestField(MongoDBTestCase): post_1.save() bm = Bookmark(bookmark_object=link_1) - self.assertRaises(ValidationError, bm.validate) + with pytest.raises(ValidationError): + bm.validate() bm = Bookmark(bookmark_object=post_1) bm.save() bm = Bookmark.objects.first() - self.assertEqual(bm.bookmark_object, post_1) + assert bm.bookmark_object == post_1 def test_generic_reference_string_choices(self): """Ensure that a GenericReferenceField can handle choices as strings @@ -1745,7 +1755,8 @@ class TestField(MongoDBTestCase): bm.save() bm = Bookmark(bookmark_object=bm) - self.assertRaises(ValidationError, bm.validate) + with pytest.raises(ValidationError): + bm.validate() def test_generic_reference_choices_no_dereference(self): """Ensure that a GenericReferenceField can handle choices on @@ -1798,13 +1809,14 @@ class TestField(MongoDBTestCase): post_1.save() user = User(bookmarks=[link_1]) - self.assertRaises(ValidationError, user.validate) + with pytest.raises(ValidationError): + user.validate() user = User(bookmarks=[post_1]) user.save() user = User.objects.first() - self.assertEqual(user.bookmarks, [post_1]) + assert user.bookmarks == [post_1] def test_generic_reference_list_item_modification(self): """Ensure that modifications of related documents (through generic reference) don't influence on querying @@ -1832,8 +1844,8 @@ class TestField(MongoDBTestCase): user = User.objects(bookmarks__all=[post_1]).first() - self.assertNotEqual(user, None) - self.assertEqual(user.bookmarks[0], post_1) + assert user != None + assert user.bookmarks[0] == post_1 def test_generic_reference_filter_by_dbref(self): """Ensure we can search for a specific generic reference by @@ -1849,7 +1861,7 @@ class TestField(MongoDBTestCase): doc2 = Doc.objects.create(ref=doc1) doc = Doc.objects.get(ref=DBRef("doc", doc1.pk)) - self.assertEqual(doc, doc2) + assert doc == doc2 def test_generic_reference_is_not_tracked_in_parent_doc(self): """Ensure that modifications of related documents (through generic reference) don't influence @@ -1871,11 +1883,11 @@ class TestField(MongoDBTestCase): doc2 = Doc2(ref=doc1, refs=[doc11]).save() doc2.ref.name = "garbage2" - self.assertEqual(doc2._get_changed_fields(), []) + assert doc2._get_changed_fields() == [] doc2.refs[0].name = "garbage3" - self.assertEqual(doc2._get_changed_fields(), []) - self.assertEqual(doc2._delta(), ({}, {})) + assert doc2._get_changed_fields() == [] + assert doc2._delta() == ({}, {}) def test_generic_reference_field(self): """Ensure we can search for a specific generic reference by @@ -1890,10 +1902,10 @@ class TestField(MongoDBTestCase): doc1 = Doc.objects.create() doc2 = Doc.objects.create(ref=doc1) - self.assertIsInstance(doc1.pk, ObjectId) + assert isinstance(doc1.pk, ObjectId) doc = Doc.objects.get(ref=doc1.pk) - self.assertEqual(doc, doc2) + assert doc == doc2 def test_choices_allow_using_sets_as_choices(self): """Ensure that sets can be used when setting choices @@ -1933,7 +1945,7 @@ class TestField(MongoDBTestCase): size = StringField(choices=("S", "M")) shirt = Shirt(size="XS") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): shirt.validate() def test_choices_get_field_display(self): @@ -1964,24 +1976,25 @@ class TestField(MongoDBTestCase): shirt2 = Shirt() # Make sure get__display returns the default value (or None) - self.assertEqual(shirt1.get_size_display(), None) - self.assertEqual(shirt1.get_style_display(), "Wide") + assert shirt1.get_size_display() == None + assert shirt1.get_style_display() == "Wide" shirt1.size = "XXL" shirt1.style = "B" shirt2.size = "M" shirt2.style = "S" - self.assertEqual(shirt1.get_size_display(), "Extra Extra Large") - self.assertEqual(shirt1.get_style_display(), "Baggy") - self.assertEqual(shirt2.get_size_display(), "Medium") - self.assertEqual(shirt2.get_style_display(), "Small") + assert shirt1.get_size_display() == "Extra Extra Large" + assert shirt1.get_style_display() == "Baggy" + assert shirt2.get_size_display() == "Medium" + assert shirt2.get_style_display() == "Small" # Set as Z - an invalid choice shirt1.size = "Z" shirt1.style = "Z" - self.assertEqual(shirt1.get_size_display(), "Z") - self.assertEqual(shirt1.get_style_display(), "Z") - self.assertRaises(ValidationError, shirt1.validate) + assert shirt1.get_size_display() == "Z" + assert shirt1.get_style_display() == "Z" + with pytest.raises(ValidationError): + shirt1.validate() def test_simple_choices_validation(self): """Ensure that value is in a container of allowed values. @@ -1999,7 +2012,8 @@ class TestField(MongoDBTestCase): shirt.validate() shirt.size = "XS" - self.assertRaises(ValidationError, shirt.validate) + with pytest.raises(ValidationError): + shirt.validate() def test_simple_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices @@ -2016,20 +2030,21 @@ class TestField(MongoDBTestCase): shirt = Shirt() - self.assertEqual(shirt.get_size_display(), None) - self.assertEqual(shirt.get_style_display(), "Small") + assert shirt.get_size_display() == None + assert shirt.get_style_display() == "Small" shirt.size = "XXL" shirt.style = "Baggy" - self.assertEqual(shirt.get_size_display(), "XXL") - self.assertEqual(shirt.get_style_display(), "Baggy") + assert shirt.get_size_display() == "XXL" + assert shirt.get_style_display() == "Baggy" # Set as Z - an invalid choice shirt.size = "Z" shirt.style = "Z" - self.assertEqual(shirt.get_size_display(), "Z") - self.assertEqual(shirt.get_style_display(), "Z") - self.assertRaises(ValidationError, shirt.validate) + assert shirt.get_size_display() == "Z" + assert shirt.get_style_display() == "Z" + with pytest.raises(ValidationError): + shirt.validate() def test_simple_choices_validation_invalid_value(self): """Ensure that error messages are correct. @@ -2060,8 +2075,8 @@ class TestField(MongoDBTestCase): except ValidationError as error: # get the validation rules error_dict = error.to_dict() - self.assertEqual(error_dict["size"], SIZE_MESSAGE) - self.assertEqual(error_dict["color"], COLOR_MESSAGE) + assert error_dict["size"] == SIZE_MESSAGE + assert error_dict["color"] == COLOR_MESSAGE def test_recursive_validation(self): """Ensure that a validation result to_dict is available.""" @@ -2082,26 +2097,25 @@ class TestField(MongoDBTestCase): post.comments.append(Comment(content="hello", author=bob)) post.comments.append(Comment(author=bob)) - self.assertRaises(ValidationError, post.validate) + with pytest.raises(ValidationError): + post.validate() try: post.validate() except ValidationError as error: # ValidationError.errors property - self.assertTrue(hasattr(error, "errors")) - self.assertIsInstance(error.errors, dict) - self.assertIn("comments", error.errors) - self.assertIn(1, error.errors["comments"]) - self.assertIsInstance( - error.errors["comments"][1]["content"], ValidationError - ) + assert hasattr(error, "errors") + assert isinstance(error.errors, dict) + assert "comments" in error.errors + assert 1 in error.errors["comments"] + assert isinstance(error.errors["comments"][1]["content"], ValidationError) # ValidationError.schema property error_dict = error.to_dict() - self.assertIsInstance(error_dict, dict) - self.assertIn("comments", error_dict) - self.assertIn(1, error_dict["comments"]) - self.assertIn("content", error_dict["comments"][1]) - self.assertEqual(error_dict["comments"][1]["content"], u"Field is required") + assert isinstance(error_dict, dict) + assert "comments" in error_dict + assert 1 in error_dict["comments"] + assert "content" in error_dict["comments"][1] + assert error_dict["comments"][1]["content"] == u"Field is required" post.comments[1].content = "here we go" post.validate() @@ -2131,10 +2145,10 @@ class TestField(MongoDBTestCase): doc.items = tuples doc.save() x = TestDoc.objects().get() - self.assertIsNotNone(x) - self.assertEqual(len(x.items), 1) - self.assertIn(tuple(x.items[0]), tuples) - self.assertIn(x.items[0], tuples) + assert x is not None + assert len(x.items) == 1 + assert tuple(x.items[0]) in tuples + assert x.items[0] in tuples def test_dynamic_fields_class(self): class Doc2(Document): @@ -2150,13 +2164,14 @@ class TestField(MongoDBTestCase): doc2 = Doc2(field_1="hello") doc = Doc(my_id=1, embed_me=doc2, field_x="x") - self.assertRaises(OperationError, doc.save) + with pytest.raises(OperationError): + doc.save() doc2.save() doc.save() doc = Doc.objects.get() - self.assertEqual(doc.embed_me.field_1, "hello") + assert doc.embed_me.field_1 == "hello" def test_dynamic_fields_embedded_class(self): class Embed(EmbeddedDocument): @@ -2172,7 +2187,7 @@ class TestField(MongoDBTestCase): Doc(my_id=1, embed_me=Embed(field_1="hello"), field_x="x").save() doc = Doc.objects.get() - self.assertEqual(doc.embed_me.field_1, "hello") + assert doc.embed_me.field_1 == "hello" def test_dynamicfield_dump_document(self): """Ensure a DynamicField can handle another document's dump.""" @@ -2197,15 +2212,15 @@ class TestField(MongoDBTestCase): to_embed = ToEmbed(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed) doc.save() - self.assertIsInstance(doc.field, ToEmbed) - self.assertEqual(doc.field, to_embed) + assert isinstance(doc.field, ToEmbed) + assert doc.field == to_embed # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed_child) doc.save() - self.assertIsInstance(doc.field, ToEmbedChild) - self.assertEqual(doc.field, to_embed_child) + assert isinstance(doc.field, ToEmbedChild) + assert doc.field == to_embed_child def test_cls_field(self): class Animal(Document): @@ -2227,10 +2242,10 @@ class TestField(MongoDBTestCase): Dog().save() Fish().save() Human().save() - self.assertEqual( - Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2 + assert ( + Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count() == 2 ) - self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) + assert Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count() == 0 def test_sparse_field(self): class Doc(Document): @@ -2249,7 +2264,7 @@ class TestField(MongoDBTestCase): class Doc(Document): foo = StringField() - with self.assertRaises(FieldDoesNotExist): + with pytest.raises(FieldDoesNotExist): Doc(bar="test") def test_undefined_field_exception_with_strict(self): @@ -2262,7 +2277,7 @@ class TestField(MongoDBTestCase): foo = StringField() meta = {"strict": False} - with self.assertRaises(FieldDoesNotExist): + with pytest.raises(FieldDoesNotExist): Doc(bar="test") @@ -2310,20 +2325,20 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): # Test with an embeddedDocument instead of a list(embeddedDocument) # It's an edge case but it used to fail with a vague error, making it difficult to troubleshoot it post = self.BlogPost(comments=comment) - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: post.validate() - self.assertIn("'comments'", str(ctx_err.exception)) - self.assertIn( - "Only lists and tuples may be used in a list field", str(ctx_err.exception) + assert "'comments'" in str(ctx_err.exception) + assert "Only lists and tuples may be used in a list field" in str( + ctx_err.exception ) # Test with a Document post = self.BlogPost(comments=Title(content="garbage")) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): post.validate() - self.assertIn("'comments'", str(ctx_err.exception)) - self.assertIn( - "Only lists and tuples may be used in a list field", str(ctx_err.exception) + assert "'comments'" in str(ctx_err.exception) + assert "Only lists and tuples may be used in a list field" in str( + ctx_err.exception ) def test_no_keyword_filter(self): @@ -2334,7 +2349,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post1.comments.filter() # Ensure nothing was changed - self.assertListEqual(filtered, self.post1.comments) + assert filtered == self.post1.comments def test_single_keyword_filter(self): """ @@ -2344,10 +2359,10 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post1.comments.filter(author="user1") # Ensure only 1 entry was returned. - self.assertEqual(len(filtered), 1) + assert len(filtered) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, "user1") + assert filtered[0].author == "user1" def test_multi_keyword_filter(self): """ @@ -2357,11 +2372,11 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post2.comments.filter(author="user2", message="message2") # Ensure only 1 entry was returned. - self.assertEqual(len(filtered), 1) + assert len(filtered) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, "user2") - self.assertEqual(filtered[0].message, "message2") + assert filtered[0].author == "user2" + assert filtered[0].message == "message2" def test_chained_filter(self): """ @@ -2370,18 +2385,18 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post2.comments.filter(author="user2").filter(message="message2") # Ensure only 1 entry was returned. - self.assertEqual(len(filtered), 1) + assert len(filtered) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, "user2") - self.assertEqual(filtered[0].message, "message2") + assert filtered[0].author == "user2" + assert filtered[0].message == "message2" def test_unknown_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents when the keyword is not a known keyword. """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.post2.comments.filter(year=2) def test_no_keyword_exclude(self): @@ -2392,7 +2407,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): filtered = self.post1.comments.exclude() # Ensure everything was removed - self.assertListEqual(filtered, []) + assert filtered == [] def test_single_keyword_exclude(self): """ @@ -2402,10 +2417,10 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): excluded = self.post1.comments.exclude(author="user1") # Ensure only 1 entry was returned. - self.assertEqual(len(excluded), 1) + assert len(excluded) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(excluded[0].author, "user2") + assert excluded[0].author == "user2" def test_multi_keyword_exclude(self): """ @@ -2415,11 +2430,11 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): excluded = self.post2.comments.exclude(author="user3", message="message1") # Ensure only 2 entries were returned. - self.assertEqual(len(excluded), 2) + assert len(excluded) == 2 # Ensure the entries returned are the correct entries. - self.assertEqual(excluded[0].author, "user2") - self.assertEqual(excluded[1].author, "user2") + assert excluded[0].author == "user2" + assert excluded[1].author == "user2" def test_non_matching_exclude(self): """ @@ -2429,14 +2444,14 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): excluded = self.post2.comments.exclude(author="user4") # Ensure the 3 entries still exist. - self.assertEqual(len(excluded), 3) + assert len(excluded) == 3 def test_unknown_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents when the keyword is not a known keyword. """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.post2.comments.exclude(year=2) def test_chained_filter_exclude(self): @@ -2449,25 +2464,25 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): ) # Ensure only 1 entry was returned. - self.assertEqual(len(excluded), 1) + assert len(excluded) == 1 # Ensure the entry returned is the correct entry. - self.assertEqual(excluded[0].author, "user2") - self.assertEqual(excluded[0].message, "message3") + assert excluded[0].author == "user2" + assert excluded[0].message == "message3" def test_count(self): """ Tests the count method of a List of Embedded Documents. """ - self.assertEqual(self.post1.comments.count(), 2) - self.assertEqual(self.post1.comments.count(), len(self.post1.comments)) + assert self.post1.comments.count() == 2 + assert self.post1.comments.count() == len(self.post1.comments) def test_filtered_count(self): """ Tests the filter + count method of a List of Embedded Documents. """ count = self.post1.comments.filter(author="user1").count() - self.assertEqual(count, 1) + assert count == 1 def test_single_keyword_get(self): """ @@ -2475,8 +2490,8 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): single keyword. """ comment = self.post1.comments.get(author="user1") - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user1") + assert isinstance(comment, self.Comments) + assert comment.author == "user1" def test_multi_keyword_get(self): """ @@ -2484,16 +2499,16 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): multiple keywords. """ comment = self.post2.comments.get(author="user2", message="message2") - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user2") - self.assertEqual(comment.message, "message2") + assert isinstance(comment, self.Comments) + assert comment.author == "user2" + assert comment.message == "message2" def test_no_keyword_multiple_return_get(self): """ Tests the get method of a List of Embedded Documents without a keyword to return multiple documents. """ - with self.assertRaises(MultipleObjectsReturned): + with pytest.raises(MultipleObjectsReturned): self.post1.comments.get() def test_keyword_multiple_return_get(self): @@ -2501,7 +2516,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): Tests the get method of a List of Embedded Documents with a keyword to return multiple documents. """ - with self.assertRaises(MultipleObjectsReturned): + with pytest.raises(MultipleObjectsReturned): self.post2.comments.get(author="user2") def test_unknown_keyword_get(self): @@ -2509,7 +2524,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): Tests the get method of a List of Embedded Documents with an unknown keyword. """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.post2.comments.get(year=2020) def test_no_result_get(self): @@ -2517,7 +2532,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): Tests the get method of a List of Embedded Documents where get returns no results. """ - with self.assertRaises(DoesNotExist): + with pytest.raises(DoesNotExist): self.post1.comments.get(author="user3") def test_first(self): @@ -2528,8 +2543,8 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): comment = self.post1.comments.first() # Ensure a Comment object was returned. - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment, self.post1.comments[0]) + assert isinstance(comment, self.Comments) + assert comment == self.post1.comments[0] def test_create(self): """ @@ -2539,14 +2554,12 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure the returned value is the comment object. - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user4") - self.assertEqual(comment.message, "message1") + assert isinstance(comment, self.Comments) + assert comment.author == "user4" + assert comment.message == "message1" # Ensure the new comment was actually saved to the database. - self.assertIn( - comment, self.BlogPost.objects(comments__author="user4")[0].comments - ) + assert comment in self.BlogPost.objects(comments__author="user4")[0].comments def test_filtered_create(self): """ @@ -2560,14 +2573,12 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure the returned value is the comment object. - self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, "user4") - self.assertEqual(comment.message, "message1") + assert isinstance(comment, self.Comments) + assert comment.author == "user4" + assert comment.message == "message1" # Ensure the new comment was actually saved to the database. - self.assertIn( - comment, self.BlogPost.objects(comments__author="user4")[0].comments - ) + assert comment in self.BlogPost.objects(comments__author="user4")[0].comments def test_no_keyword_update(self): """ @@ -2579,13 +2590,13 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure that nothing was altered. - self.assertIn(original[0], self.BlogPost.objects(id=self.post1.id)[0].comments) + assert original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments - self.assertIn(original[1], self.BlogPost.objects(id=self.post1.id)[0].comments) + assert original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure the method returned 0 as the number of entries # modified - self.assertEqual(number, 0) + assert number == 0 def test_single_keyword_update(self): """ @@ -2598,12 +2609,12 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): comments = self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure that the database was updated properly. - self.assertEqual(comments[0].author, "user4") - self.assertEqual(comments[1].author, "user4") + assert comments[0].author == "user4" + assert comments[1].author == "user4" # Ensure the method returned 2 as the number of entries # modified - self.assertEqual(number, 2) + assert number == 2 def test_unicode(self): """ @@ -2615,7 +2626,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.Comments(author="user2", message=u"хабарлама"), ] ).save() - self.assertEqual(post.comments.get(message=u"сообщение").author, "user1") + assert post.comments.get(message=u"сообщение").author == "user1" def test_save(self): """ @@ -2627,7 +2638,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): comments.save() # Ensure that the new comment has been added to the database. - self.assertIn(new_comment, self.BlogPost.objects(id=self.post1.id)[0].comments) + assert new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments def test_delete(self): """ @@ -2638,17 +2649,17 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): # Ensure that all the comments under post1 were deleted in the # database. - self.assertListEqual(self.BlogPost.objects(id=self.post1.id)[0].comments, []) + assert self.BlogPost.objects(id=self.post1.id)[0].comments == [] # Ensure that post1 comments were deleted from the list. - self.assertListEqual(self.post1.comments, []) + assert self.post1.comments == [] # Ensure that comments still returned a EmbeddedDocumentList object. - self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) + assert isinstance(self.post1.comments, EmbeddedDocumentList) # Ensure that the delete method returned 2 as the number of entries # deleted from the database - self.assertEqual(number, 2) + assert number == 2 def test_empty_list_embedded_documents_with_unique_field(self): """ @@ -2664,7 +2675,7 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): my_list = ListField(EmbeddedDocumentField(EmbeddedWithUnique)) A(my_list=[]).save() - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): A(my_list=[]).save() class EmbeddedWithSparseUnique(EmbeddedDocument): @@ -2689,16 +2700,16 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): self.post1.save() # Ensure that only the user2 comment was deleted. - self.assertNotIn(comment, self.BlogPost.objects(id=self.post1.id)[0].comments) - self.assertEqual(len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1) + assert comment not in self.BlogPost.objects(id=self.post1.id)[0].comments + assert len(self.BlogPost.objects(id=self.post1.id)[0].comments) == 1 # Ensure that the user2 comment no longer exists in the list. - self.assertNotIn(comment, self.post1.comments) - self.assertEqual(len(self.post1.comments), 1) + assert comment not in self.post1.comments + assert len(self.post1.comments) == 1 # Ensure that the delete method returned 1 as the number of entries # deleted from the database - self.assertEqual(number, 1) + assert number == 1 def test_custom_data(self): """ @@ -2714,10 +2725,10 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): CustomData.drop_collection() a1 = CustomData(a_field=1, c_field=2).save() - self.assertEqual(2, a1.c_field) - self.assertFalse(hasattr(a1.c_field, "custom_data")) - self.assertTrue(hasattr(CustomData.c_field, "custom_data")) - self.assertEqual(custom_data["a"], CustomData.c_field.custom_data["a"]) + assert 2 == a1.c_field + assert not hasattr(a1.c_field, "custom_data") + assert hasattr(CustomData.c_field, "custom_data") + assert custom_data["a"] == CustomData.c_field.custom_data["a"] if __name__ == "__main__": diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index 49eb5bc2..0746db33 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -64,13 +64,13 @@ class TestFileField(MongoDBTestCase): putfile.save() result = PutFile.objects.first() - self.assertEqual(putfile, result) - self.assertEqual( - "%s" % result.the_file, - "" % result.the_file.grid_id, + assert putfile == result + assert ( + "%s" % result.the_file + == "" % result.the_file.grid_id ) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) + assert result.the_file.read() == text + assert result.the_file.content_type == content_type result.the_file.delete() # Remove file from GridFS PutFile.objects.delete() @@ -85,9 +85,9 @@ class TestFileField(MongoDBTestCase): putfile.save() result = PutFile.objects.first() - self.assertEqual(putfile, result) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) + assert putfile == result + assert result.the_file.read() == text + assert result.the_file.content_type == content_type result.the_file.delete() def test_file_fields_stream(self): @@ -111,19 +111,19 @@ class TestFileField(MongoDBTestCase): streamfile.save() result = StreamFile.objects.first() - self.assertEqual(streamfile, result) - self.assertEqual(result.the_file.read(), text + more_text) - self.assertEqual(result.the_file.content_type, content_type) + assert streamfile == result + assert result.the_file.read() == text + more_text + assert result.the_file.content_type == content_type result.the_file.seek(0) - self.assertEqual(result.the_file.tell(), 0) - self.assertEqual(result.the_file.read(len(text)), text) - self.assertEqual(result.the_file.tell(), len(text)) - self.assertEqual(result.the_file.read(len(more_text)), more_text) - self.assertEqual(result.the_file.tell(), len(text + more_text)) + assert result.the_file.tell() == 0 + assert result.the_file.read(len(text)) == text + assert result.the_file.tell() == len(text) + assert result.the_file.read(len(more_text)) == more_text + assert result.the_file.tell() == len(text + more_text) result.the_file.delete() # Ensure deleted file returns None - self.assertTrue(result.the_file.read() is None) + assert result.the_file.read() is None def test_file_fields_stream_after_none(self): """Ensure that a file field can be written to after it has been saved as @@ -148,19 +148,19 @@ class TestFileField(MongoDBTestCase): streamfile.save() result = StreamFile.objects.first() - self.assertEqual(streamfile, result) - self.assertEqual(result.the_file.read(), text + more_text) + assert streamfile == result + assert result.the_file.read() == text + more_text # self.assertEqual(result.the_file.content_type, content_type) result.the_file.seek(0) - self.assertEqual(result.the_file.tell(), 0) - self.assertEqual(result.the_file.read(len(text)), text) - self.assertEqual(result.the_file.tell(), len(text)) - self.assertEqual(result.the_file.read(len(more_text)), more_text) - self.assertEqual(result.the_file.tell(), len(text + more_text)) + assert result.the_file.tell() == 0 + assert result.the_file.read(len(text)) == text + assert result.the_file.tell() == len(text) + assert result.the_file.read(len(more_text)) == more_text + assert result.the_file.tell() == len(text + more_text) result.the_file.delete() # Ensure deleted file returns None - self.assertTrue(result.the_file.read() is None) + assert result.the_file.read() is None def test_file_fields_set(self): class SetFile(Document): @@ -176,16 +176,16 @@ class TestFileField(MongoDBTestCase): setfile.save() result = SetFile.objects.first() - self.assertEqual(setfile, result) - self.assertEqual(result.the_file.read(), text) + assert setfile == result + assert result.the_file.read() == text # Try replacing file with new one result.the_file.replace(more_text) result.save() result = SetFile.objects.first() - self.assertEqual(setfile, result) - self.assertEqual(result.the_file.read(), more_text) + assert setfile == result + assert result.the_file.read() == more_text result.the_file.delete() def test_file_field_no_default(self): @@ -205,28 +205,28 @@ class TestFileField(MongoDBTestCase): doc_b = GridDocument.objects.with_id(doc_a.id) doc_b.the_file.replace(f, filename="doc_b") doc_b.save() - self.assertNotEqual(doc_b.the_file.grid_id, None) + assert doc_b.the_file.grid_id != None # Test it matches doc_c = GridDocument.objects.with_id(doc_b.id) - self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) + assert doc_b.the_file.grid_id == doc_c.the_file.grid_id # Test with default doc_d = GridDocument(the_file=six.b("")) doc_d.save() doc_e = GridDocument.objects.with_id(doc_d.id) - self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id) + assert doc_d.the_file.grid_id == doc_e.the_file.grid_id doc_e.the_file.replace(f, filename="doc_e") doc_e.save() doc_f = GridDocument.objects.with_id(doc_e.id) - self.assertEqual(doc_e.the_file.grid_id, doc_f.the_file.grid_id) + assert doc_e.the_file.grid_id == doc_f.the_file.grid_id db = GridDocument._get_db() grid_fs = gridfs.GridFS(db) - self.assertEqual(["doc_b", "doc_e"], grid_fs.list()) + assert ["doc_b", "doc_e"] == grid_fs.list() def test_file_uniqueness(self): """Ensure that each instance of a FileField is unique @@ -246,8 +246,8 @@ class TestFileField(MongoDBTestCase): test_file_dupe = TestFile() data = test_file_dupe.the_file.read() # Should be None - self.assertNotEqual(test_file.name, test_file_dupe.name) - self.assertNotEqual(test_file.the_file.read(), data) + assert test_file.name != test_file_dupe.name + assert test_file.the_file.read() != data TestFile.drop_collection() @@ -268,8 +268,8 @@ class TestFileField(MongoDBTestCase): marmot.save() marmot = Animal.objects.get() - self.assertEqual(marmot.photo.content_type, "image/jpeg") - self.assertEqual(marmot.photo.foo, "bar") + assert marmot.photo.content_type == "image/jpeg" + assert marmot.photo.foo == "bar" def test_file_reassigning(self): class TestFile(Document): @@ -278,12 +278,12 @@ class TestFileField(MongoDBTestCase): TestFile.drop_collection() test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() - self.assertEqual(test_file.the_file.get().length, 8313) + assert test_file.the_file.get().length == 8313 test_file = TestFile.objects.first() test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() - self.assertEqual(test_file.the_file.get().length, 4971) + assert test_file.the_file.get().length == 4971 def test_file_boolean(self): """Ensure that a boolean test of a FileField indicates its presence @@ -295,13 +295,13 @@ class TestFileField(MongoDBTestCase): TestFile.drop_collection() test_file = TestFile() - self.assertFalse(bool(test_file.the_file)) + assert not bool(test_file.the_file) test_file.the_file.put(six.b("Hello, World!"), content_type="text/plain") test_file.save() - self.assertTrue(bool(test_file.the_file)) + assert bool(test_file.the_file) test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.content_type, "text/plain") + assert test_file.the_file.content_type == "text/plain" def test_file_cmp(self): """Test comparing against other types""" @@ -310,7 +310,7 @@ class TestFileField(MongoDBTestCase): the_file = FileField() test_file = TestFile() - self.assertNotIn(test_file.the_file, [{"test": 1}]) + assert test_file.the_file not in [{"test": 1}] def test_file_disk_space(self): """ Test disk space usage when we delete/replace a file """ @@ -330,16 +330,16 @@ class TestFileField(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 1) - self.assertEqual(len(list(chunks)), 1) + assert len(list(files)) == 1 + assert len(list(chunks)) == 1 # Deleting the docoument should delete the files testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 # Test case where we don't store a file in the first place testfile = TestFile() @@ -347,15 +347,15 @@ class TestFileField(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 # Test case where we overwrite the file testfile = TestFile() @@ -368,15 +368,15 @@ class TestFileField(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 1) - self.assertEqual(len(list(chunks)), 1) + assert len(list(files)) == 1 + assert len(list(chunks)) == 1 testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEqual(len(list(files)), 0) - self.assertEqual(len(list(chunks)), 0) + assert len(list(files)) == 0 + assert len(list(chunks)) == 0 def test_image_field(self): if not HAS_PIL: @@ -396,9 +396,7 @@ class TestFileField(MongoDBTestCase): t.image.put(f) self.fail("Should have raised an invalidation error") except ValidationError as e: - self.assertEqual( - "%s" % e, "Invalid image: cannot identify image file %s" % f - ) + assert "%s" % e == "Invalid image: cannot identify image file %s" % f t = TestImage() t.image.put(get_file(TEST_IMAGE_PATH)) @@ -406,11 +404,11 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, "PNG") + assert t.image.format == "PNG" w, h = t.image.size - self.assertEqual(w, 371) - self.assertEqual(h, 76) + assert w == 371 + assert h == 76 t.image.delete() @@ -424,12 +422,12 @@ class TestFileField(MongoDBTestCase): TestFile.drop_collection() test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() - self.assertEqual(test_file.the_file.size, (371, 76)) + assert test_file.the_file.size == (371, 76) test_file = TestFile.objects.first() test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() - self.assertEqual(test_file.the_file.size, (45, 101)) + assert test_file.the_file.size == (45, 101) def test_image_field_resize(self): if not HAS_PIL: @@ -446,11 +444,11 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, "PNG") + assert t.image.format == "PNG" w, h = t.image.size - self.assertEqual(w, 185) - self.assertEqual(h, 37) + assert w == 185 + assert h == 37 t.image.delete() @@ -469,11 +467,11 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, "PNG") + assert t.image.format == "PNG" w, h = t.image.size - self.assertEqual(w, 185) - self.assertEqual(h, 37) + assert w == 185 + assert h == 37 t.image.delete() @@ -492,9 +490,9 @@ class TestFileField(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.thumbnail.format, "PNG") - self.assertEqual(t.image.thumbnail.width, 92) - self.assertEqual(t.image.thumbnail.height, 18) + assert t.image.thumbnail.format == "PNG" + assert t.image.thumbnail.width == 92 + assert t.image.thumbnail.height == 18 t.image.delete() @@ -518,17 +516,17 @@ class TestFileField(MongoDBTestCase): test_file.save() data = get_db("test_files").macumba.files.find_one() - self.assertEqual(data.get("name"), "hello.txt") + assert data.get("name") == "hello.txt" test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), six.b("Hello, World!")) + assert test_file.the_file.read() == six.b("Hello, World!") test_file = TestFile.objects.first() test_file.the_file = six.b("HELLO, WORLD!") test_file.save() test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), six.b("HELLO, WORLD!")) + assert test_file.the_file.read() == six.b("HELLO, WORLD!") def test_copyable(self): class PutFile(Document): @@ -546,8 +544,8 @@ class TestFileField(MongoDBTestCase): class TestFile(Document): name = StringField() - self.assertEqual(putfile, copy.copy(putfile)) - self.assertEqual(putfile, copy.deepcopy(putfile)) + assert putfile == copy.copy(putfile) + assert putfile == copy.deepcopy(putfile) def test_get_image_by_grid_id(self): @@ -569,9 +567,7 @@ class TestFileField(MongoDBTestCase): test = TestImage.objects.first() grid_id = test.image1.grid_id - self.assertEqual( - 1, TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count() - ) + assert 1 == TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count() def test_complex_field_filefield(self): """Ensure you can add meta data to file""" @@ -593,9 +589,9 @@ class TestFileField(MongoDBTestCase): marmot.save() marmot = Animal.objects.get() - self.assertEqual(marmot.photos[0].content_type, "image/jpeg") - self.assertEqual(marmot.photos[0].foo, "bar") - self.assertEqual(marmot.photos[0].get().length, 8313) + assert marmot.photos[0].content_type == "image/jpeg" + assert marmot.photos[0].foo == "bar" + assert marmot.photos[0].get().length == 8313 if __name__ == "__main__": diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py index 9f357ce5..d755fb4e 100644 --- a/tests/fields/test_float_field.py +++ b/tests/fields/test_float_field.py @@ -4,6 +4,7 @@ import six from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestFloatField(MongoDBTestCase): @@ -16,8 +17,8 @@ class TestFloatField(MongoDBTestCase): TestDocument(float_fld=None).save() TestDocument(float_fld=1).save() - self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) - self.assertEqual(1, TestDocument.objects(float_fld__ne=1).count()) + assert 1 == TestDocument.objects(float_fld__ne=None).count() + assert 1 == TestDocument.objects(float_fld__ne=1).count() def test_validation(self): """Ensure that invalid values cannot be assigned to float fields. @@ -34,16 +35,20 @@ class TestFloatField(MongoDBTestCase): person.validate() person.height = "2.0" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = 0.01 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.height = 4.0 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person_2 = Person(height="something invalid") - self.assertRaises(ValidationError, person_2.validate) + with pytest.raises(ValidationError): + person_2.validate() big_person = BigPerson() @@ -55,4 +60,5 @@ class TestFloatField(MongoDBTestCase): big_person.validate() big_person.height = 2 ** 100000 # Too big for a float value - self.assertRaises(ValidationError, big_person.validate) + with pytest.raises(ValidationError): + big_person.validate() diff --git a/tests/fields/test_geo_fields.py b/tests/fields/test_geo_fields.py index ff4cbc83..1b912a4b 100644 --- a/tests/fields/test_geo_fields.py +++ b/tests/fields/test_geo_fields.py @@ -11,7 +11,7 @@ class TestGeoField(MongoDBTestCase): Cls(loc=loc).validate() self.fail("Should not validate the location {0}".format(loc)) except ValidationError as e: - self.assertEqual(expected, e.to_dict()["loc"]) + assert expected == e.to_dict()["loc"] def test_geopoint_validation(self): class Location(Document): @@ -299,7 +299,7 @@ class TestGeoField(MongoDBTestCase): location = GeoPointField() geo_indicies = Event._geo_indices() - self.assertEqual(geo_indicies, [{"fields": [("location", "2d")]}]) + assert geo_indicies == [{"fields": [("location", "2d")]}] def test_geopoint_embedded_indexes(self): """Ensure that indexes are created automatically for GeoPointFields on @@ -315,7 +315,7 @@ class TestGeoField(MongoDBTestCase): venue = EmbeddedDocumentField(Venue) geo_indicies = Event._geo_indices() - self.assertEqual(geo_indicies, [{"fields": [("venue.location", "2d")]}]) + assert geo_indicies == [{"fields": [("venue.location", "2d")]}] def test_indexes_2dsphere(self): """Ensure that indexes are created automatically for GeoPointFields. @@ -328,9 +328,9 @@ class TestGeoField(MongoDBTestCase): polygon = PolygonField() geo_indicies = Event._geo_indices() - self.assertIn({"fields": [("line", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("polygon", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("point", "2dsphere")]}, geo_indicies) + assert {"fields": [("line", "2dsphere")]} in geo_indicies + assert {"fields": [("polygon", "2dsphere")]} in geo_indicies + assert {"fields": [("point", "2dsphere")]} in geo_indicies def test_indexes_2dsphere_embedded(self): """Ensure that indexes are created automatically for GeoPointFields. @@ -347,9 +347,9 @@ class TestGeoField(MongoDBTestCase): venue = EmbeddedDocumentField(Venue) geo_indicies = Event._geo_indices() - self.assertIn({"fields": [("venue.line", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("venue.polygon", "2dsphere")]}, geo_indicies) - self.assertIn({"fields": [("venue.point", "2dsphere")]}, geo_indicies) + assert {"fields": [("venue.line", "2dsphere")]} in geo_indicies + assert {"fields": [("venue.polygon", "2dsphere")]} in geo_indicies + assert {"fields": [("venue.point", "2dsphere")]} in geo_indicies def test_geo_indexes_recursion(self): class Location(Document): @@ -365,12 +365,12 @@ class TestGeoField(MongoDBTestCase): Parent(name="Berlin").save() info = Parent._get_collection().index_information() - self.assertNotIn("location_2d", info) + assert "location_2d" not in info info = Location._get_collection().index_information() - self.assertIn("location_2d", info) + assert "location_2d" in info - self.assertEqual(len(Parent._geo_indices()), 0) - self.assertEqual(len(Location._geo_indices()), 1) + assert len(Parent._geo_indices()) == 0 + assert len(Location._geo_indices()) == 1 def test_geo_indexes_auto_index(self): @@ -381,16 +381,16 @@ class TestGeoField(MongoDBTestCase): meta = {"indexes": [[("location", "2dsphere"), ("datetime", 1)]]} - self.assertEqual([], Log._geo_indices()) + assert [] == Log._geo_indices() Log.drop_collection() Log.ensure_indexes() info = Log._get_collection().index_information() - self.assertEqual( - info["location_2dsphere_datetime_1"]["key"], - [("location", "2dsphere"), ("datetime", 1)], - ) + assert info["location_2dsphere_datetime_1"]["key"] == [ + ("location", "2dsphere"), + ("datetime", 1), + ] # Test listing explicitly class Log(Document): @@ -401,16 +401,16 @@ class TestGeoField(MongoDBTestCase): "indexes": [{"fields": [("location", "2dsphere"), ("datetime", 1)]}] } - self.assertEqual([], Log._geo_indices()) + assert [] == Log._geo_indices() Log.drop_collection() Log.ensure_indexes() info = Log._get_collection().index_information() - self.assertEqual( - info["location_2dsphere_datetime_1"]["key"], - [("location", "2dsphere"), ("datetime", 1)], - ) + assert info["location_2dsphere_datetime_1"]["key"] == [ + ("location", "2dsphere"), + ("datetime", 1), + ] if __name__ == "__main__": diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py index b7db0416..65a5fbad 100644 --- a/tests/fields/test_int_field.py +++ b/tests/fields/test_int_field.py @@ -2,6 +2,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestIntField(MongoDBTestCase): @@ -23,11 +24,14 @@ class TestIntField(MongoDBTestCase): person.validate() person.age = -1 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.age = 120 - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() person.age = "ten" - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_ne_operator(self): class TestDocument(Document): @@ -38,5 +42,5 @@ class TestIntField(MongoDBTestCase): TestDocument(int_fld=None).save() TestDocument(int_fld=1).save() - self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) - self.assertEqual(1, TestDocument.objects(int_fld__ne=1).count()) + assert 1 == TestDocument.objects(int_fld__ne=None).count() + assert 1 == TestDocument.objects(int_fld__ne=1).count() diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index 2a686d7f..8150574d 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -5,13 +5,15 @@ from mongoengine import * from mongoengine.base import LazyReference from tests.utils import MongoDBTestCase +import pytest class TestLazyReferenceField(MongoDBTestCase): def test_lazy_reference_config(self): # Make sure ReferenceField only accepts a document class or a string # with a document class name. - self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) + with pytest.raises(ValidationError): + LazyReferenceField(EmbeddedDocument) def test___repr__(self): class Animal(Document): @@ -25,7 +27,7 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal() oc = Ocurrence(animal=animal) - self.assertIn("LazyReference", repr(oc.animal)) + assert "LazyReference" in repr(oc.animal) def test___getattr___unknown_attr_raises_attribute_error(self): class Animal(Document): @@ -39,7 +41,7 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal().save() oc = Ocurrence(animal=animal) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): oc.animal.not_exist def test_lazy_reference_simple(self): @@ -57,19 +59,19 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() Ocurrence(person="test", animal=animal).save() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) + assert isinstance(p.animal, LazyReference) fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) + assert fetched_animal == animal # `fetch` keep cache on referenced document by default... animal.tag = "not so heavy" animal.save() double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") + assert fetched_animal is double_fetch + assert double_fetch.tag == "heavy" # ...unless specified otherwise fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") + assert fetch_force is not fetched_animal + assert fetch_force.tag == "not so heavy" def test_lazy_reference_fetch_invalid_ref(self): class Animal(Document): @@ -87,8 +89,8 @@ class TestLazyReferenceField(MongoDBTestCase): Ocurrence(person="test", animal=animal).save() animal.delete() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(DoesNotExist): + assert isinstance(p.animal, LazyReference) + with pytest.raises(DoesNotExist): p.animal.fetch() def test_lazy_reference_set(self): @@ -122,7 +124,7 @@ class TestLazyReferenceField(MongoDBTestCase): ): p = Ocurrence(person="test", animal=ref).save() p.reload() - self.assertIsInstance(p.animal, LazyReference) + assert isinstance(p.animal, LazyReference) p.animal.fetch() def test_lazy_reference_bad_set(self): @@ -149,7 +151,7 @@ class TestLazyReferenceField(MongoDBTestCase): DBRef(baddoc._get_collection_name(), animal.pk), LazyReference(BadDoc, animal.pk), ): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): p = Ocurrence(person="test", animal=bad).save() def test_lazy_reference_query_conversion(self): @@ -179,14 +181,14 @@ class TestLazyReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id # Same thing by passing a LazyReference instance post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_lazy_reference_query_conversion_dbref(self): """Ensure that LazyReferenceFields can be queried using objects and values @@ -215,14 +217,14 @@ class TestLazyReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id # Same thing by passing a LazyReference instance post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_lazy_reference_passthrough(self): class Animal(Document): @@ -239,20 +241,20 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() Ocurrence(animal=animal, animal_passthrough=animal).save() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(KeyError): + assert isinstance(p.animal, LazyReference) + with pytest.raises(KeyError): p.animal["name"] - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): p.animal.name - self.assertEqual(p.animal.pk, animal.pk) + assert p.animal.pk == animal.pk - self.assertEqual(p.animal_passthrough.name, "Leopard") - self.assertEqual(p.animal_passthrough["name"], "Leopard") + assert p.animal_passthrough.name == "Leopard" + assert p.animal_passthrough["name"] == "Leopard" # Should not be able to access referenced document's methods - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): p.animal.save - with self.assertRaises(KeyError): + with pytest.raises(KeyError): p.animal["save"] def test_lazy_reference_not_set(self): @@ -269,7 +271,7 @@ class TestLazyReferenceField(MongoDBTestCase): Ocurrence(person="foo").save() p = Ocurrence.objects.get() - self.assertIs(p.animal, None) + assert p.animal is None def test_lazy_reference_equality(self): class Animal(Document): @@ -280,12 +282,12 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() animalref = LazyReference(Animal, animal.pk) - self.assertEqual(animal, animalref) - self.assertEqual(animalref, animal) + assert animal == animalref + assert animalref == animal other_animalref = LazyReference(Animal, ObjectId("54495ad94c934721ede76f90")) - self.assertNotEqual(animal, other_animalref) - self.assertNotEqual(other_animalref, animal) + assert animal != other_animalref + assert other_animalref != animal def test_lazy_reference_embedded(self): class Animal(Document): @@ -308,12 +310,12 @@ class TestLazyReferenceField(MongoDBTestCase): animal2 = Animal(name="cheeta").save() def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) + assert isinstance(occ.direct, LazyReference) for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) + assert isinstance(elem, LazyReference) + assert isinstance(occ.in_embedded.direct, LazyReference) for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) + assert isinstance(elem, LazyReference) occ = Ocurrence( in_list=[animal1, animal2], @@ -346,19 +348,19 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() Ocurrence(person="test", animal=animal).save() p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) + assert isinstance(p.animal, LazyReference) fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) + assert fetched_animal == animal # `fetch` keep cache on referenced document by default... animal.tag = "not so heavy" animal.save() double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") + assert fetched_animal is double_fetch + assert double_fetch.tag == "heavy" # ...unless specified otherwise fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") + assert fetch_force is not fetched_animal + assert fetch_force.tag == "not so heavy" def test_generic_lazy_reference_choices(self): class Animal(Document): @@ -385,13 +387,13 @@ class TestGenericLazyReferenceField(MongoDBTestCase): occ_animal = Ocurrence(living_thing=animal, thing=animal).save() occ_vegetal = Ocurrence(living_thing=vegetal, thing=vegetal).save() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ocurrence(living_thing=mineral).save() occ = Ocurrence.objects.get(living_thing=animal) - self.assertEqual(occ, occ_animal) - self.assertIsInstance(occ.thing, LazyReference) - self.assertIsInstance(occ.living_thing, LazyReference) + assert occ == occ_animal + assert isinstance(occ.thing, LazyReference) + assert isinstance(occ.living_thing, LazyReference) occ.thing = vegetal occ.living_thing = vegetal @@ -399,7 +401,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): occ.thing = mineral occ.living_thing = mineral - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): occ.save() def test_generic_lazy_reference_set(self): @@ -434,7 +436,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): ): p = Ocurrence(person="test", animal=ref).save() p.reload() - self.assertIsInstance(p.animal, (LazyReference, Document)) + assert isinstance(p.animal, (LazyReference, Document)) p.animal.fetch() def test_generic_lazy_reference_bad_set(self): @@ -455,7 +457,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() baddoc = BadDoc().save() for bad in (42, "foo", baddoc, LazyReference(BadDoc, animal.pk)): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): p = Ocurrence(person="test", animal=bad).save() def test_generic_lazy_reference_query_conversion(self): @@ -481,14 +483,14 @@ class TestGenericLazyReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id # Same thing by passing a LazyReference instance post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_generic_lazy_reference_not_set(self): class Animal(Document): @@ -504,7 +506,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): Ocurrence(person="foo").save() p = Ocurrence.objects.get() - self.assertIs(p.animal, None) + assert p.animal is None def test_generic_lazy_reference_accepts_string_instead_of_class(self): class Animal(Document): @@ -521,7 +523,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal().save() Ocurrence(animal=animal).save() p = Ocurrence.objects.get() - self.assertEqual(p.animal, animal) + assert p.animal == animal def test_generic_lazy_reference_embedded(self): class Animal(Document): @@ -544,12 +546,12 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal2 = Animal(name="cheeta").save() def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) + assert isinstance(occ.direct, LazyReference) for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) + assert isinstance(elem, LazyReference) + assert isinstance(occ.in_embedded.direct, LazyReference) for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) + assert isinstance(elem, LazyReference) occ = Ocurrence( in_list=[animal1, animal2], diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py index ab86eccd..51f8e255 100644 --- a/tests/fields/test_long_field.py +++ b/tests/fields/test_long_field.py @@ -10,6 +10,7 @@ from mongoengine import * from mongoengine.connection import get_db from tests.utils import MongoDBTestCase +import pytest class TestLongField(MongoDBTestCase): @@ -24,10 +25,10 @@ class TestLongField(MongoDBTestCase): doc = TestLongFieldConsideredAsInt64(some_long=42).save() db = get_db() - self.assertIsInstance( + assert isinstance( db.test_long_field_considered_as_int64.find()[0]["some_long"], Int64 ) - self.assertIsInstance(doc.some_long, six.integer_types) + assert isinstance(doc.some_long, six.integer_types) def test_long_validation(self): """Ensure that invalid values cannot be assigned to long fields. @@ -41,11 +42,14 @@ class TestLongField(MongoDBTestCase): doc.validate() doc.value = -1 - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() doc.value = 120 - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() doc.value = "ten" - self.assertRaises(ValidationError, doc.validate) + with pytest.raises(ValidationError): + doc.validate() def test_long_ne_operator(self): class TestDocument(Document): @@ -56,4 +60,4 @@ class TestLongField(MongoDBTestCase): TestDocument(long_fld=None).save() TestDocument(long_fld=1).save() - self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) + assert 1 == TestDocument.objects(long_fld__ne=None).count() diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py index 54f70aa1..fd56ddd0 100644 --- a/tests/fields/test_map_field.py +++ b/tests/fields/test_map_field.py @@ -4,6 +4,7 @@ import datetime from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestMapField(MongoDBTestCase): @@ -19,11 +20,11 @@ class TestMapField(MongoDBTestCase): e.mapping["someint"] = 1 e.save() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping["somestring"] = "abc" e.save() - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): class NoDeclaredType(Document): mapping = MapField() @@ -51,10 +52,10 @@ class TestMapField(MongoDBTestCase): e.save() e2 = Extensible.objects.get(id=e.id) - self.assertIsInstance(e2.mapping["somestring"], StringSetting) - self.assertIsInstance(e2.mapping["someint"], IntegerSetting) + assert isinstance(e2.mapping["somestring"], StringSetting) + assert isinstance(e2.mapping["someint"], IntegerSetting) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): e.mapping["someint"] = 123 e.save() @@ -74,9 +75,9 @@ class TestMapField(MongoDBTestCase): Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) test = Test.objects.get() - self.assertEqual(test.my_map["DICTIONARY_KEY"].number, 2) + assert test.my_map["DICTIONARY_KEY"].number == 2 doc = self.db.test.find_one() - self.assertEqual(doc["x"]["DICTIONARY_KEY"]["i"], 2) + assert doc["x"]["DICTIONARY_KEY"]["i"] == 2 def test_mapfield_numerical_index(self): """Ensure that MapField accept numeric strings as indexes.""" @@ -116,13 +117,13 @@ class TestMapField(MongoDBTestCase): actions={"friends": Action(operation="drink", object="beer")}, ).save() - self.assertEqual(1, Log.objects(visited__friends__exists=True).count()) + assert 1 == Log.objects(visited__friends__exists=True).count() - self.assertEqual( - 1, - Log.objects( + assert ( + 1 + == Log.objects( actions__friends__operation="drink", actions__friends__object="beer" - ).count(), + ).count() ) def test_map_field_unicode(self): @@ -139,7 +140,7 @@ class TestMapField(MongoDBTestCase): tree.save() - self.assertEqual( - BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, - u"VALUE: éééé", + assert ( + BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description + == u"VALUE: éééé" ) diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index 783a46da..783d1315 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -4,6 +4,7 @@ from bson import DBRef, SON from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestReferenceField(MongoDBTestCase): @@ -24,19 +25,22 @@ class TestReferenceField(MongoDBTestCase): # Make sure ReferenceField only accepts a document class or a string # with a document class name. - self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) + with pytest.raises(ValidationError): + ReferenceField(EmbeddedDocument) user = User(name="Test User") # Ensure that the referenced object must have been saved post1 = BlogPost(content="Chips and gravy taste good.") post1.author = user - self.assertRaises(ValidationError, post1.save) + with pytest.raises(ValidationError): + post1.save() # Check that an invalid object type cannot be used post2 = BlogPost(content="Chips and chilli taste good.") post1.author = post2 - self.assertRaises(ValidationError, post1.validate) + with pytest.raises(ValidationError): + post1.validate() # Ensure ObjectID's are accepted as references user_object_id = user.pk @@ -52,7 +56,8 @@ class TestReferenceField(MongoDBTestCase): # Make sure referencing a saved document of the *wrong* type fails post2.save() post1.author = post2 - self.assertRaises(ValidationError, post1.validate) + with pytest.raises(ValidationError): + post1.validate() def test_objectid_reference_fields(self): """Make sure storing Object ID references works.""" @@ -67,7 +72,7 @@ class TestReferenceField(MongoDBTestCase): Person(name="Ross", parent=p1.pk).save() p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) + assert p.parent == p1 def test_dbref_reference_fields(self): """Make sure storing references as bson.dbref.DBRef works.""" @@ -81,13 +86,12 @@ class TestReferenceField(MongoDBTestCase): p1 = Person(name="John").save() Person(name="Ross", parent=p1).save() - self.assertEqual( - Person._get_collection().find_one({"name": "Ross"})["parent"], - DBRef("person", p1.pk), + assert Person._get_collection().find_one({"name": "Ross"})["parent"] == DBRef( + "person", p1.pk ) p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) + assert p.parent == p1 def test_dbref_to_mongo(self): """Make sure that calling to_mongo on a ReferenceField which @@ -100,9 +104,7 @@ class TestReferenceField(MongoDBTestCase): parent = ReferenceField("self", dbref=False) p = Person(name="Steve", parent=DBRef("person", "abcdefghijklmnop")) - self.assertEqual( - p.to_mongo(), SON([("name", u"Steve"), ("parent", "abcdefghijklmnop")]) - ) + assert p.to_mongo() == SON([("name", u"Steve"), ("parent", "abcdefghijklmnop")]) def test_objectid_reference_fields(self): class Person(Document): @@ -116,10 +118,10 @@ class TestReferenceField(MongoDBTestCase): col = Person._get_collection() data = col.find_one({"name": "Ross"}) - self.assertEqual(data["parent"], p1.pk) + assert data["parent"] == p1.pk p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) + assert p.parent == p1 def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. @@ -144,14 +146,14 @@ class TestReferenceField(MongoDBTestCase): me.save() obj = Product.objects(company=ten_gen).first() - self.assertEqual(obj, mongodb) - self.assertEqual(obj.company, ten_gen) + assert obj == mongodb + assert obj.company == ten_gen obj = Product.objects(company=None).first() - self.assertEqual(obj, me) + assert obj == me obj = Product.objects.get(company=None) - self.assertEqual(obj, me) + assert obj == me def test_reference_query_conversion(self): """Ensure that ReferenceFields can be queried using objects and values @@ -180,10 +182,10 @@ class TestReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id def test_reference_query_conversion_dbref(self): """Ensure that ReferenceFields can be queried using objects and values @@ -212,7 +214,7 @@ class TestReferenceField(MongoDBTestCase): post2.save() post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) + assert post.id == post1.id post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) + assert post.id == post2.id diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py index f2c8388b..aa83f710 100644 --- a/tests/fields/test_sequence_field.py +++ b/tests/fields/test_sequence_field.py @@ -18,17 +18,17 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 Person.id.set_next_value(1000) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 1000) + assert c["next"] == 1000 def test_sequence_field_get_next_value(self): class Person(Document): @@ -41,10 +41,10 @@ class TestSequenceField(MongoDBTestCase): for x in range(10): Person(name="Person %s" % x).save() - self.assertEqual(Person.id.get_next_value(), 11) + assert Person.id.get_next_value() == 11 self.db["mongoengine.counters"].drop() - self.assertEqual(Person.id.get_next_value(), 1) + assert Person.id.get_next_value() == 1 class Person(Document): id = SequenceField(primary_key=True, value_decorator=str) @@ -56,10 +56,10 @@ class TestSequenceField(MongoDBTestCase): for x in range(10): Person(name="Person %s" % x).save() - self.assertEqual(Person.id.get_next_value(), "11") + assert Person.id.get_next_value() == "11" self.db["mongoengine.counters"].drop() - self.assertEqual(Person.id.get_next_value(), "1") + assert Person.id.get_next_value() == "1" def test_sequence_field_sequence_name(self): class Person(Document): @@ -73,17 +73,17 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 Person.id.set_next_value(1000) c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) - self.assertEqual(c["next"], 1000) + assert c["next"] == 1000 def test_multiple_sequence_fields(self): class Person(Document): @@ -98,24 +98,24 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) counters = [i.counter for i in Person.objects] - self.assertEqual(counters, range(1, 11)) + assert counters == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 Person.id.set_next_value(1000) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 1000) + assert c["next"] == 1000 Person.counter.set_next_value(999) c = self.db["mongoengine.counters"].find_one({"_id": "person.counter"}) - self.assertEqual(c["next"], 999) + assert c["next"] == 999 def test_sequence_fields_reload(self): class Animal(Document): @@ -127,20 +127,20 @@ class TestSequenceField(MongoDBTestCase): a = Animal(name="Boi").save() - self.assertEqual(a.counter, 1) + assert a.counter == 1 a.reload() - self.assertEqual(a.counter, 1) + assert a.counter == 1 a.counter = None - self.assertEqual(a.counter, 2) + assert a.counter == 2 a.save() - self.assertEqual(a.counter, 2) + assert a.counter == 2 a = Animal.objects.first() - self.assertEqual(a.counter, 2) + assert a.counter == 2 a.reload() - self.assertEqual(a.counter, 2) + assert a.counter == 2 def test_multiple_sequence_fields_on_docs(self): class Animal(Document): @@ -160,22 +160,22 @@ class TestSequenceField(MongoDBTestCase): Person(name="Person %s" % x).save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) + assert ids == range(1, 11) id = [i.id for i in Animal.objects] - self.assertEqual(id, range(1, 11)) + assert id == range(1, 11) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 def test_sequence_field_value_decorator(self): class Person(Document): @@ -190,13 +190,13 @@ class TestSequenceField(MongoDBTestCase): p.save() c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 ids = [i.id for i in Person.objects] - self.assertEqual(ids, map(str, range(1, 11))) + assert ids == map(str, range(1, 11)) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) - self.assertEqual(c["next"], 10) + assert c["next"] == 10 def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): @@ -218,10 +218,10 @@ class TestSequenceField(MongoDBTestCase): ], ).save() c = self.db["mongoengine.counters"].find_one({"_id": "comment.id"}) - self.assertEqual(c["next"], 2) + assert c["next"] == 2 post = Post.objects.first() - self.assertEqual(1, post.comments[0].id) - self.assertEqual(2, post.comments[1].id) + assert 1 == post.comments[0].id + assert 2 == post.comments[1].id def test_inherited_sequencefield(self): class Base(Document): @@ -241,16 +241,14 @@ class TestSequenceField(MongoDBTestCase): foo = Foo(name="Foo") foo.save() - self.assertTrue( - "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") - ) - self.assertFalse( + assert "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") + assert not ( ("foo.counter" or "bar.counter") in self.db["mongoengine.counters"].find().distinct("_id") ) - self.assertNotEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields["counter"].owner_document, Base) - self.assertEqual(bar._fields["counter"].owner_document, Base) + assert foo.counter != bar.counter + assert foo._fields["counter"].owner_document == Base + assert bar._fields["counter"].owner_document == Base def test_no_inherited_sequencefield(self): class Base(Document): @@ -269,13 +267,12 @@ class TestSequenceField(MongoDBTestCase): foo = Foo(name="Foo") foo.save() - self.assertFalse( + assert not ( "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") ) - self.assertTrue( - ("foo.counter" and "bar.counter") - in self.db["mongoengine.counters"].find().distinct("_id") - ) - self.assertEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields["counter"].owner_document, Foo) - self.assertEqual(bar._fields["counter"].owner_document, Bar) + assert ("foo.counter" and "bar.counter") in self.db[ + "mongoengine.counters" + ].find().distinct("_id") + assert foo.counter == bar.counter + assert foo._fields["counter"].owner_document == Foo + assert bar._fields["counter"].owner_document == Bar diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index 81baf8d0..e7df0e08 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -2,6 +2,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase +import pytest class TestURLField(MongoDBTestCase): @@ -13,7 +14,8 @@ class TestURLField(MongoDBTestCase): link = Link() link.url = "google" - self.assertRaises(ValidationError, link.validate) + with pytest.raises(ValidationError): + link.validate() link.url = "http://www.google.com:8080" link.validate() @@ -29,11 +31,11 @@ class TestURLField(MongoDBTestCase): # TODO fix URL validation - this *IS* a valid URL # For now we just want to make sure that the error message is correct - with self.assertRaises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as ctx_err: link.validate() - self.assertEqual( - unicode(ctx_err.exception), - u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])", + assert ( + unicode(ctx_err.exception) + == u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])" ) def test_url_scheme_validation(self): @@ -48,7 +50,8 @@ class TestURLField(MongoDBTestCase): link = Link() link.url = "ws://google.com" - self.assertRaises(ValidationError, link.validate) + with pytest.raises(ValidationError): + link.validate() scheme_link = SchemeLink() scheme_link.url = "ws://google.com" diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py index 647dceaf..b1413f95 100644 --- a/tests/fields/test_uuid_field.py +++ b/tests/fields/test_uuid_field.py @@ -4,6 +4,7 @@ import uuid from mongoengine import * from tests.utils import MongoDBTestCase, get_as_pymongo +import pytest class Person(Document): @@ -14,9 +15,7 @@ class TestUUIDField(MongoDBTestCase): def test_storage(self): uid = uuid.uuid4() person = Person(api_key=uid).save() - self.assertEqual( - get_as_pymongo(person), {"_id": person.id, "api_key": str(uid)} - ) + assert get_as_pymongo(person) == {"_id": person.id, "api_key": str(uid)} def test_field_string(self): """Test UUID fields storing as String @@ -25,8 +24,8 @@ class TestUUIDField(MongoDBTestCase): uu = uuid.uuid4() Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) + assert 1 == Person.objects(api_key=uu).count() + assert uu == Person.objects.first().api_key person = Person() valid = (uuid.uuid4(), uuid.uuid1()) @@ -40,7 +39,8 @@ class TestUUIDField(MongoDBTestCase): ) for api_key in invalid: person.api_key = api_key - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() def test_field_binary(self): """Test UUID fields storing as Binary object.""" @@ -48,8 +48,8 @@ class TestUUIDField(MongoDBTestCase): uu = uuid.uuid4() Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) + assert 1 == Person.objects(api_key=uu).count() + assert uu == Person.objects.first().api_key person = Person() valid = (uuid.uuid4(), uuid.uuid1()) @@ -63,4 +63,5 @@ class TestUUIDField(MongoDBTestCase): ) for api_key in invalid: person.api_key = api_key - self.assertRaises(ValidationError, person.validate) + with pytest.raises(ValidationError): + person.validate() diff --git a/tests/queryset/test_field_list.py b/tests/queryset/test_field_list.py index 703c2031..d33c4c86 100644 --- a/tests/queryset/test_field_list.py +++ b/tests/queryset/test_field_list.py @@ -2,66 +2,67 @@ import unittest from mongoengine import * from mongoengine.queryset import QueryFieldList +import pytest class TestQueryFieldList(unittest.TestCase): def test_empty(self): q = QueryFieldList() - self.assertFalse(q) + assert not q q = QueryFieldList(always_include=["_cls"]) - self.assertFalse(q) + assert not q def test_include_include(self): q = QueryFieldList() q += QueryFieldList( fields=["a", "b"], value=QueryFieldList.ONLY, _only_called=True ) - self.assertEqual(q.as_dict(), {"a": 1, "b": 1}) + assert q.as_dict() == {"a": 1, "b": 1} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"a": 1, "b": 1, "c": 1}) + assert q.as_dict() == {"a": 1, "b": 1, "c": 1} def test_include_exclude(self): q = QueryFieldList() q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"a": 1, "b": 1}) + assert q.as_dict() == {"a": 1, "b": 1} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 1}) + assert q.as_dict() == {"a": 1} def test_exclude_exclude(self): q = QueryFieldList() q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 0, "b": 0}) + assert q.as_dict() == {"a": 0, "b": 0} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 0, "b": 0, "c": 0}) + assert q.as_dict() == {"a": 0, "b": 0, "c": 0} def test_exclude_include(self): q = QueryFieldList() q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {"a": 0, "b": 0}) + assert q.as_dict() == {"a": 0, "b": 0} q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"c": 1}) + assert q.as_dict() == {"c": 1} def test_always_include(self): q = QueryFieldList(always_include=["x", "y"]) q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1}) + assert q.as_dict() == {"x": 1, "y": 1, "c": 1} def test_reset(self): q = QueryFieldList(always_include=["x", "y"]) q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1}) + assert q.as_dict() == {"x": 1, "y": 1, "c": 1} q.reset() - self.assertFalse(q) + assert not q q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "b": 1, "c": 1}) + assert q.as_dict() == {"x": 1, "y": 1, "b": 1, "c": 1} def test_using_a_slice(self): q = QueryFieldList() q += QueryFieldList(fields=["a"], value={"$slice": 5}) - self.assertEqual(q.as_dict(), {"a": {"$slice": 5}}) + assert q.as_dict() == {"a": {"$slice": 5}} class TestOnlyExcludeAll(unittest.TestCase): @@ -90,25 +91,23 @@ class TestOnlyExcludeAll(unittest.TestCase): only = ["b", "c"] qs = MyDoc.objects.fields(**{i: 1 for i in include}) - self.assertEqual( - qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1} - ) + assert qs._loaded_fields.as_dict() == {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1} qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} qs = qs.exclude(*exclude) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} qs = MyDoc.objects.fields(**{i: 1 for i in include}) qs = qs.exclude(*exclude) - self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"a": 1, "b": 1, "c": 1} qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} qs = MyDoc.objects.exclude(*exclude) qs = qs.fields(**{i: 1 for i in include}) - self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"a": 1, "b": 1, "c": 1} qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": 1, "c": 1} def test_slicing(self): class MyDoc(Document): @@ -127,15 +126,16 @@ class TestOnlyExcludeAll(unittest.TestCase): qs = qs.exclude(*exclude) qs = qs.only(*only) qs = qs.fields(slice__b=5) - self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": 1}) + assert qs._loaded_fields.as_dict() == {"b": {"$slice": 5}, "c": 1} qs = qs.fields(slice__c=[5, 1]) - self.assertEqual( - qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": {"$slice": [5, 1]}} - ) + assert qs._loaded_fields.as_dict() == { + "b": {"$slice": 5}, + "c": {"$slice": [5, 1]}, + } qs = qs.exclude("c") - self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}}) + assert qs._loaded_fields.as_dict() == {"b": {"$slice": 5}} def test_mix_slice_with_other_fields(self): class MyDoc(Document): @@ -144,7 +144,7 @@ class TestOnlyExcludeAll(unittest.TestCase): c = ListField() qs = MyDoc.objects.fields(a=1, b=0, slice__c=2) - self.assertEqual(qs._loaded_fields.as_dict(), {"c": {"$slice": 2}, "a": 1}) + assert qs._loaded_fields.as_dict() == {"c": {"$slice": 2}, "a": 1} def test_only(self): """Ensure that QuerySet.only only returns the requested fields. @@ -153,20 +153,20 @@ class TestOnlyExcludeAll(unittest.TestCase): person.save() obj = self.Person.objects.only("name").get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, None) + assert obj.name == person.name + assert obj.age == None obj = self.Person.objects.only("age").get() - self.assertEqual(obj.name, None) - self.assertEqual(obj.age, person.age) + assert obj.name == None + assert obj.age == person.age obj = self.Person.objects.only("name", "age").get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, person.age) + assert obj.name == person.name + assert obj.age == person.age obj = self.Person.objects.only(*("id", "name")).get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, None) + assert obj.name == person.name + assert obj.age == None # Check polymorphism still works class Employee(self.Person): @@ -176,12 +176,12 @@ class TestOnlyExcludeAll(unittest.TestCase): employee.save() obj = self.Person.objects(id=employee.id).only("age").get() - self.assertIsInstance(obj, Employee) + assert isinstance(obj, Employee) # Check field names are looked up properly obj = Employee.objects(id=employee.id).only("salary").get() - self.assertEqual(obj.salary, employee.salary) - self.assertEqual(obj.name, None) + assert obj.salary == employee.salary + assert obj.name == None def test_only_with_subfields(self): class User(EmbeddedDocument): @@ -215,29 +215,29 @@ class TestOnlyExcludeAll(unittest.TestCase): post.save() obj = BlogPost.objects.only("author.name").get() - self.assertEqual(obj.content, None) - self.assertEqual(obj.author.email, None) - self.assertEqual(obj.author.name, "Test User") - self.assertEqual(obj.comments, []) + assert obj.content == None + assert obj.author.email == None + assert obj.author.name == "Test User" + assert obj.comments == [] obj = BlogPost.objects.only("various.test_dynamic.some").get() - self.assertEqual(obj.various["test_dynamic"].some, True) + assert obj.various["test_dynamic"].some == True obj = BlogPost.objects.only("content", "comments.title").get() - self.assertEqual(obj.content, "Had a good coffee today...") - self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, "I aggree") - self.assertEqual(obj.comments[1].title, "Coffee") - self.assertEqual(obj.comments[0].text, None) - self.assertEqual(obj.comments[1].text, None) + assert obj.content == "Had a good coffee today..." + assert obj.author == None + assert obj.comments[0].title == "I aggree" + assert obj.comments[1].title == "Coffee" + assert obj.comments[0].text == None + assert obj.comments[1].text == None obj = BlogPost.objects.only("comments").get() - self.assertEqual(obj.content, None) - self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, "I aggree") - self.assertEqual(obj.comments[1].title, "Coffee") - self.assertEqual(obj.comments[0].text, "Great post!") - self.assertEqual(obj.comments[1].text, "I hate coffee") + assert obj.content == None + assert obj.author == None + assert obj.comments[0].title == "I aggree" + assert obj.comments[1].title == "Coffee" + assert obj.comments[0].text == "Great post!" + assert obj.comments[1].text == "I hate coffee" BlogPost.drop_collection() @@ -266,10 +266,10 @@ class TestOnlyExcludeAll(unittest.TestCase): post.save() obj = BlogPost.objects.exclude("author", "comments.text").get() - self.assertEqual(obj.author, None) - self.assertEqual(obj.content, "Had a good coffee today...") - self.assertEqual(obj.comments[0].title, "I aggree") - self.assertEqual(obj.comments[0].text, None) + assert obj.author == None + assert obj.content == "Had a good coffee today..." + assert obj.comments[0].title == "I aggree" + assert obj.comments[0].text == None BlogPost.drop_collection() @@ -301,18 +301,18 @@ class TestOnlyExcludeAll(unittest.TestCase): email.save() obj = Email.objects.exclude("content_type").exclude("body").get() - self.assertEqual(obj.sender, "me") - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, "From Russia with Love") - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) + assert obj.sender == "me" + assert obj.to == "you" + assert obj.subject == "From Russia with Love" + assert obj.body == None + assert obj.content_type == None obj = Email.objects.only("sender", "to").exclude("body", "sender").get() - self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, None) - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) + assert obj.sender == None + assert obj.to == "you" + assert obj.subject == None + assert obj.body == None + assert obj.content_type == None obj = ( Email.objects.exclude("attachments.content") @@ -320,13 +320,13 @@ class TestOnlyExcludeAll(unittest.TestCase): .only("to", "attachments.name") .get() ) - self.assertEqual(obj.attachments[0].name, "file1.doc") - self.assertEqual(obj.attachments[0].content, None) - self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, None) - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) + assert obj.attachments[0].name == "file1.doc" + assert obj.attachments[0].content == None + assert obj.sender == None + assert obj.to == "you" + assert obj.subject == None + assert obj.body == None + assert obj.content_type == None Email.drop_collection() @@ -355,11 +355,11 @@ class TestOnlyExcludeAll(unittest.TestCase): .all_fields() .get() ) - self.assertEqual(obj.sender, "me") - self.assertEqual(obj.to, "you") - self.assertEqual(obj.subject, "From Russia with Love") - self.assertEqual(obj.body, "Hello!") - self.assertEqual(obj.content_type, "text/plain") + assert obj.sender == "me" + assert obj.to == "you" + assert obj.subject == "From Russia with Love" + assert obj.body == "Hello!" + assert obj.content_type == "text/plain" Email.drop_collection() @@ -377,27 +377,27 @@ class TestOnlyExcludeAll(unittest.TestCase): # first three numbers = Numbers.objects.fields(slice__n=3).get() - self.assertEqual(numbers.n, [0, 1, 2]) + assert numbers.n == [0, 1, 2] # last three numbers = Numbers.objects.fields(slice__n=-3).get() - self.assertEqual(numbers.n, [-3, -2, -1]) + assert numbers.n == [-3, -2, -1] # skip 2, limit 3 numbers = Numbers.objects.fields(slice__n=[2, 3]).get() - self.assertEqual(numbers.n, [2, 3, 4]) + assert numbers.n == [2, 3, 4] # skip to fifth from last, limit 4 numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2]) + assert numbers.n == [-5, -4, -3, -2] # skip to fifth from last, limit 10 numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) + assert numbers.n == [-5, -4, -3, -2, -1] # skip to fifth from last, limit 10 dict method numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) + assert numbers.n == [-5, -4, -3, -2, -1] def test_slicing_nested_fields(self): """Ensure that query slicing an embedded array works. @@ -417,27 +417,27 @@ class TestOnlyExcludeAll(unittest.TestCase): # first three numbers = Numbers.objects.fields(slice__embedded__n=3).get() - self.assertEqual(numbers.embedded.n, [0, 1, 2]) + assert numbers.embedded.n == [0, 1, 2] # last three numbers = Numbers.objects.fields(slice__embedded__n=-3).get() - self.assertEqual(numbers.embedded.n, [-3, -2, -1]) + assert numbers.embedded.n == [-3, -2, -1] # skip 2, limit 3 numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() - self.assertEqual(numbers.embedded.n, [2, 3, 4]) + assert numbers.embedded.n == [2, 3, 4] # skip to fifth from last, limit 4 numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2]) + assert numbers.embedded.n == [-5, -4, -3, -2] # skip to fifth from last, limit 10 numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + assert numbers.embedded.n == [-5, -4, -3, -2, -1] # skip to fifth from last, limit 10 dict method numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + assert numbers.embedded.n == [-5, -4, -3, -2, -1] def test_exclude_from_subclasses_docs(self): class Base(Document): @@ -456,9 +456,10 @@ class TestOnlyExcludeAll(unittest.TestCase): User(username="mongodb", password="secret").save() user = Base.objects().exclude("password", "wibble").first() - self.assertEqual(user.password, None) + assert user.password == None - self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + with pytest.raises(LookUpError): + Base.objects.exclude("made_up") if __name__ == "__main__": diff --git a/tests/queryset/test_geo.py b/tests/queryset/test_geo.py index 343f864b..a546fdb6 100644 --- a/tests/queryset/test_geo.py +++ b/tests/queryset/test_geo.py @@ -48,14 +48,14 @@ class TestGeoQueries(MongoDBTestCase): # note that "near" will show the san francisco event, too, # although it sorts to last. events = self.Event.objects(location__near=[-87.67892, 41.9120459]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) + assert events.count() == 3 + assert list(events) == [event1, event3, event2] # ensure ordering is respected by "near" events = self.Event.objects(location__near=[-87.67892, 41.9120459]) events = events.order_by("-date") - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event3, event1, event2]) + assert events.count() == 3 + assert list(events) == [event3, event1, event2] def test_near_and_max_distance(self): """Ensure the "max_distance" operator works alongside the "near" @@ -66,8 +66,8 @@ class TestGeoQueries(MongoDBTestCase): # find events within 10 degrees of san francisco point = [-122.415579, 37.7566023] events = self.Event.objects(location__near=point, location__max_distance=10) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 def test_near_and_min_distance(self): """Ensure the "min_distance" operator works alongside the "near" @@ -78,7 +78,7 @@ class TestGeoQueries(MongoDBTestCase): # find events at least 10 degrees away of san francisco point = [-122.415579, 37.7566023] events = self.Event.objects(location__near=point, location__min_distance=10) - self.assertEqual(events.count(), 2) + assert events.count() == 2 def test_within_distance(self): """Make sure the "within_distance" operator works.""" @@ -87,29 +87,29 @@ class TestGeoQueries(MongoDBTestCase): # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[-87.67892, 41.9120459], 5] events = self.Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 2) + assert events.count() == 2 events = list(events) - self.assertNotIn(event2, events) - self.assertIn(event1, events) - self.assertIn(event3, events) + assert event2 not in events + assert event1 in events + assert event3 in events # find events within 10 degrees of san francisco point_and_distance = [[-122.415579, 37.7566023], 10] events = self.Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 # find events within 1 degree of greenpoint, broolyn, nyc, ny point_and_distance = [[-73.9509714, 40.7237134], 1] events = self.Event.objects(location__within_distance=point_and_distance) - self.assertEqual(events.count(), 0) + assert events.count() == 0 # ensure ordering is respected by "within_distance" point_and_distance = [[-87.67892, 41.9120459], 10] events = self.Event.objects(location__within_distance=point_and_distance) events = events.order_by("-date") - self.assertEqual(events.count(), 2) - self.assertEqual(events[0], event3) + assert events.count() == 2 + assert events[0] == event3 def test_within_box(self): """Ensure the "within_box" operator works.""" @@ -118,8 +118,8 @@ class TestGeoQueries(MongoDBTestCase): # check that within_box works box = [(-125.0, 35.0), (-100.0, 40.0)] events = self.Event.objects(location__within_box=box) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event2.id) + assert events.count() == 1 + assert events[0].id == event2.id def test_within_polygon(self): """Ensure the "within_polygon" operator works.""" @@ -133,8 +133,8 @@ class TestGeoQueries(MongoDBTestCase): (-87.656164, 41.898061), ] events = self.Event.objects(location__within_polygon=polygon) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event1.id) + assert events.count() == 1 + assert events[0].id == event1.id polygon2 = [ (-1.742249, 54.033586), @@ -142,7 +142,7 @@ class TestGeoQueries(MongoDBTestCase): (-4.40094, 53.389881), ] events = self.Event.objects(location__within_polygon=polygon2) - self.assertEqual(events.count(), 0) + assert events.count() == 0 def test_2dsphere_near(self): """Make sure the "near" operator works with a PointField, which @@ -154,14 +154,14 @@ class TestGeoQueries(MongoDBTestCase): # note that "near" will show the san francisco event, too, # although it sorts to last. events = self.Event.objects(location__near=[-87.67892, 41.9120459]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) + assert events.count() == 3 + assert list(events) == [event1, event3, event2] # ensure ordering is respected by "near" events = self.Event.objects(location__near=[-87.67892, 41.9120459]) events = events.order_by("-date") - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event3, event1, event2]) + assert events.count() == 3 + assert list(events) == [event3, event1, event2] def test_2dsphere_near_and_max_distance(self): """Ensure the "max_distance" operator works alongside the "near" @@ -172,21 +172,21 @@ class TestGeoQueries(MongoDBTestCase): # find events within 10km of san francisco point = [-122.415579, 37.7566023] events = self.Event.objects(location__near=point, location__max_distance=10000) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 # find events within 1km of greenpoint, broolyn, nyc, ny events = self.Event.objects( location__near=[-73.9509714, 40.7237134], location__max_distance=1000 ) - self.assertEqual(events.count(), 0) + assert events.count() == 0 # ensure ordering is respected by "near" events = self.Event.objects( location__near=[-87.67892, 41.9120459], location__max_distance=10000 ).order_by("-date") - self.assertEqual(events.count(), 2) - self.assertEqual(events[0], event3) + assert events.count() == 2 + assert events[0] == event3 def test_2dsphere_geo_within_box(self): """Ensure the "geo_within_box" operator works with a 2dsphere @@ -197,8 +197,8 @@ class TestGeoQueries(MongoDBTestCase): # check that within_box works box = [(-125.0, 35.0), (-100.0, 40.0)] events = self.Event.objects(location__geo_within_box=box) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event2.id) + assert events.count() == 1 + assert events[0].id == event2.id def test_2dsphere_geo_within_polygon(self): """Ensure the "geo_within_polygon" operator works with a @@ -214,8 +214,8 @@ class TestGeoQueries(MongoDBTestCase): (-87.656164, 41.898061), ] events = self.Event.objects(location__geo_within_polygon=polygon) - self.assertEqual(events.count(), 1) - self.assertEqual(events[0].id, event1.id) + assert events.count() == 1 + assert events[0].id == event1.id polygon2 = [ (-1.742249, 54.033586), @@ -223,7 +223,7 @@ class TestGeoQueries(MongoDBTestCase): (-4.40094, 53.389881), ] events = self.Event.objects(location__geo_within_polygon=polygon2) - self.assertEqual(events.count(), 0) + assert events.count() == 0 def test_2dsphere_near_and_min_max_distance(self): """Ensure "min_distace" and "max_distance" operators work well @@ -237,15 +237,15 @@ class TestGeoQueries(MongoDBTestCase): location__min_distance=1000, location__max_distance=10000, ).order_by("-date") - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event3) + assert events.count() == 1 + assert events[0] == event3 # ensure ordering is respected by "near" with "min_distance" events = self.Event.objects( location__near=[-87.67892, 41.9120459], location__min_distance=10000 ).order_by("-date") - self.assertEqual(events.count(), 1) - self.assertEqual(events[0], event2) + assert events.count() == 1 + assert events[0] == event2 def test_2dsphere_geo_within_center(self): """Make sure the "geo_within_center" operator works with a @@ -256,11 +256,11 @@ class TestGeoQueries(MongoDBTestCase): # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[-87.67892, 41.9120459], 2] events = self.Event.objects(location__geo_within_center=point_and_distance) - self.assertEqual(events.count(), 2) + assert events.count() == 2 events = list(events) - self.assertNotIn(event2, events) - self.assertIn(event1, events) - self.assertIn(event3, events) + assert event2 not in events + assert event1 in events + assert event3 in events def _test_embedded(self, point_field_class): """Helper test method ensuring given point field class works @@ -290,8 +290,8 @@ class TestGeoQueries(MongoDBTestCase): # note that "near" will show the san francisco event, too, # although it sorts to last. events = Event.objects(venue__location__near=[-87.67892, 41.9120459]) - self.assertEqual(events.count(), 3) - self.assertEqual(list(events), [event1, event3, event2]) + assert events.count() == 3 + assert list(events) == [event1, event3, event2] def test_geo_spatial_embedded(self): """Make sure GeoPointField works properly in an embedded document.""" @@ -319,55 +319,55 @@ class TestGeoQueries(MongoDBTestCase): # Finds both points because they are within 60 km of the reference # point equidistant between them. points = Point.objects(location__near_sphere=[-122, 37.5]) - self.assertEqual(points.count(), 2) + assert points.count() == 2 # Same behavior for _within_spherical_distance points = Point.objects( location__within_spherical_distance=[[-122, 37.5], 60 / earth_radius] ) - self.assertEqual(points.count(), 2) + assert points.count() == 2 points = Point.objects( location__near_sphere=[-122, 37.5], location__max_distance=60 / earth_radius ) - self.assertEqual(points.count(), 2) + assert points.count() == 2 # Test query works with max_distance, being farer from one point points = Point.objects( location__near_sphere=[-122, 37.8], location__max_distance=60 / earth_radius ) close_point = points.first() - self.assertEqual(points.count(), 1) + assert points.count() == 1 # Test query works with min_distance, being farer from one point points = Point.objects( location__near_sphere=[-122, 37.8], location__min_distance=60 / earth_radius ) - self.assertEqual(points.count(), 1) + assert points.count() == 1 far_point = points.first() - self.assertNotEqual(close_point, far_point) + assert close_point != far_point # Finds both points, but orders the north point first because it's # closer to the reference point to the north. points = Point.objects(location__near_sphere=[-122, 38.5]) - self.assertEqual(points.count(), 2) - self.assertEqual(points[0].id, north_point.id) - self.assertEqual(points[1].id, south_point.id) + assert points.count() == 2 + assert points[0].id == north_point.id + assert points[1].id == south_point.id # Finds both points, but orders the south point first because it's # closer to the reference point to the south. points = Point.objects(location__near_sphere=[-122, 36.5]) - self.assertEqual(points.count(), 2) - self.assertEqual(points[0].id, south_point.id) - self.assertEqual(points[1].id, north_point.id) + assert points.count() == 2 + assert points[0].id == south_point.id + assert points[1].id == north_point.id # Finds only one point because only the first point is within 60km of # the reference point to the south. points = Point.objects( location__within_spherical_distance=[[-122, 36.5], 60 / earth_radius] ) - self.assertEqual(points.count(), 1) - self.assertEqual(points[0].id, south_point.id) + assert points.count() == 1 + assert points[0].id == south_point.id def test_linestring(self): class Road(Document): @@ -381,13 +381,13 @@ class TestGeoQueries(MongoDBTestCase): # near point = {"type": "Point", "coordinates": [40, 5]} roads = Road.objects.filter(line__near=point["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__near=point).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__near={"$geometry": point}).count() - self.assertEqual(1, roads) + assert 1 == roads # Within polygon = { @@ -395,37 +395,37 @@ class TestGeoQueries(MongoDBTestCase): "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_within=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_within={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads # Intersects line = {"type": "LineString", "coordinates": [[40, 5], [40, 6]]} roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects=line).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count() - self.assertEqual(1, roads) + assert 1 == roads polygon = { "type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(line__geo_intersects={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads def test_polygon(self): class Road(Document): @@ -439,13 +439,13 @@ class TestGeoQueries(MongoDBTestCase): # near point = {"type": "Point", "coordinates": [40, 5]} roads = Road.objects.filter(poly__near=point["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__near=point).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__near={"$geometry": point}).count() - self.assertEqual(1, roads) + assert 1 == roads # Within polygon = { @@ -453,37 +453,37 @@ class TestGeoQueries(MongoDBTestCase): "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_within=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_within={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads # Intersects line = {"type": "LineString", "coordinates": [[40, 5], [41, 6]]} roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects=line).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count() - self.assertEqual(1, roads) + assert 1 == roads polygon = { "type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], } roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects=polygon).count() - self.assertEqual(1, roads) + assert 1 == roads roads = Road.objects.filter(poly__geo_intersects={"$geometry": polygon}).count() - self.assertEqual(1, roads) + assert 1 == roads def test_aspymongo_with_only(self): """Ensure as_pymongo works with only""" @@ -495,13 +495,10 @@ class TestGeoQueries(MongoDBTestCase): p = Place(location=[24.946861267089844, 60.16311983618494]) p.save() qs = Place.objects().only("location") - self.assertDictEqual( - qs.as_pymongo()[0]["location"], - { - u"type": u"Point", - u"coordinates": [24.946861267089844, 60.16311983618494], - }, - ) + assert qs.as_pymongo()[0]["location"] == { + u"type": u"Point", + u"coordinates": [24.946861267089844, 60.16311983618494], + } def test_2dsphere_point_sets_correctly(self): class Location(Document): @@ -511,11 +508,11 @@ class TestGeoQueries(MongoDBTestCase): Location(loc=[1, 2]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [1, 2]}) + assert loc["loc"] == {"type": "Point", "coordinates": [1, 2]} Location.objects.update(set__loc=[2, 1]) loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [2, 1]}) + assert loc["loc"] == {"type": "Point", "coordinates": [2, 1]} def test_2dsphere_linestring_sets_correctly(self): class Location(Document): @@ -525,15 +522,11 @@ class TestGeoQueries(MongoDBTestCase): Location(line=[[1, 2], [2, 2]]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["line"], {"type": "LineString", "coordinates": [[1, 2], [2, 2]]} - ) + assert loc["line"] == {"type": "LineString", "coordinates": [[1, 2], [2, 2]]} Location.objects.update(set__line=[[2, 1], [1, 2]]) loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["line"], {"type": "LineString", "coordinates": [[2, 1], [1, 2]]} - ) + assert loc["line"] == {"type": "LineString", "coordinates": [[2, 1], [1, 2]]} def test_geojson_PolygonField(self): class Location(Document): @@ -543,17 +536,17 @@ class TestGeoQueries(MongoDBTestCase): Location(poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["poly"], - {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}, - ) + assert loc["poly"] == { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], + } Location.objects.update(set__poly=[[[40, 4], [40, 6], [41, 6], [40, 4]]]) loc = Location.objects.as_pymongo()[0] - self.assertEqual( - loc["poly"], - {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]}, - ) + assert loc["poly"] == { + "type": "Polygon", + "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]], + } if __name__ == "__main__": diff --git a/tests/queryset/test_modify.py b/tests/queryset/test_modify.py index 60f4884c..293a463e 100644 --- a/tests/queryset/test_modify.py +++ b/tests/queryset/test_modify.py @@ -14,14 +14,14 @@ class TestFindAndModify(unittest.TestCase): Doc.drop_collection() def assertDbEqual(self, docs): - self.assertEqual(list(Doc._collection.find().sort("id")), docs) + assert list(Doc._collection.find().sort("id")) == docs def test_modify(self): Doc(id=0, value=0).save() doc = Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).modify(set__value=-1) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_with_new(self): @@ -30,18 +30,18 @@ class TestFindAndModify(unittest.TestCase): new_doc = Doc.objects(id=1).modify(set__value=-1, new=True) doc.value = -1 - self.assertEqual(new_doc.to_json(), doc.to_json()) + assert new_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_not_existing(self): Doc(id=0, value=0).save() - self.assertEqual(Doc.objects(id=1).modify(set__value=-1), None) + assert Doc.objects(id=1).modify(set__value=-1) == None self.assertDbEqual([{"_id": 0, "value": 0}]) def test_modify_with_upsert(self): Doc(id=0, value=0).save() old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) - self.assertEqual(old_doc, None) + assert old_doc == None self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) def test_modify_with_upsert_existing(self): @@ -49,13 +49,13 @@ class TestFindAndModify(unittest.TestCase): doc = Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_with_upsert_with_new(self): Doc(id=0, value=0).save() new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1) - self.assertEqual(new_doc.to_mongo(), {"_id": 1, "value": 1}) + assert new_doc.to_mongo() == {"_id": 1, "value": 1} self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) def test_modify_with_remove(self): @@ -63,12 +63,12 @@ class TestFindAndModify(unittest.TestCase): doc = Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).modify(remove=True) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual([{"_id": 0, "value": 0}]) def test_find_and_modify_with_remove_not_existing(self): Doc(id=0, value=0).save() - self.assertEqual(Doc.objects(id=1).modify(remove=True), None) + assert Doc.objects(id=1).modify(remove=True) == None self.assertDbEqual([{"_id": 0, "value": 0}]) def test_modify_with_order_by(self): @@ -78,7 +78,7 @@ class TestFindAndModify(unittest.TestCase): doc = Doc(id=3, value=0).save() old_doc = Doc.objects().order_by("-id").modify(set__value=-1) - self.assertEqual(old_doc.to_json(), doc.to_json()) + assert old_doc.to_json() == doc.to_json() self.assertDbEqual( [ {"_id": 0, "value": 3}, @@ -93,7 +93,7 @@ class TestFindAndModify(unittest.TestCase): Doc(id=1, value=1).save() old_doc = Doc.objects(id=1).only("id").modify(set__value=-1) - self.assertEqual(old_doc.to_mongo(), {"_id": 1}) + assert old_doc.to_mongo() == {"_id": 1} self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) def test_modify_with_push(self): @@ -106,23 +106,23 @@ class TestFindAndModify(unittest.TestCase): # Push a new tag via modify with new=False (default). BlogPost(id=blog.id).modify(push__tags="code") - self.assertEqual(blog.tags, []) + assert blog.tags == [] blog.reload() - self.assertEqual(blog.tags, ["code"]) + assert blog.tags == ["code"] # Push a new tag via modify with new=True. blog = BlogPost.objects(id=blog.id).modify(push__tags="java", new=True) - self.assertEqual(blog.tags, ["code", "java"]) + assert blog.tags == ["code", "java"] # Push a new tag with a positional argument. blog = BlogPost.objects(id=blog.id).modify(push__tags__0="python", new=True) - self.assertEqual(blog.tags, ["python", "code", "java"]) + assert blog.tags == ["python", "code", "java"] # Push multiple new tags with a positional argument. blog = BlogPost.objects(id=blog.id).modify( push__tags__1=["go", "rust"], new=True ) - self.assertEqual(blog.tags, ["python", "go", "rust", "code", "java"]) + assert blog.tags == ["python", "go", "rust", "code", "java"] if __name__ == "__main__": diff --git a/tests/queryset/test_pickable.py b/tests/queryset/test_pickable.py index 8c4e3426..d41f56df 100644 --- a/tests/queryset/test_pickable.py +++ b/tests/queryset/test_pickable.py @@ -37,13 +37,13 @@ class TestQuerysetPickable(MongoDBTestCase): loadedQs = self._get_loaded(qs) - self.assertEqual(qs.count(), loadedQs.count()) + assert qs.count() == loadedQs.count() # can update loadedQs loadedQs.update(age=23) # check - self.assertEqual(Person.objects.first().age, 23) + assert Person.objects.first().age == 23 def test_pickle_support_filtration(self): Person.objects.create(name="Alice", age=22) @@ -51,9 +51,9 @@ class TestQuerysetPickable(MongoDBTestCase): Person.objects.create(name="Bob", age=23) qs = Person.objects.filter(age__gte=22) - self.assertEqual(qs.count(), 2) + assert qs.count() == 2 loaded = self._get_loaded(qs) - self.assertEqual(loaded.count(), 2) - self.assertEqual(loaded.filter(name="Bob").first().age, 23) + assert loaded.count() == 2 + assert loaded.filter(name="Bob").first().age == 23 diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 16213254..d154de8d 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -24,6 +24,7 @@ from mongoengine.queryset import ( QuerySetManager, queryset_manager, ) +import pytest class db_ops_tracker(query_counter): @@ -64,11 +65,11 @@ class TestQueryset(unittest.TestCase): def test_initialisation(self): """Ensure that a QuerySet is correctly initialised by QuerySetManager. """ - self.assertIsInstance(self.Person.objects, QuerySet) - self.assertEqual( - self.Person.objects._collection.name, self.Person._get_collection_name() + assert isinstance(self.Person.objects, QuerySet) + assert ( + self.Person.objects._collection.name == self.Person._get_collection_name() ) - self.assertIsInstance( + assert isinstance( self.Person.objects._collection, pymongo.collection.Collection ) @@ -78,11 +79,11 @@ class TestQueryset(unittest.TestCase): author2 = GenericReferenceField() # test addressing a field from a reference - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): list(BlogPost.objects(author__name="test")) # should fail for a generic reference as well - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): list(BlogPost.objects(author2__name="test")) def test_find(self): @@ -92,27 +93,27 @@ class TestQueryset(unittest.TestCase): # Find all people in the collection people = self.Person.objects - self.assertEqual(people.count(), 2) + assert people.count() == 2 results = list(people) - self.assertIsInstance(results[0], self.Person) - self.assertIsInstance(results[0].id, ObjectId) + assert isinstance(results[0], self.Person) + assert isinstance(results[0].id, ObjectId) - self.assertEqual(results[0], user_a) - self.assertEqual(results[0].name, "User A") - self.assertEqual(results[0].age, 20) + assert results[0] == user_a + assert results[0].name == "User A" + assert results[0].age == 20 - self.assertEqual(results[1], user_b) - self.assertEqual(results[1].name, "User B") - self.assertEqual(results[1].age, 30) + assert results[1] == user_b + assert results[1].name == "User B" + assert results[1].age == 30 # Filter people by age people = self.Person.objects(age=20) - self.assertEqual(people.count(), 1) + assert people.count() == 1 person = people.next() - self.assertEqual(person, user_a) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, 20) + assert person == user_a + assert person.name == "User A" + assert person.age == 20 def test_limit(self): """Ensure that QuerySet.limit works as expected.""" @@ -121,27 +122,27 @@ class TestQueryset(unittest.TestCase): # Test limit on a new queryset people = list(self.Person.objects.limit(1)) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], user_a) + assert len(people) == 1 + assert people[0] == user_a # Test limit on an existing queryset people = self.Person.objects - self.assertEqual(len(people), 2) + assert len(people) == 2 people2 = people.limit(1) - self.assertEqual(len(people), 2) - self.assertEqual(len(people2), 1) - self.assertEqual(people2[0], user_a) + assert len(people) == 2 + assert len(people2) == 1 + assert people2[0] == user_a # Test limit with 0 as parameter people = self.Person.objects.limit(0) - self.assertEqual(people.count(with_limit_and_skip=True), 2) - self.assertEqual(len(people), 2) + assert people.count(with_limit_and_skip=True) == 2 + assert len(people) == 2 # Test chaining of only after limit person = self.Person.objects().limit(1).only("name").first() - self.assertEqual(person, user_a) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, None) + assert person == user_a + assert person.name == "User A" + assert person.age == None def test_skip(self): """Ensure that QuerySet.skip works as expected.""" @@ -150,26 +151,26 @@ class TestQueryset(unittest.TestCase): # Test skip on a new queryset people = list(self.Person.objects.skip(1)) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], user_b) + assert len(people) == 1 + assert people[0] == user_b # Test skip on an existing queryset people = self.Person.objects - self.assertEqual(len(people), 2) + assert len(people) == 2 people2 = people.skip(1) - self.assertEqual(len(people), 2) - self.assertEqual(len(people2), 1) - self.assertEqual(people2[0], user_b) + assert len(people) == 2 + assert len(people2) == 1 + assert people2[0] == user_b # Test chaining of only after skip person = self.Person.objects().skip(1).only("name").first() - self.assertEqual(person, user_b) - self.assertEqual(person.name, "User B") - self.assertEqual(person.age, None) + assert person == user_b + assert person.name == "User B" + assert person.age == None def test___getitem___invalid_index(self): """Ensure slicing a queryset works as expected.""" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.Person.objects()["a"] def test_slice(self): @@ -180,27 +181,27 @@ class TestQueryset(unittest.TestCase): # Test slice limit people = list(self.Person.objects[:2]) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], user_a) - self.assertEqual(people[1], user_b) + assert len(people) == 2 + assert people[0] == user_a + assert people[1] == user_b # Test slice skip people = list(self.Person.objects[1:]) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], user_b) - self.assertEqual(people[1], user_c) + assert len(people) == 2 + assert people[0] == user_b + assert people[1] == user_c # Test slice limit and skip people = list(self.Person.objects[1:2]) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], user_b) + assert len(people) == 1 + assert people[0] == user_b # Test slice limit and skip on an existing queryset people = self.Person.objects - self.assertEqual(len(people), 3) + assert len(people) == 3 people2 = people[1:2] - self.assertEqual(len(people2), 1) - self.assertEqual(people2[0], user_b) + assert len(people2) == 1 + assert people2[0] == user_b # Test slice limit and skip cursor reset qs = self.Person.objects[1:2] @@ -208,31 +209,31 @@ class TestQueryset(unittest.TestCase): qs._cursor qs._cursor_obj = None people = list(qs) - self.assertEqual(len(people), 1) - self.assertEqual(people[0].name, "User B") + assert len(people) == 1 + assert people[0].name == "User B" # Test empty slice people = list(self.Person.objects[1:1]) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test slice out of range people = list(self.Person.objects[80000:80001]) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test larger slice __repr__ self.Person.objects.delete() for i in range(55): self.Person(name="A%s" % i, age=i).save() - self.assertEqual(self.Person.objects.count(), 55) - self.assertEqual("Person object", "%s" % self.Person.objects[0]) - self.assertEqual( - "[, ]", - "%s" % self.Person.objects[1:3], + assert self.Person.objects.count() == 55 + assert "Person object" == "%s" % self.Person.objects[0] + assert ( + "[, ]" + == "%s" % self.Person.objects[1:3] ) - self.assertEqual( - "[, ]", - "%s" % self.Person.objects[51:53], + assert ( + "[, ]" + == "%s" % self.Person.objects[51:53] ) def test_find_one(self): @@ -245,40 +246,42 @@ class TestQueryset(unittest.TestCase): # Retrieve the first person from the database person = self.Person.objects.first() - self.assertIsInstance(person, self.Person) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, 20) + assert isinstance(person, self.Person) + assert person.name == "User A" + assert person.age == 20 # Use a query to filter the people found to just person2 person = self.Person.objects(age=30).first() - self.assertEqual(person.name, "User B") + assert person.name == "User B" person = self.Person.objects(age__lt=30).first() - self.assertEqual(person.name, "User A") + assert person.name == "User A" # Use array syntax person = self.Person.objects[0] - self.assertEqual(person.name, "User A") + assert person.name == "User A" person = self.Person.objects[1] - self.assertEqual(person.name, "User B") + assert person.name == "User B" - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.Person.objects[2] # Find a document using just the object id person = self.Person.objects.with_id(person1.id) - self.assertEqual(person.name, "User A") + assert person.name == "User A" - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): self.Person.objects(name="User A").with_id(person1.id) def test_find_only_one(self): """Ensure that a query using ``get`` returns at most one result. """ # Try retrieving when no objects exists - self.assertRaises(DoesNotExist, self.Person.objects.get) - self.assertRaises(self.Person.DoesNotExist, self.Person.objects.get) + with pytest.raises(DoesNotExist): + self.Person.objects.get() + with pytest.raises(self.Person.DoesNotExist): + self.Person.objects.get() person1 = self.Person(name="User A", age=20) person1.save() @@ -286,15 +289,17 @@ class TestQueryset(unittest.TestCase): person2.save() # Retrieve the first person from the database - self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) - self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get) + with pytest.raises(MultipleObjectsReturned): + self.Person.objects.get() + with pytest.raises(self.Person.MultipleObjectsReturned): + self.Person.objects.get() # Use a query to filter the people found to just person2 person = self.Person.objects.get(age=30) - self.assertEqual(person.name, "User B") + assert person.name == "User B" person = self.Person.objects.get(age__lt=30) - self.assertEqual(person.name, "User A") + assert person.name == "User A" def test_find_array_position(self): """Ensure that query by array position works. @@ -313,10 +318,10 @@ class TestQueryset(unittest.TestCase): Blog.drop_collection() Blog.objects.create(tags=["a", "b"]) - self.assertEqual(Blog.objects(tags__0="a").count(), 1) - self.assertEqual(Blog.objects(tags__0="b").count(), 0) - self.assertEqual(Blog.objects(tags__1="a").count(), 0) - self.assertEqual(Blog.objects(tags__1="b").count(), 1) + assert Blog.objects(tags__0="a").count() == 1 + assert Blog.objects(tags__0="b").count() == 0 + assert Blog.objects(tags__1="a").count() == 0 + assert Blog.objects(tags__1="b").count() == 1 Blog.drop_collection() @@ -328,19 +333,19 @@ class TestQueryset(unittest.TestCase): blog2 = Blog.objects.create(posts=[post2, post1]) blog = Blog.objects(posts__0__comments__0__name="testa").get() - self.assertEqual(blog, blog1) + assert blog == blog1 blog = Blog.objects(posts__0__comments__0__name="testb").get() - self.assertEqual(blog, blog2) + assert blog == blog2 query = Blog.objects(posts__1__comments__1__name="testb") - self.assertEqual(query.count(), 2) + assert query.count() == 2 query = Blog.objects(posts__1__comments__1__name="testa") - self.assertEqual(query.count(), 0) + assert query.count() == 0 query = Blog.objects(posts__0__comments__1__name="testa") - self.assertEqual(query.count(), 0) + assert query.count() == 0 Blog.drop_collection() @@ -351,8 +356,8 @@ class TestQueryset(unittest.TestCase): A.drop_collection() A().save() - self.assertEqual(list(A.objects.none()), []) - self.assertEqual(list(A.objects.none().all()), []) + assert list(A.objects.none()) == [] + assert list(A.objects.none().all()) == [] def test_chaining(self): class A(Document): @@ -376,12 +381,12 @@ class TestQueryset(unittest.TestCase): # Doesn't work q2 = B.objects.filter(ref__in=[a1, a2]) q2 = q2.filter(ref=a1)._query - self.assertEqual(q1, q2) + assert q1 == q2 a_objects = A.objects(s="test1") query = B.objects(ref__in=a_objects) query = query.filter(boolfield=True) - self.assertEqual(query.count(), 1) + assert query.count() == 1 def test_batch_size(self): """Ensure that batch_size works.""" @@ -398,7 +403,7 @@ class TestQueryset(unittest.TestCase): cnt = 0 for a in A.objects.batch_size(10): cnt += 1 - self.assertEqual(cnt, 100) + assert cnt == 100 # test chaining qs = A.objects.all() @@ -406,11 +411,11 @@ class TestQueryset(unittest.TestCase): cnt = 0 for a in qs: cnt += 1 - self.assertEqual(cnt, 9) + assert cnt == 9 # test invalid batch size qs = A.objects.batch_size(-1) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): list(qs) def test_batch_size_cloned(self): @@ -419,9 +424,9 @@ class TestQueryset(unittest.TestCase): # test that batch size gets cloned qs = A.objects.batch_size(5) - self.assertEqual(qs._batch_size, 5) + assert qs._batch_size == 5 qs_clone = qs.clone() - self.assertEqual(qs_clone._batch_size, 5) + assert qs_clone._batch_size == 5 def test_update_write_concern(self): """Test that passing write_concern works""" @@ -437,18 +442,18 @@ class TestQueryset(unittest.TestCase): result = self.Person.objects.update(set__name="Ross", write_concern={"w": 1}) - self.assertEqual(result, 2) + assert result == 2 result = self.Person.objects.update(set__name="Ross", write_concern={"w": 0}) - self.assertEqual(result, None) + assert result == None result = self.Person.objects.update_one( set__name="Test User", write_concern={"w": 1} ) - self.assertEqual(result, 1) + assert result == 1 result = self.Person.objects.update_one( set__name="Test User", write_concern={"w": 0} ) - self.assertEqual(result, None) + assert result == None def test_update_update_has_a_value(self): """Test to ensure that update is passed a value to update to""" @@ -456,10 +461,10 @@ class TestQueryset(unittest.TestCase): author = self.Person.objects.create(name="Test User") - with self.assertRaises(OperationError): + with pytest.raises(OperationError): self.Person.objects(pk=author.pk).update({}) - with self.assertRaises(OperationError): + with pytest.raises(OperationError): self.Person.objects(pk=author.pk).update_one({}) def test_update_array_position(self): @@ -492,7 +497,7 @@ class TestQueryset(unittest.TestCase): # Update all of the first comments of second posts of all blogs Blog.objects().update(set__posts__1__comments__0__name="testc") testc_blogs = Blog.objects(posts__1__comments__0__name="testc") - self.assertEqual(testc_blogs.count(), 2) + assert testc_blogs.count() == 2 Blog.drop_collection() Blog.objects.create(posts=[post1, post2]) @@ -501,10 +506,10 @@ class TestQueryset(unittest.TestCase): # Update only the first blog returned by the query Blog.objects().update_one(set__posts__1__comments__1__name="testc") testc_blogs = Blog.objects(posts__1__comments__1__name="testc") - self.assertEqual(testc_blogs.count(), 1) + assert testc_blogs.count() == 1 # Check that using this indexing syntax on a non-list fails - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Blog.objects().update(set__posts__1__comments__0__name__1="asdf") Blog.drop_collection() @@ -531,8 +536,8 @@ class TestQueryset(unittest.TestCase): BlogPost.objects(comments__by="jane").update(inc__comments__S__votes=1) post = BlogPost.objects.first() - self.assertEqual(post.comments[1].by, "jane") - self.assertEqual(post.comments[1].votes, 8) + assert post.comments[1].by == "jane" + assert post.comments[1].votes == 8 def test_update_using_positional_operator_matches_first(self): @@ -547,7 +552,7 @@ class TestQueryset(unittest.TestCase): Simple.objects(x=2).update(inc__x__S=1) simple = Simple.objects.first() - self.assertEqual(simple.x, [1, 3, 3, 2]) + assert simple.x == [1, 3, 3, 2] Simple.drop_collection() # You can set multiples @@ -559,10 +564,10 @@ class TestQueryset(unittest.TestCase): Simple.objects(x=3).update(set__x__S=0) s = Simple.objects() - self.assertEqual(s[0].x, [1, 2, 0, 4]) - self.assertEqual(s[1].x, [2, 0, 4, 5]) - self.assertEqual(s[2].x, [0, 4, 5, 6]) - self.assertEqual(s[3].x, [4, 5, 6, 7]) + assert s[0].x == [1, 2, 0, 4] + assert s[1].x == [2, 0, 4, 5] + assert s[2].x == [0, 4, 5, 6] + assert s[3].x == [4, 5, 6, 7] # Using "$unset" with an expression like this "array.$" will result in # the array item becoming None, not being removed. @@ -570,14 +575,14 @@ class TestQueryset(unittest.TestCase): Simple(x=[1, 2, 3, 4, 3, 2, 3, 4]).save() Simple.objects(x=3).update(unset__x__S=1) simple = Simple.objects.first() - self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) + assert simple.x == [1, 2, None, 4, 3, 2, 3, 4] # Nested updates arent supported yet.. - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Simple.drop_collection() Simple(x=[{"test": [1, 2, 3, 4]}]).save() Simple.objects(x__test=2).update(set__x__S__test__S=3) - self.assertEqual(simple.x, [1, 2, 3, 4]) + assert simple.x == [1, 2, 3, 4] def test_update_using_positional_operator_embedded_document(self): """Ensure that the embedded documents can be updated using the positional @@ -606,8 +611,8 @@ class TestQueryset(unittest.TestCase): ) post = BlogPost.objects.first() - self.assertEqual(post.comments[0].by, "joe") - self.assertEqual(post.comments[0].votes.score, 4) + assert post.comments[0].by == "joe" + assert post.comments[0].votes.score == 4 def test_update_min_max(self): class Scores(Document): @@ -617,14 +622,14 @@ class TestQueryset(unittest.TestCase): scores = Scores.objects.create(high_score=800, low_score=200) Scores.objects(id=scores.id).update(min__low_score=150) - self.assertEqual(Scores.objects.get(id=scores.id).low_score, 150) + assert Scores.objects.get(id=scores.id).low_score == 150 Scores.objects(id=scores.id).update(min__low_score=250) - self.assertEqual(Scores.objects.get(id=scores.id).low_score, 150) + assert Scores.objects.get(id=scores.id).low_score == 150 Scores.objects(id=scores.id).update(max__high_score=1000) - self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) + assert Scores.objects.get(id=scores.id).high_score == 1000 Scores.objects(id=scores.id).update(max__high_score=500) - self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) + assert Scores.objects.get(id=scores.id).high_score == 1000 def test_update_multiple(self): class Product(Document): @@ -634,10 +639,10 @@ class TestQueryset(unittest.TestCase): product = Product.objects.create(item="ABC", price=10.99) product = Product.objects.create(item="ABC", price=10.99) Product.objects(id=product.id).update(mul__price=1.25) - self.assertEqual(Product.objects.get(id=product.id).price, 13.7375) + assert Product.objects.get(id=product.id).price == 13.7375 unknown_product = Product.objects.create(item="Unknown") Product.objects(id=unknown_product.id).update(mul__price=100) - self.assertEqual(Product.objects.get(id=unknown_product.id).price, 0) + assert Product.objects.get(id=unknown_product.id).price == 0 def test_updates_can_have_match_operators(self): class Comment(EmbeddedDocument): @@ -663,7 +668,7 @@ class TestQueryset(unittest.TestCase): Post.objects().update_one(pull__comments__vote__lt=1) - self.assertEqual(1, len(Post.objects.first().comments)) + assert 1 == len(Post.objects.first().comments) def test_mapfield_update(self): """Ensure that the MapField can be updated.""" @@ -684,8 +689,8 @@ class TestQueryset(unittest.TestCase): Club.objects().update(set__members={"John": Member(gender="F", age=14)}) club = Club.objects().first() - self.assertEqual(club.members["John"].gender, "F") - self.assertEqual(club.members["John"].age, 14) + assert club.members["John"].gender == "F" + assert club.members["John"].age == 14 def test_dictfield_update(self): """Ensure that the DictField can be updated.""" @@ -700,25 +705,25 @@ class TestQueryset(unittest.TestCase): Club.objects().update(set__members={"John": {"gender": "F", "age": 14}}) club = Club.objects().first() - self.assertEqual(club.members["John"]["gender"], "F") - self.assertEqual(club.members["John"]["age"], 14) + assert club.members["John"]["gender"] == "F" + assert club.members["John"]["age"] == 14 def test_update_results(self): self.Person.drop_collection() result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True) - self.assertIsInstance(result, UpdateResult) - self.assertIn("upserted", result.raw_result) - self.assertFalse(result.raw_result["updatedExisting"]) + assert isinstance(result, UpdateResult) + assert "upserted" in result.raw_result + assert not result.raw_result["updatedExisting"] bob = self.Person.objects.first() result = bob.update(set__age=30, full_result=True) - self.assertIsInstance(result, UpdateResult) - self.assertTrue(result.raw_result["updatedExisting"]) + assert isinstance(result, UpdateResult) + assert result.raw_result["updatedExisting"] self.Person(name="Bob", age=20).save() result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True) - self.assertEqual(result, 2) + assert result == 2 def test_update_validate(self): class EmDoc(EmbeddedDocument): @@ -730,13 +735,12 @@ class TestQueryset(unittest.TestCase): cdt_f = ComplexDateTimeField() ed_f = EmbeddedDocumentField(EmDoc) - self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True) - self.assertRaises( - ValidationError, Doc.objects().update, dt_f="datetime", upsert=True - ) - self.assertRaises( - ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True - ) + with pytest.raises(ValidationError): + Doc.objects().update(str_f=1, upsert=True) + with pytest.raises(ValidationError): + Doc.objects().update(dt_f="datetime", upsert=True) + with pytest.raises(ValidationError): + Doc.objects().update(ed_f__str_f=1, upsert=True) def test_update_related_models(self): class TestPerson(Document): @@ -757,20 +761,20 @@ class TestQueryset(unittest.TestCase): o.owner = p p.name = "p2" - self.assertEqual(o._get_changed_fields(), ["owner"]) - self.assertEqual(p._get_changed_fields(), ["name"]) + assert o._get_changed_fields() == ["owner"] + assert p._get_changed_fields() == ["name"] o.save() - self.assertEqual(o._get_changed_fields(), []) - self.assertEqual(p._get_changed_fields(), ["name"]) # Fails; it's empty + assert o._get_changed_fields() == [] + assert p._get_changed_fields() == ["name"] # Fails; it's empty # This will do NOTHING at all, even though we changed the name p.save() p.reload() - self.assertEqual(p.name, "p2") # Fails; it's still `p1` + assert p.name == "p2" # Fails; it's still `p1` def test_upsert(self): self.Person.drop_collection() @@ -778,25 +782,25 @@ class TestQueryset(unittest.TestCase): self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True) bob = self.Person.objects.first() - self.assertEqual("Bob", bob.name) - self.assertEqual(30, bob.age) + assert "Bob" == bob.name + assert 30 == bob.age def test_upsert_one(self): self.Person.drop_collection() bob = self.Person.objects(name="Bob", age=30).upsert_one() - self.assertEqual("Bob", bob.name) - self.assertEqual(30, bob.age) + assert "Bob" == bob.name + assert 30 == bob.age bob.name = "Bobby" bob.save() bobby = self.Person.objects(name="Bobby", age=30).upsert_one() - self.assertEqual("Bobby", bobby.name) - self.assertEqual(30, bobby.age) - self.assertEqual(bob.id, bobby.id) + assert "Bobby" == bobby.name + assert 30 == bobby.age + assert bob.id == bobby.id def test_set_on_insert(self): self.Person.drop_collection() @@ -806,8 +810,8 @@ class TestQueryset(unittest.TestCase): ) bob = self.Person.objects.first() - self.assertEqual("Bob", bob.name) - self.assertEqual(30, bob.age) + assert "Bob" == bob.name + assert 30 == bob.age def test_save_and_only_on_fields_with_default(self): class Embed(EmbeddedDocument): @@ -832,9 +836,9 @@ class TestQueryset(unittest.TestCase): # Checking it was saved correctly record.reload() - self.assertEqual(record.field, 2) - self.assertEqual(record.embed_no_default.field, 2) - self.assertEqual(record.embed.field, 2) + assert record.field == 2 + assert record.embed_no_default.field == 2 + assert record.embed.field == 2 # Request only the _id field and save clone = B.objects().only("id").first() @@ -842,9 +846,9 @@ class TestQueryset(unittest.TestCase): # Reload the record and see that the embed data is not lost record.reload() - self.assertEqual(record.field, 2) - self.assertEqual(record.embed_no_default.field, 2) - self.assertEqual(record.embed.field, 2) + assert record.field == 2 + assert record.embed_no_default.field == 2 + assert record.embed.field == 2 def test_bulk_insert(self): """Ensure that bulk insert works""" @@ -863,7 +867,7 @@ class TestQueryset(unittest.TestCase): Blog.drop_collection() # Recreates the collection - self.assertEqual(0, Blog.objects.count()) + assert 0 == Blog.objects.count() comment1 = Comment(name="testa") comment2 = Comment(name="testb") @@ -873,11 +877,11 @@ class TestQueryset(unittest.TestCase): # Check bulk insert using load_bulk=False blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)] with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 1) # 1 entry containing the list of inserts + assert q == 1 # 1 entry containing the list of inserts - self.assertEqual(Blog.objects.count(), len(blogs)) + assert Blog.objects.count() == len(blogs) Blog.drop_collection() Blog.ensure_indexes() @@ -885,9 +889,9 @@ class TestQueryset(unittest.TestCase): # Check bulk insert using load_bulk=True blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)] with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 Blog.objects.insert(blogs) - self.assertEqual(q, 2) # 1 for insert 1 for fetch + assert q == 2 # 1 for insert 1 for fetch Blog.drop_collection() @@ -898,25 +902,27 @@ class TestQueryset(unittest.TestCase): blog1 = Blog(title="code", posts=[post1, post2]) blog2 = Blog(title="mongodb", posts=[post2, post1]) blog1, blog2 = Blog.objects.insert([blog1, blog2]) - self.assertEqual(blog1.title, "code") - self.assertEqual(blog2.title, "mongodb") + assert blog1.title == "code" + assert blog2.title == "mongodb" - self.assertEqual(Blog.objects.count(), 2) + assert Blog.objects.count() == 2 # test inserting an existing document (shouldn't be allowed) - with self.assertRaises(OperationError) as cm: + with pytest.raises(OperationError) as cm: blog = Blog.objects.first() Blog.objects.insert(blog) - self.assertEqual( - str(cm.exception), "Some documents have ObjectIds, use doc.update() instead" + assert ( + str(cm.exception) + == "Some documents have ObjectIds, use doc.update() instead" ) # test inserting a query set - with self.assertRaises(OperationError) as cm: + with pytest.raises(OperationError) as cm: blogs_qs = Blog.objects Blog.objects.insert(blogs_qs) - self.assertEqual( - str(cm.exception), "Some documents have ObjectIds, use doc.update() instead" + assert ( + str(cm.exception) + == "Some documents have ObjectIds, use doc.update() instead" ) # insert 1 new doc @@ -927,13 +933,13 @@ class TestQueryset(unittest.TestCase): blog1 = Blog(title="code", posts=[post1, post2]) blog1 = Blog.objects.insert(blog1) - self.assertEqual(blog1.title, "code") - self.assertEqual(Blog.objects.count(), 1) + assert blog1.title == "code" + assert Blog.objects.count() == 1 Blog.drop_collection() blog1 = Blog(title="code", posts=[post1, post2]) obj_id = Blog.objects.insert(blog1, load_bulk=False) - self.assertIsInstance(obj_id, ObjectId) + assert isinstance(obj_id, ObjectId) Blog.drop_collection() post3 = Post(comments=[comment1, comment1]) @@ -941,10 +947,10 @@ class TestQueryset(unittest.TestCase): blog2 = Blog(title="bar", posts=[post2, post3]) Blog.objects.insert([blog1, blog2]) - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): Blog.objects.insert(Blog(title=blog2.title)) - self.assertEqual(Blog.objects.count(), 2) + assert Blog.objects.count() == 2 def test_bulk_insert_different_class_fails(self): class Blog(Document): @@ -954,7 +960,7 @@ class TestQueryset(unittest.TestCase): pass # try inserting a different document class - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Blog.objects.insert(Author()) def test_bulk_insert_with_wrong_type(self): @@ -964,10 +970,10 @@ class TestQueryset(unittest.TestCase): Blog.drop_collection() Blog(name="test").save() - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Blog.objects.insert("HELLO WORLD") - with self.assertRaises(OperationError): + with pytest.raises(OperationError): Blog.objects.insert({"name": "garbage"}) def test_bulk_insert_update_input_document_ids(self): @@ -979,23 +985,23 @@ class TestQueryset(unittest.TestCase): # Test with bulk comments = [Comment(idx=idx) for idx in range(20)] for com in comments: - self.assertIsNone(com.id) + assert com.id is None returned_comments = Comment.objects.insert(comments, load_bulk=True) for com in comments: - self.assertIsInstance(com.id, ObjectId) + assert isinstance(com.id, ObjectId) input_mapping = {com.id: com.idx for com in comments} saved_mapping = {com.id: com.idx for com in returned_comments} - self.assertEqual(input_mapping, saved_mapping) + assert input_mapping == saved_mapping Comment.drop_collection() # Test with just one comment = Comment(idx=0) inserted_comment_id = Comment.objects.insert(comment, load_bulk=False) - self.assertEqual(comment.id, inserted_comment_id) + assert comment.id == inserted_comment_id def test_bulk_insert_accepts_doc_with_ids(self): class Comment(Document): @@ -1017,7 +1023,7 @@ class TestQueryset(unittest.TestCase): Comment.objects.insert(com1) - with self.assertRaises(NotUniqueError): + with pytest.raises(NotUniqueError): Comment.objects.insert(com1) def test_get_changed_fields_query_count(self): @@ -1050,28 +1056,28 @@ class TestQueryset(unittest.TestCase): o1 = Organization(name="o1", employees=[p1]).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 # Fetching a document should result in a query. org = Organization.objects.get(id=o1.id) - self.assertEqual(q, 1) + assert q == 1 # Checking changed fields of a newly fetched document should not # result in a query. org._get_changed_fields() - self.assertEqual(q, 1) + assert q == 1 # Saving a doc without changing any of its fields should not result # in a query (with or without cascade=False). org = Organization.objects.get(id=o1.id) with query_counter() as q: org.save() - self.assertEqual(q, 0) + assert q == 0 org = Organization.objects.get(id=o1.id) with query_counter() as q: org.save(cascade=False) - self.assertEqual(q, 0) + assert q == 0 # Saving a doc after you append a reference to it should result in # two db operations (a query for the reference and an update). @@ -1080,7 +1086,7 @@ class TestQueryset(unittest.TestCase): with query_counter() as q: org.employees.append(p2) # dereferences p2 org.save() # saves the org - self.assertEqual(q, 2) + assert q == 2 def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. @@ -1097,8 +1103,8 @@ class TestQueryset(unittest.TestCase): break people3 = [person for person in queryset] - self.assertEqual(people1, people2) - self.assertEqual(people1, people3) + assert people1 == people2 + assert people1 == people3 def test_repr(self): """Test repr behavior isnt destructive""" @@ -1116,21 +1122,21 @@ class TestQueryset(unittest.TestCase): docs = Doc.objects.order_by("number") - self.assertEqual(docs.count(), 1000) + assert docs.count() == 1000 docs_string = "%s" % docs - self.assertIn("Doc: 0", docs_string) + assert "Doc: 0" in docs_string - self.assertEqual(docs.count(), 1000) - self.assertIn("(remaining elements truncated)", "%s" % docs) + assert docs.count() == 1000 + assert "(remaining elements truncated)" in "%s" % docs # Limit and skip docs = docs[1:4] - self.assertEqual("[, , ]", "%s" % docs) + assert "[, , ]" == "%s" % docs - self.assertEqual(docs.count(with_limit_and_skip=True), 3) + assert docs.count(with_limit_and_skip=True) == 3 for doc in docs: - self.assertEqual(".. queryset mid-iteration ..", repr(docs)) + assert ".. queryset mid-iteration .." == repr(docs) def test_regex_query_shortcuts(self): """Ensure that contains, startswith, endswith, etc work. @@ -1140,54 +1146,54 @@ class TestQueryset(unittest.TestCase): # Test contains obj = self.Person.objects(name__contains="van").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__contains="Van").first() - self.assertEqual(obj, None) + assert obj == None # Test icontains obj = self.Person.objects(name__icontains="Van").first() - self.assertEqual(obj, person) + assert obj == person # Test startswith obj = self.Person.objects(name__startswith="Guido").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__startswith="guido").first() - self.assertEqual(obj, None) + assert obj == None # Test istartswith obj = self.Person.objects(name__istartswith="guido").first() - self.assertEqual(obj, person) + assert obj == person # Test endswith obj = self.Person.objects(name__endswith="Rossum").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__endswith="rossuM").first() - self.assertEqual(obj, None) + assert obj == None # Test iendswith obj = self.Person.objects(name__iendswith="rossuM").first() - self.assertEqual(obj, person) + assert obj == person # Test exact obj = self.Person.objects(name__exact="Guido van Rossum").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__exact="Guido van rossum").first() - self.assertEqual(obj, None) + assert obj == None obj = self.Person.objects(name__exact="Guido van Rossu").first() - self.assertEqual(obj, None) + assert obj == None # Test iexact obj = self.Person.objects(name__iexact="gUIDO VAN rOSSUM").first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(name__iexact="gUIDO VAN rOSSU").first() - self.assertEqual(obj, None) + assert obj == None # Test unsafe expressions person = self.Person(name="Guido van Rossum [.'Geek']") person.save() obj = self.Person.objects(name__icontains="[.'Geek").first() - self.assertEqual(obj, person) + assert obj == person def test_not(self): """Ensure that the __not operator works as expected. @@ -1196,10 +1202,10 @@ class TestQueryset(unittest.TestCase): alice.save() obj = self.Person.objects(name__iexact="alice").first() - self.assertEqual(obj, alice) + assert obj == alice obj = self.Person.objects(name__not__iexact="alice").first() - self.assertEqual(obj, None) + assert obj == None def test_filter_chaining(self): """Ensure filters can be chained together. @@ -1253,12 +1259,12 @@ class TestQueryset(unittest.TestCase): published_posts = published_posts.filter( published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0) ) - self.assertEqual(published_posts.count(), 2) + assert published_posts.count() == 2 blog_posts = BlogPost.objects blog_posts = blog_posts.filter(blog__in=[blog_1, blog_2]) blog_posts = blog_posts.filter(blog=blog_3) - self.assertEqual(blog_posts.count(), 0) + assert blog_posts.count() == 0 BlogPost.drop_collection() Blog.drop_collection() @@ -1269,14 +1275,14 @@ class TestQueryset(unittest.TestCase): people = self.Person.objects people = people.filter(name__startswith="Gui").filter(name__not__endswith="tum") - self.assertEqual(people.count(), 1) + assert people.count() == 1 def assertSequence(self, qs, expected): qs = list(qs) expected = list(expected) - self.assertEqual(len(qs), len(expected)) + assert len(qs) == len(expected) for i in range(len(qs)): - self.assertEqual(qs[i], expected[i]) + assert qs[i] == expected[i] def test_ordering(self): """Ensure default ordering is applied and can be overridden. @@ -1327,31 +1333,27 @@ class TestQueryset(unittest.TestCase): # default ordering should be used by default with db_ops_tracker() as q: BlogPost.objects.filter(title="whatever").first() - self.assertEqual(len(q.get_ops()), 1) - self.assertEqual( - q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": -1} - ) + assert len(q.get_ops()) == 1 + assert q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY] == {"published_date": -1} # calling order_by() should clear the default ordering with db_ops_tracker() as q: BlogPost.objects.filter(title="whatever").order_by().first() - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] # calling an explicit order_by should use a specified sort with db_ops_tracker() as q: BlogPost.objects.filter(title="whatever").order_by("published_date").first() - self.assertEqual(len(q.get_ops()), 1) - self.assertEqual( - q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": 1} - ) + assert len(q.get_ops()) == 1 + assert q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY] == {"published_date": 1} # calling order_by() after an explicit sort should clear it with db_ops_tracker() as q: qs = BlogPost.objects.filter(title="whatever").order_by("published_date") qs.order_by().first() - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] def test_no_ordering_for_get(self): """ Ensure that Doc.objects.get doesn't use any ordering. @@ -1370,14 +1372,14 @@ class TestQueryset(unittest.TestCase): with db_ops_tracker() as q: BlogPost.objects.get(title="whatever") - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] # Ordering should be ignored for .get even if we set it explicitly with db_ops_tracker() as q: BlogPost.objects.order_by("-title").get(title="whatever") - self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) + assert len(q.get_ops()) == 1 + assert ORDER_BY_KEY not in q.get_ops()[0][CMD_QUERY_KEY] def test_find_embedded(self): """Ensure that an embedded document is properly returned from @@ -1397,20 +1399,20 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.create(author=user, content="Had a good coffee today...") result = BlogPost.objects.first() - self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, "Test User") + assert isinstance(result.author, User) + assert result.author.name == "Test User" result = BlogPost.objects.get(author__name=user.name) - self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, "Test User") + assert isinstance(result.author, User) + assert result.author.name == "Test User" result = BlogPost.objects.get(author={"name": user.name}) - self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, "Test User") + assert isinstance(result.author, User) + assert result.author.name == "Test User" # Fails, since the string is not a type that is able to represent the # author's document structure (should be dict) - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): BlogPost.objects.get(author=user.name) def test_find_empty_embedded(self): @@ -1428,7 +1430,7 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.create(content="Anonymous post...") result = BlogPost.objects.get(author=None) - self.assertEqual(result.author, None) + assert result.author == None def test_find_dict_item(self): """Ensure that DictField items may be found. @@ -1443,7 +1445,7 @@ class TestQueryset(unittest.TestCase): post.save() post_obj = BlogPost.objects(info__title="test").first() - self.assertEqual(post_obj.id, post.id) + assert post_obj.id == post.id BlogPost.drop_collection() @@ -1478,10 +1480,10 @@ class TestQueryset(unittest.TestCase): # Ensure that normal queries work c = BlogPost.objects(published=True).exec_js(js_func, "hits") - self.assertEqual(c, 2) + assert c == 2 c = BlogPost.objects(published=False).exec_js(js_func, "hits") - self.assertEqual(c, 1) + assert c == 1 BlogPost.drop_collection() @@ -1525,7 +1527,7 @@ class TestQueryset(unittest.TestCase): sub_code = BlogPost.objects._sub_js_fields(code) code_chunks = ['doc["cmnts"];', 'doc["doc-name"],', 'doc["cmnts"][i]["body"]'] for chunk in code_chunks: - self.assertIn(chunk, sub_code) + assert chunk in sub_code results = BlogPost.objects.exec_js(code) expected_results = [ @@ -1533,12 +1535,12 @@ class TestQueryset(unittest.TestCase): {u"comment": u"yay", u"document": u"post1"}, {u"comment": u"nice stuff", u"document": u"post2"}, ] - self.assertEqual(results, expected_results) + assert results == expected_results # Test template style code = "{{~comments.content}}" sub_code = BlogPost.objects._sub_js_fields(code) - self.assertEqual("cmnts.body", sub_code) + assert "cmnts.body" == sub_code BlogPost.drop_collection() @@ -1549,13 +1551,13 @@ class TestQueryset(unittest.TestCase): self.Person(name="User B", age=30).save() self.Person(name="User C", age=40).save() - self.assertEqual(self.Person.objects.count(), 3) + assert self.Person.objects.count() == 3 self.Person.objects(age__lt=30).delete() - self.assertEqual(self.Person.objects.count(), 2) + assert self.Person.objects.count() == 2 self.Person.objects.delete() - self.assertEqual(self.Person.objects.count(), 0) + assert self.Person.objects.count() == 0 def test_reverse_delete_rule_cascade(self): """Ensure cascading deletion of referring documents from the database. @@ -1576,9 +1578,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Chilling out", author=me).save() BlogPost(content="Pro Testing", author=someoneelse).save() - self.assertEqual(3, BlogPost.objects.count()) + assert 3 == BlogPost.objects.count() self.Person.objects(name="Test User").delete() - self.assertEqual(1, BlogPost.objects.count()) + assert 1 == BlogPost.objects.count() def test_reverse_delete_rule_cascade_on_abstract_document(self): """Ensure cascading deletion of referring documents from the database @@ -1603,9 +1605,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Chilling out", author=me).save() BlogPost(content="Pro Testing", author=someoneelse).save() - self.assertEqual(3, BlogPost.objects.count()) + assert 3 == BlogPost.objects.count() self.Person.objects(name="Test User").delete() - self.assertEqual(1, BlogPost.objects.count()) + assert 1 == BlogPost.objects.count() def test_reverse_delete_rule_cascade_cycle(self): """Ensure reference cascading doesn't loop if reference graph isn't @@ -1622,8 +1624,10 @@ class TestQueryset(unittest.TestCase): base.delete() - self.assertRaises(DoesNotExist, base.reload) - self.assertRaises(DoesNotExist, other.reload) + with pytest.raises(DoesNotExist): + base.reload() + with pytest.raises(DoesNotExist): + other.reload() def test_reverse_delete_rule_cascade_complex_cycle(self): """Ensure reference cascading doesn't loop if reference graph isn't @@ -1646,9 +1650,12 @@ class TestQueryset(unittest.TestCase): cat.delete() - self.assertRaises(DoesNotExist, base.reload) - self.assertRaises(DoesNotExist, other.reload) - self.assertRaises(DoesNotExist, other2.reload) + with pytest.raises(DoesNotExist): + base.reload() + with pytest.raises(DoesNotExist): + other.reload() + with pytest.raises(DoesNotExist): + other2.reload() def test_reverse_delete_rule_cascade_self_referencing(self): """Ensure self-referencing CASCADE deletes do not result in infinite @@ -1677,13 +1684,13 @@ class TestQueryset(unittest.TestCase): child_child.save() tree_size = 1 + num_children + (num_children * num_children) - self.assertEqual(tree_size, Category.objects.count()) - self.assertEqual(num_children, Category.objects(parent=base).count()) + assert tree_size == Category.objects.count() + assert num_children == Category.objects(parent=base).count() # The delete should effectively wipe out the Category collection # without resulting in infinite parent-child cascade recursion base.delete() - self.assertEqual(0, Category.objects.count()) + assert 0 == Category.objects.count() def test_reverse_delete_rule_nullify(self): """Ensure nullification of references to deleted documents. @@ -1705,11 +1712,11 @@ class TestQueryset(unittest.TestCase): post = BlogPost(content="Watching TV", category=lameness) post.save() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual("Lameness", BlogPost.objects.first().category.name) + assert 1 == BlogPost.objects.count() + assert "Lameness" == BlogPost.objects.first().category.name Category.objects.delete() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual(None, BlogPost.objects.first().category) + assert 1 == BlogPost.objects.count() + assert None == BlogPost.objects.first().category def test_reverse_delete_rule_nullify_on_abstract_document(self): """Ensure nullification of references to deleted documents when @@ -1732,11 +1739,11 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Watching TV", author=me).save() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual(me, BlogPost.objects.first().author) + assert 1 == BlogPost.objects.count() + assert me == BlogPost.objects.first().author self.Person.objects(name="Test User").delete() - self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual(None, BlogPost.objects.first().author) + assert 1 == BlogPost.objects.count() + assert None == BlogPost.objects.first().author def test_reverse_delete_rule_deny(self): """Ensure deletion gets denied on documents that still have references @@ -1756,7 +1763,8 @@ class TestQueryset(unittest.TestCase): post = BlogPost(content="Watching TV", author=me) post.save() - self.assertRaises(OperationError, self.Person.objects.delete) + with pytest.raises(OperationError): + self.Person.objects.delete() def test_reverse_delete_rule_deny_on_abstract_document(self): """Ensure deletion gets denied on documents that still have references @@ -1777,8 +1785,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Watching TV", author=me).save() - self.assertEqual(1, BlogPost.objects.count()) - self.assertRaises(OperationError, self.Person.objects.delete) + assert 1 == BlogPost.objects.count() + with pytest.raises(OperationError): + self.Person.objects.delete() def test_reverse_delete_rule_pull(self): """Ensure pulling of references to deleted documents. @@ -1807,8 +1816,8 @@ class TestQueryset(unittest.TestCase): post.reload() another.reload() - self.assertEqual(post.authors, [me]) - self.assertEqual(another.authors, []) + assert post.authors == [me] + assert another.authors == [] def test_reverse_delete_rule_pull_on_abstract_documents(self): """Ensure pulling of references to deleted documents when reference @@ -1841,8 +1850,8 @@ class TestQueryset(unittest.TestCase): post.reload() another.reload() - self.assertEqual(post.authors, [me]) - self.assertEqual(another.authors, []) + assert post.authors == [me] + assert another.authors == [] def test_delete_with_limits(self): class Log(Document): @@ -1854,7 +1863,7 @@ class TestQueryset(unittest.TestCase): Log().save() Log.objects()[3:5].delete() - self.assertEqual(8, Log.objects.count()) + assert 8 == Log.objects.count() def test_delete_with_limit_handles_delete_rules(self): """Ensure cascading deletion of referring documents from the database. @@ -1875,9 +1884,9 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Chilling out", author=me).save() BlogPost(content="Pro Testing", author=someoneelse).save() - self.assertEqual(3, BlogPost.objects.count()) + assert 3 == BlogPost.objects.count() self.Person.objects()[:1].delete() - self.assertEqual(1, BlogPost.objects.count()) + assert 1 == BlogPost.objects.count() def test_delete_edge_case_with_write_concern_0_return_None(self): """Return None if the delete operation is unacknowledged. @@ -1887,7 +1896,7 @@ class TestQueryset(unittest.TestCase): """ p1 = self.Person(name="User Z", age=20).save() del_result = p1.delete(w=0) - self.assertEqual(None, del_result) + assert None == del_result def test_reference_field_find(self): """Ensure cascading deletion of referring documents from the database. @@ -1903,13 +1912,13 @@ class TestQueryset(unittest.TestCase): me = self.Person(name="Test User").save() BlogPost(content="test 123", author=me).save() - self.assertEqual(1, BlogPost.objects(author=me).count()) - self.assertEqual(1, BlogPost.objects(author=me.pk).count()) - self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + assert 1 == BlogPost.objects(author=me).count() + assert 1 == BlogPost.objects(author=me.pk).count() + assert 1 == BlogPost.objects(author="%s" % me.pk).count() - self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) - self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) - self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + assert 1 == BlogPost.objects(author__in=[me]).count() + assert 1 == BlogPost.objects(author__in=[me.pk]).count() + assert 1 == BlogPost.objects(author__in=["%s" % me.pk]).count() def test_reference_field_find_dbref(self): """Ensure cascading deletion of referring documents from the database. @@ -1925,13 +1934,13 @@ class TestQueryset(unittest.TestCase): me = self.Person(name="Test User").save() BlogPost(content="test 123", author=me).save() - self.assertEqual(1, BlogPost.objects(author=me).count()) - self.assertEqual(1, BlogPost.objects(author=me.pk).count()) - self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + assert 1 == BlogPost.objects(author=me).count() + assert 1 == BlogPost.objects(author=me.pk).count() + assert 1 == BlogPost.objects(author="%s" % me.pk).count() - self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) - self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) - self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + assert 1 == BlogPost.objects(author__in=[me]).count() + assert 1 == BlogPost.objects(author__in=[me.pk]).count() + assert 1 == BlogPost.objects(author__in=["%s" % me.pk]).count() def test_update_intfield_operator(self): class BlogPost(Document): @@ -1944,20 +1953,20 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.update_one(set__hits=10) post.reload() - self.assertEqual(post.hits, 10) + assert post.hits == 10 BlogPost.objects.update_one(inc__hits=1) post.reload() - self.assertEqual(post.hits, 11) + assert post.hits == 11 BlogPost.objects.update_one(dec__hits=1) post.reload() - self.assertEqual(post.hits, 10) + assert post.hits == 10 # Negative dec operator is equal to a positive inc operator BlogPost.objects.update_one(dec__hits=-1) post.reload() - self.assertEqual(post.hits, 11) + assert post.hits == 11 def test_update_decimalfield_operator(self): class BlogPost(Document): @@ -1970,19 +1979,19 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.update_one(inc__review=0.1) # test with floats post.reload() - self.assertEqual(float(post.review), 3.6) + assert float(post.review) == 3.6 BlogPost.objects.update_one(dec__review=0.1) post.reload() - self.assertEqual(float(post.review), 3.5) + assert float(post.review) == 3.5 BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal post.reload() - self.assertEqual(float(post.review), 3.62) + assert float(post.review) == 3.62 BlogPost.objects.update_one(dec__review=Decimal(0.12)) post.reload() - self.assertEqual(float(post.review), 3.5) + assert float(post.review) == 3.5 def test_update_decimalfield_operator_not_working_with_force_string(self): class BlogPost(Document): @@ -1993,7 +2002,7 @@ class TestQueryset(unittest.TestCase): post = BlogPost(review=3.5) post.save() - with self.assertRaises(OperationError): + with pytest.raises(OperationError): BlogPost.objects.update_one(inc__review=0.1) # test with floats def test_update_listfield_operator(self): @@ -2011,22 +2020,22 @@ class TestQueryset(unittest.TestCase): # ListField operator BlogPost.objects.update(push__tags="mongo") post.reload() - self.assertIn("mongo", post.tags) + assert "mongo" in post.tags BlogPost.objects.update_one(push_all__tags=["db", "nosql"]) post.reload() - self.assertIn("db", post.tags) - self.assertIn("nosql", post.tags) + assert "db" in post.tags + assert "nosql" in post.tags tags = post.tags[:-1] BlogPost.objects.update(pop__tags=1) post.reload() - self.assertEqual(post.tags, tags) + assert post.tags == tags BlogPost.objects.update_one(add_to_set__tags="unique") BlogPost.objects.update_one(add_to_set__tags="unique") post.reload() - self.assertEqual(post.tags.count("unique"), 1) + assert post.tags.count("unique") == 1 BlogPost.drop_collection() @@ -2038,12 +2047,12 @@ class TestQueryset(unittest.TestCase): post = BlogPost(title="garbage").save() - self.assertNotEqual(post.title, None) + assert post.title != None BlogPost.objects.update_one(unset__title=1) post.reload() - self.assertEqual(post.title, None) + assert post.title == None pymongo_doc = BlogPost.objects.as_pymongo().first() - self.assertNotIn("title", pymongo_doc) + assert "title" not in pymongo_doc def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly. @@ -2060,16 +2069,16 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.filter(id=post.id).update(push__tags="code") BlogPost.objects.filter(id=post.id).update(push__tags__0=["mongodb", "python"]) post.reload() - self.assertEqual(post.tags, ["mongodb", "python", "code"]) + assert post.tags == ["mongodb", "python", "code"] BlogPost.objects.filter(id=post.id).update(set__tags__2="java") post.reload() - self.assertEqual(post.tags, ["mongodb", "python", "java"]) + assert post.tags == ["mongodb", "python", "java"] # test push with singular value BlogPost.objects.filter(id=post.id).update(push__tags__0="scala") post.reload() - self.assertEqual(post.tags, ["scala", "mongodb", "python", "java"]) + assert post.tags == ["scala", "mongodb", "python", "java"] def test_update_push_list_of_list(self): """Ensure that the 'push' update operation works in the list of list @@ -2085,7 +2094,7 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.filter(slug="test").update(push__tags=["value1", 123]) post.reload() - self.assertEqual(post.tags, [["value1", 123]]) + assert post.tags == [["value1", 123]] def test_update_push_and_pull_add_to_set(self): """Ensure that the 'pull' update operation works correctly. @@ -2102,25 +2111,25 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.filter(id=post.id).update(push__tags="code") post.reload() - self.assertEqual(post.tags, ["code"]) + assert post.tags == ["code"] BlogPost.objects.filter(id=post.id).update(push_all__tags=["mongodb", "code"]) post.reload() - self.assertEqual(post.tags, ["code", "mongodb", "code"]) + assert post.tags == ["code", "mongodb", "code"] BlogPost.objects(slug="test").update(pull__tags="code") post.reload() - self.assertEqual(post.tags, ["mongodb"]) + assert post.tags == ["mongodb"] BlogPost.objects(slug="test").update(pull_all__tags=["mongodb", "code"]) post.reload() - self.assertEqual(post.tags, []) + assert post.tags == [] BlogPost.objects(slug="test").update( __raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}} ) post.reload() - self.assertEqual(post.tags, ["code", "mongodb"]) + assert post.tags == ["code", "mongodb"] def test_add_to_set_each(self): class Item(Document): @@ -2137,7 +2146,7 @@ class TestQueryset(unittest.TestCase): item.update(add_to_set__parents=[parent_1, parent_2, parent_1]) item.reload() - self.assertEqual([parent_1, parent_2], item.parents) + assert [parent_1, parent_2] == item.parents def test_pull_nested(self): class Collaborator(EmbeddedDocument): @@ -2156,9 +2165,9 @@ class TestQueryset(unittest.TestCase): s = Site(name="test", collaborators=[c]).save() Site.objects(id=s.id).update_one(pull__collaborators__user="Esteban") - self.assertEqual(Site.objects.first().collaborators, []) + assert Site.objects.first().collaborators == [] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Site.objects(id=s.id).update_one(pull_all__collaborators__user=["Ross"]) def test_pull_from_nested_embedded(self): @@ -2185,14 +2194,14 @@ class TestQueryset(unittest.TestCase): ).save() Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) - self.assertEqual(Site.objects.first().collaborators["helpful"], []) + assert Site.objects.first().collaborators["helpful"] == [] Site.objects(id=s.id).update_one( pull__collaborators__unhelpful={"name": "Frank"} ) - self.assertEqual(Site.objects.first().collaborators["unhelpful"], []) + assert Site.objects.first().collaborators["unhelpful"] == [] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=["Ross"] ) @@ -2229,12 +2238,12 @@ class TestQueryset(unittest.TestCase): Site.objects(id=s.id).update_one( pull__collaborators__helpful__name__in=["Esteban"] ) # Pull a - self.assertEqual(Site.objects.first().collaborators["helpful"], [b]) + assert Site.objects.first().collaborators["helpful"] == [b] Site.objects(id=s.id).update_one( pull__collaborators__unhelpful__name__nin=["John"] ) # Pull x - self.assertEqual(Site.objects.first().collaborators["unhelpful"], [y]) + assert Site.objects.first().collaborators["unhelpful"] == [y] def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): @@ -2255,14 +2264,14 @@ class TestQueryset(unittest.TestCase): s.save() Site.objects(id=s.id).update_one(pull__collaborators__helpful__user="Esteban") - self.assertEqual(Site.objects.first().collaborators["helpful"], []) + assert Site.objects.first().collaborators["helpful"] == [] Site.objects(id=s.id).update_one( pull__collaborators__unhelpful={"user": "Frank"} ) - self.assertEqual(Site.objects.first().collaborators["unhelpful"], []) + assert Site.objects.first().collaborators["unhelpful"] == [] - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__user=["Ross"] ) @@ -2280,7 +2289,7 @@ class TestQueryset(unittest.TestCase): bar = Bar(foos=[foo]).save() Bar.objects(id=bar.id).update(pull__foos=foo) bar.reload() - self.assertEqual(len(bar.foos), 0) + assert len(bar.foos) == 0 def test_update_one_check_return_with_full_result(self): class BlogTag(Document): @@ -2290,10 +2299,10 @@ class TestQueryset(unittest.TestCase): BlogTag(name="garbage").save() default_update = BlogTag.objects.update_one(name="new") - self.assertEqual(default_update, 1) + assert default_update == 1 full_result_update = BlogTag.objects.update_one(name="new", full_result=True) - self.assertIsInstance(full_result_update, UpdateResult) + assert isinstance(full_result_update, UpdateResult) def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2316,12 +2325,12 @@ class TestQueryset(unittest.TestCase): post = BlogPost(slug="test-2", tags=[tag_1, tag_2]) post.save() - self.assertEqual(len(post.tags), 2) + assert len(post.tags) == 2 BlogPost.objects(slug="test-2").update_one(pop__tags=-1) post.reload() - self.assertEqual(len(post.tags), 1) + assert len(post.tags) == 1 BlogPost.drop_collection() BlogTag.drop_collection() @@ -2344,15 +2353,15 @@ class TestQueryset(unittest.TestCase): post = BlogPost(slug="test-2", tags=[tag_1, tag_2]) post.save() - self.assertEqual(len(post.tags), 2) + assert len(post.tags) == 2 BlogPost.objects(slug="test-2").update_one(set__tags__0__name="python") post.reload() - self.assertEqual(post.tags[0].name, "python") + assert post.tags[0].name == "python" BlogPost.objects(slug="test-2").update_one(pop__tags=-1) post.reload() - self.assertEqual(len(post.tags), 1) + assert len(post.tags) == 1 BlogPost.drop_collection() @@ -2374,7 +2383,7 @@ class TestQueryset(unittest.TestCase): ) message = message.reload() - self.assertEqual(message.authors[0].name, "Ross") + assert message.authors[0].name == "Ross" Message.objects(authors__name="Ross").update_one( set__authors=[ @@ -2385,9 +2394,9 @@ class TestQueryset(unittest.TestCase): ) message = message.reload() - self.assertEqual(message.authors[0].name, "Harry") - self.assertEqual(message.authors[1].name, "Ross") - self.assertEqual(message.authors[2].name, "Adam") + assert message.authors[0].name == "Harry" + assert message.authors[1].name == "Ross" + assert message.authors[2].name == "Adam" def test_set_generic_embedded_documents(self): class Bar(EmbeddedDocument): @@ -2403,7 +2412,7 @@ class TestQueryset(unittest.TestCase): User.objects(username="abc").update(set__bar=Bar(name="test"), upsert=True) user = User.objects(username="abc").first() - self.assertEqual(user.bar.name, "test") + assert user.bar.name == "test" def test_reload_embedded_docs_instance(self): class SubDoc(EmbeddedDocument): @@ -2415,7 +2424,7 @@ class TestQueryset(unittest.TestCase): doc = Doc(embedded=SubDoc(val=0)).save() doc.reload() - self.assertEqual(doc.pk, doc.embedded._instance.pk) + assert doc.pk == doc.embedded._instance.pk def test_reload_list_embedded_docs_instance(self): class SubDoc(EmbeddedDocument): @@ -2427,7 +2436,7 @@ class TestQueryset(unittest.TestCase): doc = Doc(embedded=[SubDoc(val=0)]).save() doc.reload() - self.assertEqual(doc.pk, doc.embedded[0]._instance.pk) + assert doc.pk == doc.embedded[0]._instance.pk def test_order_by(self): """Ensure that QuerySets may be ordered. @@ -2437,16 +2446,16 @@ class TestQueryset(unittest.TestCase): self.Person(name="User C", age=30).save() names = [p.name for p in self.Person.objects.order_by("-age")] - self.assertEqual(names, ["User B", "User C", "User A"]) + assert names == ["User B", "User C", "User A"] names = [p.name for p in self.Person.objects.order_by("+age")] - self.assertEqual(names, ["User A", "User C", "User B"]) + assert names == ["User A", "User C", "User B"] names = [p.name for p in self.Person.objects.order_by("age")] - self.assertEqual(names, ["User A", "User C", "User B"]) + assert names == ["User A", "User C", "User B"] ages = [p.age for p in self.Person.objects.order_by("-name")] - self.assertEqual(ages, [30, 40, 20]) + assert ages == [30, 40, 20] def test_order_by_optional(self): class BlogPost(Document): @@ -2511,24 +2520,24 @@ class TestQueryset(unittest.TestCase): ages = [p.age for p in only_age] # The .only('age') clause should mean that all names are None - self.assertEqual(names, [None, None, None]) - self.assertEqual(ages, [40, 30, 20]) + assert names == [None, None, None] + assert ages == [40, 30, 20] qs = self.Person.objects.all().order_by("-age") qs = qs.limit(10) ages = [p.age for p in qs] - self.assertEqual(ages, [40, 30, 20]) + assert ages == [40, 30, 20] qs = self.Person.objects.all().limit(10) qs = qs.order_by("-age") ages = [p.age for p in qs] - self.assertEqual(ages, [40, 30, 20]) + assert ages == [40, 30, 20] qs = self.Person.objects.all().skip(0) qs = qs.order_by("-age") ages = [p.age for p in qs] - self.assertEqual(ages, [40, 30, 20]) + assert ages == [40, 30, 20] def test_confirm_order_by_reference_wont_work(self): """Ordering by reference is not possible. Use map / reduce.. or @@ -2551,7 +2560,7 @@ class TestQueryset(unittest.TestCase): Author(author=person_c).save() names = [a.author.name for a in Author.objects.order_by("-author__age")] - self.assertEqual(names, ["User A", "User B", "User C"]) + assert names == ["User A", "User B", "User C"] def test_comment(self): """Make sure adding a comment to the query gets added to the query""" @@ -2573,10 +2582,10 @@ class TestQueryset(unittest.TestCase): ) ops = q.get_ops() - self.assertEqual(len(ops), 2) + assert len(ops) == 2 for op in ops: - self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {"age": {"$gte": 18}}) - self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], "looking for an adult") + assert op[CMD_QUERY_KEY][QUERY_KEY] == {"age": {"$gte": 18}} + assert op[CMD_QUERY_KEY][COMMENT_KEY] == "looking for an adult" def test_map_reduce(self): """Ensure map/reduce is both mapping and reducing. @@ -2613,13 +2622,13 @@ class TestQueryset(unittest.TestCase): # run a map/reduce operation spanning all posts results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - self.assertEqual(len(results), 4) + assert len(results) == 4 music = list(filter(lambda r: r.key == "music", results))[0] - self.assertEqual(music.value, 2) + assert music.value == 2 film = list(filter(lambda r: r.key == "film", results))[0] - self.assertEqual(film.value, 3) + assert film.value == 3 BlogPost.drop_collection() @@ -2640,8 +2649,8 @@ class TestQueryset(unittest.TestCase): post2.save() post3.save() - self.assertEqual(BlogPost._fields["title"].db_field, "_id") - self.assertEqual(BlogPost._meta["id_field"], "title") + assert BlogPost._fields["title"].db_field == "_id" + assert BlogPost._meta["id_field"] == "title" map_f = """ function() { @@ -2663,9 +2672,9 @@ class TestQueryset(unittest.TestCase): results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - self.assertEqual(results[0].object, post1) - self.assertEqual(results[1].object, post2) - self.assertEqual(results[2].object, post3) + assert results[0].object == post1 + assert results[1].object == post2 + assert results[2].object == post3 BlogPost.drop_collection() @@ -2770,50 +2779,41 @@ class TestQueryset(unittest.TestCase): results = list(results) collection = get_db("test2").family_map - self.assertEqual( - collection.find_one({"_id": 1}), - { - "_id": 1, - "value": { - "persons": [ - {"age": 21, "name": u"Wilson Jr"}, - {"age": 45, "name": u"Wilson Father"}, - {"age": 40, "name": u"Eliana Costa"}, - {"age": 17, "name": u"Tayza Mariana"}, - ], - "totalAge": 123, - }, + assert collection.find_one({"_id": 1}) == { + "_id": 1, + "value": { + "persons": [ + {"age": 21, "name": u"Wilson Jr"}, + {"age": 45, "name": u"Wilson Father"}, + {"age": 40, "name": u"Eliana Costa"}, + {"age": 17, "name": u"Tayza Mariana"}, + ], + "totalAge": 123, }, - ) + } - self.assertEqual( - collection.find_one({"_id": 2}), - { - "_id": 2, - "value": { - "persons": [ - {"age": 16, "name": u"Isabella Luanna"}, - {"age": 36, "name": u"Sandra Mara"}, - {"age": 10, "name": u"Igor Gabriel"}, - ], - "totalAge": 62, - }, + assert collection.find_one({"_id": 2}) == { + "_id": 2, + "value": { + "persons": [ + {"age": 16, "name": u"Isabella Luanna"}, + {"age": 36, "name": u"Sandra Mara"}, + {"age": 10, "name": u"Igor Gabriel"}, + ], + "totalAge": 62, }, - ) + } - self.assertEqual( - collection.find_one({"_id": 3}), - { - "_id": 3, - "value": { - "persons": [ - {"age": 30, "name": u"Arthur WA"}, - {"age": 25, "name": u"Paula Leonel"}, - ], - "totalAge": 55, - }, + assert collection.find_one({"_id": 3}) == { + "_id": 3, + "value": { + "persons": [ + {"age": 30, "name": u"Arthur WA"}, + {"age": 25, "name": u"Paula Leonel"}, + ], + "totalAge": 55, }, - ) + } def test_map_reduce_finalize(self): """Ensure that map, reduce, and finalize run and introduce "scope" @@ -2933,10 +2933,10 @@ class TestQueryset(unittest.TestCase): results = list(results) # assert troublesome Buzz article is ranked 1st - self.assertTrue(results[0].object.title.startswith("Google Buzz")) + assert results[0].object.title.startswith("Google Buzz") # assert laser vision is ranked last - self.assertTrue(results[-1].object.title.startswith("How to see")) + assert results[-1].object.title.startswith("How to see") Link.drop_collection() @@ -2956,11 +2956,11 @@ class TestQueryset(unittest.TestCase): def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["music", "film", "actors", "watch"]), set(f.keys())) - self.assertEqual(f["music"], 3) - self.assertEqual(f["actors"], 2) - self.assertEqual(f["watch"], 2) - self.assertEqual(f["film"], 1) + assert set(["music", "film", "actors", "watch"]) == set(f.keys()) + assert f["music"] == 3 + assert f["actors"] == 2 + assert f["watch"] == 2 + assert f["film"] == 1 exec_js = BlogPost.objects.item_frequencies("tags") map_reduce = BlogPost.objects.item_frequencies("tags", map_reduce=True) @@ -2970,10 +2970,10 @@ class TestQueryset(unittest.TestCase): # Ensure query is taken into account def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["music", "actors", "watch"]), set(f.keys())) - self.assertEqual(f["music"], 2) - self.assertEqual(f["actors"], 1) - self.assertEqual(f["watch"], 1) + assert set(["music", "actors", "watch"]) == set(f.keys()) + assert f["music"] == 2 + assert f["actors"] == 1 + assert f["watch"] == 1 exec_js = BlogPost.objects(hits__gt=1).item_frequencies("tags") map_reduce = BlogPost.objects(hits__gt=1).item_frequencies( @@ -2984,10 +2984,10 @@ class TestQueryset(unittest.TestCase): # Check that normalization works def test_assertions(f): - self.assertAlmostEqual(f["music"], 3.0 / 8.0) - self.assertAlmostEqual(f["actors"], 2.0 / 8.0) - self.assertAlmostEqual(f["watch"], 2.0 / 8.0) - self.assertAlmostEqual(f["film"], 1.0 / 8.0) + assert round(abs(f["music"] - 3.0 / 8.0), 7) == 0 + assert round(abs(f["actors"] - 2.0 / 8.0), 7) == 0 + assert round(abs(f["watch"] - 2.0 / 8.0), 7) == 0 + assert round(abs(f["film"] - 1.0 / 8.0), 7) == 0 exec_js = BlogPost.objects.item_frequencies("tags", normalize=True) map_reduce = BlogPost.objects.item_frequencies( @@ -2998,9 +2998,9 @@ class TestQueryset(unittest.TestCase): # Check item_frequencies works for non-list fields def test_assertions(f): - self.assertEqual(set([1, 2]), set(f.keys())) - self.assertEqual(f[1], 1) - self.assertEqual(f[2], 2) + assert set([1, 2]) == set(f.keys()) + assert f[1] == 1 + assert f[2] == 2 exec_js = BlogPost.objects.item_frequencies("hits") map_reduce = BlogPost.objects.item_frequencies("hits", map_reduce=True) @@ -3036,9 +3036,9 @@ class TestQueryset(unittest.TestCase): def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["62-3331-1656", "62-3332-1656"]), set(f.keys())) - self.assertEqual(f["62-3331-1656"], 2) - self.assertEqual(f["62-3332-1656"], 1) + assert set(["62-3331-1656", "62-3332-1656"]) == set(f.keys()) + assert f["62-3331-1656"] == 2 + assert f["62-3332-1656"] == 1 exec_js = Person.objects.item_frequencies("phone.number") map_reduce = Person.objects.item_frequencies("phone.number", map_reduce=True) @@ -3048,8 +3048,8 @@ class TestQueryset(unittest.TestCase): # Ensure query is taken into account def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(["62-3331-1656"]), set(f.keys())) - self.assertEqual(f["62-3331-1656"], 2) + assert set(["62-3331-1656"]) == set(f.keys()) + assert f["62-3331-1656"] == 2 exec_js = Person.objects(phone__number="62-3331-1656").item_frequencies( "phone.number" @@ -3062,8 +3062,8 @@ class TestQueryset(unittest.TestCase): # Check that normalization works def test_assertions(f): - self.assertEqual(f["62-3331-1656"], 2.0 / 3.0) - self.assertEqual(f["62-3332-1656"], 1.0 / 3.0) + assert f["62-3331-1656"] == 2.0 / 3.0 + assert f["62-3332-1656"] == 1.0 / 3.0 exec_js = Person.objects.item_frequencies("phone.number", normalize=True) map_reduce = Person.objects.item_frequencies( @@ -3083,14 +3083,14 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson Jr").save() freq = Person.objects.item_frequencies("city") - self.assertEqual(freq, {"CRB": 1.0, None: 1.0}) + assert freq == {"CRB": 1.0, None: 1.0} freq = Person.objects.item_frequencies("city", normalize=True) - self.assertEqual(freq, {"CRB": 0.5, None: 0.5}) + assert freq == {"CRB": 0.5, None: 0.5} freq = Person.objects.item_frequencies("city", map_reduce=True) - self.assertEqual(freq, {"CRB": 1.0, None: 1.0}) + assert freq == {"CRB": 1.0, None: 1.0} freq = Person.objects.item_frequencies("city", normalize=True, map_reduce=True) - self.assertEqual(freq, {"CRB": 0.5, None: 0.5}) + assert freq == {"CRB": 0.5, None: 0.5} def test_item_frequencies_with_null_embedded(self): class Data(EmbeddedDocument): @@ -3115,10 +3115,10 @@ class TestQueryset(unittest.TestCase): p.save() ot = Person.objects.item_frequencies("extra.tag", map_reduce=False) - self.assertEqual(ot, {None: 1.0, u"friend": 1.0}) + assert ot == {None: 1.0, u"friend": 1.0} ot = Person.objects.item_frequencies("extra.tag", map_reduce=True) - self.assertEqual(ot, {None: 1.0, u"friend": 1.0}) + assert ot == {None: 1.0, u"friend": 1.0} def test_item_frequencies_with_0_values(self): class Test(Document): @@ -3130,9 +3130,9 @@ class TestQueryset(unittest.TestCase): t.save() ot = Test.objects.item_frequencies("val", map_reduce=True) - self.assertEqual(ot, {0: 1}) + assert ot == {0: 1} ot = Test.objects.item_frequencies("val", map_reduce=False) - self.assertEqual(ot, {0: 1}) + assert ot == {0: 1} def test_item_frequencies_with_False_values(self): class Test(Document): @@ -3144,9 +3144,9 @@ class TestQueryset(unittest.TestCase): t.save() ot = Test.objects.item_frequencies("val", map_reduce=True) - self.assertEqual(ot, {False: 1}) + assert ot == {False: 1} ot = Test.objects.item_frequencies("val", map_reduce=False) - self.assertEqual(ot, {False: 1}) + assert ot == {False: 1} def test_item_frequencies_normalize(self): class Test(Document): @@ -3161,31 +3161,32 @@ class TestQueryset(unittest.TestCase): Test(val=2).save() freqs = Test.objects.item_frequencies("val", map_reduce=False, normalize=True) - self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) + assert freqs == {1: 50.0 / 70, 2: 20.0 / 70} freqs = Test.objects.item_frequencies("val", map_reduce=True, normalize=True) - self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) + assert freqs == {1: 50.0 / 70, 2: 20.0 / 70} def test_average(self): """Ensure that field can be averaged correctly. """ self.Person(name="person", age=0).save() - self.assertEqual(int(self.Person.objects.average("age")), 0) + assert int(self.Person.objects.average("age")) == 0 ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): self.Person(name="test%s" % i, age=age).save() avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 - self.assertAlmostEqual(int(self.Person.objects.average("age")), avg) + assert round(abs(int(self.Person.objects.average("age")) - avg), 7) == 0 self.Person(name="ageless person").save() - self.assertEqual(int(self.Person.objects.average("age")), avg) + assert int(self.Person.objects.average("age")) == avg # dot notation self.Person(name="person meta", person_meta=self.PersonMeta(weight=0)).save() - self.assertAlmostEqual( - int(self.Person.objects.average("person_meta.weight")), 0 + assert ( + round(abs(int(self.Person.objects.average("person_meta.weight")) - 0), 7) + == 0 ) for i, weight in enumerate(ages): @@ -3193,17 +3194,18 @@ class TestQueryset(unittest.TestCase): name="test meta%i", person_meta=self.PersonMeta(weight=weight) ).save() - self.assertAlmostEqual( - int(self.Person.objects.average("person_meta.weight")), avg + assert ( + round(abs(int(self.Person.objects.average("person_meta.weight")) - avg), 7) + == 0 ) self.Person(name="test meta none").save() - self.assertEqual(int(self.Person.objects.average("person_meta.weight")), avg) + assert int(self.Person.objects.average("person_meta.weight")) == avg # test summing over a filtered queryset over_50 = [a for a in ages if a >= 50] avg = float(sum(over_50)) / len(over_50) - self.assertEqual(self.Person.objects.filter(age__gte=50).average("age"), avg) + assert self.Person.objects.filter(age__gte=50).average("age") == avg def test_sum(self): """Ensure that field can be summed over correctly. @@ -3212,25 +3214,24 @@ class TestQueryset(unittest.TestCase): for i, age in enumerate(ages): self.Person(name="test%s" % i, age=age).save() - self.assertEqual(self.Person.objects.sum("age"), sum(ages)) + assert self.Person.objects.sum("age") == sum(ages) self.Person(name="ageless person").save() - self.assertEqual(self.Person.objects.sum("age"), sum(ages)) + assert self.Person.objects.sum("age") == sum(ages) for i, age in enumerate(ages): self.Person( name="test meta%s" % i, person_meta=self.PersonMeta(weight=age) ).save() - self.assertEqual(self.Person.objects.sum("person_meta.weight"), sum(ages)) + assert self.Person.objects.sum("person_meta.weight") == sum(ages) self.Person(name="weightless person").save() - self.assertEqual(self.Person.objects.sum("age"), sum(ages)) + assert self.Person.objects.sum("age") == sum(ages) # test summing over a filtered queryset - self.assertEqual( - self.Person.objects.filter(age__gte=50).sum("age"), - sum([a for a in ages if a >= 50]), + assert self.Person.objects.filter(age__gte=50).sum("age") == sum( + [a for a in ages if a >= 50] ) def test_sum_over_db_field(self): @@ -3246,7 +3247,7 @@ class TestQueryset(unittest.TestCase): UserVisit.objects.create(num_visits=10) UserVisit.objects.create(num_visits=5) - self.assertEqual(UserVisit.objects.sum("num_visits"), 15) + assert UserVisit.objects.sum("num_visits") == 15 def test_average_over_db_field(self): """Ensure that a field mapped to a db field with a different name @@ -3261,7 +3262,7 @@ class TestQueryset(unittest.TestCase): UserVisit.objects.create(num_visits=20) UserVisit.objects.create(num_visits=10) - self.assertEqual(UserVisit.objects.average("num_visits"), 15) + assert UserVisit.objects.average("num_visits") == 15 def test_embedded_average(self): class Pay(EmbeddedDocument): @@ -3278,7 +3279,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(value=165)).save() Doc(name="Eliana Costa", pay=Pay(value=115)).save() - self.assertEqual(Doc.objects.average("pay.value"), 240) + assert Doc.objects.average("pay.value") == 240 def test_embedded_array_average(self): class Pay(EmbeddedDocument): @@ -3295,7 +3296,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save() Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save() - self.assertEqual(Doc.objects.average("pay.values"), 170) + assert Doc.objects.average("pay.values") == 170 def test_array_average(self): class Doc(Document): @@ -3308,7 +3309,7 @@ class TestQueryset(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual(Doc.objects.average("values"), 170) + assert Doc.objects.average("values") == 170 def test_embedded_sum(self): class Pay(EmbeddedDocument): @@ -3325,7 +3326,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(value=165)).save() Doc(name="Eliana Costa", pay=Pay(value=115)).save() - self.assertEqual(Doc.objects.sum("pay.value"), 960) + assert Doc.objects.sum("pay.value") == 960 def test_embedded_array_sum(self): class Pay(EmbeddedDocument): @@ -3342,7 +3343,7 @@ class TestQueryset(unittest.TestCase): Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save() Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save() - self.assertEqual(Doc.objects.sum("pay.values"), 1360) + assert Doc.objects.sum("pay.values") == 1360 def test_array_sum(self): class Doc(Document): @@ -3355,7 +3356,7 @@ class TestQueryset(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual(Doc.objects.sum("values"), 1360) + assert Doc.objects.sum("values") == 1360 def test_distinct(self): """Ensure that the QuerySet.distinct method works. @@ -3364,14 +3365,12 @@ class TestQueryset(unittest.TestCase): self.Person(name="Mr White", age=20).save() self.Person(name="Mr Orange", age=30).save() self.Person(name="Mr Pink", age=30).save() - self.assertEqual( - set(self.Person.objects.distinct("name")), - set(["Mr Orange", "Mr White", "Mr Pink"]), + assert set(self.Person.objects.distinct("name")) == set( + ["Mr Orange", "Mr White", "Mr Pink"] ) - self.assertEqual(set(self.Person.objects.distinct("age")), set([20, 30])) - self.assertEqual( - set(self.Person.objects(age=30).distinct("name")), - set(["Mr Orange", "Mr Pink"]), + assert set(self.Person.objects.distinct("age")) == set([20, 30]) + assert set(self.Person.objects(age=30).distinct("name")) == set( + ["Mr Orange", "Mr Pink"] ) def test_distinct_handles_references(self): @@ -3390,7 +3389,7 @@ class TestQueryset(unittest.TestCase): foo = Foo(bar=bar) foo.save() - self.assertEqual(Foo.objects.distinct("bar"), [bar]) + assert Foo.objects.distinct("bar") == [bar] def test_text_indexes(self): class News(Document): @@ -3410,8 +3409,8 @@ class TestQueryset(unittest.TestCase): News.drop_collection() info = News.objects._collection.index_information() - self.assertIn("title_text_content_text", info) - self.assertIn("textIndexVersion", info["title_text_content_text"]) + assert "title_text_content_text" in info + assert "textIndexVersion" in info["title_text_content_text"] News( title="Neymar quebrou a vertebra", @@ -3426,11 +3425,11 @@ class TestQueryset(unittest.TestCase): count = News.objects.search_text("neymar", language="portuguese").count() - self.assertEqual(count, 1) + assert count == 1 count = News.objects.search_text("brasil -neymar").count() - self.assertEqual(count, 1) + assert count == 1 News( title=u"As eleições no Brasil já estão em planejamento", @@ -3442,41 +3441,41 @@ class TestQueryset(unittest.TestCase): query = News.objects(is_active=False).search_text("dilma", language="pt")._query - self.assertEqual( - query, - {"$text": {"$search": "dilma", "$language": "pt"}, "is_active": False}, - ) + assert query == { + "$text": {"$search": "dilma", "$language": "pt"}, + "is_active": False, + } - self.assertFalse(new.is_active) - self.assertIn("dilma", new.content) - self.assertIn("planejamento", new.title) + assert not new.is_active + assert "dilma" in new.content + assert "planejamento" in new.title query = News.objects.search_text("candidata") - self.assertEqual(query._search_text, "candidata") + assert query._search_text == "candidata" new = query.first() - self.assertIsInstance(new.get_text_score(), float) + assert isinstance(new.get_text_score(), float) # count query = News.objects.search_text("brasil").order_by("$text_score") - self.assertEqual(query._search_text, "brasil") + assert query._search_text == "brasil" - self.assertEqual(query.count(), 3) - self.assertEqual(query._query, {"$text": {"$search": "brasil"}}) + assert query.count() == 3 + assert query._query == {"$text": {"$search": "brasil"}} cursor_args = query._cursor_args cursor_args_fields = cursor_args["projection"] - self.assertEqual(cursor_args_fields, {"_text_score": {"$meta": "textScore"}}) + assert cursor_args_fields == {"_text_score": {"$meta": "textScore"}} text_scores = [i.get_text_score() for i in query] - self.assertEqual(len(text_scores), 3) + assert len(text_scores) == 3 - self.assertTrue(text_scores[0] > text_scores[1]) - self.assertTrue(text_scores[1] > text_scores[2]) + assert text_scores[0] > text_scores[1] + assert text_scores[1] > text_scores[2] max_text_score = text_scores[0] # get item item = News.objects.search_text("brasil").order_by("$text_score").first() - self.assertEqual(item.get_text_score(), max_text_score) + assert item.get_text_score() == max_text_score def test_distinct_handles_references_to_alias(self): register_connection("testdb", "mongoenginetest2") @@ -3498,7 +3497,7 @@ class TestQueryset(unittest.TestCase): foo = Foo(bar=bar) foo.save() - self.assertEqual(Foo.objects.distinct("bar"), [bar]) + assert Foo.objects.distinct("bar") == [bar] def test_distinct_handles_db_field(self): """Ensure that distinct resolves field name to db_field as expected. @@ -3513,8 +3512,8 @@ class TestQueryset(unittest.TestCase): Product(product_id=2).save() Product(product_id=1).save() - self.assertEqual(set(Product.objects.distinct("product_id")), set([1, 2])) - self.assertEqual(set(Product.objects.distinct("pid")), set([1, 2])) + assert set(Product.objects.distinct("product_id")) == set([1, 2]) + assert set(Product.objects.distinct("pid")) == set([1, 2]) Product.drop_collection() @@ -3536,7 +3535,7 @@ class TestQueryset(unittest.TestCase): Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) authors = Book.objects.distinct("authors") - self.assertEqual(authors, [mark_twain, john_tolkien]) + assert authors == [mark_twain, john_tolkien] def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): class Continent(EmbeddedDocument): @@ -3570,10 +3569,10 @@ class TestQueryset(unittest.TestCase): Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) country_list = Book.objects.distinct("authors.country") - self.assertEqual(country_list, [scotland, tibet]) + assert country_list == [scotland, tibet] continent_list = Book.objects.distinct("authors.country.continent") - self.assertEqual(continent_list, [europe, asia]) + assert continent_list == [europe, asia] def test_distinct_ListField_ReferenceField(self): class Bar(Document): @@ -3595,7 +3594,7 @@ class TestQueryset(unittest.TestCase): foo = Foo(bar=bar_1, bar_lst=[bar_1, bar_2]) foo.save() - self.assertEqual(Foo.objects.distinct("bar_lst"), [bar_1, bar_2]) + assert Foo.objects.distinct("bar_lst") == [bar_1, bar_2] def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. @@ -3627,15 +3626,15 @@ class TestQueryset(unittest.TestCase): post3 = BlogPost(tags=["film", "actors"]).save() post4 = BlogPost(tags=["film", "actors", "music"], deleted=True).save() - self.assertEqual( - [p.id for p in BlogPost.objects()], [post1.id, post2.id, post3.id] - ) - self.assertEqual( - [p.id for p in BlogPost.objects_1_arg()], [post1.id, post2.id, post3.id] - ) - self.assertEqual([p.id for p in BlogPost.music_posts()], [post1.id, post2.id]) + assert [p.id for p in BlogPost.objects()] == [post1.id, post2.id, post3.id] + assert [p.id for p in BlogPost.objects_1_arg()] == [ + post1.id, + post2.id, + post3.id, + ] + assert [p.id for p in BlogPost.music_posts()] == [post1.id, post2.id] - self.assertEqual([p.id for p in BlogPost.music_posts(True)], [post4.id]) + assert [p.id for p in BlogPost.music_posts(True)] == [post4.id] BlogPost.drop_collection() @@ -3657,12 +3656,12 @@ class TestQueryset(unittest.TestCase): Foo(active=True).save() Foo(active=False).save() - self.assertEqual(1, Foo.objects.count()) - self.assertEqual(1, Foo.with_inactive.count()) + assert 1 == Foo.objects.count() + assert 1 == Foo.with_inactive.count() Foo.with_inactive.first().delete() - self.assertEqual(0, Foo.with_inactive.count()) - self.assertEqual(1, Foo.objects.count()) + assert 0 == Foo.with_inactive.count() + assert 1 == Foo.objects.count() def test_inherit_objects(self): class Foo(Document): @@ -3678,7 +3677,7 @@ class TestQueryset(unittest.TestCase): Bar.drop_collection() Bar.objects.create(active=False) - self.assertEqual(0, Bar.objects.count()) + assert 0 == Bar.objects.count() def test_inherit_objects_override(self): class Foo(Document): @@ -3696,8 +3695,8 @@ class TestQueryset(unittest.TestCase): Bar.drop_collection() Bar.objects.create(active=False) - self.assertEqual(0, Foo.objects.count()) - self.assertEqual(1, Bar.objects.count()) + assert 0 == Foo.objects.count() + assert 1 == Bar.objects.count() def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. @@ -3718,11 +3717,11 @@ class TestQueryset(unittest.TestCase): # while using a ReferenceField's name - the document should be # converted to an DBRef, which is legal, unlike a Document object post_obj = BlogPost.objects(author=person).first() - self.assertEqual(post.id, post_obj.id) + assert post.id == post_obj.id # Test that lists of values work when using the 'in', 'nin' and 'all' post_obj = BlogPost.objects(author__in=[person]).first() - self.assertEqual(post.id, post_obj.id) + assert post.id == post_obj.id BlogPost.drop_collection() @@ -3746,9 +3745,9 @@ class TestQueryset(unittest.TestCase): Group.objects(id=group.id).update(set__members=[user1, user2]) group.reload() - self.assertEqual(len(group.members), 2) - self.assertEqual(group.members[0].name, user1.name) - self.assertEqual(group.members[1].name, user2.name) + assert len(group.members) == 2 + assert group.members[0].name == user1.name + assert group.members[1].name == user2.name Group.drop_collection() @@ -3776,15 +3775,15 @@ class TestQueryset(unittest.TestCase): ids = [post_1.id, post_2.id, post_5.id] objects = BlogPost.objects.in_bulk(ids) - self.assertEqual(len(objects), 3) + assert len(objects) == 3 - self.assertIn(post_1.id, objects) - self.assertIn(post_2.id, objects) - self.assertIn(post_5.id, objects) + assert post_1.id in objects + assert post_2.id in objects + assert post_5.id in objects - self.assertEqual(objects[post_1.id].title, post_1.title) - self.assertEqual(objects[post_2.id].title, post_2.title) - self.assertEqual(objects[post_5.id].title, post_5.title) + assert objects[post_1.id].title == post_1.title + assert objects[post_2.id].title == post_2.title + assert objects[post_5.id].title == post_5.title BlogPost.drop_collection() @@ -3804,11 +3803,11 @@ class TestQueryset(unittest.TestCase): Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3828,11 +3827,11 @@ class TestQueryset(unittest.TestCase): Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3853,8 +3852,8 @@ class TestQueryset(unittest.TestCase): Post().save() Post(is_published=True).save() - self.assertEqual(Post.objects.count(), 2) - self.assertEqual(Post.published.count(), 1) + assert Post.objects.count() == 2 + assert Post.published.count() == 1 Post.drop_collection() @@ -3873,11 +3872,11 @@ class TestQueryset(unittest.TestCase): pass Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3900,11 +3899,11 @@ class TestQueryset(unittest.TestCase): pass Post.drop_collection() - self.assertIsInstance(Post.objects, CustomQuerySet) - self.assertFalse(Post.objects.not_empty()) + assert isinstance(Post.objects, CustomQuerySet) + assert not Post.objects.not_empty() Post().save() - self.assertTrue(Post.objects.not_empty()) + assert Post.objects.not_empty() Post.drop_collection() @@ -3917,13 +3916,9 @@ class TestQueryset(unittest.TestCase): for i in range(10): Post(title="Post %s" % i).save() - self.assertEqual( - 5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True) - ) + assert 5 == Post.objects.limit(5).skip(5).count(with_limit_and_skip=True) - self.assertEqual( - 10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False) - ) + assert 10 == Post.objects.limit(5).skip(5).count(with_limit_and_skip=False) def test_count_and_none(self): """Test count works with None()""" @@ -3935,8 +3930,8 @@ class TestQueryset(unittest.TestCase): for i in range(0, 10): MyDoc().save() - self.assertEqual(MyDoc.objects.count(), 10) - self.assertEqual(MyDoc.objects.none().count(), 0) + assert MyDoc.objects.count() == 10 + assert MyDoc.objects.none().count() == 0 def test_count_list_embedded(self): class B(EmbeddedDocument): @@ -3945,7 +3940,7 @@ class TestQueryset(unittest.TestCase): class A(Document): b = ListField(EmbeddedDocumentField(B)) - self.assertEqual(A.objects(b=[{"c": "c"}]).count(), 0) + assert A.objects(b=[{"c": "c"}]).count() == 0 def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works @@ -3960,7 +3955,7 @@ class TestQueryset(unittest.TestCase): Post(title="Post 2").save() posts = Post.objects.all()[0:1] - self.assertEqual(len(list(posts())), 1) + assert len(list(posts())) == 1 Post.drop_collection() @@ -3976,9 +3971,9 @@ class TestQueryset(unittest.TestCase): n2 = Number.objects.create(n=2) n1 = Number.objects.create(n=1) - self.assertEqual(list(Number.objects), [n2, n1]) - self.assertEqual(list(Number.objects.order_by("n")), [n1, n2]) - self.assertEqual(list(Number.objects.order_by("n").filter()), [n1, n2]) + assert list(Number.objects) == [n2, n1] + assert list(Number.objects.order_by("n")) == [n1, n2] + assert list(Number.objects.order_by("n").filter()) == [n1, n2] Number.drop_collection() @@ -3997,18 +3992,18 @@ class TestQueryset(unittest.TestCase): test = Number.objects test2 = test.clone() - self.assertNotEqual(test, test2) - self.assertEqual(test.count(), test2.count()) + assert test != test2 + assert test.count() == test2.count() test = test.filter(n__gt=11) test2 = test.clone() - self.assertNotEqual(test, test2) - self.assertEqual(test.count(), test2.count()) + assert test != test2 + assert test.count() == test2.count() test = test.limit(10) test2 = test.clone() - self.assertNotEqual(test, test2) - self.assertEqual(test.count(), test2.count()) + assert test != test2 + assert test.count() == test2.count() Number.drop_collection() @@ -4028,7 +4023,7 @@ class TestQueryset(unittest.TestCase): t.switch_db("test2") t.save() - self.assertEqual(len(Number2.objects.using("test2")), 9) + assert len(Number2.objects.using("test2")) == 9 def test_unset_reference(self): class Comment(Document): @@ -4043,10 +4038,10 @@ class TestQueryset(unittest.TestCase): comment = Comment.objects.create(text="test") post = Post.objects.create(comment=comment) - self.assertEqual(post.comment, comment) + assert post.comment == comment Post.objects.update(unset__comment=1) post.reload() - self.assertEqual(post.comment, None) + assert post.comment == None Comment.drop_collection() Post.drop_collection() @@ -4060,8 +4055,8 @@ class TestQueryset(unittest.TestCase): n2 = Number.objects.create(n=2) n1 = Number.objects.create(n=1) - self.assertEqual(list(Number.objects), [n2, n1]) - self.assertEqual(list(Number.objects.order_by("n")), [n1, n2]) + assert list(Number.objects) == [n2, n1] + assert list(Number.objects.order_by("n")) == [n1, n2] Number.drop_collection() @@ -4079,10 +4074,10 @@ class TestQueryset(unittest.TestCase): Number(n=3).save() numbers = [n.n for n in Number.objects.order_by("-n")] - self.assertEqual([3, 2, 1], numbers) + assert [3, 2, 1] == numbers numbers = [n.n for n in Number.objects.order_by("+n")] - self.assertEqual([1, 2, 3], numbers) + assert [1, 2, 3] == numbers Number.drop_collection() def test_ensure_index(self): @@ -4100,7 +4095,7 @@ class TestQueryset(unittest.TestCase): (value["key"], value.get("unique", False), value.get("sparse", False)) for key, value in iteritems(info) ] - self.assertIn(([("_cls", 1), ("message", 1)], False, False), info) + assert ([("_cls", 1), ("message", 1)], False, False) in info def test_where(self): """Ensure that where clauses work. @@ -4120,30 +4115,30 @@ class TestQueryset(unittest.TestCase): c.save() query = IntPair.objects.where("this[~fielda] >= this[~fieldb]") - self.assertEqual('this["fielda"] >= this["fieldb"]', query._where_clause) + assert 'this["fielda"] >= this["fieldb"]' == query._where_clause results = list(query) - self.assertEqual(2, len(results)) - self.assertIn(a, results) - self.assertIn(c, results) + assert 2 == len(results) + assert a in results + assert c in results query = IntPair.objects.where("this[~fielda] == this[~fieldb]") results = list(query) - self.assertEqual(1, len(results)) - self.assertIn(a, results) + assert 1 == len(results) + assert a in results query = IntPair.objects.where( "function() { return this[~fielda] >= this[~fieldb] }" ) - self.assertEqual( - 'function() { return this["fielda"] >= this["fieldb"] }', - query._where_clause, + assert ( + 'function() { return this["fielda"] >= this["fieldb"] }' + == query._where_clause ) results = list(query) - self.assertEqual(2, len(results)) - self.assertIn(a, results) - self.assertIn(c, results) + assert 2 == len(results) + assert a in results + assert c in results - with self.assertRaises(TypeError): + with pytest.raises(TypeError): list(IntPair.objects.where(fielda__gte=3)) def test_scalar(self): @@ -4165,13 +4160,13 @@ class TestQueryset(unittest.TestCase): # set of users (Pretend this has additional filtering.) user_orgs = set(User.objects.scalar("organization")) orgs = Organization.objects(id__in=user_orgs).scalar("name") - self.assertEqual(list(orgs), ["White House"]) + assert list(orgs) == ["White House"] # Efficient for generating listings, too. orgs = Organization.objects.scalar("name").in_bulk(list(user_orgs)) user_map = User.objects.scalar("name", "organization") user_listing = [(user, orgs[org]) for user, org in user_map] - self.assertEqual([("Bob Dole", "White House")], user_listing) + assert [("Bob Dole", "White House")] == user_listing def test_scalar_simple(self): class TestDoc(Document): @@ -4186,10 +4181,10 @@ class TestQueryset(unittest.TestCase): plist = list(TestDoc.objects.scalar("x", "y")) - self.assertEqual(len(plist), 3) - self.assertEqual(plist[0], (10, True)) - self.assertEqual(plist[1], (20, False)) - self.assertEqual(plist[2], (30, True)) + assert len(plist) == 3 + assert plist[0] == (10, True) + assert plist[1] == (20, False) + assert plist[2] == (30, True) class UserDoc(Document): name = StringField() @@ -4204,14 +4199,16 @@ class TestQueryset(unittest.TestCase): ulist = list(UserDoc.objects.scalar("name", "age")) - self.assertEqual( - ulist, - [(u"Wilson Jr", 19), (u"Wilson", 43), (u"Eliana", 37), (u"Tayza", 15)], - ) + assert ulist == [ + (u"Wilson Jr", 19), + (u"Wilson", 43), + (u"Eliana", 37), + (u"Tayza", 15), + ] ulist = list(UserDoc.objects.scalar("name").order_by("age")) - self.assertEqual(ulist, [(u"Tayza"), (u"Wilson Jr"), (u"Eliana"), (u"Wilson")]) + assert ulist == [(u"Tayza"), (u"Wilson Jr"), (u"Eliana"), (u"Wilson")] def test_scalar_embedded(self): class Profile(EmbeddedDocument): @@ -4248,25 +4245,21 @@ class TestQueryset(unittest.TestCase): locale=Locale(city="Brasilia", country="Brazil"), ).save() - self.assertEqual( - list(Person.objects.order_by("profile__age").scalar("profile__name")), - [u"Wilson Jr", u"Gabriel Falcao", u"Lincoln de souza", u"Walter cruz"], - ) + assert list( + Person.objects.order_by("profile__age").scalar("profile__name") + ) == [u"Wilson Jr", u"Gabriel Falcao", u"Lincoln de souza", u"Walter cruz"] ulist = list( Person.objects.order_by("locale.city").scalar( "profile__name", "profile__age", "locale__city" ) ) - self.assertEqual( - ulist, - [ - (u"Lincoln de souza", 28, u"Belo Horizonte"), - (u"Walter cruz", 30, u"Brasilia"), - (u"Wilson Jr", 19, u"Corumba-GO"), - (u"Gabriel Falcao", 23, u"New York"), - ], - ) + assert ulist == [ + (u"Lincoln de souza", 28, u"Belo Horizonte"), + (u"Walter cruz", 30, u"Brasilia"), + (u"Wilson Jr", 19, u"Corumba-GO"), + (u"Gabriel Falcao", 23, u"New York"), + ] def test_scalar_decimal(self): from decimal import Decimal @@ -4279,7 +4272,7 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson Jr", rating=Decimal("1.0")).save() ulist = list(Person.objects.scalar("name", "rating")) - self.assertEqual(ulist, [(u"Wilson Jr", Decimal("1.0"))]) + assert ulist == [(u"Wilson Jr", Decimal("1.0"))] def test_scalar_reference_field(self): class State(Document): @@ -4298,7 +4291,7 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson JR", state=s1).save() plist = list(Person.objects.scalar("name", "state")) - self.assertEqual(plist, [(u"Wilson JR", s1)]) + assert plist == [(u"Wilson JR", s1)] def test_scalar_generic_reference_field(self): class State(Document): @@ -4317,7 +4310,7 @@ class TestQueryset(unittest.TestCase): Person(name="Wilson JR", state=s1).save() plist = list(Person.objects.scalar("name", "state")) - self.assertEqual(plist, [(u"Wilson JR", s1)]) + assert plist == [(u"Wilson JR", s1)] def test_generic_reference_field_with_only_and_as_pymongo(self): class TestPerson(Document): @@ -4342,18 +4335,18 @@ class TestQueryset(unittest.TestCase): .no_dereference() .first() ) - self.assertEqual(activity[0], a1.pk) - self.assertEqual(activity[1]["_ref"], DBRef("test_person", person.pk)) + assert activity[0] == a1.pk + assert activity[1]["_ref"] == DBRef("test_person", person.pk) activity = TestActivity.objects(owner=person).only("id", "owner")[0] - self.assertEqual(activity.pk, a1.pk) - self.assertEqual(activity.owner, person) + assert activity.pk == a1.pk + assert activity.owner == person activity = ( TestActivity.objects(owner=person).only("id", "owner").as_pymongo().first() ) - self.assertEqual(activity["_id"], a1.pk) - self.assertTrue(activity["owner"]["_ref"], DBRef("test_person", person.pk)) + assert activity["_id"] == a1.pk + assert activity["owner"]["_ref"], DBRef("test_person", person.pk) def test_scalar_db_field(self): class TestDoc(Document): @@ -4367,10 +4360,10 @@ class TestQueryset(unittest.TestCase): TestDoc(x=30, y=True).save() plist = list(TestDoc.objects.scalar("x", "y")) - self.assertEqual(len(plist), 3) - self.assertEqual(plist[0], (10, True)) - self.assertEqual(plist[1], (20, False)) - self.assertEqual(plist[2], (30, True)) + assert len(plist) == 3 + assert plist[0] == (10, True) + assert plist[1] == (20, False) + assert plist[2] == (30, True) def test_scalar_primary_key(self): class SettingValue(Document): @@ -4382,7 +4375,7 @@ class TestQueryset(unittest.TestCase): s.save() val = SettingValue.objects.scalar("key", "value") - self.assertEqual(list(val), [("test", "test value")]) + assert list(val) == [("test", "test value")] def test_scalar_cursor_behaviour(self): """Ensure that a query returns a valid set of results. @@ -4394,90 +4387,86 @@ class TestQueryset(unittest.TestCase): # Find all people in the collection people = self.Person.objects.scalar("name") - self.assertEqual(people.count(), 2) + assert people.count() == 2 results = list(people) - self.assertEqual(results[0], "User A") - self.assertEqual(results[1], "User B") + assert results[0] == "User A" + assert results[1] == "User B" # Use a query to filter the people found to just person1 people = self.Person.objects(age=20).scalar("name") - self.assertEqual(people.count(), 1) + assert people.count() == 1 person = people.next() - self.assertEqual(person, "User A") + assert person == "User A" # Test limit people = list(self.Person.objects.limit(1).scalar("name")) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], "User A") + assert len(people) == 1 + assert people[0] == "User A" # Test skip people = list(self.Person.objects.skip(1).scalar("name")) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], "User B") + assert len(people) == 1 + assert people[0] == "User B" person3 = self.Person(name="User C", age=40) person3.save() # Test slice limit people = list(self.Person.objects[:2].scalar("name")) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], "User A") - self.assertEqual(people[1], "User B") + assert len(people) == 2 + assert people[0] == "User A" + assert people[1] == "User B" # Test slice skip people = list(self.Person.objects[1:].scalar("name")) - self.assertEqual(len(people), 2) - self.assertEqual(people[0], "User B") - self.assertEqual(people[1], "User C") + assert len(people) == 2 + assert people[0] == "User B" + assert people[1] == "User C" # Test slice limit and skip people = list(self.Person.objects[1:2].scalar("name")) - self.assertEqual(len(people), 1) - self.assertEqual(people[0], "User B") + assert len(people) == 1 + assert people[0] == "User B" people = list(self.Person.objects[1:1].scalar("name")) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test slice out of range people = list(self.Person.objects.scalar("name")[80000:80001]) - self.assertEqual(len(people), 0) + assert len(people) == 0 # Test larger slice __repr__ self.Person.objects.delete() for i in range(55): self.Person(name="A%s" % i, age=i).save() - self.assertEqual(self.Person.objects.scalar("name").count(), 55) - self.assertEqual( - "A0", "%s" % self.Person.objects.order_by("name").scalar("name").first() - ) - self.assertEqual( - "A0", "%s" % self.Person.objects.scalar("name").order_by("name")[0] + assert self.Person.objects.scalar("name").count() == 55 + assert ( + "A0" == "%s" % self.Person.objects.order_by("name").scalar("name").first() ) + assert "A0" == "%s" % self.Person.objects.scalar("name").order_by("name")[0] if six.PY3: - self.assertEqual( - "['A1', 'A2']", - "%s" % self.Person.objects.order_by("age").scalar("name")[1:3], + assert ( + "['A1', 'A2']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[1:3] ) - self.assertEqual( - "['A51', 'A52']", - "%s" % self.Person.objects.order_by("age").scalar("name")[51:53], + assert ( + "['A51', 'A52']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[51:53] ) else: - self.assertEqual( - "[u'A1', u'A2']", - "%s" % self.Person.objects.order_by("age").scalar("name")[1:3], + assert ( + "[u'A1', u'A2']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[1:3] ) - self.assertEqual( - "[u'A51', u'A52']", - "%s" % self.Person.objects.order_by("age").scalar("name")[51:53], + assert ( + "[u'A51', u'A52']" + == "%s" % self.Person.objects.order_by("age").scalar("name")[51:53] ) # with_id and in_bulk person = self.Person.objects.order_by("name").first() - self.assertEqual( - "A0", "%s" % self.Person.objects.scalar("name").with_id(person.id) - ) + assert "A0" == "%s" % self.Person.objects.scalar("name").with_id(person.id) pks = self.Person.objects.order_by("age").scalar("pk")[1:3] names = self.Person.objects.scalar("name").in_bulk(list(pks)).values() @@ -4485,7 +4474,7 @@ class TestQueryset(unittest.TestCase): expected = "['A1', 'A2']" else: expected = "[u'A1', u'A2']" - self.assertEqual(expected, "%s" % sorted(names)) + assert expected == "%s" % sorted(names) def test_elem_match(self): class Foo(EmbeddedDocument): @@ -4525,29 +4514,29 @@ class TestQueryset(unittest.TestCase): b3.save() ak = list(Bar.objects(foo__match={"shape": "square", "color": "purple"})) - self.assertEqual([b1], ak) + assert [b1] == ak ak = list(Bar.objects(foo__elemMatch={"shape": "square", "color": "purple"})) - self.assertEqual([b1], ak) + assert [b1] == ak ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) - self.assertEqual([b1], ak) + assert [b1] == ak ak = list( Bar.objects(foo__elemMatch={"shape": "square", "color__exists": True}) ) - self.assertEqual([b1, b2], ak) + assert [b1, b2] == ak ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": True})) - self.assertEqual([b1, b2], ak) + assert [b1, b2] == ak ak = list( Bar.objects(foo__elemMatch={"shape": "square", "color__exists": False}) ) - self.assertEqual([b3], ak) + assert [b3] == ak ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": False})) - self.assertEqual([b3], ak) + assert [b3] == ak def test_upsert_includes_cls(self): """Upserts should include _cls information for inheritable classes @@ -4558,7 +4547,7 @@ class TestQueryset(unittest.TestCase): Test.drop_collection() Test.objects(test="foo").update_one(upsert=True, set__test="foo") - self.assertNotIn("_cls", Test._collection.find_one()) + assert "_cls" not in Test._collection.find_one() class Test(Document): meta = {"allow_inheritance": True} @@ -4567,15 +4556,15 @@ class TestQueryset(unittest.TestCase): Test.drop_collection() Test.objects(test="foo").update_one(upsert=True, set__test="foo") - self.assertIn("_cls", Test._collection.find_one()) + assert "_cls" in Test._collection.find_one() def test_update_upsert_looks_like_a_digit(self): class MyDoc(DynamicDocument): pass MyDoc.drop_collection() - self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1)) - self.assertEqual(MyDoc.objects.get()["47"], 1) + assert 1 == MyDoc.objects.update_one(upsert=True, inc__47=1) + assert MyDoc.objects.get()["47"] == 1 def test_dictfield_key_looks_like_a_digit(self): """Only should work with DictField even if they have numeric keys.""" @@ -4586,7 +4575,7 @@ class TestQueryset(unittest.TestCase): MyDoc.drop_collection() doc = MyDoc(test={"47": 1}) doc.save() - self.assertEqual(MyDoc.objects.only("test__47").get().test["47"], 1) + assert MyDoc.objects.only("test__47").get().test["47"] == 1 def test_clear_cls_query(self): class Parent(Document): @@ -4599,32 +4588,28 @@ class TestQueryset(unittest.TestCase): Parent.drop_collection() # Default query includes the "_cls" check. - self.assertEqual( - Parent.objects._query, {"_cls": {"$in": ("Parent", "Parent.Child")}} - ) + assert Parent.objects._query == {"_cls": {"$in": ("Parent", "Parent.Child")}} # Clearing the "_cls" query should work. - self.assertEqual(Parent.objects.clear_cls_query()._query, {}) + assert Parent.objects.clear_cls_query()._query == {} # Clearing the "_cls" query should not persist across queryset instances. - self.assertEqual( - Parent.objects._query, {"_cls": {"$in": ("Parent", "Parent.Child")}} - ) + assert Parent.objects._query == {"_cls": {"$in": ("Parent", "Parent.Child")}} # The rest of the query should not be cleared. - self.assertEqual( - Parent.objects.filter(name="xyz").clear_cls_query()._query, {"name": "xyz"} - ) + assert Parent.objects.filter(name="xyz").clear_cls_query()._query == { + "name": "xyz" + } Parent.objects.create(name="foo") Child.objects.create(name="bar", age=1) - self.assertEqual(Parent.objects.clear_cls_query().count(), 2) - self.assertEqual(Parent.objects.count(), 2) - self.assertEqual(Child.objects().count(), 1) + assert Parent.objects.clear_cls_query().count() == 2 + assert Parent.objects.count() == 2 + assert Child.objects().count() == 1 # XXX This isn't really how you'd want to use `clear_cls_query()`, but # it's a decent test to validate its behavior nonetheless. - self.assertEqual(Child.objects.clear_cls_query().count(), 2) + assert Child.objects.clear_cls_query().count() == 2 def test_read_preference(self): class Bar(Document): @@ -4636,20 +4621,21 @@ class TestQueryset(unittest.TestCase): bar = Bar.objects.create(txt="xyz") bars = list(Bar.objects.read_preference(ReadPreference.PRIMARY)) - self.assertEqual(bars, [bar]) + assert bars == [bar] bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) - self.assertEqual( - bars._cursor.collection.read_preference, ReadPreference.SECONDARY_PREFERRED + assert bars._read_preference == ReadPreference.SECONDARY_PREFERRED + assert ( + bars._cursor.collection.read_preference == ReadPreference.SECONDARY_PREFERRED ) # Make sure that `.read_preference(...)` does accept string values. - self.assertRaises(TypeError, Bar.objects.read_preference, "Primary") + with pytest.raises(TypeError): + Bar.objects.read_preference("Primary") def assert_read_pref(qs, expected_read_pref): - self.assertEqual(qs._read_preference, expected_read_pref) - self.assertEqual(qs._cursor.collection.read_preference, expected_read_pref) + assert qs._read_preference == expected_read_pref + assert qs._cursor.collection.read_preference == expected_read_pref # Make sure read preference is respected after a `.skip(...)`. bars = Bar.objects.skip(1).read_preference(ReadPreference.SECONDARY_PREFERRED) @@ -4681,9 +4667,9 @@ class TestQueryset(unittest.TestCase): bars = Bar.objects.read_preference( ReadPreference.SECONDARY_PREFERRED ).aggregate() - self.assertEqual( - bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED, + assert ( + bars._CommandCursor__collection.read_preference + == ReadPreference.SECONDARY_PREFERRED ) def test_json_simple(self): @@ -4702,7 +4688,7 @@ class TestQueryset(unittest.TestCase): json_data = Doc.objects.to_json(sort_keys=True, separators=(",", ":")) doc_objects = list(Doc.objects) - self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) + assert doc_objects == Doc.objects.from_json(json_data) def test_json_complex(self): class EmbeddedDoc(EmbeddedDocument): @@ -4748,7 +4734,7 @@ class TestQueryset(unittest.TestCase): json_data = Doc.objects.to_json() doc_objects = list(Doc.objects) - self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) + assert doc_objects == Doc.objects.from_json(json_data) def test_as_pymongo(self): class LastLogin(EmbeddedDocument): @@ -4774,36 +4760,33 @@ class TestQueryset(unittest.TestCase): ) results = User.objects.as_pymongo() - self.assertEqual(set(results[0].keys()), set(["_id", "name", "age", "price"])) - self.assertEqual( - set(results[1].keys()), set(["_id", "name", "age", "price", "last_login"]) + assert set(results[0].keys()) == set(["_id", "name", "age", "price"]) + assert set(results[1].keys()) == set( + ["_id", "name", "age", "price", "last_login"] ) results = User.objects.only("id", "name").as_pymongo() - self.assertEqual(set(results[0].keys()), set(["_id", "name"])) + assert set(results[0].keys()) == set(["_id", "name"]) users = User.objects.only("name", "price").as_pymongo() results = list(users) - self.assertIsInstance(results[0], dict) - self.assertIsInstance(results[1], dict) - self.assertEqual(results[0]["name"], "Bob Dole") - self.assertEqual(results[0]["price"], 1.11) - self.assertEqual(results[1]["name"], "Barak Obama") - self.assertEqual(results[1]["price"], 2.22) + assert isinstance(results[0], dict) + assert isinstance(results[1], dict) + assert results[0]["name"] == "Bob Dole" + assert results[0]["price"] == 1.11 + assert results[1]["name"] == "Barak Obama" + assert results[1]["price"] == 2.22 users = User.objects.only("name", "last_login").as_pymongo() results = list(users) - self.assertIsInstance(results[0], dict) - self.assertIsInstance(results[1], dict) - self.assertEqual(results[0], {"_id": "Bob", "name": "Bob Dole"}) - self.assertEqual( - results[1], - { - "_id": "Barak", - "name": "Barak Obama", - "last_login": {"location": "White House", "ip": "104.107.108.116"}, - }, - ) + assert isinstance(results[0], dict) + assert isinstance(results[1], dict) + assert results[0] == {"_id": "Bob", "name": "Bob Dole"} + assert results[1] == { + "_id": "Barak", + "name": "Barak Obama", + "last_login": {"location": "White House", "ip": "104.107.108.116"}, + } def test_as_pymongo_returns_cls_attribute_when_using_inheritance(self): class User(Document): @@ -4814,7 +4797,7 @@ class TestQueryset(unittest.TestCase): user = User(name="Bob Dole").save() result = User.objects.as_pymongo().first() - self.assertEqual(result, {"_cls": "User", "_id": user.id, "name": "Bob Dole"}) + assert result == {"_cls": "User", "_id": user.id, "name": "Bob Dole"} def test_as_pymongo_json_limit_fields(self): class User(Document): @@ -4830,30 +4813,30 @@ class TestQueryset(unittest.TestCase): serialized_user = User.objects.exclude( "password_salt", "password_hash" ).as_pymongo()[0] - self.assertEqual({"_id", "email"}, set(serialized_user.keys())) + assert {"_id", "email"} == set(serialized_user.keys()) serialized_user = User.objects.exclude( "id", "password_salt", "password_hash" ).to_json() - self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + assert '[{"email": "ross@example.com"}]' == serialized_user serialized_user = User.objects.only("email").as_pymongo()[0] - self.assertEqual({"_id", "email"}, set(serialized_user.keys())) + assert {"_id", "email"} == set(serialized_user.keys()) serialized_user = ( User.objects.exclude("password_salt").only("email").as_pymongo()[0] ) - self.assertEqual({"_id", "email"}, set(serialized_user.keys())) + assert {"_id", "email"} == set(serialized_user.keys()) serialized_user = ( User.objects.exclude("password_salt", "id").only("email").as_pymongo()[0] ) - self.assertEqual({"email"}, set(serialized_user.keys())) + assert {"email"} == set(serialized_user.keys()) serialized_user = ( User.objects.exclude("password_salt", "id").only("email").to_json() ) - self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) + assert '[{"email": "ross@example.com"}]' == serialized_user def test_only_after_count(self): """Test that only() works after count()""" @@ -4869,13 +4852,13 @@ class TestQueryset(unittest.TestCase): user_queryset = User.objects(age=50) result = user_queryset.only("name", "age").as_pymongo().first() - self.assertEqual(result, {"_id": user.id, "name": "User", "age": 50}) + assert result == {"_id": user.id, "name": "User", "age": 50} result = user_queryset.count() - self.assertEqual(result, 1) + assert result == 1 result = user_queryset.only("name", "age").as_pymongo().first() - self.assertEqual(result, {"_id": user.id, "name": "User", "age": 50}) + assert result == {"_id": user.id, "name": "User", "age": 50} def test_no_dereference(self): class Organization(Document): @@ -4894,12 +4877,12 @@ class TestQueryset(unittest.TestCase): qs = User.objects() qs_user = qs.first() - self.assertIsInstance(qs.first().organization, Organization) + assert isinstance(qs.first().organization, Organization) - self.assertIsInstance(qs.no_dereference().first().organization, DBRef) + assert isinstance(qs.no_dereference().first().organization, DBRef) - self.assertIsInstance(qs_user.organization, Organization) - self.assertIsInstance(qs.first().organization, Organization) + assert isinstance(qs_user.organization, Organization) + assert isinstance(qs.first().organization, Organization) def test_no_dereference_internals(self): # Test the internals on which queryset.no_dereference relies on @@ -4913,24 +4896,24 @@ class TestQueryset(unittest.TestCase): Organization.drop_collection() cls_organization_field = User.organization - self.assertTrue(cls_organization_field._auto_dereference, True) # default + assert cls_organization_field._auto_dereference, True # default org = Organization(name="whatever").save() User(organization=org).save() qs_no_deref = User.objects().no_dereference() user_no_deref = qs_no_deref.first() - self.assertFalse(qs_no_deref._auto_dereference) + assert not qs_no_deref._auto_dereference # Make sure the instance field is different from the class field instance_org_field = user_no_deref._fields["organization"] - self.assertIsNot(instance_org_field, cls_organization_field) - self.assertFalse(instance_org_field._auto_dereference) + assert instance_org_field is not cls_organization_field + assert not instance_org_field._auto_dereference - self.assertIsInstance(user_no_deref.organization, DBRef) - self.assertTrue( - cls_organization_field._auto_dereference, True - ) # Make sure the class Field wasn't altered + assert isinstance(user_no_deref.organization, DBRef) + assert ( + cls_organization_field._auto_dereference + ), True # Make sure the class Field wasn't altered def test_no_dereference_no_side_effect_on_existing_instance(self): # Relates to issue #1677 - ensures no regression of the bug @@ -4956,13 +4939,13 @@ class TestQueryset(unittest.TestCase): # ReferenceField no_derf_org = user_no_deref.organization # was triggering the bug - self.assertIsInstance(no_derf_org, DBRef) - self.assertIsInstance(user.organization, Organization) + assert isinstance(no_derf_org, DBRef) + assert isinstance(user.organization, Organization) # GenericReferenceField no_derf_org_gen = user_no_deref.organization_gen - self.assertIsInstance(no_derf_org_gen, dict) - self.assertIsInstance(user.organization_gen, Organization) + assert isinstance(no_derf_org_gen, dict) + assert isinstance(user.organization_gen, Organization) def test_no_dereference_embedded_doc(self): class User(Document): @@ -4994,13 +4977,13 @@ class TestQueryset(unittest.TestCase): org = Organization.objects().no_dereference().first() - self.assertNotEqual(id(org._fields["admins"]), id(Organization.admins)) - self.assertFalse(org._fields["admins"]._auto_dereference) + assert id(org._fields["admins"]) != id(Organization.admins) + assert not org._fields["admins"]._auto_dereference admin = org.admins[0] - self.assertIsInstance(admin, DBRef) - self.assertIsInstance(org.member.user, DBRef) - self.assertIsInstance(org.members[0].user, DBRef) + assert isinstance(admin, DBRef) + assert isinstance(org.member.user, DBRef) + assert isinstance(org.members[0].user, DBRef) def test_cached_queryset(self): class Person(Document): @@ -5011,11 +4994,11 @@ class TestQueryset(unittest.TestCase): Person(name="No: %s" % i).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 people = Person.objects [x for x in people] - self.assertEqual(100, len(people._result_cache)) + assert 100 == len(people._result_cache) import platform @@ -5023,15 +5006,15 @@ class TestQueryset(unittest.TestCase): # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. # This may be a bug in PyPy (PyPy/#1802) but it does not affect # the behavior of MongoEngine. - self.assertEqual(None, people._len) - self.assertEqual(q, 1) + assert None == people._len + assert q == 1 list(people) - self.assertEqual(100, people._len) # Caused by list calling len - self.assertEqual(q, 1) + assert 100 == people._len # Caused by list calling len + assert q == 1 people.count(with_limit_and_skip=True) # count is cached - self.assertEqual(q, 1) + assert q == 1 def test_no_cached_queryset(self): class Person(Document): @@ -5042,17 +5025,17 @@ class TestQueryset(unittest.TestCase): Person(name="No: %s" % i).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 people = Person.objects.no_cache() [x for x in people] - self.assertEqual(q, 1) + assert q == 1 list(people) - self.assertEqual(q, 2) + assert q == 2 people.count() - self.assertEqual(q, 3) + assert q == 3 def test_no_cached_queryset__repr__(self): class Person(Document): @@ -5060,7 +5043,7 @@ class TestQueryset(unittest.TestCase): Person.drop_collection() qs = Person.objects.no_cache() - self.assertEqual(repr(qs), "[]") + assert repr(qs) == "[]" def test_no_cached_on_a_cached_queryset_raise_error(self): class Person(Document): @@ -5070,9 +5053,9 @@ class TestQueryset(unittest.TestCase): Person(name="a").save() qs = Person.objects() _ = list(qs) - with self.assertRaises(OperationError) as ctx_err: + with pytest.raises(OperationError) as ctx_err: qs.no_cache() - self.assertEqual("QuerySet already cached", str(ctx_err.exception)) + assert "QuerySet already cached" == str(ctx_err.exception) def test_no_cached_queryset_no_cache_back_to_cache(self): class Person(Document): @@ -5080,11 +5063,11 @@ class TestQueryset(unittest.TestCase): Person.drop_collection() qs = Person.objects() - self.assertIsInstance(qs, QuerySet) + assert isinstance(qs, QuerySet) qs = qs.no_cache() - self.assertIsInstance(qs, QuerySetNoCache) + assert isinstance(qs, QuerySetNoCache) qs = qs.cache() - self.assertIsInstance(qs, QuerySet) + assert isinstance(qs, QuerySet) def test_cache_not_cloned(self): class User(Document): @@ -5099,12 +5082,12 @@ class TestQueryset(unittest.TestCase): User(name="Bob").save() users = User.objects.all().order_by("name") - self.assertEqual("%s" % users, "[, ]") - self.assertEqual(2, len(users._result_cache)) + assert "%s" % users == "[, ]" + assert 2 == len(users._result_cache) users = users.filter(name="Bob") - self.assertEqual("%s" % users, "[]") - self.assertEqual(1, len(users._result_cache)) + assert "%s" % users == "[]" + assert 1 == len(users._result_cache) def test_no_cache(self): """Ensure you can add meta data to file""" @@ -5122,23 +5105,23 @@ class TestQueryset(unittest.TestCase): docs = Noddy.objects.no_cache() counter = len([1 for i in docs]) - self.assertEqual(counter, 100) + assert counter == 100 - self.assertEqual(len(list(docs)), 100) + assert len(list(docs)) == 100 # Can't directly get a length of a no-cache queryset. - with self.assertRaises(TypeError): + with pytest.raises(TypeError): len(docs) # Another iteration over the queryset should result in another db op. with query_counter() as q: list(docs) - self.assertEqual(q, 1) + assert q == 1 # ... and another one to double-check. with query_counter() as q: list(docs) - self.assertEqual(q, 1) + assert q == 1 def test_nested_queryset_iterator(self): # Try iterating the same queryset twice, nested. @@ -5161,32 +5144,32 @@ class TestQueryset(unittest.TestCase): inner_total_count = 0 with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 - self.assertEqual(users.count(with_limit_and_skip=True), 7) + assert users.count(with_limit_and_skip=True) == 7 for i, outer_user in enumerate(users): - self.assertEqual(outer_user.name, names[i]) + assert outer_user.name == names[i] outer_count += 1 inner_count = 0 # Calling len might disrupt the inner loop if there are bugs - self.assertEqual(users.count(with_limit_and_skip=True), 7) + assert users.count(with_limit_and_skip=True) == 7 for j, inner_user in enumerate(users): - self.assertEqual(inner_user.name, names[j]) + assert inner_user.name == names[j] inner_count += 1 inner_total_count += 1 # inner loop should always be executed seven times - self.assertEqual(inner_count, 7) + assert inner_count == 7 # outer loop should be executed seven times total - self.assertEqual(outer_count, 7) + assert outer_count == 7 # inner loop should be executed fourtynine times total - self.assertEqual(inner_total_count, 7 * 7) + assert inner_total_count == 7 * 7 - self.assertEqual(q, 2) + assert q == 2 def test_no_sub_classes(self): class A(Document): @@ -5209,23 +5192,23 @@ class TestQueryset(unittest.TestCase): B(x=30, y=50).save() C(x=40, y=60).save() - self.assertEqual(A.objects.no_sub_classes().count(), 2) - self.assertEqual(A.objects.count(), 5) + assert A.objects.no_sub_classes().count() == 2 + assert A.objects.count() == 5 - self.assertEqual(B.objects.no_sub_classes().count(), 2) - self.assertEqual(B.objects.count(), 3) + assert B.objects.no_sub_classes().count() == 2 + assert B.objects.count() == 3 - self.assertEqual(C.objects.no_sub_classes().count(), 1) - self.assertEqual(C.objects.count(), 1) + assert C.objects.no_sub_classes().count() == 1 + assert C.objects.count() == 1 for obj in A.objects.no_sub_classes(): - self.assertEqual(obj.__class__, A) + assert obj.__class__ == A for obj in B.objects.no_sub_classes(): - self.assertEqual(obj.__class__, B) + assert obj.__class__ == B for obj in C.objects.no_sub_classes(): - self.assertEqual(obj.__class__, C) + assert obj.__class__ == C def test_query_generic_embedded_document(self): """Ensure that querying sub field on generic_embedded_field works @@ -5245,10 +5228,10 @@ class TestQueryset(unittest.TestCase): Doc(document=B(b_name="B doc")).save() # Using raw in filter working fine - self.assertEqual(Doc.objects(__raw__={"document.a_name": "A doc"}).count(), 1) - self.assertEqual(Doc.objects(__raw__={"document.b_name": "B doc"}).count(), 1) - self.assertEqual(Doc.objects(document__a_name="A doc").count(), 1) - self.assertEqual(Doc.objects(document__b_name="B doc").count(), 1) + assert Doc.objects(__raw__={"document.a_name": "A doc"}).count() == 1 + assert Doc.objects(__raw__={"document.b_name": "B doc"}).count() == 1 + assert Doc.objects(document__a_name="A doc").count() == 1 + assert Doc.objects(document__b_name="B doc").count() == 1 def test_query_reference_to_custom_pk_doc(self): class A(Document): @@ -5263,9 +5246,9 @@ class TestQueryset(unittest.TestCase): a = A.objects.create(id="custom_id") B.objects.create(a=a) - self.assertEqual(B.objects.count(), 1) - self.assertEqual(B.objects.get(a=a).a, a) - self.assertEqual(B.objects.get(a=a.id).a, a) + assert B.objects.count() == 1 + assert B.objects.get(a=a).a == a + assert B.objects.get(a=a.id).a == a def test_cls_query_in_subclassed_docs(self): class Animal(Document): @@ -5279,21 +5262,18 @@ class TestQueryset(unittest.TestCase): class Cat(Animal): pass - self.assertEqual( - Animal.objects(name="Charlie")._query, - { - "name": "Charlie", - "_cls": {"$in": ("Animal", "Animal.Dog", "Animal.Cat")}, - }, - ) - self.assertEqual( - Dog.objects(name="Charlie")._query, - {"name": "Charlie", "_cls": "Animal.Dog"}, - ) - self.assertEqual( - Cat.objects(name="Charlie")._query, - {"name": "Charlie", "_cls": "Animal.Cat"}, - ) + assert Animal.objects(name="Charlie")._query == { + "name": "Charlie", + "_cls": {"$in": ("Animal", "Animal.Dog", "Animal.Cat")}, + } + assert Dog.objects(name="Charlie")._query == { + "name": "Charlie", + "_cls": "Animal.Dog", + } + assert Cat.objects(name="Charlie")._query == { + "name": "Charlie", + "_cls": "Animal.Cat", + } def test_can_have_field_same_name_as_query_operator(self): class Size(Document): @@ -5308,8 +5288,8 @@ class TestQueryset(unittest.TestCase): instance_size = Size(name="Large").save() Example(size=instance_size).save() - self.assertEqual(Example.objects(size=instance_size).count(), 1) - self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) + assert Example.objects(size=instance_size).count() == 1 + assert Example.objects(size__in=[instance_size]).count() == 1 def test_cursor_in_an_if_stmt(self): class Test(Document): @@ -5347,12 +5327,12 @@ class TestQueryset(unittest.TestCase): if Person.objects: pass - self.assertEqual(q, 1) + assert q == 1 op = q.db.system.profile.find( {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertEqual(op["nreturned"], 1) + assert op["nreturned"] == 1 def test_bool_with_ordering(self): ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) @@ -5375,7 +5355,7 @@ class TestQueryset(unittest.TestCase): {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) + assert ORDER_BY_KEY not in op[CMD_QUERY_KEY] # Check that normal query uses orderby qs2 = Person.objects.order_by("name") @@ -5388,7 +5368,7 @@ class TestQueryset(unittest.TestCase): {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) + assert ORDER_BY_KEY in op[CMD_QUERY_KEY] def test_bool_with_ordering_from_meta_dict(self): ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) @@ -5412,16 +5392,12 @@ class TestQueryset(unittest.TestCase): {"ns": {"$ne": "%s.system.indexes" % q.db.name}} )[0] - self.assertNotIn( - "$orderby", - op[CMD_QUERY_KEY], - "BaseQuerySet must remove orderby from meta in boolen test", - ) + assert ( + "$orderby" not in op[CMD_QUERY_KEY] + ), "BaseQuerySet must remove orderby from meta in boolen test" - self.assertEqual(Person.objects.first().name, "A") - self.assertTrue( - Person.objects._has_data(), "Cursor has data and returned False" - ) + assert Person.objects.first().name == "A" + assert Person.objects._has_data(), "Cursor has data and returned False" def test_queryset_aggregation_framework(self): class Person(Document): @@ -5439,13 +5415,10 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual( - list(data), - [ - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - ], - ) + assert list(data) == [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + ] data = ( Person.objects(age__lte=22) @@ -5453,13 +5426,10 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual( - list(data), - [ - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - ], - ) + assert list(data) == [ + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + ] data = ( Person.objects(age__gte=17, age__lte=40) @@ -5468,12 +5438,10 @@ class TestQueryset(unittest.TestCase): {"$group": {"_id": None, "total": {"$sum": 1}, "avg": {"$avg": "$age"}}} ) ) - self.assertEqual(list(data), [{"_id": None, "avg": 29, "total": 2}]) + assert list(data) == [{"_id": None, "avg": 29, "total": 2}] data = Person.objects().aggregate({"$match": {"name": "Isabella Luanna"}}) - self.assertEqual( - list(data), [{u"_id": p1.pk, u"age": 16, u"name": u"Isabella Luanna"}] - ) + assert list(data) == [{u"_id": p1.pk, u"age": 16, u"name": u"Isabella Luanna"}] def test_queryset_aggregation_with_skip(self): class Person(Document): @@ -5491,13 +5459,10 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual( - list(data), - [ - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - {"_id": p3.pk, "name": "SANDRA MARA"}, - ], - ) + assert list(data) == [ + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + ] def test_queryset_aggregation_with_limit(self): class Person(Document): @@ -5515,7 +5480,7 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}]) + assert list(data) == [{"_id": p1.pk, "name": "ISABELLA LUANNA"}] def test_queryset_aggregation_with_sort(self): class Person(Document): @@ -5533,14 +5498,11 @@ class TestQueryset(unittest.TestCase): {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual( - list(data), - [ - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - {"_id": p3.pk, "name": "SANDRA MARA"}, - {"_id": p2.pk, "name": "WILSON JUNIOR"}, - ], - ) + assert list(data) == [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + ] def test_queryset_aggregation_with_skip_with_limit(self): class Person(Document): @@ -5560,7 +5522,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}]) + assert list(data) == [{"_id": p2.pk, "name": "WILSON JUNIOR"}] # Make sure limit/skip chaining order has no impact data2 = ( @@ -5569,7 +5531,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(data, list(data2)) + assert data == list(data2) def test_queryset_aggregation_with_sort_with_limit(self): class Person(Document): @@ -5589,13 +5551,10 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual( - list(data), - [ - {"_id": p1.pk, "name": "ISABELLA LUANNA"}, - {"_id": p3.pk, "name": "SANDRA MARA"}, - ], - ) + assert list(data) == [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + ] # Verify adding limit/skip steps works as expected data = ( @@ -5604,7 +5563,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}, {"$limit": 1}) ) - self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}]) + assert list(data) == [{"_id": p1.pk, "name": "ISABELLA LUANNA"}] data = ( Person.objects.order_by("name") @@ -5616,7 +5575,7 @@ class TestQueryset(unittest.TestCase): ) ) - self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}]) + assert list(data) == [{"_id": p3.pk, "name": "SANDRA MARA"}] def test_queryset_aggregation_with_sort_with_skip(self): class Person(Document): @@ -5636,7 +5595,7 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}]) + assert list(data) == [{"_id": p2.pk, "name": "WILSON JUNIOR"}] def test_queryset_aggregation_with_sort_with_skip_with_limit(self): class Person(Document): @@ -5657,30 +5616,29 @@ class TestQueryset(unittest.TestCase): .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}]) + assert list(data) == [{"_id": p3.pk, "name": "SANDRA MARA"}] def test_delete_count(self): [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] - self.assertEqual( - self.Person.objects().delete(), 3 + assert ( + self.Person.objects().delete() == 3 ) # test ordinary QuerySey delete count [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] - self.assertEqual( - self.Person.objects().skip(1).delete(), 2 + assert ( + self.Person.objects().skip(1).delete() == 2 ) # test Document delete with existing documents self.Person.objects().delete() - self.assertEqual( - self.Person.objects().skip(1).delete(), 0 + assert ( + self.Person.objects().skip(1).delete() == 0 ) # test Document delete without existing documents def test_max_time_ms(self): # 778: max_time_ms can get only int or None as input - self.assertRaises( - TypeError, self.Person.objects(name="name").max_time_ms, "not a number" - ) + with pytest.raises(TypeError): + self.Person.objects(name="name").max_time_ms("not a number") def test_subclass_field_query(self): class Animal(Document): @@ -5698,8 +5656,8 @@ class TestQueryset(unittest.TestCase): Animal(is_mamal=False).save() Cat(is_mamal=True, whiskers_length=5.1).save() ScottishCat(is_mamal=True, folded_ears=True).save() - self.assertEqual(Animal.objects(folded_ears=True).count(), 1) - self.assertEqual(Animal.objects(whiskers_length=5.1).count(), 1) + assert Animal.objects(folded_ears=True).count() == 1 + assert Animal.objects(whiskers_length=5.1).count() == 1 def test_loop_over_invalid_id_does_not_crash(self): class Person(Document): @@ -5709,7 +5667,7 @@ class TestQueryset(unittest.TestCase): Person._get_collection().insert_one({"name": "a", "id": ""}) for p in Person.objects(): - self.assertEqual(p.name, "a") + assert p.name == "a" def test_len_during_iteration(self): """Tests that calling len on a queyset during iteration doesn't @@ -5733,7 +5691,7 @@ class TestQueryset(unittest.TestCase): for i, r in enumerate(records): if i == 58: len(records) - self.assertEqual(i, 249) + assert i == 249 # Assert the same behavior is true even if we didn't pre-populate the # result cache. @@ -5741,7 +5699,7 @@ class TestQueryset(unittest.TestCase): for i, r in enumerate(records): if i == 58: len(records) - self.assertEqual(i, 249) + assert i == 249 def test_iteration_within_iteration(self): """You should be able to reliably iterate over all the documents @@ -5760,8 +5718,8 @@ class TestQueryset(unittest.TestCase): for j, doc2 in enumerate(qs): pass - self.assertEqual(i, 249) - self.assertEqual(j, 249) + assert i == 249 + assert j == 249 def test_in_operator_on_non_iterable(self): """Ensure that using the `__in` operator on a non-iterable raises an @@ -5785,24 +5743,26 @@ class TestQueryset(unittest.TestCase): # Make sure using `__in` with a list works blog_posts = BlogPost.objects(authors__in=[author]) - self.assertEqual(list(blog_posts), [post]) + assert list(blog_posts) == [post] # Using `__in` with a non-iterable should raise a TypeError - self.assertRaises(TypeError, BlogPost.objects(authors__in=author.pk).count) + with pytest.raises(TypeError): + BlogPost.objects(authors__in=author.pk).count() # Using `__in` with a `Document` (which is seemingly iterable but not # in a way we'd expect) should raise a TypeError, too - self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count) + with pytest.raises(TypeError): + BlogPost.objects(authors__in=author).count() def test_create_count(self): self.Person.drop_collection() self.Person.objects.create(name="Foo") self.Person.objects.create(name="Bar") self.Person.objects.create(name="Baz") - self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 3) + assert self.Person.objects.count(with_limit_and_skip=True) == 3 - self.Person.objects.create(name="Foo_1") - self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 4) + newPerson = self.Person.objects.create(name="Foo_1") + assert self.Person.objects.count(with_limit_and_skip=True) == 4 def test_no_cursor_timeout(self): qs = self.Person.objects() diff --git a/tests/queryset/test_transform.py b/tests/queryset/test_transform.py index 8207351d..be28c3b8 100644 --- a/tests/queryset/test_transform.py +++ b/tests/queryset/test_transform.py @@ -4,6 +4,7 @@ from bson.son import SON from mongoengine import * from mongoengine.queryset import Q, transform +import pytest class TestTransform(unittest.TestCase): @@ -13,23 +14,16 @@ class TestTransform(unittest.TestCase): def test_transform_query(self): """Ensure that the _transform_query function operates correctly. """ - self.assertEqual( - transform.query(name="test", age=30), {"name": "test", "age": 30} - ) - self.assertEqual(transform.query(age__lt=30), {"age": {"$lt": 30}}) - self.assertEqual( - transform.query(age__gt=20, age__lt=50), {"age": {"$gt": 20, "$lt": 50}} - ) - self.assertEqual( - transform.query(age=20, age__gt=50), - {"$and": [{"age": {"$gt": 50}}, {"age": 20}]}, - ) - self.assertEqual( - transform.query(friend__age__gte=30), {"friend.age": {"$gte": 30}} - ) - self.assertEqual( - transform.query(name__exists=True), {"name": {"$exists": True}} - ) + assert transform.query(name="test", age=30) == {"name": "test", "age": 30} + assert transform.query(age__lt=30) == {"age": {"$lt": 30}} + assert transform.query(age__gt=20, age__lt=50) == { + "age": {"$gt": 20, "$lt": 50} + } + assert transform.query(age=20, age__gt=50) == { + "$and": [{"age": {"$gt": 50}}, {"age": 20}] + } + assert transform.query(friend__age__gte=30) == {"friend.age": {"$gte": 30}} + assert transform.query(name__exists=True) == {"name": {"$exists": True}} def test_transform_update(self): class LisDoc(Document): @@ -54,17 +48,17 @@ class TestTransform(unittest.TestCase): ("push", "$push"), ): update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) - self.assertIsInstance(update[v]["dictField.test"], dict) + assert isinstance(update[v]["dictField.test"], dict) # Update special cases update = transform.update(DicDoc, unset__dictField__test=doc) - self.assertEqual(update["$unset"]["dictField.test"], 1) + assert update["$unset"]["dictField.test"] == 1 update = transform.update(DicDoc, pull__dictField__test=doc) - self.assertIsInstance(update["$pull"]["dictField"]["test"], dict) + assert isinstance(update["$pull"]["dictField"]["test"], dict) update = transform.update(LisDoc, pull__foo__in=["a"]) - self.assertEqual(update, {"$pull": {"foo": {"$in": ["a"]}}}) + assert update == {"$pull": {"foo": {"$in": ["a"]}}} def test_transform_update_push(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" @@ -73,10 +67,10 @@ class TestTransform(unittest.TestCase): tags = ListField(StringField()) update = transform.update(BlogPost, push__tags=["mongo", "db"]) - self.assertEqual(update, {"$push": {"tags": ["mongo", "db"]}}) + assert update == {"$push": {"tags": ["mongo", "db"]}} update = transform.update(BlogPost, push_all__tags=["mongo", "db"]) - self.assertEqual(update, {"$push": {"tags": {"$each": ["mongo", "db"]}}}) + assert update == {"$push": {"tags": {"$each": ["mongo", "db"]}}} def test_transform_update_no_operator_default_to_set(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" @@ -85,7 +79,7 @@ class TestTransform(unittest.TestCase): tags = ListField(StringField()) update = transform.update(BlogPost, tags=["mongo", "db"]) - self.assertEqual(update, {"$set": {"tags": ["mongo", "db"]}}) + assert update == {"$set": {"tags": ["mongo", "db"]}} def test_query_field_name(self): """Ensure that the correct field name is used when querying. @@ -106,18 +100,18 @@ class TestTransform(unittest.TestCase): post = BlogPost(**data) post.save() - self.assertIn("postTitle", BlogPost.objects(title=data["title"])._query) - self.assertFalse("title" in BlogPost.objects(title=data["title"])._query) - self.assertEqual(BlogPost.objects(title=data["title"]).count(), 1) + assert "postTitle" in BlogPost.objects(title=data["title"])._query + assert not ("title" in BlogPost.objects(title=data["title"])._query) + assert BlogPost.objects(title=data["title"]).count() == 1 - self.assertIn("_id", BlogPost.objects(pk=post.id)._query) - self.assertEqual(BlogPost.objects(pk=post.id).count(), 1) + assert "_id" in BlogPost.objects(pk=post.id)._query + assert BlogPost.objects(pk=post.id).count() == 1 - self.assertIn( - "postComments.commentContent", - BlogPost.objects(comments__content="test")._query, + assert ( + "postComments.commentContent" + in BlogPost.objects(comments__content="test")._query ) - self.assertEqual(BlogPost.objects(comments__content="test").count(), 1) + assert BlogPost.objects(comments__content="test").count() == 1 BlogPost.drop_collection() @@ -135,9 +129,9 @@ class TestTransform(unittest.TestCase): post = BlogPost(**data) post.save() - self.assertIn("_id", BlogPost.objects(pk=data["title"])._query) - self.assertIn("_id", BlogPost.objects(title=data["title"])._query) - self.assertEqual(BlogPost.objects(pk=data["title"]).count(), 1) + assert "_id" in BlogPost.objects(pk=data["title"])._query + assert "_id" in BlogPost.objects(title=data["title"])._query + assert BlogPost.objects(pk=data["title"]).count() == 1 BlogPost.drop_collection() @@ -163,7 +157,7 @@ class TestTransform(unittest.TestCase): q2 = B.objects.filter(a__in=[a1, a2]) q2 = q2.filter(a=a1)._query - self.assertEqual(q1, q2) + assert q1 == q2 def test_raw_query_and_Q_objects(self): """ @@ -179,11 +173,11 @@ class TestTransform(unittest.TestCase): meta = {"allow_inheritance": False} query = Foo.objects(__raw__={"$nor": [{"name": "bar"}]})._query - self.assertEqual(query, {"$nor": [{"name": "bar"}]}) + assert query == {"$nor": [{"name": "bar"}]} q1 = {"$or": [{"a": 1}, {"b": 1}]} query = Foo.objects(Q(__raw__=q1) & Q(c=1))._query - self.assertEqual(query, {"$or": [{"a": 1}, {"b": 1}], "c": 1}) + assert query == {"$or": [{"a": 1}, {"b": 1}], "c": 1} def test_raw_and_merging(self): class Doc(Document): @@ -200,51 +194,39 @@ class TestTransform(unittest.TestCase): } )._query - self.assertEqual( - raw_query, - { - "deleted": False, - "scraped": "yes", - "$nor": [ - {"views.extracted": "no"}, - {"attachments.views.extracted": "no"}, - ], - }, - ) + assert raw_query == { + "deleted": False, + "scraped": "yes", + "$nor": [{"views.extracted": "no"}, {"attachments.views.extracted": "no"}], + } def test_geojson_PointField(self): class Location(Document): loc = PointField() update = transform.update(Location, set__loc=[1, 2]) - self.assertEqual( - update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} - ) + assert update == {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} update = transform.update( Location, set__loc={"type": "Point", "coordinates": [1, 2]} ) - self.assertEqual( - update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} - ) + assert update == {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} def test_geojson_LineStringField(self): class Location(Document): line = LineStringField() update = transform.update(Location, set__line=[[1, 2], [2, 2]]) - self.assertEqual( - update, - {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}, - ) + assert update == { + "$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}} + } update = transform.update( Location, set__line={"type": "LineString", "coordinates": [[1, 2], [2, 2]]} ) - self.assertEqual( - update, - {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}, - ) + assert update == { + "$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}} + } def test_geojson_PolygonField(self): class Location(Document): @@ -253,17 +235,14 @@ class TestTransform(unittest.TestCase): update = transform.update( Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]] ) - self.assertEqual( - update, - { - "$set": { - "poly": { - "type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], - } + assert update == { + "$set": { + "poly": { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], } - }, - ) + } + } update = transform.update( Location, @@ -272,17 +251,14 @@ class TestTransform(unittest.TestCase): "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], }, ) - self.assertEqual( - update, - { - "$set": { - "poly": { - "type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], - } + assert update == { + "$set": { + "poly": { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], } - }, - ) + } + } def test_type(self): class Doc(Document): @@ -291,10 +267,10 @@ class TestTransform(unittest.TestCase): Doc(df=True).save() Doc(df=7).save() Doc(df="df").save() - self.assertEqual(Doc.objects(df__type=1).count(), 0) # double - self.assertEqual(Doc.objects(df__type=8).count(), 1) # bool - self.assertEqual(Doc.objects(df__type=2).count(), 1) # str - self.assertEqual(Doc.objects(df__type=16).count(), 1) # int + assert Doc.objects(df__type=1).count() == 0 # double + assert Doc.objects(df__type=8).count() == 1 # bool + assert Doc.objects(df__type=2).count() == 1 # str + assert Doc.objects(df__type=16).count() == 1 # int def test_last_field_name_like_operator(self): class EmbeddedItem(EmbeddedDocument): @@ -309,12 +285,12 @@ class TestTransform(unittest.TestCase): doc = Doc(item=EmbeddedItem(type="axe", name="Heroic axe")) doc.save() - self.assertEqual(1, Doc.objects(item__type__="axe").count()) - self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) + assert 1 == Doc.objects(item__type__="axe").count() + assert 1 == Doc.objects(item__name__="Heroic axe").count() Doc.objects(id=doc.id).update(set__item__type__="sword") - self.assertEqual(1, Doc.objects(item__type__="sword").count()) - self.assertEqual(0, Doc.objects(item__type__="axe").count()) + assert 1 == Doc.objects(item__type__="sword").count() + assert 0 == Doc.objects(item__type__="axe").count() def test_understandable_error_raised(self): class Event(Document): @@ -324,7 +300,7 @@ class TestTransform(unittest.TestCase): box = [(35.0, -125.0), (40.0, -100.0)] # I *meant* to execute location__within_box=box events = Event.objects(location__within=box) - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): events.count() def test_update_pull_for_list_fields(self): @@ -347,24 +323,20 @@ class TestTransform(unittest.TestCase): word = Word(word="abc", index=1) update = transform.update(MainDoc, pull__content__text=word) - self.assertEqual( - update, {"$pull": {"content.text": SON([("word", u"abc"), ("index", 1)])}} - ) + assert update == { + "$pull": {"content.text": SON([("word", u"abc"), ("index", 1)])} + } update = transform.update(MainDoc, pull__content__heading="xyz") - self.assertEqual(update, {"$pull": {"content.heading": "xyz"}}) + assert update == {"$pull": {"content.heading": "xyz"}} update = transform.update(MainDoc, pull__content__text__word__in=["foo", "bar"]) - self.assertEqual( - update, {"$pull": {"content.text": {"word": {"$in": ["foo", "bar"]}}}} - ) + assert update == {"$pull": {"content.text": {"word": {"$in": ["foo", "bar"]}}}} update = transform.update( MainDoc, pull__content__text__word__nin=["foo", "bar"] ) - self.assertEqual( - update, {"$pull": {"content.text": {"word": {"$nin": ["foo", "bar"]}}}} - ) + assert update == {"$pull": {"content.text": {"word": {"$nin": ["foo", "bar"]}}}} if __name__ == "__main__": diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index acadabd4..a41f9278 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -7,6 +7,7 @@ from bson import ObjectId from mongoengine import * from mongoengine.errors import InvalidQueryError from mongoengine.queryset import Q +import pytest class TestQ(unittest.TestCase): @@ -35,10 +36,10 @@ class TestQ(unittest.TestCase): age = IntField() query = {"$or": [{"age": {"$gte": 18}}, {"name": "test"}]} - self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) + assert (q1 | q2 | q3 | q4 | q5).to_query(Person) == query query = {"age": {"$gte": 18}, "name": "test"} - self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) + assert (q1 & q2 & q3 & q4 & q5).to_query(Person) == query def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" @@ -53,8 +54,8 @@ class TestQ(unittest.TestCase): user = User.objects.create() Post.objects.create(created_user=user) - self.assertEqual(Post.objects.filter(created_user=user).count(), 1) - self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + assert Post.objects.filter(created_user=user).count() == 1 + assert Post.objects.filter(Q(created_user=user)).count() == 1 def test_and_combination(self): """Ensure that Q-objects correctly AND together. @@ -65,12 +66,10 @@ class TestQ(unittest.TestCase): y = StringField() query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) - self.assertEqual(query, {"$and": [{"x": {"$lt": 7}}, {"x": {"$lt": 3}}]}) + assert query == {"$and": [{"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} query = (Q(y="a") & Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) - self.assertEqual( - query, {"$and": [{"y": "a"}, {"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} - ) + assert query == {"$and": [{"y": "a"}, {"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} # Check normal cases work without an error query = Q(x__lt=7) & Q(x__gt=3) @@ -78,7 +77,7 @@ class TestQ(unittest.TestCase): q1 = Q(x__lt=7) q2 = Q(x__gt=3) query = (q1 & q2).to_query(TestDoc) - self.assertEqual(query, {"x": {"$lt": 7, "$gt": 3}}) + assert query == {"x": {"$lt": 7, "$gt": 3}} # More complex nested example query = Q(x__lt=100) & Q(y__ne="NotMyString") @@ -87,7 +86,7 @@ class TestQ(unittest.TestCase): "x": {"$lt": 100, "$gt": -100}, "y": {"$ne": "NotMyString", "$in": ["a", "b", "c"]}, } - self.assertEqual(query.to_query(TestDoc), mongo_query) + assert query.to_query(TestDoc) == mongo_query def test_or_combination(self): """Ensure that Q-objects correctly OR together. @@ -99,7 +98,7 @@ class TestQ(unittest.TestCase): q1 = Q(x__lt=3) q2 = Q(x__gt=7) query = (q1 | q2).to_query(TestDoc) - self.assertEqual(query, {"$or": [{"x": {"$lt": 3}}, {"x": {"$gt": 7}}]}) + assert query == {"$or": [{"x": {"$lt": 3}}, {"x": {"$gt": 7}}]} def test_and_or_combination(self): """Ensure that Q-objects handle ANDing ORed components. @@ -113,15 +112,12 @@ class TestQ(unittest.TestCase): query = Q(x__gt=0) | Q(x__exists=False) query &= Q(x__lt=100) - self.assertEqual( - query.to_query(TestDoc), - { - "$and": [ - {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, - {"x": {"$lt": 100}}, - ] - }, - ) + assert query.to_query(TestDoc) == { + "$and": [ + {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, + {"x": {"$lt": 100}}, + ] + } q1 = Q(x__gt=0) | Q(x__exists=False) q2 = Q(x__lt=100) | Q(y=True) @@ -131,16 +127,13 @@ class TestQ(unittest.TestCase): TestDoc(x=10).save() TestDoc(y=True).save() - self.assertEqual( - query, - { - "$and": [ - {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, - {"$or": [{"x": {"$lt": 100}}, {"y": True}]}, - ] - }, - ) - self.assertEqual(2, TestDoc.objects(q1 & q2).count()) + assert query == { + "$and": [ + {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, + {"$or": [{"x": {"$lt": 100}}, {"y": True}]}, + ] + } + assert 2 == TestDoc.objects(q1 & q2).count() def test_or_and_or_combination(self): """Ensure that Q-objects handle ORing ANDed ORed components. :) @@ -160,26 +153,23 @@ class TestQ(unittest.TestCase): q2 = Q(x__lt=100) & (Q(y=False) | Q(y__exists=False)) query = (q1 | q2).to_query(TestDoc) - self.assertEqual( - query, - { - "$or": [ - { - "$and": [ - {"x": {"$gt": 0}}, - {"$or": [{"y": True}, {"y": {"$exists": False}}]}, - ] - }, - { - "$and": [ - {"x": {"$lt": 100}}, - {"$or": [{"y": False}, {"y": {"$exists": False}}]}, - ] - }, - ] - }, - ) - self.assertEqual(2, TestDoc.objects(q1 | q2).count()) + assert query == { + "$or": [ + { + "$and": [ + {"x": {"$gt": 0}}, + {"$or": [{"y": True}, {"y": {"$exists": False}}]}, + ] + }, + { + "$and": [ + {"x": {"$lt": 100}}, + {"$or": [{"y": False}, {"y": {"$exists": False}}]}, + ] + }, + ] + } + assert 2 == TestDoc.objects(q1 | q2).count() def test_multiple_occurence_in_field(self): class Test(Document): @@ -192,8 +182,8 @@ class TestQ(unittest.TestCase): q3 = q1 & q2 query = q3.to_query(Test) - self.assertEqual(query["$and"][0], q1.to_query(Test)) - self.assertEqual(query["$and"][1], q2.to_query(Test)) + assert query["$and"][0] == q1.to_query(Test) + assert query["$and"][1] == q2.to_query(Test) def test_q_clone(self): class TestDoc(Document): @@ -207,15 +197,15 @@ class TestQ(unittest.TestCase): # Check normal cases work without an error test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) - self.assertEqual(test.count(), 3) + assert test.count() == 3 test2 = test.clone() - self.assertEqual(test2.count(), 3) - self.assertNotEqual(test2, test) + assert test2.count() == 3 + assert test2 != test test3 = test2.filter(x=6) - self.assertEqual(test3.count(), 1) - self.assertEqual(test.count(), 3) + assert test3.count() == 1 + assert test.count() == 3 def test_q(self): """Ensure that Q objects may be used to query for documents. @@ -252,19 +242,19 @@ class TestQ(unittest.TestCase): # Check ObjectId lookup works obj = BlogPost.objects(id=post1.id).first() - self.assertEqual(obj, post1) + assert obj == post1 # Check Q object combination with one does not exist q = BlogPost.objects(Q(title="Test 5") | Q(published=True)) posts = [post.id for post in q] published_posts = (post2, post3) - self.assertTrue(all(obj.id in posts for obj in published_posts)) + assert all(obj.id in posts for obj in published_posts) q = BlogPost.objects(Q(title="Test 1") | Q(published=True)) posts = [post.id for post in q] published_posts = (post1, post2, post3, post5, post6) - self.assertTrue(all(obj.id in posts for obj in published_posts)) + assert all(obj.id in posts for obj in published_posts) # Check Q object combination date = datetime.datetime(2010, 1, 10) @@ -272,9 +262,9 @@ class TestQ(unittest.TestCase): posts = [post.id for post in q] published_posts = (post1, post2, post3, post4) - self.assertTrue(all(obj.id in posts for obj in published_posts)) + assert all(obj.id in posts for obj in published_posts) - self.assertFalse(any(obj.id in posts for obj in [post5, post6])) + assert not any(obj.id in posts for obj in [post5, post6]) BlogPost.drop_collection() @@ -284,15 +274,15 @@ class TestQ(unittest.TestCase): self.Person(name="user3", age=30).save() self.Person(name="user4", age=40).save() - self.assertEqual(self.Person.objects(Q(age__in=[20])).count(), 2) - self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) + assert self.Person.objects(Q(age__in=[20])).count() == 2 + assert self.Person.objects(Q(age__in=[20, 30])).count() == 3 # Test invalid query objs - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): self.Person.objects("user1") # filter should fail, too - with self.assertRaises(InvalidQueryError): + with pytest.raises(InvalidQueryError): self.Person.objects.filter("user1") def test_q_regex(self): @@ -302,31 +292,31 @@ class TestQ(unittest.TestCase): person.save() obj = self.Person.objects(Q(name=re.compile("^Gui"))).first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(Q(name=re.compile("^gui"))).first() - self.assertEqual(obj, None) + assert obj == None obj = self.Person.objects(Q(name=re.compile("^gui", re.I))).first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(Q(name__not=re.compile("^bob"))).first() - self.assertEqual(obj, person) + assert obj == person obj = self.Person.objects(Q(name__not=re.compile("^Gui"))).first() - self.assertEqual(obj, None) + assert obj == None def test_q_repr(self): - self.assertEqual(repr(Q()), "Q(**{})") - self.assertEqual(repr(Q(name="test")), "Q(**{'name': 'test'})") + assert repr(Q()) == "Q(**{})" + assert repr(Q(name="test")) == "Q(**{'name': 'test'})" - self.assertEqual( - repr(Q(name="test") & Q(age__gte=18)), - "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))", + assert ( + repr(Q(name="test") & Q(age__gte=18)) + == "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))" ) - self.assertEqual( - repr(Q(name="test") | Q(age__gte=18)), - "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))", + assert ( + repr(Q(name="test") | Q(age__gte=18)) + == "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))" ) def test_q_lists(self): @@ -341,8 +331,8 @@ class TestQ(unittest.TestCase): BlogPost(tags=["python", "mongo"]).save() BlogPost(tags=["python"]).save() - self.assertEqual(BlogPost.objects(Q(tags="mongo")).count(), 1) - self.assertEqual(BlogPost.objects(Q(tags="python")).count(), 2) + assert BlogPost.objects(Q(tags="mongo")).count() == 1 + assert BlogPost.objects(Q(tags="python")).count() == 2 BlogPost.drop_collection() @@ -355,12 +345,12 @@ class TestQ(unittest.TestCase): pk = ObjectId() User(email="example@example.com", pk=pk).save() - self.assertEqual( - 1, - User.objects.filter(Q(email="example@example.com") | Q(name="John Doe")) + assert ( + 1 + == User.objects.filter(Q(email="example@example.com") | Q(name="John Doe")) .limit(2) .filter(pk=pk) - .count(), + .count() ) def test_chained_q_or_filtering(self): @@ -376,14 +366,12 @@ class TestQ(unittest.TestCase): Item(postables=[Post(name="a"), Post(name="c")]).save() Item(postables=[Post(name="a"), Post(name="b"), Post(name="c")]).save() - self.assertEqual( - Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2 + assert ( + Item.objects(Q(postables__name="a") & Q(postables__name="b")).count() == 2 ) - self.assertEqual( - Item.objects.filter(postables__name="a") - .filter(postables__name="b") - .count(), - 2, + assert ( + Item.objects.filter(postables__name="a").filter(postables__name="b").count() + == 2 ) diff --git a/tests/test_common.py b/tests/test_common.py index 28f0b992..6b6f18de 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,7 @@ import unittest +import pytest + from mongoengine import Document from mongoengine.common import _import_class @@ -7,8 +9,8 @@ from mongoengine.common import _import_class class TestCommon(unittest.TestCase): def test__import_class(self): doc_cls = _import_class("Document") - self.assertIs(doc_cls, Document) + assert doc_cls is Document def test__import_class_raise_if_not_known(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _import_class("UnknownClass") diff --git a/tests/test_connection.py b/tests/test_connection.py index 1519a835..c73b67d1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,6 +29,7 @@ from mongoengine.connection import ( get_connection, get_db, ) +import pytest def get_tz_awareness(connection): @@ -54,15 +55,15 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" connect("mongoenginetest2", alias="testdb") conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) def test_connect_disconnect_works_properly(self): class History1(Document): @@ -82,31 +83,27 @@ class ConnectionTest(unittest.TestCase): h = History1(name="default").save() h1 = History2(name="db1").save() - self.assertEqual( - list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] - ) - self.assertEqual( - list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] - ) + assert list(History1.objects().as_pymongo()) == [ + {"_id": h.id, "name": "default"} + ] + assert list(History2.objects().as_pymongo()) == [{"_id": h1.id, "name": "db1"}] disconnect("db1") disconnect("db2") - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): list(History1.objects().as_pymongo()) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): list(History2.objects().as_pymongo()) connect("db1", alias="db1") connect("db2", alias="db2") - self.assertEqual( - list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] - ) - self.assertEqual( - list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] - ) + assert list(History1.objects().as_pymongo()) == [ + {"_id": h.id, "name": "default"} + ] + assert list(History2.objects().as_pymongo()) == [{"_id": h1.id, "name": "db1"}] def test_connect_different_documents_to_different_database(self): class History(Document): @@ -132,39 +129,35 @@ class ConnectionTest(unittest.TestCase): h1 = History1(name="db1").save() h2 = History2(name="db2").save() - self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) - self.assertEqual(History1._collection.database.name, "db1") - self.assertEqual(History2._collection.database.name, "db2") + assert History._collection.database.name == DEFAULT_DATABASE_NAME + assert History1._collection.database.name == "db1" + assert History2._collection.database.name == "db2" - self.assertEqual( - list(History.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] - ) - self.assertEqual( - list(History1.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] - ) - self.assertEqual( - list(History2.objects().as_pymongo()), [{"_id": h2.id, "name": "db2"}] - ) + assert list(History.objects().as_pymongo()) == [ + {"_id": h.id, "name": "default"} + ] + assert list(History1.objects().as_pymongo()) == [{"_id": h1.id, "name": "db1"}] + assert list(History2.objects().as_pymongo()) == [{"_id": h2.id, "name": "db2"}] def test_connect_fails_if_connect_2_times_with_default_alias(self): connect("mongoenginetest") - with self.assertRaises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as ctx_err: connect("mongoenginetest2") - self.assertEqual( - "A different connection with alias `default` was already registered. Use disconnect() first", - str(ctx_err.exception), + assert ( + "A different connection with alias `default` was already registered. Use disconnect() first" + == str(ctx_err.exception) ) def test_connect_fails_if_connect_2_times_with_custom_alias(self): connect("mongoenginetest", alias="alias1") - with self.assertRaises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as ctx_err: connect("mongoenginetest2", alias="alias1") - self.assertEqual( - "A different connection with alias `alias1` was already registered. Use disconnect() first", - str(ctx_err.exception), + assert ( + "A different connection with alias `alias1` was already registered. Use disconnect() first" + == str(ctx_err.exception) ) def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( @@ -175,25 +168,25 @@ class ConnectionTest(unittest.TestCase): db_alias = "alias1" connect(db=db_name, alias=db_alias, host="localhost", port=27017) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): connect(host="mongodb://localhost:27017/%s" % db_name, alias=db_alias) def test_connect_passes_silently_connect_multiple_times_with_same_config(self): # test default connection to `test` connect() connect() - self.assertEqual(len(mongoengine.connection._connections), 1) + assert len(mongoengine.connection._connections) == 1 connect("test01", alias="test01") connect("test01", alias="test01") - self.assertEqual(len(mongoengine.connection._connections), 2) + assert len(mongoengine.connection._connections) == 2 connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02") connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02") - self.assertEqual(len(mongoengine.connection._connections), 3) + assert len(mongoengine.connection._connections) == 3 def test_connect_with_invalid_db_name(self): """Ensure that connect() method fails fast if db name is invalid """ - with self.assertRaises(InvalidName): + with pytest.raises(InvalidName): connect("mongomock://localhost") def test_connect_with_db_name_external(self): @@ -203,20 +196,20 @@ class ConnectionTest(unittest.TestCase): connect("$external") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "$external") + assert isinstance(db, pymongo.database.Database) + assert db.name == "$external" connect("$external", alias="testdb") conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) def test_connect_with_invalid_db_name_type(self): """Ensure that connect() method fails fast if db name has invalid type """ - with self.assertRaises(TypeError): + with pytest.raises(TypeError): non_string_db_name = ["e. g. list instead of a string"] connect(non_string_db_name) @@ -230,11 +223,11 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest", host="mongomock://localhost") conn = get_connection() - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect("mongoenginetest2", host="mongomock://localhost", alias="testdb2") conn = get_connection("testdb2") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( "mongoenginetest3", @@ -243,11 +236,11 @@ class ConnectionTest(unittest.TestCase): alias="testdb3", ) conn = get_connection("testdb3") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect("mongoenginetest4", is_mock=True, alias="testdb4") conn = get_connection("testdb4") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host="mongodb://localhost:27017/mongoenginetest5", @@ -255,11 +248,11 @@ class ConnectionTest(unittest.TestCase): alias="testdb5", ) conn = get_connection("testdb5") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect(host="mongomock://localhost:27017/mongoenginetest6", alias="testdb6") conn = get_connection("testdb6") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host="mongomock://localhost:27017/mongoenginetest7", @@ -267,7 +260,7 @@ class ConnectionTest(unittest.TestCase): alias="testdb7", ) conn = get_connection("testdb7") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) def test_default_database_with_mocking(self): """Ensure that the default database is correctly set when using mongomock. @@ -286,8 +279,8 @@ class ConnectionTest(unittest.TestCase): some_document = SomeDocument() # database won't exist until we save a document some_document.save() - self.assertEqual(conn.get_default_database().name, "mongoenginetest") - self.assertEqual(conn.list_database_names()[0], "mongoenginetest") + assert conn.get_default_database().name == "mongoenginetest" + assert conn.database_names()[0] == "mongoenginetest" def test_connect_with_host_list(self): """Ensure that the connect() method works when host is a list @@ -301,22 +294,22 @@ class ConnectionTest(unittest.TestCase): connect(host=["mongomock://localhost"]) conn = get_connection() - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect(host=["mongodb://localhost"], is_mock=True, alias="testdb2") conn = get_connection("testdb2") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect(host=["localhost"], is_mock=True, alias="testdb3") conn = get_connection("testdb3") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host=["mongomock://localhost:27017", "mongomock://localhost:27018"], alias="testdb4", ) conn = get_connection("testdb4") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host=["mongodb://localhost:27017", "mongodb://localhost:27018"], @@ -324,13 +317,13 @@ class ConnectionTest(unittest.TestCase): alias="testdb5", ) conn = get_connection("testdb5") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) connect( host=["localhost:27017", "localhost:27018"], is_mock=True, alias="testdb6" ) conn = get_connection("testdb6") - self.assertIsInstance(conn, mongomock.MongoClient) + assert isinstance(conn, mongomock.MongoClient) def test_disconnect_cleans_globals(self): """Ensure that the disconnect() method cleans the globals objects""" @@ -340,20 +333,20 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest") - self.assertEqual(len(connections), 1) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 1) + assert len(connections) == 1 + assert len(dbs) == 0 + assert len(connection_settings) == 1 class TestDoc(Document): pass TestDoc.drop_collection() # triggers the db - self.assertEqual(len(dbs), 1) + assert len(dbs) == 1 disconnect() - self.assertEqual(len(connections), 0) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 0) + assert len(connections) == 0 + assert len(dbs) == 0 + assert len(connection_settings) == 0 def test_disconnect_cleans_cached_collection_attribute_in_document(self): """Ensure that the disconnect() method works properly""" @@ -362,22 +355,20 @@ class ConnectionTest(unittest.TestCase): class History(Document): pass - self.assertIsNone(History._collection) + assert History._collection is None History.drop_collection() History.objects.first() # will trigger the caching of _collection attribute - self.assertIsNotNone(History._collection) + assert History._collection is not None disconnect() - self.assertIsNone(History._collection) + assert History._collection is None - with self.assertRaises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as ctx_err: History.objects.first() - self.assertEqual( - "You have not defined a default connection", str(ctx_err.exception) - ) + assert "You have not defined a default connection" == str(ctx_err.exception) def test_connect_disconnect_works_on_same_document(self): """Ensure that the connect/disconnect works properly with a single Document""" @@ -399,7 +390,7 @@ class ConnectionTest(unittest.TestCase): disconnect() # Make sure save doesnt work at this stage - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): User(name="Wont work").save() # Save in db2 @@ -408,13 +399,13 @@ class ConnectionTest(unittest.TestCase): disconnect() db1_users = list(client[db1].user.find()) - self.assertEqual(db1_users, [{"_id": user1.id, "name": "John is in db1"}]) + assert db1_users == [{"_id": user1.id, "name": "John is in db1"}] db2_users = list(client[db2].user.find()) - self.assertEqual(db2_users, [{"_id": user2.id, "name": "Bob is in db2"}]) + assert db2_users == [{"_id": user2.id, "name": "Bob is in db2"}] def test_disconnect_silently_pass_if_alias_does_not_exist(self): connections = mongoengine.connection._connections - self.assertEqual(len(connections), 0) + assert len(connections) == 0 disconnect(alias="not_exist") def test_disconnect_all(self): @@ -437,26 +428,26 @@ class ConnectionTest(unittest.TestCase): History1.drop_collection() History1.objects.first() - self.assertIsNotNone(History._collection) - self.assertIsNotNone(History1._collection) + assert History._collection is not None + assert History1._collection is not None - self.assertEqual(len(connections), 2) - self.assertEqual(len(dbs), 2) - self.assertEqual(len(connection_settings), 2) + assert len(connections) == 2 + assert len(dbs) == 2 + assert len(connection_settings) == 2 disconnect_all() - self.assertIsNone(History._collection) - self.assertIsNone(History1._collection) + assert History._collection is None + assert History1._collection is None - self.assertEqual(len(connections), 0) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 0) + assert len(connections) == 0 + assert len(dbs) == 0 + assert len(connection_settings) == 0 - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): History.objects.first() - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): History1.objects.first() def test_disconnect_all_silently_pass_if_no_connection_exist(self): @@ -473,7 +464,7 @@ class ConnectionTest(unittest.TestCase): expected_connection.server_info() - self.assertEqual(expected_connection, actual_connection) + assert expected_connection == actual_connection def test_connect_uri(self): """Ensure that the connect() method works properly with URIs.""" @@ -490,11 +481,11 @@ class ConnectionTest(unittest.TestCase): ) conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" c.admin.system.users.delete_many({}) c.mongoenginetest.system.users.delete_many({}) @@ -506,11 +497,11 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest", host="mongodb://localhost/") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" def test_connect_uri_default_db(self): """Ensure connect() defaults to the right database name if @@ -519,11 +510,11 @@ class ConnectionTest(unittest.TestCase): connect(host="mongodb://localhost/") conn = get_connection() - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "test") + assert isinstance(db, pymongo.database.Database) + assert db.name == "test" def test_uri_without_credentials_doesnt_override_conn_settings(self): """Ensure connect() uses the username & password params if the URI @@ -536,7 +527,8 @@ class ConnectionTest(unittest.TestCase): # OperationFailure means that mongoengine attempted authentication # w/ the provided username/password and failed - that's the desired # behavior. If the MongoDB URI would override the credentials - self.assertRaises(OperationFailure, get_db) + with pytest.raises(OperationFailure): + get_db() def test_connect_uri_with_authsource(self): """Ensure that the connect() method works well with `authSource` @@ -554,7 +546,8 @@ class ConnectionTest(unittest.TestCase): alias="test1", host="mongodb://username2:password@localhost/mongoenginetest", ) - self.assertRaises(OperationFailure, test_conn.server_info) + with pytest.raises(OperationFailure): + test_conn.server_info() # Authentication succeeds with "authSource" authd_conn = connect( @@ -566,8 +559,8 @@ class ConnectionTest(unittest.TestCase): ), ) db = get_db("test2") - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" # Clear all users authd_conn.admin.system.users.delete_many({}) @@ -577,13 +570,14 @@ class ConnectionTest(unittest.TestCase): """ register_connection("testdb", "mongoenginetest2") - self.assertRaises(ConnectionFailure, get_connection) + with pytest.raises(ConnectionFailure): + get_connection() conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) db = get_db("testdb") - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "mongoenginetest2") + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest2" def test_register_connection_defaults(self): """Ensure that defaults are used when the host and port are None. @@ -591,18 +585,18 @@ class ConnectionTest(unittest.TestCase): register_connection("testdb", "mongoenginetest", host=None, port=None) conn = get_connection("testdb") - self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + assert isinstance(conn, pymongo.mongo_client.MongoClient) def test_connection_kwargs(self): """Ensure that connection kwargs get passed to pymongo.""" connect("mongoenginetest", alias="t1", tz_aware=True) conn = get_connection("t1") - self.assertTrue(get_tz_awareness(conn)) + assert get_tz_awareness(conn) connect("mongoenginetest2", alias="t2") conn = get_connection("t2") - self.assertFalse(get_tz_awareness(conn)) + assert not get_tz_awareness(conn) def test_connection_pool_via_kwarg(self): """Ensure we can specify a max connection pool size using @@ -613,7 +607,7 @@ class ConnectionTest(unittest.TestCase): conn = connect( "mongoenginetest", alias="max_pool_size_via_kwarg", **pool_size_kwargs ) - self.assertEqual(conn.max_pool_size, 100) + assert conn.max_pool_size == 100 def test_connection_pool_via_uri(self): """Ensure we can specify a max connection pool size using @@ -623,7 +617,7 @@ class ConnectionTest(unittest.TestCase): host="mongodb://localhost/test?maxpoolsize=100", alias="max_pool_size_via_uri", ) - self.assertEqual(conn.max_pool_size, 100) + assert conn.max_pool_size == 100 def test_write_concern(self): """Ensure write concern can be specified in connect() via @@ -642,18 +636,18 @@ class ConnectionTest(unittest.TestCase): """ c = connect(host="mongodb://localhost/test?replicaSet=local-rs") db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "test") + assert isinstance(db, pymongo.database.Database) + assert db.name == "test" def test_connect_with_replicaset_via_kwargs(self): """Ensure connect() works when specifying a replicaSet via the connection kwargs """ c = connect(replicaset="local-rs") - self.assertEqual(c._MongoClient__options.replica_set_name, "local-rs") + assert c._MongoClient__options.replica_set_name == "local-rs" db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, "test") + assert isinstance(db, pymongo.database.Database) + assert db.name == "test" def test_connect_tz_aware(self): connect("mongoenginetest", tz_aware=True) @@ -666,13 +660,13 @@ class ConnectionTest(unittest.TestCase): DateDoc(the_date=d).save() date_doc = DateDoc.objects.first() - self.assertEqual(d, date_doc.the_date) + assert d == date_doc.the_date def test_read_preference_from_parse(self): conn = connect( host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred" ) - self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) + assert conn.read_preference == ReadPreference.SECONDARY_PREFERRED def test_multiple_connection_settings(self): connect("mongoenginetest", alias="t1", host="localhost") @@ -680,27 +674,27 @@ class ConnectionTest(unittest.TestCase): connect("mongoenginetest2", alias="t2", host="127.0.0.1") mongo_connections = mongoengine.connection._connections - self.assertEqual(len(mongo_connections.items()), 2) - self.assertIn("t1", mongo_connections.keys()) - self.assertIn("t2", mongo_connections.keys()) + assert len(mongo_connections.items()) == 2 + assert "t1" in mongo_connections.keys() + assert "t2" in mongo_connections.keys() # Handle PyMongo 3+ Async Connection # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. # Purposely not catching exception to fail test if thrown. mongo_connections["t1"].server_info() mongo_connections["t2"].server_info() - self.assertEqual(mongo_connections["t1"].address[0], "localhost") - self.assertEqual(mongo_connections["t2"].address[0], "127.0.0.1") + assert mongo_connections["t1"].address[0] == "localhost" + assert mongo_connections["t2"].address[0] == "127.0.0.1" def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): c1 = connect(alias="testdb1", db="testdb1") c2 = connect(alias="testdb2", db="testdb2") - self.assertIs(c1, c2) + assert c1 is c2 def test_connect_2_databases_uses_different_client_if_different_parameters(self): c1 = connect(alias="testdb1", db="testdb1", username="u1") c2 = connect(alias="testdb2", db="testdb2", username="u2") - self.assertIsNot(c1, c2) + assert c1 is not c2 if __name__ == "__main__": diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 32e48a70..cf4dd100 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -10,6 +10,7 @@ from mongoengine.context_managers import ( switch_db, ) from mongoengine.pymongo_support import count_documents +import pytest class ContextManagersTest(unittest.TestCase): @@ -23,20 +24,20 @@ class ContextManagersTest(unittest.TestCase): Group.drop_collection() Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() with switch_db(Group, "testdb-1") as Group: - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() Group(name="hello").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() Group.drop_collection() - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() def test_switch_collection_context_manager(self): connect("mongoenginetest") @@ -51,20 +52,20 @@ class ContextManagersTest(unittest.TestCase): Group.drop_collection() # drops in group1 Group(name="hello - group").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() with switch_collection(Group, "group1") as Group: - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() Group(name="hello - group1").save() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() Group.drop_collection() - self.assertEqual(0, Group.objects.count()) + assert 0 == Group.objects.count() - self.assertEqual(1, Group.objects.count()) + assert 1 == Group.objects.count() def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced. @@ -89,20 +90,20 @@ class ContextManagersTest(unittest.TestCase): Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields["members"]._auto_dereference) - self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference) + assert Group._fields["members"]._auto_dereference + assert not NoDeRefGroup._fields["members"]._auto_dereference with no_dereference(Group) as Group: group = Group.objects.first() for m in group.members: - self.assertNotIsInstance(m, User) - self.assertNotIsInstance(group.ref, User) - self.assertNotIsInstance(group.generic, User) + assert not isinstance(m, User) + assert not isinstance(group.ref, User) + assert not isinstance(group.generic, User) for m in group.members: - self.assertIsInstance(m, User) - self.assertIsInstance(group.ref, User) - self.assertIsInstance(group.generic, User) + assert isinstance(m, User) + assert isinstance(group.ref, User) + assert isinstance(group.generic, User) def test_no_dereference_context_manager_dbref(self): """Ensure that DBRef items in ListFields aren't dereferenced. @@ -127,18 +128,18 @@ class ContextManagersTest(unittest.TestCase): Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields["members"]._auto_dereference) - self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference) + assert Group._fields["members"]._auto_dereference + assert not NoDeRefGroup._fields["members"]._auto_dereference with no_dereference(Group) as Group: group = Group.objects.first() - self.assertTrue(all([not isinstance(m, User) for m in group.members])) - self.assertNotIsInstance(group.ref, User) - self.assertNotIsInstance(group.generic, User) + assert all([not isinstance(m, User) for m in group.members]) + assert not isinstance(group.ref, User) + assert not isinstance(group.generic, User) - self.assertTrue(all([isinstance(m, User) for m in group.members])) - self.assertIsInstance(group.ref, User) - self.assertIsInstance(group.generic, User) + assert all([isinstance(m, User) for m in group.members]) + assert isinstance(group.ref, User) + assert isinstance(group.generic, User) def test_no_sub_classes(self): class A(Document): @@ -159,32 +160,32 @@ class ContextManagersTest(unittest.TestCase): B(x=30).save() C(x=40).save() - self.assertEqual(A.objects.count(), 5) - self.assertEqual(B.objects.count(), 3) - self.assertEqual(C.objects.count(), 1) + assert A.objects.count() == 5 + assert B.objects.count() == 3 + assert C.objects.count() == 1 with no_sub_classes(A): - self.assertEqual(A.objects.count(), 2) + assert A.objects.count() == 2 for obj in A.objects: - self.assertEqual(obj.__class__, A) + assert obj.__class__ == A with no_sub_classes(B): - self.assertEqual(B.objects.count(), 2) + assert B.objects.count() == 2 for obj in B.objects: - self.assertEqual(obj.__class__, B) + assert obj.__class__ == B with no_sub_classes(C): - self.assertEqual(C.objects.count(), 1) + assert C.objects.count() == 1 for obj in C.objects: - self.assertEqual(obj.__class__, C) + assert obj.__class__ == C # Confirm context manager exit correctly - self.assertEqual(A.objects.count(), 5) - self.assertEqual(B.objects.count(), 3) - self.assertEqual(C.objects.count(), 1) + assert A.objects.count() == 5 + assert B.objects.count() == 3 + assert C.objects.count() == 1 def test_no_sub_classes_modification_to_document_class_are_temporary(self): class A(Document): @@ -194,27 +195,27 @@ class ContextManagersTest(unittest.TestCase): class B(A): z = IntField() - self.assertEqual(A._subclasses, ("A", "A.B")) + assert A._subclasses == ("A", "A.B") with no_sub_classes(A): - self.assertEqual(A._subclasses, ("A",)) - self.assertEqual(A._subclasses, ("A", "A.B")) + assert A._subclasses == ("A",) + assert A._subclasses == ("A", "A.B") - self.assertEqual(B._subclasses, ("A.B",)) + assert B._subclasses == ("A.B",) with no_sub_classes(B): - self.assertEqual(B._subclasses, ("A.B",)) - self.assertEqual(B._subclasses, ("A.B",)) + assert B._subclasses == ("A.B",) + assert B._subclasses == ("A.B",) def test_no_subclass_context_manager_does_not_swallow_exception(self): class User(Document): name = StringField() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): with no_sub_classes(User): raise TypeError() def test_query_counter_does_not_swallow_exception(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): with query_counter() as q: raise TypeError() @@ -227,10 +228,10 @@ class ContextManagersTest(unittest.TestCase): try: NEW_LEVEL = 1 db.set_profiling_level(NEW_LEVEL) - self.assertEqual(db.profiling_level(), NEW_LEVEL) + assert db.profiling_level() == NEW_LEVEL with query_counter() as q: - self.assertEqual(db.profiling_level(), 2) - self.assertEqual(db.profiling_level(), NEW_LEVEL) + assert db.profiling_level() == 2 + assert db.profiling_level() == NEW_LEVEL except Exception: db.set_profiling_level( initial_profiling_level @@ -255,33 +256,31 @@ class ContextManagersTest(unittest.TestCase): counter = 0 with query_counter() as q: - self.assertEqual(q, counter) - self.assertEqual( - q, counter - ) # Ensures previous count query did not get counted + assert q == counter + assert q == counter # Ensures previous count query did not get counted for _ in range(10): issue_1_insert_query() counter += 1 - self.assertEqual(q, counter) + assert q == counter for _ in range(4): issue_1_find_query() counter += 1 - self.assertEqual(q, counter) + assert q == counter for _ in range(3): issue_1_count_query() counter += 1 - self.assertEqual(q, counter) + assert q == counter - self.assertEqual(int(q), counter) # test __int__ - self.assertEqual(repr(q), str(int(q))) # test __repr__ - self.assertGreater(q, -1) # test __gt__ - self.assertGreaterEqual(q, int(q)) # test __gte__ - self.assertNotEqual(q, -1) - self.assertLess(q, 1000) - self.assertLessEqual(q, int(q)) + assert int(q) == counter # test __int__ + assert repr(q) == str(int(q)) # test __repr__ + assert q > -1 # test __gt__ + assert q >= int(q) # test __gte__ + assert q != -1 + assert q < 1000 + assert q <= int(q) def test_query_counter_counts_getmore_queries(self): connect("mongoenginetest") @@ -296,9 +295,9 @@ class ContextManagersTest(unittest.TestCase): ) # first batch of documents contains 101 documents with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 list(collection.find()) - self.assertEqual(q, 2) # 1st select + 1 getmore + assert q == 2 # 1st select + 1 getmore def test_query_counter_ignores_particular_queries(self): connect("mongoenginetest") @@ -308,18 +307,18 @@ class ContextManagersTest(unittest.TestCase): collection.insert_many([{"test": "garbage %s" % i} for i in range(10)]) with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 cursor = collection.find() - self.assertEqual(q, 0) # cursor wasn't opened yet + assert q == 0 # cursor wasn't opened yet _ = next(cursor) # opens the cursor and fires the find query - self.assertEqual(q, 1) + assert q == 1 cursor.close() # issues a `killcursors` query that is ignored by the context - self.assertEqual(q, 1) + assert q == 1 _ = ( db.system.indexes.find_one() ) # queries on db.system.indexes are ignored as well - self.assertEqual(q, 1) + assert q == 1 if __name__ == "__main__": diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index ff7598be..3a6029c1 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,6 +1,8 @@ import unittest from six import iterkeys +import pytest + from mongoengine import Document from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict @@ -31,48 +33,48 @@ class TestBaseDict(unittest.TestCase): dict_items = {"k": "v"} doc = MyDoc() base_dict = BaseDict(dict_items, instance=doc, name="my_name") - self.assertIsInstance(base_dict._instance, Document) - self.assertEqual(base_dict._name, "my_name") - self.assertEqual(base_dict, dict_items) + assert isinstance(base_dict._instance, Document) + assert base_dict._name == "my_name" + assert base_dict == dict_items def test_setdefault_calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict.setdefault("k", "v") - self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) + assert base_dict._instance._changed_fields == [base_dict._name] def test_popitems_calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) - self.assertEqual(base_dict.popitem(), ("k", "v")) - self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) - self.assertFalse(base_dict) + assert base_dict.popitem() == ("k", "v") + assert base_dict._instance._changed_fields == [base_dict._name] + assert not base_dict def test_pop_calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) - self.assertEqual(base_dict.pop("k"), "v") - self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) - self.assertFalse(base_dict) + assert base_dict.pop("k") == "v" + assert base_dict._instance._changed_fields == [base_dict._name] + assert not base_dict def test_pop_calls_does_not_mark_as_changed_when_it_fails(self): base_dict = self._get_basedict({"k": "v"}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): base_dict.pop("X") - self.assertFalse(base_dict._instance._changed_fields) + assert not base_dict._instance._changed_fields def test_clear_calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) base_dict.clear() - self.assertEqual(base_dict._instance._changed_fields, ["my_name"]) - self.assertEqual(base_dict, {}) + assert base_dict._instance._changed_fields == ["my_name"] + assert base_dict == {} def test___delitem___calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) del base_dict["k"] - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"]) - self.assertEqual(base_dict, {}) + assert base_dict._instance._changed_fields == ["my_name.k"] + assert base_dict == {} def test___getitem____KeyError(self): base_dict = self._get_basedict({}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): base_dict["new"] def test___getitem____simple_value(self): @@ -82,62 +84,62 @@ class TestBaseDict(unittest.TestCase): def test___getitem____sublist_gets_converted_to_BaseList(self): base_dict = self._get_basedict({"k": [0, 1, 2]}) sub_list = base_dict["k"] - self.assertEqual(sub_list, [0, 1, 2]) - self.assertIsInstance(sub_list, BaseList) - self.assertIs(sub_list._instance, base_dict._instance) - self.assertEqual(sub_list._name, "my_name.k") - self.assertEqual(base_dict._instance._changed_fields, []) + assert sub_list == [0, 1, 2] + assert isinstance(sub_list, BaseList) + assert sub_list._instance is base_dict._instance + assert sub_list._name == "my_name.k" + assert base_dict._instance._changed_fields == [] # Challenge mark_as_changed from sublist sub_list[1] = None - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.1"]) + assert base_dict._instance._changed_fields == ["my_name.k.1"] def test___getitem____subdict_gets_converted_to_BaseDict(self): base_dict = self._get_basedict({"k": {"subk": "subv"}}) sub_dict = base_dict["k"] - self.assertEqual(sub_dict, {"subk": "subv"}) - self.assertIsInstance(sub_dict, BaseDict) - self.assertIs(sub_dict._instance, base_dict._instance) - self.assertEqual(sub_dict._name, "my_name.k") - self.assertEqual(base_dict._instance._changed_fields, []) + assert sub_dict == {"subk": "subv"} + assert isinstance(sub_dict, BaseDict) + assert sub_dict._instance is base_dict._instance + assert sub_dict._name == "my_name.k" + assert base_dict._instance._changed_fields == [] # Challenge mark_as_changed from subdict sub_dict["subk"] = None - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.subk"]) + assert base_dict._instance._changed_fields == ["my_name.k.subk"] def test_get_sublist_gets_converted_to_BaseList_just_like__getitem__(self): base_dict = self._get_basedict({"k": [0, 1, 2]}) sub_list = base_dict.get("k") - self.assertEqual(sub_list, [0, 1, 2]) - self.assertIsInstance(sub_list, BaseList) + assert sub_list == [0, 1, 2] + assert isinstance(sub_list, BaseList) def test_get_returns_the_same_as___getitem__(self): base_dict = self._get_basedict({"k": [0, 1, 2]}) get_ = base_dict.get("k") getitem_ = base_dict["k"] - self.assertEqual(get_, getitem_) + assert get_ == getitem_ def test_get_default(self): base_dict = self._get_basedict({}) sentinel = object() - self.assertEqual(base_dict.get("new"), None) - self.assertIs(base_dict.get("new", sentinel), sentinel) + assert base_dict.get("new") == None + assert base_dict.get("new", sentinel) is sentinel def test___setitem___calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict["k"] = "v" - self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"]) - self.assertEqual(base_dict, {"k": "v"}) + assert base_dict._instance._changed_fields == ["my_name.k"] + assert base_dict == {"k": "v"} def test_update_calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict.update({"k": "v"}) - self.assertEqual(base_dict._instance._changed_fields, ["my_name"]) + assert base_dict._instance._changed_fields == ["my_name"] def test___setattr____not_tracked_by_changes(self): base_dict = self._get_basedict({}) base_dict.a_new_attr = "test" - self.assertEqual(base_dict._instance._changed_fields, []) + assert base_dict._instance._changed_fields == [] def test___delattr____tracked_by_changes(self): # This is probably a bug as __setattr__ is not tracked @@ -146,7 +148,7 @@ class TestBaseDict(unittest.TestCase): base_dict = self._get_basedict({}) base_dict.a_new_attr = "test" del base_dict.a_new_attr - self.assertEqual(base_dict._instance._changed_fields, ["my_name.a_new_attr"]) + assert base_dict._instance._changed_fields == ["my_name.a_new_attr"] class TestBaseList(unittest.TestCase): @@ -167,14 +169,14 @@ class TestBaseList(unittest.TestCase): list_items = [True] doc = MyDoc() base_list = BaseList(list_items, instance=doc, name="my_name") - self.assertIsInstance(base_list._instance, Document) - self.assertEqual(base_list._name, "my_name") - self.assertEqual(base_list, list_items) + assert isinstance(base_list._instance, Document) + assert base_list._name == "my_name" + assert base_list == list_items def test___iter__(self): values = [True, False, True, False] base_list = BaseList(values, instance=None, name="my_name") - self.assertEqual(values, list(base_list)) + assert values == list(base_list) def test___iter___allow_modification_while_iterating_withou_error(self): # regular list allows for this, thus this subclass must comply to that @@ -185,9 +187,9 @@ class TestBaseList(unittest.TestCase): def test_append_calls_mark_as_changed(self): base_list = self._get_baselist([]) - self.assertFalse(base_list._instance._changed_fields) + assert not base_list._instance._changed_fields base_list.append(True) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_subclass_append(self): # Due to the way mark_as_changed_wrapper is implemented @@ -200,85 +202,85 @@ class TestBaseList(unittest.TestCase): def test___getitem__using_simple_index(self): base_list = self._get_baselist([0, 1, 2]) - self.assertEqual(base_list[0], 0) - self.assertEqual(base_list[1], 1) - self.assertEqual(base_list[-1], 2) + assert base_list[0] == 0 + assert base_list[1] == 1 + assert base_list[-1] == 2 def test___getitem__using_slice(self): base_list = self._get_baselist([0, 1, 2]) - self.assertEqual(base_list[1:3], [1, 2]) - self.assertEqual(base_list[0:3:2], [0, 2]) + assert base_list[1:3] == [1, 2] + assert base_list[0:3:2] == [0, 2] def test___getitem___using_slice_returns_list(self): # Bug: using slice does not properly handles the instance # and mark_as_changed behaviour. base_list = self._get_baselist([0, 1, 2]) sliced = base_list[1:3] - self.assertEqual(sliced, [1, 2]) - self.assertIsInstance(sliced, list) - self.assertEqual(base_list._instance._changed_fields, []) + assert sliced == [1, 2] + assert isinstance(sliced, list) + assert base_list._instance._changed_fields == [] def test___getitem__sublist_returns_BaseList_bound_to_instance(self): base_list = self._get_baselist([[1, 2], [3, 4]]) sub_list = base_list[0] - self.assertEqual(sub_list, [1, 2]) - self.assertIsInstance(sub_list, BaseList) - self.assertIs(sub_list._instance, base_list._instance) - self.assertEqual(sub_list._name, "my_name.0") - self.assertEqual(base_list._instance._changed_fields, []) + assert sub_list == [1, 2] + assert isinstance(sub_list, BaseList) + assert sub_list._instance is base_list._instance + assert sub_list._name == "my_name.0" + assert base_list._instance._changed_fields == [] # Challenge mark_as_changed from sublist sub_list[1] = None - self.assertEqual(base_list._instance._changed_fields, ["my_name.0.1"]) + assert base_list._instance._changed_fields == ["my_name.0.1"] def test___getitem__subdict_returns_BaseList_bound_to_instance(self): base_list = self._get_baselist([{"subk": "subv"}]) sub_dict = base_list[0] - self.assertEqual(sub_dict, {"subk": "subv"}) - self.assertIsInstance(sub_dict, BaseDict) - self.assertIs(sub_dict._instance, base_list._instance) - self.assertEqual(sub_dict._name, "my_name.0") - self.assertEqual(base_list._instance._changed_fields, []) + assert sub_dict == {"subk": "subv"} + assert isinstance(sub_dict, BaseDict) + assert sub_dict._instance is base_list._instance + assert sub_dict._name == "my_name.0" + assert base_list._instance._changed_fields == [] # Challenge mark_as_changed from subdict sub_dict["subk"] = None - self.assertEqual(base_list._instance._changed_fields, ["my_name.0.subk"]) + assert base_list._instance._changed_fields == ["my_name.0.subk"] def test_extend_calls_mark_as_changed(self): base_list = self._get_baselist([]) base_list.extend([True]) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_insert_calls_mark_as_changed(self): base_list = self._get_baselist([]) base_list.insert(0, True) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_remove_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list.remove(True) - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_remove_not_mark_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): base_list.remove(False) - self.assertFalse(base_list._instance._changed_fields) + assert not base_list._instance._changed_fields def test_pop_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list.pop() - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_reverse_calls_mark_as_changed(self): base_list = self._get_baselist([True, False]) base_list.reverse() - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test___delitem___calls_mark_as_changed(self): base_list = self._get_baselist([True]) del base_list[0] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test___setitem___calls_with_full_slice_mark_as_changed(self): base_list = self._get_baselist([]) @@ -286,8 +288,8 @@ class TestBaseList(unittest.TestCase): 0, 1, ] # Will use __setslice__ under py2 and __setitem__ under py3 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [0, 1]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [0, 1] def test___setitem___calls_with_partial_slice_mark_as_changed(self): base_list = self._get_baselist([0, 1, 2]) @@ -295,66 +297,66 @@ class TestBaseList(unittest.TestCase): 1, 0, ] # Will use __setslice__ under py2 and __setitem__ under py3 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [1, 0, 2]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [1, 0, 2] def test___setitem___calls_with_step_slice_mark_as_changed(self): base_list = self._get_baselist([0, 1, 2]) base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [-1, 1, -2]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [-1, 1, -2] def test___setitem___with_slice(self): base_list = self._get_baselist([0, 1, 2, 3, 4, 5]) base_list[0:6:2] = [None, None, None] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [None, 1, None, 3, None, 5]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [None, 1, None, 3, None, 5] def test___setitem___item_0_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list[0] = False - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [False]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [False] def test___setitem___item_1_calls_mark_as_changed(self): base_list = self._get_baselist([True, True]) base_list[1] = False - self.assertEqual(base_list._instance._changed_fields, ["my_name.1"]) - self.assertEqual(base_list, [True, False]) + assert base_list._instance._changed_fields == ["my_name.1"] + assert base_list == [True, False] def test___delslice___calls_mark_as_changed(self): base_list = self._get_baselist([0, 1]) del base_list[0:1] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) - self.assertEqual(base_list, [1]) + assert base_list._instance._changed_fields == ["my_name"] + assert base_list == [1] def test___iadd___calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list += [False] - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test___imul___calls_mark_as_changed(self): base_list = self._get_baselist([True]) - self.assertEqual(base_list._instance._changed_fields, []) + assert base_list._instance._changed_fields == [] base_list *= 2 - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_sort_calls_not_marked_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): base_list.sort(key=1) - self.assertEqual(base_list._instance._changed_fields, []) + assert base_list._instance._changed_fields == [] def test_sort_calls_mark_as_changed(self): base_list = self._get_baselist([True, False]) base_list.sort() - self.assertEqual(base_list._instance._changed_fields, ["my_name"]) + assert base_list._instance._changed_fields == ["my_name"] def test_sort_calls_with_key(self): base_list = self._get_baselist([1, 2, 11]) base_list.sort(key=lambda i: str(i)) - self.assertEqual(base_list, [1, 11, 2]) + assert base_list == [1, 11, 2] class TestStrictDict(unittest.TestCase): @@ -366,32 +368,32 @@ class TestStrictDict(unittest.TestCase): def test_init(self): d = self.dtype(a=1, b=1, c=1) - self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + assert (d.a, d.b, d.c) == (1, 1, 1) def test_iterkeys(self): d = self.dtype(a=1) - self.assertEqual(list(iterkeys(d)), ["a"]) + assert list(iterkeys(d)) == ["a"] def test_len(self): d = self.dtype(a=1) - self.assertEqual(len(d), 1) + assert len(d) == 1 def test_pop(self): d = self.dtype(a=1) - self.assertIn("a", d) + assert "a" in d d.pop("a") - self.assertNotIn("a", d) + assert "a" not in d def test_repr(self): d = self.dtype(a=1, b=2, c=3) - self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') + assert repr(d) == '{"a": 1, "b": 2, "c": 3}' # make sure quotes are escaped properly d = self.dtype(a='"', b="'", c="") - self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') + assert repr(d) == '{"a": \'"\', "b": "\'", "c": \'\'}' def test_init_fails_on_nonexisting_attrs(self): - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.dtype(a=1, b=2, d=3) def test_eq(self): @@ -403,45 +405,46 @@ class TestStrictDict(unittest.TestCase): h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) - self.assertEqual(d, dd) - self.assertNotEqual(d, e) - self.assertNotEqual(d, f) - self.assertNotEqual(d, g) - self.assertNotEqual(f, d) - self.assertEqual(d, h) - self.assertNotEqual(d, i) + assert d == dd + assert d != e + assert d != f + assert d != g + assert f != d + assert d == h + assert d != i def test_setattr_getattr(self): d = self.dtype() d.a = 1 - self.assertEqual(d.a, 1) - self.assertRaises(AttributeError, getattr, d, "b") + assert d.a == 1 + with pytest.raises(AttributeError): + getattr(d, "b") def test_setattr_raises_on_nonexisting_attr(self): d = self.dtype() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): d.x = 1 def test_setattr_getattr_special(self): d = self.strict_dict_class(["items"]) d.items = 1 - self.assertEqual(d.items, 1) + assert d.items == 1 def test_get(self): d = self.dtype(a=1) - self.assertEqual(d.get("a"), 1) - self.assertEqual(d.get("b", "bla"), "bla") + assert d.get("a") == 1 + assert d.get("b", "bla") == "bla" def test_items(self): d = self.dtype(a=1) - self.assertEqual(d.items(), [("a", 1)]) + assert d.items() == [("a", 1)] d = self.dtype(a=1, b=2) - self.assertEqual(d.items(), [("a", 1), ("b", 2)]) + assert d.items() == [("a", 1), ("b", 2)] def test_mappings_protocol(self): d = self.dtype(a=1, b=2) - self.assertEqual(dict(d), {"a": 1, "b": 2}) - self.assertEqual(dict(**d), {"a": 1, "b": 2}) + assert dict(d) == {"a": 1, "b": 2} + assert dict(**d) == {"a": 1, "b": 2} if __name__ == "__main__": diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 4730e2e3..b9d92883 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -42,37 +42,37 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 len(group_obj._data["members"]) - self.assertEqual(q, 1) + assert q == 1 len(group_obj.members) - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 User.drop_collection() Group.drop_collection() @@ -99,40 +99,40 @@ class FieldTest(unittest.TestCase): group.reload() # Confirm reload works with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced # verifies that no additional queries gets executed # if we re-iterate over the ListField once it is # dereferenced [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 def test_list_item_dereference_orphan_dbref(self): """Ensure that orphan DBRef items in ListFields are dereferenced. @@ -159,21 +159,21 @@ class FieldTest(unittest.TestCase): # Group.members list is an orphan DBRef User.objects[0].delete() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced # verifies that no additional queries gets executed # if we re-iterate over the ListField once it is # dereferenced [m for m in group_obj.members] - self.assertEqual(q, 2) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 2 + assert group_obj._data["members"]._dereferenced User.drop_collection() Group.drop_collection() @@ -197,8 +197,8 @@ class FieldTest(unittest.TestCase): Group(members=User.objects).save() group = Group.objects.first() - self.assertEqual(Group._get_collection().find_one()["members"], [1]) - self.assertEqual(group.members, [user]) + assert Group._get_collection().find_one()["members"] == [1] + assert group.members == [user] def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -231,8 +231,8 @@ class FieldTest(unittest.TestCase): group.save() group = Group.objects.first() - self.assertEqual(group.members[0].name, "user 1") - self.assertEqual(group.members[-1].name, "String!") + assert group.members[0].name == "user 1" + assert group.members[-1].name == "String!" def test_migrate_references(self): """Example of migrating ReferenceField storage @@ -253,12 +253,12 @@ class FieldTest(unittest.TestCase): group = Group(author=user, members=[user]).save() raw_data = Group._get_collection().find_one() - self.assertIsInstance(raw_data["author"], DBRef) - self.assertIsInstance(raw_data["members"][0], DBRef) + assert isinstance(raw_data["author"], DBRef) + assert isinstance(raw_data["members"][0], DBRef) group = Group.objects.first() - self.assertEqual(group.author, user) - self.assertEqual(group.members, [user]) + assert group.author == user + assert group.members == [user] # Migrate the model definition class Group(Document): @@ -273,12 +273,12 @@ class FieldTest(unittest.TestCase): g.save() group = Group.objects.first() - self.assertEqual(group.author, user) - self.assertEqual(group.members, [user]) + assert group.author == user + assert group.members == [user] raw_data = Group._get_collection().find_one() - self.assertIsInstance(raw_data["author"], ObjectId) - self.assertIsInstance(raw_data["members"][0], ObjectId) + assert isinstance(raw_data["author"], ObjectId) + assert isinstance(raw_data["members"][0], ObjectId) def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. @@ -309,43 +309,43 @@ class FieldTest(unittest.TestCase): Employee(name="Funky Gibbon", boss=bill, friends=friends).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 peter = Employee.objects.with_id(peter.id) - self.assertEqual(q, 1) + assert q == 1 peter.boss - self.assertEqual(q, 2) + assert q == 2 peter.friends - self.assertEqual(q, 3) + assert q == 3 # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 peter = Employee.objects.with_id(peter.id).select_related() - self.assertEqual(q, 2) + assert q == 2 - self.assertEqual(peter.boss, bill) - self.assertEqual(q, 2) + assert peter.boss == bill + assert q == 2 - self.assertEqual(peter.friends, friends) - self.assertEqual(q, 2) + assert peter.friends == friends + assert q == 2 # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 employees = Employee.objects(boss=bill).select_related() - self.assertEqual(q, 2) + assert q == 2 for employee in employees: - self.assertEqual(employee.boss, bill) - self.assertEqual(q, 2) + assert employee.boss == bill + assert q == 2 - self.assertEqual(employee.friends, friends) - self.assertEqual(q, 2) + assert employee.friends == friends + assert q == 2 def test_list_of_lists_of_references(self): class User(Document): @@ -366,10 +366,10 @@ class FieldTest(unittest.TestCase): u3 = User.objects.create(name="u3") SimpleList.objects.create(users=[u1, u2, u3]) - self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) + assert SimpleList.objects.all()[0].users == [u1, u2, u3] Post.objects.create(user_lists=[[u1, u2], [u3]]) - self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) + assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]] def test_circular_reference(self): """Ensure you can handle circular references @@ -403,9 +403,7 @@ class FieldTest(unittest.TestCase): daughter.relations.append(self_rel) daughter.save() - self.assertEqual( - "[, ]", "%s" % Person.objects() - ) + assert "[, ]" == "%s" % Person.objects() def test_circular_reference_on_self(self): """Ensure you can handle circular references @@ -432,9 +430,7 @@ class FieldTest(unittest.TestCase): daughter.relations.append(daughter) daughter.save() - self.assertEqual( - "[, ]", "%s" % Person.objects() - ) + assert "[, ]" == "%s" % Person.objects() def test_circular_tree_reference(self): """Ensure you can handle circular references with more than one level @@ -473,9 +469,9 @@ class FieldTest(unittest.TestCase): anna.other.name = "Anna's friends" anna.save() - self.assertEqual( - "[, , , ]", - "%s" % Person.objects(), + assert ( + "[, , , ]" + == "%s" % Person.objects() ) def test_generic_reference(self): @@ -516,52 +512,52 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ def test_generic_reference_orphan_dbref(self): """Ensure that generic orphan DBRef items in ListFields are dereferenced. @@ -604,18 +600,18 @@ class FieldTest(unittest.TestCase): # an orphan DBRef in the GenericReference ListField UserA.objects[0].delete() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 4 + assert group_obj._data["members"]._dereferenced [m for m in group_obj.members] - self.assertEqual(q, 4) - self.assertTrue(group_obj._data["members"]._dereferenced) + assert q == 4 + assert group_obj._data["members"]._dereferenced UserA.drop_collection() UserB.drop_collection() @@ -660,52 +656,52 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for m in group_obj.members: - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ UserA.drop_collection() UserB.drop_collection() @@ -735,43 +731,43 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, User) + assert isinstance(m, User) # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, User) + assert isinstance(m, User) # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, User) + assert isinstance(m, User) User.drop_collection() Group.drop_collection() @@ -813,65 +809,65 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ Group.objects.delete() Group().save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 1) - self.assertEqual(group_obj.members, {}) + assert q == 1 + assert group_obj.members == {} UserA.drop_collection() UserB.drop_collection() @@ -903,52 +899,52 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, UserA) + assert isinstance(m, UserA) # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, UserA) + assert isinstance(m, UserA) # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 2) + assert q == 2 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 [m for m in group_obj.members] - self.assertEqual(q, 2) + assert q == 2 for k, m in iteritems(group_obj.members): - self.assertIsInstance(m, UserA) + assert isinstance(m, UserA) UserA.drop_collection() Group.drop_collection() @@ -990,64 +986,64 @@ class FieldTest(unittest.TestCase): group.save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Document select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first().select_related() - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ # Queryset select_related with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_objs = Group.objects.select_related() - self.assertEqual(q, 4) + assert q == 4 for group_obj in group_objs: [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 [m for m in group_obj.members] - self.assertEqual(q, 4) + assert q == 4 for k, m in iteritems(group_obj.members): - self.assertIn("User", m.__class__.__name__) + assert "User" in m.__class__.__name__ Group.objects.delete() Group().save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 [m for m in group_obj.members] - self.assertEqual(q, 1) + assert q == 1 UserA.drop_collection() UserB.drop_collection() @@ -1075,8 +1071,8 @@ class FieldTest(unittest.TestCase): root.save() root = root.reload() - self.assertEqual(root.children, [company]) - self.assertEqual(company.parents, [root]) + assert root.children == [company] + assert company.parents == [root] def test_dict_in_dbref_instance(self): class Person(Document): @@ -1102,8 +1098,8 @@ class FieldTest(unittest.TestCase): room_101.save() room = Room.objects.first().select_related() - self.assertEqual(room.staffs_with_position[0]["staff"], sarah) - self.assertEqual(room.staffs_with_position[1]["staff"], bob) + assert room.staffs_with_position[0]["staff"] == sarah + assert room.staffs_with_position[1]["staff"] == bob def test_document_reload_no_inheritance(self): class Foo(Document): @@ -1133,8 +1129,8 @@ class FieldTest(unittest.TestCase): foo.save() foo.reload() - self.assertEqual(type(foo.bar), Bar) - self.assertEqual(type(foo.baz), Baz) + assert type(foo.bar) == Bar + assert type(foo.baz) == Baz def test_document_reload_reference_integrity(self): """ @@ -1166,13 +1162,13 @@ class FieldTest(unittest.TestCase): concurrent_change_user = User.objects.get(id=1) concurrent_change_user.name = "new-name" concurrent_change_user.save() - self.assertNotEqual(user.name, "new-name") + assert user.name != "new-name" msg = Message.objects.get(id=1) msg.reload() - self.assertEqual(msg.topic, topic) - self.assertEqual(msg.author, user) - self.assertEqual(msg.author.name, "new-name") + assert msg.topic == topic + assert msg.author == user + assert msg.author.name == "new-name" def test_list_lookup_not_checked_in_map(self): """Ensure we dereference list data correctly @@ -1194,8 +1190,8 @@ class FieldTest(unittest.TestCase): Message(id=1, comments=[c1, c2]).save() msg = Message.objects.get(id=1) - self.assertEqual(0, msg.comments[0].id) - self.assertEqual(1, msg.comments[1].id) + assert 0 == msg.comments[0].id + assert 1 == msg.comments[1].id def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -1217,15 +1213,15 @@ class FieldTest(unittest.TestCase): Group(name="Test", members=User.objects).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 group_obj.name = "new test" group_obj.save() - self.assertEqual(q, 2) + assert q == 2 def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -1247,15 +1243,15 @@ class FieldTest(unittest.TestCase): Group(name="Test", members=User.objects).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 group_obj.name = "new test" group_obj.save() - self.assertEqual(q, 2) + assert q == 2 def test_generic_reference_save_doesnt_cause_extra_queries(self): class UserA(Document): @@ -1287,15 +1283,15 @@ class FieldTest(unittest.TestCase): Group(name="test", members=members).save() with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 group_obj = Group.objects.first() - self.assertEqual(q, 1) + assert q == 1 group_obj.name = "new test" group_obj.save() - self.assertEqual(q, 2) + assert q == 2 def test_objectid_reference_across_databases(self): # mongoenginetest - Is default connection alias from setUp() @@ -1319,10 +1315,10 @@ class FieldTest(unittest.TestCase): # Can't use query_counter across databases - so test the _data object book = Book.objects.first() - self.assertNotIsInstance(book._data["author"], User) + assert not isinstance(book._data["author"], User) book.select_related() - self.assertIsInstance(book._data["author"], User) + assert isinstance(book._data["author"], User) def test_non_ascii_pk(self): """ @@ -1346,7 +1342,7 @@ class FieldTest(unittest.TestCase): BrandGroup(title="top_brands", brands=[brand1, brand2]).save() brand_groups = BrandGroup.objects().all() - self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) + assert 2 == len([brand for bg in brand_groups for brand in bg.brands]) def test_dereferencing_embedded_listfield_referencefield(self): class Tag(Document): @@ -1370,7 +1366,7 @@ class FieldTest(unittest.TestCase): Page(tags=[tag], posts=[post]).save() page = Page.objects.first() - self.assertEqual(page.tags[0], page.posts[0].tags[0]) + assert page.tags[0] == page.posts[0].tags[0] def test_select_related_follows_embedded_referencefields(self): class Song(Document): @@ -1390,12 +1386,12 @@ class FieldTest(unittest.TestCase): playlist = Playlist.objects.create(items=items) with query_counter() as q: - self.assertEqual(q, 0) + assert q == 0 playlist = Playlist.objects.first().select_related() songs = [item.song for item in playlist.items] - self.assertEqual(q, 2) + assert q == 2 if __name__ == "__main__": diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index e92f3d09..c1ea407c 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -39,7 +39,7 @@ class ConnectionTest(unittest.TestCase): # really??? return - self.assertEqual(conn.read_preference, READ_PREF) + assert conn.read_preference == READ_PREF if __name__ == "__main__": diff --git a/tests/test_signals.py b/tests/test_signals.py index 1d0607d7..b217712b 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -245,7 +245,7 @@ class SignalTests(unittest.TestCase): # Note that there is a chance that the following assert fails in case # some receivers (eventually created in other tests) # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) - self.assertEqual(self.pre_signals, post_signals) + assert self.pre_signals == post_signals def test_model_signals(self): """ Model saves should throw some signals. """ @@ -267,97 +267,76 @@ class SignalTests(unittest.TestCase): self.get_signal_output(lambda: None) # eliminate signal output a1 = self.Author.objects(name="Bill Shakespeare")[0] - self.assertEqual( - self.get_signal_output(create_author), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - ], - ) + assert self.get_signal_output(create_author) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + ] a1 = self.Author(name="Bill Shakespeare") - self.assertEqual( - self.get_signal_output(a1.save), - [ - "pre_save signal, Bill Shakespeare", - {}, - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", - {}, - "post_save signal, Bill Shakespeare", - "post_save dirty keys, ['name']", - "Is created", - {}, - ], - ) + assert self.get_signal_output(a1.save) == [ + "pre_save signal, Bill Shakespeare", + {}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {}, + ] a1.reload() a1.name = "William Shakespeare" - self.assertEqual( - self.get_signal_output(a1.save), - [ - "pre_save signal, William Shakespeare", - {}, - "pre_save_post_validation signal, William Shakespeare", - "Is updated", - {}, - "post_save signal, William Shakespeare", - "post_save dirty keys, ['name']", - "Is updated", - {}, - ], - ) + assert self.get_signal_output(a1.save) == [ + "pre_save signal, William Shakespeare", + {}, + "pre_save_post_validation signal, William Shakespeare", + "Is updated", + {}, + "post_save signal, William Shakespeare", + "post_save dirty keys, ['name']", + "Is updated", + {}, + ] - self.assertEqual( - self.get_signal_output(a1.delete), - [ - "pre_delete signal, William Shakespeare", - {}, - "post_delete signal, William Shakespeare", - {}, - ], - ) + assert self.get_signal_output(a1.delete) == [ + "pre_delete signal, William Shakespeare", + {}, + "post_delete signal, William Shakespeare", + {}, + ] - self.assertEqual( - self.get_signal_output(load_existing_author), - [ - "pre_init signal, Author", - {"id": 2, "name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = False", - ], - ) + assert self.get_signal_output(load_existing_author) == [ + "pre_init signal, Author", + {"id": 2, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + ] - self.assertEqual( - self.get_signal_output(bulk_create_author_with_load), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {}, - "pre_init signal, Author", - {"id": 3, "name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = False", - "post_bulk_insert signal, []", - "Is loaded", - {}, - ], - ) + assert self.get_signal_output(bulk_create_author_with_load) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {}, + "pre_init signal, Author", + {"id": 3, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + "post_bulk_insert signal, []", + "Is loaded", + {}, + ] - self.assertEqual( - self.get_signal_output(bulk_create_author_without_load), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {}, - "post_bulk_insert signal, []", - "Not loaded", - {}, - ], - ) + assert self.get_signal_output(bulk_create_author_without_load) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {}, + "post_bulk_insert signal, []", + "Not loaded", + {}, + ] def test_signal_kwargs(self): """ Make sure signal_kwargs is passed to signals calls. """ @@ -367,83 +346,74 @@ class SignalTests(unittest.TestCase): a.save(signal_kwargs={"live": True, "die": False}) a.delete(signal_kwargs={"live": False, "die": True}) - self.assertEqual( - self.get_signal_output(live_and_let_die), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_save signal, Bill Shakespeare", - {"die": False, "live": True}, - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", - {"die": False, "live": True}, - "post_save signal, Bill Shakespeare", - "post_save dirty keys, ['name']", - "Is created", - {"die": False, "live": True}, - "pre_delete signal, Bill Shakespeare", - {"die": True, "live": False}, - "post_delete signal, Bill Shakespeare", - {"die": True, "live": False}, - ], - ) + assert self.get_signal_output(live_and_let_die) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_save signal, Bill Shakespeare", + {"die": False, "live": True}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {"die": False, "live": True}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {"die": False, "live": True}, + "pre_delete signal, Bill Shakespeare", + {"die": True, "live": False}, + "post_delete signal, Bill Shakespeare", + {"die": True, "live": False}, + ] def bulk_create_author(): a1 = self.Author(name="Bill Shakespeare") self.Author.objects.insert([a1], signal_kwargs={"key": True}) - self.assertEqual( - self.get_signal_output(bulk_create_author), - [ - "pre_init signal, Author", - {"name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {"key": True}, - "pre_init signal, Author", - {"id": 2, "name": "Bill Shakespeare"}, - "post_init signal, Bill Shakespeare, document._created = False", - "post_bulk_insert signal, []", - "Is loaded", - {"key": True}, - ], - ) + assert self.get_signal_output(bulk_create_author) == [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {"key": True}, + "pre_init signal, Author", + {"id": 2, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + "post_bulk_insert signal, []", + "Is loaded", + {"key": True}, + ] def test_queryset_delete_signals(self): """ Queryset delete should throw some signals. """ self.Another(name="Bill Shakespeare").save() - self.assertEqual( - self.get_signal_output(self.Another.objects.delete), - [ - "pre_delete signal, Bill Shakespeare", - {}, - "post_delete signal, Bill Shakespeare", - {}, - ], - ) + assert self.get_signal_output(self.Another.objects.delete) == [ + "pre_delete signal, Bill Shakespeare", + {}, + "post_delete signal, Bill Shakespeare", + {}, + ] def test_signals_with_explicit_doc_ids(self): """ Model saves must have a created flag the first time.""" ei = self.ExplicitId(id=123) # post save must received the created flag, even if there's already # an object id present - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] # second time, it must be an update - self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) + assert self.get_signal_output(ei.save) == ["Is updated"] def test_signals_with_switch_collection(self): ei = self.ExplicitId(id=123) ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) + assert self.get_signal_output(ei.save) == ["Is updated"] ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] def test_signals_with_switch_db(self): connect("mongoenginetest") @@ -451,14 +421,14 @@ class SignalTests(unittest.TestCase): ei = self.ExplicitId(id=123) ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) + assert self.get_signal_output(ei.save) == ["Is updated"] ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) + assert self.get_signal_output(ei.save) == ["Is created"] def test_signals_bulk_insert(self): def bulk_set_active_post(): @@ -470,16 +440,13 @@ class SignalTests(unittest.TestCase): self.Post.objects.insert(posts) results = self.get_signal_output(bulk_set_active_post) - self.assertEqual( - results, - [ - "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", - {}, - "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", - "Is loaded", - {}, - ], - ) + assert results == [ + "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", + {}, + "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", + "Is loaded", + {}, + ] if __name__ == "__main__": diff --git a/tests/test_utils.py b/tests/test_utils.py index 897c19b2..ccb44aac 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import re import unittest from mongoengine.base.utils import LazyRegexCompiler +import pytest signal_output = [] @@ -12,21 +13,21 @@ class LazyRegexCompilerTest(unittest.TestCase): EMAIL_REGEX = LazyRegexCompiler("@", flags=32) descriptor = UserEmail.__dict__["EMAIL_REGEX"] - self.assertIsNone(descriptor._compiled_regex) + assert descriptor._compiled_regex is None regex = UserEmail.EMAIL_REGEX - self.assertEqual(regex, re.compile("@", flags=32)) - self.assertEqual(regex.search("user@domain.com").group(), "@") + assert regex == re.compile("@", flags=32) + assert regex.search("user@domain.com").group() == "@" user_email = UserEmail() - self.assertIs(user_email.EMAIL_REGEX, UserEmail.EMAIL_REGEX) + assert user_email.EMAIL_REGEX is UserEmail.EMAIL_REGEX def test_lazy_regex_compiler_verify_cannot_set_descriptor_on_instance(self): class UserEmail(object): EMAIL_REGEX = LazyRegexCompiler("@") user_email = UserEmail() - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): user_email.EMAIL_REGEX = re.compile("@") def test_lazy_regex_compiler_verify_can_override_class_attr(self): @@ -34,6 +35,4 @@ class LazyRegexCompilerTest(unittest.TestCase): EMAIL_REGEX = LazyRegexCompiler("@") UserEmail.EMAIL_REGEX = re.compile("cookies") - self.assertEqual( - UserEmail.EMAIL_REGEX.search("Cake & cookies").group(), "cookies" - ) + assert UserEmail.EMAIL_REGEX.search("Cake & cookies").group() == "cookies" From 3e764d068c2b09c500b6226505e662389e6427b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 31 Aug 2019 22:40:54 +0300 Subject: [PATCH 29/59] fix remaining assertRaises --- tests/document/test_indexes.py | 4 +- tests/document/test_inheritance.py | 10 ++--- tests/document/test_instance.py | 36 +++++++-------- tests/fields/test_dict_field.py | 8 ++-- tests/fields/test_email_field.py | 5 ++- tests/fields/test_embedded_document_field.py | 20 ++++----- tests/fields/test_fields.py | 46 +++++++++----------- tests/fields/test_url_field.py | 4 +- tests/queryset/test_queryset.py | 11 +++-- tests/test_connection.py | 12 ++--- 10 files changed, 76 insertions(+), 80 deletions(-) diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index cc1aae52..6c31054a 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -595,12 +595,12 @@ class TestIndexes(unittest.TestCase): Blog.drop_collection() - with pytest.raises(OperationFailure) as ctx_err: + with pytest.raises(OperationFailure) as exc_info: Blog(id="garbage").save() # One of the errors below should happen. Which one depends on the # PyMongo version and dict order. - err_msg = str(ctx_err.exception) + err_msg = str(exc_info.value) assert any( [ "The field 'unique' is not valid for an _id index specification" diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index 6a913b3e..3e515653 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -335,13 +335,13 @@ class TestInheritance(MongoDBTestCase): name = StringField() # can't inherit because Animal didn't explicitly allow inheritance - with pytest.raises(ValueError) as cm: + with pytest.raises( + ValueError, match="Document Animal may not be subclassed" + ) as exc_info: class Dog(Animal): pass - assert "Document Animal may not be subclassed" in str(cm.exception) - # Check that _cls etc aren't present on simple documents dog = Animal(name="dog").save() assert dog.to_mongo().keys() == ["_id", "name"] @@ -358,13 +358,13 @@ class TestInheritance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with pytest.raises(ValueError) as cm: + with pytest.raises(ValueError) as exc_info: class Mammal(Animal): meta = {"allow_inheritance": False} assert ( - str(cm.exception) + str(exc_info.value) == 'Only direct subclasses of Document may set "allow_inheritance" to False' ) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 01dc492b..c7bc113e 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -350,14 +350,11 @@ class TestInstance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="Cannot override primary key field") as e: class EmailUser(User): email = StringField(primary_key=True) - exc = e.exception - assert str(exc) == "Cannot override primary key field" - def test_custom_id_field_is_required(self): """Ensure the custom primary key field is required.""" @@ -365,10 +362,9 @@ class TestInstance(MongoDBTestCase): username = StringField(primary_key=True) name = StringField() - with pytest.raises(ValidationError) as e: + with pytest.raises(ValidationError) as exc_info: User(name="test").save() - exc = e.exception - assert "Field is required: ['username']" in str(exc) + assert "Field is required: ['username']" in str(exc_info.value) def test_document_not_registered(self): class Place(Document): @@ -870,12 +866,12 @@ class TestInstance(MongoDBTestCase): t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) - with pytest.raises(ValidationError) as cm: + with pytest.raises(ValidationError) as exc_info: t.save() expected_msg = "Value of z != x + y" - assert expected_msg in cm.exception.message - assert cm.exception.to_dict() == {"doc": {"__all__": expected_msg}} + assert expected_msg in str(exc_info.value) + assert exc_info.value.to_dict() == {"doc": {"__all__": expected_msg}} t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() assert t.doc.z == 35 @@ -3208,43 +3204,47 @@ class TestInstance(MongoDBTestCase): def test_positional_creation(self): """Document cannot be instantiated using positional arguments.""" - with pytest.raises(TypeError) as e: + with pytest.raises(TypeError) as exc_info: person = self.Person("Test User", 42) + expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - assert str(e.exception) == expected_msg + assert str(exc_info.value) == expected_msg def test_mixed_creation(self): """Document cannot be instantiated using mixed arguments.""" - with pytest.raises(TypeError) as e: + with pytest.raises(TypeError) as exc_info: person = self.Person("Test User", age=42) + expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - assert str(e.exception) == expected_msg + assert str(exc_info.value) == expected_msg def test_positional_creation_embedded(self): """Embedded document cannot be created using positional arguments.""" - with pytest.raises(TypeError) as e: + with pytest.raises(TypeError) as exc_info: job = self.Job("Test Job", 4) + expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - assert str(e.exception) == expected_msg + assert str(exc_info.value) == expected_msg def test_mixed_creation_embedded(self): """Embedded document cannot be created using mixed arguments.""" - with pytest.raises(TypeError) as e: + with pytest.raises(TypeError) as exc_info: job = self.Job("Test Job", years=4) + expected_msg = ( "Instantiating a document with positional arguments is not " "supported. Please use `field_name=value` keyword arguments." ) - assert str(e.exception) == expected_msg + assert str(exc_info.value) == expected_msg def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id'.""" diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 56df682f..7dda2a9c 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -270,10 +270,12 @@ class TestDictField(MongoDBTestCase): embed = Embedded(name="garbage") doc = DictFieldTest(dictionary=embed) - with pytest.raises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as exc_info: doc.validate() - assert "'dictionary'" in str(ctx_err.exception) - assert "Only dictionaries may be used in a DictField" in str(ctx_err.exception) + + error_msg = str(exc_info.value) + assert "'dictionary'" in error_msg + assert "Only dictionaries may be used in a DictField" in error_msg def test_atomic_update_dict_field(self): """Ensure that the entire DictField can be atomically updated.""" diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index b8d3d169..902a7c42 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -88,9 +88,10 @@ class TestEmailField(MongoDBTestCase): invalid_idn = ".google.com" user = User(email="me@%s" % invalid_idn) - with pytest.raises(ValidationError) as ctx_err: + + with pytest.raises(ValidationError) as exc_info: user.validate() - assert "domain failed IDN encoding" in str(ctx_err.exception) + assert "domain failed IDN encoding" in str(exc_info.value) def test_email_field_ip_domain(self): class User(Document): diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index 4fcf6bf1..9e6871cc 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -36,11 +36,11 @@ class TestEmbeddedDocumentField(MongoDBTestCase): name = StringField() emb = EmbeddedDocumentField("MyDoc") - with pytest.raises(ValidationError) as ctx: + with pytest.raises(ValidationError) as exc_info: emb.document_type assert ( "Invalid embedded document class provided to an EmbeddedDocumentField" - in str(ctx.exception) + in str(exc_info.value) ) def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): @@ -72,9 +72,9 @@ class TestEmbeddedDocumentField(MongoDBTestCase): p = Person(settings=AdminSettings(foo1="bar1", foo2="bar2"), name="John").save() # Test non exiting attribute - with pytest.raises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as exc_info: Person.objects(settings__notexist="bar").first() - assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' + assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' with pytest.raises(LookUpError): Person.objects.only("settings.notexist") @@ -108,9 +108,9 @@ class TestEmbeddedDocumentField(MongoDBTestCase): p.save() # Test non exiting attribute - with pytest.raises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as exc_info: assert Person.objects(settings__notexist="bar").first().id == p.id - assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' + assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' # Test existing attribute assert Person.objects(settings__base_foo="basefoo").first().id == p.id @@ -316,9 +316,9 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): p2 = Person(settings=NonAdminSettings(foo2="bar2")).save() # Test non exiting attribute - with pytest.raises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as exc_info: Person.objects(settings__notexist="bar").first() - assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' + assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' with pytest.raises(LookUpError): Person.objects.only("settings.notexist") @@ -344,9 +344,9 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): p.save() # Test non exiting attribute - with pytest.raises(InvalidQueryError) as ctx_err: + with pytest.raises(InvalidQueryError) as exc_info: assert Person.objects(settings__notexist="bar").first().id == p.id - assert unicode(ctx_err.exception) == u'Cannot resolve field "notexist"' + assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' # Test existing attribute assert Person.objects(settings__base_foo="basefoo").first().id == p.id diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index b27d95d2..0ce65087 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -96,13 +96,13 @@ class TestField(MongoDBTestCase): "it should raise a ValidationError if validation fails" ) - with pytest.raises(DeprecatedError) as ctx_err: + with pytest.raises(DeprecatedError) as exc_info: Person(name="").validate() - assert str(ctx_err.exception) == error + assert str(exc_info.value) == error - with pytest.raises(DeprecatedError) as ctx_err: + with pytest.raises(DeprecatedError) as exc_info: Person(name="").save() - assert str(ctx_err.exception) == error + assert str(exc_info.value) == error def test_custom_field_validation_raise_validation_error(self): def _not_empty(z): @@ -114,16 +114,10 @@ class TestField(MongoDBTestCase): Person.drop_collection() - with pytest.raises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as exc_info: Person(name="").validate() assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( - ctx_err.exception - ) - - with pytest.raises(ValidationError): - Person(name="").save() - assert "ValidationError (Person:None) (cantbeempty: ['name'])" == str( - ctx_err.exception + exc_info.value ) Person(name="garbage").validate() @@ -1029,9 +1023,9 @@ class TestField(MongoDBTestCase): if i < 6: foo.save() else: - with pytest.raises(ValidationError) as cm: + with pytest.raises(ValidationError) as exc_info: foo.save() - assert "List is too long" in str(cm.exception) + assert "List is too long" in str(exc_info.value) def test_list_field_max_length_set_operator(self): """Ensure ListField's max_length is respected for a "set" operator.""" @@ -1040,9 +1034,9 @@ class TestField(MongoDBTestCase): items = ListField(IntField(), max_length=3) foo = Foo.objects.create(items=[1, 2, 3]) - with pytest.raises(ValidationError) as cm: + with pytest.raises(ValidationError) as exc_info: foo.modify(set__items=[1, 2, 3, 4]) - assert "List is too long" in str(cm.exception) + assert "List is too long" in str(exc_info.value) def test_list_field_rejects_strings(self): """Strings aren't valid list field data types.""" @@ -2325,21 +2319,21 @@ class TestEmbeddedDocumentListField(MongoDBTestCase): # Test with an embeddedDocument instead of a list(embeddedDocument) # It's an edge case but it used to fail with a vague error, making it difficult to troubleshoot it post = self.BlogPost(comments=comment) - with pytest.raises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as exc_info: post.validate() - assert "'comments'" in str(ctx_err.exception) - assert "Only lists and tuples may be used in a list field" in str( - ctx_err.exception - ) + + error_msg = str(exc_info.value) + assert "'comments'" in error_msg + assert "Only lists and tuples may be used in a list field" in error_msg # Test with a Document post = self.BlogPost(comments=Title(content="garbage")) - with pytest.raises(ValidationError): + with pytest.raises(ValidationError) as exc_info: post.validate() - assert "'comments'" in str(ctx_err.exception) - assert "Only lists and tuples may be used in a list field" in str( - ctx_err.exception - ) + + error_msg = str(exc_info.value) + assert "'comments'" in error_msg + assert "Only lists and tuples may be used in a list field" in error_msg def test_no_keyword_filter(self): """ diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index e7df0e08..e125f56a 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -31,10 +31,10 @@ class TestURLField(MongoDBTestCase): # TODO fix URL validation - this *IS* a valid URL # For now we just want to make sure that the error message is correct - with pytest.raises(ValidationError) as ctx_err: + with pytest.raises(ValidationError) as exc_info: link.validate() assert ( - unicode(ctx_err.exception) + unicode(exc_info.value) == u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])" ) diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index d154de8d..31abb42f 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -908,20 +908,20 @@ class TestQueryset(unittest.TestCase): assert Blog.objects.count() == 2 # test inserting an existing document (shouldn't be allowed) - with pytest.raises(OperationError) as cm: + with pytest.raises(OperationError) as exc_info: blog = Blog.objects.first() Blog.objects.insert(blog) assert ( - str(cm.exception) + str(exc_info.value) == "Some documents have ObjectIds, use doc.update() instead" ) # test inserting a query set - with pytest.raises(OperationError) as cm: + with pytest.raises(OperationError) as exc_info: blogs_qs = Blog.objects Blog.objects.insert(blogs_qs) assert ( - str(cm.exception) + str(exc_info.value) == "Some documents have ObjectIds, use doc.update() instead" ) @@ -5053,9 +5053,8 @@ class TestQueryset(unittest.TestCase): Person(name="a").save() qs = Person.objects() _ = list(qs) - with pytest.raises(OperationError) as ctx_err: + with pytest.raises(OperationError, match="QuerySet already cached") as ctx_err: qs.no_cache() - assert "QuerySet already cached" == str(ctx_err.exception) def test_no_cached_queryset_no_cache_back_to_cache(self): class Person(Document): diff --git a/tests/test_connection.py b/tests/test_connection.py index c73b67d1..8db69b0c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -142,22 +142,22 @@ class ConnectionTest(unittest.TestCase): def test_connect_fails_if_connect_2_times_with_default_alias(self): connect("mongoenginetest") - with pytest.raises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as exc_info: connect("mongoenginetest2") assert ( "A different connection with alias `default` was already registered. Use disconnect() first" - == str(ctx_err.exception) + == str(exc_info.value) ) def test_connect_fails_if_connect_2_times_with_custom_alias(self): connect("mongoenginetest", alias="alias1") - with pytest.raises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as exc_info: connect("mongoenginetest2", alias="alias1") assert ( "A different connection with alias `alias1` was already registered. Use disconnect() first" - == str(ctx_err.exception) + == str(exc_info.value) ) def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( @@ -366,9 +366,9 @@ class ConnectionTest(unittest.TestCase): assert History._collection is None - with pytest.raises(ConnectionFailure) as ctx_err: + with pytest.raises(ConnectionFailure) as exc_info: History.objects.first() - assert "You have not defined a default connection" == str(ctx_err.exception) + assert "You have not defined a default connection" == str(exc_info.value) def test_connect_disconnect_works_on_same_document(self): """Ensure that the connect/disconnect works properly with a single Document""" From c61c6a85253e76fe5ef8d7da48af94d248e3786f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 31 Aug 2019 22:51:13 +0300 Subject: [PATCH 30/59] fix == None assertions --- tests/document/test_class_methods.py | 2 +- tests/document/test_inheritance.py | 4 +-- tests/document/test_instance.py | 32 +++++++++--------- tests/document/test_validation.py | 2 +- tests/fields/test_fields.py | 10 +++--- tests/queryset/test_field_list.py | 50 ++++++++++++++-------------- tests/queryset/test_modify.py | 6 ++-- tests/queryset/test_queryset.py | 28 ++++++++-------- tests/queryset/test_visitor.py | 4 +-- tests/test_datastructures.py | 2 +- 10 files changed, 70 insertions(+), 70 deletions(-) diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index 98909d2f..be883b2a 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -72,7 +72,7 @@ class TestClassMethods(unittest.TestCase): class Job(Document): employee = ReferenceField(self.Person) - assert self.Person._meta.get("delete_rules") == None + assert self.Person._meta.get("delete_rules") is None self.Person.register_delete_rule(Job, "employee", NULLIFY) assert self.Person._meta["delete_rules"] == {(Job, "employee"): NULLIFY} diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index 3e515653..b6b6088a 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -559,8 +559,8 @@ class TestInheritance(MongoDBTestCase): assert "collection" not in Animal._meta assert "collection" not in Mammal._meta - assert Animal._get_collection_name() == None - assert Mammal._get_collection_name() == None + assert Animal._get_collection_name() is None + assert Mammal._get_collection_name() is None assert Fish._get_collection_name() == "fish" assert Guppy._get_collection_name() == "fish" diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index c7bc113e..57815355 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -1391,7 +1391,7 @@ class TestInstance(MongoDBTestCase): person.reload() assert person.name == "User" - assert person.age == None + assert person.age is None person = self.Person.objects.get() person.name = None @@ -1399,8 +1399,8 @@ class TestInstance(MongoDBTestCase): person.save() person.reload() - assert person.name == None - assert person.age == None + assert person.name is None + assert person.age is None def test_update_rename_operator(self): """Test the $rename operator.""" @@ -2018,7 +2018,7 @@ class TestInstance(MongoDBTestCase): promoted_employee.save() promoted_employee.reload() - assert promoted_employee.details == None + assert promoted_employee.details is None def test_object_mixins(self): class NameMixin(object): @@ -2154,7 +2154,7 @@ class TestInstance(MongoDBTestCase): reviewer.delete() # No effect on the BlogPost assert BlogPost.objects.count() == 1 - assert BlogPost.objects.get().reviewer == None + assert BlogPost.objects.get().reviewer is None # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() @@ -2200,7 +2200,7 @@ class TestInstance(MongoDBTestCase): reviewer.delete() assert Book.objects.count() == 1 - assert Book.objects.get().reviewer == None + assert Book.objects.get().reviewer is None user.delete() assert Book.objects.count() == 0 @@ -2267,7 +2267,7 @@ class TestInstance(MongoDBTestCase): reviewer.delete() assert BlogPost.objects.count() == 1 - assert BlogPost.objects.get().reviewer == None + assert BlogPost.objects.get().reviewer is None # Delete the Writer should lead to deletion of the BlogPost author.delete() @@ -2378,7 +2378,7 @@ class TestInstance(MongoDBTestCase): f.delete() assert Bar.objects.count() == 1 # No effect on the BlogPost - assert Bar.objects.get().foo == None + assert Bar.objects.get().foo is None def test_invalid_reverse_delete_rule_raise_errors(self): with pytest.raises(InvalidDocumentError): @@ -3464,7 +3464,7 @@ class TestInstance(MongoDBTestCase): p = Person.from_json('{"name": "name"}', created=False) assert p._created == False - assert p.id == None + assert p.id is None # Make sure the document is subsequently persisted correctly. p.save() @@ -3540,13 +3540,13 @@ class TestInstance(MongoDBTestCase): u_from_db = User.objects.get(name="user") u_from_db.height = None u_from_db.save() - assert u_from_db.height == None + assert u_from_db.height is None # 864 - assert u_from_db.str_fld == None - assert u_from_db.int_fld == None - assert u_from_db.flt_fld == None - assert u_from_db.dt_fld == None - assert u_from_db.cdt_fld == None + assert u_from_db.str_fld is None + assert u_from_db.int_fld is None + assert u_from_db.flt_fld is None + assert u_from_db.dt_fld is None + assert u_from_db.cdt_fld is None # 735 User.objects.delete() @@ -3554,7 +3554,7 @@ class TestInstance(MongoDBTestCase): u.save() User.objects(name="user").update_one(set__height=None, upsert=True) u_from_db = User.objects.get(name="user") - assert u_from_db.height == None + assert u_from_db.height is None def test_not_saved_eq(self): """Ensure we can compare documents not saved. diff --git a/tests/document/test_validation.py b/tests/document/test_validation.py index 80601994..dfae5bae 100644 --- a/tests/document/test_validation.py +++ b/tests/document/test_validation.py @@ -110,7 +110,7 @@ class TestValidatorError(MongoDBTestCase): comment.date = datetime.now() comment.validate() - assert comment._instance == None + assert comment._instance is None def test_embedded_db_field_validate(self): class SubDoc(EmbeddedDocument): diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index 0ce65087..21cc78be 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -180,7 +180,7 @@ class TestField(MongoDBTestCase): assert person.validate() is None - assert person.name == None + assert person.name is None assert person.age == 30 assert person.userid == "test" assert isinstance(person.created, datetime.datetime) @@ -250,7 +250,7 @@ class TestField(MongoDBTestCase): assert person.validate() is None - assert person.name == None + assert person.name is None assert person.age == 30 assert person.userid == "test" assert isinstance(person.created, datetime.datetime) @@ -363,7 +363,7 @@ class TestField(MongoDBTestCase): name = StringField() person = Person(name="Test User") - assert person.id == None + assert person.id is None person.id = 47 with pytest.raises(ValidationError): @@ -1970,7 +1970,7 @@ class TestField(MongoDBTestCase): shirt2 = Shirt() # Make sure get__display returns the default value (or None) - assert shirt1.get_size_display() == None + assert shirt1.get_size_display() is None assert shirt1.get_style_display() == "Wide" shirt1.size = "XXL" @@ -2024,7 +2024,7 @@ class TestField(MongoDBTestCase): shirt = Shirt() - assert shirt.get_size_display() == None + assert shirt.get_size_display() is None assert shirt.get_style_display() == "Small" shirt.size = "XXL" diff --git a/tests/queryset/test_field_list.py b/tests/queryset/test_field_list.py index d33c4c86..a2bf6f1f 100644 --- a/tests/queryset/test_field_list.py +++ b/tests/queryset/test_field_list.py @@ -154,10 +154,10 @@ class TestOnlyExcludeAll(unittest.TestCase): obj = self.Person.objects.only("name").get() assert obj.name == person.name - assert obj.age == None + assert obj.age is None obj = self.Person.objects.only("age").get() - assert obj.name == None + assert obj.name is None assert obj.age == person.age obj = self.Person.objects.only("name", "age").get() @@ -166,7 +166,7 @@ class TestOnlyExcludeAll(unittest.TestCase): obj = self.Person.objects.only(*("id", "name")).get() assert obj.name == person.name - assert obj.age == None + assert obj.age is None # Check polymorphism still works class Employee(self.Person): @@ -181,7 +181,7 @@ class TestOnlyExcludeAll(unittest.TestCase): # Check field names are looked up properly obj = Employee.objects(id=employee.id).only("salary").get() assert obj.salary == employee.salary - assert obj.name == None + assert obj.name is None def test_only_with_subfields(self): class User(EmbeddedDocument): @@ -215,8 +215,8 @@ class TestOnlyExcludeAll(unittest.TestCase): post.save() obj = BlogPost.objects.only("author.name").get() - assert obj.content == None - assert obj.author.email == None + assert obj.content is None + assert obj.author.email is None assert obj.author.name == "Test User" assert obj.comments == [] @@ -225,15 +225,15 @@ class TestOnlyExcludeAll(unittest.TestCase): obj = BlogPost.objects.only("content", "comments.title").get() assert obj.content == "Had a good coffee today..." - assert obj.author == None + assert obj.author is None assert obj.comments[0].title == "I aggree" assert obj.comments[1].title == "Coffee" - assert obj.comments[0].text == None - assert obj.comments[1].text == None + assert obj.comments[0].text is None + assert obj.comments[1].text is None obj = BlogPost.objects.only("comments").get() - assert obj.content == None - assert obj.author == None + assert obj.content is None + assert obj.author is None assert obj.comments[0].title == "I aggree" assert obj.comments[1].title == "Coffee" assert obj.comments[0].text == "Great post!" @@ -266,10 +266,10 @@ class TestOnlyExcludeAll(unittest.TestCase): post.save() obj = BlogPost.objects.exclude("author", "comments.text").get() - assert obj.author == None + assert obj.author is None assert obj.content == "Had a good coffee today..." assert obj.comments[0].title == "I aggree" - assert obj.comments[0].text == None + assert obj.comments[0].text is None BlogPost.drop_collection() @@ -304,15 +304,15 @@ class TestOnlyExcludeAll(unittest.TestCase): assert obj.sender == "me" assert obj.to == "you" assert obj.subject == "From Russia with Love" - assert obj.body == None - assert obj.content_type == None + assert obj.body is None + assert obj.content_type is None obj = Email.objects.only("sender", "to").exclude("body", "sender").get() - assert obj.sender == None + assert obj.sender is None assert obj.to == "you" - assert obj.subject == None - assert obj.body == None - assert obj.content_type == None + assert obj.subject is None + assert obj.body is None + assert obj.content_type is None obj = ( Email.objects.exclude("attachments.content") @@ -321,12 +321,12 @@ class TestOnlyExcludeAll(unittest.TestCase): .get() ) assert obj.attachments[0].name == "file1.doc" - assert obj.attachments[0].content == None - assert obj.sender == None + assert obj.attachments[0].content is None + assert obj.sender is None assert obj.to == "you" - assert obj.subject == None - assert obj.body == None - assert obj.content_type == None + assert obj.subject is None + assert obj.body is None + assert obj.content_type is None Email.drop_collection() @@ -456,7 +456,7 @@ class TestOnlyExcludeAll(unittest.TestCase): User(username="mongodb", password="secret").save() user = Base.objects().exclude("password", "wibble").first() - assert user.password == None + assert user.password is None with pytest.raises(LookUpError): Base.objects.exclude("made_up") diff --git a/tests/queryset/test_modify.py b/tests/queryset/test_modify.py index 293a463e..556e6d9e 100644 --- a/tests/queryset/test_modify.py +++ b/tests/queryset/test_modify.py @@ -35,13 +35,13 @@ class TestFindAndModify(unittest.TestCase): def test_modify_not_existing(self): Doc(id=0, value=0).save() - assert Doc.objects(id=1).modify(set__value=-1) == None + assert Doc.objects(id=1).modify(set__value=-1) is None self.assertDbEqual([{"_id": 0, "value": 0}]) def test_modify_with_upsert(self): Doc(id=0, value=0).save() old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) - assert old_doc == None + assert old_doc is None self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) def test_modify_with_upsert_existing(self): @@ -68,7 +68,7 @@ class TestFindAndModify(unittest.TestCase): def test_find_and_modify_with_remove_not_existing(self): Doc(id=0, value=0).save() - assert Doc.objects(id=1).modify(remove=True) == None + assert Doc.objects(id=1).modify(remove=True) is None self.assertDbEqual([{"_id": 0, "value": 0}]) def test_modify_with_order_by(self): diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 31abb42f..f3606609 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -142,7 +142,7 @@ class TestQueryset(unittest.TestCase): person = self.Person.objects().limit(1).only("name").first() assert person == user_a assert person.name == "User A" - assert person.age == None + assert person.age is None def test_skip(self): """Ensure that QuerySet.skip works as expected.""" @@ -166,7 +166,7 @@ class TestQueryset(unittest.TestCase): person = self.Person.objects().skip(1).only("name").first() assert person == user_b assert person.name == "User B" - assert person.age == None + assert person.age is None def test___getitem___invalid_index(self): """Ensure slicing a queryset works as expected.""" @@ -444,7 +444,7 @@ class TestQueryset(unittest.TestCase): assert result == 2 result = self.Person.objects.update(set__name="Ross", write_concern={"w": 0}) - assert result == None + assert result is None result = self.Person.objects.update_one( set__name="Test User", write_concern={"w": 1} @@ -453,7 +453,7 @@ class TestQueryset(unittest.TestCase): result = self.Person.objects.update_one( set__name="Test User", write_concern={"w": 0} ) - assert result == None + assert result is None def test_update_update_has_a_value(self): """Test to ensure that update is passed a value to update to""" @@ -1148,7 +1148,7 @@ class TestQueryset(unittest.TestCase): obj = self.Person.objects(name__contains="van").first() assert obj == person obj = self.Person.objects(name__contains="Van").first() - assert obj == None + assert obj is None # Test icontains obj = self.Person.objects(name__icontains="Van").first() @@ -1158,7 +1158,7 @@ class TestQueryset(unittest.TestCase): obj = self.Person.objects(name__startswith="Guido").first() assert obj == person obj = self.Person.objects(name__startswith="guido").first() - assert obj == None + assert obj is None # Test istartswith obj = self.Person.objects(name__istartswith="guido").first() @@ -1168,7 +1168,7 @@ class TestQueryset(unittest.TestCase): obj = self.Person.objects(name__endswith="Rossum").first() assert obj == person obj = self.Person.objects(name__endswith="rossuM").first() - assert obj == None + assert obj is None # Test iendswith obj = self.Person.objects(name__iendswith="rossuM").first() @@ -1178,15 +1178,15 @@ class TestQueryset(unittest.TestCase): obj = self.Person.objects(name__exact="Guido van Rossum").first() assert obj == person obj = self.Person.objects(name__exact="Guido van rossum").first() - assert obj == None + assert obj is None obj = self.Person.objects(name__exact="Guido van Rossu").first() - assert obj == None + assert obj is None # Test iexact obj = self.Person.objects(name__iexact="gUIDO VAN rOSSUM").first() assert obj == person obj = self.Person.objects(name__iexact="gUIDO VAN rOSSU").first() - assert obj == None + assert obj is None # Test unsafe expressions person = self.Person(name="Guido van Rossum [.'Geek']") @@ -1205,7 +1205,7 @@ class TestQueryset(unittest.TestCase): assert obj == alice obj = self.Person.objects(name__not__iexact="alice").first() - assert obj == None + assert obj is None def test_filter_chaining(self): """Ensure filters can be chained together. @@ -1430,7 +1430,7 @@ class TestQueryset(unittest.TestCase): BlogPost.objects.create(content="Anonymous post...") result = BlogPost.objects.get(author=None) - assert result.author == None + assert result.author is None def test_find_dict_item(self): """Ensure that DictField items may be found. @@ -2050,7 +2050,7 @@ class TestQueryset(unittest.TestCase): assert post.title != None BlogPost.objects.update_one(unset__title=1) post.reload() - assert post.title == None + assert post.title is None pymongo_doc = BlogPost.objects.as_pymongo().first() assert "title" not in pymongo_doc @@ -4041,7 +4041,7 @@ class TestQueryset(unittest.TestCase): assert post.comment == comment Post.objects.update(unset__comment=1) post.reload() - assert post.comment == None + assert post.comment is None Comment.drop_collection() Post.drop_collection() diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index a41f9278..9706d012 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -294,7 +294,7 @@ class TestQ(unittest.TestCase): obj = self.Person.objects(Q(name=re.compile("^Gui"))).first() assert obj == person obj = self.Person.objects(Q(name=re.compile("^gui"))).first() - assert obj == None + assert obj is None obj = self.Person.objects(Q(name=re.compile("^gui", re.I))).first() assert obj == person @@ -303,7 +303,7 @@ class TestQ(unittest.TestCase): assert obj == person obj = self.Person.objects(Q(name__not=re.compile("^Gui"))).first() - assert obj == None + assert obj is None def test_q_repr(self): assert repr(Q()) == "Q(**{})" diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 3a6029c1..24cda40d 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -122,7 +122,7 @@ class TestBaseDict(unittest.TestCase): def test_get_default(self): base_dict = self._get_basedict({}) sentinel = object() - assert base_dict.get("new") == None + assert base_dict.get("new") is None assert base_dict.get("new", sentinel) is sentinel def test___setitem___calls_mark_as_changed(self): From bc0c55e49a58f1a8104ffa77f1b7b87c605504da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 1 Sep 2019 15:03:29 +0300 Subject: [PATCH 31/59] improve tests health (flake8 warnings) --- tests/document/test_dynamic.py | 3 +- tests/document/test_indexes.py | 2 +- tests/document/test_inheritance.py | 8 ++-- tests/document/test_instance.py | 40 ++++++++++---------- tests/document/test_validation.py | 3 +- tests/fields/test_binary_field.py | 3 +- tests/fields/test_boolean_field.py | 6 +-- tests/fields/test_cached_reference_field.py | 6 +-- tests/fields/test_date_field.py | 4 +- tests/fields/test_datetime_field.py | 3 +- tests/fields/test_decimal_field.py | 6 +-- tests/fields/test_dict_field.py | 5 ++- tests/fields/test_email_field.py | 6 +-- tests/fields/test_embedded_document_field.py | 3 +- tests/fields/test_fields.py | 4 +- tests/fields/test_file_field.py | 3 +- tests/fields/test_float_field.py | 2 +- tests/fields/test_int_field.py | 3 +- tests/fields/test_lazy_reference_field.py | 2 +- tests/fields/test_long_field.py | 2 +- tests/fields/test_map_field.py | 6 +-- tests/fields/test_reference_field.py | 18 +-------- tests/fields/test_url_field.py | 3 +- tests/fields/test_uuid_field.py | 6 +-- tests/fixtures.py | 4 +- tests/queryset/test_field_list.py | 5 ++- tests/queryset/test_queryset.py | 26 ++++++------- tests/queryset/test_transform.py | 2 +- tests/queryset/test_visitor.py | 2 +- tests/test_connection.py | 7 ++-- tests/test_context_managers.py | 4 +- tests/test_datastructures.py | 2 +- tests/test_replicaset_connection.py | 6 +-- tests/test_utils.py | 3 +- 34 files changed, 97 insertions(+), 111 deletions(-) diff --git a/tests/document/test_dynamic.py b/tests/document/test_dynamic.py index a6f46862..0032dfd9 100644 --- a/tests/document/test_dynamic.py +++ b/tests/document/test_dynamic.py @@ -1,8 +1,9 @@ import unittest +import pytest + from mongoengine import * from tests.utils import MongoDBTestCase -import pytest __all__ = ("TestDynamicDocument",) diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 6c31054a..dc6c5c8e 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -5,11 +5,11 @@ from datetime import datetime from nose.plugins.skip import SkipTest from pymongo.collation import Collation from pymongo.errors import OperationFailure +import pytest from six import iteritems from mongoengine import * from mongoengine.connection import get_db -import pytest class TestIndexes(unittest.TestCase): diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index b6b6088a..5072f841 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -2,6 +2,7 @@ import unittest import warnings +import pytest from six import iteritems from mongoengine import ( @@ -17,7 +18,6 @@ from mongoengine import ( from mongoengine.pymongo_support import list_collection_names from tests.fixtures import Base from tests.utils import MongoDBTestCase -import pytest class TestInheritance(MongoDBTestCase): @@ -335,9 +335,7 @@ class TestInheritance(MongoDBTestCase): name = StringField() # can't inherit because Animal didn't explicitly allow inheritance - with pytest.raises( - ValueError, match="Document Animal may not be subclassed" - ) as exc_info: + with pytest.raises(ValueError, match="Document Animal may not be subclassed"): class Dog(Animal): pass @@ -475,7 +473,7 @@ class TestInheritance(MongoDBTestCase): meta = {"abstract": True, "allow_inheritance": False} city = City(continent="asia") - assert None == city.pk + assert city.pk is None # TODO: expected error? Shouldn't we create a new error type? with pytest.raises(KeyError): setattr(city, "pk", 1) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 57815355..9d533129 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -9,6 +9,7 @@ from datetime import datetime import bson from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError +import pytest from six import iteritems from mongoengine import * @@ -36,7 +37,6 @@ from tests.fixtures import ( PickleTest, ) from tests.utils import MongoDBTestCase, get_as_pymongo -import pytest TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") @@ -96,7 +96,7 @@ class TestInstance(MongoDBTestCase): assert Log.objects.count() == 10 options = Log.objects._collection.options() - assert options["capped"] == True + assert options["capped"] is True assert options["max"] == 10 assert options["size"] == 4096 @@ -122,7 +122,7 @@ class TestInstance(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - assert options["capped"] == True + assert options["capped"] is True assert options["max"] == 10 assert options["size"] == 10 * 2 ** 20 @@ -150,7 +150,7 @@ class TestInstance(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - assert options["capped"] == True + assert options["capped"] is True assert options["size"] >= 10000 # Check that the document with odd max_size value can be recreated @@ -350,7 +350,7 @@ class TestInstance(MongoDBTestCase): name = StringField() meta = {"allow_inheritance": True} - with pytest.raises(ValueError, match="Cannot override primary key field") as e: + with pytest.raises(ValueError, match="Cannot override primary key field"): class EmailUser(User): email = StringField(primary_key=True) @@ -620,7 +620,7 @@ class TestInstance(MongoDBTestCase): f.reload() def test_reload_of_non_strict_with_special_field_name(self): - """Ensures reloading works for documents with meta strict == False.""" + """Ensures reloading works for documents with meta strict is False.""" class Post(Document): meta = {"strict": False} @@ -832,13 +832,13 @@ class TestInstance(MongoDBTestCase): t = TestDocument(status="published") t.save(clean=False) assert t.status == "published" - assert t.cleaned == False + assert t.cleaned is False t = TestDocument(status="published") - assert t.cleaned == False + assert t.cleaned is False t.save(clean=True) assert t.status == "published" - assert t.cleaned == True + assert t.cleaned is True raw_doc = get_as_pymongo(t) # Make sure clean changes makes it to the db assert raw_doc == {"status": "published", "cleaned": True, "_id": t.id} @@ -1600,7 +1600,7 @@ class TestInstance(MongoDBTestCase): person = self.Person.objects.get() assert person.name == "User" assert person.age == 21 - assert person.active == False + assert person.active is False def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( self, @@ -2521,9 +2521,9 @@ class TestInstance(MongoDBTestCase): assert all_user_dic.get(u1, False) == "OK" assert all_user_dic.get(u2, False) == "OK" assert all_user_dic.get(u3, False) == "OK" - assert all_user_dic.get(u4, False) == False # New object - assert all_user_dic.get(b1, False) == False # Other object - assert all_user_dic.get(b2, False) == False # Other object + assert all_user_dic.get(u4, False) is False # New object + assert all_user_dic.get(b1, False) is False # Other object + assert all_user_dic.get(b2, False) is False # Other object # Make sure docs are properly identified in a set (__hash__ is used # for hashing the docs). @@ -3216,7 +3216,7 @@ class TestInstance(MongoDBTestCase): def test_mixed_creation(self): """Document cannot be instantiated using mixed arguments.""" with pytest.raises(TypeError) as exc_info: - person = self.Person("Test User", age=42) + self.Person("Test User", age=42) expected_msg = ( "Instantiating a document with positional arguments is not " @@ -3227,7 +3227,7 @@ class TestInstance(MongoDBTestCase): def test_positional_creation_embedded(self): """Embedded document cannot be created using positional arguments.""" with pytest.raises(TypeError) as exc_info: - job = self.Job("Test Job", 4) + self.Job("Test Job", 4) expected_msg = ( "Instantiating a document with positional arguments is not " @@ -3238,7 +3238,7 @@ class TestInstance(MongoDBTestCase): def test_mixed_creation_embedded(self): """Embedded document cannot be created using mixed arguments.""" with pytest.raises(TypeError) as exc_info: - job = self.Job("Test Job", years=4) + self.Job("Test Job", years=4) expected_msg = ( "Instantiating a document with positional arguments is not " @@ -3432,7 +3432,7 @@ class TestInstance(MongoDBTestCase): meta = {"shard_key": ("id", "name")} p = Person.from_json('{"name": "name", "age": 27}', created=True) - assert p._created == True + assert p._created is True p.name = "new name" p.id = "12345" assert p.name == "new name" @@ -3450,7 +3450,7 @@ class TestInstance(MongoDBTestCase): meta = {"shard_key": ("id", "name")} p = Person._from_son({"name": "name", "age": 27}, created=True) - assert p._created == True + assert p._created is True p.name = "new name" p.id = "12345" assert p.name == "new name" @@ -3463,7 +3463,7 @@ class TestInstance(MongoDBTestCase): Person.objects.delete() p = Person.from_json('{"name": "name"}', created=False) - assert p._created == False + assert p._created is False assert p.id is None # Make sure the document is subsequently persisted correctly. @@ -3483,7 +3483,7 @@ class TestInstance(MongoDBTestCase): p = Person.from_json( '{"_id": "5b85a8b04ec5dc2da388296e", "name": "name"}', created=False ) - assert p._created == False + assert p._created is False assert p._changed_fields == [] assert p.name == "name" assert p.id == ObjectId("5b85a8b04ec5dc2da388296e") diff --git a/tests/document/test_validation.py b/tests/document/test_validation.py index dfae5bae..2439f283 100644 --- a/tests/document/test_validation.py +++ b/tests/document/test_validation.py @@ -2,9 +2,10 @@ import unittest from datetime import datetime +import pytest + from mongoengine import * from tests.utils import MongoDBTestCase -import pytest class TestValidatorError(MongoDBTestCase): diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py index 86ee2654..e2a1b8d6 100644 --- a/tests/fields/test_binary_field.py +++ b/tests/fields/test_binary_field.py @@ -2,12 +2,11 @@ import uuid from bson import Binary -from nose.plugins.skip import SkipTest +import pytest import six from mongoengine import * from tests.utils import MongoDBTestCase -import pytest BIN_VALUE = six.b( "\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5" diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py index b38b5ea4..041f9f56 100644 --- a/tests/fields/test_boolean_field.py +++ b/tests/fields/test_boolean_field.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -from mongoengine import * - -from tests.utils import MongoDBTestCase, get_as_pymongo import pytest +from mongoengine import * +from tests.utils import MongoDBTestCase, get_as_pymongo + class TestBooleanField(MongoDBTestCase): def test_storage(self): diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index e404aae0..bb4c57d2 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- from decimal import Decimal -from mongoengine import * - -from tests.utils import MongoDBTestCase import pytest +from mongoengine import * +from tests.utils import MongoDBTestCase + class TestCachedReferenceField(MongoDBTestCase): def test_get_and_save(self): diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index 46fa4f0f..e94ed0ce 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import datetime + +import pytest import six try: @@ -8,9 +10,7 @@ except ImportError: dateutil = None from mongoengine import * - from tests.utils import MongoDBTestCase -import pytest class TestDateField(MongoDBTestCase): diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index 8db491c6..70debac5 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import datetime as dt + +import pytest import six try: @@ -11,7 +13,6 @@ from mongoengine import * from mongoengine import connection from tests.utils import MongoDBTestCase -import pytest class TestDateTimeField(MongoDBTestCase): diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py index b5b95363..c531166f 100644 --- a/tests/fields/test_decimal_field.py +++ b/tests/fields/test_decimal_field.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- from decimal import Decimal -from mongoengine import * - -from tests.utils import MongoDBTestCase import pytest +from mongoengine import * +from tests.utils import MongoDBTestCase + class TestDecimalField(MongoDBTestCase): def test_validation(self): diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 7dda2a9c..e88128f9 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- +import pytest + from mongoengine import * from mongoengine.base import BaseDict from tests.utils import MongoDBTestCase, get_as_pymongo -import pytest class TestDictField(MongoDBTestCase): @@ -290,7 +291,7 @@ class TestDictField(MongoDBTestCase): e.save() e.update(set__mapping={"ints": [3, 4]}) e.reload() - assert BaseDict == type(e.mapping) + assert isinstance(e.mapping, BaseDict) assert {"ints": [3, 4]} == e.mapping # try creating an invalid mapping diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index 902a7c42..55255df5 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -2,11 +2,11 @@ import sys from unittest import SkipTest -from mongoengine import * - -from tests.utils import MongoDBTestCase import pytest +from mongoengine import * +from tests.utils import MongoDBTestCase + class TestEmailField(MongoDBTestCase): def test_generic_behavior(self): diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index 9e6871cc..eeddac1e 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import pytest + from mongoengine import ( Document, EmbeddedDocument, @@ -13,7 +15,6 @@ from mongoengine import ( ) from tests.utils import MongoDBTestCase -import pytest class TestEmbeddedDocumentField(MongoDBTestCase): diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index 21cc78be..b8c916f8 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -4,6 +4,7 @@ import unittest from bson import DBRef, ObjectId, SON from nose.plugins.skip import SkipTest +import pytest from mongoengine import ( BooleanField, @@ -39,7 +40,6 @@ from mongoengine.base import BaseField, EmbeddedDocumentList, _document_registry from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase -import pytest class TestField(MongoDBTestCase): @@ -1838,7 +1838,7 @@ class TestField(MongoDBTestCase): user = User.objects(bookmarks__all=[post_1]).first() - assert user != None + assert user is not None assert user.bookmarks[0] == post_1 def test_generic_reference_filter_by_dbref(self): diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index 0746db33..fb8cacff 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -137,7 +137,6 @@ class TestFileField(MongoDBTestCase): text = six.b("Hello, World!") more_text = six.b("Foo Bar") - content_type = "text/plain" streamfile = StreamFile() streamfile.save() @@ -205,7 +204,7 @@ class TestFileField(MongoDBTestCase): doc_b = GridDocument.objects.with_id(doc_a.id) doc_b.the_file.replace(f, filename="doc_b") doc_b.save() - assert doc_b.the_file.grid_id != None + assert doc_b.the_file.grid_id is not None # Test it matches doc_c = GridDocument.objects.with_id(doc_b.id) diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py index d755fb4e..a1cd7a0a 100644 --- a/tests/fields/test_float_field.py +++ b/tests/fields/test_float_field.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- +import pytest import six from mongoengine import * from tests.utils import MongoDBTestCase -import pytest class TestFloatField(MongoDBTestCase): diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py index 65a5fbad..1f9c5a77 100644 --- a/tests/fields/test_int_field.py +++ b/tests/fields/test_int_field.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +import pytest + from mongoengine import * from tests.utils import MongoDBTestCase -import pytest class TestIntField(MongoDBTestCase): diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index 8150574d..b5b8690e 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- from bson import DBRef, ObjectId +import pytest from mongoengine import * from mongoengine.base import LazyReference from tests.utils import MongoDBTestCase -import pytest class TestLazyReferenceField(MongoDBTestCase): diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py index 51f8e255..da4f04c8 100644 --- a/tests/fields/test_long_field.py +++ b/tests/fields/test_long_field.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import pytest import six try: @@ -10,7 +11,6 @@ from mongoengine import * from mongoengine.connection import get_db from tests.utils import MongoDBTestCase -import pytest class TestLongField(MongoDBTestCase): diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py index fd56ddd0..8b8b1c46 100644 --- a/tests/fields/test_map_field.py +++ b/tests/fields/test_map_field.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- import datetime -from mongoengine import * - -from tests.utils import MongoDBTestCase import pytest +from mongoengine import * +from tests.utils import MongoDBTestCase + class TestMapField(MongoDBTestCase): def test_mapfield(self): diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index 783d1315..949eac67 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- from bson import DBRef, SON +import pytest from mongoengine import * - from tests.utils import MongoDBTestCase -import pytest class TestReferenceField(MongoDBTestCase): @@ -59,21 +58,6 @@ class TestReferenceField(MongoDBTestCase): with pytest.raises(ValidationError): post1.validate() - def test_objectid_reference_fields(self): - """Make sure storing Object ID references works.""" - - class Person(Document): - name = StringField() - parent = ReferenceField("self") - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1.pk).save() - - p = Person.objects.get(name="Ross") - assert p.parent == p1 - def test_dbref_reference_fields(self): """Make sure storing references as bson.dbref.DBRef works.""" diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index e125f56a..948a4788 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +import pytest + from mongoengine import * from tests.utils import MongoDBTestCase -import pytest class TestURLField(MongoDBTestCase): diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py index b1413f95..21b7a090 100644 --- a/tests/fields/test_uuid_field.py +++ b/tests/fields/test_uuid_field.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- import uuid -from mongoengine import * - -from tests.utils import MongoDBTestCase, get_as_pymongo import pytest +from mongoengine import * +from tests.utils import MongoDBTestCase, get_as_pymongo + class Person(Document): api_key = UUIDField(binary=False) diff --git a/tests/fixtures.py b/tests/fixtures.py index 9f06f1ab..59fc3bf3 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -42,11 +42,11 @@ class PickleSignalsTest(Document): @classmethod def post_save(self, sender, document, created, **kwargs): - pickled = pickle.dumps(document) + pickle.dumps(document) @classmethod def post_delete(self, sender, document, **kwargs): - pickled = pickle.dumps(document) + pickle.dumps(document) signals.post_save.connect(PickleSignalsTest.post_save, sender=PickleSignalsTest) diff --git a/tests/queryset/test_field_list.py b/tests/queryset/test_field_list.py index a2bf6f1f..fbdde23b 100644 --- a/tests/queryset/test_field_list.py +++ b/tests/queryset/test_field_list.py @@ -1,8 +1,9 @@ import unittest +import pytest + from mongoengine import * from mongoengine.queryset import QueryFieldList -import pytest class TestQueryFieldList(unittest.TestCase): @@ -221,7 +222,7 @@ class TestOnlyExcludeAll(unittest.TestCase): assert obj.comments == [] obj = BlogPost.objects.only("various.test_dynamic.some").get() - assert obj.various["test_dynamic"].some == True + assert obj.various["test_dynamic"].some is True obj = BlogPost.objects.only("content", "comments.title").get() assert obj.content == "Had a good coffee today..." diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index f3606609..79f5793d 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -9,6 +9,7 @@ from bson import DBRef, ObjectId import pymongo from pymongo.read_preferences import ReadPreference from pymongo.results import UpdateResult +import pytest import six from six import iteritems @@ -24,7 +25,6 @@ from mongoengine.queryset import ( QuerySetManager, queryset_manager, ) -import pytest class db_ops_tracker(query_counter): @@ -1712,11 +1712,11 @@ class TestQueryset(unittest.TestCase): post = BlogPost(content="Watching TV", category=lameness) post.save() - assert 1 == BlogPost.objects.count() - assert "Lameness" == BlogPost.objects.first().category.name + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.first().category.name == "Lameness" Category.objects.delete() - assert 1 == BlogPost.objects.count() - assert None == BlogPost.objects.first().category + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.first().category is None def test_reverse_delete_rule_nullify_on_abstract_document(self): """Ensure nullification of references to deleted documents when @@ -1739,11 +1739,11 @@ class TestQueryset(unittest.TestCase): BlogPost(content="Watching TV", author=me).save() - assert 1 == BlogPost.objects.count() - assert me == BlogPost.objects.first().author + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.first().author == me self.Person.objects(name="Test User").delete() - assert 1 == BlogPost.objects.count() - assert None == BlogPost.objects.first().author + assert BlogPost.objects.count() == 1 + assert BlogPost.objects.first().author is None def test_reverse_delete_rule_deny(self): """Ensure deletion gets denied on documents that still have references @@ -1896,7 +1896,7 @@ class TestQueryset(unittest.TestCase): """ p1 = self.Person(name="User Z", age=20).save() del_result = p1.delete(w=0) - assert None == del_result + assert del_result is None def test_reference_field_find(self): """Ensure cascading deletion of referring documents from the database. @@ -2047,7 +2047,7 @@ class TestQueryset(unittest.TestCase): post = BlogPost(title="garbage").save() - assert post.title != None + assert post.title is not None BlogPost.objects.update_one(unset__title=1) post.reload() assert post.title is None @@ -5006,7 +5006,7 @@ class TestQueryset(unittest.TestCase): # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. # This may be a bug in PyPy (PyPy/#1802) but it does not affect # the behavior of MongoEngine. - assert None == people._len + assert people._len is None assert q == 1 list(people) @@ -5053,7 +5053,7 @@ class TestQueryset(unittest.TestCase): Person(name="a").save() qs = Person.objects() _ = list(qs) - with pytest.raises(OperationError, match="QuerySet already cached") as ctx_err: + with pytest.raises(OperationError, match="QuerySet already cached"): qs.no_cache() def test_no_cached_queryset_no_cache_back_to_cache(self): diff --git a/tests/queryset/test_transform.py b/tests/queryset/test_transform.py index be28c3b8..3898809e 100644 --- a/tests/queryset/test_transform.py +++ b/tests/queryset/test_transform.py @@ -1,10 +1,10 @@ import unittest from bson.son import SON +import pytest from mongoengine import * from mongoengine.queryset import Q, transform -import pytest class TestTransform(unittest.TestCase): diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index 9706d012..e597e3d8 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -3,11 +3,11 @@ import re import unittest from bson import ObjectId +import pytest from mongoengine import * from mongoengine.errors import InvalidQueryError from mongoengine.queryset import Q -import pytest class TestQ(unittest.TestCase): diff --git a/tests/test_connection.py b/tests/test_connection.py index 8db69b0c..07edcbba 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -3,10 +3,10 @@ import datetime from bson.tz_util import utc from nose.plugins.skip import SkipTest import pymongo -from pymongo import MongoClient -from pymongo import ReadPreference -from pymongo.errors import InvalidName, OperationFailure +from pymongo import MongoClient, ReadPreference +from pymongo.errors import InvalidName, OperationFailure +import pytest try: import unittest2 as unittest @@ -29,7 +29,6 @@ from mongoengine.connection import ( get_connection, get_db, ) -import pytest def get_tz_awareness(connection): diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index cf4dd100..d68afbb0 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -1,5 +1,7 @@ import unittest +import pytest + from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import ( @@ -10,7 +12,6 @@ from mongoengine.context_managers import ( switch_db, ) from mongoengine.pymongo_support import count_documents -import pytest class ContextManagersTest(unittest.TestCase): @@ -214,7 +215,6 @@ class ContextManagersTest(unittest.TestCase): raise TypeError() def test_query_counter_does_not_swallow_exception(self): - with pytest.raises(TypeError): with query_counter() as q: raise TypeError() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 24cda40d..ad421a72 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,7 +1,7 @@ import unittest -from six import iterkeys import pytest +from six import iterkeys from mongoengine import Document from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index c1ea407c..5d83da00 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,7 +1,6 @@ import unittest -from pymongo import MongoClient -from pymongo import ReadPreference +from pymongo import MongoClient, ReadPreference import mongoengine from mongoengine.connection import ConnectionFailure @@ -25,14 +24,13 @@ class ConnectionTest(unittest.TestCase): def test_replicaset_uri_passes_read_preference(self): """Requires a replica set called "rs" on port 27017 """ - try: conn = mongoengine.connect( db="mongoenginetest", host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=READ_PREF, ) - except ConnectionFailure as e: + except ConnectionFailure: return if not isinstance(conn, CONN_CLASS): diff --git a/tests/test_utils.py b/tests/test_utils.py index ccb44aac..ef396571 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,10 @@ import re import unittest -from mongoengine.base.utils import LazyRegexCompiler import pytest +from mongoengine.base.utils import LazyRegexCompiler + signal_output = [] From 799cdafae63b3ac22aafed4400bff10796a8ffee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 1 Sep 2019 15:27:11 +0300 Subject: [PATCH 32/59] remove references to nose --- CONTRIBUTING.rst | 2 +- requirements.txt | 1 - setup.py | 2 +- tests/document/test_indexes.py | 3 +-- tests/fields/test_email_field.py | 6 ------ tests/fields/test_fields.py | 13 ++++++------- tests/fields/test_file_field.py | 29 +++++++++-------------------- tests/test_connection.py | 31 +++++++++++++++---------------- tests/utils.py | 10 ++++------ tox.ini | 1 - 10 files changed, 37 insertions(+), 61 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 4711c1d3..56bae31f 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -54,7 +54,7 @@ General Guidelines should adapt to the breaking change in docs/upgrade.rst. - Write inline documentation for new classes and methods. - Write tests and make sure they pass (make sure you have a mongod - running on the default port, then execute ``python setup.py nosetests`` + running on the default port, then execute ``python setup.py test`` from the cmd line to run the test suite). - Ensure tests pass on all supported Python, PyMongo, and MongoDB versions. You can test various Python and PyMongo versions locally by executing diff --git a/requirements.txt b/requirements.txt index 46eabac3..43e5261b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -nose pymongo>=3.4 six==1.10.0 Sphinx==1.5.5 diff --git a/setup.py b/setup.py index 2bc1ae1c..939e8e50 100644 --- a/setup.py +++ b/setup.py @@ -120,7 +120,7 @@ extra_opts = { } if sys.version_info[0] == 3: extra_opts["use_2to3"] = True - if "test" in sys.argv or "nosetests" in sys.argv: + if "test" in sys.argv: extra_opts["packages"] = find_packages() extra_opts["package_data"] = { "tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"] diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index dc6c5c8e..90402c46 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -2,7 +2,6 @@ import unittest from datetime import datetime -from nose.plugins.skip import SkipTest from pymongo.collation import Collation from pymongo.errors import OperationFailure import pytest @@ -251,7 +250,7 @@ class TestIndexes(unittest.TestCase): def test_explicit_geohaystack_index(self): """Ensure that geohaystack indexes work when created via meta[indexes] """ - raise SkipTest( + pytest.skip( "GeoHaystack index creation is not supported for now" "from meta, as it requires a bucketSize parameter." ) diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index 55255df5..5a58ede4 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import sys -from unittest import SkipTest import pytest @@ -46,11 +45,6 @@ class TestEmailField(MongoDBTestCase): user.validate() def test_email_field_unicode_user(self): - # Don't run this test on pypy3, which doesn't support unicode regex: - # https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode - if sys.version_info[:2] == (3, 2): - raise SkipTest("unicode email addresses are not supported on PyPy 3") - class User(Document): email = EmailField() diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index b8c916f8..652f6903 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -3,7 +3,6 @@ import datetime import unittest from bson import DBRef, ObjectId, SON -from nose.plugins.skip import SkipTest import pytest from mongoengine import ( @@ -1239,17 +1238,17 @@ class TestField(MongoDBTestCase): a = A._from_son(SON([("fb", SON([("fc", SON([("txt", "hi")]))]))])) assert a.b.c.txt == "hi" + @pytest.mark.xfail( + reason="Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet", + raises=NotRegistered, + ) def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet( self, ): - raise SkipTest( - "Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet" - ) - class MyDoc2(Document): - emb = EmbeddedDocumentField("MyDoc") + emb = EmbeddedDocumentField("MyFunkyDoc123") - class MyDoc(EmbeddedDocument): + class MyFunkyDoc123(EmbeddedDocument): name = StringField() def test_embedded_document_validation(self): diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index fb8cacff..bfc86511 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -5,7 +5,7 @@ import tempfile import unittest import gridfs -from nose.plugins.skip import SkipTest +import pytest import six from mongoengine import * @@ -21,6 +21,8 @@ except ImportError: from tests.utils import MongoDBTestCase +require_pil = pytest.mark.skipif(not HAS_PIL, reason="PIL not installed") + TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "mongoengine.png") TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), "mongodb_leaf.png") @@ -377,10 +379,8 @@ class TestFileField(MongoDBTestCase): assert len(list(files)) == 0 assert len(list(chunks)) == 0 + @require_pil def test_image_field(self): - if not HAS_PIL: - raise SkipTest("PIL not installed") - class TestImage(Document): image = ImageField() @@ -411,10 +411,8 @@ class TestFileField(MongoDBTestCase): t.image.delete() + @require_pil def test_image_field_reassigning(self): - if not HAS_PIL: - raise SkipTest("PIL not installed") - class TestFile(Document): the_file = ImageField() @@ -428,10 +426,8 @@ class TestFileField(MongoDBTestCase): test_file.save() assert test_file.the_file.size == (45, 101) + @require_pil def test_image_field_resize(self): - if not HAS_PIL: - raise SkipTest("PIL not installed") - class TestImage(Document): image = ImageField(size=(185, 37)) @@ -451,10 +447,8 @@ class TestFileField(MongoDBTestCase): t.image.delete() + @require_pil def test_image_field_resize_force(self): - if not HAS_PIL: - raise SkipTest("PIL not installed") - class TestImage(Document): image = ImageField(size=(185, 37, True)) @@ -474,10 +468,8 @@ class TestFileField(MongoDBTestCase): t.image.delete() + @require_pil def test_image_field_thumbnail(self): - if not HAS_PIL: - raise SkipTest("PIL not installed") - class TestImage(Document): image = ImageField(thumbnail_size=(92, 18)) @@ -546,11 +538,8 @@ class TestFileField(MongoDBTestCase): assert putfile == copy.copy(putfile) assert putfile == copy.deepcopy(putfile) + @require_pil def test_get_image_by_grid_id(self): - - if not HAS_PIL: - raise SkipTest("PIL not installed") - class TestImage(Document): image1 = ImageField() diff --git a/tests/test_connection.py b/tests/test_connection.py index 07edcbba..acaab904 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,7 +1,6 @@ import datetime from bson.tz_util import utc -from nose.plugins.skip import SkipTest import pymongo from pymongo import MongoClient, ReadPreference @@ -35,6 +34,18 @@ def get_tz_awareness(connection): return connection.codec_options.tz_aware +try: + import mongomock + + MONGOMOCK_INSTALLED = True +except ImportError: + MONGOMOCK_INSTALLED = False + +require_mongomock = pytest.mark.skipif( + not MONGOMOCK_INSTALLED, reason="you need mongomock installed to run this testcase" +) + + class ConnectionTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -212,14 +223,10 @@ class ConnectionTest(unittest.TestCase): non_string_db_name = ["e. g. list instead of a string"] connect(non_string_db_name) + @require_mongomock def test_connect_in_mocking(self): """Ensure that the connect() method works properly in mocking. """ - try: - import mongomock - except ImportError: - raise SkipTest("you need mongomock installed to run this testcase") - connect("mongoenginetest", host="mongomock://localhost") conn = get_connection() assert isinstance(conn, mongomock.MongoClient) @@ -261,14 +268,10 @@ class ConnectionTest(unittest.TestCase): conn = get_connection("testdb7") assert isinstance(conn, mongomock.MongoClient) + @require_mongomock def test_default_database_with_mocking(self): """Ensure that the default database is correctly set when using mongomock. """ - try: - import mongomock - except ImportError: - raise SkipTest("you need mongomock installed to run this testcase") - disconnect_all() class SomeDocument(Document): @@ -281,16 +284,12 @@ class ConnectionTest(unittest.TestCase): assert conn.get_default_database().name == "mongoenginetest" assert conn.database_names()[0] == "mongoenginetest" + @require_mongomock def test_connect_with_host_list(self): """Ensure that the connect() method works when host is a list Uses mongomock to test w/o needing multiple mongod/mongos processes """ - try: - import mongomock - except ImportError: - raise SkipTest("you need mongomock installed to run this testcase") - connect(host=["mongomock://localhost"]) conn = get_connection() assert isinstance(conn, mongomock.MongoClient) diff --git a/tests/utils.py b/tests/utils.py index 0719d6ef..7ee22c3c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,6 @@ -import operator import unittest -from nose.plugins.skip import SkipTest +import pytest from mongoengine import connect from mongoengine.connection import disconnect_all, get_db @@ -37,7 +36,7 @@ def get_as_pymongo(doc): def _decorated_with_ver_requirement(func, mongo_version_req, oper): """Return a MongoDB version requirement decorator. - The resulting decorator will raise a SkipTest exception if the current + The resulting decorator will skip the test if the current MongoDB version doesn't match the provided version/operator. For example, if you define a decorator like so: @@ -59,9 +58,8 @@ def _decorated_with_ver_requirement(func, mongo_version_req, oper): if oper(mongodb_v, mongo_version_req): return func(*args, **kwargs) - raise SkipTest( - "Needs MongoDB v{}+".format(".".join(str(n) for n in mongo_version_req)) - ) + pretty_version = ".".join(str(n) for n in mongo_version_req) + pytest.skip("Needs MongoDB v{}+".format(pretty_version)) _inner.__name__ = func.__name__ _inner.__doc__ = func.__doc__ diff --git a/tox.ini b/tox.ini index 94ccc9cf..349b5577 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,6 @@ envlist = {py27,py35,pypy,pypy3}-{mg34,mg36} commands = python setup.py test {posargs} deps = - nose mg34: pymongo>=3.4,<3.5 mg36: pymongo>=3.6,<3.7 mg39: pymongo>=3.9,<4.0 From d8924ed8920f856d3754db38c7c5e8adf0f96ece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 2 Sep 2019 08:50:46 +0300 Subject: [PATCH 33/59] remove inheritance from unittest.TestCase on basic test classes --- tests/queryset/test_field_list.py | 2 +- tests/test_common.py | 2 +- tests/test_context_managers.py | 2 +- tests/test_datastructures.py | 10 +++++----- tests/test_signals.py | 2 +- tests/test_utils.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/queryset/test_field_list.py b/tests/queryset/test_field_list.py index fbdde23b..be7903fd 100644 --- a/tests/queryset/test_field_list.py +++ b/tests/queryset/test_field_list.py @@ -6,7 +6,7 @@ from mongoengine import * from mongoengine.queryset import QueryFieldList -class TestQueryFieldList(unittest.TestCase): +class TestQueryFieldList: def test_empty(self): q = QueryFieldList() assert not q diff --git a/tests/test_common.py b/tests/test_common.py index 6b6f18de..1779a91b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -6,7 +6,7 @@ from mongoengine import Document from mongoengine.common import _import_class -class TestCommon(unittest.TestCase): +class TestCommon: def test__import_class(self): doc_cls = _import_class("Document") assert doc_cls is Document diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index d68afbb0..c10a0224 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -14,7 +14,7 @@ from mongoengine.context_managers import ( from mongoengine.pymongo_support import count_documents -class ContextManagersTest(unittest.TestCase): +class TestContextManagers: def test_switch_db_context_manager(self): connect("mongoenginetest") register_connection("testdb-1", "mongoenginetest2") diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index ad421a72..7b5d7d11 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -15,7 +15,7 @@ class DocumentStub(object): self._changed_fields.append(key) -class TestBaseDict(unittest.TestCase): +class TestBaseDict: @staticmethod def _get_basedict(dict_items): """Get a BaseList bound to a fake document instance""" @@ -151,7 +151,7 @@ class TestBaseDict(unittest.TestCase): assert base_dict._instance._changed_fields == ["my_name.a_new_attr"] -class TestBaseList(unittest.TestCase): +class TestBaseList: @staticmethod def _get_baselist(list_items): """Get a BaseList bound to a fake document instance""" @@ -360,12 +360,12 @@ class TestBaseList(unittest.TestCase): class TestStrictDict(unittest.TestCase): - def strict_dict_class(self, *args, **kwargs): - return StrictDict.create(*args, **kwargs) - def setUp(self): self.dtype = self.strict_dict_class(("a", "b", "c")) + def strict_dict_class(self, *args, **kwargs): + return StrictDict.create(*args, **kwargs) + def test_init(self): d = self.dtype(a=1, b=1, c=1) assert (d.a, d.b, d.c) == (1, 1, 1) diff --git a/tests/test_signals.py b/tests/test_signals.py index b217712b..d79eaf75 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -7,7 +7,7 @@ from mongoengine import signals signal_output = [] -class SignalTests(unittest.TestCase): +class TestSignal(unittest.TestCase): """ Testing signals before/after saving and deleting. """ diff --git a/tests/test_utils.py b/tests/test_utils.py index ef396571..dd178273 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ from mongoengine.base.utils import LazyRegexCompiler signal_output = [] -class LazyRegexCompilerTest(unittest.TestCase): +class TestLazyRegexCompiler: def test_lazy_regex_compiler_verify_laziness_of_descriptor(self): class UserEmail(object): EMAIL_REGEX = LazyRegexCompiler("@", flags=32) From 81647d67a0c97b3d0fac6cb687c385fc2827a108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 31 Oct 2019 23:06:40 +0100 Subject: [PATCH 34/59] fix recent tests update with unittest2pytest --- tests/document/test_indexes.py | 16 ++++++---------- tests/queryset/test_queryset.py | 9 +++++---- tests/test_connection.py | 4 ++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 90402c46..be857b59 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -544,23 +544,19 @@ class TestIndexes(unittest.TestCase): BlogPost(name=name).save() query_result = BlogPost.objects.collation(base).order_by("name") - self.assertEqual( - [x.name for x in query_result], sorted(names, key=lambda x: x.lower()) - ) - self.assertEqual(5, query_result.count()) + assert [x.name for x in query_result] == sorted(names, key=lambda x: x.lower()) + assert 5 == query_result.count() query_result = BlogPost.objects.collation(Collation(**base)).order_by("name") - self.assertEqual( - [x.name for x in query_result], sorted(names, key=lambda x: x.lower()) - ) - self.assertEqual(5, query_result.count()) + assert [x.name for x in query_result] == sorted(names, key=lambda x: x.lower()) + assert 5 == query_result.count() incorrect_collation = {"arndom": "wrdo"} - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): BlogPost.objects.collation(incorrect_collation).count() query_result = BlogPost.objects.collation({}).order_by("name") - self.assertEqual([x.name for x in query_result], sorted(names)) + assert [x.name for x in query_result] == sorted(names) def test_unique(self): """Ensure that uniqueness constraints are applied to fields. diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 79f5793d..7812ab66 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -4626,7 +4626,8 @@ class TestQueryset(unittest.TestCase): bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) assert bars._read_preference == ReadPreference.SECONDARY_PREFERRED assert ( - bars._cursor.collection.read_preference == ReadPreference.SECONDARY_PREFERRED + bars._cursor.collection.read_preference + == ReadPreference.SECONDARY_PREFERRED ) # Make sure that `.read_preference(...)` does accept string values. @@ -5765,13 +5766,13 @@ class TestQueryset(unittest.TestCase): def test_no_cursor_timeout(self): qs = self.Person.objects() - self.assertEqual(qs._cursor_args, {}) # ensure no regression of #2148 + assert qs._cursor_args == {} # ensure no regression of #2148 qs = self.Person.objects().timeout(True) - self.assertEqual(qs._cursor_args, {}) + assert qs._cursor_args == {} qs = self.Person.objects().timeout(False) - self.assertEqual(qs._cursor_args, {"no_cursor_timeout": True}) + assert qs._cursor_args == {"no_cursor_timeout": True} if __name__ == "__main__": diff --git a/tests/test_connection.py b/tests/test_connection.py index acaab904..e40a6994 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -625,8 +625,8 @@ class ConnectionTest(unittest.TestCase): alias="conn1", host="mongodb://localhost/testing?w=1&journal=true" ) conn2 = connect("testing", alias="conn2", w=1, journal=True) - self.assertEqual(conn1.write_concern.document, {"w": 1, "j": True}) - self.assertEqual(conn2.write_concern.document, {"w": 1, "j": True}) + assert conn1.write_concern.document == {"w": 1, "j": True} + assert conn2.write_concern.document == {"w": 1, "j": True} def test_connect_with_replicaset_via_uri(self): """Ensure connect() works when specifying a replicaSet via the From ff749a7a0a7b9a86b3745cad393effcf594db5f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Philip=20G=C3=B6pfert?= Date: Wed, 6 Nov 2019 10:35:16 +0100 Subject: [PATCH 35/59] Specify version of requirement In `README.rst`, a version of `six` of at least `1.10.0` is specified. This was missing from the requirements, potentially leading to broken installations. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 939e8e50..ceb5afad 100644 --- a/setup.py +++ b/setup.py @@ -143,7 +143,7 @@ setup( long_description=LONG_DESCRIPTION, platforms=["any"], classifiers=CLASSIFIERS, - install_requires=["pymongo>=3.4", "six"], + install_requires=["pymongo>=3.4", "six>=1.10.0"], cmdclass={"test": PyTest}, **extra_opts ) From d3420918cd9804243900c8566b7e155044e668cb Mon Sep 17 00:00:00 2001 From: Eloi Zalczer Date: Mon, 18 Nov 2019 17:16:06 +0100 Subject: [PATCH 36/59] Added alias parameter in query_counter --- mongoengine/context_managers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index d8dfeaac..5920b724 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -182,10 +182,10 @@ class query_counter(object): - Some queries are ignored by default by the counter (killcursors, db.system.indexes) """ - def __init__(self): + def __init__(self, alias=DEFAULT_CONNECTION_NAME): """Construct the query_counter """ - self.db = get_db() + self.db = get_db(alias=alias) self.initial_profiling_level = None self._ctx_query_counter = 0 # number of queries issued by the context From 0bf08db7b943eba85d7e0dd85d161df4e615a371 Mon Sep 17 00:00:00 2001 From: Eloi Zalczer Date: Mon, 2 Dec 2019 10:07:33 +0100 Subject: [PATCH 37/59] Added test case for query_counter alias --- tests/document/test_instance.py | 40 ++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 1d3e18d0..c8ad2ff3 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -2825,6 +2825,44 @@ class TestInstance(MongoDBTestCase): assert "testdb-1" == B._meta.get("db_alias") + def test_query_counter_alias(self): + """query_counter works properly with db aliases?""" + # Register a connection with db_alias testdb-1 + register_connection("testdb-1", "mongoenginetest2") + + class A(Document): + """Uses default db_alias + """ + + name = StringField() + + class B(Document): + """Uses testdb-1 db_alias + """ + + name = StringField() + meta = {"db_alias": "testdb-1"} + + with query_counter() as q: + assert q == 0 + a = A.objects.create(name="A") + assert q == 1 + a = A.objects.first() + assert q == 2 + a.name = "Test A" + a.save() + assert q == 3 + + with query_counter(alias="testdb-1") as q: + assert q == 0 + b = B.objects.create(name="B") + assert q == 1 + b = B.objects.first() + assert q == 2 + b.name = "Test B" + b.save() + assert q == 3 + def test_db_ref_usage(self): """DB Ref usage in dict_fields.""" @@ -3644,7 +3682,7 @@ class TestInstance(MongoDBTestCase): User.objects().select_related() def test_embedded_document_failed_while_loading_instance_when_it_is_not_a_dict( - self + self, ): class LightSaber(EmbeddedDocument): color = StringField() From 0458ef869eb07d98c2ebb4da82dc3ca0bcd94a49 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Tue, 3 Dec 2019 00:42:10 +0100 Subject: [PATCH 38/59] Add __eq__ to Q and Q operations --- mongoengine/queryset/visitor.py | 12 ++++++++++++ tests/queryset/test_visitor.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 0fe139fd..058c722a 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -96,9 +96,11 @@ class QNode(object): """Combine this node with another node into a QCombination object. """ + # If the other Q() is empty, ignore it and just use `self`. if getattr(other, "empty", True): return self + # Or if this Q is empty, ignore it and just use `other`. if self.empty: return other @@ -146,6 +148,13 @@ class QCombination(QNode): def empty(self): return not bool(self.children) + def __eq__(self, other): + return ( + self.__class__ == other.__class__ + and self.operation == other.operation + and self.children == other.children + ) + class Q(QNode): """A simple query object, used in a query tree to build up more complex @@ -164,3 +173,6 @@ class Q(QNode): @property def empty(self): return not bool(self.query) + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.query == other.query diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index e597e3d8..e8504abd 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -374,6 +374,38 @@ class TestQ(unittest.TestCase): == 2 ) + def test_equality(self): + assert Q(name="John") == Q(name="John") + assert Q() == Q() + + def test_inequality(self): + assert Q(name="John") != Q(name="Ralph") + + def test_operation_equality(self): + q1 = Q(name="John") | Q(title="Sir") & Q(surname="Paul") + q2 = Q(name="John") | Q(title="Sir") & Q(surname="Paul") + assert q1 == q2 + + def test_operation_inequality(self): + q1 = Q(name="John") | Q(title="Sir") + q2 = Q(title="Sir") | Q(name="John") + assert q1 != q2 + + def test_combine_and_empty(self): + q = Q(x=1) + assert q & Q() == q + assert Q() & q == q + + def test_combine_and_both_empty(self): + assert Q() & Q() == Q() + + def test_combine_or_empty(self): + q = Q(x=1) + assert q | Q() == q + assert Q() | q == q + + def test_combine_or_both_empty(self): + assert Q() | Q() == Q() if __name__ == "__main__": unittest.main() From 091238a2cfd3e77fba724ad8264bae78c360c675 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Tue, 3 Dec 2019 00:54:46 +0100 Subject: [PATCH 39/59] Update Authors --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 45a754cc..aa044bd2 100644 --- a/AUTHORS +++ b/AUTHORS @@ -252,3 +252,4 @@ that much better: * Paulo Amaral (https://github.com/pauloAmaral) * Gaurav Dadhania (https://github.com/GVRV) * Yurii Andrieiev (https://github.com/yandrieiev) + * Filip Kucharczyk (https://github.com/Pacu2) From f7f0e10d4d3748381007617a119758e40bdd76bb Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Tue, 3 Dec 2019 00:54:53 +0100 Subject: [PATCH 40/59] Update changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0b4893a6..102e826d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -23,6 +23,7 @@ Development - Switch from nosetest to pytest as test runner #2114 - The codebase is now formatted using ``black``. #2109 - In bulk write insert, the detailed error message would raise in exception. +- Added ability to compare Q and Q operations #2204 Changes in 0.18.2 ================= From 3f75f30f2675375fef0bf14fdbff63480676e056 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Tue, 3 Dec 2019 09:03:49 +0100 Subject: [PATCH 41/59] Run black --- tests/queryset/test_visitor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index e8504abd..afa00839 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -407,5 +407,6 @@ class TestQ(unittest.TestCase): def test_combine_or_both_empty(self): assert Q() | Q() == Q() + if __name__ == "__main__": unittest.main() From af82c07acc3226f7ed65818536680343f2fd83c6 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Tue, 3 Dec 2019 09:19:02 +0100 Subject: [PATCH 42/59] Reformat with black --- tests/document/test_instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 1d3e18d0..173e02f2 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3644,7 +3644,7 @@ class TestInstance(MongoDBTestCase): User.objects().select_related() def test_embedded_document_failed_while_loading_instance_when_it_is_not_a_dict( - self + self, ): class LightSaber(EmbeddedDocument): color = StringField() From 78b240b740b34de450d30a00c669a9283a8b37de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 4 Dec 2019 21:49:17 +0100 Subject: [PATCH 43/59] updated changelog + improved query_counter test --- docs/changelog.rst | 1 + mongoengine/context_managers.py | 2 +- tests/document/test_instance.py | 38 --------------------------- tests/test_context_managers.py | 46 +++++++++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 39 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 102e826d..99081957 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,7 @@ Development - The codebase is now formatted using ``black``. #2109 - In bulk write insert, the detailed error message would raise in exception. - Added ability to compare Q and Q operations #2204 +- Added ability to use a db alias on query_counter #2194 Changes in 0.18.2 ================= diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 5920b724..1592ceef 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -171,7 +171,7 @@ class no_sub_classes(object): class query_counter(object): """Query_counter context manager to get the number of queries. This works by updating the `profiling_level` of the database so that all queries get logged, - resetting the db.system.profile collection at the beginnig of the context and counting the new entries. + resetting the db.system.profile collection at the beginning of the context and counting the new entries. This was designed for debugging purpose. In fact it is a global counter so queries issued by other threads/processes can interfere with it diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index c8ad2ff3..173e02f2 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -2825,44 +2825,6 @@ class TestInstance(MongoDBTestCase): assert "testdb-1" == B._meta.get("db_alias") - def test_query_counter_alias(self): - """query_counter works properly with db aliases?""" - # Register a connection with db_alias testdb-1 - register_connection("testdb-1", "mongoenginetest2") - - class A(Document): - """Uses default db_alias - """ - - name = StringField() - - class B(Document): - """Uses testdb-1 db_alias - """ - - name = StringField() - meta = {"db_alias": "testdb-1"} - - with query_counter() as q: - assert q == 0 - a = A.objects.create(name="A") - assert q == 1 - a = A.objects.first() - assert q == 2 - a.name = "Test A" - a.save() - assert q == 3 - - with query_counter(alias="testdb-1") as q: - assert q == 0 - b = B.objects.create(name="B") - assert q == 1 - b = B.objects.first() - assert q == 2 - b.name = "Test B" - b.save() - assert q == 3 - def test_db_ref_usage(self): """DB Ref usage in dict_fields.""" diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index c10a0224..fa3f5960 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -282,6 +282,52 @@ class TestContextManagers: assert q < 1000 assert q <= int(q) + def test_query_counter_alias(self): + """query_counter works properly with db aliases?""" + # Register a connection with db_alias testdb-1 + register_connection("testdb-1", "mongoenginetest2") + + class A(Document): + """Uses default db_alias""" + + name = StringField() + + class B(Document): + """Uses testdb-1 db_alias""" + + name = StringField() + meta = {"db_alias": "testdb-1"} + + A.drop_collection() + B.drop_collection() + + with query_counter() as q: + assert q == 0 + A.objects.create(name="A") + assert q == 1 + a = A.objects.first() + assert q == 2 + a.name = "Test A" + a.save() + assert q == 3 + # querying the other db should'nt alter the counter + B.objects().first() + assert q == 3 + + with query_counter(alias="testdb-1") as q: + assert q == 0 + B.objects.create(name="B") + assert q == 1 + b = B.objects.first() + assert q == 2 + b.name = "Test B" + b.save() + assert b.name == "Test B" + assert q == 3 + # querying the other db should'nt alter the counter + A.objects().first() + assert q == 3 + def test_query_counter_counts_getmore_queries(self): connect("mongoenginetest") db = get_db() From cb77bb6b69bb80c79c97d2f0792d173fc2f443d4 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Thu, 5 Dec 2019 00:21:03 +0100 Subject: [PATCH 44/59] Implement __bool__ on Q and QCombination --- mongoengine/queryset/base.py | 2 +- mongoengine/queryset/visitor.py | 20 ++++++++++++++++---- tests/queryset/test_visitor.py | 11 +++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index a648391e..c6f467cc 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -686,7 +686,7 @@ class BaseQuerySet(object): .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set """ queryset = self.clone() - if not queryset._query_obj.empty: + if queryset._query_obj: msg = "Cannot use a filter whilst using `with_id`" raise InvalidQueryError(msg) return queryset.filter(pk=object_id).first() diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 058c722a..a7295ae5 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -2,6 +2,8 @@ import copy from mongoengine.errors import InvalidQueryError from mongoengine.queryset import transform +import warnings + __all__ = ("Q", "QNode") @@ -101,13 +103,15 @@ class QNode(object): return self # Or if this Q is empty, ignore it and just use `other`. - if self.empty: + if not self: return other return QCombination(operation, [self, other]) @property def empty(self): + msg = "'empty' property is deprecated in favour of using 'not bool(filter)" + warnings.warn(msg, DeprecationWarning) return False def __or__(self, other): @@ -137,6 +141,9 @@ class QCombination(QNode): op = " & " if self.operation is self.AND else " | " return "(%s)" % op.join([repr(node) for node in self.children]) + def __bool__(self): + return bool(self.children) + def accept(self, visitor): for i in range(len(self.children)): if isinstance(self.children[i], QNode): @@ -146,6 +153,8 @@ class QCombination(QNode): @property def empty(self): + msg = "'empty' property is deprecated in favour of using 'not bool(filter)" + warnings.warn(msg, DeprecationWarning) return not bool(self.children) def __eq__(self, other): @@ -167,12 +176,15 @@ class Q(QNode): def __repr__(self): return "Q(**%s)" % repr(self.query) + def __bool__(self): + return bool(self.query) + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.query == other.query + def accept(self, visitor): return visitor.visit_query(self) @property def empty(self): return not bool(self.query) - - def __eq__(self, other): - return self.__class__ == other.__class__ and self.query == other.query diff --git a/tests/queryset/test_visitor.py b/tests/queryset/test_visitor.py index afa00839..81e0f253 100644 --- a/tests/queryset/test_visitor.py +++ b/tests/queryset/test_visitor.py @@ -407,6 +407,17 @@ class TestQ(unittest.TestCase): def test_combine_or_both_empty(self): assert Q() | Q() == Q() + def test_q_bool(self): + assert Q(name="John") + assert not Q() + + def test_combine_bool(self): + assert not Q() & Q() + assert Q() & Q(name="John") + assert Q(name="John") & Q() + assert Q() | Q(name="John") + assert Q(name="John") | Q() + if __name__ == "__main__": unittest.main() From bd6c52e025fbe60473a6f009b100eb4d8edbfe83 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Thu, 5 Dec 2019 00:30:03 +0100 Subject: [PATCH 45/59] Changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 99081957..e2ffa41e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,6 +25,7 @@ Development - In bulk write insert, the detailed error message would raise in exception. - Added ability to compare Q and Q operations #2204 - Added ability to use a db alias on query_counter #2194 +- Added ability to check if Q or Q operations is empty by parsing them to bool #2210 Changes in 0.18.2 ================= From 5f14d958ac32925df18e757f77f729f3bfb79c5a Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Thu, 5 Dec 2019 00:46:57 +0100 Subject: [PATCH 46/59] Sort imports --- mongoengine/queryset/visitor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index a7295ae5..8038d23f 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -1,9 +1,8 @@ import copy +import warnings from mongoengine.errors import InvalidQueryError from mongoengine.queryset import transform -import warnings - __all__ = ("Q", "QNode") From 6e8196d475953f88bd70207e81234bc07e1526d0 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Thu, 5 Dec 2019 01:31:37 +0100 Subject: [PATCH 47/59] Python 2.x compatibility --- mongoengine/queryset/visitor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 8038d23f..7faed897 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -143,6 +143,8 @@ class QCombination(QNode): def __bool__(self): return bool(self.children) + __nonzero__ = __bool__ # For Py2 support + def accept(self, visitor): for i in range(len(self.children)): if isinstance(self.children[i], QNode): @@ -178,6 +180,8 @@ class Q(QNode): def __bool__(self): return bool(self.query) + __nonzero__ = __bool__ # For Py2 support + def __eq__(self, other): return self.__class__ == other.__class__ and self.query == other.query From 1b38309d70efc122720d9c5d3fcc6d362436ed62 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Fri, 6 Dec 2019 10:14:22 +0100 Subject: [PATCH 48/59] Revert 'empty' usage to it's previous state --- mongoengine/queryset/base.py | 2 +- mongoengine/queryset/visitor.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index c6f467cc..a648391e 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -686,7 +686,7 @@ class BaseQuerySet(object): .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set """ queryset = self.clone() - if queryset._query_obj: + if not queryset._query_obj.empty: msg = "Cannot use a filter whilst using `with_id`" raise InvalidQueryError(msg) return queryset.filter(pk=object_id).first() diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 7faed897..470839c1 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -102,14 +102,14 @@ class QNode(object): return self # Or if this Q is empty, ignore it and just use `other`. - if not self: + if self.empty: return other return QCombination(operation, [self, other]) @property def empty(self): - msg = "'empty' property is deprecated in favour of using 'not bool(filter)" + msg = "'empty' property is deprecated in favour of using 'not bool(filter)'" warnings.warn(msg, DeprecationWarning) return False @@ -154,7 +154,7 @@ class QCombination(QNode): @property def empty(self): - msg = "'empty' property is deprecated in favour of using 'not bool(filter)" + msg = "'empty' property is deprecated in favour of using 'not bool(filter)'" warnings.warn(msg, DeprecationWarning) return not bool(self.children) From e83132f32c254f04cb505d1ede1d90f0dac84b18 Mon Sep 17 00:00:00 2001 From: Filip Kucharczyk Date: Tue, 10 Dec 2019 11:51:33 +0100 Subject: [PATCH 49/59] Note deprecation of 'empty' in changelog --- docs/changelog.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index e2ffa41e..bc01a403 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,9 @@ Development - If you catch/use ``MongoEngineConnectionError`` in your code, you'll have to rename it. - BREAKING CHANGE: Positional arguments when instantiating a document are no longer supported. #2103 - From now on keyword arguments (e.g. ``Doc(field_name=value)``) are required. +- DEPRECATION: ``Q.empty`` & ``QNode.empty`` are marked as deprecated and will be removed in a next version of MongoEngine. #2210 + - Added ability to check if Q or QNode are empty by parsing them to bool. + - Instead of ``Q(name="John").empty`` use ``not Q(name="John")``. - Improve error message related to InvalidDocumentError #2180 - Fix updating/modifying/deleting/reloading a document that's sharded by a field with ``db_field`` specified. #2125 - ``ListField`` now accepts an optional ``max_length`` parameter. #2110 @@ -25,7 +28,6 @@ Development - In bulk write insert, the detailed error message would raise in exception. - Added ability to compare Q and Q operations #2204 - Added ability to use a db alias on query_counter #2194 -- Added ability to check if Q or Q operations is empty by parsing them to bool #2210 Changes in 0.18.2 ================= From 3b099f936a02444b3bf02c7dcdf13b1f2fc3b895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 13 Dec 2019 21:32:45 +0100 Subject: [PATCH 50/59] provide additional details on how inheritance works in doc --- docs/guide/defining-documents.rst | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 9dcca88c..652c5cd9 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -744,7 +744,7 @@ Document inheritance To create a specialised type of a :class:`~mongoengine.Document` you have defined, you may subclass it and add any extra fields or methods you may need. -As this is new class is not a direct subclass of +As this new class is not a direct subclass of :class:`~mongoengine.Document`, it will not be stored in its own collection; it will use the same collection as its superclass uses. This allows for more convenient and efficient retrieval of related documents -- all you need do is @@ -767,6 +767,27 @@ document.:: Setting :attr:`allow_inheritance` to True should also be used in :class:`~mongoengine.EmbeddedDocument` class in case you need to subclass it +When it comes to querying using :attr:`.objects()`, querying `Page.objects()` will query +both `Page` and `DatedPage` whereas querying `DatedPage` will only query the `DatedPage` documents. +Behind the scenes, MongoEngine deals with inheritance by adding a :attr:`_cls` attribute that contains +the class name in every documents. When a document is loaded, MongoEngine checks +it's :attr:`_cls` attribute and use that class to construct the instance.:: + + Page(title='a funky title').save() + DatedPage(title='another title', date=datetime.utcnow()).save() + + print(Page.objects().count()) # 2 + print(DatedPage.objects().count()) # 1 + + # print documents in their native form + # we remove 'id' to avoid polluting the output with unnecessary detail + qs = Page.objects.exclude('id').as_pymongo() + print(list(qs)) + # [ + # {'_cls': u 'Page', 'title': 'a funky title'}, + # {'_cls': u 'Page.DatedPage', 'title': u 'another title', 'date': datetime.datetime(2019, 12, 13, 20, 16, 59, 993000)} + # ] + Working with existing data -------------------------- As MongoEngine no longer defaults to needing :attr:`_cls`, you can quickly and From 280a73af3bea8e9232a5ebe761d451840f025135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 14 Dec 2019 21:44:59 +0100 Subject: [PATCH 51/59] minor fix in doc of NULLIFY to improve #834 --- docs/guide/defining-documents.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 9dcca88c..82388d3d 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -352,7 +352,7 @@ Its value can take any of the following constants: Deletion is denied if there still exist references to the object being deleted. :const:`mongoengine.NULLIFY` - Any object's fields still referring to the object being deleted are removed + Any object's fields still referring to the object being deleted are set to None (using MongoDB's "unset" operation), effectively nullifying the relationship. :const:`mongoengine.CASCADE` Any object containing fields that are referring to the object being deleted From 50882e5bb09b74faddfac3cb93afd278ea94ced2 Mon Sep 17 00:00:00 2001 From: Eric Timmons Date: Wed, 16 Oct 2019 09:49:40 -0400 Subject: [PATCH 52/59] Add failing test Test that __eq__ for EmbeddedDocuments with LazyReferenceFields works as expected. --- tests/document/test_instance.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 173e02f2..6ba6827e 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3319,6 +3319,38 @@ class TestInstance(MongoDBTestCase): f1.ref # Dereferences lazily assert f1 == f2 + def test_embedded_document_equality_with_lazy_ref(self): + class Job(EmbeddedDocument): + boss = LazyReferenceField('Person') + + class Person(Document): + job = EmbeddedDocumentField(Job) + + Person.drop_collection() + + boss = Person() + worker = Person(job=Job(boss=boss)) + boss.save() + worker.save() + + worker1 = Person.objects.get(id=worker.id) + + # worker1.job should be equal to the job used originally to create the + # document. + self.assertEqual(worker1.job, worker.job) + + # worker1.job should be equal to a newly created Job EmbeddedDocument + # using either the Boss object or his ID. + self.assertEqual(worker1.job, Job(boss=boss)) + self.assertEqual(worker1.job, Job(boss=boss.id)) + + # The above equalities should also hold after worker1.job.boss has been + # fetch()ed. + worker1.job.boss.fetch() + self.assertEqual(worker1.job, worker.job) + self.assertEqual(worker1.job, Job(boss=boss)) + self.assertEqual(worker1.job, Job(boss=boss.id)) + def test_dbref_equality(self): class Test2(Document): name = StringField() From dc7b96a5691335e970b13fb30ef62426b126e2bd Mon Sep 17 00:00:00 2001 From: Eric Timmons Date: Wed, 16 Oct 2019 09:50:47 -0400 Subject: [PATCH 53/59] Make python value for LazyReferenceFields be a DBRef Previously, when reading a LazyReferenceField from the DB, it was stored internally in the parent document's _data field as an ObjectId. However, this meant that equality tests using an enclosing EmbeddedDocument would not return True when the EmbeddedDocument being compared to contained a DBRef or Document in _data. Enclosing Documents were largely unaffected because they look at the primary key for equality (which EmbeddedDocuments lack). This makes the internal Python representation of a LazyReferenceField (before the LazyReference itself has been constructed) a DBRef, using code identical to ReferenceField. --- mongoengine/fields.py | 9 +++++++++ tests/document/test_instance.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f8f527a3..0c29d1bc 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -2502,6 +2502,15 @@ class LazyReferenceField(BaseField): else: return pk + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type.""" + if not self.dbref and not isinstance( + value, (DBRef, Document, EmbeddedDocument) + ): + collection = self.document_type._get_collection_name() + value = DBRef(collection, self.document_type.id.to_python(value)) + return value + def validate(self, value): if isinstance(value, LazyReference): if value.collection != self.document_type._get_collection_name(): diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 6ba6827e..07376b4b 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3321,7 +3321,7 @@ class TestInstance(MongoDBTestCase): def test_embedded_document_equality_with_lazy_ref(self): class Job(EmbeddedDocument): - boss = LazyReferenceField('Person') + boss = LazyReferenceField("Person") class Person(Document): job = EmbeddedDocumentField(Job) From 0d4e61d489a9264863cecdfed08fc9e67a74d03a Mon Sep 17 00:00:00 2001 From: Eric Timmons Date: Wed, 16 Oct 2019 10:01:19 -0400 Subject: [PATCH 54/59] Add daewok to AUTHORS per contributing guidelines --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index aa044bd2..374e2f7f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -253,3 +253,4 @@ that much better: * Gaurav Dadhania (https://github.com/GVRV) * Yurii Andrieiev (https://github.com/yandrieiev) * Filip Kucharczyk (https://github.com/Pacu2) + * Eric Timmons (https://github.com/daewok) From 68dc2925fbea13702fa23ced0afd786d77b2ca28 Mon Sep 17 00:00:00 2001 From: Eric Timmons Date: Sun, 15 Dec 2019 12:08:04 -0500 Subject: [PATCH 55/59] Add LazyReferenceField with dbref=True to embedded_document equality test --- tests/document/test_instance.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 07376b4b..b899684f 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3322,6 +3322,7 @@ class TestInstance(MongoDBTestCase): def test_embedded_document_equality_with_lazy_ref(self): class Job(EmbeddedDocument): boss = LazyReferenceField("Person") + boss_dbref = LazyReferenceField("Person", dbref=True) class Person(Document): job = EmbeddedDocumentField(Job) @@ -3329,7 +3330,7 @@ class TestInstance(MongoDBTestCase): Person.drop_collection() boss = Person() - worker = Person(job=Job(boss=boss)) + worker = Person(job=Job(boss=boss, boss_dbref=boss)) boss.save() worker.save() @@ -3341,15 +3342,15 @@ class TestInstance(MongoDBTestCase): # worker1.job should be equal to a newly created Job EmbeddedDocument # using either the Boss object or his ID. - self.assertEqual(worker1.job, Job(boss=boss)) - self.assertEqual(worker1.job, Job(boss=boss.id)) + self.assertEqual(worker1.job, Job(boss=boss, boss_dbref=boss)) + self.assertEqual(worker1.job, Job(boss=boss.id, boss_dbref=boss.id)) # The above equalities should also hold after worker1.job.boss has been # fetch()ed. worker1.job.boss.fetch() self.assertEqual(worker1.job, worker.job) - self.assertEqual(worker1.job, Job(boss=boss)) - self.assertEqual(worker1.job, Job(boss=boss.id)) + self.assertEqual(worker1.job, Job(boss=boss, boss_dbref=boss)) + self.assertEqual(worker1.job, Job(boss=boss.id, boss_dbref=boss.id)) def test_dbref_equality(self): class Test2(Document): From 329f030a41da4d93aaec1f3ccee634e898f7d289 Mon Sep 17 00:00:00 2001 From: Eric Timmons Date: Sun, 15 Dec 2019 20:15:13 -0500 Subject: [PATCH 56/59] Always store a DBRef, Document, or EmbeddedDocument in LazyReferenceField._data This is required to handle the case of equality tests on a LazyReferenceField with dbref=True when comparing against a field instantiated with an ObjectId. --- mongoengine/fields.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0c29d1bc..a385559d 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -2504,9 +2504,7 @@ class LazyReferenceField(BaseField): def to_python(self, value): """Convert a MongoDB-compatible type to a Python type.""" - if not self.dbref and not isinstance( - value, (DBRef, Document, EmbeddedDocument) - ): + if not isinstance(value, (DBRef, Document, EmbeddedDocument)): collection = self.document_type._get_collection_name() value = DBRef(collection, self.document_type.id.to_python(value)) return value From cfd4d6a161556ef4a8aa355468384554eb684442 Mon Sep 17 00:00:00 2001 From: Eric Timmons Date: Sun, 15 Dec 2019 12:02:24 -0500 Subject: [PATCH 57/59] Add breaking change to changelog for LazyReferenceField representation in _data --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index bc01a403..b308c5fb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,7 @@ Development - If you catch/use ``MongoEngineConnectionError`` in your code, you'll have to rename it. - BREAKING CHANGE: Positional arguments when instantiating a document are no longer supported. #2103 - From now on keyword arguments (e.g. ``Doc(field_name=value)``) are required. +- BREAKING CHANGE: A ``LazyReferenceField`` is now stored in the ``_data`` field of its parent as a ``DBRef``, ``Document``, or ``EmbeddedDocument`` (``ObjectId`` is no longer allowed). #2182 - DEPRECATION: ``Q.empty`` & ``QNode.empty`` are marked as deprecated and will be removed in a next version of MongoEngine. #2210 - Added ability to check if Q or QNode are empty by parsing them to bool. - Instead of ``Q(name="John").empty`` use ``not Q(name="John")``. From 332bd767d43af14bbb783e779d585cd2dbcf21de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 20 Dec 2019 22:51:08 +0100 Subject: [PATCH 58/59] minor fixes in tests --- docs/guide/mongomock.rst | 6 +++--- tests/document/test_instance.py | 22 +++++++++++----------- tests/fields/test_file_field.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/guide/mongomock.rst b/docs/guide/mongomock.rst index d70ee6a6..040ff912 100644 --- a/docs/guide/mongomock.rst +++ b/docs/guide/mongomock.rst @@ -2,10 +2,10 @@ Use mongomock for testing ============================== -`mongomock `_ is a package to do just +`mongomock `_ is a package to do just what the name implies, mocking a mongo database. -To use with mongoengine, simply specify mongomock when connecting with +To use with mongoengine, simply specify mongomock when connecting with mongoengine: .. code-block:: python @@ -45,4 +45,4 @@ Example of test file: pers.save() fresh_pers = Person.objects().first() - self.assertEqual(fresh_pers.name, 'John') + assert fresh_pers.name == 'John' diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index b899684f..609d0690 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -3338,19 +3338,19 @@ class TestInstance(MongoDBTestCase): # worker1.job should be equal to the job used originally to create the # document. - self.assertEqual(worker1.job, worker.job) + assert worker1.job == worker.job # worker1.job should be equal to a newly created Job EmbeddedDocument # using either the Boss object or his ID. - self.assertEqual(worker1.job, Job(boss=boss, boss_dbref=boss)) - self.assertEqual(worker1.job, Job(boss=boss.id, boss_dbref=boss.id)) + assert worker1.job == Job(boss=boss, boss_dbref=boss) + assert worker1.job == Job(boss=boss.id, boss_dbref=boss.id) # The above equalities should also hold after worker1.job.boss has been # fetch()ed. worker1.job.boss.fetch() - self.assertEqual(worker1.job, worker.job) - self.assertEqual(worker1.job, Job(boss=boss, boss_dbref=boss)) - self.assertEqual(worker1.job, Job(boss=boss.id, boss_dbref=boss.id)) + assert worker1.job == worker.job + assert worker1.job == Job(boss=boss, boss_dbref=boss) + assert worker1.job == Job(boss=boss.id, boss_dbref=boss.id) def test_dbref_equality(self): class Test2(Document): @@ -3693,13 +3693,13 @@ class TestInstance(MongoDBTestCase): value = u"I_should_be_a_dict" coll.insert_one({"light_saber": value}) - with self.assertRaises(InvalidDocumentError) as cm: + with pytest.raises(InvalidDocumentError) as exc_info: list(Jedi.objects) - self.assertEqual( - str(cm.exception), - "Invalid data to create a `Jedi` instance.\nField 'light_saber' - The source SON object needs to be of type 'dict' but a '%s' was found" - % type(value), + assert str( + exc_info.value + ) == "Invalid data to create a `Jedi` instance.\nField 'light_saber' - The source SON object needs to be of type 'dict' but a '%s' was found" % type( + value ) diff --git a/tests/fields/test_file_field.py b/tests/fields/test_file_field.py index bfc86511..b8ece1a9 100644 --- a/tests/fields/test_file_field.py +++ b/tests/fields/test_file_field.py @@ -151,7 +151,7 @@ class TestFileField(MongoDBTestCase): result = StreamFile.objects.first() assert streamfile == result assert result.the_file.read() == text + more_text - # self.assertEqual(result.the_file.content_type, content_type) + # assert result.the_file.content_type == content_type result.the_file.seek(0) assert result.the_file.tell() == 0 assert result.the_file.read(len(text)) == text From 1170de1e8e30b976895c1c92cca134089dc5b806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 20 Dec 2019 23:16:29 +0100 Subject: [PATCH 59/59] added explicit doc for order_by #2117 --- docs/guide/mongomock.rst | 2 +- docs/guide/querying.rst | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/guide/mongomock.rst b/docs/guide/mongomock.rst index 040ff912..141d7b69 100644 --- a/docs/guide/mongomock.rst +++ b/docs/guide/mongomock.rst @@ -21,7 +21,7 @@ or with an alias: conn = get_connection('testdb') Example of test file: --------- +--------------------- .. code-block:: python import unittest diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index d64c169c..121325ae 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -222,6 +222,18 @@ keyword argument:: .. versionadded:: 0.4 +Sorting/Ordering results +======================== +It is possible to order the results by 1 or more keys using :meth:`~mongoengine.queryset.QuerySet.order_by`. +The order may be specified by prepending each of the keys by "+" or "-". Ascending order is assumed if there's no prefix.:: + + # Order by ascending date + blogs = BlogPost.objects().order_by('date') # equivalent to .order_by('+date') + + # Order by ascending date first, then descending title + blogs = BlogPost.objects().order_by('+date', '-title') + + Limiting and skipping results ============================= Just as with traditional ORMs, you may limit the number of results returned or @@ -585,7 +597,8 @@ cannot use the `$` syntax in keyword arguments it has been mapped to `S`:: ['database', 'mongodb'] From MongoDB version 2.6, push operator supports $position value which allows -to push values with index. +to push values with index:: + >>> post = BlogPost(title="Test", tags=["mongo"]) >>> post.save() >>> post.update(push__tags__0=["database", "code"])