Merge pull request #1923 from bagerard/update_document_ids_during_bulk_insert

Update the ids of the given documents during bulk insert
This commit is contained in:
erdenezul 2018-10-17 19:03:11 +08:00 committed by GitHub
commit 0afd5a40d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 52 deletions

View File

@ -5,6 +5,7 @@ Changelog
Development
===========
- QuerySet limit function behaviour: Passing 0 as parameter will return all the documents in the cursor #1611
- bulk insert updates the ids of the input documents #1919
- (Fill this out as you fix issues and develop your features).
=======
Changes in 0.15.4

View File

@ -372,14 +372,17 @@ class BaseQuerySet(object):
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
# Apply inserted_ids to documents
for doc, doc_id in zip(docs, ids):
doc.pk = doc_id
if not load_bulk:
signals.post_bulk_insert.send(
self._document, documents=docs, loaded=False, **signal_kwargs)
return ids[0] if return_one else ids
documents = self.in_bulk(ids)
results = []
for obj_id in ids:
results.append(documents.get(obj_id))
results = [documents.get(obj_id) for obj_id in ids]
signals.post_bulk_insert.send(
self._document, documents=results, loaded=True, **signal_kwargs)
return results[0] if return_one else results

View File

@ -839,10 +839,10 @@ class QuerySetTest(unittest.TestCase):
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blogs = [Blog(title="post %s" % i, posts=[post1, post2])
for i in range(99)]
# Check bulk insert using load_bulk=False
blogs = [Blog(title="%s" % i, posts=[post1, post2])
for i in range(99)]
with query_counter() as q:
self.assertEqual(q, 0)
Blog.objects.insert(blogs, load_bulk=False)
@ -858,6 +858,8 @@ class QuerySetTest(unittest.TestCase):
Blog.ensure_indexes()
# Check bulk insert using load_bulk=True
blogs = [Blog(title="%s" % i, posts=[post1, post2])
for i in range(99)]
with query_counter() as q:
self.assertEqual(q, 0)
Blog.objects.insert(blogs)
@ -882,30 +884,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2)
# test inserting an existing document (shouldn't be allowed)
with self.assertRaises(OperationError):
with self.assertRaises(OperationError) as cm:
blog = Blog.objects.first()
Blog.objects.insert(blog)
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
# test inserting a query set
with self.assertRaises(OperationError):
blogs = Blog.objects
Blog.objects.insert(blogs)
with self.assertRaises(OperationError) as cm:
blogs_qs = Blog.objects
Blog.objects.insert(blogs_qs)
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
# insert a new doc
# insert 1 new doc
new_post = Blog(title="code123", id=ObjectId())
Blog.objects.insert(new_post)
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
# try inserting a non-document
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
@ -916,20 +909,70 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
obj_id = Blog.objects.insert(blog1, load_bulk=False)
self.assertEqual(obj_id.__class__.__name__, 'ObjectId')
self.assertIsInstance(obj_id, ObjectId)
Blog.drop_collection()
post3 = Post(comments=[comment1, comment1])
blog1 = Blog(title="foo", posts=[post1, post2])
blog2 = Blog(title="bar", posts=[post2, post3])
blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2])
with self.assertRaises(NotUniqueError):
Blog.objects.insert([blog2, blog3])
Blog.objects.insert(Blog(title=blog2.title))
self.assertEqual(Blog.objects.count(), 2)
def test_bulk_insert_different_class_fails(self):
class Blog(Document):
pass
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
def test_bulk_insert_with_wrong_type(self):
class Blog(Document):
name = StringField()
Blog.drop_collection()
Blog(name='test').save()
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
with self.assertRaises(OperationError):
Blog.objects.insert({'name': 'garbage'})
def test_bulk_insert_update_input_document_ids(self):
class Comment(Document):
idx = IntField()
Comment.drop_collection()
# Test with bulk
comments = [Comment(idx=idx) for idx in range(20)]
for com in comments:
self.assertIsNone(com.id)
returned_comments = Comment.objects.insert(comments, load_bulk=True)
for com in comments:
self.assertIsInstance(com.id, ObjectId)
input_mapping = {com.id: com.idx for com in comments}
saved_mapping = {com.id: com.idx for com in returned_comments}
self.assertEqual(input_mapping, saved_mapping)
Comment.drop_collection()
# Test with just one
comment = Comment(idx=0)
inserted_comment_id = Comment.objects.insert(comment, load_bulk=False)
self.assertEqual(comment.id, inserted_comment_id)
def test_get_changed_fields_query_count(self):
"""Make sure we don't perform unnecessary db operations when
none of document's fields were updated.
@ -2286,7 +2329,7 @@ class QuerySetTest(unittest.TestCase):
class User(Document):
username = StringField()
bar = GenericEmbeddedDocumentField(choices=[Bar,])
bar = GenericEmbeddedDocumentField(choices=[Bar])
User.drop_collection()