Merge pull request #2055 from bagerard/improve_test_cov

Improve test cov
This commit is contained in:
erdenezul 2019-05-18 12:40:20 +02:00 committed by GitHub
commit 597b962ad5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 228 additions and 42 deletions

View File

@ -184,9 +184,6 @@ class DocumentMetaclass(type):
if issubclass(new_class, EmbeddedDocument): if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError('CachedReferenceFields is not ' raise InvalidDocumentError('CachedReferenceFields is not '
'allowed in EmbeddedDocuments') 'allowed in EmbeddedDocuments')
if not f.document_type:
raise InvalidDocumentError(
'Document is not available to sync')
if f.auto_sync: if f.auto_sync:
f.start_listener() f.start_listener()

View File

@ -31,7 +31,6 @@ def _import_class(cls_name):
field_classes = _field_list_cache field_classes = _field_list_cache
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)
if cls_name == 'BaseDocument': if cls_name == 'BaseDocument':
@ -43,14 +42,11 @@ def _import_class(cls_name):
elif cls_name in field_classes: elif cls_name in field_classes:
from mongoengine import fields as module from mongoengine import fields as module
import_classes = field_classes import_classes = field_classes
elif cls_name in queryset_classes:
from mongoengine import queryset as module
import_classes = queryset_classes
elif cls_name in deref_classes: elif cls_name in deref_classes:
from mongoengine import dereference as module from mongoengine import dereference as module
import_classes = deref_classes import_classes = deref_classes
else: else:
raise ValueError('No import set for: ' % cls_name) raise ValueError('No import set for: %s' % cls_name)
for cls in import_classes: for cls in import_classes:
_class_registry_cache[cls] = getattr(module, cls) _class_registry_cache[cls] = getattr(module, cls)

View File

@ -110,9 +110,6 @@ class ValidationError(AssertionError):
def build_dict(source): def build_dict(source):
errors_dict = {} errors_dict = {}
if not source:
return errors_dict
if isinstance(source, dict): if isinstance(source, dict):
for field_name, error in iteritems(source): for field_name, error in iteritems(source):
errors_dict[field_name] = build_dict(error) errors_dict[field_name] = build_dict(error)

View File

@ -152,12 +152,10 @@ class URLField(StringField):
scheme = value.split('://')[0].lower() scheme = value.split('://')[0].lower()
if scheme not in self.schemes: if scheme not in self.schemes:
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value)) self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
return
# Then check full URL # Then check full URL
if not self.url_regex.match(value): if not self.url_regex.match(value):
self.error(u'Invalid URL: {}'.format(value)) self.error(u'Invalid URL: {}'.format(value))
return
class EmailField(StringField): class EmailField(StringField):
@ -259,10 +257,10 @@ class EmailField(StringField):
try: try:
domain_part = domain_part.encode('idna').decode('ascii') domain_part = domain_part.encode('idna').decode('ascii')
except UnicodeError: except UnicodeError:
self.error(self.error_msg % value) self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)"))
else: else:
if not self.validate_domain_part(domain_part): if not self.validate_domain_part(domain_part):
self.error(self.error_msg % value) self.error("%s %s" % (self.error_msg % value, "(domain validation failed)"))
class IntField(BaseField): class IntField(BaseField):

View File

@ -197,7 +197,7 @@ class BaseQuerySet(object):
only_fields=self.only_fields only_fields=self.only_fields
) )
raise AttributeError('Provide a slice or an integer index') raise TypeError('Provide a slice or an integer index')
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError

View File

@ -88,18 +88,10 @@ def query(_doc_cls=None, **kwargs):
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += STRING_OPERATORS singular_ops += STRING_OPERATORS
if op in singular_ops: if op in singular_ops:
if isinstance(field, six.string_types): value = field.prepare_query_value(op, value)
if (op in STRING_OPERATORS and
isinstance(value, six.string_types)):
StringField = _import_class('StringField')
value = StringField.prepare_query_value(op, value)
else:
value = field
else:
value = field.prepare_query_value(op, value)
if isinstance(field, CachedReferenceField) and value: if isinstance(field, CachedReferenceField) and value:
value = value['_id'] value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# Raise an error if the in/nin/all/near param is not iterable. # Raise an error if the in/nin/all/near param is not iterable.
@ -308,10 +300,6 @@ def update(_doc_cls=None, **update):
key = '.'.join(parts) key = '.'.join(parts)
if not op:
raise InvalidQueryError('Updates must supply an operation '
'eg: set__FIELD=value')
if 'pull' in op and '.' in key: if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations # Dot operators don't work on pull operations
# unless they point to a list field # unless they point to a list field

View File

@ -593,8 +593,9 @@ class IndexesTest(unittest.TestCase):
# Two posts with the same slug is not allowed # Two posts with the same slug is not allowed
post2 = BlogPost(title='test2', slug='test') post2 = BlogPost(title='test2', slug='test')
self.assertRaises(NotUniqueError, post2.save) self.assertRaises(NotUniqueError, post2.save)
self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2)
# Ensure backwards compatibilty for errors # Ensure backwards compatibility for errors
self.assertRaises(OperationError, post2.save) self.assertRaises(OperationError, post2.save)
@requires_mongodb_gte_34 @requires_mongodb_gte_34
@ -826,6 +827,18 @@ class IndexesTest(unittest.TestCase):
self.assertEqual(3600, self.assertEqual(3600,
info['created_1']['expireAfterSeconds']) info['created_1']['expireAfterSeconds'])
def test_index_drop_dups_silently_ignored(self):
class Customer(Document):
cust_id = IntField(unique=True, required=True)
meta = {
'indexes': ['cust_id'],
'index_drop_dups': True,
'allow_inheritance': False,
}
Customer.drop_collection()
Customer.objects.first()
def test_unique_and_indexes(self): def test_unique_and_indexes(self):
"""Ensure that 'unique' constraints aren't overridden by """Ensure that 'unique' constraints aren't overridden by
meta.indexes. meta.indexes.
@ -842,11 +855,16 @@ class IndexesTest(unittest.TestCase):
cust.save() cust.save()
cust_dupe = Customer(cust_id=1) cust_dupe = Customer(cust_id=1)
try: with self.assertRaises(NotUniqueError):
cust_dupe.save() cust_dupe.save()
raise AssertionError("We saved a dupe!")
except NotUniqueError: cust = Customer(cust_id=2)
pass cust.save()
# duplicate key on update
with self.assertRaises(NotUniqueError):
cust.cust_id = 1
cust.save()
def test_primary_save_duplicate_update_existing_object(self): def test_primary_save_duplicate_update_existing_object(self):
"""If you set a field as primary, then unexpected behaviour can occur. """If you set a field as primary, then unexpected behaviour can occur.

View File

@ -420,6 +420,12 @@ class InstanceTest(MongoDBTestCase):
person.save() person.save()
person.to_dbref() person.to_dbref()
def test_key_like_attribute_access(self):
person = self.Person(age=30)
self.assertEqual(person['age'], 30)
with self.assertRaises(KeyError):
person['unknown_attr']
def test_save_abstract_document(self): def test_save_abstract_document(self):
"""Saving an abstract document should fail.""" """Saving an abstract document should fail."""
class Doc(Document): class Doc(Document):

View File

@ -40,6 +40,11 @@ class GeoFieldTest(unittest.TestCase):
expected = "Both values (%s) in point must be float or int" % repr(coord) expected = "Both values (%s) in point must be float or int" % repr(coord)
self._test_for_expected_error(Location, coord, expected) self._test_for_expected_error(Location, coord, expected)
invalid_coords = [21, 4, 'a']
for coord in invalid_coords:
expected = "GeoPointField can only accept tuples or lists of (x, y)"
self._test_for_expected_error(Location, coord, expected)
def test_point_validation(self): def test_point_validation(self):
class Location(Document): class Location(Document):
loc = PointField() loc = PointField()

View File

@ -208,10 +208,7 @@ class TestCachedReferenceField(MongoDBTestCase):
('pj', "PJ") ('pj', "PJ")
) )
name = StringField() name = StringField()
tp = StringField( tp = StringField(choices=TYPES)
choices=TYPES
)
father = CachedReferenceField('self', fields=('tp',)) father = CachedReferenceField('self', fields=('tp',))
Person.drop_collection() Person.drop_collection()
@ -222,6 +219,9 @@ class TestCachedReferenceField(MongoDBTestCase):
a2 = Person(name='Wilson Junior', tp='pf', father=a1) a2 = Person(name='Wilson Junior', tp='pf', father=a1)
a2.save() a2.save()
a2 = Person.objects.with_id(a2.id)
self.assertEqual(a2.father.tp, a1.tp)
self.assertEqual(dict(a2.to_mongo()), { self.assertEqual(dict(a2.to_mongo()), {
"_id": a2.pk, "_id": a2.pk,
"name": u"Wilson Junior", "name": u"Wilson Junior",
@ -374,6 +374,9 @@ class TestCachedReferenceField(MongoDBTestCase):
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u')
# Check to_mongo with fields
self.assertNotIn('animal', o.to_mongo(fields=['person']))
# counts # counts
Ocorrence(person="teste 2").save() Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save() Ocorrence(person="teste 3").save()

View File

@ -172,6 +172,9 @@ class TestDateTimeField(MongoDBTestCase):
log.time = datetime.datetime.now().isoformat(' ') log.time = datetime.datetime.now().isoformat(' ')
log.validate() log.validate()
log.time = '2019-05-16 21:42:57.897847'
log.validate()
if dateutil: if dateutil:
log.time = datetime.datetime.now().isoformat('T') log.time = datetime.datetime.now().isoformat('T')
log.validate() log.validate()
@ -180,6 +183,12 @@ class TestDateTimeField(MongoDBTestCase):
self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
log.time = 'ABC' log.time = 'ABC'
self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
log.time = '2019-05-16 21:GARBAGE:12'
self.assertRaises(ValidationError, log.validate)
log.time = '2019-05-16 21:42:57.GARBAGE'
self.assertRaises(ValidationError, log.validate)
log.time = '2019-05-16 21:42:57.123.456'
self.assertRaises(ValidationError, log.validate)
class TestDateTimeTzAware(MongoDBTestCase): class TestDateTimeTzAware(MongoDBTestCase):

View File

@ -75,6 +75,16 @@ class TestEmailField(MongoDBTestCase):
user = User(email='me@localhost') user = User(email='me@localhost')
user.validate() user.validate()
def test_email_domain_validation_fails_if_invalid_idn(self):
class User(Document):
email = EmailField()
invalid_idn = '.google.com'
user = User(email='me@%s' % invalid_idn)
with self.assertRaises(ValidationError) as ctx_err:
user.validate()
self.assertIn("domain failed IDN encoding", str(ctx_err.exception))
def test_email_field_ip_domain(self): def test_email_field_ip_domain(self):
class User(Document): class User(Document):
email = EmailField() email = EmailField()

View File

@ -13,6 +13,35 @@ class TestLazyReferenceField(MongoDBTestCase):
# with a document class name. # with a document class name.
self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument)
def test___repr__(self):
class Animal(Document):
pass
class Ocurrence(Document):
animal = LazyReferenceField(Animal)
Animal.drop_collection()
Ocurrence.drop_collection()
animal = Animal()
oc = Ocurrence(animal=animal)
self.assertIn('LazyReference', repr(oc.animal))
def test___getattr___unknown_attr_raises_attribute_error(self):
class Animal(Document):
pass
class Ocurrence(Document):
animal = LazyReferenceField(Animal)
Animal.drop_collection()
Ocurrence.drop_collection()
animal = Animal().save()
oc = Ocurrence(animal=animal)
with self.assertRaises(AttributeError):
oc.animal.not_exist
def test_lazy_reference_simple(self): def test_lazy_reference_simple(self):
class Animal(Document): class Animal(Document):
name = StringField() name = StringField()
@ -479,6 +508,23 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
p = Ocurrence.objects.get() p = Ocurrence.objects.get()
self.assertIs(p.animal, None) self.assertIs(p.animal, None)
def test_generic_lazy_reference_accepts_string_instead_of_class(self):
class Animal(Document):
name = StringField()
tag = StringField()
class Ocurrence(Document):
person = StringField()
animal = GenericLazyReferenceField('Animal')
Animal.drop_collection()
Ocurrence.drop_collection()
animal = Animal().save()
Ocurrence(animal=animal).save()
p = Ocurrence.objects.get()
self.assertEqual(p.animal, animal)
def test_generic_lazy_reference_embedded(self): def test_generic_lazy_reference_embedded(self):
class Animal(Document): class Animal(Document):
name = StringField() name = StringField()

View File

@ -39,9 +39,9 @@ class TestLongField(MongoDBTestCase):
doc.value = -1 doc.value = -1
self.assertRaises(ValidationError, doc.validate) self.assertRaises(ValidationError, doc.validate)
doc.age = 120 doc.value = 120
self.assertRaises(ValidationError, doc.validate) self.assertRaises(ValidationError, doc.validate)
doc.age = 'ten' doc.value = 'ten'
self.assertRaises(ValidationError, doc.validate) self.assertRaises(ValidationError, doc.validate)
def test_long_ne_operator(self): def test_long_ne_operator(self):

View File

@ -158,6 +158,11 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(person.name, 'User B') self.assertEqual(person.name, 'User B')
self.assertEqual(person.age, None) self.assertEqual(person.age, None)
def test___getitem___invalid_index(self):
"""Ensure slicing a queryset works as expected."""
with self.assertRaises(TypeError):
self.Person.objects()['a']
def test_slice(self): def test_slice(self):
"""Ensure slicing a queryset works as expected.""" """Ensure slicing a queryset works as expected."""
user_a = self.Person.objects.create(name='User A', age=20) user_a = self.Person.objects.create(name='User A', age=20)
@ -986,6 +991,29 @@ class QuerySetTest(unittest.TestCase):
inserted_comment_id = Comment.objects.insert(comment, load_bulk=False) inserted_comment_id = Comment.objects.insert(comment, load_bulk=False)
self.assertEqual(comment.id, inserted_comment_id) self.assertEqual(comment.id, inserted_comment_id)
def test_bulk_insert_accepts_doc_with_ids(self):
class Comment(Document):
id = IntField(primary_key=True)
Comment.drop_collection()
com1 = Comment(id=0)
com2 = Comment(id=1)
Comment.objects.insert([com1, com2])
def test_insert_raise_if_duplicate_in_constraint(self):
class Comment(Document):
id = IntField(primary_key=True)
Comment.drop_collection()
com1 = Comment(id=0)
Comment.objects.insert(com1)
with self.assertRaises(NotUniqueError):
Comment.objects.insert(com1)
def test_get_changed_fields_query_count(self): def test_get_changed_fields_query_count(self):
"""Make sure we don't perform unnecessary db operations when """Make sure we don't perform unnecessary db operations when
none of document's fields were updated. none of document's fields were updated.
@ -3604,6 +3632,11 @@ class QuerySetTest(unittest.TestCase):
opts = {"deleted": False} opts = {"deleted": False}
return qryset(**opts) return qryset(**opts)
@queryset_manager
def objects_1_arg(qryset):
opts = {"deleted": False}
return qryset(**opts)
@queryset_manager @queryset_manager
def music_posts(doc_cls, queryset, deleted=False): def music_posts(doc_cls, queryset, deleted=False):
return queryset(tags='music', return queryset(tags='music',
@ -3618,6 +3651,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([p.id for p in BlogPost.objects()], self.assertEqual([p.id for p in BlogPost.objects()],
[post1.id, post2.id, post3.id]) [post1.id, post2.id, post3.id])
self.assertEqual([p.id for p in BlogPost.objects_1_arg()],
[post1.id, post2.id, post3.id])
self.assertEqual([p.id for p in BlogPost.music_posts()], self.assertEqual([p.id for p in BlogPost.music_posts()],
[post1.id, post2.id]) [post1.id, post2.id])
@ -5002,6 +5037,38 @@ class QuerySetTest(unittest.TestCase):
people.count() people.count()
self.assertEqual(q, 3) self.assertEqual(q, 3)
def test_no_cached_queryset__repr__(self):
class Person(Document):
name = StringField()
Person.drop_collection()
qs = Person.objects.no_cache()
self.assertEqual(repr(qs), '[]')
def test_no_cached_on_a_cached_queryset_raise_error(self):
class Person(Document):
name = StringField()
Person.drop_collection()
Person(name='a').save()
qs = Person.objects()
_ = list(qs)
with self.assertRaises(OperationError) as ctx_err:
qs.no_cache()
self.assertEqual("QuerySet already cached", str(ctx_err.exception))
def test_no_cached_queryset_no_cache_back_to_cache(self):
class Person(Document):
name = StringField()
Person.drop_collection()
qs = Person.objects()
self.assertIsInstance(qs, QuerySet)
qs = qs.no_cache()
self.assertIsInstance(qs, QuerySetNoCache)
qs = qs.cache()
self.assertIsInstance(qs, QuerySet)
def test_cache_not_cloned(self): def test_cache_not_cloned(self):
class User(Document): class User(Document):

View File

@ -71,6 +71,14 @@ class TransformTest(unittest.TestCase):
update = transform.update(BlogPost, push_all__tags=['mongo', 'db']) update = transform.update(BlogPost, push_all__tags=['mongo', 'db'])
self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}}) self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}})
def test_transform_update_no_operator_default_to_set(self):
"""Ensure the differences in behvaior between 'push' and 'push_all'"""
class BlogPost(Document):
tags = ListField(StringField())
update = transform.update(BlogPost, tags=['mongo', 'db'])
self.assertEqual(update, {'$set': {'tags': ['mongo', 'db']}})
def test_query_field_name(self): def test_query_field_name(self):
"""Ensure that the correct field name is used when querying. """Ensure that the correct field name is used when querying.
""" """

15
tests/test_common.py Normal file
View File

@ -0,0 +1,15 @@
import unittest
from mongoengine.common import _import_class
from mongoengine import Document
class TestCommon(unittest.TestCase):
def test__import_class(self):
doc_cls = _import_class("Document")
self.assertIs(doc_cls, Document)
def test__import_class_raise_if_not_known(self):
with self.assertRaises(ValueError):
_import_class("UnknownClass")

View File

@ -270,6 +270,14 @@ class ContextManagersTest(unittest.TestCase):
counter += 1 counter += 1
self.assertEqual(q, counter) self.assertEqual(q, counter)
self.assertEqual(int(q), counter) # test __int__
self.assertEqual(repr(q), str(int(q))) # test __repr__
self.assertGreater(q, -1) # test __gt__
self.assertGreaterEqual(q, int(q)) # test __gte__
self.assertNotEqual(q, -1)
self.assertLess(q, 1000)
self.assertLessEqual(q, int(q))
def test_query_counter_counts_getmore_queries(self): def test_query_counter_counts_getmore_queries(self):
connect('mongoenginetest') connect('mongoenginetest')
db = get_db() db = get_db()

View File

@ -1,4 +1,5 @@
import unittest import unittest
from six import iterkeys
from mongoengine import Document from mongoengine import Document
from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict
@ -368,6 +369,20 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype(a=1, b=1, c=1) d = self.dtype(a=1, b=1, c=1)
self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
def test_iterkeys(self):
d = self.dtype(a=1)
self.assertEqual(list(iterkeys(d)), ['a'])
def test_len(self):
d = self.dtype(a=1)
self.assertEqual(len(d), 1)
def test_pop(self):
d = self.dtype(a=1)
self.assertIn('a', d)
d.pop('a')
self.assertNotIn('a', d)
def test_repr(self): def test_repr(self):
d = self.dtype(a=1, b=2, c=3) d = self.dtype(a=1, b=2, c=3)
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')