use with self.assertRaises for readability

This commit is contained in:
Stefan Wojcik 2016-12-10 22:33:39 -05:00
parent a8884391c2
commit 3ebe3748fa
8 changed files with 206 additions and 326 deletions

View File

@ -251,19 +251,17 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(classes, [Human]) self.assertEqual(classes, [Human])
def test_allow_inheritance(self): def test_allow_inheritance(self):
"""Ensure that inheritance may be disabled on simple classes and that """Ensure that inheritance is disabled by default on simple
_cls and _subclasses will not be used. classes and that _cls will not be used.
""" """
class Animal(Document): class Animal(Document):
name = StringField() name = StringField()
def create_dog_class(): # can't inherit because Animal didn't explicitly allow inheritance
with self.assertRaises(ValueError):
class Dog(Animal): class Dog(Animal):
pass pass
self.assertRaises(ValueError, create_dog_class)
# Check that _cls etc aren't present on simple documents # Check that _cls etc aren't present on simple documents
dog = Animal(name='dog').save() dog = Animal(name='dog').save()
self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) self.assertEqual(dog.to_mongo().keys(), ['_id', 'name'])
@ -273,17 +271,15 @@ class InheritanceTest(unittest.TestCase):
self.assertFalse('_cls' in obj) self.assertFalse('_cls' in obj)
def test_cant_turn_off_inheritance_on_subclass(self): def test_cant_turn_off_inheritance_on_subclass(self):
"""Ensure if inheritance is on in a subclass you cant turn it off """Ensure if inheritance is on in a subclass you cant turn it off.
""" """
class Animal(Document): class Animal(Document):
name = StringField() name = StringField()
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
def create_mammal_class(): with self.assertRaises(ValueError):
class Mammal(Animal): class Mammal(Animal):
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
self.assertRaises(ValueError, create_mammal_class)
def test_allow_inheritance_abstract_document(self): def test_allow_inheritance_abstract_document(self):
"""Ensure that abstract documents can set inheritance rules and that """Ensure that abstract documents can set inheritance rules and that
@ -296,10 +292,9 @@ class InheritanceTest(unittest.TestCase):
class Animal(FinalDocument): class Animal(FinalDocument):
name = StringField() name = StringField()
def create_mammal_class(): with self.assertRaises(ValueError):
class Mammal(Animal): class Mammal(Animal):
pass pass
self.assertRaises(ValueError, create_mammal_class)
# Check that _cls isn't present in simple documents # Check that _cls isn't present in simple documents
doc = Animal(name='dog') doc = Animal(name='dog')
@ -358,29 +353,26 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(berlin.pk, berlin.auto_id_0) self.assertEqual(berlin.pk, berlin.auto_id_0)
def test_abstract_document_creation_does_not_fail(self): def test_abstract_document_creation_does_not_fail(self):
class City(Document): class City(Document):
continent = StringField() continent = StringField()
meta = {'abstract': True, meta = {'abstract': True,
'allow_inheritance': False} 'allow_inheritance': False}
bkk = City(continent='asia') bkk = City(continent='asia')
self.assertEqual(None, bkk.pk) self.assertEqual(None, bkk.pk)
# TODO: expected error? Shouldn't we create a new error type? # TODO: expected error? Shouldn't we create a new error type?
self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1)) with self.assertRaises(KeyError):
setattr(bkk, 'pk', 1)
def test_allow_inheritance_embedded_document(self): def test_allow_inheritance_embedded_document(self):
"""Ensure embedded documents respect inheritance """Ensure embedded documents respect inheritance."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
def create_special_comment(): with self.assertRaises(ValueError):
class SpecialComment(Comment): class SpecialComment(Comment):
pass pass
self.assertRaises(ValueError, create_special_comment)
doc = Comment(content='test') doc = Comment(content='test')
self.assertFalse('_cls' in doc.to_mongo()) self.assertFalse('_cls' in doc.to_mongo())
@ -452,11 +444,11 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(Guppy._get_collection_name(), 'fish') self.assertEqual(Guppy._get_collection_name(), 'fish')
self.assertEqual(Human._get_collection_name(), 'human') self.assertEqual(Human._get_collection_name(), 'human')
def create_bad_abstract(): # ensure that a subclass of a non-abstract class can't be abstract
with self.assertRaises(ValueError):
class EvilHuman(Human): class EvilHuman(Human):
evil = BooleanField(default=True) evil = BooleanField(default=True)
meta = {'abstract': True} meta = {'abstract': True}
self.assertRaises(ValueError, create_bad_abstract)
def test_abstract_embedded_documents(self): def test_abstract_embedded_documents(self):
# 789: EmbeddedDocument shouldn't inherit abstract # 789: EmbeddedDocument shouldn't inherit abstract

View File

@ -99,21 +99,18 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(options['size'], 4096) self.assertEqual(options['size'], 4096)
# Check that the document cannot be redefined with different options # Check that the document cannot be redefined with different options
def recreate_log_document():
class Log(Document): class Log(Document):
date = DateTimeField(default=datetime.now) date = DateTimeField(default=datetime.now)
meta = { meta = {
'max_documents': 11, 'max_documents': 11,
} }
# Create the collection by accessing Document.objects
Log.objects
self.assertRaises(InvalidCollectionError, recreate_log_document)
Log.drop_collection() # Accessing Document.objects creates the collection
with self.assertRaises(InvalidCollectionError):
Log.objects
def test_capped_collection_default(self): def test_capped_collection_default(self):
"""Ensure that capped collections defaults work properly. """Ensure that capped collections defaults work properly."""
"""
class Log(Document): class Log(Document):
date = DateTimeField(default=datetime.now) date = DateTimeField(default=datetime.now)
meta = { meta = {
@ -131,16 +128,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(options['size'], 10 * 2**20) self.assertEqual(options['size'], 10 * 2**20)
# Check that the document with default value can be recreated # Check that the document with default value can be recreated
def recreate_log_document():
class Log(Document): class Log(Document):
date = DateTimeField(default=datetime.now) date = DateTimeField(default=datetime.now)
meta = { meta = {
'max_documents': 10, 'max_documents': 10,
} }
# Create the collection by accessing Document.objects # Create the collection by accessing Document.objects
Log.objects Log.objects
recreate_log_document()
Log.drop_collection()
def test_capped_collection_no_max_size_problems(self): def test_capped_collection_no_max_size_problems(self):
"""Ensure that capped collections with odd max_size work properly. """Ensure that capped collections with odd max_size work properly.
@ -163,16 +158,14 @@ class InstanceTest(unittest.TestCase):
self.assertTrue(options['size'] >= 10000) self.assertTrue(options['size'] >= 10000)
# Check that the document with odd max_size value can be recreated # Check that the document with odd max_size value can be recreated
def recreate_log_document():
class Log(Document): class Log(Document):
date = DateTimeField(default=datetime.now) date = DateTimeField(default=datetime.now)
meta = { meta = {
'max_size': 10000, 'max_size': 10000,
} }
# Create the collection by accessing Document.objects # Create the collection by accessing Document.objects
Log.objects Log.objects
recreate_log_document()
Log.drop_collection()
def test_repr(self): def test_repr(self):
"""Ensure that unicode representation works """Ensure that unicode representation works
@ -353,14 +346,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(User._fields['username'].db_field, '_id') self.assertEqual(User._fields['username'].db_field, '_id')
self.assertEqual(User._meta['id_field'], 'username') self.assertEqual(User._meta['id_field'], 'username')
def create_invalid_user(): # test no primary key field
User(name='test').save() # no primary key field self.assertRaises(ValidationError, User(name='test').save)
self.assertRaises(ValidationError, create_invalid_user)
def define_invalid_user(): # define a subclass with a different primary key field than the
# parent
with self.assertRaises(ValueError):
class EmailUser(User): class EmailUser(User):
email = StringField(primary_key=True) email = StringField(primary_key=True)
self.assertRaises(ValueError, define_invalid_user)
class EmailUser(User): class EmailUser(User):
email = StringField() email = StringField()
@ -410,9 +403,8 @@ class InstanceTest(unittest.TestCase):
# and the NicePlace model not being imported in at query time. # and the NicePlace model not being imported in at query time.
del(_document_registry['Place.NicePlace']) del(_document_registry['Place.NicePlace'])
def query_without_importing_nice_place(): with self.assertRaises(NotRegistered):
list(Place.objects.all()) list(Place.objects.all())
self.assertRaises(NotRegistered, query_without_importing_nice_place)
def test_document_registry_regressions(self): def test_document_registry_regressions(self):
@ -794,8 +786,10 @@ class InstanceTest(unittest.TestCase):
def test_modify_empty(self): def test_modify_empty(self):
doc = self.Person(name="bob", age=10).save() doc = self.Person(name="bob", age=10).save()
self.assertRaises(
InvalidDocumentError, lambda: self.Person().modify(set__age=10)) with self.assertRaises(InvalidDocumentError):
self.Person().modify(set__age=10)
self.assertDbEqual([dict(doc.to_mongo())]) self.assertDbEqual([dict(doc.to_mongo())])
def test_modify_invalid_query(self): def test_modify_invalid_query(self):
@ -803,9 +797,8 @@ class InstanceTest(unittest.TestCase):
doc2 = self.Person(name="jim", age=20).save() doc2 = self.Person(name="jim", age=20).save()
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
self.assertRaises( with self.assertRaises(InvalidQueryError):
InvalidQueryError, doc1.modify({'id': doc2.id}, set__value=20)
lambda: doc1.modify({'id': doc2.id}, set__value=20))
self.assertDbEqual(docs) self.assertDbEqual(docs)
@ -1289,12 +1282,11 @@ class InstanceTest(unittest.TestCase):
def test_document_update(self): def test_document_update(self):
def update_not_saved_raises(): # try updating a non-saved document
with self.assertRaises(OperationError):
person = self.Person(name='dcrosta') person = self.Person(name='dcrosta')
person.update(set__name='Dan Crosta') person.update(set__name='Dan Crosta')
self.assertRaises(OperationError, update_not_saved_raises)
author = self.Person(name='dcrosta') author = self.Person(name='dcrosta')
author.save() author.save()
@ -1304,19 +1296,17 @@ class InstanceTest(unittest.TestCase):
p1 = self.Person.objects.first() p1 = self.Person.objects.first()
self.assertEqual(p1.name, author.name) self.assertEqual(p1.name, author.name)
def update_no_value_raises(): # try sending an empty update
with self.assertRaises(OperationError):
person = self.Person.objects.first() person = self.Person.objects.first()
person.update() person.update()
self.assertRaises(OperationError, update_no_value_raises) # update that doesn't explicitly specify an operator should default
# to 'set__'
def update_no_op_should_default_to_set():
person = self.Person.objects.first() person = self.Person.objects.first()
person.update(name="Dan") person.update(name="Dan")
person.reload() person.reload()
return person.name self.assertEqual("Dan", person.name)
self.assertEqual("Dan", update_no_op_should_default_to_set())
def test_update_unique_field(self): def test_update_unique_field(self):
class Doc(Document): class Doc(Document):
@ -1325,8 +1315,8 @@ class InstanceTest(unittest.TestCase):
doc1 = Doc(name="first").save() doc1 = Doc(name="first").save()
doc2 = Doc(name="second").save() doc2 = Doc(name="second").save()
self.assertRaises(NotUniqueError, lambda: with self.assertRaises(NotUniqueError):
doc2.update(set__name=doc1.name)) doc2.update(set__name=doc1.name)
def test_embedded_update(self): def test_embedded_update(self):
""" """
@ -1844,15 +1834,13 @@ class InstanceTest(unittest.TestCase):
def test_duplicate_db_fields_raise_invalid_document_error(self): def test_duplicate_db_fields_raise_invalid_document_error(self):
"""Ensure a InvalidDocumentError is thrown if duplicate fields """Ensure a InvalidDocumentError is thrown if duplicate fields
declare the same db_field""" declare the same db_field.
"""
def throw_invalid_document_error(): with self.assertRaises(InvalidDocumentError):
class Foo(Document): class Foo(Document):
name = StringField() name = StringField()
name2 = StringField(db_field='name') name2 = StringField(db_field='name')
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def test_invalid_son(self): def test_invalid_son(self):
"""Raise an error if loading invalid data""" """Raise an error if loading invalid data"""
class Occurrence(EmbeddedDocument): class Occurrence(EmbeddedDocument):
@ -1864,11 +1852,13 @@ class InstanceTest(unittest.TestCase):
forms = ListField(StringField(), default=list) forms = ListField(StringField(), default=list)
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
def raise_invalid_document(): with self.assertRaises(InvalidDocumentError):
Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', Word._from_son({
'occurs': {"hello": None}}) 'stem': [1, 2, 3],
'forms': 1,
self.assertRaises(InvalidDocumentError, raise_invalid_document) 'count': 'one',
'occurs': {"hello": None}
})
def test_reverse_delete_rule_cascade_and_nullify(self): def test_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon deletion. """Ensure that a referenced document is also deleted upon deletion.
@ -2099,8 +2089,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Bar.objects.get().foo, None) self.assertEqual(Bar.objects.get().foo, None)
def test_invalid_reverse_delete_rule_raise_errors(self): def test_invalid_reverse_delete_rule_raise_errors(self):
with self.assertRaises(InvalidDocumentError):
def throw_invalid_document_error():
class Blog(Document): class Blog(Document):
content = StringField() content = StringField()
authors = MapField(ReferenceField( authors = MapField(ReferenceField(
@ -2110,21 +2099,15 @@ class InstanceTest(unittest.TestCase):
self.Person, self.Person,
reverse_delete_rule=NULLIFY)) reverse_delete_rule=NULLIFY))
self.assertRaises(InvalidDocumentError, throw_invalid_document_error) with self.assertRaises(InvalidDocumentError):
def throw_invalid_document_error_embedded():
class Parents(EmbeddedDocument): class Parents(EmbeddedDocument):
father = ReferenceField('Person', reverse_delete_rule=DENY) father = ReferenceField('Person', reverse_delete_rule=DENY)
mother = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY)
self.assertRaises(
InvalidDocumentError, throw_invalid_document_error_embedded)
def test_reverse_delete_rule_cascade_recurs(self): def test_reverse_delete_rule_cascade_recurs(self):
"""Ensure that a chain of documents is also deleted upon cascaded """Ensure that a chain of documents is also deleted upon cascaded
deletion. deletion.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
@ -2340,15 +2323,14 @@ class InstanceTest(unittest.TestCase):
pickle_doc.save() pickle_doc.save()
pickle_doc.delete() pickle_doc.delete()
def test_throw_invalid_document_error(self): def test_override_method_with_field(self):
"""Test creating a field with a field name that would override
# test handles people trying to upsert the "validate" method.
def throw_invalid_document_error(): """
with self.assertRaises(InvalidDocumentError):
class Blog(Document): class Blog(Document):
validate = DictField() validate = DictField()
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def test_mutating_documents(self): def test_mutating_documents(self):
class B(EmbeddedDocument): class B(EmbeddedDocument):
@ -2811,11 +2793,10 @@ class InstanceTest(unittest.TestCase):
log.log = "Saving" log.log = "Saving"
log.save() log.save()
def change_shard_key(): # try to change the shard key
with self.assertRaises(OperationError):
log.machine = "127.0.0.1" log.machine = "127.0.0.1"
self.assertRaises(OperationError, change_shard_key)
def test_shard_key_in_embedded_document(self): def test_shard_key_in_embedded_document(self):
class Foo(EmbeddedDocument): class Foo(EmbeddedDocument):
foo = StringField() foo = StringField()
@ -2836,12 +2817,11 @@ class InstanceTest(unittest.TestCase):
bar_doc.bar = 'baz' bar_doc.bar = 'baz'
bar_doc.save() bar_doc.save()
def change_shard_key(): # try to change the shard key
with self.assertRaises(OperationError):
bar_doc.foo.foo = 'something' bar_doc.foo.foo = 'something'
bar_doc.save() bar_doc.save()
self.assertRaises(OperationError, change_shard_key)
def test_shard_key_primary(self): def test_shard_key_primary(self):
class LogEntry(Document): class LogEntry(Document):
machine = StringField(primary_key=True) machine = StringField(primary_key=True)
@ -2862,11 +2842,10 @@ class InstanceTest(unittest.TestCase):
log.log = "Saving" log.log = "Saving"
log.save() log.save()
def change_shard_key(): # try to change the shard key
with self.assertRaises(OperationError):
log.machine = "127.0.0.1" log.machine = "127.0.0.1"
self.assertRaises(OperationError, change_shard_key)
def test_kwargs_simple(self): def test_kwargs_simple(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -2951,11 +2930,9 @@ class InstanceTest(unittest.TestCase):
def test_bad_mixed_creation(self): def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments """Ensure that document gives correct error when duplicating arguments
""" """
def construct_bad_instance(): with self.assertRaises(TypeError):
return self.Person("Test User", 42, name="Bad User") return self.Person("Test User", 42, name="Bad User")
self.assertRaises(TypeError, construct_bad_instance)
def test_data_contains_id_field(self): def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id' """Ensure that asking for _data returns 'id'
""" """

View File

@ -153,14 +153,14 @@ class ValidatorErrorTest(unittest.TestCase):
s = SubDoc() s = SubDoc()
self.assertRaises(ValidationError, lambda: s.validate()) self.assertRaises(ValidationError, s.validate)
d1.e = s d1.e = s
d2.e = s d2.e = s
del d1 del d1
self.assertRaises(ValidationError, lambda: d2.validate()) self.assertRaises(ValidationError, d2.validate)
def test_parent_reference_in_child_document(self): def test_parent_reference_in_child_document(self):
""" """

View File

@ -1115,12 +1115,11 @@ class FieldTest(unittest.TestCase):
e.mapping = [1] e.mapping = [1]
e.save() e.save()
def create_invalid_mapping(): # try creating an invalid mapping
with self.assertRaises(ValidationError):
e.mapping = ["abc"] e.mapping = ["abc"]
e.save() e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection() Simple.drop_collection()
def test_list_field_rejects_strings(self): def test_list_field_rejects_strings(self):
@ -1387,12 +1386,11 @@ class FieldTest(unittest.TestCase):
e.mapping['someint'] = 1 e.mapping['someint'] = 1
e.save() e.save()
def create_invalid_mapping(): # try creating an invalid mapping
with self.assertRaises(ValidationError):
e.mapping['somestring'] = "abc" e.mapping['somestring'] = "abc"
e.save() e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection() Simple.drop_collection()
def test_dictfield_complex(self): def test_dictfield_complex(self):
@ -1465,11 +1463,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(BaseDict, type(e.mapping)) self.assertEqual(BaseDict, type(e.mapping))
self.assertEqual({"ints": [3, 4]}, e.mapping) self.assertEqual({"ints": [3, 4]}, e.mapping)
def create_invalid_mapping(): # try creating an invalid mapping
with self.assertRaises(ValueError):
e.update(set__mapping={"somestrings": ["foo", "bar", ]}) e.update(set__mapping={"somestrings": ["foo", "bar", ]})
self.assertRaises(ValueError, create_invalid_mapping)
Simple.drop_collection() Simple.drop_collection()
def test_mapfield(self): def test_mapfield(self):
@ -1484,18 +1481,14 @@ class FieldTest(unittest.TestCase):
e.mapping['someint'] = 1 e.mapping['someint'] = 1
e.save() e.save()
def create_invalid_mapping(): with self.assertRaises(ValidationError):
e.mapping['somestring'] = "abc" e.mapping['somestring'] = "abc"
e.save() e.save()
self.assertRaises(ValidationError, create_invalid_mapping) with self.assertRaises(ValidationError):
def create_invalid_class():
class NoDeclaredType(Document): class NoDeclaredType(Document):
mapping = MapField() mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection() Simple.drop_collection()
def test_complex_mapfield(self): def test_complex_mapfield(self):
@ -1524,14 +1517,10 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping(): with self.assertRaises(ValidationError):
e.mapping['someint'] = 123 e.mapping['someint'] = 123
e.save() e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
def test_embedded_mapfield_db_field(self): def test_embedded_mapfield_db_field(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -1741,8 +1730,8 @@ class FieldTest(unittest.TestCase):
# Reference is no longer valid # Reference is no longer valid
foo.delete() foo.delete()
bar = Bar.objects.get() bar = Bar.objects.get()
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref')) self.assertRaises(DoesNotExist, getattr, bar, 'ref')
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref')) self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref')
# When auto_dereference is disabled, there is no trouble returning DBRef # When auto_dereference is disabled, there is no trouble returning DBRef
bar = Bar.objects.get() bar = Bar.objects.get()
@ -2017,7 +2006,7 @@ class FieldTest(unittest.TestCase):
}) })
def test_cached_reference_fields_on_embedded_documents(self): def test_cached_reference_fields_on_embedded_documents(self):
def build(): with self.assertRaises(InvalidDocumentError):
class Test(Document): class Test(Document):
name = StringField() name = StringField()
@ -2026,8 +2015,6 @@ class FieldTest(unittest.TestCase):
'test': CachedReferenceField(Test) 'test': CachedReferenceField(Test)
}) })
self.assertRaises(InvalidDocumentError, build)
def test_cached_reference_auto_sync(self): def test_cached_reference_auto_sync(self):
class Person(Document): class Person(Document):
TYPES = ( TYPES = (
@ -3815,9 +3802,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
filtered = self.post1.comments.filter() filtered = self.post1.comments.filter()
# Ensure nothing was changed # Ensure nothing was changed
# < 2.6 Incompatible > self.assertListEqual(filtered, self.post1.comments)
# self.assertListEqual(filtered, self.post1.comments)
self.assertEqual(filtered, self.post1.comments)
def test_single_keyword_filter(self): def test_single_keyword_filter(self):
""" """
@ -3868,10 +3853,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
Tests the filter method of a List of Embedded Documents Tests the filter method of a List of Embedded Documents
when the keyword is not a known keyword. when the keyword is not a known keyword.
""" """
# < 2.6 Incompatible > with self.assertRaises(AttributeError):
# with self.assertRaises(AttributeError): self.post2.comments.filter(year=2)
# self.post2.comments.filter(year=2)
self.assertRaises(AttributeError, self.post2.comments.filter, year=2)
def test_no_keyword_exclude(self): def test_no_keyword_exclude(self):
""" """
@ -3881,9 +3864,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
filtered = self.post1.comments.exclude() filtered = self.post1.comments.exclude()
# Ensure everything was removed # Ensure everything was removed
# < 2.6 Incompatible > self.assertListEqual(filtered, [])
# self.assertListEqual(filtered, [])
self.assertEqual(filtered, [])
def test_single_keyword_exclude(self): def test_single_keyword_exclude(self):
""" """
@ -3929,10 +3910,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
Tests the exclude method of a List of Embedded Documents Tests the exclude method of a List of Embedded Documents
when the keyword is not a known keyword. when the keyword is not a known keyword.
""" """
# < 2.6 Incompatible > with self.assertRaises(AttributeError):
# with self.assertRaises(AttributeError): self.post2.comments.exclude(year=2)
# self.post2.comments.exclude(year=2)
self.assertRaises(AttributeError, self.post2.comments.exclude, year=2)
def test_chained_filter_exclude(self): def test_chained_filter_exclude(self):
""" """
@ -3970,10 +3949,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
single keyword. single keyword.
""" """
comment = self.post1.comments.get(author='user1') comment = self.post1.comments.get(author='user1')
self.assertIsInstance(comment, self.Comments)
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertEqual(comment.author, 'user1') self.assertEqual(comment.author, 'user1')
def test_multi_keyword_get(self): def test_multi_keyword_get(self):
@ -3982,10 +3958,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
multiple keywords. multiple keywords.
""" """
comment = self.post2.comments.get(author='user2', message='message2') comment = self.post2.comments.get(author='user2', message='message2')
self.assertIsInstance(comment, self.Comments)
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertEqual(comment.author, 'user2') self.assertEqual(comment.author, 'user2')
self.assertEqual(comment.message, 'message2') self.assertEqual(comment.message, 'message2')
@ -3994,44 +3967,32 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
Tests the get method of a List of Embedded Documents without Tests the get method of a List of Embedded Documents without
a keyword to return multiple documents. a keyword to return multiple documents.
""" """
# < 2.6 Incompatible > with self.assertRaises(MultipleObjectsReturned):
# with self.assertRaises(MultipleObjectsReturned): self.post1.comments.get()
# self.post1.comments.get()
self.assertRaises(MultipleObjectsReturned, self.post1.comments.get)
def test_keyword_multiple_return_get(self): def test_keyword_multiple_return_get(self):
""" """
Tests the get method of a List of Embedded Documents with a keyword Tests the get method of a List of Embedded Documents with a keyword
to return multiple documents. to return multiple documents.
""" """
# < 2.6 Incompatible > with self.assertRaises(MultipleObjectsReturned):
# with self.assertRaises(MultipleObjectsReturned): self.post2.comments.get(author='user2')
# self.post2.comments.get(author='user2')
self.assertRaises(
MultipleObjectsReturned, self.post2.comments.get, author='user2'
)
def test_unknown_keyword_get(self): def test_unknown_keyword_get(self):
""" """
Tests the get method of a List of Embedded Documents with an Tests the get method of a List of Embedded Documents with an
unknown keyword. unknown keyword.
""" """
# < 2.6 Incompatible > with self.assertRaises(AttributeError):
# with self.assertRaises(AttributeError): self.post2.comments.get(year=2020)
# self.post2.comments.get(year=2020)
self.assertRaises(AttributeError, self.post2.comments.get, year=2020)
def test_no_result_get(self): def test_no_result_get(self):
""" """
Tests the get method of a List of Embedded Documents where get Tests the get method of a List of Embedded Documents where get
returns no results. returns no results.
""" """
# < 2.6 Incompatible > with self.assertRaises(DoesNotExist):
# with self.assertRaises(DoesNotExist): self.post1.comments.get(author='user3')
# self.post1.comments.get(author='user3')
self.assertRaises(
DoesNotExist, self.post1.comments.get, author='user3'
)
def test_first(self): def test_first(self):
""" """
@ -4041,9 +4002,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
comment = self.post1.comments.first() comment = self.post1.comments.first()
# Ensure a Comment object was returned. # Ensure a Comment object was returned.
# < 2.6 Incompatible > self.assertIsInstance(comment, self.Comments)
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertEqual(comment, self.post1.comments[0]) self.assertEqual(comment, self.post1.comments[0])
def test_create(self): def test_create(self):
@ -4056,22 +4015,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save() self.post1.save()
# Ensure the returned value is the comment object. # Ensure the returned value is the comment object.
# < 2.6 Incompatible > self.assertIsInstance(comment, self.Comments)
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertEqual(comment.author, 'user4') self.assertEqual(comment.author, 'user4')
self.assertEqual(comment.message, 'message1') self.assertEqual(comment.message, 'message1')
# Ensure the new comment was actually saved to the database. # Ensure the new comment was actually saved to the database.
# < 2.6 Incompatible > self.assertIn(
# self.assertIn( comment,
# comment, self.BlogPost.objects(comments__author='user4')[0].comments
# self.BlogPost.objects(comments__author='user4')[0].comments
# )
self.assertTrue(
comment in self.BlogPost.objects(
comments__author='user4'
)[0].comments
) )
def test_filtered_create(self): def test_filtered_create(self):
@ -4086,22 +4037,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save() self.post1.save()
# Ensure the returned value is the comment object. # Ensure the returned value is the comment object.
# < 2.6 Incompatible > self.assertIsInstance(comment, self.Comments)
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertEqual(comment.author, 'user4') self.assertEqual(comment.author, 'user4')
self.assertEqual(comment.message, 'message1') self.assertEqual(comment.message, 'message1')
# Ensure the new comment was actually saved to the database. # Ensure the new comment was actually saved to the database.
# < 2.6 Incompatible > self.assertIn(
# self.assertIn( comment,
# comment, self.BlogPost.objects(comments__author='user4')[0].comments
# self.BlogPost.objects(comments__author='user4')[0].comments
# )
self.assertTrue(
comment in self.BlogPost.objects(
comments__author='user4'
)[0].comments
) )
def test_no_keyword_update(self): def test_no_keyword_update(self):
@ -4114,22 +4057,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save() self.post1.save()
# Ensure that nothing was altered. # Ensure that nothing was altered.
# < 2.6 Incompatible > self.assertIn(
# self.assertIn( original[0],
# original[0], self.BlogPost.objects(id=self.post1.id)[0].comments
# self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments
) )
# < 2.6 Incompatible > self.assertIn(
# self.assertIn( original[1],
# original[1], self.BlogPost.objects(id=self.post1.id)[0].comments
# self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments
) )
# Ensure the method returned 0 as the number of entries # Ensure the method returned 0 as the number of entries
@ -4175,13 +4110,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
comments.save() comments.save()
# Ensure that the new comment has been added to the database. # Ensure that the new comment has been added to the database.
# < 2.6 Incompatible > self.assertIn(
# self.assertIn( new_comment,
# new_comment, self.BlogPost.objects(id=self.post1.id)[0].comments
# self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments
) )
def test_delete(self): def test_delete(self):
@ -4193,23 +4124,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
# Ensure that all the comments under post1 were deleted in the # Ensure that all the comments under post1 were deleted in the
# database. # database.
# < 2.6 Incompatible > self.assertListEqual(
# self.assertListEqual(
# self.BlogPost.objects(id=self.post1.id)[0].comments, []
# )
self.assertEqual(
self.BlogPost.objects(id=self.post1.id)[0].comments, [] self.BlogPost.objects(id=self.post1.id)[0].comments, []
) )
# Ensure that post1 comments were deleted from the list. # Ensure that post1 comments were deleted from the list.
# < 2.6 Incompatible > self.assertListEqual(self.post1.comments, [])
# self.assertListEqual(self.post1.comments, [])
self.assertEqual(self.post1.comments, [])
# Ensure that comments still returned a EmbeddedDocumentList object. # Ensure that comments still returned a EmbeddedDocumentList object.
# < 2.6 Incompatible > self.assertIsInstance(self.post1.comments, EmbeddedDocumentList)
# self.assertIsInstance(self.post1.comments, EmbeddedDocumentList)
self.assertTrue(isinstance(self.post1.comments, EmbeddedDocumentList))
# Ensure that the delete method returned 2 as the number of entries # Ensure that the delete method returned 2 as the number of entries
# deleted from the database # deleted from the database
@ -4249,21 +4172,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save() self.post1.save()
# Ensure that only the user2 comment was deleted. # Ensure that only the user2 comment was deleted.
# < 2.6 Incompatible > self.assertNotIn(
# self.assertNotIn( comment, self.BlogPost.objects(id=self.post1.id)[0].comments
# comment, self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
comment not in self.BlogPost.objects(id=self.post1.id)[0].comments
) )
self.assertEqual( self.assertEqual(
len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1 len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1
) )
# Ensure that the user2 comment no longer exists in the list. # Ensure that the user2 comment no longer exists in the list.
# < 2.6 Incompatible > self.assertNotIn(comment, self.post1.comments)
# self.assertNotIn(comment, self.post1.comments)
self.assertTrue(comment not in self.post1.comments)
self.assertEqual(len(self.post1.comments), 1) self.assertEqual(len(self.post1.comments), 1)
# Ensure that the delete method returned 1 as the number of entries # Ensure that the delete method returned 1 as the number of entries

View File

@ -25,7 +25,10 @@ __all__ = ("QuerySetTest",)
class db_ops_tracker(query_counter): class db_ops_tracker(query_counter):
def get_ops(self): def get_ops(self):
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} ignore_query = {
'ns': {'$ne': '%s.system.indexes' % self.db.name},
'command.count': {'$ne': 'system.profile'}
}
return list(self.db.system.profile.find(ignore_query)) return list(self.db.system.profile.find(ignore_query))
@ -94,12 +97,12 @@ class QuerySetTest(unittest.TestCase):
author = ReferenceField(self.Person) author = ReferenceField(self.Person)
author2 = GenericReferenceField() author2 = GenericReferenceField()
def test_reference(): # test addressing a field from a reference
with self.assertRaises(InvalidQueryError):
list(BlogPost.objects(author__name="test")) list(BlogPost.objects(author__name="test"))
self.assertRaises(InvalidQueryError, test_reference) # should fail for a generic reference as well
with self.assertRaises(InvalidQueryError):
def test_generic_reference():
list(BlogPost.objects(author2__name="test")) list(BlogPost.objects(author2__name="test"))
def test_find(self): def test_find(self):
@ -218,14 +221,15 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects[1] person = self.Person.objects[1]
self.assertEqual(person.name, "User B") self.assertEqual(person.name, "User B")
self.assertRaises(IndexError, self.Person.objects.__getitem__, 2) with self.assertRaises(IndexError):
self.Person.objects[2]
# Find a document using just the object id # Find a document using just the object id
person = self.Person.objects.with_id(person1.id) person = self.Person.objects.with_id(person1.id)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertRaises( with self.assertRaises(InvalidQueryError):
InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id) self.Person.objects(name="User A").with_id(person1.id)
def test_find_only_one(self): def test_find_only_one(self):
"""Ensure that a query using ``get`` returns at most one result. """Ensure that a query using ``get`` returns at most one result.
@ -363,7 +367,8 @@ class QuerySetTest(unittest.TestCase):
# test invalid batch size # test invalid batch size
qs = A.objects.batch_size(-1) qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs)) with self.assertRaises(ValueError):
list(qs)
def test_update_write_concern(self): def test_update_write_concern(self):
"""Test that passing write_concern works""" """Test that passing write_concern works"""
@ -392,18 +397,14 @@ class QuerySetTest(unittest.TestCase):
"""Test to ensure that update is passed a value to update to""" """Test to ensure that update is passed a value to update to"""
self.Person.drop_collection() self.Person.drop_collection()
author = self.Person(name='Test User') author = self.Person.objects.create(name='Test User')
author.save()
def update_raises(): with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update({}) self.Person.objects(pk=author.pk).update({})
def update_one_raises(): with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update_one({}) self.Person.objects(pk=author.pk).update_one({})
self.assertRaises(OperationError, update_raises)
self.assertRaises(OperationError, update_one_raises)
def test_update_array_position(self): def test_update_array_position(self):
"""Ensure that updating by array position works. """Ensure that updating by array position works.
@ -431,8 +432,8 @@ class QuerySetTest(unittest.TestCase):
Blog.objects.create(posts=[post2, post1]) Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs # Update all of the first comments of second posts of all blogs
Blog.objects().update(set__posts__1__comments__0__name="testc") Blog.objects().update(set__posts__1__comments__0__name='testc')
testc_blogs = Blog.objects(posts__1__comments__0__name="testc") testc_blogs = Blog.objects(posts__1__comments__0__name='testc')
self.assertEqual(testc_blogs.count(), 2) self.assertEqual(testc_blogs.count(), 2)
Blog.drop_collection() Blog.drop_collection()
@ -441,14 +442,13 @@ class QuerySetTest(unittest.TestCase):
# Update only the first blog returned by the query # Update only the first blog returned by the query
Blog.objects().update_one( Blog.objects().update_one(
set__posts__1__comments__1__name="testc") set__posts__1__comments__1__name='testc')
testc_blogs = Blog.objects(posts__1__comments__1__name="testc") testc_blogs = Blog.objects(posts__1__comments__1__name='testc')
self.assertEqual(testc_blogs.count(), 1) self.assertEqual(testc_blogs.count(), 1)
# Check that using this indexing syntax on a non-list fails # Check that using this indexing syntax on a non-list fails
def non_list_indexing(): with self.assertRaises(InvalidQueryError):
Blog.objects().update(set__posts__1__comments__0__name__1="asdf") Blog.objects().update(set__posts__1__comments__0__name__1='asdf')
self.assertRaises(InvalidQueryError, non_list_indexing)
Blog.drop_collection() Blog.drop_collection()
@ -516,15 +516,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4])
# Nested updates arent supported yet.. # Nested updates arent supported yet..
def update_nested(): with self.assertRaises(OperationError):
Simple.drop_collection() Simple.drop_collection()
Simple(x=[{'test': [1, 2, 3, 4]}]).save() Simple(x=[{'test': [1, 2, 3, 4]}]).save()
Simple.objects(x__test=2).update(set__x__S__test__S=3) Simple.objects(x__test=2).update(set__x__S__test__S=3)
self.assertEqual(simple.x, [1, 2, 3, 4]) self.assertEqual(simple.x, [1, 2, 3, 4])
self.assertRaises(OperationError, update_nested)
Simple.drop_collection()
def test_update_using_positional_operator_embedded_document(self): def test_update_using_positional_operator_embedded_document(self):
"""Ensure that the embedded documents can be updated using the positional """Ensure that the embedded documents can be updated using the positional
operator.""" operator."""
@ -839,30 +836,31 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2) self.assertEqual(Blog.objects.count(), 2)
# test handles people trying to upsert # test inserting an existing document (shouldn't be allowed)
def throw_operation_error(): with self.assertRaises(OperationError):
blog = Blog.objects.first()
Blog.objects.insert(blog)
# test inserting a query set
with self.assertRaises(OperationError):
blogs = Blog.objects blogs = Blog.objects
Blog.objects.insert(blogs) Blog.objects.insert(blogs)
self.assertRaises(OperationError, throw_operation_error) # insert a new doc
# Test can insert new doc
new_post = Blog(title="code123", id=ObjectId()) new_post = Blog(title="code123", id=ObjectId())
Blog.objects.insert(new_post) Blog.objects.insert(new_post)
# test handles other classes being inserted
def throw_operation_error_wrong_doc():
class Author(Document): class Author(Document):
pass pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author()) Blog.objects.insert(Author())
self.assertRaises(OperationError, throw_operation_error_wrong_doc) # try inserting a non-document
with self.assertRaises(OperationError):
def throw_operation_error_not_a_document():
Blog.objects.insert("HELLO WORLD") Blog.objects.insert("HELLO WORLD")
self.assertRaises(OperationError, throw_operation_error_not_a_document)
Blog.drop_collection() Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2]) blog1 = Blog(title="code", posts=[post1, post2])
@ -882,14 +880,13 @@ class QuerySetTest(unittest.TestCase):
blog3 = Blog(title="baz", posts=[post1, post2]) blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2]) Blog.objects.insert([blog1, blog2])
def throw_operation_error_not_unique(): with self.assertRaises(NotUniqueError):
Blog.objects.insert([blog2, blog3]) Blog.objects.insert([blog2, blog3])
self.assertRaises(NotUniqueError, throw_operation_error_not_unique)
self.assertEqual(Blog.objects.count(), 2) self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], write_concern={"w": 0, Blog.objects.insert([blog2, blog3],
'continue_on_error': True}) write_concern={"w": 0, 'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3) self.assertEqual(Blog.objects.count(), 3)
def test_get_changed_fields_query_count(self): def test_get_changed_fields_query_count(self):
@ -1233,7 +1230,9 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.filter(title='whatever').first() BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1) self.assertEqual(len(q.get_ops()), 1)
self.assertEqual( self.assertEqual(
q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) q.get_ops()[0]['query']['$orderby'],
{'published_date': -1}
)
with db_ops_tracker() as q: with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first() BlogPost.objects.filter(title='whatever').order_by().first()
@ -1910,12 +1909,10 @@ class QuerySetTest(unittest.TestCase):
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
self.assertEqual(Site.objects.first().collaborators, []) self.assertEqual(Site.objects.first().collaborators, [])
def pull_all(): with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one( Site.objects(id=s.id).update_one(
pull_all__collaborators__user=['Ross']) pull_all__collaborators__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_embedded(self): def test_pull_from_nested_embedded(self):
class User(EmbeddedDocument): class User(EmbeddedDocument):
@ -1946,12 +1943,10 @@ class QuerySetTest(unittest.TestCase):
pull__collaborators__unhelpful={'name': 'Frank'}) pull__collaborators__unhelpful={'name': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all(): with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one( Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__name=['Ross']) pull_all__collaborators__helpful__name=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_mapfield(self): def test_pull_from_nested_mapfield(self):
class Collaborator(EmbeddedDocument): class Collaborator(EmbeddedDocument):
@ -1980,12 +1975,10 @@ class QuerySetTest(unittest.TestCase):
pull__collaborators__unhelpful={'user': 'Frank'}) pull__collaborators__unhelpful={'user': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all(): with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one( Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__user=['Ross']) pull_all__collaborators__helpful__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_update_one_pop_generic_reference(self): def test_update_one_pop_generic_reference(self):
class BlogTag(Document): class BlogTag(Document):
@ -3821,11 +3814,9 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(a in results) self.assertTrue(a in results)
self.assertTrue(c in results) self.assertTrue(c in results)
def invalid_where(): with self.assertRaises(TypeError):
list(IntPair.objects.where(fielda__gte=3)) list(IntPair.objects.where(fielda__gte=3))
self.assertRaises(TypeError, invalid_where)
def test_scalar(self): def test_scalar(self):
class Organization(Document): class Organization(Document):
@ -4550,7 +4541,9 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(counter, 100) self.assertEqual(counter, 100)
self.assertEqual(len(list(docs)), 100) self.assertEqual(len(list(docs)), 100)
self.assertRaises(TypeError, lambda: len(docs))
with self.assertRaises(TypeError):
len(docs)
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
@ -4875,7 +4868,9 @@ class QuerySetTest(unittest.TestCase):
def test_max_time_ms(self): def test_max_time_ms(self):
# 778: max_time_ms can get only int or None as input # 778: max_time_ms can get only int or None as input
self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number") self.assertRaises(TypeError,
self.Person.objects(name="name").max_time_ms,
'not a number')
def test_subclass_field_query(self): def test_subclass_field_query(self):
class Animal(Document): class Animal(Document):

View File

@ -238,7 +238,8 @@ class TransformTest(unittest.TestCase):
box = [(35.0, -125.0), (40.0, -100.0)] box = [(35.0, -125.0), (40.0, -100.0)]
# I *meant* to execute location__within_box=box # I *meant* to execute location__within_box=box
events = Event.objects(location__within=box) events = Event.objects(location__within=box)
self.assertRaises(InvalidQueryError, lambda: events.count()) with self.assertRaises(InvalidQueryError):
events.count()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -268,14 +268,13 @@ class QTest(unittest.TestCase):
self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3)
# Test invalid query objs # Test invalid query objs
def wrong_query_objs(): with self.assertRaises(InvalidQueryError):
self.Person.objects('user1') self.Person.objects('user1')
def wrong_query_objs_filter(): # filter should fail, too
self.Person.objects('user1') with self.assertRaises(InvalidQueryError):
self.Person.objects.filter('user1')
self.assertRaises(InvalidQueryError, wrong_query_objs)
self.assertRaises(InvalidQueryError, wrong_query_objs_filter)
def test_q_regex(self): def test_q_regex(self):
"""Ensure that Q objects can be queried using regexes. """Ensure that Q objects can be queried using regexes.

View File

@ -23,7 +23,8 @@ class TestStrictDict(unittest.TestCase):
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
def test_init_fails_on_nonexisting_attrs(self): def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) with self.assertRaises(AttributeError):
self.dtype(a=1, b=2, d=3)
def test_eq(self): def test_eq(self):
d = self.dtype(a=1, b=1, c=1) d = self.dtype(a=1, b=1, c=1)
@ -46,14 +47,12 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype() d = self.dtype()
d.a = 1 d.a = 1
self.assertEqual(d.a, 1) self.assertEqual(d.a, 1)
self.assertRaises(AttributeError, lambda: d.b) self.assertRaises(AttributeError, getattr, d, 'b')
def test_setattr_raises_on_nonexisting_attr(self): def test_setattr_raises_on_nonexisting_attr(self):
d = self.dtype() d = self.dtype()
with self.assertRaises(AttributeError):
def _f():
d.x = 1 d.x = 1
self.assertRaises(AttributeError, _f)
def test_setattr_getattr_special(self): def test_setattr_getattr_special(self):
d = self.strict_dict_class(["items"]) d = self.strict_dict_class(["items"])