From 98e1df0c450f65fd920803ea873715c8596e22c6 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Sat, 14 Jan 2017 23:04:55 -0500 Subject: [PATCH] Add `continue_on_error` optional kwarg to QuerySet.insert --- mongoengine/queryset/base.py | 36 +++++++++++++++++++++-------------- tests/queryset/queryset.py | 37 +++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 098f198e..142ebb66 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -296,22 +296,25 @@ class BaseQuerySet(object): result = None return result - def insert(self, doc_or_docs, load_bulk=True, - write_concern=None, signal_kwargs=None): + def insert(self, doc_or_docs, load_bulk=True, write_concern=None, + signal_kwargs=None, continue_on_error=None): """bulk insert documents :param doc_or_docs: a document or list of documents to be inserted :param load_bulk (optional): If True returns the list of document instances - :param write_concern: Extra keyword arguments are passed down to - :meth:`~pymongo.collection.Collection.insert` - which will be used as options for the resultant - ``getLastError`` command. For example, - ``insert(..., {w: 2, fsync: True})`` will wait until at least - two servers have recorded the write and will force an fsync on - each server being written to. + :param write_concern: Optional keyword argument passed down to + :meth:`~pymongo.collection.Collection.insert`, representing + the write concern. For example, + ``insert(..., write_concert={w: 2, fsync: True})`` will + wait until at least two servers have recorded the write + and will force an fsync on each server being written to. :parm signal_kwargs: (optional) kwargs dictionary to be passed to the signal calls. + :param continue_on_error: Optional keyword argument passed down to + :meth:`~pymongo.collection.Collection.insert`. Defines what + to do when a document cannot be inserted (e.g. due to + duplicate IDs). Read PyMongo's docs for more info. By default returns document instances, set ``load_bulk`` to False to return just ``ObjectIds`` @@ -322,12 +325,10 @@ class BaseQuerySet(object): """ Document = _import_class('Document') - if write_concern is None: - write_concern = {} - + # Determine if we're inserting one doc or more docs = doc_or_docs return_one = False - if isinstance(docs, Document) or issubclass(docs.__class__, Document): + if isinstance(docs, Document): return_one = True docs = [docs] @@ -344,9 +345,16 @@ class BaseQuerySet(object): signals.pre_bulk_insert.send(self._document, documents=docs, **signal_kwargs) + # Resolve optional insert kwargs + insert_kwargs = {} + if write_concern is not None: + insert_kwargs.update(write_concern) + if continue_on_error is not None: + insert_kwargs['continue_on_error'] = continue_on_error + raw = [doc.to_mongo() for doc in docs] try: - ids = self._collection.insert(raw, **write_concern) + ids = self._collection.insert(raw, **insert_kwargs) except pymongo.errors.DuplicateKeyError as err: message = 'Could not save document (%s)' raise NotUniqueError(message % six.text_type(err)) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 2d5b5b0f..8768e707 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -766,8 +766,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(record.embed.field, 2) def test_bulk_insert(self): - """Ensure that bulk insert works - """ + """Ensure that bulk insert works.""" class Comment(EmbeddedDocument): name = StringField() @@ -885,9 +884,37 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Blog.objects.count(), 2) - Blog.objects.insert([blog2, blog3], - write_concern={"w": 0, 'continue_on_error': True}) - self.assertEqual(Blog.objects.count(), 3) + def test_bulk_insert_continue_on_error(self): + """Ensure that bulk insert works with the continue_on_error option.""" + + class Person(Document): + email = EmailField(unique=True) + + Person.drop_collection() + + Person.objects.insert([ + Person(email='alice@example.com'), + Person(email='bob@example.com') + ]) + self.assertEqual(Person.objects.count(), 2) + + new_docs = [ + Person(email='alice@example.com'), # dupe + Person(email='bob@example.com'), # dupe + Person(email='steve@example.com') # new one + ] + + # By default inserting dupe docs should fail and no new docs should + # be inserted. + with self.assertRaises(NotUniqueError): + Person.objects.insert(new_docs) + self.assertEqual(Person.objects.count(), 2) + + # With continue_on_error, new doc should be inserted, even though we + # still get a NotUniqueError caused by the other 2 dupes. + with self.assertRaises(NotUniqueError): + Person.objects.insert(new_docs, continue_on_error=True) + self.assertEqual(Person.objects.count(), 3) def test_get_changed_fields_query_count(self):