From 331f8b8ae7ef31badb0db3ddf4b7e843406ea807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Tue, 9 Aug 2011 14:31:26 -0300 Subject: [PATCH] fixes dereference for documents (allow_inheritance = False) --- mongoengine/dereference.py | 18 +++++++++++++++--- tests/document.py | 25 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 6bfabd94..7fe9ba2f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -3,6 +3,7 @@ import operator import pymongo from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass +from fields import ReferenceField from connection import _get_db from queryset import QuerySet from document import Document @@ -32,8 +33,16 @@ class DeReference(object): items = [i for i in items] self.max_depth = max_depth + + doc_type = None + if instance and instance._fields: + doc_type = instance._fields[name].field + + if isinstance(doc_type, ReferenceField): + doc_type = doc_type.document_type + self.reference_map = self._find_references(items) - self.object_map = self._fetch_objects() + self.object_map = self._fetch_objects(doc_type=doc_type) return self._attach_objects(items, 0, instance, name, get) def _find_references(self, items, depth=0): @@ -80,7 +89,7 @@ class DeReference(object): depth += 1 return reference_map - def _fetch_objects(self): + def _fetch_objects(self, doc_type=None): """Fetch all references and convert to their document objects """ object_map = {} @@ -94,7 +103,10 @@ class DeReference(object): else: # Generic reference: use the refs data to convert to document references = _get_db()[col].find({'_id': {'$in': refs}}) for ref in references: - doc = get_document(ref['_cls'])._from_son(ref) + if '_cls' in ref: + doc = get_document(ref['_cls'])._from_son(ref) + else: + doc = doc_type._from_son(ref) object_map[doc.id] = doc return object_map diff --git a/tests/document.py b/tests/document.py index 1c9b90ed..90a0bc25 100644 --- a/tests/document.py +++ b/tests/document.py @@ -289,6 +289,31 @@ class DocumentTest(unittest.TestCase): Zoo.drop_collection() Animal.drop_collection() + def test_reference_inheritance(self): + class Stats(Document): + created = DateTimeField(default=datetime.now) + + meta = {'allow_inheritance': False} + + class CompareStats(Document): + generated = DateTimeField(default=datetime.now) + stats = ListField(ReferenceField(Stats)) + + Stats.drop_collection() + CompareStats.drop_collection() + + list_stats = [] + + for i in xrange(10): + s = Stats() + s.save() + list_stats.append(s) + + cmp_stats = CompareStats(stats=list_stats) + cmp_stats.save() + + self.assertEqual(list_stats, CompareStats.objects.first().stats) + def test_inheritance(self): """Ensure that document may inherit fields from a superclass document. """