diff --git a/docs/changelog.rst b/docs/changelog.rst index edc0fb1a..69581a4b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Development =========== - QuerySet limit function behaviour: Passing 0 as parameter will return all the documents in the cursor #1611 +- bulk insert updates the ids of the input documents #1919 - (Fill this out as you fix issues and develop your features). ======= Changes in 0.15.4 diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index d3a5f050..9ea214cc 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -372,14 +372,17 @@ class BaseQuerySet(object): raise NotUniqueError(message % six.text_type(err)) raise OperationError(message % six.text_type(err)) + # Apply inserted_ids to documents + for doc, doc_id in zip(docs, ids): + doc.pk = doc_id + if not load_bulk: signals.post_bulk_insert.send( self._document, documents=docs, loaded=False, **signal_kwargs) return ids[0] if return_one else ids + documents = self.in_bulk(ids) - results = [] - for obj_id in ids: - results.append(documents.get(obj_id)) + results = [documents.get(obj_id) for obj_id in ids] signals.post_bulk_insert.send( self._document, documents=results, loaded=True, **signal_kwargs) return results[0] if return_one else results diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index c5004ed2..f4bc1dcd 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -704,38 +704,38 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True) def test_update_related_models(self): - class TestPerson(Document): - name = StringField() + class TestPerson(Document): + name = StringField() - class TestOrganization(Document): - name = StringField() - owner = ReferenceField(TestPerson) + class TestOrganization(Document): + name = StringField() + owner = ReferenceField(TestPerson) - TestPerson.drop_collection() - TestOrganization.drop_collection() + TestPerson.drop_collection() + TestOrganization.drop_collection() - p = TestPerson(name='p1') - p.save() - o = TestOrganization(name='o1') - o.save() + p = TestPerson(name='p1') + p.save() + o = TestOrganization(name='o1') + o.save() - o.owner = p - p.name = 'p2' + o.owner = p + p.name = 'p2' - self.assertEqual(o._get_changed_fields(), ['owner']) - self.assertEqual(p._get_changed_fields(), ['name']) + self.assertEqual(o._get_changed_fields(), ['owner']) + self.assertEqual(p._get_changed_fields(), ['name']) - o.save() + o.save() - self.assertEqual(o._get_changed_fields(), []) - self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty + self.assertEqual(o._get_changed_fields(), []) + self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty - # This will do NOTHING at all, even though we changed the name - p.save() + # This will do NOTHING at all, even though we changed the name + p.save() - p.reload() + p.reload() - self.assertEqual(p.name, 'p2') # Fails; it's still `p1` + self.assertEqual(p.name, 'p2') # Fails; it's still `p1` def test_upsert(self): self.Person.drop_collection() @@ -839,18 +839,18 @@ class QuerySetTest(unittest.TestCase): comment2 = Comment(name='testb') post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) - blogs = [Blog(title="post %s" % i, posts=[post1, post2]) - for i in range(99)] # Check bulk insert using load_bulk=False + blogs = [Blog(title="%s" % i, posts=[post1, post2]) + for i in range(99)] with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs, load_bulk=False) if MONGO_VER == MONGODB_32: - self.assertEqual(q, 1) # 1 entry containing the list of inserts + self.assertEqual(q, 1) # 1 entry containing the list of inserts else: - self.assertEqual(q, len(blogs)) # 1 entry per doc inserted + self.assertEqual(q, len(blogs)) # 1 entry per doc inserted self.assertEqual(Blog.objects.count(), len(blogs)) @@ -858,14 +858,16 @@ class QuerySetTest(unittest.TestCase): Blog.ensure_indexes() # Check bulk insert using load_bulk=True + blogs = [Blog(title="%s" % i, posts=[post1, post2]) + for i in range(99)] with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs) if MONGO_VER == MONGODB_32: - self.assertEqual(q, 2) # 1 for insert 1 for fetch + self.assertEqual(q, 2) # 1 for insert 1 for fetch else: - self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs + self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs Blog.drop_collection() @@ -882,30 +884,21 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Blog.objects.count(), 2) # test inserting an existing document (shouldn't be allowed) - with self.assertRaises(OperationError): + with self.assertRaises(OperationError) as cm: blog = Blog.objects.first() Blog.objects.insert(blog) + self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead') # test inserting a query set - with self.assertRaises(OperationError): - blogs = Blog.objects - Blog.objects.insert(blogs) + with self.assertRaises(OperationError) as cm: + blogs_qs = Blog.objects + Blog.objects.insert(blogs_qs) + self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead') - # insert a new doc + # insert 1 new doc new_post = Blog(title="code123", id=ObjectId()) Blog.objects.insert(new_post) - class Author(Document): - pass - - # try inserting a different document class - with self.assertRaises(OperationError): - Blog.objects.insert(Author()) - - # try inserting a non-document - with self.assertRaises(OperationError): - Blog.objects.insert("HELLO WORLD") - Blog.drop_collection() blog1 = Blog(title="code", posts=[post1, post2]) @@ -916,20 +909,70 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() blog1 = Blog(title="code", posts=[post1, post2]) obj_id = Blog.objects.insert(blog1, load_bulk=False) - self.assertEqual(obj_id.__class__.__name__, 'ObjectId') + self.assertIsInstance(obj_id, ObjectId) Blog.drop_collection() post3 = Post(comments=[comment1, comment1]) blog1 = Blog(title="foo", posts=[post1, post2]) blog2 = Blog(title="bar", posts=[post2, post3]) - blog3 = Blog(title="baz", posts=[post1, post2]) Blog.objects.insert([blog1, blog2]) with self.assertRaises(NotUniqueError): - Blog.objects.insert([blog2, blog3]) + Blog.objects.insert(Blog(title=blog2.title)) self.assertEqual(Blog.objects.count(), 2) + def test_bulk_insert_different_class_fails(self): + class Blog(Document): + pass + + class Author(Document): + pass + + # try inserting a different document class + with self.assertRaises(OperationError): + Blog.objects.insert(Author()) + + def test_bulk_insert_with_wrong_type(self): + class Blog(Document): + name = StringField() + + Blog.drop_collection() + Blog(name='test').save() + + with self.assertRaises(OperationError): + Blog.objects.insert("HELLO WORLD") + + with self.assertRaises(OperationError): + Blog.objects.insert({'name': 'garbage'}) + + def test_bulk_insert_update_input_document_ids(self): + class Comment(Document): + idx = IntField() + + Comment.drop_collection() + + # Test with bulk + comments = [Comment(idx=idx) for idx in range(20)] + for com in comments: + self.assertIsNone(com.id) + + returned_comments = Comment.objects.insert(comments, load_bulk=True) + + for com in comments: + self.assertIsInstance(com.id, ObjectId) + + input_mapping = {com.id: com.idx for com in comments} + saved_mapping = {com.id: com.idx for com in returned_comments} + self.assertEqual(input_mapping, saved_mapping) + + Comment.drop_collection() + + # Test with just one + comment = Comment(idx=0) + inserted_comment_id = Comment.objects.insert(comment, load_bulk=False) + self.assertEqual(comment.id, inserted_comment_id) + def test_get_changed_fields_query_count(self): """Make sure we don't perform unnecessary db operations when none of document's fields were updated. @@ -2007,7 +2050,7 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.tags, ['mongodb', 'python', 'java']) - #test push with singular value + # test push with singular value BlogPost.objects.filter(id=post.id).update(push__tags__0='scala') post.reload() self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java']) @@ -2286,7 +2329,7 @@ class QuerySetTest(unittest.TestCase): class User(Document): username = StringField() - bar = GenericEmbeddedDocumentField(choices=[Bar,]) + bar = GenericEmbeddedDocumentField(choices=[Bar]) User.drop_collection()