diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index f9b8c96d..454d781a 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -296,7 +296,6 @@ class BaseQuerySet(object): return_one = True docs = [docs] - raw = [] for doc in docs: if not isinstance(doc, self._document): msg = ("Some documents inserted aren't instances of %s" @@ -305,9 +304,10 @@ class BaseQuerySet(object): if doc.pk and not doc._created: msg = "Some documents have ObjectIds use doc.update() instead" raise OperationError(msg) - raw.append(doc.to_mongo()) signals.pre_bulk_insert.send(self._document, documents=docs) + + raw = [doc.to_mongo() for doc in docs] try: ids = self._collection.insert(raw, **write_concern) except pymongo.errors.DuplicateKeyError, err: diff --git a/tests/test_signals.py b/tests/test_signals.py index 8672925c..78305b7d 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -118,6 +118,35 @@ class SignalTests(unittest.TestCase): self.ExplicitId = ExplicitId ExplicitId.drop_collection() + class Post(Document): + title = StringField() + content = StringField() + active = BooleanField(default=False) + + def __unicode__(self): + return self.title + + @classmethod + def pre_bulk_insert(cls, sender, documents, **kwargs): + signal_output.append('pre_bulk_insert signal, %s' % + [(doc, {'active': documents[n].active}) + for n, doc in enumerate(documents)]) + + # make changes here, this is just an example - + # it could be anything that needs pre-validation or looks-ups before bulk bulk inserting + for document in documents: + if not document.active: + document.active = True + + @classmethod + def post_bulk_insert(cls, sender, documents, **kwargs): + signal_output.append('post_bulk_insert signal, %s' % + [(doc, {'active': documents[n].active}) + for n, doc in enumerate(documents)]) + + self.Post = Post + Post.drop_collection() + # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered self.pre_signals = ( @@ -147,6 +176,9 @@ class SignalTests(unittest.TestCase): signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) + signals.pre_bulk_insert.connect(Post.pre_bulk_insert, sender=Post) + signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post) + def tearDown(self): signals.pre_init.disconnect(self.Author.pre_init) signals.post_init.disconnect(self.Author.post_init) @@ -163,6 +195,9 @@ class SignalTests(unittest.TestCase): signals.post_save.disconnect(self.ExplicitId.post_save) + signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) + signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) + # Check that all our signals got disconnected properly. post_signals = ( len(signals.pre_init.receivers), @@ -199,7 +234,7 @@ class SignalTests(unittest.TestCase): a.save() self.get_signal_output(lambda: None) # eliminate signal output a1 = self.Author.objects(name='Bill Shakespeare')[0] - + self.assertEqual(self.get_signal_output(create_author), [ "pre_init signal, Author", "{'name': 'Bill Shakespeare'}", @@ -306,6 +341,20 @@ class SignalTests(unittest.TestCase): ei.switch_db("testdb-1", keep_created=False) self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + def test_signals_bulk_insert(self): + def bulk_set_active_post(): + posts = [ + self.Post(title='Post 1'), + self.Post(title='Post 2'), + self.Post(title='Post 3') + ] + self.Post.objects.insert(posts) + + results = self.get_signal_output(bulk_set_active_post) + self.assertEqual(results, [ + "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", + "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]" + ]) if __name__ == '__main__': unittest.main()