diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 5bee10f6..42f068c6 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -211,13 +211,15 @@ class ListField(BaseField): def __get__(self, instance, owner): """Descriptor to automatically dereference references. """ + global _model_registry + if instance is None: # Document class being used rather than a document object return self - if isinstance(self.field, (ReferenceField, GenericReferenceField)): + if isinstance(self.field, ReferenceField): referenced_type = self.field.document_type - # Get value from document instance if available + # Get value from document instance if available value_list = instance._data.get(self.name) if value_list: deref_list = [] @@ -230,19 +232,19 @@ class ListField(BaseField): deref_list.append(value) instance._data[self.name] = deref_list - # if isinstance(self.field, GenericReferenceField): - # value_list = instance._data.get(self.name) - # if value_list: - # deref_list = [] - # for value in value_list: - # # Dereference DBRefs - # if isinstance(value, pymongo.dbref.DBRef): - # value = _get_db().dereference(value) - # referenced_type = value. - # deref_list.append() - # else: - # deref_list.append(value) - # instance._data[self.name] = deref_list + if isinstance(self.field, GenericReferenceField): + value_list = instance._data.get(self.name) + if value_list: + deref_list = [] + for value in value_list: + # Dereference DBRefs + if isinstance(value, pymongo.dbref.DBRef): + value = _get_db().dereference(value) + referenced_type = _model_registry[value['_cls']] + deref_list.append(referenced_type._from_son(value)) + else: + deref_list.append(value) + instance._data[self.name] = deref_list return super(ListField, self).__get__(instance, owner) diff --git a/tests/fields.py b/tests/fields.py index 844ee547..7f8749a9 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -392,55 +392,40 @@ class FieldTest(unittest.TestCase): Post.drop_collection() Bookmark.drop_collection() - # def test_generic_reference_list(self): - # """Ensure that a ListField properly dereferences - # relationships to *any* model via GenericReferenceField. - # """ - # class Link(Document): - # title = StringField() - # - # class Post(Document): - # title = StringField() - # - # class User(Document): - # bookmarks = ListField(GenericReferenceField()) - # - # Link.drop_collection() - # Post.drop_collection() - # User.drop_collection() - # - # link_1 = Link(title="Pitchfork") - # link_1.save() - # - # post_1 = Post(title="Behind the Scenes of the Pavement Reunion") - # post_1.save() - # - # user = User(bookmarks=[post_1, link_1]) - # user.save() - # - # del user - # - # user = User.objects().first() - # print user.bookmarks - # - # # print dir(user) - # - # # self.assertEqual(bm.bookmark_object, post_1) - # # self.assertEqual(bm._data['bookmark_object'].__class__, - # # pymongo.dbref.DBRef) - # # - # # bm.bookmark_object = link_1 - # # bm.save() - # # - # # bm.reload() - # # - # # self.assertEqual(bm.bookmark_object, link_1) - # # self.assertEqual(bm._data['bookmark_object'].__class__, - # # pymongo.dbref.DBRef) - # - # Link.drop_collection() - # Post.drop_collection() - # User.drop_collection() + def test_generic_reference_list(self): + """Ensure that a ListField properly dereferences + relationships to *any* model via GenericReferenceField. + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class User(Document): + bookmarks = ListField(GenericReferenceField()) + + Link.drop_collection() + Post.drop_collection() + User.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + user = User(bookmarks=[post_1, link_1]) + user.save() + + user.reload() + + self.assertEqual(user.bookmarks[0], post_1) + self.assertEqual(user.bookmarks[1], link_1) + + Link.drop_collection() + Post.drop_collection() + User.drop_collection() if __name__ == '__main__':