From 170693cf0bf9b114ef9a032739c5ebf5d53a56da Mon Sep 17 00:00:00 2001 From: Clay McClure Date: Thu, 19 Jun 2014 19:33:46 -0400 Subject: [PATCH] Follow ReferenceFields in EmbeddedDocuments with select_related For the following structure: class Playlist(Document): items = ListField(EmbeddedDocumentField("PlaylistItem")) class PlaylistItem(EmbeddedDocument): song = ReferenceField("Song") class Song(Document): title = StringField() this patch prevents the N+1 queries otherwise required to fetch all the `Song` instances referenced by all the `PlaylistItem`s. --- docs/changelog.rst | 1 + mongoengine/dereference.py | 4 ++-- tests/test_dereference.py | 26 +++++++++++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9a55b91b..52347d01 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Follow ReferenceFields in EmbeddedDocuments with select_related #690 - Added preliminary support for text indexes #680 - Added `elemMatch` operator as well - `match` is too obscure #653 - Added support for progressive JPEG #486 #548 diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 18235b96..f9c8ecd6 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -95,7 +95,7 @@ class DeReference(object): # Recursively find dbreferences depth += 1 for k, item in iterator: - if isinstance(item, Document): + if isinstance(item, (Document, EmbeddedDocument)): for field_name, field in item._fields.iteritems(): v = item._data.get(field_name, None) if isinstance(v, (DBRef)): @@ -202,7 +202,7 @@ class DeReference(object): if k in self.object_map and not is_list: data[k] = self.object_map[k] - elif isinstance(v, Document): + elif isinstance(v, (Document, EmbeddedDocument)): for field_name, field in v._fields.iteritems(): v = data[k]._data.get(field_name, None) if isinstance(v, (DBRef)): diff --git a/tests/test_dereference.py b/tests/test_dereference.py index dc416007..c37ada59 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1219,6 +1219,30 @@ class FieldTest(unittest.TestCase): page = Page.objects.first() self.assertEqual(page.tags[0], page.posts[0].tags[0]) + def test_select_related_follows_embedded_referencefields(self): + class Playlist(Document): + items = ListField(EmbeddedDocumentField("PlaylistItem")) + + class PlaylistItem(EmbeddedDocument): + song = ReferenceField("Song") + + class Song(Document): + title = StringField() + + Playlist.drop_collection() + Song.drop_collection() + + songs = [Song.objects.create(title="song %d" % i) for i in range(3)] + items = [PlaylistItem(song=song) for song in songs] + playlist = Playlist.objects.create(items=items) + + with query_counter() as q: + self.assertEqual(q, 0) + + playlist = Playlist.objects.first().select_related() + songs = [item.song for item in playlist.items] + + self.assertEqual(q, 2) + if __name__ == '__main__': unittest.main() -