Follow ReferenceFields in EmbeddedDocuments with select_related

For the following structure:

    class Playlist(Document):
        items = ListField(EmbeddedDocumentField("PlaylistItem"))

    class PlaylistItem(EmbeddedDocument):
        song = ReferenceField("Song")

    class Song(Document):
        title = StringField()

this patch prevents the N+1 queries otherwise required to fetch all
the `Song` instances referenced by all the `PlaylistItem`s.
This commit is contained in:
Clay McClure 2014-06-19 19:33:46 -04:00
parent 4e7b5d4af8
commit 170693cf0b
3 changed files with 28 additions and 3 deletions

View File

@ -6,6 +6,7 @@ Changelog
Changes in 0.9.X - DEV Changes in 0.9.X - DEV
====================== ======================
- Follow ReferenceFields in EmbeddedDocuments with select_related #690
- Added preliminary support for text indexes #680 - Added preliminary support for text indexes #680
- Added `elemMatch` operator as well - `match` is too obscure #653 - Added `elemMatch` operator as well - `match` is too obscure #653
- Added support for progressive JPEG #486 #548 - Added support for progressive JPEG #486 #548

View File

@ -95,7 +95,7 @@ class DeReference(object):
# Recursively find dbreferences # Recursively find dbreferences
depth += 1 depth += 1
for k, item in iterator: for k, item in iterator:
if isinstance(item, Document): if isinstance(item, (Document, EmbeddedDocument)):
for field_name, field in item._fields.iteritems(): for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None) v = item._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
@ -202,7 +202,7 @@ class DeReference(object):
if k in self.object_map and not is_list: if k in self.object_map and not is_list:
data[k] = self.object_map[k] data[k] = self.object_map[k]
elif isinstance(v, Document): elif isinstance(v, (Document, EmbeddedDocument)):
for field_name, field in v._fields.iteritems(): for field_name, field in v._fields.iteritems():
v = data[k]._data.get(field_name, None) v = data[k]._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):

View File

@ -1219,6 +1219,30 @@ class FieldTest(unittest.TestCase):
page = Page.objects.first() page = Page.objects.first()
self.assertEqual(page.tags[0], page.posts[0].tags[0]) self.assertEqual(page.tags[0], page.posts[0].tags[0])
def test_select_related_follows_embedded_referencefields(self):
class Playlist(Document):
items = ListField(EmbeddedDocumentField("PlaylistItem"))
class PlaylistItem(EmbeddedDocument):
song = ReferenceField("Song")
class Song(Document):
title = StringField()
Playlist.drop_collection()
Song.drop_collection()
songs = [Song.objects.create(title="song %d" % i) for i in range(3)]
items = [PlaylistItem(song=song) for song in songs]
playlist = Playlist.objects.create(items=items)
with query_counter() as q:
self.assertEqual(q, 0)
playlist = Playlist.objects.first().select_related()
songs = [item.song for item in playlist.items]
self.assertEqual(q, 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()