diff --git a/AUTHORS b/AUTHORS index 411e274d..fbca84e4 100644 --- a/AUTHORS +++ b/AUTHORS @@ -230,3 +230,4 @@ that much better: * Amit Lichtenberg (https://github.com/amitlicht) * Lars Butler (https://github.com/larsbutler) * George Macon (https://github.com/gmacon) + * Ashley Whetter (https://github.com/AWhetter) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5809314c..37720431 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Changes in 0.10.2 ================= - Allow shard key to point to a field in an embedded document. #551 - Allow arbirary metadata in fields. #1129 +- ReferenceFields now support abstract document types. #837 Changes in 0.10.1 ======================= diff --git a/mongoengine/fields.py b/mongoengine/fields.py index ad3c468b..3887fd84 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -895,6 +895,10 @@ class ReferenceField(BaseField): or as the :class:`~pymongo.objectid.ObjectId`.id . :param reverse_delete_rule: Determines what to do when the referring object is deleted + + .. note :: + A reference to an abstract document type is always stored as a + :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. """ if not isinstance(document_type, basestring): if not issubclass(document_type, (Document, basestring)): @@ -927,9 +931,14 @@ class ReferenceField(BaseField): self._auto_dereference = instance._fields[self.name]._auto_dereference # Dereference DBRefs if self._auto_dereference and isinstance(value, DBRef): - value = self.document_type._get_db().dereference(value) + if hasattr(value, 'cls'): + # Dereference using the class type specified in the reference + cls = get_document(value.cls) + else: + cls = self.document_type + value = cls._get_db().dereference(value) if value is not None: - instance._data[self.name] = self.document_type._from_son(value) + instance._data[self.name] = cls._from_son(value) return super(ReferenceField, self).__get__(instance, owner) @@ -939,21 +948,29 @@ class ReferenceField(BaseField): return document.id return document - id_field_name = self.document_type._meta['id_field'] - id_field = self.document_type._fields[id_field_name] - if isinstance(document, Document): # We need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') + + # Use the attributes from the document instance, so that they + # override the attributes of this field's document type + cls = document else: id_ = document + cls = self.document_type + + id_field_name = cls._meta['id_field'] + id_field = cls._fields[id_field_name] id_ = id_field.to_mongo(id_) - if self.dbref: - collection = self.document_type._get_collection_name() + if self.document_type._meta.get('abstract'): + collection = cls._get_collection_name() + return DBRef(collection, id_, cls=cls._class_name) + elif self.dbref: + collection = cls._get_collection_name() return DBRef(collection, id_) return id_ @@ -982,6 +999,14 @@ class ReferenceField(BaseField): self.error('You can only reference documents once they have been ' 'saved to the database') + if self.document_type._meta.get('abstract') and \ + not isinstance(value, self.document_type): + self.error('%s is not an instance of abstract reference' + ' type %s' % (value._class_name, + self.document_type._class_name) + ) + + def lookup_member(self, member_name): return self.document_type._fields.get(member_name) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 7ef298fc..15daecb9 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2281,6 +2281,81 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() + def test_reference_class_with_abstract_parent(self): + """Ensure that a class with an abstract parent can be referenced. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Sister(Sibling): + pass + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + Sister.drop_collection() + Brother.drop_collection() + + sister = Sister(name="Alice") + sister.save() + brother = Brother(name="Bob", sibling=sister) + brother.save() + + self.assertEquals(Brother.objects[0].sibling.name, sister.name) + + Sister.drop_collection() + Brother.drop_collection() + + def test_reference_abstract_class(self): + """Ensure that an abstract class instance cannot be used in the + reference of that abstract class. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Sister(Sibling): + pass + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + Sister.drop_collection() + Brother.drop_collection() + + sister = Sibling(name="Alice") + brother = Brother(name="Bob", sibling=sister) + self.assertRaises(ValidationError, brother.save) + + Sister.drop_collection() + Brother.drop_collection() + + def test_abstract_reference_base_type(self): + """Ensure that an an abstract reference fails validation when given a + Document that does not inherit from the abstract type. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + class Mother(Document): + name = StringField() + + Brother.drop_collection() + Mother.drop_collection() + + mother = Mother(name="Carol") + mother.save() + brother = Brother(name="Bob", sibling=mother) + self.assertRaises(ValidationError, brother.save) + + Brother.drop_collection() + Mother.drop_collection() + def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """