clean up document instance tests

This commit is contained in:
Stefan Wojcik 2017-03-02 00:25:56 -05:00
parent 741643af5f
commit 30e8b8186f

View File

@ -28,8 +28,6 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
__all__ = ("InstanceTest",) __all__ = ("InstanceTest",)
class InstanceTest(unittest.TestCase): class InstanceTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -72,8 +70,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(field._instance, instance) self.assertEqual(field._instance, instance)
def test_capped_collection(self): def test_capped_collection(self):
"""Ensure that capped collections work properly. """Ensure that capped collections work properly."""
"""
class Log(Document): class Log(Document):
date = DateTimeField(default=datetime.now) date = DateTimeField(default=datetime.now)
meta = { meta = {
@ -181,8 +178,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('<Article: привет мир>', repr(doc)) self.assertEqual('<Article: привет мир>', repr(doc))
def test_repr_none(self): def test_repr_none(self):
"""Ensure None values handled correctly """Ensure None values are handled correctly."""
"""
class Article(Document): class Article(Document):
title = StringField() title = StringField()
@ -190,25 +186,23 @@ class InstanceTest(unittest.TestCase):
return None return None
doc = Article(title=u'привет мир') doc = Article(title=u'привет мир')
self.assertEqual('<Article: None>', repr(doc)) self.assertEqual('<Article: None>', repr(doc))
def test_queryset_resurrects_dropped_collection(self): def test_queryset_resurrects_dropped_collection(self):
self.Person.drop_collection() self.Person.drop_collection()
self.assertEqual([], list(self.Person.objects())) self.assertEqual([], list(self.Person.objects()))
# Ensure works correctly with inhertited classes
class Actor(self.Person): class Actor(self.Person):
pass pass
# Ensure works correctly with inhertited classes
Actor.objects() Actor.objects()
self.Person.drop_collection() self.Person.drop_collection()
self.assertEqual([], list(Actor.objects())) self.assertEqual([], list(Actor.objects()))
def test_polymorphic_references(self): def test_polymorphic_references(self):
"""Ensure that the correct subclasses are returned from a query when """Ensure that the correct subclasses are returned from a query
using references / generic references when using references / generic references
""" """
class Animal(Document): class Animal(Document):
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
@ -258,9 +252,6 @@ class InstanceTest(unittest.TestCase):
classes = [a.__class__ for a in Zoo.objects.first().animals] classes = [a.__class__ for a in Zoo.objects.first().animals]
self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human])
Zoo.drop_collection()
Animal.drop_collection()
def test_reference_inheritance(self): def test_reference_inheritance(self):
class Stats(Document): class Stats(Document):
created = DateTimeField(default=datetime.now) created = DateTimeField(default=datetime.now)
@ -287,8 +278,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(list_stats, CompareStats.objects.first().stats) self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_db_field_load(self): def test_db_field_load(self):
"""Ensure we load data correctly """Ensure we load data correctly from the right db field."""
"""
class Person(Document): class Person(Document):
name = StringField(required=True) name = StringField(required=True)
_rank = StringField(required=False, db_field="rank") _rank = StringField(required=False, db_field="rank")
@ -307,8 +297,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Person.objects.get(name="Fred").rank, "Private") self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_db_embedded_doc_field_load(self): def test_db_embedded_doc_field_load(self):
"""Ensure we load embedded document data correctly """Ensure we load embedded document data correctly."""
"""
class Rank(EmbeddedDocument): class Rank(EmbeddedDocument):
title = StringField(required=True) title = StringField(required=True)
@ -333,8 +322,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Person.objects.get(name="Fred").rank, "Private") self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_custom_id_field(self): def test_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys. """Ensure that documents may be created with custom primary keys."""
"""
class User(Document): class User(Document):
username = StringField(primary_key=True) username = StringField(primary_key=True)
name = StringField() name = StringField()
@ -382,10 +370,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(user_son['_id'], 'mongo') self.assertEqual(user_son['_id'], 'mongo')
self.assertTrue('username' not in user_son['_id']) self.assertTrue('username' not in user_son['_id'])
User.drop_collection()
def test_document_not_registered(self): def test_document_not_registered(self):
class Place(Document): class Place(Document):
name = StringField() name = StringField()
@ -407,7 +392,6 @@ class InstanceTest(unittest.TestCase):
list(Place.objects.all()) list(Place.objects.all())
def test_document_registry_regressions(self): def test_document_registry_regressions(self):
class Location(Document): class Location(Document):
name = StringField() name = StringField()
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
@ -421,18 +405,16 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Area, get_document("Location.Area")) self.assertEqual(Area, get_document("Location.Area"))
def test_creation(self): def test_creation(self):
"""Ensure that document may be created using keyword arguments. """Ensure that document may be created using keyword arguments."""
"""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 30) self.assertEqual(person.age, 30)
def test_to_dbref(self): def test_to_dbref(self):
"""Ensure that you can get a dbref of a document""" """Ensure that you can get a dbref of a document."""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertRaises(OperationError, person.to_dbref) self.assertRaises(OperationError, person.to_dbref)
person.save() person.save()
person.to_dbref() person.to_dbref()
def test_save_abstract_document(self): def test_save_abstract_document(self):
@ -445,8 +427,7 @@ class InstanceTest(unittest.TestCase):
Doc(name='aaa').save() Doc(name='aaa').save()
def test_reload(self): def test_reload(self):
"""Ensure that attributes may be reloaded. """Ensure that attributes may be reloaded."""
"""
person = self.Person(name="Test User", age=20) person = self.Person(name="Test User", age=20)
person.save() person.save()
@ -479,7 +460,6 @@ class InstanceTest(unittest.TestCase):
doc = Animal(superphylum='Deuterostomia') doc = Animal(superphylum='Deuterostomia')
doc.save() doc.save()
doc.reload() doc.reload()
Animal.drop_collection()
def test_reload_sharded_nested(self): def test_reload_sharded_nested(self):
class SuperPhylum(EmbeddedDocument): class SuperPhylum(EmbeddedDocument):
@ -493,11 +473,9 @@ class InstanceTest(unittest.TestCase):
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
doc.save() doc.save()
doc.reload() doc.reload()
Animal.drop_collection()
def test_reload_referencing(self): def test_reload_referencing(self):
"""Ensures reloading updates weakrefs correctly """Ensures reloading updates weakrefs correctly."""
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
dict_field = DictField() dict_field = DictField()
list_field = ListField() list_field = ListField()
@ -569,8 +547,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse("Threw wrong exception") self.assertFalse("Threw wrong exception")
def test_reload_of_non_strict_with_special_field_name(self): def test_reload_of_non_strict_with_special_field_name(self):
"""Ensures reloading works for documents with meta strict == False """Ensures reloading works for documents with meta strict == False."""
"""
class Post(Document): class Post(Document):
meta = { meta = {
'strict': False 'strict': False
@ -591,8 +568,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(post.items, ["more lorem", "even more ipsum"]) self.assertEqual(post.items, ["more lorem", "even more ipsum"])
def test_dictionary_access(self): def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly. """Ensure that dictionary-style field access works properly."""
"""
person = self.Person(name='Test User', age=30, job=self.Job()) person = self.Person(name='Test User', age=30, job=self.Job())
self.assertEqual(person['name'], 'Test User') self.assertEqual(person['name'], 'Test User')
@ -634,8 +610,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(sub_doc.to_mongo().keys(), ['id']) self.assertEqual(sub_doc.to_mongo().keys(), ['id'])
def test_embedded_document(self): def test_embedded_document(self):
"""Ensure that embedded documents are set up correctly. """Ensure that embedded documents are set up correctly."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@ -643,8 +618,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse('id' in Comment._fields) self.assertFalse('id' in Comment._fields)
def test_embedded_document_instance(self): def test_embedded_document_instance(self):
"""Ensure that embedded documents can reference parent instance """Ensure that embedded documents can reference parent instance."""
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
string = StringField() string = StringField()
@ -652,6 +626,7 @@ class InstanceTest(unittest.TestCase):
embedded_field = EmbeddedDocumentField(Embedded) embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection() Doc.drop_collection()
doc = Doc(embedded_field=Embedded(string="Hi")) doc = Doc(embedded_field=Embedded(string="Hi"))
self.assertHasInstance(doc.embedded_field, doc) self.assertHasInstance(doc.embedded_field, doc)
@ -661,7 +636,8 @@ class InstanceTest(unittest.TestCase):
def test_embedded_document_complex_instance(self): def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference """Ensure that embedded documents in complex fields can reference
parent instance""" parent instance.
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
string = StringField() string = StringField()
@ -677,8 +653,7 @@ class InstanceTest(unittest.TestCase):
self.assertHasInstance(doc.embedded_field[0], doc) self.assertHasInstance(doc.embedded_field[0], doc)
def test_embedded_document_complex_instance_no_use_db_field(self): def test_embedded_document_complex_instance_no_use_db_field(self):
"""Ensure that use_db_field is propagated to list of Emb Docs """Ensure that use_db_field is propagated to list of Emb Docs."""
"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
string = StringField(db_field='s') string = StringField(db_field='s')
@ -690,7 +665,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(d['embedded_field'], [{'string': 'Hi'}]) self.assertEqual(d['embedded_field'], [{'string': 'Hi'}])
def test_instance_is_set_on_setattr(self): def test_instance_is_set_on_setattr(self):
class Email(EmbeddedDocument): class Email(EmbeddedDocument):
email = EmailField() email = EmailField()
@ -698,6 +672,7 @@ class InstanceTest(unittest.TestCase):
email = EmbeddedDocumentField(Email) email = EmbeddedDocumentField(Email)
Account.drop_collection() Account.drop_collection()
acc = Account() acc = Account()
acc.email = Email(email='test@example.com') acc.email = Email(email='test@example.com')
self.assertHasInstance(acc._data["email"], acc) self.assertHasInstance(acc._data["email"], acc)
@ -707,7 +682,6 @@ class InstanceTest(unittest.TestCase):
self.assertHasInstance(acc1._data["email"], acc1) self.assertHasInstance(acc1._data["email"], acc1)
def test_instance_is_set_on_setattr_on_embedded_document_list(self): def test_instance_is_set_on_setattr_on_embedded_document_list(self):
class Email(EmbeddedDocument): class Email(EmbeddedDocument):
email = EmailField() email = EmailField()
@ -853,32 +827,28 @@ class InstanceTest(unittest.TestCase):
self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
def test_save(self): def test_save(self):
"""Ensure that a document may be saved in the database. """Ensure that a document may be saved in the database."""
"""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30)
person.save() person.save()
# Ensure that the object is in the database # Ensure that the object is in the database
collection = self.db[self.Person._get_collection_name()] collection = self.db[self.Person._get_collection_name()]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['name'], 'Test User')
self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['age'], 30)
self.assertEqual(person_obj['_id'], person.id) self.assertEqual(person_obj['_id'], person.id)
# Test skipping validation on save
# Test skipping validation on save
class Recipient(Document): class Recipient(Document):
email = EmailField(required=True) email = EmailField(required=True)
recipient = Recipient(email='root@localhost') recipient = Recipient(email='root@localhost')
self.assertRaises(ValidationError, recipient.save) self.assertRaises(ValidationError, recipient.save)
try:
recipient.save(validate=False) recipient.save(validate=False)
except ValidationError:
self.fail()
def test_save_to_a_value_that_equates_to_false(self): def test_save_to_a_value_that_equates_to_false(self):
class Thing(EmbeddedDocument): class Thing(EmbeddedDocument):
count = IntField() count = IntField()
@ -898,7 +868,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(user.thing.count, 0) self.assertEqual(user.thing.count, 0)
def test_save_max_recursion_not_hit(self): def test_save_max_recursion_not_hit(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@ -924,7 +893,6 @@ class InstanceTest(unittest.TestCase):
p0.save() p0.save()
def test_save_max_recursion_not_hit_with_file_field(self): def test_save_max_recursion_not_hit_with_file_field(self):
class Foo(Document): class Foo(Document):
name = StringField() name = StringField()
picture = FileField() picture = FileField()
@ -948,7 +916,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture)
def test_save_cascades(self): def test_save_cascades(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@ -971,7 +938,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_cascade_kwargs(self): def test_save_cascade_kwargs(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@ -992,7 +958,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p2.parent.name) self.assertEqual(p1.name, p2.parent.name)
def test_save_cascade_meta_false(self): def test_save_cascade_meta_false(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@ -1021,7 +986,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_cascade_meta_true(self): def test_save_cascade_meta_true(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = ReferenceField('self') parent = ReferenceField('self')
@ -1046,7 +1010,6 @@ class InstanceTest(unittest.TestCase):
self.assertNotEqual(p1.name, p.parent.name) self.assertNotEqual(p1.name, p.parent.name)
def test_save_cascades_generically(self): def test_save_cascades_generically(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
parent = GenericReferenceField() parent = GenericReferenceField()
@ -1072,7 +1035,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_atomicity_condition(self): def test_save_atomicity_condition(self):
class Widget(Document): class Widget(Document):
toggle = BooleanField(default=False) toggle = BooleanField(default=False)
count = IntField(default=0) count = IntField(default=0)
@ -1150,7 +1112,8 @@ class InstanceTest(unittest.TestCase):
def test_update(self): def test_update(self):
"""Ensure that an existing document is updated instead of be """Ensure that an existing document is updated instead of be
overwritten.""" overwritten.
"""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30)
person.save() person.save()
@ -1254,7 +1217,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(2, self.Person.objects.count()) self.assertEqual(2, self.Person.objects.count())
def test_can_save_if_not_included(self): def test_can_save_if_not_included(self):
class EmbeddedDoc(EmbeddedDocument): class EmbeddedDoc(EmbeddedDocument):
pass pass
@ -1341,10 +1303,7 @@ class InstanceTest(unittest.TestCase):
doc2.update(set__name=doc1.name) doc2.update(set__name=doc1.name)
def test_embedded_update(self): def test_embedded_update(self):
""" """Test update on `EmbeddedDocumentField` fields."""
Test update on `EmbeddedDocumentField` fields
"""
class Page(EmbeddedDocument): class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message", log_message = StringField(verbose_name="Log message",
required=True) required=True)
@ -1365,11 +1324,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_embedded_update_db_field(self): def test_embedded_update_db_field(self):
"""Test update on `EmbeddedDocumentField` fields when db_field
is other than default.
""" """
Test update on `EmbeddedDocumentField` fields when db_field is other
than default.
"""
class Page(EmbeddedDocument): class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message", log_message = StringField(verbose_name="Log message",
db_field="page_log_message", db_field="page_log_message",
@ -1392,9 +1349,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_save_only_changed_fields(self): def test_save_only_changed_fields(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields."""
"""
class User(self.Person): class User(self.Person):
active = BooleanField(default=True) active = BooleanField(default=True)
@ -1514,8 +1469,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(q, 3) self.assertEqual(q, 3)
def test_set_unset_one_operation(self): def test_set_unset_one_operation(self):
"""Ensure that $set and $unset actions are performed in the same """Ensure that $set and $unset actions are performed in the
operation. same operation.
""" """
class FooBar(Document): class FooBar(Document):
foo = StringField(default=None) foo = StringField(default=None)
@ -1536,9 +1491,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(1, q) self.assertEqual(1, q)
def test_save_only_changed_fields_recursive(self): def test_save_only_changed_fields_recursive(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
published = BooleanField(default=True) published = BooleanField(default=True)
@ -1578,8 +1531,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse(person.comments_dict['first_post'].published) self.assertFalse(person.comments_dict['first_post'].published)
def test_delete(self): def test_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method."""
"""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
person.save() person.save()
self.assertEqual(self.Person.objects.count(), 1) self.assertEqual(self.Person.objects.count(), 1)
@ -1587,33 +1539,34 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(self.Person.objects.count(), 0) self.assertEqual(self.Person.objects.count(), 0)
def test_save_custom_id(self): def test_save_custom_id(self):
"""Ensure that a document may be saved with a custom _id. """Ensure that a document may be saved with a custom _id."""
"""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
id='497ce96f395f2f052a494fd4') id='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._get_collection_name()] collection = self.db[self.Person._get_collection_name()]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_custom_pk(self): def test_save_custom_pk(self):
""" """Ensure that a document may be saved with a custom _id using
Ensure that a document may be saved with a custom _id using pk alias. pk alias.
""" """
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
pk='497ce96f395f2f052a494fd4') pk='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._get_collection_name()] collection = self.db[self.Person._get_collection_name()]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_list(self): def test_save_list(self):
"""Ensure that a list field may be properly saved. """Ensure that a list field may be properly saved."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@ -1636,8 +1589,6 @@ class InstanceTest(unittest.TestCase):
for comment_obj, comment in zip(post_obj['comments'], comments): for comment_obj, comment in zip(post_obj['comments'], comments):
self.assertEqual(comment_obj['content'], comment['content']) self.assertEqual(comment_obj['content'], comment['content'])
BlogPost.drop_collection()
def test_list_search_by_embedded(self): def test_list_search_by_embedded(self):
class User(Document): class User(Document):
username = StringField(required=True) username = StringField(required=True)
@ -1697,8 +1648,8 @@ class InstanceTest(unittest.TestCase):
list(Page.objects.filter(comments__user=u3))) list(Page.objects.filter(comments__user=u3)))
def test_save_embedded_document(self): def test_save_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may
saved in the database. be saved in the database.
""" """
class EmployeeDetails(EmbeddedDocument): class EmployeeDetails(EmbeddedDocument):
position = StringField() position = StringField()
@ -1717,13 +1668,13 @@ class InstanceTest(unittest.TestCase):
employee_obj = collection.find_one({'name': 'Test Employee'}) employee_obj = collection.find_one({'name': 'Test Employee'})
self.assertEqual(employee_obj['name'], 'Test Employee') self.assertEqual(employee_obj['name'], 'Test Employee')
self.assertEqual(employee_obj['age'], 50) self.assertEqual(employee_obj['age'], 50)
# Ensure that the 'details' embedded object saved correctly # Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Developer') self.assertEqual(employee_obj['details']['position'], 'Developer')
def test_embedded_update_after_save(self): def test_embedded_update_after_save(self):
""" """Test update of `EmbeddedDocumentField` attached to a newly
Test update of `EmbeddedDocumentField` attached to a newly saved saved document.
document.
""" """
class Page(EmbeddedDocument): class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message", log_message = StringField(verbose_name="Log message",
@ -1744,8 +1695,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_updating_an_embedded_document(self): def test_updating_an_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may
saved in the database. be saved in the database.
""" """
class EmployeeDetails(EmbeddedDocument): class EmployeeDetails(EmbeddedDocument):
position = StringField() position = StringField()
@ -1780,7 +1731,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(promoted_employee.details, None) self.assertEqual(promoted_employee.details, None)
def test_object_mixins(self): def test_object_mixins(self):
class NameMixin(object): class NameMixin(object):
name = StringField() name = StringField()
@ -1819,9 +1769,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(t.count, 12) self.assertEqual(t.count, 12)
def test_save_reference(self): def test_save_reference(self):
"""Ensure that a document reference field may be saved in the database. """Ensure that a document reference field may be saved in the
database.
""" """
class BlogPost(Document): class BlogPost(Document):
meta = {'collection': 'blogpost_1'} meta = {'collection': 'blogpost_1'}
content = StringField() content = StringField()
@ -1852,8 +1802,6 @@ class InstanceTest(unittest.TestCase):
author = list(self.Person.objects(name='Test User'))[-1] author = list(self.Person.objects(name='Test User'))[-1]
self.assertEqual(author.age, 25) self.assertEqual(author.age, 25)
BlogPost.drop_collection()
def test_duplicate_db_fields_raise_invalid_document_error(self): def test_duplicate_db_fields_raise_invalid_document_error(self):
"""Ensure a InvalidDocumentError is thrown if duplicate fields """Ensure a InvalidDocumentError is thrown if duplicate fields
declare the same db_field. declare the same db_field.
@ -1864,7 +1812,7 @@ class InstanceTest(unittest.TestCase):
name2 = StringField(db_field='name') name2 = StringField(db_field='name')
def test_invalid_son(self): def test_invalid_son(self):
"""Raise an error if loading invalid data""" """Raise an error if loading invalid data."""
class Occurrence(EmbeddedDocument): class Occurrence(EmbeddedDocument):
number = IntField() number = IntField()
@ -1887,9 +1835,9 @@ class InstanceTest(unittest.TestCase):
Word._from_son('this is not a valid SON dict') Word._from_son('this is not a valid SON dict')
def test_reverse_delete_rule_cascade_and_nullify(self): def test_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon deletion. """Ensure that a referenced document is also deleted upon
deletion.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
@ -1944,7 +1892,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Book.objects.count(), 0) self.assertEqual(Book.objects.count(), 0)
def test_reverse_delete_rule_with_shared_id_among_collections(self): def test_reverse_delete_rule_with_shared_id_among_collections(self):
"""Ensure that cascade delete rule doesn't mix id among collections. """Ensure that cascade delete rule doesn't mix id among
collections.
""" """
class User(Document): class User(Document):
id = IntField(primary_key=True) id = IntField(primary_key=True)
@ -1975,10 +1924,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Book.objects.get(), book_2) self.assertEqual(Book.objects.get(), book_2)
def test_reverse_delete_rule_with_document_inheritance(self): def test_reverse_delete_rule_with_document_inheritance(self):
"""Ensure that a referenced document is also deleted upon deletion """Ensure that a referenced document is also deleted upon
of a child document. deletion of a child document.
""" """
class Writer(self.Person): class Writer(self.Person):
pass pass
@ -2010,10 +1958,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): def test_reverse_delete_rule_cascade_and_nullify_complex_field(self):
"""Ensure that a referenced document is also deleted upon deletion for """Ensure that a referenced document is also deleted upon
complex fields. deletion for complex fields.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
authors = ListField(ReferenceField( authors = ListField(ReferenceField(
@ -2022,7 +1969,6 @@ class InstanceTest(unittest.TestCase):
self.Person, reverse_delete_rule=NULLIFY)) self.Person, reverse_delete_rule=NULLIFY))
self.Person.drop_collection() self.Person.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
author = self.Person(name='Test User') author = self.Person(name='Test User')
@ -2046,10 +1992,10 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self):
""" ensure the pre_delete signal is triggered upon a cascading deletion """Ensure the pre_delete signal is triggered upon a cascading
setup a blog post with content, an author and editor deletion setup a blog post with content, an author and editor
delete the author which triggers deletion of blogpost via cascade delete the author which triggers deletion of blogpost via
blog post's pre_delete signal alters an editor attribute cascade blog post's pre_delete signal alters an editor attribute.
""" """
class Editor(self.Person): class Editor(self.Person):
review_queue = IntField(default=0) review_queue = IntField(default=0)
@ -2077,6 +2023,7 @@ class InstanceTest(unittest.TestCase):
# delete the author, the post is also deleted due to the CASCADE rule # delete the author, the post is also deleted due to the CASCADE rule
author.delete() author.delete()
# the pre-delete signal should have decremented the editor's queue # the pre-delete signal should have decremented the editor's queue
editor = Editor.objects(name='Max P.').get() editor = Editor.objects(name='Max P.').get()
self.assertEqual(editor.review_queue, 0) self.assertEqual(editor.review_queue, 0)
@ -2085,7 +2032,6 @@ class InstanceTest(unittest.TestCase):
"""Ensure that Bi-Directional relationships work with """Ensure that Bi-Directional relationships work with
reverse_delete_rule reverse_delete_rule
""" """
class Bar(Document): class Bar(Document):
content = StringField() content = StringField()
foo = ReferenceField('Foo') foo = ReferenceField('Foo')
@ -2131,8 +2077,8 @@ class InstanceTest(unittest.TestCase):
mother = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY)
def test_reverse_delete_rule_cascade_recurs(self): def test_reverse_delete_rule_cascade_recurs(self):
"""Ensure that a chain of documents is also deleted upon cascaded """Ensure that a chain of documents is also deleted upon
deletion. cascaded deletion.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
@ -2162,15 +2108,10 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(Comment.objects.count(), 0) self.assertEqual(Comment.objects.count(), 0)
self.Person.drop_collection()
BlogPost.drop_collection()
Comment.drop_collection()
def test_reverse_delete_rule_deny(self): def test_reverse_delete_rule_deny(self):
"""Ensure that a document cannot be referenced if there are still """Ensure that a document cannot be referenced if there are
documents referring to it. still documents referring to it.
""" """
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=DENY) author = ReferenceField(self.Person, reverse_delete_rule=DENY)
@ -2198,11 +2139,7 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(self.Person.objects.count(), 1) self.assertEqual(self.Person.objects.count(), 1)
self.Person.drop_collection()
BlogPost.drop_collection()
def subclasses_and_unique_keys_works(self): def subclasses_and_unique_keys_works(self):
class A(Document): class A(Document):
pass pass
@ -2218,12 +2155,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(A.objects.count(), 2) self.assertEqual(A.objects.count(), 2)
self.assertEqual(B.objects.count(), 1) self.assertEqual(B.objects.count(), 1)
A.drop_collection()
B.drop_collection()
def test_document_hash(self): def test_document_hash(self):
"""Test document in list, dict, set """Test document in list, dict, set."""
"""
class User(Document): class User(Document):
pass pass
@ -2266,11 +2200,9 @@ class InstanceTest(unittest.TestCase):
# in Set # in Set
all_user_set = set(User.objects.all()) all_user_set = set(User.objects.all())
self.assertTrue(u1 in all_user_set) self.assertTrue(u1 in all_user_set)
def test_picklable(self): def test_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded() pickle_doc.embedded = PickleEmbedded()
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
@ -2296,7 +2228,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) self.assertEqual(pickle_doc.lists, ["1", "2", "3"])
def test_regular_document_pickle(self): def test_regular_document_pickle(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
pickle_doc.save() pickle_doc.save()
@ -2319,7 +2250,6 @@ class InstanceTest(unittest.TestCase):
fixtures.PickleTest = PickleTest fixtures.PickleTest = PickleTest
def test_dynamic_document_pickle(self): def test_dynamic_document_pickle(self):
pickle_doc = PickleDynamicTest( pickle_doc = PickleDynamicTest(
name="test", number=1, string="One", lists=['1', '2']) name="test", number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar") pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar")
@ -2358,7 +2288,6 @@ class InstanceTest(unittest.TestCase):
validate = DictField() validate = DictField()
def test_mutating_documents(self): def test_mutating_documents(self):
class B(EmbeddedDocument): class B(EmbeddedDocument):
field1 = StringField(default='field1') field1 = StringField(default='field1')
@ -2366,6 +2295,7 @@ class InstanceTest(unittest.TestCase):
b = EmbeddedDocumentField(B, default=lambda: B()) b = EmbeddedDocumentField(B, default=lambda: B())
A.drop_collection() A.drop_collection()
a = A() a = A()
a.save() a.save()
a.reload() a.reload()
@ -2389,12 +2319,13 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(a.b.field2.c_field, 'new value') self.assertEqual(a.b.field2.c_field, 'new value')
def test_can_save_false_values(self): def test_can_save_false_values(self):
"""Ensures you can save False values on save""" """Ensures you can save False values on save."""
class Doc(Document): class Doc(Document):
foo = StringField() foo = StringField()
archived = BooleanField(default=False, required=True) archived = BooleanField(default=False, required=True)
Doc.drop_collection() Doc.drop_collection()
d = Doc() d = Doc()
d.save() d.save()
d.archived = False d.archived = False
@ -2403,11 +2334,12 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
def test_can_save_false_values_dynamic(self): def test_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs""" """Ensures you can save False values on dynamic docs."""
class Doc(DynamicDocument): class Doc(DynamicDocument):
foo = StringField() foo = StringField()
Doc.drop_collection() Doc.drop_collection()
d = Doc() d = Doc()
d.save() d.save()
d.archived = False d.archived = False
@ -2447,7 +2379,7 @@ class InstanceTest(unittest.TestCase):
Collection.update = orig_update Collection.update = orig_update
def test_db_alias_tests(self): def test_db_alias_tests(self):
""" DB Alias tests """ """DB Alias tests."""
# mongoenginetest - Is default connection alias from setUp() # mongoenginetest - Is default connection alias from setUp()
# Register Aliases # Register Aliases
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
@ -2509,8 +2441,7 @@ class InstanceTest(unittest.TestCase):
get_db("testdb-3")[AuthorBooks._get_collection_name()]) get_db("testdb-3")[AuthorBooks._get_collection_name()])
def test_db_alias_overrides(self): def test_db_alias_overrides(self):
"""db_alias can be overriden """Test db_alias can be overriden."""
"""
# Register a connection with db_alias testdb-2 # Register a connection with db_alias testdb-2
register_connection('testdb-2', 'mongoenginetest2') register_connection('testdb-2', 'mongoenginetest2')
@ -2534,8 +2465,7 @@ class InstanceTest(unittest.TestCase):
B._get_collection().database.name) B._get_collection().database.name)
def test_db_alias_propagates(self): def test_db_alias_propagates(self):
"""db_alias propagates? """db_alias propagates?"""
"""
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
class A(Document): class A(Document):
@ -2548,8 +2478,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('testdb-1', B._meta.get('db_alias')) self.assertEqual('testdb-1', B._meta.get('db_alias'))
def test_db_ref_usage(self): def test_db_ref_usage(self):
""" DB Ref usage in dict_fields""" """DB Ref usage in dict_fields."""
class User(Document): class User(Document):
name = StringField() name = StringField()
@ -2784,7 +2713,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(user.thing._data['data'], [1, 2, 3]) self.assertEqual(user.thing._data['data'], [1, 2, 3])
def test_spaces_in_keys(self): def test_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument): class Embedded(DynamicEmbeddedDocument):
pass pass
@ -2873,7 +2801,6 @@ class InstanceTest(unittest.TestCase):
log.machine = "127.0.0.1" log.machine = "127.0.0.1"
def test_kwargs_simple(self): def test_kwargs_simple(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
name = StringField() name = StringField()
@ -2893,7 +2820,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_kwargs_complex(self): def test_kwargs_complex(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
name = StringField() name = StringField()
@ -2916,36 +2842,35 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_positional_creation(self): def test_positional_creation(self):
"""Ensure that document may be created using positional arguments. """Ensure that document may be created using positional arguments."""
"""
person = self.Person("Test User", 42) person = self.Person("Test User", 42)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_mixed_creation(self): def test_mixed_creation(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments."""
"""
person = self.Person("Test User", age=42) person = self.Person("Test User", age=42)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_positional_creation_embedded(self): def test_positional_creation_embedded(self):
"""Ensure that embedded document may be created using positional arguments. """Ensure that embedded document may be created using positional
arguments.
""" """
job = self.Job("Test Job", 4) job = self.Job("Test Job", 4)
self.assertEqual(job.name, "Test Job") self.assertEqual(job.name, "Test Job")
self.assertEqual(job.years, 4) self.assertEqual(job.years, 4)
def test_mixed_creation_embedded(self): def test_mixed_creation_embedded(self):
"""Ensure that embedded document may be created using mixed arguments. """Ensure that embedded document may be created using mixed
arguments.
""" """
job = self.Job("Test Job", years=4) job = self.Job("Test Job", years=4)
self.assertEqual(job.name, "Test Job") self.assertEqual(job.name, "Test Job")
self.assertEqual(job.years, 4) self.assertEqual(job.years, 4)
def test_mixed_creation_dynamic(self): def test_mixed_creation_dynamic(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments."""
"""
class Person(DynamicDocument): class Person(DynamicDocument):
name = StringField() name = StringField()
@ -2954,14 +2879,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self): def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments """Ensure that document gives correct error when duplicating
arguments.
""" """
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
return self.Person("Test User", 42, name="Bad User") return self.Person("Test User", 42, name="Bad User")
def test_data_contains_id_field(self): def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id' """Ensure that asking for _data returns 'id'."""
"""
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -2973,7 +2898,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person._data.get('id'), person.id) self.assertEqual(person._data.get('id'), person.id)
def test_complex_nesting_document_and_embedded_document(self): def test_complex_nesting_document_and_embedded_document(self):
class Macro(EmbeddedDocument): class Macro(EmbeddedDocument):
value = DynamicField(default="UNDEFINED") value = DynamicField(default="UNDEFINED")
@ -3016,7 +2940,6 @@ class InstanceTest(unittest.TestCase):
system.nodes["node"].parameters["param"].macros["test"].value) system.nodes["node"].parameters["param"].macros["test"].value)
def test_embedded_document_equality(self): def test_embedded_document_equality(self):
class Test(Document): class Test(Document):
field = StringField(required=True) field = StringField(required=True)
@ -3202,8 +3125,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(idx, 2) self.assertEqual(idx, 2)
def test_falsey_pk(self): def test_falsey_pk(self):
"""Ensure that we can create and update a document with Falsey PK. """Ensure that we can create and update a document with Falsey PK."""
"""
class Person(Document): class Person(Document):
age = IntField(primary_key=True) age = IntField(primary_key=True)
height = FloatField() height = FloatField()