From c3a88404356c40702ca0204193ede1b1bd21e0ad Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 24 May 2011 20:27:19 +0100 Subject: [PATCH 01/19] Blinker signals added --- mongoengine/base.py | 6 ++ mongoengine/document.py | 10 ++++ mongoengine/signals.py | 41 +++++++++++++ setup.py | 6 +- tests/signals.py | 130 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 mongoengine/signals.py create mode 100644 tests/signals.py diff --git a/mongoengine/base.py b/mongoengine/base.py index ffceb794..101bb73f 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -2,6 +2,8 @@ from queryset import QuerySet, QuerySetManager from queryset import DoesNotExist, MultipleObjectsReturned from queryset import DO_NOTHING +from mongoengine import signals + import sys import pymongo import pymongo.objectid @@ -382,6 +384,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class BaseDocument(object): def __init__(self, **values): + signals.pre_init.send(self, values=values) + self._data = {} # Assign default values to instance for attr_name in self._fields.keys(): @@ -395,6 +399,8 @@ class BaseDocument(object): except AttributeError: pass + signals.post_init.send(self) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. diff --git a/mongoengine/document.py b/mongoengine/document.py index 771b9229..b563f427 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,3 +1,4 @@ +from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError) from queryset import OperationError @@ -75,6 +76,8 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ + signals.pre_save.send(self) + if validate: self.validate() @@ -82,6 +85,7 @@ class Document(BaseDocument): write_options = {} doc = self.to_mongo() + created = '_id' not in doc try: collection = self.__class__.objects._collection if force_insert: @@ -96,12 +100,16 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) + signals.post_save.send(self, created=created) + def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. :param safe: check if the operation succeeded before returning """ + signals.pre_delete.send(self) + id_field = self._meta['id_field'] object_id = self._fields[id_field].to_mongo(self[id_field]) try: @@ -110,6 +118,8 @@ class Document(BaseDocument): message = u'Could not delete document (%s)' % err.message raise OperationError(message) + signals.post_delete.send(self) + @classmethod def register_delete_rule(cls, document_cls, field_name, rule): """This method registers the delete rules to apply when removing this diff --git a/mongoengine/signals.py b/mongoengine/signals.py new file mode 100644 index 00000000..4caa5530 --- /dev/null +++ b/mongoengine/signals.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +signals_available = False +try: + from blinker import Namespace + signals_available = True +except ImportError: + class Namespace(object): + def signal(self, name, doc=None): + return _FakeSignal(name, doc) + + class _FakeSignal(object): + """If blinker is unavailable, create a fake class with the same + interface that allows sending of signals but will fail with an + error on anything else. Instead of doing anything on send, it + will just ignore the arguments and do nothing instead. + """ + + def __init__(self, name, doc=None): + self.name = name + self.__doc__ = doc + + def _fail(self, *args, **kwargs): + raise RuntimeError('signalling support is unavailable ' + 'because the blinker library is ' + 'not installed.') + send = lambda *a, **kw: None + connect = disconnect = has_receivers_for = receivers_for = \ + temporarily_connected_to = _fail + del _fail + +# the namespace for code signals. If you are not mongoengine code, do +# not put signals in here. Create your own namespace instead. +_signals = Namespace() + +pre_init = _signals.signal('pre_init') +post_init = _signals.signal('post_init') +pre_save = _signals.signal('pre_save') +post_save = _signals.signal('post_save') +pre_delete = _signals.signal('pre_delete') +post_delete = _signals.signal('post_delete') diff --git a/setup.py b/setup.py index e0585b7c..01a201d5 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version(version_tuple): version = '%s.%s' % (version, version_tuple[2]) return version -# Dirty hack to get version number from monogengine/__init__.py - we can't +# Dirty hack to get version number from monogengine/__init__.py - we can't # import it as it depends on PyMongo and PyMongo isn't installed until this # file is read init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py') @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo'], - test_suite='tests', + install_requires=['pymongo', 'blinker'], + test_suite='tests.signals', ) diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 00000000..fff2d398 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import * +from mongoengine import signals + +signal_output = [] + + +class SignalTests(unittest.TestCase): + """ + Testing signals before/after saving and deleting. + """ + + def get_signal_output(self, fn, *args, **kwargs): + # Flush any existing signal output + global signal_output + signal_output = [] + fn(*args, **kwargs) + return signal_output + + def setUp(self): + connect(db='mongoenginetest') + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, instance, **kwargs): + signal_output.append('pre_init signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, instance, **kwargs): + signal_output.append('post_init signal, %s' % instance) + + @classmethod + def pre_save(cls, instance, **kwargs): + signal_output.append('pre_save signal, %s' % instance) + + @classmethod + def post_save(cls, instance, **kwargs): + signal_output.append('post_save signal, %s' % instance) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, instance, **kwargs): + signal_output.append('pre_delete signal, %s' % instance) + + @classmethod + def post_delete(cls, instance, **kwargs): + signal_output.append('post_delete signal, %s' % instance) + + self.Author = Author + + # Save up the number of connected signals so that we can check at the end + # that all the signals we register get properly unregistered + self.pre_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + signals.pre_init.connect(Author.pre_init) + signals.post_init.connect(Author.post_init) + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + signals.pre_delete.connect(Author.pre_delete) + signals.post_delete.connect(Author.post_delete) + + def tearDown(self): + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save.disconnect(self.Author.pre_save) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + self.assertEqual(self.pre_signals, post_signals) + + def test_model_signals(self): + """ Model saves should throw some signals. """ + + def create_author(): + a1 = self.Author(name='Bill Shakespeare') + + self.assertEqual(self.get_signal_output(create_author), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", + ]) + + a1 = self.Author(name='Bill Shakespeare') + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, Bill Shakespeare", + "post_save signal, Bill Shakespeare", + "Is created" + ]) + + a1.reload() + a1.name='William Shakespeare' + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, William Shakespeare", + "post_save signal, William Shakespeare", + "Is updated" + ]) + + self.assertEqual(self.get_signal_output(a1.delete), [ + 'pre_delete signal, William Shakespeare', + 'post_delete signal, William Shakespeare', + ]) \ No newline at end of file From 0708d1bedc53d933750e2b871a1c16626627b1f7 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 25 May 2011 09:34:50 +0100 Subject: [PATCH 02/19] Run all tests... --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 01a201d5..d3be64b3 100644 --- a/setup.py +++ b/setup.py @@ -46,5 +46,5 @@ setup(name='mongoengine', platforms=['any'], classifiers=CLASSIFIERS, install_requires=['pymongo', 'blinker'], - test_suite='tests.signals', + test_suite='tests', ) From 6f5bd7b0b90eb33760896ba907634013d404b4c8 Mon Sep 17 00:00:00 2001 From: Colin Howe Date: Thu, 26 May 2011 18:54:52 +0100 Subject: [PATCH 03/19] Test needs a connection... --- tests/queryset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/queryset.py b/tests/queryset.py index 1f03fbd9..081ffb32 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2101,6 +2101,9 @@ class QuerySetTest(unittest.TestCase): class QTest(unittest.TestCase): + def setUp(self): + connect(db='mongoenginetest') + def test_empty_q(self): """Ensure that empty Q objects won't hurt. """ From 1fa47206aa817dc4556e703de9121d69fb8b064c Mon Sep 17 00:00:00 2001 From: Colin Howe Date: Thu, 26 May 2011 19:39:41 +0100 Subject: [PATCH 04/19] Support for sparse indexes and omitting types from indexes --- mongoengine/queryset.py | 31 ++++++++++++++++++++++--------- tests/document.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 17a1b0da..68afefca 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -382,15 +382,17 @@ class QuerySet(object): return self @classmethod - def _build_index_spec(cls, doc_cls, key_or_list): + def _build_index_spec(cls, doc_cls, spec): """Build a PyMongo index spec from a MongoEngine index spec. """ - if isinstance(key_or_list, basestring): - key_or_list = [key_or_list] + if isinstance(spec, basestring): + spec = {'fields': [spec]} + if isinstance(spec, (list, tuple)): + spec = {'fields': spec} index_list = [] use_types = doc_cls._meta.get('allow_inheritance', True) - for key in key_or_list: + for key in spec['fields']: # Get direction from + or - direction = pymongo.ASCENDING if key.startswith("-"): @@ -411,10 +413,18 @@ class QuerySet(object): use_types = False # If _types is being used, prepend it to every specified index - if doc_cls._meta.get('allow_inheritance') and use_types: + if (spec.get('types', True) and doc_cls._meta.get('allow_inheritance') + and use_types): index_list.insert(0, ('_types', 1)) - return index_list + spec['fields'] = index_list + + if spec.get('sparse', False) and len(spec['fields']) > 1: + raise ValueError( + 'Sparse indexes can only have one field in them. ' + 'See https://jira.mongodb.org/browse/SERVER-2193') + + return spec def __call__(self, q_obj=None, class_check=True, **query): """Filter the selected documents by calling the @@ -465,9 +475,12 @@ class QuerySet(object): # Ensure document-defined indexes are created if self._document._meta['indexes']: - for key_or_list in self._document._meta['indexes']: - self._collection.ensure_index(key_or_list, - background=background, **index_opts) + for spec in self._document._meta['indexes']: + opts = index_opts.copy() + opts['unique'] = spec.get('unique', False) + opts['sparse'] = spec.get('sparse', False) + self._collection.ensure_index(spec['fields'], + background=background, **opts) # If _types is being used (for polymorphism), it needs an index if '_types' in self._query: diff --git a/tests/document.py b/tests/document.py index fe67312e..a8120469 100644 --- a/tests/document.py +++ b/tests/document.py @@ -377,6 +377,40 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() + + def test_dictionary_indexes(self): + """Ensure that indexes are used when meta[indexes] contains dictionaries + instead of lists. + """ + class BlogPost(Document): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + { 'fields': ['-date'], 'unique': True, + 'sparse': True, 'types': False }, + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # _id, '-date' + self.assertEqual(len(info), 3) + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('addDate', -1)], True, True) in info) + + BlogPost.drop_collection() + + def test_unique(self): """Ensure that uniqueness constraints are applied to fields. """ From 5d778648e697651ca681d5347051e6071cfe8487 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 27 May 2011 11:33:40 +0100 Subject: [PATCH 05/19] Inital tests for dereferencing improvements --- mongoengine/base.py | 1 + mongoengine/fields.py | 215 +++++++++++++++++++++++-------- mongoengine/tests.py | 58 +++++++++ tests/dereference.py | 288 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 512 insertions(+), 50 deletions(-) create mode 100644 mongoengine/tests.py create mode 100644 tests/dereference.py diff --git a/mongoengine/base.py b/mongoengine/base.py index ffceb794..4e3154fd 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -126,6 +126,7 @@ class BaseField(object): self.validate(value) + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b2aab5a4..c21829c9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -337,33 +337,54 @@ class ListField(BaseField): # Document class being used rather than a document object return self - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_list.append(referenced_type._from_son(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list + # Get value from document instance if available + value_list = instance._data.get(self.name) + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in enumerate(value_list)] + deref_list = [] + collections = {} - if isinstance(self.field, GenericReferenceField): - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_list.append(self.field.dereference(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list + for k, v in value_list: + deref_list.append(v) + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + deref_list[key] = get_document(ref['_cls'])._from_son(ref) + instance._data[self.name] = deref_list + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + + db = _get_db() + value_list = [(k,v) for k,v in enumerate(value_list)] + deref_list = [] + classes = {} + + for k, v in value_list: + deref_list.append(v) + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + deref_list[key] = doc_cls._from_son(ref) + instance._data[self.name] = deref_list return super(ListField, self).__get__(instance, owner) def to_python(self, value): @@ -501,32 +522,53 @@ class MapField(BaseField): # Document class being used rather than a document object return self - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_dict[key] = referenced_type._from_son(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict + # Get value from document instance if available + value_list = instance._data.get(self.name) + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + deref_dict = {} + collections = {} - if isinstance(self.field, GenericReferenceField): - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_dict[key] = self.field.dereference(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict + for k, v in value_list.items(): + deref_dict[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + deref_dict[key] = get_document(ref['_cls'])._from_son(ref) + instance._data[self.name] = deref_dict + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + deref_dict = {} + classes = {} + + for k, v in value_list: + deref_dict[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + deref_dict[key] = doc_cls._from_son(ref) + instance._data[self.name] = deref_dict return super(MapField, self).__get__(instance, owner) @@ -869,3 +911,76 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') + + + +class DereferenceMixin(object): + """ WORK IN PROGRESS""" + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(MapField, self).__get__(instance, owner) + + is_dict = True + if not hasattr(value_list, 'items'): + is_dict = False + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + if not is_dict: + dbref = [] + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + instance._data[self.name] = dbref + + return super(DereferenceField, self).__get__(instance, owner) \ No newline at end of file diff --git a/mongoengine/tests.py b/mongoengine/tests.py new file mode 100644 index 00000000..4932bc2c --- /dev/null +++ b/mongoengine/tests.py @@ -0,0 +1,58 @@ +from mongoengine.connection import _get_db + +class query_counter(object): + """ Query_counter contextmanager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = _get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/tests/dereference.py b/tests/dereference.py new file mode 100644 index 00000000..2764ee72 --- /dev/null +++ b/tests/dereference.py @@ -0,0 +1,288 @@ +import unittest + +from mongoengine import * +from mongoengine.connection import _get_db +from mongoengine.tests import query_counter + + +class FieldTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = _get_db() + + def ztest_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def ztest_recursive_reference(self): + """Ensure that ReferenceFields can reference their own documents. + """ + class Employee(Document): + name = StringField() + boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + bill = Employee(name='Bill Lumbergh') + bill.save() + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id) + self.assertEqual(q, 1) + + peter.boss + self.assertEqual(q, 2) + + peter.friends + self.assertEqual(q, 3) + + def ztest_generic_reference(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_map_field_reference(self): + + class User(Document): + name = StringField() + + class Group(Document): + members = MapField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + members.append(user) + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def ztest_generic_reference_dict_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = DictField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_generic_reference_map_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = MapField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() \ No newline at end of file From 40df08c74c44546fd04f23f1cba4da0f5f162d0e Mon Sep 17 00:00:00 2001 From: Colin Howe Date: Sun, 29 May 2011 13:33:00 +0100 Subject: [PATCH 06/19] Fix QuerySet.ensure_index for new index specs --- mongoengine/queryset.py | 10 +++++++--- tests/queryset.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 68afefca..2de15ed4 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -376,9 +376,13 @@ class QuerySet(object): construct a multi-field index); keys may be prefixed with a **+** or a **-** to determine the index ordering """ - index_list = QuerySet._build_index_spec(self._document, key_or_list) - self._collection.ensure_index(index_list, drop_dups=drop_dups, - background=background) + index_spec = QuerySet._build_index_spec(self._document, key_or_list) + self._collection.ensure_index( + index_spec['fields'], + drop_dups=drop_dups, + background=background, + sparse=index_spec.get('sparse', False), + unique=index_spec.get('unique', False)) return self @classmethod diff --git a/tests/queryset.py b/tests/queryset.py index 081ffb32..8d046902 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2099,6 +2099,22 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_ensure_index(self): + """Ensure that manual creation of indexes works. + """ + class Comment(Document): + message = StringField() + + Comment.objects.ensure_index('message') + + info = Comment.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info) + + class QTest(unittest.TestCase): def setUp(self): From 9a2cf206b22f7e9697b5e2d7ea47d37230f68206 Mon Sep 17 00:00:00 2001 From: Colin Howe Date: Sun, 29 May 2011 13:38:54 +0100 Subject: [PATCH 07/19] Documentation for new-style indices --- docs/guide/defining-documents.rst | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index e333674e..a524520c 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -341,9 +341,10 @@ Indexes You can specify indexes on collections to make querying faster. This is done by creating a list of index specifications called :attr:`indexes` in the :attr:`~mongoengine.Document.meta` dictionary, where an index specification may -either be a single field name, or a tuple containing multiple field names. A -direction may be specified on fields by prefixing the field name with a **+** -or a **-** sign. Note that direction only matters on multi-field indexes. :: +either be a single field name, a tuple containing multiple field names, or a +dictionary containing a full index definition. A direction may be specified on +fields by prefixing the field name with a **+** or a **-** sign. Note that +direction only matters on multi-field indexes. :: class Page(Document): title = StringField() @@ -352,6 +353,21 @@ or a **-** sign. Note that direction only matters on multi-field indexes. :: 'indexes': ['title', ('title', '-rating')] } +If a dictionary is passed then the following options are available: + +:attr:`fields` (Default: None) + The fields to index. Specified in the same format as described above. + +:attr:`types` (Default: True) + Whether the index should have the :attr:`_types` field added automatically + to the start of the index. + +:attr:`sparse` (Default: False) + Whether the index should be sparse. + +:attr:`unique` (Default: False) + Whether the index should be sparse. + .. note:: Geospatial indexes will be automatically created for all :class:`~mongoengine.GeoPointField`\ s From ec7effa0ef8c3a71d1f8dd0695639f60763b9858 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:04:06 +0100 Subject: [PATCH 08/19] Added DereferenceBaseField class Handles the lazy dereferencing of all items in a list / dict. Improves query efficiency by an order of magnitude. --- mongoengine/base.py | 82 ++++++++++++++++ mongoengine/fields.py | 223 ++++-------------------------------------- tests/dereference.py | 6 +- 3 files changed, 104 insertions(+), 207 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 4e3154fd..ce61547e 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -5,6 +5,7 @@ from queryset import DO_NOTHING import sys import pymongo import pymongo.objectid +from operator import itemgetter class NotRegistered(Exception): @@ -127,6 +128,87 @@ class BaseField(object): self.validate(value) +class DereferenceBaseField(BaseField): + """Handles the lazy dereferencing of a queryset. Will dereference all + items in a list / dict rather than one at a time. + """ + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + from fields import ReferenceField, GenericReferenceField + from connection import _get_db + + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(DereferenceBaseField, self).__get__(instance, owner) + + is_list = False + if not hasattr(value_list, 'items'): + is_list = True + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + + instance._data[self.name] = dbref + + return super(DereferenceBaseField, self).__get__(instance, owner) + + + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c21829c9..dc03fc05 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,5 @@ -from base import BaseField, ObjectIdField, ValidationError, get_document +from base import (BaseField, DereferenceBaseField, ObjectIdField, + ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument from connection import _get_db @@ -12,7 +13,6 @@ import pymongo.binary import datetime, time import decimal import gridfs -import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', @@ -118,8 +118,8 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,6 +153,7 @@ class IntField(BaseField): def prepare_query_value(self, op, value): return int(value) + class FloatField(BaseField): """An floating point number field. """ @@ -178,6 +179,7 @@ class FloatField(BaseField): def prepare_query_value(self, op, value): return float(value) + class DecimalField(BaseField): """A fixed-point decimal number field. @@ -252,21 +254,21 @@ class DateTimeField(BaseField): else: usecs = 0 kwargs = {'microsecond': usecs} - try: # Seconds are optional, so try converting seconds first. + try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], **kwargs) - except ValueError: - try: # Try without seconds. + try: # Try without seconds. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], **kwargs) - except ValueError: # Try without hour/minutes/seconds. + except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], **kwargs) except ValueError: return None + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. @@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(BaseField): +class ListField(DereferenceBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -330,63 +332,6 @@ class ListField(BaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - value_list = [(k,v) for k,v in enumerate(value_list)] - deref_list = [] - collections = {} - - for k, v in value_list: - deref_list.append(v) - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - deref_list[key] = get_document(ref['_cls'])._from_son(ref) - instance._data[self.name] = deref_list - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in enumerate(value_list)] - deref_list = [] - classes = {} - - for k, v in value_list: - deref_list.append(v) - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - deref_list[key] = doc_cls._from_son(ref) - instance._data[self.name] = deref_list - return super(ListField, self).__get__(instance, owner) - def to_python(self, value): return [self.field.to_python(item) for item in value] @@ -480,10 +425,10 @@ class DictField(BaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField,self).prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) -class MapField(BaseField): +class MapField(DereferenceBaseField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -515,68 +460,11 @@ class MapField(BaseField): except Exception, err: raise ValidationError('Invalid MapField item (%s)' % str(item)) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - deref_dict = {} - collections = {} - - for k, v in value_list.items(): - deref_dict[k] = v - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - deref_dict[key] = get_document(ref['_cls'])._from_son(ref) - instance._data[self.name] = deref_dict - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - deref_dict = {} - classes = {} - - for k, v in value_list: - deref_dict[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - deref_dict[key] = doc_cls._from_son(ref) - instance._data[self.name] = deref_dict - - return super(MapField, self).__get__(instance, owner) - def to_python(self, value): - return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) def to_mongo(self, value): - return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) def prepare_query_value(self, op, value): if op not in ('set', 'unset'): @@ -794,11 +682,11 @@ class GridFSProxy(object): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id - def put(self, file, **kwargs): + def put(self, file_obj, **kwargs): if self.grid_id: raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') - self.grid_id = self.fs.put(file, **kwargs) + self.grid_id = self.fs.put(file_obj, **kwargs) def write(self, string): if self.grid_id: @@ -827,9 +715,9 @@ class GridFSProxy(object): self.grid_id = None self.gridout = None - def replace(self, file, **kwargs): + def replace(self, file_obj, **kwargs): self.delete() - self.put(file, **kwargs) + self.put(file_obj, **kwargs) def close(self): if self.newfile: @@ -911,76 +799,3 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') - - - -class DereferenceMixin(object): - """ WORK IN PROGRESS""" - - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if not value_list: - return super(MapField, self).__get__(instance, owner) - - is_dict = True - if not hasattr(value_list, 'items'): - is_dict = False - value_list = dict([(k,v) for k,v in enumerate(value_list)]) - - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - dbref = {} - if not is_dict: - dbref = [] - collections = {} - - for k, v in value_list.items(): - dbref[k] = v - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - dbref[key] = get_document(ref['_cls'])._from_son(ref) - - instance._data[self.name] = dbref - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - dbref = {} - classes = {} - - for k, v in value_list: - dbref[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - dbref[key] = doc_cls._from_son(ref) - instance._data[self.name] = dbref - - return super(DereferenceField, self).__get__(instance, owner) \ No newline at end of file diff --git a/tests/dereference.py b/tests/dereference.py index 2764ee72..b6cee89e 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -11,7 +11,7 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = _get_db() - def ztest_list_item_dereference(self): + def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ class User(Document): @@ -42,7 +42,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() - def ztest_recursive_reference(self): + def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ class Employee(Document): @@ -75,7 +75,7 @@ class FieldTest(unittest.TestCase): peter.friends self.assertEqual(q, 3) - def ztest_generic_reference(self): + def test_generic_reference(self): class UserA(Document): name = StringField() From 7312db5c252bf3c395357cba3b7254cdccd1c6c0 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:07:27 +0100 Subject: [PATCH 09/19] Updated docs / authors. Thanks @jorgebastida for the awesome query_counter test context manager. --- AUTHORS | 1 + docs/changelog.rst | 2 ++ mongoengine/tests.py | 1 + 3 files changed, 4 insertions(+) diff --git a/AUTHORS b/AUTHORS index 93fe819e..aecdcaa9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -3,3 +3,4 @@ Matt Dennewitz Deepak Thukral Florian Schlachter Steve Challis +Ross Lawley diff --git a/docs/changelog.rst b/docs/changelog.rst index 686b326f..58da0d94 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,8 @@ Changelog Changes in dev ============== +- Added query_counter context manager for tests +- Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies - Added inline_map_reduce option to map_reduce - Updated connection exception so it provides more info on the cause. diff --git a/mongoengine/tests.py b/mongoengine/tests.py index 4932bc2c..9584bc7c 100644 --- a/mongoengine/tests.py +++ b/mongoengine/tests.py @@ -1,5 +1,6 @@ from mongoengine.connection import _get_db + class query_counter(object): """ Query_counter contextmanager to get the number of queries. """ From 0e4507811611b80be6529f2376c5e3e9b4d5bdef Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:34:43 +0100 Subject: [PATCH 10/19] Added Blinker signal support --- docs/changelog.rst | 1 + docs/guide/index.rst | 1 + mongoengine/__init__.py | 4 +++- mongoengine/signals.py | 3 +++ 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 58da0d94..659bdb4e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added blinker signal support - Added query_counter context manager for tests - Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies diff --git a/docs/guide/index.rst b/docs/guide/index.rst index aac72469..d56e7479 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -11,3 +11,4 @@ User Guide document-instances querying gridfs + signals diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 6d18ffe7..de635f96 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -6,9 +6,11 @@ import connection from connection import * import queryset from queryset import * +import signals +from signals import * __all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__) + queryset.__all__ + signals.__all__) __author__ = 'Harry Marr' diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 4caa5530..0a697534 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- +__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', + 'pre_delete', 'post_delete'] + signals_available = False try: from blinker import Namespace From 74b5043ef9441938c6668af6eb510adccc8e531a Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 11:39:58 +0100 Subject: [PATCH 11/19] Added signals documentation --- docs/guide/signals.rst | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 docs/guide/signals.rst diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst new file mode 100644 index 00000000..d80a421b --- /dev/null +++ b/docs/guide/signals.rst @@ -0,0 +1,49 @@ +.. _signals: + +Signals +======= + +.. versionadded:: 0.5 + +Signal support is provided by the excellent `blinker`_ library and +will gracefully fall back if it is not available. + + +The following document signals exist in MongoEngine and are pretty self explaintary: + + * `mongoengine.signals.pre_init` + * `mongoengine.signals.post_init` + * `mongoengine.signals.pre_save` + * `mongoengine.signals.post_save` + * `mongoengine.signals.pre_delete` + * `mongoengine.signals.post_delete` + +Example usage:: + + from mongoengine import * + from mongoengine import signals + + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_save(cls, instance, **kwargs): + logging.debug("Pre Save: %s" % instance.name) + + @classmethod + def post_save(cls, instance, **kwargs): + logging.debug("Post Save: %s" % instance.name) + if 'created' in kwargs: + if kwargs['created']: + logging.debug("Created") + else: + logging.debug("Updated") + + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + + +.. _blinker: http://pypi.python.org/pypi/blinker \ No newline at end of file From 56f00a64d77655bee2d00ebd783d07655a6900ff Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 12:37:06 +0100 Subject: [PATCH 12/19] Added bulk insert method. Updated changelog and added tests / query_counter tests --- docs/changelog.rst | 1 + mongoengine/queryset.py | 42 ++++++++++++++++++++- tests/queryset.py | 83 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 659bdb4e..29ecdf7a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added insert method for bulk inserts - Added blinker signal support - Added query_counter context manager for tests - Added DereferenceBaseField - for improved performance in field dereferencing diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 2de15ed4..0e87db7a 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -378,7 +378,7 @@ class QuerySet(object): """ index_spec = QuerySet._build_index_spec(self._document, key_or_list) self._collection.ensure_index( - index_spec['fields'], + index_spec['fields'], drop_dups=drop_dups, background=background, sparse=index_spec.get('sparse', False), @@ -719,6 +719,46 @@ class QuerySet(object): result = None return result + def insert(self, doc_or_docs, load_bulk=True): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document instances + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + from document import Document + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = "Some documents inserted aren't instances of %s" % str(self._document) + raise OperationError(msg) + if doc.pk: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + ids = self._collection.insert(raw) + + if not load_bulk: + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + return return_one and results[0] or results + def with_id(self, object_id): """Retrieve the object matching the id provided. diff --git a/tests/queryset.py b/tests/queryset.py index 8d046902..0b64e3e9 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -9,6 +9,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, QueryFieldList) from mongoengine import * +from mongoengine.tests import query_counter class QuerySetTest(unittest.TestCase): @@ -331,6 +332,88 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") + def test_bulk_insert(self): + """Ensure that query by array position works. + """ + + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + title = StringField() + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + with query_counter() as q: + self.assertEqual(q, 0) + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + + blogs = [] + for i in xrange(1, 100): + blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) + + Blog.objects.insert(blogs, load_bulk=False) + self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert + + Blog.objects.insert(blogs) + self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog(title="code", posts=[post1, post2]) + blog2 = Blog(title="mongodb", posts=[post2, post1]) + blog1, blog2 = Blog.objects.insert([blog1, blog2]) + self.assertEqual(blog1.title, "code") + self.assertEqual(blog2.title, "mongodb") + + self.assertEqual(Blog.objects.count(), 2) + + # test handles people trying to upsert + def throw_operation_error(): + blogs = Blog.objects + Blog.objects.insert(blogs) + + self.assertRaises(OperationError, throw_operation_error) + + # test handles other classes being inserted + def throw_operation_error_wrong_doc(): + class Author(Document): + pass + Blog.objects.insert(Author()) + + self.assertRaises(OperationError, throw_operation_error_wrong_doc) + + def throw_operation_error_not_a_document(): + Blog.objects.insert("HELLO WORLD") + + self.assertRaises(OperationError, throw_operation_error_not_a_document) + + Blog.drop_collection() + + blog1 = Blog(title="code", posts=[post1, post2]) + blog1 = Blog.objects.insert(blog1) + self.assertEqual(blog1.title, "code") + self.assertEqual(Blog.objects.count(), 1) + + Blog.drop_collection() + blog1 = Blog(title="code", posts=[post1, post2]) + obj_id = Blog.objects.insert(blog1, load_bulk=False) + self.assertEquals(obj_id.__class__.__name__, 'ObjectId') + + def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """ From 55e20bda12ea6ee7a39d6d5ebdf124bfb5cc4689 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 14:35:46 +0100 Subject: [PATCH 13/19] Added slave_okay syntax to querysets. * slave_okay (optional): if True, allows this query to be run against a replica secondary. --- mongoengine/queryset.py | 7 ++++++- tests/queryset.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0e87db7a..7b4fef35 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -336,6 +336,7 @@ class QuerySet(object): self._snapshot = False self._timeout = True self._class_check = True + self._slave_okay = False # If inheritance is allowed, only return instances and instances of # subclasses of the class being used @@ -430,7 +431,7 @@ class QuerySet(object): return spec - def __call__(self, q_obj=None, class_check=True, **query): + def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): """Filter the selected documents by calling the :class:`~mongoengine.queryset.QuerySet` with a query. @@ -440,6 +441,8 @@ class QuerySet(object): objects, only the last one will be used :param class_check: If set to False bypass class name check when querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. :param query: Django-style query keyword arguments """ query = Q(**query) @@ -449,6 +452,7 @@ class QuerySet(object): self._mongo_query = None self._cursor_obj = None self._class_check = class_check + self._slave_okay = slave_okay return self def filter(self, *q_objs, **query): @@ -506,6 +510,7 @@ class QuerySet(object): cursor_args = { 'snapshot': self._snapshot, 'timeout': self._timeout, + 'slave_okay': self._slave_okay } if self._loaded_fields: cursor_args['fields'] = self._loaded_fields.as_dict() diff --git a/tests/queryset.py b/tests/queryset.py index 0b64e3e9..28d44861 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -413,6 +413,19 @@ class QuerySetTest(unittest.TestCase): obj_id = Blog.objects.insert(blog1, load_bulk=False) self.assertEquals(obj_id.__class__.__name__, 'ObjectId') + def test_slave_okay(self): + """Ensures that a query can take slave_okay syntax + """ + person1 = self.Person(name="User A", age=20) + person1.save() + person2 = self.Person(name="User B", age=30) + person2.save() + + # Retrieve the first person from the database + person = self.Person.objects(slave_okay=True).first() + self.assertTrue(isinstance(person, self.Person)) + self.assertEqual(person.name, "User A") + self.assertEqual(person.age, 20) def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. From 711db45c022cae092069432d42e9267411f80008 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 14:36:44 +0100 Subject: [PATCH 14/19] Changelist updated --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 29ecdf7a..ed877ebb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added slave_okay kwarg to queryset - Added insert method for bulk inserts - Added blinker signal support - Added query_counter context manager for tests From d63bf0abde50b50ef091426aa0cef9b1646ed308 Mon Sep 17 00:00:00 2001 From: kuno Date: Tue, 7 Jun 2011 20:19:29 +0800 Subject: [PATCH 15/19] fixed import path typo in django documents --- docs/django.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/django.rst b/docs/django.rst index 8a490571..bbfbb565 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -52,7 +52,7 @@ it is useful to have a Django file storage backend that wraps this. The new storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it is very similar to using the default FileSystemStorage.:: - fs = mongoengine.django.GridFSStorage() + fs = mongoengine.django.storage.GridFSStorage() filename = fs.save('hello.txt', 'Hello, World!') From c059ad47f24394c3bb3f4b4f24a6f9e91280c4ab Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 7 Jun 2011 15:14:41 +0100 Subject: [PATCH 16/19] Updated django docs refs #186 --- docs/django.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/django.rst b/docs/django.rst index bbfbb565..4478b94f 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -49,10 +49,11 @@ Storage ======= With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, it is useful to have a Django file storage backend that wraps this. The new -storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it -is very similar to using the default FileSystemStorage.:: - - fs = mongoengine.django.storage.GridFSStorage() +storage module is called :class:`~mongoengine.django.storage.GridFSStorage`. +Using it is very similar to using the default FileSystemStorage.:: + + from mongoengine.django.storage import GridFSStorage + fs = GridFSStorage() filename = fs.save('hello.txt', 'Hello, World!') From cfcd77b193da1eb03ef5632f88cd2189f58b2974 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 10:33:56 +0100 Subject: [PATCH 17/19] Added tests displaying datetime behaviour. Updated datetimefield documentation --- mongoengine/fields.py | 4 +++ tests/fields.py | 60 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index dc03fc05..1995d345 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -229,6 +229,10 @@ class BooleanField(BaseField): class DateTimeField(BaseField): """A datetime field. + + Note: Microseconds are rounded to the nearest millisecond. + Pre UTC microsecond support is effecively broken see + `tests.field.test_datetime` for more information. """ def validate(self, value): diff --git a/tests/fields.py b/tests/fields.py index 00b1c886..320e33db 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -187,6 +187,66 @@ class FieldTest(unittest.TestCase): log.time = '1pm' self.assertRaises(ValidationError, log.validate) + def test_datetime(self): + """Tests showing pymongo datetime fields handling of microseconds. + Microseconds are rounded to the nearest millisecond and pre UTC + handling is wonky. + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and dropped + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Post UTC - microseconds are rounded (down) nearest millisecond + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) + d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Pre UTC dates microseconds below 1000 are dropped + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Pre UTC microseconds above 1000 is wonky. + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + # + # However, the timedelta is predicable with pre UTC timestamps + # It always adds 16 seconds and [777216-776217] microseconds + for i in xrange(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + + delta = log.date - d1 + self.assertEquals(delta.seconds, 16) + microseconds = 777216 - (i % 1000) + self.assertEquals(delta.microseconds, microseconds) + + LogEntry.drop_collection() + def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. """ From d15f5ccbf43e31557c43eb238028537e9a59c089 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 10:41:08 +0100 Subject: [PATCH 18/19] Added _slave_okay to clone --- mongoengine/queryset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 7b4fef35..a1e1245f 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -353,7 +353,7 @@ class QuerySet(object): copy_props = ('_initial_query', '_query_obj', '_where_clause', '_loaded_fields', '_ordering', '_snapshot', - '_timeout', '_limit', '_skip') + '_timeout', '_limit', '_skip', '_slave_okay') for prop in copy_props: val = getattr(self, prop) From 3c88faa889e01071c6953992307112f20140f2f7 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 12:06:26 +0100 Subject: [PATCH 19/19] Updated slave_okay syntax Now inline with .timeout() and .snapshot(). Made them chainable - so its easier to use and added tests for cursor_args --- mongoengine/queryset.py | 37 ++++++++++++++++++++++++++----------- tests/queryset.py | 26 +++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index a1e1245f..f542cc87 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -452,7 +452,6 @@ class QuerySet(object): self._mongo_query = None self._cursor_obj = None self._class_check = class_check - self._slave_okay = slave_okay return self def filter(self, *q_objs, **query): @@ -504,18 +503,23 @@ class QuerySet(object): return self._collection_obj + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + 'slave_okay': self._slave_okay + } + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + @property def _cursor(self): if self._cursor_obj is None: - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout, - 'slave_okay': self._slave_okay - } - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() + self._cursor_obj = self._collection.find(self._query, - **cursor_args) + **self._cursor_args) # Apply where clauses to cursor if self._where_clause: self._cursor_obj.where(self._where_clause) @@ -772,7 +776,7 @@ class QuerySet(object): id_field = self._document._meta['id_field'] object_id = self._document._fields[id_field].to_mongo(object_id) - result = self._collection.find_one({'_id': object_id}) + result = self._collection.find_one({'_id': object_id}, **self._cursor_args) if result is not None: result = self._document._from_son(result) return result @@ -788,7 +792,8 @@ class QuerySet(object): """ doc_map = {} - docs = self._collection.find({'_id': {'$in': object_ids}}) + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) for doc in docs: doc_map[doc['_id']] = self._document._from_son(doc) @@ -1085,6 +1090,7 @@ class QuerySet(object): :param enabled: whether or not snapshot mode is enabled """ self._snapshot = enabled + return self def timeout(self, enabled): """Enable or disable the default mongod timeout when querying. @@ -1092,6 +1098,15 @@ class QuerySet(object): :param enabled: whether or not the timeout is used """ self._timeout = enabled + return self + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + self._slave_okay = enabled + return self def delete(self, safe=False): """Delete the documents matched by the query. diff --git a/tests/queryset.py b/tests/queryset.py index 28d44861..1947254b 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -422,11 +422,35 @@ class QuerySetTest(unittest.TestCase): person2.save() # Retrieve the first person from the database - person = self.Person.objects(slave_okay=True).first() + person = self.Person.objects.slave_okay(True).first() self.assertTrue(isinstance(person, self.Person)) self.assertEqual(person.name, "User A") self.assertEqual(person.age, 20) + def test_cursor_args(self): + """Ensures the cursor args can be set as expected + """ + p = self.Person.objects + # Check default + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': True}) + + p.snapshot(False).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(True) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': True}) + def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """