diff --git a/AUTHORS b/AUTHORS index 93fe819e..aecdcaa9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -3,3 +3,4 @@ Matt Dennewitz Deepak Thukral Florian Schlachter Steve Challis +Ross Lawley diff --git a/docs/changelog.rst b/docs/changelog.rst index 686b326f..659bdb4e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,9 @@ Changelog Changes in dev ============== +- Added blinker signal support +- Added query_counter context manager for tests +- Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies - Added inline_map_reduce option to map_reduce - Updated connection exception so it provides more info on the cause. diff --git a/docs/guide/index.rst b/docs/guide/index.rst index aac72469..d56e7479 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -11,3 +11,4 @@ User Guide document-instances querying gridfs + signals diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst new file mode 100644 index 00000000..d80a421b --- /dev/null +++ b/docs/guide/signals.rst @@ -0,0 +1,49 @@ +.. _signals: + +Signals +======= + +.. versionadded:: 0.5 + +Signal support is provided by the excellent `blinker`_ library and +will gracefully fall back if it is not available. + + +The following document signals exist in MongoEngine and are pretty self explaintary: + + * `mongoengine.signals.pre_init` + * `mongoengine.signals.post_init` + * `mongoengine.signals.pre_save` + * `mongoengine.signals.post_save` + * `mongoengine.signals.pre_delete` + * `mongoengine.signals.post_delete` + +Example usage:: + + from mongoengine import * + from mongoengine import signals + + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_save(cls, instance, **kwargs): + logging.debug("Pre Save: %s" % instance.name) + + @classmethod + def post_save(cls, instance, **kwargs): + logging.debug("Post Save: %s" % instance.name) + if 'created' in kwargs: + if kwargs['created']: + logging.debug("Created") + else: + logging.debug("Updated") + + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + + +.. _blinker: http://pypi.python.org/pypi/blinker \ No newline at end of file diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 6d18ffe7..de635f96 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -6,9 +6,11 @@ import connection from connection import * import queryset from queryset import * +import signals +from signals import * __all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__) + queryset.__all__ + signals.__all__) __author__ = 'Harry Marr' diff --git a/mongoengine/base.py b/mongoengine/base.py index ffceb794..76bb1ab7 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager from queryset import DoesNotExist, MultipleObjectsReturned from queryset import DO_NOTHING +from mongoengine import signals + import sys import pymongo import pymongo.objectid +from operator import itemgetter class NotRegistered(Exception): @@ -126,6 +129,88 @@ class BaseField(object): self.validate(value) + +class DereferenceBaseField(BaseField): + """Handles the lazy dereferencing of a queryset. Will dereference all + items in a list / dict rather than one at a time. + """ + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + from fields import ReferenceField, GenericReferenceField + from connection import _get_db + + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(DereferenceBaseField, self).__get__(instance, owner) + + is_list = False + if not hasattr(value_list, 'items'): + is_list = True + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + + instance._data[self.name] = dbref + + return super(DereferenceBaseField, self).__get__(instance, owner) + + + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ @@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class BaseDocument(object): def __init__(self, **values): + signals.pre_init.send(self, values=values) + self._data = {} # Assign default values to instance for attr_name in self._fields.keys(): @@ -395,6 +482,8 @@ class BaseDocument(object): except AttributeError: pass + signals.post_init.send(self) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. diff --git a/mongoengine/document.py b/mongoengine/document.py index 771b9229..b563f427 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,3 +1,4 @@ +from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError) from queryset import OperationError @@ -75,6 +76,8 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ + signals.pre_save.send(self) + if validate: self.validate() @@ -82,6 +85,7 @@ class Document(BaseDocument): write_options = {} doc = self.to_mongo() + created = '_id' not in doc try: collection = self.__class__.objects._collection if force_insert: @@ -96,12 +100,16 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) + signals.post_save.send(self, created=created) + def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. :param safe: check if the operation succeeded before returning """ + signals.pre_delete.send(self) + id_field = self._meta['id_field'] object_id = self._fields[id_field].to_mongo(self[id_field]) try: @@ -110,6 +118,8 @@ class Document(BaseDocument): message = u'Could not delete document (%s)' % err.message raise OperationError(message) + signals.post_delete.send(self) + @classmethod def register_delete_rule(cls, document_cls, field_name, rule): """This method registers the delete rules to apply when removing this diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b2aab5a4..dc03fc05 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,5 @@ -from base import BaseField, ObjectIdField, ValidationError, get_document +from base import (BaseField, DereferenceBaseField, ObjectIdField, + ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument from connection import _get_db @@ -12,7 +13,6 @@ import pymongo.binary import datetime, time import decimal import gridfs -import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', @@ -118,8 +118,8 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,6 +153,7 @@ class IntField(BaseField): def prepare_query_value(self, op, value): return int(value) + class FloatField(BaseField): """An floating point number field. """ @@ -178,6 +179,7 @@ class FloatField(BaseField): def prepare_query_value(self, op, value): return float(value) + class DecimalField(BaseField): """A fixed-point decimal number field. @@ -252,21 +254,21 @@ class DateTimeField(BaseField): else: usecs = 0 kwargs = {'microsecond': usecs} - try: # Seconds are optional, so try converting seconds first. + try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], **kwargs) - except ValueError: - try: # Try without seconds. + try: # Try without seconds. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], **kwargs) - except ValueError: # Try without hour/minutes/seconds. + except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], **kwargs) except ValueError: return None + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. @@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(BaseField): +class ListField(DereferenceBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -330,42 +332,6 @@ class ListField(BaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_list.append(referenced_type._from_son(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list - - if isinstance(self.field, GenericReferenceField): - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_list.append(self.field.dereference(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list - - return super(ListField, self).__get__(instance, owner) - def to_python(self, value): return [self.field.to_python(item) for item in value] @@ -459,10 +425,10 @@ class DictField(BaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField,self).prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) -class MapField(BaseField): +class MapField(DereferenceBaseField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -494,47 +460,11 @@ class MapField(BaseField): except Exception, err: raise ValidationError('Invalid MapField item (%s)' % str(item)) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_dict[key] = referenced_type._from_son(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict - - if isinstance(self.field, GenericReferenceField): - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_dict[key] = self.field.dereference(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict - - return super(MapField, self).__get__(instance, owner) - def to_python(self, value): - return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) def to_mongo(self, value): - return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) def prepare_query_value(self, op, value): if op not in ('set', 'unset'): @@ -752,11 +682,11 @@ class GridFSProxy(object): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id - def put(self, file, **kwargs): + def put(self, file_obj, **kwargs): if self.grid_id: raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') - self.grid_id = self.fs.put(file, **kwargs) + self.grid_id = self.fs.put(file_obj, **kwargs) def write(self, string): if self.grid_id: @@ -785,9 +715,9 @@ class GridFSProxy(object): self.grid_id = None self.gridout = None - def replace(self, file, **kwargs): + def replace(self, file_obj, **kwargs): self.delete() - self.put(file, **kwargs) + self.put(file_obj, **kwargs) def close(self): if self.newfile: diff --git a/mongoengine/signals.py b/mongoengine/signals.py new file mode 100644 index 00000000..0a697534 --- /dev/null +++ b/mongoengine/signals.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', + 'pre_delete', 'post_delete'] + +signals_available = False +try: + from blinker import Namespace + signals_available = True +except ImportError: + class Namespace(object): + def signal(self, name, doc=None): + return _FakeSignal(name, doc) + + class _FakeSignal(object): + """If blinker is unavailable, create a fake class with the same + interface that allows sending of signals but will fail with an + error on anything else. Instead of doing anything on send, it + will just ignore the arguments and do nothing instead. + """ + + def __init__(self, name, doc=None): + self.name = name + self.__doc__ = doc + + def _fail(self, *args, **kwargs): + raise RuntimeError('signalling support is unavailable ' + 'because the blinker library is ' + 'not installed.') + send = lambda *a, **kw: None + connect = disconnect = has_receivers_for = receivers_for = \ + temporarily_connected_to = _fail + del _fail + +# the namespace for code signals. If you are not mongoengine code, do +# not put signals in here. Create your own namespace instead. +_signals = Namespace() + +pre_init = _signals.signal('pre_init') +post_init = _signals.signal('post_init') +pre_save = _signals.signal('pre_save') +post_save = _signals.signal('post_save') +pre_delete = _signals.signal('pre_delete') +post_delete = _signals.signal('post_delete') diff --git a/mongoengine/tests.py b/mongoengine/tests.py new file mode 100644 index 00000000..9584bc7c --- /dev/null +++ b/mongoengine/tests.py @@ -0,0 +1,59 @@ +from mongoengine.connection import _get_db + + +class query_counter(object): + """ Query_counter contextmanager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = _get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/setup.py b/setup.py index 0c19d8d0..d3be64b3 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo'], + install_requires=['pymongo', 'blinker'], test_suite='tests', ) diff --git a/tests/dereference.py b/tests/dereference.py new file mode 100644 index 00000000..b6cee89e --- /dev/null +++ b/tests/dereference.py @@ -0,0 +1,288 @@ +import unittest + +from mongoengine import * +from mongoengine.connection import _get_db +from mongoengine.tests import query_counter + + +class FieldTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = _get_db() + + def test_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def test_recursive_reference(self): + """Ensure that ReferenceFields can reference their own documents. + """ + class Employee(Document): + name = StringField() + boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + bill = Employee(name='Bill Lumbergh') + bill.save() + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id) + self.assertEqual(q, 1) + + peter.boss + self.assertEqual(q, 2) + + peter.friends + self.assertEqual(q, 3) + + def test_generic_reference(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_map_field_reference(self): + + class User(Document): + name = StringField() + + class Group(Document): + members = MapField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + members.append(user) + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def ztest_generic_reference_dict_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = DictField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_generic_reference_map_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = MapField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() \ No newline at end of file diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 00000000..fff2d398 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import * +from mongoengine import signals + +signal_output = [] + + +class SignalTests(unittest.TestCase): + """ + Testing signals before/after saving and deleting. + """ + + def get_signal_output(self, fn, *args, **kwargs): + # Flush any existing signal output + global signal_output + signal_output = [] + fn(*args, **kwargs) + return signal_output + + def setUp(self): + connect(db='mongoenginetest') + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, instance, **kwargs): + signal_output.append('pre_init signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, instance, **kwargs): + signal_output.append('post_init signal, %s' % instance) + + @classmethod + def pre_save(cls, instance, **kwargs): + signal_output.append('pre_save signal, %s' % instance) + + @classmethod + def post_save(cls, instance, **kwargs): + signal_output.append('post_save signal, %s' % instance) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, instance, **kwargs): + signal_output.append('pre_delete signal, %s' % instance) + + @classmethod + def post_delete(cls, instance, **kwargs): + signal_output.append('post_delete signal, %s' % instance) + + self.Author = Author + + # Save up the number of connected signals so that we can check at the end + # that all the signals we register get properly unregistered + self.pre_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + signals.pre_init.connect(Author.pre_init) + signals.post_init.connect(Author.post_init) + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + signals.pre_delete.connect(Author.pre_delete) + signals.post_delete.connect(Author.post_delete) + + def tearDown(self): + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save.disconnect(self.Author.pre_save) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + self.assertEqual(self.pre_signals, post_signals) + + def test_model_signals(self): + """ Model saves should throw some signals. """ + + def create_author(): + a1 = self.Author(name='Bill Shakespeare') + + self.assertEqual(self.get_signal_output(create_author), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", + ]) + + a1 = self.Author(name='Bill Shakespeare') + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, Bill Shakespeare", + "post_save signal, Bill Shakespeare", + "Is created" + ]) + + a1.reload() + a1.name='William Shakespeare' + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, William Shakespeare", + "post_save signal, William Shakespeare", + "Is updated" + ]) + + self.assertEqual(self.get_signal_output(a1.delete), [ + 'pre_delete signal, William Shakespeare', + 'post_delete signal, William Shakespeare', + ]) \ No newline at end of file