From c3a88404356c40702ca0204193ede1b1bd21e0ad Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 24 May 2011 20:27:19 +0100 Subject: [PATCH 1/7] Blinker signals added --- mongoengine/base.py | 6 ++ mongoengine/document.py | 10 ++++ mongoengine/signals.py | 41 +++++++++++++ setup.py | 6 +- tests/signals.py | 130 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 mongoengine/signals.py create mode 100644 tests/signals.py diff --git a/mongoengine/base.py b/mongoengine/base.py index ffceb794..101bb73f 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -2,6 +2,8 @@ from queryset import QuerySet, QuerySetManager from queryset import DoesNotExist, MultipleObjectsReturned from queryset import DO_NOTHING +from mongoengine import signals + import sys import pymongo import pymongo.objectid @@ -382,6 +384,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class BaseDocument(object): def __init__(self, **values): + signals.pre_init.send(self, values=values) + self._data = {} # Assign default values to instance for attr_name in self._fields.keys(): @@ -395,6 +399,8 @@ class BaseDocument(object): except AttributeError: pass + signals.post_init.send(self) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. diff --git a/mongoengine/document.py b/mongoengine/document.py index 771b9229..b563f427 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,3 +1,4 @@ +from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError) from queryset import OperationError @@ -75,6 +76,8 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ + signals.pre_save.send(self) + if validate: self.validate() @@ -82,6 +85,7 @@ class Document(BaseDocument): write_options = {} doc = self.to_mongo() + created = '_id' not in doc try: collection = self.__class__.objects._collection if force_insert: @@ -96,12 +100,16 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) + signals.post_save.send(self, created=created) + def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. :param safe: check if the operation succeeded before returning """ + signals.pre_delete.send(self) + id_field = self._meta['id_field'] object_id = self._fields[id_field].to_mongo(self[id_field]) try: @@ -110,6 +118,8 @@ class Document(BaseDocument): message = u'Could not delete document (%s)' % err.message raise OperationError(message) + signals.post_delete.send(self) + @classmethod def register_delete_rule(cls, document_cls, field_name, rule): """This method registers the delete rules to apply when removing this diff --git a/mongoengine/signals.py b/mongoengine/signals.py new file mode 100644 index 00000000..4caa5530 --- /dev/null +++ b/mongoengine/signals.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +signals_available = False +try: + from blinker import Namespace + signals_available = True +except ImportError: + class Namespace(object): + def signal(self, name, doc=None): + return _FakeSignal(name, doc) + + class _FakeSignal(object): + """If blinker is unavailable, create a fake class with the same + interface that allows sending of signals but will fail with an + error on anything else. Instead of doing anything on send, it + will just ignore the arguments and do nothing instead. + """ + + def __init__(self, name, doc=None): + self.name = name + self.__doc__ = doc + + def _fail(self, *args, **kwargs): + raise RuntimeError('signalling support is unavailable ' + 'because the blinker library is ' + 'not installed.') + send = lambda *a, **kw: None + connect = disconnect = has_receivers_for = receivers_for = \ + temporarily_connected_to = _fail + del _fail + +# the namespace for code signals. If you are not mongoengine code, do +# not put signals in here. Create your own namespace instead. +_signals = Namespace() + +pre_init = _signals.signal('pre_init') +post_init = _signals.signal('post_init') +pre_save = _signals.signal('pre_save') +post_save = _signals.signal('post_save') +pre_delete = _signals.signal('pre_delete') +post_delete = _signals.signal('post_delete') diff --git a/setup.py b/setup.py index e0585b7c..01a201d5 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version(version_tuple): version = '%s.%s' % (version, version_tuple[2]) return version -# Dirty hack to get version number from monogengine/__init__.py - we can't +# Dirty hack to get version number from monogengine/__init__.py - we can't # import it as it depends on PyMongo and PyMongo isn't installed until this # file is read init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py') @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo'], - test_suite='tests', + install_requires=['pymongo', 'blinker'], + test_suite='tests.signals', ) diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 00000000..fff2d398 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import * +from mongoengine import signals + +signal_output = [] + + +class SignalTests(unittest.TestCase): + """ + Testing signals before/after saving and deleting. + """ + + def get_signal_output(self, fn, *args, **kwargs): + # Flush any existing signal output + global signal_output + signal_output = [] + fn(*args, **kwargs) + return signal_output + + def setUp(self): + connect(db='mongoenginetest') + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, instance, **kwargs): + signal_output.append('pre_init signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, instance, **kwargs): + signal_output.append('post_init signal, %s' % instance) + + @classmethod + def pre_save(cls, instance, **kwargs): + signal_output.append('pre_save signal, %s' % instance) + + @classmethod + def post_save(cls, instance, **kwargs): + signal_output.append('post_save signal, %s' % instance) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, instance, **kwargs): + signal_output.append('pre_delete signal, %s' % instance) + + @classmethod + def post_delete(cls, instance, **kwargs): + signal_output.append('post_delete signal, %s' % instance) + + self.Author = Author + + # Save up the number of connected signals so that we can check at the end + # that all the signals we register get properly unregistered + self.pre_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + signals.pre_init.connect(Author.pre_init) + signals.post_init.connect(Author.post_init) + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + signals.pre_delete.connect(Author.pre_delete) + signals.post_delete.connect(Author.post_delete) + + def tearDown(self): + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save.disconnect(self.Author.pre_save) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + self.assertEqual(self.pre_signals, post_signals) + + def test_model_signals(self): + """ Model saves should throw some signals. """ + + def create_author(): + a1 = self.Author(name='Bill Shakespeare') + + self.assertEqual(self.get_signal_output(create_author), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", + ]) + + a1 = self.Author(name='Bill Shakespeare') + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, Bill Shakespeare", + "post_save signal, Bill Shakespeare", + "Is created" + ]) + + a1.reload() + a1.name='William Shakespeare' + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, William Shakespeare", + "post_save signal, William Shakespeare", + "Is updated" + ]) + + self.assertEqual(self.get_signal_output(a1.delete), [ + 'pre_delete signal, William Shakespeare', + 'post_delete signal, William Shakespeare', + ]) \ No newline at end of file From 0708d1bedc53d933750e2b871a1c16626627b1f7 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 25 May 2011 09:34:50 +0100 Subject: [PATCH 2/7] Run all tests... --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 01a201d5..d3be64b3 100644 --- a/setup.py +++ b/setup.py @@ -46,5 +46,5 @@ setup(name='mongoengine', platforms=['any'], classifiers=CLASSIFIERS, install_requires=['pymongo', 'blinker'], - test_suite='tests.signals', + test_suite='tests', ) From 5d778648e697651ca681d5347051e6071cfe8487 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 27 May 2011 11:33:40 +0100 Subject: [PATCH 3/7] Inital tests for dereferencing improvements --- mongoengine/base.py | 1 + mongoengine/fields.py | 215 +++++++++++++++++++++++-------- mongoengine/tests.py | 58 +++++++++ tests/dereference.py | 288 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 512 insertions(+), 50 deletions(-) create mode 100644 mongoengine/tests.py create mode 100644 tests/dereference.py diff --git a/mongoengine/base.py b/mongoengine/base.py index ffceb794..4e3154fd 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -126,6 +126,7 @@ class BaseField(object): self.validate(value) + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b2aab5a4..c21829c9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -337,33 +337,54 @@ class ListField(BaseField): # Document class being used rather than a document object return self - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_list.append(referenced_type._from_son(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list + # Get value from document instance if available + value_list = instance._data.get(self.name) + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in enumerate(value_list)] + deref_list = [] + collections = {} - if isinstance(self.field, GenericReferenceField): - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_list.append(self.field.dereference(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list + for k, v in value_list: + deref_list.append(v) + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + deref_list[key] = get_document(ref['_cls'])._from_son(ref) + instance._data[self.name] = deref_list + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + + db = _get_db() + value_list = [(k,v) for k,v in enumerate(value_list)] + deref_list = [] + classes = {} + + for k, v in value_list: + deref_list.append(v) + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + deref_list[key] = doc_cls._from_son(ref) + instance._data[self.name] = deref_list return super(ListField, self).__get__(instance, owner) def to_python(self, value): @@ -501,32 +522,53 @@ class MapField(BaseField): # Document class being used rather than a document object return self - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_dict[key] = referenced_type._from_son(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict + # Get value from document instance if available + value_list = instance._data.get(self.name) + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + deref_dict = {} + collections = {} - if isinstance(self.field, GenericReferenceField): - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_dict[key] = self.field.dereference(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict + for k, v in value_list.items(): + deref_dict[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + deref_dict[key] = get_document(ref['_cls'])._from_son(ref) + instance._data[self.name] = deref_dict + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + deref_dict = {} + classes = {} + + for k, v in value_list: + deref_dict[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + deref_dict[key] = doc_cls._from_son(ref) + instance._data[self.name] = deref_dict return super(MapField, self).__get__(instance, owner) @@ -869,3 +911,76 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') + + + +class DereferenceMixin(object): + """ WORK IN PROGRESS""" + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(MapField, self).__get__(instance, owner) + + is_dict = True + if not hasattr(value_list, 'items'): + is_dict = False + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + if not is_dict: + dbref = [] + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + instance._data[self.name] = dbref + + return super(DereferenceField, self).__get__(instance, owner) \ No newline at end of file diff --git a/mongoengine/tests.py b/mongoengine/tests.py new file mode 100644 index 00000000..4932bc2c --- /dev/null +++ b/mongoengine/tests.py @@ -0,0 +1,58 @@ +from mongoengine.connection import _get_db + +class query_counter(object): + """ Query_counter contextmanager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = _get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/tests/dereference.py b/tests/dereference.py new file mode 100644 index 00000000..2764ee72 --- /dev/null +++ b/tests/dereference.py @@ -0,0 +1,288 @@ +import unittest + +from mongoengine import * +from mongoengine.connection import _get_db +from mongoengine.tests import query_counter + + +class FieldTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = _get_db() + + def ztest_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def ztest_recursive_reference(self): + """Ensure that ReferenceFields can reference their own documents. + """ + class Employee(Document): + name = StringField() + boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + bill = Employee(name='Bill Lumbergh') + bill.save() + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id) + self.assertEqual(q, 1) + + peter.boss + self.assertEqual(q, 2) + + peter.friends + self.assertEqual(q, 3) + + def ztest_generic_reference(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_map_field_reference(self): + + class User(Document): + name = StringField() + + class Group(Document): + members = MapField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + members.append(user) + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def ztest_generic_reference_dict_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = DictField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_generic_reference_map_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = MapField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() \ No newline at end of file From ec7effa0ef8c3a71d1f8dd0695639f60763b9858 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:04:06 +0100 Subject: [PATCH 4/7] Added DereferenceBaseField class Handles the lazy dereferencing of all items in a list / dict. Improves query efficiency by an order of magnitude. --- mongoengine/base.py | 82 ++++++++++++++++ mongoengine/fields.py | 223 ++++-------------------------------------- tests/dereference.py | 6 +- 3 files changed, 104 insertions(+), 207 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 4e3154fd..ce61547e 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -5,6 +5,7 @@ from queryset import DO_NOTHING import sys import pymongo import pymongo.objectid +from operator import itemgetter class NotRegistered(Exception): @@ -127,6 +128,87 @@ class BaseField(object): self.validate(value) +class DereferenceBaseField(BaseField): + """Handles the lazy dereferencing of a queryset. Will dereference all + items in a list / dict rather than one at a time. + """ + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + from fields import ReferenceField, GenericReferenceField + from connection import _get_db + + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(DereferenceBaseField, self).__get__(instance, owner) + + is_list = False + if not hasattr(value_list, 'items'): + is_list = True + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + + instance._data[self.name] = dbref + + return super(DereferenceBaseField, self).__get__(instance, owner) + + + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c21829c9..dc03fc05 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,5 @@ -from base import BaseField, ObjectIdField, ValidationError, get_document +from base import (BaseField, DereferenceBaseField, ObjectIdField, + ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument from connection import _get_db @@ -12,7 +13,6 @@ import pymongo.binary import datetime, time import decimal import gridfs -import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', @@ -118,8 +118,8 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,6 +153,7 @@ class IntField(BaseField): def prepare_query_value(self, op, value): return int(value) + class FloatField(BaseField): """An floating point number field. """ @@ -178,6 +179,7 @@ class FloatField(BaseField): def prepare_query_value(self, op, value): return float(value) + class DecimalField(BaseField): """A fixed-point decimal number field. @@ -252,21 +254,21 @@ class DateTimeField(BaseField): else: usecs = 0 kwargs = {'microsecond': usecs} - try: # Seconds are optional, so try converting seconds first. + try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], **kwargs) - except ValueError: - try: # Try without seconds. + try: # Try without seconds. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], **kwargs) - except ValueError: # Try without hour/minutes/seconds. + except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], **kwargs) except ValueError: return None + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. @@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(BaseField): +class ListField(DereferenceBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -330,63 +332,6 @@ class ListField(BaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - value_list = [(k,v) for k,v in enumerate(value_list)] - deref_list = [] - collections = {} - - for k, v in value_list: - deref_list.append(v) - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - deref_list[key] = get_document(ref['_cls'])._from_son(ref) - instance._data[self.name] = deref_list - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in enumerate(value_list)] - deref_list = [] - classes = {} - - for k, v in value_list: - deref_list.append(v) - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - deref_list[key] = doc_cls._from_son(ref) - instance._data[self.name] = deref_list - return super(ListField, self).__get__(instance, owner) - def to_python(self, value): return [self.field.to_python(item) for item in value] @@ -480,10 +425,10 @@ class DictField(BaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField,self).prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) -class MapField(BaseField): +class MapField(DereferenceBaseField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -515,68 +460,11 @@ class MapField(BaseField): except Exception, err: raise ValidationError('Invalid MapField item (%s)' % str(item)) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - deref_dict = {} - collections = {} - - for k, v in value_list.items(): - deref_dict[k] = v - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - deref_dict[key] = get_document(ref['_cls'])._from_son(ref) - instance._data[self.name] = deref_dict - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - deref_dict = {} - classes = {} - - for k, v in value_list: - deref_dict[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - deref_dict[key] = doc_cls._from_son(ref) - instance._data[self.name] = deref_dict - - return super(MapField, self).__get__(instance, owner) - def to_python(self, value): - return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) def to_mongo(self, value): - return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) def prepare_query_value(self, op, value): if op not in ('set', 'unset'): @@ -794,11 +682,11 @@ class GridFSProxy(object): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id - def put(self, file, **kwargs): + def put(self, file_obj, **kwargs): if self.grid_id: raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') - self.grid_id = self.fs.put(file, **kwargs) + self.grid_id = self.fs.put(file_obj, **kwargs) def write(self, string): if self.grid_id: @@ -827,9 +715,9 @@ class GridFSProxy(object): self.grid_id = None self.gridout = None - def replace(self, file, **kwargs): + def replace(self, file_obj, **kwargs): self.delete() - self.put(file, **kwargs) + self.put(file_obj, **kwargs) def close(self): if self.newfile: @@ -911,76 +799,3 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') - - - -class DereferenceMixin(object): - """ WORK IN PROGRESS""" - - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if not value_list: - return super(MapField, self).__get__(instance, owner) - - is_dict = True - if not hasattr(value_list, 'items'): - is_dict = False - value_list = dict([(k,v) for k,v in enumerate(value_list)]) - - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - dbref = {} - if not is_dict: - dbref = [] - collections = {} - - for k, v in value_list.items(): - dbref[k] = v - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - dbref[key] = get_document(ref['_cls'])._from_son(ref) - - instance._data[self.name] = dbref - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - dbref = {} - classes = {} - - for k, v in value_list: - dbref[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - dbref[key] = doc_cls._from_son(ref) - instance._data[self.name] = dbref - - return super(DereferenceField, self).__get__(instance, owner) \ No newline at end of file diff --git a/tests/dereference.py b/tests/dereference.py index 2764ee72..b6cee89e 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -11,7 +11,7 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = _get_db() - def ztest_list_item_dereference(self): + def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ class User(Document): @@ -42,7 +42,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() - def ztest_recursive_reference(self): + def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ class Employee(Document): @@ -75,7 +75,7 @@ class FieldTest(unittest.TestCase): peter.friends self.assertEqual(q, 3) - def ztest_generic_reference(self): + def test_generic_reference(self): class UserA(Document): name = StringField() From 7312db5c252bf3c395357cba3b7254cdccd1c6c0 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:07:27 +0100 Subject: [PATCH 5/7] Updated docs / authors. Thanks @jorgebastida for the awesome query_counter test context manager. --- AUTHORS | 1 + docs/changelog.rst | 2 ++ mongoengine/tests.py | 1 + 3 files changed, 4 insertions(+) diff --git a/AUTHORS b/AUTHORS index 93fe819e..aecdcaa9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -3,3 +3,4 @@ Matt Dennewitz Deepak Thukral Florian Schlachter Steve Challis +Ross Lawley diff --git a/docs/changelog.rst b/docs/changelog.rst index 686b326f..58da0d94 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,8 @@ Changelog Changes in dev ============== +- Added query_counter context manager for tests +- Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies - Added inline_map_reduce option to map_reduce - Updated connection exception so it provides more info on the cause. diff --git a/mongoengine/tests.py b/mongoengine/tests.py index 4932bc2c..9584bc7c 100644 --- a/mongoengine/tests.py +++ b/mongoengine/tests.py @@ -1,5 +1,6 @@ from mongoengine.connection import _get_db + class query_counter(object): """ Query_counter contextmanager to get the number of queries. """ From 0e4507811611b80be6529f2376c5e3e9b4d5bdef Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:34:43 +0100 Subject: [PATCH 6/7] Added Blinker signal support --- docs/changelog.rst | 1 + docs/guide/index.rst | 1 + mongoengine/__init__.py | 4 +++- mongoengine/signals.py | 3 +++ 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 58da0d94..659bdb4e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added blinker signal support - Added query_counter context manager for tests - Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies diff --git a/docs/guide/index.rst b/docs/guide/index.rst index aac72469..d56e7479 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -11,3 +11,4 @@ User Guide document-instances querying gridfs + signals diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 6d18ffe7..de635f96 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -6,9 +6,11 @@ import connection from connection import * import queryset from queryset import * +import signals +from signals import * __all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__) + queryset.__all__ + signals.__all__) __author__ = 'Harry Marr' diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 4caa5530..0a697534 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- +__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', + 'pre_delete', 'post_delete'] + signals_available = False try: from blinker import Namespace From 74b5043ef9441938c6668af6eb510adccc8e531a Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:39:58 +0100 Subject: [PATCH 7/7] Added signals documentation --- docs/guide/signals.rst | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 docs/guide/signals.rst diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst new file mode 100644 index 00000000..d80a421b --- /dev/null +++ b/docs/guide/signals.rst @@ -0,0 +1,49 @@ +.. _signals: + +Signals +======= + +.. versionadded:: 0.5 + +Signal support is provided by the excellent `blinker`_ library and +will gracefully fall back if it is not available. + + +The following document signals exist in MongoEngine and are pretty self explaintary: + + * `mongoengine.signals.pre_init` + * `mongoengine.signals.post_init` + * `mongoengine.signals.pre_save` + * `mongoengine.signals.post_save` + * `mongoengine.signals.pre_delete` + * `mongoengine.signals.post_delete` + +Example usage:: + + from mongoengine import * + from mongoengine import signals + + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_save(cls, instance, **kwargs): + logging.debug("Pre Save: %s" % instance.name) + + @classmethod + def post_save(cls, instance, **kwargs): + logging.debug("Post Save: %s" % instance.name) + if 'created' in kwargs: + if kwargs['created']: + logging.debug("Created") + else: + logging.debug("Updated") + + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + + +.. _blinker: http://pypi.python.org/pypi/blinker \ No newline at end of file