Compare commits

...

1 Commits

Author SHA1 Message Date
Stefan Wojcik
98e1df0c45 Add continue_on_error optional kwarg to QuerySet.insert 2017-01-14 23:04:55 -05:00
2 changed files with 54 additions and 19 deletions

View File

@ -296,22 +296,25 @@ class BaseQuerySet(object):
result = None
return result
def insert(self, doc_or_docs, load_bulk=True,
write_concern=None, signal_kwargs=None):
def insert(self, doc_or_docs, load_bulk=True, write_concern=None,
signal_kwargs=None, continue_on_error=None):
"""bulk insert documents
:param doc_or_docs: a document or list of documents to be inserted
:param load_bulk (optional): If True returns the list of document
instances
:param write_concern: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant
``getLastError`` command. For example,
``insert(..., {w: 2, fsync: True})`` will wait until at least
two servers have recorded the write and will force an fsync on
each server being written to.
:param write_concern: Optional keyword argument passed down to
:meth:`~pymongo.collection.Collection.insert`, representing
the write concern. For example,
``insert(..., write_concert={w: 2, fsync: True})`` will
wait until at least two servers have recorded the write
and will force an fsync on each server being written to.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
:param continue_on_error: Optional keyword argument passed down to
:meth:`~pymongo.collection.Collection.insert`. Defines what
to do when a document cannot be inserted (e.g. due to
duplicate IDs). Read PyMongo's docs for more info.
By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds``
@ -322,12 +325,10 @@ class BaseQuerySet(object):
"""
Document = _import_class('Document')
if write_concern is None:
write_concern = {}
# Determine if we're inserting one doc or more
docs = doc_or_docs
return_one = False
if isinstance(docs, Document) or issubclass(docs.__class__, Document):
if isinstance(docs, Document):
return_one = True
docs = [docs]
@ -344,9 +345,16 @@ class BaseQuerySet(object):
signals.pre_bulk_insert.send(self._document,
documents=docs, **signal_kwargs)
# Resolve optional insert kwargs
insert_kwargs = {}
if write_concern is not None:
insert_kwargs.update(write_concern)
if continue_on_error is not None:
insert_kwargs['continue_on_error'] = continue_on_error
raw = [doc.to_mongo() for doc in docs]
try:
ids = self._collection.insert(raw, **write_concern)
ids = self._collection.insert(raw, **insert_kwargs)
except pymongo.errors.DuplicateKeyError as err:
message = 'Could not save document (%s)'
raise NotUniqueError(message % six.text_type(err))

View File

@ -766,8 +766,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(record.embed.field, 2)
def test_bulk_insert(self):
"""Ensure that bulk insert works
"""
"""Ensure that bulk insert works."""
class Comment(EmbeddedDocument):
name = StringField()
@ -885,9 +884,37 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3],
write_concern={"w": 0, 'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3)
def test_bulk_insert_continue_on_error(self):
"""Ensure that bulk insert works with the continue_on_error option."""
class Person(Document):
email = EmailField(unique=True)
Person.drop_collection()
Person.objects.insert([
Person(email='alice@example.com'),
Person(email='bob@example.com')
])
self.assertEqual(Person.objects.count(), 2)
new_docs = [
Person(email='alice@example.com'), # dupe
Person(email='bob@example.com'), # dupe
Person(email='steve@example.com') # new one
]
# By default inserting dupe docs should fail and no new docs should
# be inserted.
with self.assertRaises(NotUniqueError):
Person.objects.insert(new_docs)
self.assertEqual(Person.objects.count(), 2)
# With continue_on_error, new doc should be inserted, even though we
# still get a NotUniqueError caused by the other 2 dupes.
with self.assertRaises(NotUniqueError):
Person.objects.insert(new_docs, continue_on_error=True)
self.assertEqual(Person.objects.count(), 3)
def test_get_changed_fields_query_count(self):