diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 4c57bd7b..b7120be5 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -362,6 +362,21 @@ class QuerySet(object): result = self._document._from_son(result) return result + def in_bulk(self, object_ids): + """Retrieve a set of documents by their ids. + + :param object_ids: a list or tuple of ``ObjectId``s + :rtype: dict of ObjectIds as keys and collection-specific + Document subclasses as values. + """ + doc_map = {} + + docs = self._collection.find({'_id': {'$in': object_ids}}) + for doc in docs: + doc_map[doc['_id']] = self._document._from_son(doc) + + return doc_map + def next(self): """Wrap the result in a :class:`~mongoengine.Document` object. """ diff --git a/tests/queryset.py b/tests/queryset.py index 02f53f33..ecf7ada1 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -645,6 +645,41 @@ class QuerySetTest(unittest.TestCase): self.assertFalse([('_types', 1)] in info.values()) BlogPost.drop_collection() + + def test_bulk(self): + """Ensure bulk querying by object id returns a proper dict. + """ + class BlogPost(Document): + title = StringField() + + BlogPost.drop_collection() + + post_1 = BlogPost(title="Post #1") + post_2 = BlogPost(title="Post #2") + post_3 = BlogPost(title="Post #3") + post_4 = BlogPost(title="Post #4") + post_5 = BlogPost(title="Post #5") + + post_1.save() + post_2.save() + post_3.save() + post_4.save() + post_5.save() + + ids = [post_1.id, post_2.id, post_5.id] + objects = BlogPost.objects.in_bulk(ids) + + self.assertEqual(len(objects), 3) + + self.assertTrue(post_1.id in objects) + self.assertTrue(post_2.id in objects) + self.assertTrue(post_5.id in objects) + + self.assertTrue(objects[post_1.id].title == post_1.title) + self.assertTrue(objects[post_2.id].title == post_2.title) + self.assertTrue(objects[post_3.id].title == post_3.title) + + BlogPost.drop_collection() def tearDown(self): self.Person.drop_collection()