Merge pull request #1131 from noirbizarre/fix-instance-back-references
Fix instance back references
This commit is contained in:
commit
d92d41cb05
@ -135,6 +135,10 @@ class BaseField(object):
|
|||||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
if isinstance(value, EmbeddedDocument):
|
if isinstance(value, EmbeddedDocument):
|
||||||
value._instance = weakref.proxy(instance)
|
value._instance = weakref.proxy(instance)
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
for v in value:
|
||||||
|
if isinstance(v, EmbeddedDocument):
|
||||||
|
v._instance = weakref.proxy(instance)
|
||||||
instance._data[self.name] = value
|
instance._data[self.name] = value
|
||||||
|
|
||||||
def error(self, message="", errors=None, field_name=None):
|
def error(self, message="", errors=None, field_name=None):
|
||||||
|
@ -7,6 +7,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
import weakref
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from bson import DBRef, ObjectId
|
from bson import DBRef, ObjectId
|
||||||
@ -30,6 +31,8 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
|
|||||||
__all__ = ("InstanceTest",)
|
__all__ = ("InstanceTest",)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceTest(unittest.TestCase):
|
class InstanceTest(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -63,6 +66,14 @@ class InstanceTest(unittest.TestCase):
|
|||||||
list(self.Person._get_collection().find().sort("id")),
|
list(self.Person._get_collection().find().sort("id")),
|
||||||
sorted(docs, key=lambda doc: doc["_id"]))
|
sorted(docs, key=lambda doc: doc["_id"]))
|
||||||
|
|
||||||
|
def assertHasInstance(self, field, instance):
|
||||||
|
self.assertTrue(hasattr(field, "_instance"))
|
||||||
|
self.assertTrue(field._instance is not None)
|
||||||
|
if isinstance(field._instance, weakref.ProxyType):
|
||||||
|
self.assertTrue(field._instance.__eq__(instance))
|
||||||
|
else:
|
||||||
|
self.assertEqual(field._instance, instance)
|
||||||
|
|
||||||
def test_capped_collection(self):
|
def test_capped_collection(self):
|
||||||
"""Ensure that capped collections work properly.
|
"""Ensure that capped collections work properly.
|
||||||
"""
|
"""
|
||||||
@ -608,10 +619,12 @@ class InstanceTest(unittest.TestCase):
|
|||||||
embedded_field = EmbeddedDocumentField(Embedded)
|
embedded_field = EmbeddedDocumentField(Embedded)
|
||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
Doc(embedded_field=Embedded(string="Hi")).save()
|
doc = Doc(embedded_field=Embedded(string="Hi"))
|
||||||
|
self.assertHasInstance(doc.embedded_field, doc)
|
||||||
|
|
||||||
|
doc.save()
|
||||||
doc = Doc.objects.get()
|
doc = Doc.objects.get()
|
||||||
self.assertEqual(doc, doc.embedded_field._instance)
|
self.assertHasInstance(doc.embedded_field, doc)
|
||||||
|
|
||||||
def test_embedded_document_complex_instance(self):
|
def test_embedded_document_complex_instance(self):
|
||||||
"""Ensure that embedded documents in complex fields can reference
|
"""Ensure that embedded documents in complex fields can reference
|
||||||
@ -623,10 +636,12 @@ class InstanceTest(unittest.TestCase):
|
|||||||
embedded_field = ListField(EmbeddedDocumentField(Embedded))
|
embedded_field = ListField(EmbeddedDocumentField(Embedded))
|
||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
Doc(embedded_field=[Embedded(string="Hi")]).save()
|
doc = Doc(embedded_field=[Embedded(string="Hi")])
|
||||||
|
self.assertHasInstance(doc.embedded_field[0], doc)
|
||||||
|
|
||||||
|
doc.save()
|
||||||
doc = Doc.objects.get()
|
doc = Doc.objects.get()
|
||||||
self.assertEqual(doc, doc.embedded_field[0]._instance)
|
self.assertHasInstance(doc.embedded_field[0], doc)
|
||||||
|
|
||||||
def test_instance_is_set_on_setattr(self):
|
def test_instance_is_set_on_setattr(self):
|
||||||
|
|
||||||
@ -639,11 +654,28 @@ class InstanceTest(unittest.TestCase):
|
|||||||
Account.drop_collection()
|
Account.drop_collection()
|
||||||
acc = Account()
|
acc = Account()
|
||||||
acc.email = Email(email='test@example.com')
|
acc.email = Email(email='test@example.com')
|
||||||
self.assertTrue(hasattr(acc._data["email"], "_instance"))
|
self.assertHasInstance(acc._data["email"], acc)
|
||||||
acc.save()
|
acc.save()
|
||||||
|
|
||||||
acc1 = Account.objects.first()
|
acc1 = Account.objects.first()
|
||||||
self.assertTrue(hasattr(acc1._data["email"], "_instance"))
|
self.assertHasInstance(acc1._data["email"], acc1)
|
||||||
|
|
||||||
|
def test_instance_is_set_on_setattr_on_embedded_document_list(self):
|
||||||
|
|
||||||
|
class Email(EmbeddedDocument):
|
||||||
|
email = EmailField()
|
||||||
|
|
||||||
|
class Account(Document):
|
||||||
|
emails = EmbeddedDocumentListField(Email)
|
||||||
|
|
||||||
|
Account.drop_collection()
|
||||||
|
acc = Account()
|
||||||
|
acc.emails = [Email(email='test@example.com')]
|
||||||
|
self.assertHasInstance(acc._data["emails"][0], acc)
|
||||||
|
acc.save()
|
||||||
|
|
||||||
|
acc1 = Account.objects.first()
|
||||||
|
self.assertHasInstance(acc1._data["emails"][0], acc1)
|
||||||
|
|
||||||
def test_document_clean(self):
|
def test_document_clean(self):
|
||||||
class TestDocument(Document):
|
class TestDocument(Document):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user