From 56f00a64d77655bee2d00ebd783d07655a6900ff Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 12:37:06 +0100 Subject: [PATCH] Added bulk insert method. Updated changelog and added tests / query_counter tests --- docs/changelog.rst | 1 + mongoengine/queryset.py | 42 ++++++++++++++++++++- tests/queryset.py | 83 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 659bdb4e..29ecdf7a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added insert method for bulk inserts - Added blinker signal support - Added query_counter context manager for tests - Added DereferenceBaseField - for improved performance in field dereferencing diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 2de15ed4..0e87db7a 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -378,7 +378,7 @@ class QuerySet(object): """ index_spec = QuerySet._build_index_spec(self._document, key_or_list) self._collection.ensure_index( - index_spec['fields'], + index_spec['fields'], drop_dups=drop_dups, background=background, sparse=index_spec.get('sparse', False), @@ -719,6 +719,46 @@ class QuerySet(object): result = None return result + def insert(self, doc_or_docs, load_bulk=True): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document instances + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + from document import Document + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = "Some documents inserted aren't instances of %s" % str(self._document) + raise OperationError(msg) + if doc.pk: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + ids = self._collection.insert(raw) + + if not load_bulk: + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + return return_one and results[0] or results + def with_id(self, object_id): """Retrieve the object matching the id provided. diff --git a/tests/queryset.py b/tests/queryset.py index 8d046902..0b64e3e9 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -9,6 +9,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, QueryFieldList) from mongoengine import * +from mongoengine.tests import query_counter class QuerySetTest(unittest.TestCase): @@ -331,6 +332,88 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") + def test_bulk_insert(self): + """Ensure that query by array position works. + """ + + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + title = StringField() + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + with query_counter() as q: + self.assertEqual(q, 0) + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + + blogs = [] + for i in xrange(1, 100): + blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) + + Blog.objects.insert(blogs, load_bulk=False) + self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert + + Blog.objects.insert(blogs) + self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog(title="code", posts=[post1, post2]) + blog2 = Blog(title="mongodb", posts=[post2, post1]) + blog1, blog2 = Blog.objects.insert([blog1, blog2]) + self.assertEqual(blog1.title, "code") + self.assertEqual(blog2.title, "mongodb") + + self.assertEqual(Blog.objects.count(), 2) + + # test handles people trying to upsert + def throw_operation_error(): + blogs = Blog.objects + Blog.objects.insert(blogs) + + self.assertRaises(OperationError, throw_operation_error) + + # test handles other classes being inserted + def throw_operation_error_wrong_doc(): + class Author(Document): + pass + Blog.objects.insert(Author()) + + self.assertRaises(OperationError, throw_operation_error_wrong_doc) + + def throw_operation_error_not_a_document(): + Blog.objects.insert("HELLO WORLD") + + self.assertRaises(OperationError, throw_operation_error_not_a_document) + + Blog.drop_collection() + + blog1 = Blog(title="code", posts=[post1, post2]) + blog1 = Blog.objects.insert(blog1) + self.assertEqual(blog1.title, "code") + self.assertEqual(Blog.objects.count(), 1) + + Blog.drop_collection() + blog1 = Blog(title="code", posts=[post1, post2]) + obj_id = Blog.objects.insert(blog1, load_bulk=False) + self.assertEquals(obj_id.__class__.__name__, 'ObjectId') + + def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """