Merge pull request #1923 from bagerard/update_document_ids_during_bulk_insert

Update the ids of the given documents during bulk insert
This commit is contained in:
erdenezul 2018-10-17 19:03:11 +08:00 committed by GitHub
commit 0afd5a40d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 52 deletions

View File

@ -5,6 +5,7 @@ Changelog
Development Development
=========== ===========
- QuerySet limit function behaviour: Passing 0 as parameter will return all the documents in the cursor #1611 - QuerySet limit function behaviour: Passing 0 as parameter will return all the documents in the cursor #1611
- bulk insert updates the ids of the input documents #1919
- (Fill this out as you fix issues and develop your features). - (Fill this out as you fix issues and develop your features).
======= =======
Changes in 0.15.4 Changes in 0.15.4

View File

@ -372,14 +372,17 @@ class BaseQuerySet(object):
raise NotUniqueError(message % six.text_type(err)) raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err)) raise OperationError(message % six.text_type(err))
# Apply inserted_ids to documents
for doc, doc_id in zip(docs, ids):
doc.pk = doc_id
if not load_bulk: if not load_bulk:
signals.post_bulk_insert.send( signals.post_bulk_insert.send(
self._document, documents=docs, loaded=False, **signal_kwargs) self._document, documents=docs, loaded=False, **signal_kwargs)
return ids[0] if return_one else ids return ids[0] if return_one else ids
documents = self.in_bulk(ids) documents = self.in_bulk(ids)
results = [] results = [documents.get(obj_id) for obj_id in ids]
for obj_id in ids:
results.append(documents.get(obj_id))
signals.post_bulk_insert.send( signals.post_bulk_insert.send(
self._document, documents=results, loaded=True, **signal_kwargs) self._document, documents=results, loaded=True, **signal_kwargs)
return results[0] if return_one else results return results[0] if return_one else results

View File

@ -704,38 +704,38 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True) self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True)
def test_update_related_models(self): def test_update_related_models(self):
class TestPerson(Document): class TestPerson(Document):
name = StringField() name = StringField()
class TestOrganization(Document): class TestOrganization(Document):
name = StringField() name = StringField()
owner = ReferenceField(TestPerson) owner = ReferenceField(TestPerson)
TestPerson.drop_collection() TestPerson.drop_collection()
TestOrganization.drop_collection() TestOrganization.drop_collection()
p = TestPerson(name='p1') p = TestPerson(name='p1')
p.save() p.save()
o = TestOrganization(name='o1') o = TestOrganization(name='o1')
o.save() o.save()
o.owner = p o.owner = p
p.name = 'p2' p.name = 'p2'
self.assertEqual(o._get_changed_fields(), ['owner']) self.assertEqual(o._get_changed_fields(), ['owner'])
self.assertEqual(p._get_changed_fields(), ['name']) self.assertEqual(p._get_changed_fields(), ['name'])
o.save() o.save()
self.assertEqual(o._get_changed_fields(), []) self.assertEqual(o._get_changed_fields(), [])
self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty
# This will do NOTHING at all, even though we changed the name # This will do NOTHING at all, even though we changed the name
p.save() p.save()
p.reload() p.reload()
self.assertEqual(p.name, 'p2') # Fails; it's still `p1` self.assertEqual(p.name, 'p2') # Fails; it's still `p1`
def test_upsert(self): def test_upsert(self):
self.Person.drop_collection() self.Person.drop_collection()
@ -839,18 +839,18 @@ class QuerySetTest(unittest.TestCase):
comment2 = Comment(name='testb') comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2]) post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2]) post2 = Post(comments=[comment2, comment2])
blogs = [Blog(title="post %s" % i, posts=[post1, post2])
for i in range(99)]
# Check bulk insert using load_bulk=False # Check bulk insert using load_bulk=False
blogs = [Blog(title="%s" % i, posts=[post1, post2])
for i in range(99)]
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
Blog.objects.insert(blogs, load_bulk=False) Blog.objects.insert(blogs, load_bulk=False)
if MONGO_VER == MONGODB_32: if MONGO_VER == MONGODB_32:
self.assertEqual(q, 1) # 1 entry containing the list of inserts self.assertEqual(q, 1) # 1 entry containing the list of inserts
else: else:
self.assertEqual(q, len(blogs)) # 1 entry per doc inserted self.assertEqual(q, len(blogs)) # 1 entry per doc inserted
self.assertEqual(Blog.objects.count(), len(blogs)) self.assertEqual(Blog.objects.count(), len(blogs))
@ -858,14 +858,16 @@ class QuerySetTest(unittest.TestCase):
Blog.ensure_indexes() Blog.ensure_indexes()
# Check bulk insert using load_bulk=True # Check bulk insert using load_bulk=True
blogs = [Blog(title="%s" % i, posts=[post1, post2])
for i in range(99)]
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
Blog.objects.insert(blogs) Blog.objects.insert(blogs)
if MONGO_VER == MONGODB_32: if MONGO_VER == MONGODB_32:
self.assertEqual(q, 2) # 1 for insert 1 for fetch self.assertEqual(q, 2) # 1 for insert 1 for fetch
else: else:
self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs
Blog.drop_collection() Blog.drop_collection()
@ -882,30 +884,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2) self.assertEqual(Blog.objects.count(), 2)
# test inserting an existing document (shouldn't be allowed) # test inserting an existing document (shouldn't be allowed)
with self.assertRaises(OperationError): with self.assertRaises(OperationError) as cm:
blog = Blog.objects.first() blog = Blog.objects.first()
Blog.objects.insert(blog) Blog.objects.insert(blog)
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
# test inserting a query set # test inserting a query set
with self.assertRaises(OperationError): with self.assertRaises(OperationError) as cm:
blogs = Blog.objects blogs_qs = Blog.objects
Blog.objects.insert(blogs) Blog.objects.insert(blogs_qs)
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
# insert a new doc # insert 1 new doc
new_post = Blog(title="code123", id=ObjectId()) new_post = Blog(title="code123", id=ObjectId())
Blog.objects.insert(new_post) Blog.objects.insert(new_post)
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
# try inserting a non-document
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
Blog.drop_collection() Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2]) blog1 = Blog(title="code", posts=[post1, post2])
@ -916,20 +909,70 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection() Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2]) blog1 = Blog(title="code", posts=[post1, post2])
obj_id = Blog.objects.insert(blog1, load_bulk=False) obj_id = Blog.objects.insert(blog1, load_bulk=False)
self.assertEqual(obj_id.__class__.__name__, 'ObjectId') self.assertIsInstance(obj_id, ObjectId)
Blog.drop_collection() Blog.drop_collection()
post3 = Post(comments=[comment1, comment1]) post3 = Post(comments=[comment1, comment1])
blog1 = Blog(title="foo", posts=[post1, post2]) blog1 = Blog(title="foo", posts=[post1, post2])
blog2 = Blog(title="bar", posts=[post2, post3]) blog2 = Blog(title="bar", posts=[post2, post3])
blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2]) Blog.objects.insert([blog1, blog2])
with self.assertRaises(NotUniqueError): with self.assertRaises(NotUniqueError):
Blog.objects.insert([blog2, blog3]) Blog.objects.insert(Blog(title=blog2.title))
self.assertEqual(Blog.objects.count(), 2) self.assertEqual(Blog.objects.count(), 2)
def test_bulk_insert_different_class_fails(self):
class Blog(Document):
pass
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
def test_bulk_insert_with_wrong_type(self):
class Blog(Document):
name = StringField()
Blog.drop_collection()
Blog(name='test').save()
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
with self.assertRaises(OperationError):
Blog.objects.insert({'name': 'garbage'})
def test_bulk_insert_update_input_document_ids(self):
class Comment(Document):
idx = IntField()
Comment.drop_collection()
# Test with bulk
comments = [Comment(idx=idx) for idx in range(20)]
for com in comments:
self.assertIsNone(com.id)
returned_comments = Comment.objects.insert(comments, load_bulk=True)
for com in comments:
self.assertIsInstance(com.id, ObjectId)
input_mapping = {com.id: com.idx for com in comments}
saved_mapping = {com.id: com.idx for com in returned_comments}
self.assertEqual(input_mapping, saved_mapping)
Comment.drop_collection()
# Test with just one
comment = Comment(idx=0)
inserted_comment_id = Comment.objects.insert(comment, load_bulk=False)
self.assertEqual(comment.id, inserted_comment_id)
def test_get_changed_fields_query_count(self): def test_get_changed_fields_query_count(self):
"""Make sure we don't perform unnecessary db operations when """Make sure we don't perform unnecessary db operations when
none of document's fields were updated. none of document's fields were updated.
@ -2007,7 +2050,7 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.tags, ['mongodb', 'python', 'java']) self.assertEqual(post.tags, ['mongodb', 'python', 'java'])
#test push with singular value # test push with singular value
BlogPost.objects.filter(id=post.id).update(push__tags__0='scala') BlogPost.objects.filter(id=post.id).update(push__tags__0='scala')
post.reload() post.reload()
self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java']) self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java'])
@ -2286,7 +2329,7 @@ class QuerySetTest(unittest.TestCase):
class User(Document): class User(Document):
username = StringField() username = StringField()
bar = GenericEmbeddedDocumentField(choices=[Bar,]) bar = GenericEmbeddedDocumentField(choices=[Bar])
User.drop_collection() User.drop_collection()