diff --git a/tests/document/instance.py b/tests/document/instance.py index 7a393416..8caa5675 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -7,6 +7,7 @@ import os import pickle import unittest import uuid +import weakref from datetime import datetime from bson import DBRef, ObjectId @@ -30,6 +31,8 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), __all__ = ("InstanceTest",) + + class InstanceTest(unittest.TestCase): def setUp(self): @@ -63,6 +66,14 @@ class InstanceTest(unittest.TestCase): list(self.Person._get_collection().find().sort("id")), sorted(docs, key=lambda doc: doc["_id"])) + def assertHasInstance(self, field, instance): + self.assertTrue(hasattr(field, "_instance")) + self.assertIsNotNone(field._instance) + if isinstance(field._instance, weakref.ProxyType): + self.assertTrue(field._instance.__eq__(instance)) + else: + self.assertEqual(field._instance, instance) + def test_capped_collection(self): """Ensure that capped collections work properly. """ @@ -608,10 +619,12 @@ class InstanceTest(unittest.TestCase): embedded_field = EmbeddedDocumentField(Embedded) Doc.drop_collection() - Doc(embedded_field=Embedded(string="Hi")).save() + doc = Doc(embedded_field=Embedded(string="Hi")) + self.assertHasInstance(doc.embedded_field, doc) + doc.save() doc = Doc.objects.get() - self.assertEqual(doc, doc.embedded_field._instance) + self.assertHasInstance(doc.embedded_field, doc) def test_embedded_document_complex_instance(self): """Ensure that embedded documents in complex fields can reference @@ -623,10 +636,12 @@ class InstanceTest(unittest.TestCase): embedded_field = ListField(EmbeddedDocumentField(Embedded)) Doc.drop_collection() - Doc(embedded_field=[Embedded(string="Hi")]).save() + doc = Doc(embedded_field=[Embedded(string="Hi")]) + self.assertHasInstance(doc.embedded_field[0], doc) + doc.save() doc = Doc.objects.get() - self.assertEqual(doc, doc.embedded_field[0]._instance) + self.assertHasInstance(doc.embedded_field[0], doc) def test_instance_is_set_on_setattr(self): @@ -639,11 +654,28 @@ class InstanceTest(unittest.TestCase): Account.drop_collection() acc = Account() acc.email = Email(email='test@example.com') - self.assertTrue(hasattr(acc._data["email"], "_instance")) + self.assertHasInstance(acc._data["email"], acc) acc.save() acc1 = Account.objects.first() - self.assertTrue(hasattr(acc1._data["email"], "_instance")) + self.assertHasInstance(acc1._data["email"], acc1) + + def test_instance_is_set_on_setattr_on_embedded_document_list(self): + + class Email(EmbeddedDocument): + email = EmailField() + + class Account(Document): + emails = EmbeddedDocumentListField(Email) + + Account.drop_collection() + acc = Account() + acc.emails = [Email(email='test@example.com')] + self.assertHasInstance(acc._data["emails"][0], acc) + acc.save() + + acc1 = Account.objects.first() + self.assertHasInstance(acc1._data["emails"][0], acc1) def test_document_clean(self): class TestDocument(Document):