From 013227323d8725e6ffd2e13aecfd0f3a9165d276 Mon Sep 17 00:00:00 2001 From: Ashley Whetter Date: Tue, 10 Nov 2015 14:29:25 +0000 Subject: [PATCH] ReferenceFields can now reference abstract Document types A class that inherits from an abstract Document type is stored in the database as a reference with a 'cls' field that is the class name of the document being stored. Fixes #837 --- mongoengine/fields.py | 25 ++++++++++++++++----- tests/fields/fields.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f5899311..13538c89 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -928,9 +928,14 @@ class ReferenceField(BaseField): self._auto_dereference = instance._fields[self.name]._auto_dereference # Dereference DBRefs if self._auto_dereference and isinstance(value, DBRef): - value = self.document_type._get_db().dereference(value) + if hasattr(value, 'cls'): + # Dereference using the class type specified in the reference + cls = get_document(value.cls) + else: + cls = self.document_type + value = cls._get_db().dereference(value) if value is not None: - instance._data[self.name] = self.document_type._from_son(value) + instance._data[self.name] = cls._from_son(value) return super(ReferenceField, self).__get__(instance, owner) @@ -940,22 +945,30 @@ class ReferenceField(BaseField): return document.id return document - id_field_name = self.document_type._meta['id_field'] - id_field = self.document_type._fields[id_field_name] - if isinstance(document, Document): # We need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') + + # Use the attributes from the document instance, so that they + # override the attributes of this field's document type + cls = document else: id_ = document + cls = self.document_type + + id_field_name = cls._meta['id_field'] + id_field = cls._fields[id_field_name] id_ = id_field.to_mongo(id_) if self.dbref: - collection = self.document_type._get_collection_name() + collection = cls._get_collection_name() return DBRef(collection, id_) + elif self.document_type._meta.get('abstract'): + collection = cls._get_collection_name() + return DBRef(collection, id_, cls=cls._class_name) return id_ diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 7ef298fc..860a5749 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2281,6 +2281,56 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() + def test_reference_class_with_abstract_parent(self): + """Ensure that a class with an abstract parent can be referenced. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Sister(Sibling): + pass + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + Sister.drop_collection() + Brother.drop_collection() + + sister = Sister(name="Alice") + sister.save() + brother = Brother(name="Bob", sibling=sister) + brother.save() + + self.assertEquals(Brother.objects[0].sibling.name, sister.name) + + Sister.drop_collection() + Brother.drop_collection() + + def test_reference_abstract_class(self): + """Ensure that an abstract class instance cannot be used in the + reference of that abstract class. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Sister(Sibling): + pass + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + Sister.drop_collection() + Brother.drop_collection() + + sister = Sibling(name="Alice") + brother = Brother(name="Bob", sibling=sister) + self.assertRaises(ValidationError, brother.save) + + Sister.drop_collection() + Brother.drop_collection() + def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """