From 1a93b9b2263a71f179a35760b84b5d740fdd4f57 Mon Sep 17 00:00:00 2001 From: helduel Date: Thu, 8 Nov 2012 16:30:29 +0100 Subject: [PATCH] More precise "created" keyword argument signals If a document has a user given id value, the post_save signal always got the "created" keyword argument with False value (unless force_insert is True). This patch uses the result of getlasterror to check whether the save was an update or not. --- mongoengine/document.py | 19 +++++++++++++++---- tests/test_signals.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index 7b3afafb..694d1ed4 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -233,13 +233,24 @@ class Document(BaseDocument): actual_key = self._db_field_map.get(k, k) select_dict[actual_key] = doc[actual_key] + def is_new_object(last_error): + if last_error is not None: + updated = last_error.get("updatedExisting") + if updated is not None: + return not updated + return created + upsert = self._created if updates: - collection.update(select_dict, {"$set": updates}, - upsert=upsert, safe=safe, **write_options) + last_error = collection.update(select_dict, + {"$set": updates}, upsert=upsert, safe=safe, + **write_options) + created = is_new_object(last_error) if removals: - collection.update(select_dict, {"$unset": removals}, - upsert=upsert, safe=safe, **write_options) + last_error = collection.update(select_dict, + {"$unset": removals}, upsert=upsert, safe=safe, + **write_options) + created = created or is_new_object(last_error) warn_cascade = not cascade and 'cascade' not in self._meta cascade = (self._meta.get('cascade', True) diff --git a/tests/test_signals.py b/tests/test_signals.py index d1199248..2ca820da 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -108,6 +108,20 @@ class SignalTests(unittest.TestCase): signal_output.append('post_delete Another signal, %s' % document) self.Another = Another + + class ExplicitId(Document): + id = IntField(primary_key=True) + + @classmethod + def post_save(cls, sender, document, **kwargs): + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + self.ExplicitId = ExplicitId + self.ExplicitId.objects.delete() # Save up the number of connected signals so that we can check at the end # that all the signals we register get properly unregistered self.pre_signals = ( @@ -137,6 +151,8 @@ class SignalTests(unittest.TestCase): signals.pre_delete.connect(Another.pre_delete, sender=Another) signals.post_delete.connect(Another.post_delete, sender=Another) + signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) + def tearDown(self): signals.pre_init.disconnect(self.Author.pre_init) signals.post_init.disconnect(self.Author.post_init) @@ -154,6 +170,8 @@ class SignalTests(unittest.TestCase): signals.post_save.disconnect(self.Another.post_save) signals.pre_save.disconnect(self.Another.pre_save) + signals.post_save.disconnect(self.ExplicitId.post_save) + # Check that all our signals got disconnected properly. post_signals = ( len(signals.pre_init.receivers), @@ -166,6 +184,8 @@ class SignalTests(unittest.TestCase): len(signals.post_bulk_insert.receivers), ) + self.ExplicitId.objects.delete() + self.assertEqual(self.pre_signals, post_signals) def test_model_signals(self): @@ -228,3 +248,12 @@ class SignalTests(unittest.TestCase): ]) self.Author.objects.delete() + + def test_signals_with_explicit_doc_ids(self): + """ Model saves must have a created flag the first time.""" + ei = self.ExplicitId(id=123) + # post save must received the created flag, even if there's already + # an object id present + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + # second time, it must be an update + self.assertEqual(self.get_signal_output(ei.save), ['Is updated'])