use with self.assertRaises for readability

This commit is contained in:
Stefan Wojcik
2016-12-10 22:33:39 -05:00
parent a8884391c2
commit 3ebe3748fa
8 changed files with 206 additions and 326 deletions

View File

@@ -25,7 +25,10 @@ __all__ = ("QuerySetTest",)
class db_ops_tracker(query_counter):
def get_ops(self):
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
ignore_query = {
'ns': {'$ne': '%s.system.indexes' % self.db.name},
'command.count': {'$ne': 'system.profile'}
}
return list(self.db.system.profile.find(ignore_query))
@@ -94,12 +97,12 @@ class QuerySetTest(unittest.TestCase):
author = ReferenceField(self.Person)
author2 = GenericReferenceField()
def test_reference():
# test addressing a field from a reference
with self.assertRaises(InvalidQueryError):
list(BlogPost.objects(author__name="test"))
self.assertRaises(InvalidQueryError, test_reference)
def test_generic_reference():
# should fail for a generic reference as well
with self.assertRaises(InvalidQueryError):
list(BlogPost.objects(author2__name="test"))
def test_find(self):
@@ -218,14 +221,15 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects[1]
self.assertEqual(person.name, "User B")
self.assertRaises(IndexError, self.Person.objects.__getitem__, 2)
with self.assertRaises(IndexError):
self.Person.objects[2]
# Find a document using just the object id
person = self.Person.objects.with_id(person1.id)
self.assertEqual(person.name, "User A")
self.assertRaises(
InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id)
with self.assertRaises(InvalidQueryError):
self.Person.objects(name="User A").with_id(person1.id)
def test_find_only_one(self):
"""Ensure that a query using ``get`` returns at most one result.
@@ -363,7 +367,8 @@ class QuerySetTest(unittest.TestCase):
# test invalid batch size
qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs))
with self.assertRaises(ValueError):
list(qs)
def test_update_write_concern(self):
"""Test that passing write_concern works"""
@@ -392,18 +397,14 @@ class QuerySetTest(unittest.TestCase):
"""Test to ensure that update is passed a value to update to"""
self.Person.drop_collection()
author = self.Person(name='Test User')
author.save()
author = self.Person.objects.create(name='Test User')
def update_raises():
with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update({})
def update_one_raises():
with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update_one({})
self.assertRaises(OperationError, update_raises)
self.assertRaises(OperationError, update_one_raises)
def test_update_array_position(self):
"""Ensure that updating by array position works.
@@ -431,8 +432,8 @@ class QuerySetTest(unittest.TestCase):
Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
Blog.objects().update(set__posts__1__comments__0__name="testc")
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
Blog.objects().update(set__posts__1__comments__0__name='testc')
testc_blogs = Blog.objects(posts__1__comments__0__name='testc')
self.assertEqual(testc_blogs.count(), 2)
Blog.drop_collection()
@@ -441,14 +442,13 @@ class QuerySetTest(unittest.TestCase):
# Update only the first blog returned by the query
Blog.objects().update_one(
set__posts__1__comments__1__name="testc")
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
set__posts__1__comments__1__name='testc')
testc_blogs = Blog.objects(posts__1__comments__1__name='testc')
self.assertEqual(testc_blogs.count(), 1)
# Check that using this indexing syntax on a non-list fails
def non_list_indexing():
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
self.assertRaises(InvalidQueryError, non_list_indexing)
with self.assertRaises(InvalidQueryError):
Blog.objects().update(set__posts__1__comments__0__name__1='asdf')
Blog.drop_collection()
@@ -516,15 +516,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4])
# Nested updates arent supported yet..
def update_nested():
with self.assertRaises(OperationError):
Simple.drop_collection()
Simple(x=[{'test': [1, 2, 3, 4]}]).save()
Simple.objects(x__test=2).update(set__x__S__test__S=3)
self.assertEqual(simple.x, [1, 2, 3, 4])
self.assertRaises(OperationError, update_nested)
Simple.drop_collection()
def test_update_using_positional_operator_embedded_document(self):
"""Ensure that the embedded documents can be updated using the positional
operator."""
@@ -839,30 +836,31 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2)
# test handles people trying to upsert
def throw_operation_error():
# test inserting an existing document (shouldn't be allowed)
with self.assertRaises(OperationError):
blog = Blog.objects.first()
Blog.objects.insert(blog)
# test inserting a query set
with self.assertRaises(OperationError):
blogs = Blog.objects
Blog.objects.insert(blogs)
self.assertRaises(OperationError, throw_operation_error)
# Test can insert new doc
# insert a new doc
new_post = Blog(title="code123", id=ObjectId())
Blog.objects.insert(new_post)
# test handles other classes being inserted
def throw_operation_error_wrong_doc():
class Author(Document):
pass
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
self.assertRaises(OperationError, throw_operation_error_wrong_doc)
def throw_operation_error_not_a_document():
# try inserting a non-document
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
self.assertRaises(OperationError, throw_operation_error_not_a_document)
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
@@ -882,14 +880,13 @@ class QuerySetTest(unittest.TestCase):
blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2])
def throw_operation_error_not_unique():
with self.assertRaises(NotUniqueError):
Blog.objects.insert([blog2, blog3])
self.assertRaises(NotUniqueError, throw_operation_error_not_unique)
self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], write_concern={"w": 0,
'continue_on_error': True})
Blog.objects.insert([blog2, blog3],
write_concern={"w": 0, 'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3)
def test_get_changed_fields_query_count(self):
@@ -1233,7 +1230,9 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
q.get_ops()[0]['query']['$orderby'], {u'published_date': -1})
q.get_ops()[0]['query']['$orderby'],
{'published_date': -1}
)
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first()
@@ -1910,12 +1909,10 @@ class QuerySetTest(unittest.TestCase):
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
self.assertEqual(Site.objects.first().collaborators, [])
def pull_all():
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
pull_all__collaborators__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_embedded(self):
class User(EmbeddedDocument):
@@ -1946,12 +1943,10 @@ class QuerySetTest(unittest.TestCase):
pull__collaborators__unhelpful={'name': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__name=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_mapfield(self):
class Collaborator(EmbeddedDocument):
@@ -1980,12 +1975,10 @@ class QuerySetTest(unittest.TestCase):
pull__collaborators__unhelpful={'user': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_update_one_pop_generic_reference(self):
class BlogTag(Document):
@@ -3821,11 +3814,9 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(a in results)
self.assertTrue(c in results)
def invalid_where():
with self.assertRaises(TypeError):
list(IntPair.objects.where(fielda__gte=3))
self.assertRaises(TypeError, invalid_where)
def test_scalar(self):
class Organization(Document):
@@ -4550,7 +4541,9 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(counter, 100)
self.assertEqual(len(list(docs)), 100)
self.assertRaises(TypeError, lambda: len(docs))
with self.assertRaises(TypeError):
len(docs)
with query_counter() as q:
self.assertEqual(q, 0)
@@ -4875,7 +4868,9 @@ class QuerySetTest(unittest.TestCase):
def test_max_time_ms(self):
# 778: max_time_ms can get only int or None as input
self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number")
self.assertRaises(TypeError,
self.Person.objects(name="name").max_time_ms,
'not a number')
def test_subclass_field_query(self):
class Animal(Document):

View File

@@ -238,7 +238,8 @@ class TransformTest(unittest.TestCase):
box = [(35.0, -125.0), (40.0, -100.0)]
# I *meant* to execute location__within_box=box
events = Event.objects(location__within=box)
self.assertRaises(InvalidQueryError, lambda: events.count())
with self.assertRaises(InvalidQueryError):
events.count()
if __name__ == '__main__':

View File

@@ -268,14 +268,13 @@ class QTest(unittest.TestCase):
self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3)
# Test invalid query objs
def wrong_query_objs():
with self.assertRaises(InvalidQueryError):
self.Person.objects('user1')
def wrong_query_objs_filter():
self.Person.objects('user1')
# filter should fail, too
with self.assertRaises(InvalidQueryError):
self.Person.objects.filter('user1')
self.assertRaises(InvalidQueryError, wrong_query_objs)
self.assertRaises(InvalidQueryError, wrong_query_objs_filter)
def test_q_regex(self):
"""Ensure that Q objects can be queried using regexes.