diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst index d80a421b..3c3159f8 100644 --- a/docs/guide/signals.rst +++ b/docs/guide/signals.rst @@ -30,20 +30,20 @@ Example usage:: return self.name @classmethod - def pre_save(cls, instance, **kwargs): - logging.debug("Pre Save: %s" % instance.name) + def pre_save(cls, sender, document, **kwargs): + logging.debug("Pre Save: %s" % document.name) @classmethod - def post_save(cls, instance, **kwargs): - logging.debug("Post Save: %s" % instance.name) + def post_save(cls, sender, document, **kwargs): + logging.debug("Post Save: %s" % document.name) if 'created' in kwargs: if kwargs['created']: logging.debug("Created") else: logging.debug("Updated") - signals.pre_save.connect(Author.pre_save) - signals.post_save.connect(Author.post_save) + signals.pre_save.connect(Author.pre_save, sender=Author) + signals.post_save.connect(Author.post_save, sender=Author) .. _blinker: http://pypi.python.org/pypi/blinker \ No newline at end of file diff --git a/mongoengine/base.py b/mongoengine/base.py index 8a0ded51..c5b704e1 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -600,7 +600,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class BaseDocument(object): def __init__(self, **values): - signals.pre_init.send(self, values=values) + signals.pre_init.send(self.__class__, document=self, values=values) self._data = {} # Assign default values to instance @@ -619,7 +619,7 @@ class BaseDocument(object): except AttributeError: pass - signals.post_init.send(self) + signals.post_init.send(self.__class__, document=self) def _get_FIELD_display(self, field): """Returns the display value for a choice field""" diff --git a/mongoengine/document.py b/mongoengine/document.py index 2f40eec7..69b19e2c 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -91,7 +91,7 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ - signals.pre_save.send(self) + signals.pre_save.send(self.__class__, document=self) if validate: self.validate() @@ -122,7 +122,7 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) self._changed_fields = [] - signals.post_save.send(self, created=created) + signals.post_save.send(self.__class__, document=self, created=created) def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This @@ -130,7 +130,7 @@ class Document(BaseDocument): :param safe: check if the operation succeeded before returning """ - signals.pre_delete.send(self) + signals.pre_delete.send(self.__class__, document=self) id_field = self._meta['id_field'] object_id = self._fields[id_field].to_mongo(self[id_field]) @@ -140,7 +140,7 @@ class Document(BaseDocument): message = u'Could not delete document (%s)' % err.message raise OperationError(message) - signals.post_delete.send(self) + signals.post_delete.send(self.__class__, document=self) def reload(self): """Reloads all attributes from the database. diff --git a/tests/signals.py b/tests/signals.py index fff2d398..9c413379 100644 --- a/tests/signals.py +++ b/tests/signals.py @@ -28,21 +28,21 @@ class SignalTests(unittest.TestCase): return self.name @classmethod - def pre_init(cls, instance, **kwargs): + def pre_init(cls, sender, document, *args, **kwargs): signal_output.append('pre_init signal, %s' % cls.__name__) signal_output.append(str(kwargs['values'])) @classmethod - def post_init(cls, instance, **kwargs): - signal_output.append('post_init signal, %s' % instance) + def post_init(cls, sender, document, **kwargs): + signal_output.append('post_init signal, %s' % document) @classmethod - def pre_save(cls, instance, **kwargs): - signal_output.append('pre_save signal, %s' % instance) + def pre_save(cls, sender, document, **kwargs): + signal_output.append('pre_save signal, %s' % document) @classmethod - def post_save(cls, instance, **kwargs): - signal_output.append('post_save signal, %s' % instance) + def post_save(cls, sender, document, **kwargs): + signal_output.append('post_save signal, %s' % document) if 'created' in kwargs: if kwargs['created']: signal_output.append('Is created') @@ -50,15 +50,52 @@ class SignalTests(unittest.TestCase): signal_output.append('Is updated') @classmethod - def pre_delete(cls, instance, **kwargs): - signal_output.append('pre_delete signal, %s' % instance) + def pre_delete(cls, sender, document, **kwargs): + signal_output.append('pre_delete signal, %s' % document) @classmethod - def post_delete(cls, instance, **kwargs): - signal_output.append('post_delete signal, %s' % instance) - + def post_delete(cls, sender, document, **kwargs): + signal_output.append('post_delete signal, %s' % document) self.Author = Author + + class Another(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, sender, document, **kwargs): + signal_output.append('pre_init Another signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, sender, document, **kwargs): + signal_output.append('post_init Another signal, %s' % document) + + @classmethod + def pre_save(cls, sender, document, **kwargs): + signal_output.append('pre_save Another signal, %s' % document) + + @classmethod + def post_save(cls, sender, document, **kwargs): + signal_output.append('post_save Another signal, %s' % document) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, sender, document, **kwargs): + signal_output.append('pre_delete Another signal, %s' % document) + + @classmethod + def post_delete(cls, sender, document, **kwargs): + signal_output.append('post_delete Another signal, %s' % document) + + self.Another = Another # Save up the number of connected signals so that we can check at the end # that all the signals we register get properly unregistered self.pre_signals = ( @@ -70,12 +107,19 @@ class SignalTests(unittest.TestCase): len(signals.post_delete.receivers) ) - signals.pre_init.connect(Author.pre_init) - signals.post_init.connect(Author.post_init) - signals.pre_save.connect(Author.pre_save) - signals.post_save.connect(Author.post_save) - signals.pre_delete.connect(Author.pre_delete) - signals.post_delete.connect(Author.post_delete) + signals.pre_init.connect(Author.pre_init, sender=Author) + signals.post_init.connect(Author.post_init, sender=Author) + signals.pre_save.connect(Author.pre_save, sender=Author) + signals.post_save.connect(Author.post_save, sender=Author) + signals.pre_delete.connect(Author.pre_delete, sender=Author) + signals.post_delete.connect(Author.post_delete, sender=Author) + + signals.pre_init.connect(Another.pre_init, sender=Another) + signals.post_init.connect(Another.post_init, sender=Another) + signals.pre_save.connect(Another.pre_save, sender=Another) + signals.post_save.connect(Another.post_save, sender=Another) + signals.pre_delete.connect(Another.pre_delete, sender=Another) + signals.post_delete.connect(Another.post_delete, sender=Another) def tearDown(self): signals.pre_init.disconnect(self.Author.pre_init) @@ -85,6 +129,13 @@ class SignalTests(unittest.TestCase): signals.post_save.disconnect(self.Author.post_save) signals.pre_save.disconnect(self.Author.pre_save) + signals.pre_init.disconnect(self.Another.pre_init) + signals.post_init.disconnect(self.Another.post_init) + signals.post_delete.disconnect(self.Another.post_delete) + signals.pre_delete.disconnect(self.Another.pre_delete) + signals.post_save.disconnect(self.Another.post_save) + signals.pre_save.disconnect(self.Another.pre_save) + # Check that all our signals got disconnected properly. post_signals = ( len(signals.pre_init.receivers),