More precise "created" keyword argument signals

If a document has a user given id value, the post_save signal always got the
"created" keyword argument with False value (unless force_insert is True).

This patch uses the result of getlasterror to check whether the save was an
update or not.
This commit is contained in:
helduel 2012-11-08 16:30:29 +01:00
parent c31488add9
commit 1a93b9b226
2 changed files with 44 additions and 4 deletions

View File

@ -233,13 +233,24 @@ class Document(BaseDocument):
actual_key = self._db_field_map.get(k, k) actual_key = self._db_field_map.get(k, k)
select_dict[actual_key] = doc[actual_key] select_dict[actual_key] = doc[actual_key]
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get("updatedExisting")
if updated is not None:
return not updated
return created
upsert = self._created upsert = self._created
if updates: if updates:
collection.update(select_dict, {"$set": updates}, last_error = collection.update(select_dict,
upsert=upsert, safe=safe, **write_options) {"$set": updates}, upsert=upsert, safe=safe,
**write_options)
created = is_new_object(last_error)
if removals: if removals:
collection.update(select_dict, {"$unset": removals}, last_error = collection.update(select_dict,
upsert=upsert, safe=safe, **write_options) {"$unset": removals}, upsert=upsert, safe=safe,
**write_options)
created = created or is_new_object(last_error)
warn_cascade = not cascade and 'cascade' not in self._meta warn_cascade = not cascade and 'cascade' not in self._meta
cascade = (self._meta.get('cascade', True) cascade = (self._meta.get('cascade', True)

View File

@ -108,6 +108,20 @@ class SignalTests(unittest.TestCase):
signal_output.append('post_delete Another signal, %s' % document) signal_output.append('post_delete Another signal, %s' % document)
self.Another = Another self.Another = Another
class ExplicitId(Document):
id = IntField(primary_key=True)
@classmethod
def post_save(cls, sender, document, **kwargs):
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
self.ExplicitId = ExplicitId
self.ExplicitId.objects.delete()
# Save up the number of connected signals so that we can check at the end # Save up the number of connected signals so that we can check at the end
# that all the signals we register get properly unregistered # that all the signals we register get properly unregistered
self.pre_signals = ( self.pre_signals = (
@ -137,6 +151,8 @@ class SignalTests(unittest.TestCase):
signals.pre_delete.connect(Another.pre_delete, sender=Another) signals.pre_delete.connect(Another.pre_delete, sender=Another)
signals.post_delete.connect(Another.post_delete, sender=Another) signals.post_delete.connect(Another.post_delete, sender=Another)
signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId)
def tearDown(self): def tearDown(self):
signals.pre_init.disconnect(self.Author.pre_init) signals.pre_init.disconnect(self.Author.pre_init)
signals.post_init.disconnect(self.Author.post_init) signals.post_init.disconnect(self.Author.post_init)
@ -154,6 +170,8 @@ class SignalTests(unittest.TestCase):
signals.post_save.disconnect(self.Another.post_save) signals.post_save.disconnect(self.Another.post_save)
signals.pre_save.disconnect(self.Another.pre_save) signals.pre_save.disconnect(self.Another.pre_save)
signals.post_save.disconnect(self.ExplicitId.post_save)
# Check that all our signals got disconnected properly. # Check that all our signals got disconnected properly.
post_signals = ( post_signals = (
len(signals.pre_init.receivers), len(signals.pre_init.receivers),
@ -166,6 +184,8 @@ class SignalTests(unittest.TestCase):
len(signals.post_bulk_insert.receivers), len(signals.post_bulk_insert.receivers),
) )
self.ExplicitId.objects.delete()
self.assertEqual(self.pre_signals, post_signals) self.assertEqual(self.pre_signals, post_signals)
def test_model_signals(self): def test_model_signals(self):
@ -228,3 +248,12 @@ class SignalTests(unittest.TestCase):
]) ])
self.Author.objects.delete() self.Author.objects.delete()
def test_signals_with_explicit_doc_ids(self):
""" Model saves must have a created flag the first time."""
ei = self.ExplicitId(id=123)
# post save must received the created flag, even if there's already
# an object id present
self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
# second time, it must be an update
self.assertEqual(self.get_signal_output(ei.save), ['Is updated'])