add "safe" and "write_options" parameters to QuerySet.insert similar to Document.save

This commit is contained in:
Greg Banks 2012-04-26 13:56:52 -07:00
parent 2769d6d7ca
commit 0bb9781b91
3 changed files with 44 additions and 6 deletions

View File

@ -147,8 +147,9 @@ class Document(BaseDocument):
:meth:`~pymongo.collection.Collection.save` OR :meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert` :meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command. which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers For example, ``save(..., write_options={w: 2, fsync: True}, ...)`` will
have recorded the write and will force an fsync on each server being written to. wait until at least two servers have recorded the write and will force an
fsync on each server being written to.
:param cascade: Sets the flag for cascading saves. You can set a default by setting :param cascade: Sets the flag for cascading saves. You can set a default by setting
"cascade" in the document __meta__ "cascade" in the document __meta__
:param cascade_kwargs: optional kwargs dictionary to be passed throw to cascading saves :param cascade_kwargs: optional kwargs dictionary to be passed throw to cascading saves

View File

@ -824,11 +824,21 @@ class QuerySet(object):
result = None result = None
return result return result
def insert(self, doc_or_docs, load_bulk=True): def insert(self, doc_or_docs, load_bulk=True, safe=False, write_options=None):
"""bulk insert documents """bulk insert documents
If ``safe=True`` and the operation is unsuccessful, an
:class:`~mongoengine.OperationError` will be raised.
:param docs_or_doc: a document or list of documents to be inserted :param docs_or_doc: a document or list of documents to be inserted
:param load_bulk (optional): If True returns the list of document instances :param load_bulk (optional): If True returns the list of document instances
:param safe: check if the operation succeeded before returning
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``insert(..., {w: 2, fsync: True})`` will wait until at least two
servers have recorded the write and will force an fsync on each server being
written to.
By default returns document instances, set ``load_bulk`` to False to By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds`` return just ``ObjectIds``
@ -837,6 +847,10 @@ class QuerySet(object):
""" """
from document import Document from document import Document
if not write_options:
write_options = {}
write_options.update({'safe': safe})
docs = doc_or_docs docs = doc_or_docs
return_one = False return_one = False
if isinstance(docs, Document) or issubclass(docs.__class__, Document): if isinstance(docs, Document) or issubclass(docs.__class__, Document):
@ -854,7 +868,13 @@ class QuerySet(object):
raw.append(doc.to_mongo()) raw.append(doc.to_mongo())
signals.pre_bulk_insert.send(self._document, documents=docs) signals.pre_bulk_insert.send(self._document, documents=docs)
ids = self._collection.insert(raw) try:
ids = self._collection.insert(raw, **write_options)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err):
message = u'Tried to save duplicate unique keys (%s)'
raise OperationError(message % unicode(err))
if not load_bulk: if not load_bulk:
signals.post_bulk_insert.send( signals.post_bulk_insert.send(

View File

@ -480,7 +480,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(person.name, "User C") self.assertEqual(person.name, "User C")
def test_bulk_insert(self): def test_bulk_insert(self):
"""Ensure that query by array position works. """Ensure that bulk insert works
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
@ -490,7 +490,7 @@ class QuerySetTest(unittest.TestCase):
comments = ListField(EmbeddedDocumentField(Comment)) comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document): class Blog(Document):
title = StringField() title = StringField(unique=True)
tags = ListField(StringField()) tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post)) posts = ListField(EmbeddedDocumentField(Post))
@ -563,6 +563,23 @@ class QuerySetTest(unittest.TestCase):
obj_id = Blog.objects.insert(blog1, load_bulk=False) obj_id = Blog.objects.insert(blog1, load_bulk=False)
self.assertEquals(obj_id.__class__.__name__, 'ObjectId') self.assertEquals(obj_id.__class__.__name__, 'ObjectId')
Blog.drop_collection()
post3 = Post(comments=[comment1, comment1])
blog1 = Blog(title="foo", posts=[post1, post2])
blog2 = Blog(title="bar", posts=[post2, post3])
blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2])
def throw_operation_error_not_unique():
Blog.objects.insert([blog2, blog3], safe=True)
self.assertRaises(OperationError, throw_operation_error_not_unique)
self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], write_options={'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3)
def test_slave_okay(self): def test_slave_okay(self):
"""Ensures that a query can take slave_okay syntax """Ensures that a query can take slave_okay syntax
""" """