Merge pull request #1131 from noirbizarre/fix-instance-back-references

Fix instance back references
This commit is contained in:
Omer Katz 2015-11-08 12:14:37 +02:00
commit d92d41cb05
2 changed files with 42 additions and 6 deletions

View File

@ -135,6 +135,10 @@ class BaseField(object):
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument):
value._instance = weakref.proxy(instance)
elif isinstance(value, (list, tuple)):
for v in value:
if isinstance(v, EmbeddedDocument):
v._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None):

View File

@ -7,6 +7,7 @@ import os
import pickle
import unittest
import uuid
import weakref
from datetime import datetime
from bson import DBRef, ObjectId
@ -30,6 +31,8 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
__all__ = ("InstanceTest",)
class InstanceTest(unittest.TestCase):
def setUp(self):
@ -63,6 +66,14 @@ class InstanceTest(unittest.TestCase):
list(self.Person._get_collection().find().sort("id")),
sorted(docs, key=lambda doc: doc["_id"]))
def assertHasInstance(self, field, instance):
self.assertTrue(hasattr(field, "_instance"))
self.assertTrue(field._instance is not None)
if isinstance(field._instance, weakref.ProxyType):
self.assertTrue(field._instance.__eq__(instance))
else:
self.assertEqual(field._instance, instance)
def test_capped_collection(self):
"""Ensure that capped collections work properly.
"""
@ -608,10 +619,12 @@ class InstanceTest(unittest.TestCase):
embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection()
Doc(embedded_field=Embedded(string="Hi")).save()
doc = Doc(embedded_field=Embedded(string="Hi"))
self.assertHasInstance(doc.embedded_field, doc)
doc.save()
doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field._instance)
self.assertHasInstance(doc.embedded_field, doc)
def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference
@ -623,10 +636,12 @@ class InstanceTest(unittest.TestCase):
embedded_field = ListField(EmbeddedDocumentField(Embedded))
Doc.drop_collection()
Doc(embedded_field=[Embedded(string="Hi")]).save()
doc = Doc(embedded_field=[Embedded(string="Hi")])
self.assertHasInstance(doc.embedded_field[0], doc)
doc.save()
doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field[0]._instance)
self.assertHasInstance(doc.embedded_field[0], doc)
def test_instance_is_set_on_setattr(self):
@ -639,11 +654,28 @@ class InstanceTest(unittest.TestCase):
Account.drop_collection()
acc = Account()
acc.email = Email(email='test@example.com')
self.assertTrue(hasattr(acc._data["email"], "_instance"))
self.assertHasInstance(acc._data["email"], acc)
acc.save()
acc1 = Account.objects.first()
self.assertTrue(hasattr(acc1._data["email"], "_instance"))
self.assertHasInstance(acc1._data["email"], acc1)
def test_instance_is_set_on_setattr_on_embedded_document_list(self):
class Email(EmbeddedDocument):
email = EmailField()
class Account(Document):
emails = EmbeddedDocumentListField(Email)
Account.drop_collection()
acc = Account()
acc.emails = [Email(email='test@example.com')]
self.assertHasInstance(acc._data["emails"][0], acc)
acc.save()
acc1 = Account.objects.first()
self.assertHasInstance(acc1._data["emails"][0], acc1)
def test_document_clean(self):
class TestDocument(Document):