fixes dereference for documents (allow_inheritance = False)

This commit is contained in:
Wilson Júnior 2011-08-09 14:31:26 -03:00
parent 3f3f93b0fa
commit 331f8b8ae7
2 changed files with 40 additions and 3 deletions

View File

@ -3,6 +3,7 @@ import operator
import pymongo
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass
from fields import ReferenceField
from connection import _get_db
from queryset import QuerySet
from document import Document
@ -32,8 +33,16 @@ class DeReference(object):
items = [i for i in items]
self.max_depth = max_depth
doc_type = None
if instance and instance._fields:
doc_type = instance._fields[name].field
if isinstance(doc_type, ReferenceField):
doc_type = doc_type.document_type
self.reference_map = self._find_references(items)
self.object_map = self._fetch_objects()
self.object_map = self._fetch_objects(doc_type=doc_type)
return self._attach_objects(items, 0, instance, name, get)
def _find_references(self, items, depth=0):
@ -80,7 +89,7 @@ class DeReference(object):
depth += 1
return reference_map
def _fetch_objects(self):
def _fetch_objects(self, doc_type=None):
"""Fetch all references and convert to their document objects
"""
object_map = {}
@ -94,7 +103,10 @@ class DeReference(object):
else: # Generic reference: use the refs data to convert to document
references = _get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
doc = get_document(ref['_cls'])._from_son(ref)
if '_cls' in ref:
doc = get_document(ref['_cls'])._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
return object_map

View File

@ -289,6 +289,31 @@ class DocumentTest(unittest.TestCase):
Zoo.drop_collection()
Animal.drop_collection()
def test_reference_inheritance(self):
class Stats(Document):
created = DateTimeField(default=datetime.now)
meta = {'allow_inheritance': False}
class CompareStats(Document):
generated = DateTimeField(default=datetime.now)
stats = ListField(ReferenceField(Stats))
Stats.drop_collection()
CompareStats.drop_collection()
list_stats = []
for i in xrange(10):
s = Stats()
s.save()
list_stats.append(s)
cmp_stats = CompareStats(stats=list_stats)
cmp_stats.save()
self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_inheritance(self):
"""Ensure that document may inherit fields from a superclass document.
"""