Merge pull request #1923 from bagerard/update_document_ids_during_bulk_insert

Update the ids of the given documents during bulk insert
This commit is contained in:
erdenezul 2018-10-17 19:03:11 +08:00 committed by GitHub
commit 0afd5a40d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 52 deletions

View File

@ -5,6 +5,7 @@ Changelog
Development
===========
- QuerySet limit function behaviour: Passing 0 as parameter will return all the documents in the cursor #1611
- bulk insert updates the ids of the input documents #1919
- (Fill this out as you fix issues and develop your features).
=======
Changes in 0.15.4

View File

@ -372,14 +372,17 @@ class BaseQuerySet(object):
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
# Apply inserted_ids to documents
for doc, doc_id in zip(docs, ids):
doc.pk = doc_id
if not load_bulk:
signals.post_bulk_insert.send(
self._document, documents=docs, loaded=False, **signal_kwargs)
return ids[0] if return_one else ids
documents = self.in_bulk(ids)
results = []
for obj_id in ids:
results.append(documents.get(obj_id))
results = [documents.get(obj_id) for obj_id in ids]
signals.post_bulk_insert.send(
self._document, documents=results, loaded=True, **signal_kwargs)
return results[0] if return_one else results

View File

@ -704,38 +704,38 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True)
def test_update_related_models(self):
class TestPerson(Document):
name = StringField()
class TestPerson(Document):
name = StringField()
class TestOrganization(Document):
name = StringField()
owner = ReferenceField(TestPerson)
class TestOrganization(Document):
name = StringField()
owner = ReferenceField(TestPerson)
TestPerson.drop_collection()
TestOrganization.drop_collection()
TestPerson.drop_collection()
TestOrganization.drop_collection()
p = TestPerson(name='p1')
p.save()
o = TestOrganization(name='o1')
o.save()
p = TestPerson(name='p1')
p.save()
o = TestOrganization(name='o1')
o.save()
o.owner = p
p.name = 'p2'
o.owner = p
p.name = 'p2'
self.assertEqual(o._get_changed_fields(), ['owner'])
self.assertEqual(p._get_changed_fields(), ['name'])
self.assertEqual(o._get_changed_fields(), ['owner'])
self.assertEqual(p._get_changed_fields(), ['name'])
o.save()
o.save()
self.assertEqual(o._get_changed_fields(), [])
self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty
self.assertEqual(o._get_changed_fields(), [])
self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty
# This will do NOTHING at all, even though we changed the name
p.save()
# This will do NOTHING at all, even though we changed the name
p.save()
p.reload()
p.reload()
self.assertEqual(p.name, 'p2') # Fails; it's still `p1`
self.assertEqual(p.name, 'p2') # Fails; it's still `p1`
def test_upsert(self):
self.Person.drop_collection()
@ -839,18 +839,18 @@ class QuerySetTest(unittest.TestCase):
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blogs = [Blog(title="post %s" % i, posts=[post1, post2])
for i in range(99)]
# Check bulk insert using load_bulk=False
blogs = [Blog(title="%s" % i, posts=[post1, post2])
for i in range(99)]
with query_counter() as q:
self.assertEqual(q, 0)
Blog.objects.insert(blogs, load_bulk=False)
if MONGO_VER == MONGODB_32:
self.assertEqual(q, 1) # 1 entry containing the list of inserts
self.assertEqual(q, 1) # 1 entry containing the list of inserts
else:
self.assertEqual(q, len(blogs)) # 1 entry per doc inserted
self.assertEqual(q, len(blogs)) # 1 entry per doc inserted
self.assertEqual(Blog.objects.count(), len(blogs))
@ -858,14 +858,16 @@ class QuerySetTest(unittest.TestCase):
Blog.ensure_indexes()
# Check bulk insert using load_bulk=True
blogs = [Blog(title="%s" % i, posts=[post1, post2])
for i in range(99)]
with query_counter() as q:
self.assertEqual(q, 0)
Blog.objects.insert(blogs)
if MONGO_VER == MONGODB_32:
self.assertEqual(q, 2) # 1 for insert 1 for fetch
self.assertEqual(q, 2) # 1 for insert 1 for fetch
else:
self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs
self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs
Blog.drop_collection()
@ -882,30 +884,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2)
# test inserting an existing document (shouldn't be allowed)
with self.assertRaises(OperationError):
with self.assertRaises(OperationError) as cm:
blog = Blog.objects.first()
Blog.objects.insert(blog)
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
# test inserting a query set
with self.assertRaises(OperationError):
blogs = Blog.objects
Blog.objects.insert(blogs)
with self.assertRaises(OperationError) as cm:
blogs_qs = Blog.objects
Blog.objects.insert(blogs_qs)
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
# insert a new doc
# insert 1 new doc
new_post = Blog(title="code123", id=ObjectId())
Blog.objects.insert(new_post)
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
# try inserting a non-document
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
@ -916,20 +909,70 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
obj_id = Blog.objects.insert(blog1, load_bulk=False)
self.assertEqual(obj_id.__class__.__name__, 'ObjectId')
self.assertIsInstance(obj_id, ObjectId)
Blog.drop_collection()
post3 = Post(comments=[comment1, comment1])
blog1 = Blog(title="foo", posts=[post1, post2])
blog2 = Blog(title="bar", posts=[post2, post3])
blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2])
with self.assertRaises(NotUniqueError):
Blog.objects.insert([blog2, blog3])
Blog.objects.insert(Blog(title=blog2.title))
self.assertEqual(Blog.objects.count(), 2)
def test_bulk_insert_different_class_fails(self):
class Blog(Document):
pass
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
def test_bulk_insert_with_wrong_type(self):
class Blog(Document):
name = StringField()
Blog.drop_collection()
Blog(name='test').save()
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
with self.assertRaises(OperationError):
Blog.objects.insert({'name': 'garbage'})
def test_bulk_insert_update_input_document_ids(self):
class Comment(Document):
idx = IntField()
Comment.drop_collection()
# Test with bulk
comments = [Comment(idx=idx) for idx in range(20)]
for com in comments:
self.assertIsNone(com.id)
returned_comments = Comment.objects.insert(comments, load_bulk=True)
for com in comments:
self.assertIsInstance(com.id, ObjectId)
input_mapping = {com.id: com.idx for com in comments}
saved_mapping = {com.id: com.idx for com in returned_comments}
self.assertEqual(input_mapping, saved_mapping)
Comment.drop_collection()
# Test with just one
comment = Comment(idx=0)
inserted_comment_id = Comment.objects.insert(comment, load_bulk=False)
self.assertEqual(comment.id, inserted_comment_id)
def test_get_changed_fields_query_count(self):
"""Make sure we don't perform unnecessary db operations when
none of document's fields were updated.
@ -2007,7 +2050,7 @@ class QuerySetTest(unittest.TestCase):
post.reload()
self.assertEqual(post.tags, ['mongodb', 'python', 'java'])
#test push with singular value
# test push with singular value
BlogPost.objects.filter(id=post.id).update(push__tags__0='scala')
post.reload()
self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java'])
@ -2286,7 +2329,7 @@ class QuerySetTest(unittest.TestCase):
class User(Document):
username = StringField()
bar = GenericEmbeddedDocumentField(choices=[Bar,])
bar = GenericEmbeddedDocumentField(choices=[Bar])
User.drop_collection()