diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index ab73824f..2897e1d1 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -251,19 +251,17 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(classes, [Human]) def test_allow_inheritance(self): - """Ensure that inheritance may be disabled on simple classes and that - _cls and _subclasses will not be used. + """Ensure that inheritance is disabled by default on simple + classes and that _cls will not be used. """ - class Animal(Document): name = StringField() - def create_dog_class(): + # can't inherit because Animal didn't explicitly allow inheritance + with self.assertRaises(ValueError): class Dog(Animal): pass - self.assertRaises(ValueError, create_dog_class) - # Check that _cls etc aren't present on simple documents dog = Animal(name='dog').save() self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) @@ -273,17 +271,15 @@ class InheritanceTest(unittest.TestCase): self.assertFalse('_cls' in obj) def test_cant_turn_off_inheritance_on_subclass(self): - """Ensure if inheritance is on in a subclass you cant turn it off + """Ensure if inheritance is on in a subclass you cant turn it off. """ - class Animal(Document): name = StringField() meta = {'allow_inheritance': True} - def create_mammal_class(): + with self.assertRaises(ValueError): class Mammal(Animal): meta = {'allow_inheritance': False} - self.assertRaises(ValueError, create_mammal_class) def test_allow_inheritance_abstract_document(self): """Ensure that abstract documents can set inheritance rules and that @@ -296,10 +292,9 @@ class InheritanceTest(unittest.TestCase): class Animal(FinalDocument): name = StringField() - def create_mammal_class(): + with self.assertRaises(ValueError): class Mammal(Animal): pass - self.assertRaises(ValueError, create_mammal_class) # Check that _cls isn't present in simple documents doc = Animal(name='dog') @@ -358,29 +353,26 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(berlin.pk, berlin.auto_id_0) def test_abstract_document_creation_does_not_fail(self): - class City(Document): continent = StringField() meta = {'abstract': True, 'allow_inheritance': False} + bkk = City(continent='asia') self.assertEqual(None, bkk.pk) # TODO: expected error? Shouldn't we create a new error type? - self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1)) + with self.assertRaises(KeyError): + setattr(bkk, 'pk', 1) def test_allow_inheritance_embedded_document(self): - """Ensure embedded documents respect inheritance - """ - + """Ensure embedded documents respect inheritance.""" class Comment(EmbeddedDocument): content = StringField() - def create_special_comment(): + with self.assertRaises(ValueError): class SpecialComment(Comment): pass - self.assertRaises(ValueError, create_special_comment) - doc = Comment(content='test') self.assertFalse('_cls' in doc.to_mongo()) @@ -452,11 +444,11 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(Guppy._get_collection_name(), 'fish') self.assertEqual(Human._get_collection_name(), 'human') - def create_bad_abstract(): + # ensure that a subclass of a non-abstract class can't be abstract + with self.assertRaises(ValueError): class EvilHuman(Human): evil = BooleanField(default=True) meta = {'abstract': True} - self.assertRaises(ValueError, create_bad_abstract) def test_abstract_embedded_documents(self): # 789: EmbeddedDocument shouldn't inherit abstract diff --git a/tests/document/instance.py b/tests/document/instance.py index 43cd2d68..b92bafa9 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -99,21 +99,18 @@ class InstanceTest(unittest.TestCase): self.assertEqual(options['size'], 4096) # Check that the document cannot be redefined with different options - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 11, - } - # Create the collection by accessing Document.objects - Log.objects - self.assertRaises(InvalidCollectionError, recreate_log_document) + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 11, + } - Log.drop_collection() + # Accessing Document.objects creates the collection + with self.assertRaises(InvalidCollectionError): + Log.objects def test_capped_collection_default(self): - """Ensure that capped collections defaults work properly. - """ + """Ensure that capped collections defaults work properly.""" class Log(Document): date = DateTimeField(default=datetime.now) meta = { @@ -131,16 +128,14 @@ class InstanceTest(unittest.TestCase): self.assertEqual(options['size'], 10 * 2**20) # Check that the document with default value can be recreated - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - } - # Create the collection by accessing Document.objects - Log.objects - recreate_log_document() - Log.drop_collection() + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 10, + } + + # Create the collection by accessing Document.objects + Log.objects def test_capped_collection_no_max_size_problems(self): """Ensure that capped collections with odd max_size work properly. @@ -163,16 +158,14 @@ class InstanceTest(unittest.TestCase): self.assertTrue(options['size'] >= 10000) # Check that the document with odd max_size value can be recreated - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_size': 10000, - } - # Create the collection by accessing Document.objects - Log.objects - recreate_log_document() - Log.drop_collection() + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_size': 10000, + } + + # Create the collection by accessing Document.objects + Log.objects def test_repr(self): """Ensure that unicode representation works @@ -353,14 +346,14 @@ class InstanceTest(unittest.TestCase): self.assertEqual(User._fields['username'].db_field, '_id') self.assertEqual(User._meta['id_field'], 'username') - def create_invalid_user(): - User(name='test').save() # no primary key field - self.assertRaises(ValidationError, create_invalid_user) + # test no primary key field + self.assertRaises(ValidationError, User(name='test').save) - def define_invalid_user(): + # define a subclass with a different primary key field than the + # parent + with self.assertRaises(ValueError): class EmailUser(User): email = StringField(primary_key=True) - self.assertRaises(ValueError, define_invalid_user) class EmailUser(User): email = StringField() @@ -410,9 +403,8 @@ class InstanceTest(unittest.TestCase): # and the NicePlace model not being imported in at query time. del(_document_registry['Place.NicePlace']) - def query_without_importing_nice_place(): + with self.assertRaises(NotRegistered): list(Place.objects.all()) - self.assertRaises(NotRegistered, query_without_importing_nice_place) def test_document_registry_regressions(self): @@ -794,8 +786,10 @@ class InstanceTest(unittest.TestCase): def test_modify_empty(self): doc = self.Person(name="bob", age=10).save() - self.assertRaises( - InvalidDocumentError, lambda: self.Person().modify(set__age=10)) + + with self.assertRaises(InvalidDocumentError): + self.Person().modify(set__age=10) + self.assertDbEqual([dict(doc.to_mongo())]) def test_modify_invalid_query(self): @@ -803,9 +797,8 @@ class InstanceTest(unittest.TestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - self.assertRaises( - InvalidQueryError, - lambda: doc1.modify({'id': doc2.id}, set__value=20)) + with self.assertRaises(InvalidQueryError): + doc1.modify({'id': doc2.id}, set__value=20) self.assertDbEqual(docs) @@ -1289,12 +1282,11 @@ class InstanceTest(unittest.TestCase): def test_document_update(self): - def update_not_saved_raises(): + # try updating a non-saved document + with self.assertRaises(OperationError): person = self.Person(name='dcrosta') person.update(set__name='Dan Crosta') - self.assertRaises(OperationError, update_not_saved_raises) - author = self.Person(name='dcrosta') author.save() @@ -1304,19 +1296,17 @@ class InstanceTest(unittest.TestCase): p1 = self.Person.objects.first() self.assertEqual(p1.name, author.name) - def update_no_value_raises(): + # try sending an empty update + with self.assertRaises(OperationError): person = self.Person.objects.first() person.update() - self.assertRaises(OperationError, update_no_value_raises) - - def update_no_op_should_default_to_set(): - person = self.Person.objects.first() - person.update(name="Dan") - person.reload() - return person.name - - self.assertEqual("Dan", update_no_op_should_default_to_set()) + # update that doesn't explicitly specify an operator should default + # to 'set__' + person = self.Person.objects.first() + person.update(name="Dan") + person.reload() + self.assertEqual("Dan", person.name) def test_update_unique_field(self): class Doc(Document): @@ -1325,8 +1315,8 @@ class InstanceTest(unittest.TestCase): doc1 = Doc(name="first").save() doc2 = Doc(name="second").save() - self.assertRaises(NotUniqueError, lambda: - doc2.update(set__name=doc1.name)) + with self.assertRaises(NotUniqueError): + doc2.update(set__name=doc1.name) def test_embedded_update(self): """ @@ -1844,15 +1834,13 @@ class InstanceTest(unittest.TestCase): def test_duplicate_db_fields_raise_invalid_document_error(self): """Ensure a InvalidDocumentError is thrown if duplicate fields - declare the same db_field""" - - def throw_invalid_document_error(): + declare the same db_field. + """ + with self.assertRaises(InvalidDocumentError): class Foo(Document): name = StringField() name2 = StringField(db_field='name') - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - def test_invalid_son(self): """Raise an error if loading invalid data""" class Occurrence(EmbeddedDocument): @@ -1864,11 +1852,13 @@ class InstanceTest(unittest.TestCase): forms = ListField(StringField(), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) - def raise_invalid_document(): - Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', - 'occurs': {"hello": None}}) - - self.assertRaises(InvalidDocumentError, raise_invalid_document) + with self.assertRaises(InvalidDocumentError): + Word._from_son({ + 'stem': [1, 2, 3], + 'forms': 1, + 'count': 'one', + 'occurs': {"hello": None} + }) def test_reverse_delete_rule_cascade_and_nullify(self): """Ensure that a referenced document is also deleted upon deletion. @@ -2099,8 +2089,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(Bar.objects.get().foo, None) def test_invalid_reverse_delete_rule_raise_errors(self): - - def throw_invalid_document_error(): + with self.assertRaises(InvalidDocumentError): class Blog(Document): content = StringField() authors = MapField(ReferenceField( @@ -2110,21 +2099,15 @@ class InstanceTest(unittest.TestCase): self.Person, reverse_delete_rule=NULLIFY)) - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - - def throw_invalid_document_error_embedded(): + with self.assertRaises(InvalidDocumentError): class Parents(EmbeddedDocument): father = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY) - self.assertRaises( - InvalidDocumentError, throw_invalid_document_error_embedded) - def test_reverse_delete_rule_cascade_recurs(self): """Ensure that a chain of documents is also deleted upon cascaded deletion. """ - class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) @@ -2340,15 +2323,14 @@ class InstanceTest(unittest.TestCase): pickle_doc.save() pickle_doc.delete() - def test_throw_invalid_document_error(self): - - # test handles people trying to upsert - def throw_invalid_document_error(): + def test_override_method_with_field(self): + """Test creating a field with a field name that would override + the "validate" method. + """ + with self.assertRaises(InvalidDocumentError): class Blog(Document): validate = DictField() - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - def test_mutating_documents(self): class B(EmbeddedDocument): @@ -2811,11 +2793,10 @@ class InstanceTest(unittest.TestCase): log.log = "Saving" log.save() - def change_shard_key(): + # try to change the shard key + with self.assertRaises(OperationError): log.machine = "127.0.0.1" - self.assertRaises(OperationError, change_shard_key) - def test_shard_key_in_embedded_document(self): class Foo(EmbeddedDocument): foo = StringField() @@ -2836,12 +2817,11 @@ class InstanceTest(unittest.TestCase): bar_doc.bar = 'baz' bar_doc.save() - def change_shard_key(): + # try to change the shard key + with self.assertRaises(OperationError): bar_doc.foo.foo = 'something' bar_doc.save() - self.assertRaises(OperationError, change_shard_key) - def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) @@ -2862,11 +2842,10 @@ class InstanceTest(unittest.TestCase): log.log = "Saving" log.save() - def change_shard_key(): + # try to change the shard key + with self.assertRaises(OperationError): log.machine = "127.0.0.1" - self.assertRaises(OperationError, change_shard_key) - def test_kwargs_simple(self): class Embedded(EmbeddedDocument): @@ -2951,11 +2930,9 @@ class InstanceTest(unittest.TestCase): def test_bad_mixed_creation(self): """Ensure that document gives correct error when duplicating arguments """ - def construct_bad_instance(): + with self.assertRaises(TypeError): return self.Person("Test User", 42, name="Bad User") - self.assertRaises(TypeError, construct_bad_instance) - def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id' """ diff --git a/tests/document/validation.py b/tests/document/validation.py index 1ff88ab5..105bc8b0 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -153,14 +153,14 @@ class ValidatorErrorTest(unittest.TestCase): s = SubDoc() - self.assertRaises(ValidationError, lambda: s.validate()) + self.assertRaises(ValidationError, s.validate) d1.e = s d2.e = s del d1 - self.assertRaises(ValidationError, lambda: d2.validate()) + self.assertRaises(ValidationError, d2.validate) def test_parent_reference_in_child_document(self): """ diff --git a/tests/fields/fields.py b/tests/fields/fields.py index b4396f39..17b76742 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1115,12 +1115,11 @@ class FieldTest(unittest.TestCase): e.mapping = [1] e.save() - def create_invalid_mapping(): + # try creating an invalid mapping + with self.assertRaises(ValidationError): e.mapping = ["abc"] e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - Simple.drop_collection() def test_list_field_rejects_strings(self): @@ -1387,12 +1386,11 @@ class FieldTest(unittest.TestCase): e.mapping['someint'] = 1 e.save() - def create_invalid_mapping(): + # try creating an invalid mapping + with self.assertRaises(ValidationError): e.mapping['somestring'] = "abc" e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - Simple.drop_collection() def test_dictfield_complex(self): @@ -1465,11 +1463,10 @@ class FieldTest(unittest.TestCase): self.assertEqual(BaseDict, type(e.mapping)) self.assertEqual({"ints": [3, 4]}, e.mapping) - def create_invalid_mapping(): + # try creating an invalid mapping + with self.assertRaises(ValueError): e.update(set__mapping={"somestrings": ["foo", "bar", ]}) - self.assertRaises(ValueError, create_invalid_mapping) - Simple.drop_collection() def test_mapfield(self): @@ -1484,18 +1481,14 @@ class FieldTest(unittest.TestCase): e.mapping['someint'] = 1 e.save() - def create_invalid_mapping(): + with self.assertRaises(ValidationError): e.mapping['somestring'] = "abc" e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - - def create_invalid_class(): + with self.assertRaises(ValidationError): class NoDeclaredType(Document): mapping = MapField() - self.assertRaises(ValidationError, create_invalid_class) - Simple.drop_collection() def test_complex_mapfield(self): @@ -1524,14 +1517,10 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) - def create_invalid_mapping(): + with self.assertRaises(ValidationError): e.mapping['someint'] = 123 e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - - Extensible.drop_collection() - def test_embedded_mapfield_db_field(self): class Embedded(EmbeddedDocument): @@ -1741,8 +1730,8 @@ class FieldTest(unittest.TestCase): # Reference is no longer valid foo.delete() bar = Bar.objects.get() - self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref')) - self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref')) + self.assertRaises(DoesNotExist, getattr, bar, 'ref') + self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref') # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() @@ -2017,7 +2006,7 @@ class FieldTest(unittest.TestCase): }) def test_cached_reference_fields_on_embedded_documents(self): - def build(): + with self.assertRaises(InvalidDocumentError): class Test(Document): name = StringField() @@ -2026,8 +2015,6 @@ class FieldTest(unittest.TestCase): 'test': CachedReferenceField(Test) }) - self.assertRaises(InvalidDocumentError, build) - def test_cached_reference_auto_sync(self): class Person(Document): TYPES = ( @@ -3815,9 +3802,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): filtered = self.post1.comments.filter() # Ensure nothing was changed - # < 2.6 Incompatible > - # self.assertListEqual(filtered, self.post1.comments) - self.assertEqual(filtered, self.post1.comments) + self.assertListEqual(filtered, self.post1.comments) def test_single_keyword_filter(self): """ @@ -3868,10 +3853,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): Tests the filter method of a List of Embedded Documents when the keyword is not a known keyword. """ - # < 2.6 Incompatible > - # with self.assertRaises(AttributeError): - # self.post2.comments.filter(year=2) - self.assertRaises(AttributeError, self.post2.comments.filter, year=2) + with self.assertRaises(AttributeError): + self.post2.comments.filter(year=2) def test_no_keyword_exclude(self): """ @@ -3881,9 +3864,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): filtered = self.post1.comments.exclude() # Ensure everything was removed - # < 2.6 Incompatible > - # self.assertListEqual(filtered, []) - self.assertEqual(filtered, []) + self.assertListEqual(filtered, []) def test_single_keyword_exclude(self): """ @@ -3929,10 +3910,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): Tests the exclude method of a List of Embedded Documents when the keyword is not a known keyword. """ - # < 2.6 Incompatible > - # with self.assertRaises(AttributeError): - # self.post2.comments.exclude(year=2) - self.assertRaises(AttributeError, self.post2.comments.exclude, year=2) + with self.assertRaises(AttributeError): + self.post2.comments.exclude(year=2) def test_chained_filter_exclude(self): """ @@ -3970,10 +3949,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): single keyword. """ comment = self.post1.comments.get(author='user1') - - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user1') def test_multi_keyword_get(self): @@ -3982,10 +3958,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): multiple keywords. """ comment = self.post2.comments.get(author='user2', message='message2') - - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user2') self.assertEqual(comment.message, 'message2') @@ -3994,44 +3967,32 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): Tests the get method of a List of Embedded Documents without a keyword to return multiple documents. """ - # < 2.6 Incompatible > - # with self.assertRaises(MultipleObjectsReturned): - # self.post1.comments.get() - self.assertRaises(MultipleObjectsReturned, self.post1.comments.get) + with self.assertRaises(MultipleObjectsReturned): + self.post1.comments.get() def test_keyword_multiple_return_get(self): """ Tests the get method of a List of Embedded Documents with a keyword to return multiple documents. """ - # < 2.6 Incompatible > - # with self.assertRaises(MultipleObjectsReturned): - # self.post2.comments.get(author='user2') - self.assertRaises( - MultipleObjectsReturned, self.post2.comments.get, author='user2' - ) + with self.assertRaises(MultipleObjectsReturned): + self.post2.comments.get(author='user2') def test_unknown_keyword_get(self): """ Tests the get method of a List of Embedded Documents with an unknown keyword. """ - # < 2.6 Incompatible > - # with self.assertRaises(AttributeError): - # self.post2.comments.get(year=2020) - self.assertRaises(AttributeError, self.post2.comments.get, year=2020) + with self.assertRaises(AttributeError): + self.post2.comments.get(year=2020) def test_no_result_get(self): """ Tests the get method of a List of Embedded Documents where get returns no results. """ - # < 2.6 Incompatible > - # with self.assertRaises(DoesNotExist): - # self.post1.comments.get(author='user3') - self.assertRaises( - DoesNotExist, self.post1.comments.get, author='user3' - ) + with self.assertRaises(DoesNotExist): + self.post1.comments.get(author='user3') def test_first(self): """ @@ -4041,9 +4002,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): comment = self.post1.comments.first() # Ensure a Comment object was returned. - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment, self.post1.comments[0]) def test_create(self): @@ -4056,22 +4015,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): self.post1.save() # Ensure the returned value is the comment object. - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user4') self.assertEqual(comment.message, 'message1') # Ensure the new comment was actually saved to the database. - # < 2.6 Incompatible > - # self.assertIn( - # comment, - # self.BlogPost.objects(comments__author='user4')[0].comments - # ) - self.assertTrue( - comment in self.BlogPost.objects( - comments__author='user4' - )[0].comments + self.assertIn( + comment, + self.BlogPost.objects(comments__author='user4')[0].comments ) def test_filtered_create(self): @@ -4086,22 +4037,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): self.post1.save() # Ensure the returned value is the comment object. - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user4') self.assertEqual(comment.message, 'message1') # Ensure the new comment was actually saved to the database. - # < 2.6 Incompatible > - # self.assertIn( - # comment, - # self.BlogPost.objects(comments__author='user4')[0].comments - # ) - self.assertTrue( - comment in self.BlogPost.objects( - comments__author='user4' - )[0].comments + self.assertIn( + comment, + self.BlogPost.objects(comments__author='user4')[0].comments ) def test_no_keyword_update(self): @@ -4114,22 +4057,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): self.post1.save() # Ensure that nothing was altered. - # < 2.6 Incompatible > - # self.assertIn( - # original[0], - # self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertIn( + original[0], + self.BlogPost.objects(id=self.post1.id)[0].comments ) - # < 2.6 Incompatible > - # self.assertIn( - # original[1], - # self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertIn( + original[1], + self.BlogPost.objects(id=self.post1.id)[0].comments ) # Ensure the method returned 0 as the number of entries @@ -4175,13 +4110,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): comments.save() # Ensure that the new comment has been added to the database. - # < 2.6 Incompatible > - # self.assertIn( - # new_comment, - # self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertIn( + new_comment, + self.BlogPost.objects(id=self.post1.id)[0].comments ) def test_delete(self): @@ -4193,23 +4124,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): # Ensure that all the comments under post1 were deleted in the # database. - # < 2.6 Incompatible > - # self.assertListEqual( - # self.BlogPost.objects(id=self.post1.id)[0].comments, [] - # ) - self.assertEqual( + self.assertListEqual( self.BlogPost.objects(id=self.post1.id)[0].comments, [] ) # Ensure that post1 comments were deleted from the list. - # < 2.6 Incompatible > - # self.assertListEqual(self.post1.comments, []) - self.assertEqual(self.post1.comments, []) + self.assertListEqual(self.post1.comments, []) # Ensure that comments still returned a EmbeddedDocumentList object. - # < 2.6 Incompatible > - # self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) - self.assertTrue(isinstance(self.post1.comments, EmbeddedDocumentList)) + self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) # Ensure that the delete method returned 2 as the number of entries # deleted from the database @@ -4249,21 +4172,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): self.post1.save() # Ensure that only the user2 comment was deleted. - # < 2.6 Incompatible > - # self.assertNotIn( - # comment, self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - comment not in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertNotIn( + comment, self.BlogPost.objects(id=self.post1.id)[0].comments ) self.assertEqual( len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1 ) # Ensure that the user2 comment no longer exists in the list. - # < 2.6 Incompatible > - # self.assertNotIn(comment, self.post1.comments) - self.assertTrue(comment not in self.post1.comments) + self.assertNotIn(comment, self.post1.comments) self.assertEqual(len(self.post1.comments), 1) # Ensure that the delete method returned 1 as the number of entries diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index a56b6c26..e4c71de7 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -25,7 +25,10 @@ __all__ = ("QuerySetTest",) class db_ops_tracker(query_counter): def get_ops(self): - ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} + ignore_query = { + 'ns': {'$ne': '%s.system.indexes' % self.db.name}, + 'command.count': {'$ne': 'system.profile'} + } return list(self.db.system.profile.find(ignore_query)) @@ -94,12 +97,12 @@ class QuerySetTest(unittest.TestCase): author = ReferenceField(self.Person) author2 = GenericReferenceField() - def test_reference(): + # test addressing a field from a reference + with self.assertRaises(InvalidQueryError): list(BlogPost.objects(author__name="test")) - self.assertRaises(InvalidQueryError, test_reference) - - def test_generic_reference(): + # should fail for a generic reference as well + with self.assertRaises(InvalidQueryError): list(BlogPost.objects(author2__name="test")) def test_find(self): @@ -218,14 +221,15 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects[1] self.assertEqual(person.name, "User B") - self.assertRaises(IndexError, self.Person.objects.__getitem__, 2) + with self.assertRaises(IndexError): + self.Person.objects[2] # Find a document using just the object id person = self.Person.objects.with_id(person1.id) self.assertEqual(person.name, "User A") - self.assertRaises( - InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id) + with self.assertRaises(InvalidQueryError): + self.Person.objects(name="User A").with_id(person1.id) def test_find_only_one(self): """Ensure that a query using ``get`` returns at most one result. @@ -363,7 +367,8 @@ class QuerySetTest(unittest.TestCase): # test invalid batch size qs = A.objects.batch_size(-1) - self.assertRaises(ValueError, lambda: list(qs)) + with self.assertRaises(ValueError): + list(qs) def test_update_write_concern(self): """Test that passing write_concern works""" @@ -392,18 +397,14 @@ class QuerySetTest(unittest.TestCase): """Test to ensure that update is passed a value to update to""" self.Person.drop_collection() - author = self.Person(name='Test User') - author.save() + author = self.Person.objects.create(name='Test User') - def update_raises(): + with self.assertRaises(OperationError): self.Person.objects(pk=author.pk).update({}) - def update_one_raises(): + with self.assertRaises(OperationError): self.Person.objects(pk=author.pk).update_one({}) - self.assertRaises(OperationError, update_raises) - self.assertRaises(OperationError, update_one_raises) - def test_update_array_position(self): """Ensure that updating by array position works. @@ -431,8 +432,8 @@ class QuerySetTest(unittest.TestCase): Blog.objects.create(posts=[post2, post1]) # Update all of the first comments of second posts of all blogs - Blog.objects().update(set__posts__1__comments__0__name="testc") - testc_blogs = Blog.objects(posts__1__comments__0__name="testc") + Blog.objects().update(set__posts__1__comments__0__name='testc') + testc_blogs = Blog.objects(posts__1__comments__0__name='testc') self.assertEqual(testc_blogs.count(), 2) Blog.drop_collection() @@ -441,14 +442,13 @@ class QuerySetTest(unittest.TestCase): # Update only the first blog returned by the query Blog.objects().update_one( - set__posts__1__comments__1__name="testc") - testc_blogs = Blog.objects(posts__1__comments__1__name="testc") + set__posts__1__comments__1__name='testc') + testc_blogs = Blog.objects(posts__1__comments__1__name='testc') self.assertEqual(testc_blogs.count(), 1) # Check that using this indexing syntax on a non-list fails - def non_list_indexing(): - Blog.objects().update(set__posts__1__comments__0__name__1="asdf") - self.assertRaises(InvalidQueryError, non_list_indexing) + with self.assertRaises(InvalidQueryError): + Blog.objects().update(set__posts__1__comments__0__name__1='asdf') Blog.drop_collection() @@ -516,15 +516,12 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) # Nested updates arent supported yet.. - def update_nested(): + with self.assertRaises(OperationError): Simple.drop_collection() Simple(x=[{'test': [1, 2, 3, 4]}]).save() Simple.objects(x__test=2).update(set__x__S__test__S=3) self.assertEqual(simple.x, [1, 2, 3, 4]) - self.assertRaises(OperationError, update_nested) - Simple.drop_collection() - def test_update_using_positional_operator_embedded_document(self): """Ensure that the embedded documents can be updated using the positional operator.""" @@ -839,30 +836,31 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Blog.objects.count(), 2) - # test handles people trying to upsert - def throw_operation_error(): + # test inserting an existing document (shouldn't be allowed) + with self.assertRaises(OperationError): + blog = Blog.objects.first() + Blog.objects.insert(blog) + + # test inserting a query set + with self.assertRaises(OperationError): blogs = Blog.objects Blog.objects.insert(blogs) - self.assertRaises(OperationError, throw_operation_error) - - # Test can insert new doc + # insert a new doc new_post = Blog(title="code123", id=ObjectId()) Blog.objects.insert(new_post) - # test handles other classes being inserted - def throw_operation_error_wrong_doc(): - class Author(Document): - pass + class Author(Document): + pass + + # try inserting a different document class + with self.assertRaises(OperationError): Blog.objects.insert(Author()) - self.assertRaises(OperationError, throw_operation_error_wrong_doc) - - def throw_operation_error_not_a_document(): + # try inserting a non-document + with self.assertRaises(OperationError): Blog.objects.insert("HELLO WORLD") - self.assertRaises(OperationError, throw_operation_error_not_a_document) - Blog.drop_collection() blog1 = Blog(title="code", posts=[post1, post2]) @@ -882,14 +880,13 @@ class QuerySetTest(unittest.TestCase): blog3 = Blog(title="baz", posts=[post1, post2]) Blog.objects.insert([blog1, blog2]) - def throw_operation_error_not_unique(): + with self.assertRaises(NotUniqueError): Blog.objects.insert([blog2, blog3]) - self.assertRaises(NotUniqueError, throw_operation_error_not_unique) self.assertEqual(Blog.objects.count(), 2) - Blog.objects.insert([blog2, blog3], write_concern={"w": 0, - 'continue_on_error': True}) + Blog.objects.insert([blog2, blog3], + write_concern={"w": 0, 'continue_on_error': True}) self.assertEqual(Blog.objects.count(), 3) def test_get_changed_fields_query_count(self): @@ -1233,7 +1230,9 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.filter(title='whatever').first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) + q.get_ops()[0]['query']['$orderby'], + {'published_date': -1} + ) with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by().first() @@ -1910,12 +1909,10 @@ class QuerySetTest(unittest.TestCase): Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') self.assertEqual(Site.objects.first().collaborators, []) - def pull_all(): + with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__user=['Ross']) - self.assertRaises(InvalidQueryError, pull_all) - def test_pull_from_nested_embedded(self): class User(EmbeddedDocument): @@ -1946,12 +1943,10 @@ class QuerySetTest(unittest.TestCase): pull__collaborators__unhelpful={'name': 'Frank'}) self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) - def pull_all(): + with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=['Ross']) - self.assertRaises(InvalidQueryError, pull_all) - def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): @@ -1980,12 +1975,10 @@ class QuerySetTest(unittest.TestCase): pull__collaborators__unhelpful={'user': 'Frank'}) self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) - def pull_all(): + with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__user=['Ross']) - self.assertRaises(InvalidQueryError, pull_all) - def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -3821,11 +3814,9 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(a in results) self.assertTrue(c in results) - def invalid_where(): + with self.assertRaises(TypeError): list(IntPair.objects.where(fielda__gte=3)) - self.assertRaises(TypeError, invalid_where) - def test_scalar(self): class Organization(Document): @@ -4550,7 +4541,9 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(counter, 100) self.assertEqual(len(list(docs)), 100) - self.assertRaises(TypeError, lambda: len(docs)) + + with self.assertRaises(TypeError): + len(docs) with query_counter() as q: self.assertEqual(q, 0) @@ -4875,7 +4868,9 @@ class QuerySetTest(unittest.TestCase): def test_max_time_ms(self): # 778: max_time_ms can get only int or None as input - self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number") + self.assertRaises(TypeError, + self.Person.objects(name="name").max_time_ms, + 'not a number') def test_subclass_field_query(self): class Animal(Document): diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 06fe4ea5..20ab0b3f 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -238,7 +238,8 @@ class TransformTest(unittest.TestCase): box = [(35.0, -125.0), (40.0, -100.0)] # I *meant* to execute location__within_box=box events = Event.objects(location__within=box) - self.assertRaises(InvalidQueryError, lambda: events.count()) + with self.assertRaises(InvalidQueryError): + events.count() if __name__ == '__main__': diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index ee2ef59a..6f020e88 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -268,14 +268,13 @@ class QTest(unittest.TestCase): self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) # Test invalid query objs - def wrong_query_objs(): + with self.assertRaises(InvalidQueryError): self.Person.objects('user1') - def wrong_query_objs_filter(): - self.Person.objects('user1') + # filter should fail, too + with self.assertRaises(InvalidQueryError): + self.Person.objects.filter('user1') - self.assertRaises(InvalidQueryError, wrong_query_objs) - self.assertRaises(InvalidQueryError, wrong_query_objs_filter) def test_q_regex(self): """Ensure that Q objects can be queried using regexes. diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index eb40e767..6830a188 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -23,7 +23,8 @@ class TestStrictDict(unittest.TestCase): self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') def test_init_fails_on_nonexisting_attrs(self): - self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) + with self.assertRaises(AttributeError): + self.dtype(a=1, b=2, d=3) def test_eq(self): d = self.dtype(a=1, b=1, c=1) @@ -46,14 +47,12 @@ class TestStrictDict(unittest.TestCase): d = self.dtype() d.a = 1 self.assertEqual(d.a, 1) - self.assertRaises(AttributeError, lambda: d.b) + self.assertRaises(AttributeError, getattr, d, 'b') def test_setattr_raises_on_nonexisting_attr(self): d = self.dtype() - - def _f(): + with self.assertRaises(AttributeError): d.x = 1 - self.assertRaises(AttributeError, _f) def test_setattr_getattr_special(self): d = self.strict_dict_class(["items"])