diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index b1529d3c..8928d1ef 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -112,6 +112,10 @@ class DeReference(object): for ref in references: if '_cls' in ref: doc = get_document(ref["_cls"])._from_son(ref) + elif doc_type is None: + doc = get_document( + ''.join(x.capitalize() + for x in col.split('_')))._from_son(ref) else: doc = doc_type._from_son(ref) object_map[doc.id] = doc diff --git a/tests/dereference.py b/tests/dereference.py index 8a4b310e..9f0d4330 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -810,3 +810,34 @@ class FieldTest(unittest.TestCase): room = Room.objects.first().select_related() self.assertEquals(room.staffs_with_position[0]['staff'], sarah) self.assertEquals(room.staffs_with_position[1]['staff'], bob) + + def test_document_reload_no_inheritance(self): + class Foo(Document): + meta = {'allow_inheritance': False} + bar = ReferenceField('Bar') + baz = ReferenceField('Baz') + + class Bar(Document): + meta = {'allow_inheritance': False} + msg = StringField(required=True, default='Blammo!') + + class Baz(Document): + meta = {'allow_inheritance': False} + msg = StringField(required=True, default='Kaboom!') + + Foo.drop_collection() + Bar.drop_collection() + Baz.drop_collection() + + bar = Bar() + bar.save() + baz = Baz() + baz.save() + foo = Foo() + foo.bar = bar + foo.baz = baz + foo.save() + foo.reload() + + self.assertEquals(type(foo.bar), Bar) + self.assertEquals(type(foo.baz), Baz)