From a1d43fecd962ea856e361fe59bf0913269f8eb84 Mon Sep 17 00:00:00 2001 From: Greg Banks Date: Wed, 11 Apr 2012 16:37:22 -0700 Subject: [PATCH 1/2] fix for issue 473 --- mongoengine/dereference.py | 3 +++ tests/dereference.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index b1529d3c..b67c2d2f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -113,6 +113,9 @@ class DeReference(object): if '_cls' in ref: doc = get_document(ref["_cls"])._from_son(ref) else: + if doc_type is None: + doc_type = get_document( + ''.join(x.capitalize() for x in col.split('_'))) doc = doc_type._from_son(ref) object_map[doc.id] = doc return object_map diff --git a/tests/dereference.py b/tests/dereference.py index 8a4b310e..0ed64e6c 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -810,3 +810,22 @@ class FieldTest(unittest.TestCase): room = Room.objects.first().select_related() self.assertEquals(room.staffs_with_position[0]['staff'], sarah) self.assertEquals(room.staffs_with_position[1]['staff'], bob) + + def test_document_reload_no_inheritance(self): + class Foo(Document): + meta = {'allow_inheritance': False} + bar = ReferenceField('Bar') + + class Bar(Document): + meta = {'allow_inheritance': False} + msg = StringField(required=True, default='Blammo!') + + Foo.drop_collection() + Bar.drop_collection() + + bar = Bar() + bar.save() + foo = Foo() + foo.bar = bar + foo.save() + foo.reload() From 49a66ba81a4353445ec0917173f50a65880ad7bc Mon Sep 17 00:00:00 2001 From: Greg Banks Date: Thu, 12 Apr 2012 11:42:10 -0700 Subject: [PATCH 2/2] whoops, don't dereference all references as the first type encountered --- mongoengine/dereference.py | 7 ++++--- tests/dereference.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index b67c2d2f..8928d1ef 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -112,10 +112,11 @@ class DeReference(object): for ref in references: if '_cls' in ref: doc = get_document(ref["_cls"])._from_son(ref) + elif doc_type is None: + doc = get_document( + ''.join(x.capitalize() + for x in col.split('_')))._from_son(ref) else: - if doc_type is None: - doc_type = get_document( - ''.join(x.capitalize() for x in col.split('_'))) doc = doc_type._from_son(ref) object_map[doc.id] = doc return object_map diff --git a/tests/dereference.py b/tests/dereference.py index 0ed64e6c..9f0d4330 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -815,17 +815,29 @@ class FieldTest(unittest.TestCase): class Foo(Document): meta = {'allow_inheritance': False} bar = ReferenceField('Bar') + baz = ReferenceField('Baz') class Bar(Document): meta = {'allow_inheritance': False} msg = StringField(required=True, default='Blammo!') + class Baz(Document): + meta = {'allow_inheritance': False} + msg = StringField(required=True, default='Kaboom!') + Foo.drop_collection() Bar.drop_collection() + Baz.drop_collection() bar = Bar() bar.save() + baz = Baz() + baz.save() foo = Foo() foo.bar = bar + foo.baz = baz foo.save() foo.reload() + + self.assertEquals(type(foo.bar), Bar) + self.assertEquals(type(foo.baz), Baz)