diff --git a/AUTHORS b/AUTHORS index 93fe819e..aecdcaa9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -3,3 +3,4 @@ Matt Dennewitz Deepak Thukral Florian Schlachter Steve Challis +Ross Lawley diff --git a/docs/changelog.rst b/docs/changelog.rst index 686b326f..ed877ebb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,11 @@ Changelog Changes in dev ============== +- Added slave_okay kwarg to queryset +- Added insert method for bulk inserts +- Added blinker signal support +- Added query_counter context manager for tests +- Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies - Added inline_map_reduce option to map_reduce - Updated connection exception so it provides more info on the cause. diff --git a/docs/django.rst b/docs/django.rst index 8a490571..4478b94f 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -49,10 +49,11 @@ Storage ======= With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, it is useful to have a Django file storage backend that wraps this. The new -storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it -is very similar to using the default FileSystemStorage.:: - - fs = mongoengine.django.GridFSStorage() +storage module is called :class:`~mongoengine.django.storage.GridFSStorage`. +Using it is very similar to using the default FileSystemStorage.:: + + from mongoengine.django.storage import GridFSStorage + fs = GridFSStorage() filename = fs.save('hello.txt', 'Hello, World!') diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index e333674e..a524520c 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -341,9 +341,10 @@ Indexes You can specify indexes on collections to make querying faster. This is done by creating a list of index specifications called :attr:`indexes` in the :attr:`~mongoengine.Document.meta` dictionary, where an index specification may -either be a single field name, or a tuple containing multiple field names. A -direction may be specified on fields by prefixing the field name with a **+** -or a **-** sign. Note that direction only matters on multi-field indexes. :: +either be a single field name, a tuple containing multiple field names, or a +dictionary containing a full index definition. A direction may be specified on +fields by prefixing the field name with a **+** or a **-** sign. Note that +direction only matters on multi-field indexes. :: class Page(Document): title = StringField() @@ -352,6 +353,21 @@ or a **-** sign. Note that direction only matters on multi-field indexes. :: 'indexes': ['title', ('title', '-rating')] } +If a dictionary is passed then the following options are available: + +:attr:`fields` (Default: None) + The fields to index. Specified in the same format as described above. + +:attr:`types` (Default: True) + Whether the index should have the :attr:`_types` field added automatically + to the start of the index. + +:attr:`sparse` (Default: False) + Whether the index should be sparse. + +:attr:`unique` (Default: False) + Whether the index should be sparse. + .. note:: Geospatial indexes will be automatically created for all :class:`~mongoengine.GeoPointField`\ s diff --git a/docs/guide/index.rst b/docs/guide/index.rst index aac72469..d56e7479 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -11,3 +11,4 @@ User Guide document-instances querying gridfs + signals diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst new file mode 100644 index 00000000..d80a421b --- /dev/null +++ b/docs/guide/signals.rst @@ -0,0 +1,49 @@ +.. _signals: + +Signals +======= + +.. versionadded:: 0.5 + +Signal support is provided by the excellent `blinker`_ library and +will gracefully fall back if it is not available. + + +The following document signals exist in MongoEngine and are pretty self explaintary: + + * `mongoengine.signals.pre_init` + * `mongoengine.signals.post_init` + * `mongoengine.signals.pre_save` + * `mongoengine.signals.post_save` + * `mongoengine.signals.pre_delete` + * `mongoengine.signals.post_delete` + +Example usage:: + + from mongoengine import * + from mongoengine import signals + + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_save(cls, instance, **kwargs): + logging.debug("Pre Save: %s" % instance.name) + + @classmethod + def post_save(cls, instance, **kwargs): + logging.debug("Post Save: %s" % instance.name) + if 'created' in kwargs: + if kwargs['created']: + logging.debug("Created") + else: + logging.debug("Updated") + + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + + +.. _blinker: http://pypi.python.org/pypi/blinker \ No newline at end of file diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 6d18ffe7..de635f96 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -6,9 +6,11 @@ import connection from connection import * import queryset from queryset import * +import signals +from signals import * __all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__) + queryset.__all__ + signals.__all__) __author__ = 'Harry Marr' diff --git a/mongoengine/base.py b/mongoengine/base.py index ffceb794..76bb1ab7 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager from queryset import DoesNotExist, MultipleObjectsReturned from queryset import DO_NOTHING +from mongoengine import signals + import sys import pymongo import pymongo.objectid +from operator import itemgetter class NotRegistered(Exception): @@ -126,6 +129,88 @@ class BaseField(object): self.validate(value) + +class DereferenceBaseField(BaseField): + """Handles the lazy dereferencing of a queryset. Will dereference all + items in a list / dict rather than one at a time. + """ + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + from fields import ReferenceField, GenericReferenceField + from connection import _get_db + + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(DereferenceBaseField, self).__get__(instance, owner) + + is_list = False + if not hasattr(value_list, 'items'): + is_list = True + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + + instance._data[self.name] = dbref + + return super(DereferenceBaseField, self).__get__(instance, owner) + + + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ @@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class BaseDocument(object): def __init__(self, **values): + signals.pre_init.send(self, values=values) + self._data = {} # Assign default values to instance for attr_name in self._fields.keys(): @@ -395,6 +482,8 @@ class BaseDocument(object): except AttributeError: pass + signals.post_init.send(self) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. diff --git a/mongoengine/document.py b/mongoengine/document.py index 771b9229..b563f427 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,3 +1,4 @@ +from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError) from queryset import OperationError @@ -75,6 +76,8 @@ class Document(BaseDocument): For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. """ + signals.pre_save.send(self) + if validate: self.validate() @@ -82,6 +85,7 @@ class Document(BaseDocument): write_options = {} doc = self.to_mongo() + created = '_id' not in doc try: collection = self.__class__.objects._collection if force_insert: @@ -96,12 +100,16 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) + signals.post_save.send(self, created=created) + def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. :param safe: check if the operation succeeded before returning """ + signals.pre_delete.send(self) + id_field = self._meta['id_field'] object_id = self._fields[id_field].to_mongo(self[id_field]) try: @@ -110,6 +118,8 @@ class Document(BaseDocument): message = u'Could not delete document (%s)' % err.message raise OperationError(message) + signals.post_delete.send(self) + @classmethod def register_delete_rule(cls, document_cls, field_name, rule): """This method registers the delete rules to apply when removing this diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b2aab5a4..1995d345 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,5 @@ -from base import BaseField, ObjectIdField, ValidationError, get_document +from base import (BaseField, DereferenceBaseField, ObjectIdField, + ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument from connection import _get_db @@ -12,7 +13,6 @@ import pymongo.binary import datetime, time import decimal import gridfs -import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', @@ -118,8 +118,8 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,6 +153,7 @@ class IntField(BaseField): def prepare_query_value(self, op, value): return int(value) + class FloatField(BaseField): """An floating point number field. """ @@ -178,6 +179,7 @@ class FloatField(BaseField): def prepare_query_value(self, op, value): return float(value) + class DecimalField(BaseField): """A fixed-point decimal number field. @@ -227,6 +229,10 @@ class BooleanField(BaseField): class DateTimeField(BaseField): """A datetime field. + + Note: Microseconds are rounded to the nearest millisecond. + Pre UTC microsecond support is effecively broken see + `tests.field.test_datetime` for more information. """ def validate(self, value): @@ -252,21 +258,21 @@ class DateTimeField(BaseField): else: usecs = 0 kwargs = {'microsecond': usecs} - try: # Seconds are optional, so try converting seconds first. + try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], **kwargs) - except ValueError: - try: # Try without seconds. + try: # Try without seconds. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], **kwargs) - except ValueError: # Try without hour/minutes/seconds. + except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], **kwargs) except ValueError: return None + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. @@ -314,7 +320,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(BaseField): +class ListField(DereferenceBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -330,42 +336,6 @@ class ListField(BaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_list.append(referenced_type._from_son(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list - - if isinstance(self.field, GenericReferenceField): - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_list.append(self.field.dereference(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list - - return super(ListField, self).__get__(instance, owner) - def to_python(self, value): return [self.field.to_python(item) for item in value] @@ -459,10 +429,10 @@ class DictField(BaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField,self).prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) -class MapField(BaseField): +class MapField(DereferenceBaseField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -494,47 +464,11 @@ class MapField(BaseField): except Exception, err: raise ValidationError('Invalid MapField item (%s)' % str(item)) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_dict[key] = referenced_type._from_son(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict - - if isinstance(self.field, GenericReferenceField): - value_dict = instance._data.get(self.name) - if value_dict: - deref_dict = [] - for key,value in value_dict.iteritems(): - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_dict[key] = self.field.dereference(value) - else: - deref_dict[key] = value - instance._data[self.name] = deref_dict - - return super(MapField, self).__get__(instance, owner) - def to_python(self, value): - return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) def to_mongo(self, value): - return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) def prepare_query_value(self, op, value): if op not in ('set', 'unset'): @@ -752,11 +686,11 @@ class GridFSProxy(object): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id - def put(self, file, **kwargs): + def put(self, file_obj, **kwargs): if self.grid_id: raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') - self.grid_id = self.fs.put(file, **kwargs) + self.grid_id = self.fs.put(file_obj, **kwargs) def write(self, string): if self.grid_id: @@ -785,9 +719,9 @@ class GridFSProxy(object): self.grid_id = None self.gridout = None - def replace(self, file, **kwargs): + def replace(self, file_obj, **kwargs): self.delete() - self.put(file, **kwargs) + self.put(file_obj, **kwargs) def close(self): if self.newfile: diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 303afb6a..1dfe55af 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -336,6 +336,7 @@ class QuerySet(object): self._snapshot = False self._timeout = True self._class_check = True + self._slave_okay = False # If inheritance is allowed, only return instances and instances of # subclasses of the class being used @@ -352,7 +353,7 @@ class QuerySet(object): copy_props = ('_initial_query', '_query_obj', '_where_clause', '_loaded_fields', '_ordering', '_snapshot', - '_timeout', '_limit', '_skip') + '_timeout', '_limit', '_skip', '_slave_okay') for prop in copy_props: val = getattr(self, prop) @@ -376,21 +377,27 @@ class QuerySet(object): construct a multi-field index); keys may be prefixed with a **+** or a **-** to determine the index ordering """ - index_list = QuerySet._build_index_spec(self._document, key_or_list) - self._collection.ensure_index(index_list, drop_dups=drop_dups, - background=background) + index_spec = QuerySet._build_index_spec(self._document, key_or_list) + self._collection.ensure_index( + index_spec['fields'], + drop_dups=drop_dups, + background=background, + sparse=index_spec.get('sparse', False), + unique=index_spec.get('unique', False)) return self @classmethod - def _build_index_spec(cls, doc_cls, key_or_list): + def _build_index_spec(cls, doc_cls, spec): """Build a PyMongo index spec from a MongoEngine index spec. """ - if isinstance(key_or_list, basestring): - key_or_list = [key_or_list] + if isinstance(spec, basestring): + spec = {'fields': [spec]} + if isinstance(spec, (list, tuple)): + spec = {'fields': spec} index_list = [] use_types = doc_cls._meta.get('allow_inheritance', True) - for key in key_or_list: + for key in spec['fields']: # Get direction from + or - direction = pymongo.ASCENDING if key.startswith("-"): @@ -410,15 +417,22 @@ class QuerySet(object): if use_types and not all(f._index_with_types for f in fields): use_types = False - # If _types is being used, create an index for it + # If _types is being used, prepend it to every specified index index_types = doc_cls._meta.get('index_types', True) allow_inheritance = doc_cls._meta.get('allow_inheritance') - if index_types and allow_inheritance and use_types: + if spec.get('types', index_types) and allow_inheritance and use_types: index_list.insert(0, ('_types', 1)) - return index_list + spec['fields'] = index_list - def __call__(self, q_obj=None, class_check=True, **query): + if spec.get('sparse', False) and len(spec['fields']) > 1: + raise ValueError( + 'Sparse indexes can only have one field in them. ' + 'See https://jira.mongodb.org/browse/SERVER-2193') + + return spec + + def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): """Filter the selected documents by calling the :class:`~mongoengine.queryset.QuerySet` with a query. @@ -428,6 +442,8 @@ class QuerySet(object): objects, only the last one will be used :param class_check: If set to False bypass class name check when querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. :param query: Django-style query keyword arguments """ query = Q(**query) @@ -468,9 +484,12 @@ class QuerySet(object): # Ensure document-defined indexes are created if self._document._meta['indexes']: - for key_or_list in self._document._meta['indexes']: - self._collection.ensure_index(key_or_list, - background=background, **index_opts) + for spec in self._document._meta['indexes']: + opts = index_opts.copy() + opts['unique'] = spec.get('unique', False) + opts['sparse'] = spec.get('sparse', False) + self._collection.ensure_index(spec['fields'], + background=background, **opts) # If _types is being used (for polymorphism), it needs an index if index_types and '_types' in self._query: @@ -486,17 +505,23 @@ class QuerySet(object): return self._collection_obj + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + 'slave_okay': self._slave_okay + } + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + @property def _cursor(self): if self._cursor_obj is None: - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout, - } - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() + self._cursor_obj = self._collection.find(self._query, - **cursor_args) + **self._cursor_args) # Apply where clauses to cursor if self._where_clause: self._cursor_obj.where(self._where_clause) @@ -705,6 +730,46 @@ class QuerySet(object): result = None return result + def insert(self, doc_or_docs, load_bulk=True): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document instances + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + from document import Document + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = "Some documents inserted aren't instances of %s" % str(self._document) + raise OperationError(msg) + if doc.pk: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + ids = self._collection.insert(raw) + + if not load_bulk: + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + return return_one and results[0] or results + def with_id(self, object_id): """Retrieve the object matching the id provided. @@ -713,7 +778,7 @@ class QuerySet(object): id_field = self._document._meta['id_field'] object_id = self._document._fields[id_field].to_mongo(object_id) - result = self._collection.find_one({'_id': object_id}) + result = self._collection.find_one({'_id': object_id}, **self._cursor_args) if result is not None: result = self._document._from_son(result) return result @@ -729,7 +794,8 @@ class QuerySet(object): """ doc_map = {} - docs = self._collection.find({'_id': {'$in': object_ids}}) + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) for doc in docs: doc_map[doc['_id']] = self._document._from_son(doc) @@ -1026,6 +1092,7 @@ class QuerySet(object): :param enabled: whether or not snapshot mode is enabled """ self._snapshot = enabled + return self def timeout(self, enabled): """Enable or disable the default mongod timeout when querying. @@ -1033,6 +1100,15 @@ class QuerySet(object): :param enabled: whether or not the timeout is used """ self._timeout = enabled + return self + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + self._slave_okay = enabled + return self def delete(self, safe=False): """Delete the documents matched by the query. diff --git a/mongoengine/signals.py b/mongoengine/signals.py new file mode 100644 index 00000000..0a697534 --- /dev/null +++ b/mongoengine/signals.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', + 'pre_delete', 'post_delete'] + +signals_available = False +try: + from blinker import Namespace + signals_available = True +except ImportError: + class Namespace(object): + def signal(self, name, doc=None): + return _FakeSignal(name, doc) + + class _FakeSignal(object): + """If blinker is unavailable, create a fake class with the same + interface that allows sending of signals but will fail with an + error on anything else. Instead of doing anything on send, it + will just ignore the arguments and do nothing instead. + """ + + def __init__(self, name, doc=None): + self.name = name + self.__doc__ = doc + + def _fail(self, *args, **kwargs): + raise RuntimeError('signalling support is unavailable ' + 'because the blinker library is ' + 'not installed.') + send = lambda *a, **kw: None + connect = disconnect = has_receivers_for = receivers_for = \ + temporarily_connected_to = _fail + del _fail + +# the namespace for code signals. If you are not mongoengine code, do +# not put signals in here. Create your own namespace instead. +_signals = Namespace() + +pre_init = _signals.signal('pre_init') +post_init = _signals.signal('post_init') +pre_save = _signals.signal('pre_save') +post_save = _signals.signal('post_save') +pre_delete = _signals.signal('pre_delete') +post_delete = _signals.signal('post_delete') diff --git a/mongoengine/tests.py b/mongoengine/tests.py new file mode 100644 index 00000000..9584bc7c --- /dev/null +++ b/mongoengine/tests.py @@ -0,0 +1,59 @@ +from mongoengine.connection import _get_db + + +class query_counter(object): + """ Query_counter contextmanager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = _get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/setup.py b/setup.py index 0c19d8d0..d3be64b3 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo'], + install_requires=['pymongo', 'blinker'], test_suite='tests', ) diff --git a/tests/dereference.py b/tests/dereference.py new file mode 100644 index 00000000..b6cee89e --- /dev/null +++ b/tests/dereference.py @@ -0,0 +1,288 @@ +import unittest + +from mongoengine import * +from mongoengine.connection import _get_db +from mongoengine.tests import query_counter + + +class FieldTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = _get_db() + + def test_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def test_recursive_reference(self): + """Ensure that ReferenceFields can reference their own documents. + """ + class Employee(Document): + name = StringField() + boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + bill = Employee(name='Bill Lumbergh') + bill.save() + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id) + self.assertEqual(q, 1) + + peter.boss + self.assertEqual(q, 2) + + peter.friends + self.assertEqual(q, 3) + + def test_generic_reference(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_map_field_reference(self): + + class User(Document): + name = StringField() + + class Group(Document): + members = MapField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + members.append(user) + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def ztest_generic_reference_dict_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = DictField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_generic_reference_map_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = MapField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + group.members = {} + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() \ No newline at end of file diff --git a/tests/document.py b/tests/document.py index fe67312e..a8120469 100644 --- a/tests/document.py +++ b/tests/document.py @@ -377,6 +377,40 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() + + def test_dictionary_indexes(self): + """Ensure that indexes are used when meta[indexes] contains dictionaries + instead of lists. + """ + class BlogPost(Document): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + { 'fields': ['-date'], 'unique': True, + 'sparse': True, 'types': False }, + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # _id, '-date' + self.assertEqual(len(info), 3) + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('addDate', -1)], True, True) in info) + + BlogPost.drop_collection() + + def test_unique(self): """Ensure that uniqueness constraints are applied to fields. """ diff --git a/tests/fields.py b/tests/fields.py index 00b1c886..320e33db 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -187,6 +187,66 @@ class FieldTest(unittest.TestCase): log.time = '1pm' self.assertRaises(ValidationError, log.validate) + def test_datetime(self): + """Tests showing pymongo datetime fields handling of microseconds. + Microseconds are rounded to the nearest millisecond and pre UTC + handling is wonky. + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and dropped + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Post UTC - microseconds are rounded (down) nearest millisecond + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) + d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Pre UTC dates microseconds below 1000 are dropped + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Pre UTC microseconds above 1000 is wonky. + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + # + # However, the timedelta is predicable with pre UTC timestamps + # It always adds 16 seconds and [777216-776217] microseconds + for i in xrange(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + + delta = log.date - d1 + self.assertEquals(delta.seconds, 16) + microseconds = 777216 - (i % 1000) + self.assertEquals(delta.microseconds, microseconds) + + LogEntry.drop_collection() + def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. """ diff --git a/tests/queryset.py b/tests/queryset.py index 1e5e7a5a..37140f4a 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -9,6 +9,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, QueryFieldList) from mongoengine import * +from mongoengine.tests import query_counter class QuerySetTest(unittest.TestCase): @@ -331,6 +332,125 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") + def test_bulk_insert(self): + """Ensure that query by array position works. + """ + + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + title = StringField() + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + with query_counter() as q: + self.assertEqual(q, 0) + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + + blogs = [] + for i in xrange(1, 100): + blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) + + Blog.objects.insert(blogs, load_bulk=False) + self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert + + Blog.objects.insert(blogs) + self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog(title="code", posts=[post1, post2]) + blog2 = Blog(title="mongodb", posts=[post2, post1]) + blog1, blog2 = Blog.objects.insert([blog1, blog2]) + self.assertEqual(blog1.title, "code") + self.assertEqual(blog2.title, "mongodb") + + self.assertEqual(Blog.objects.count(), 2) + + # test handles people trying to upsert + def throw_operation_error(): + blogs = Blog.objects + Blog.objects.insert(blogs) + + self.assertRaises(OperationError, throw_operation_error) + + # test handles other classes being inserted + def throw_operation_error_wrong_doc(): + class Author(Document): + pass + Blog.objects.insert(Author()) + + self.assertRaises(OperationError, throw_operation_error_wrong_doc) + + def throw_operation_error_not_a_document(): + Blog.objects.insert("HELLO WORLD") + + self.assertRaises(OperationError, throw_operation_error_not_a_document) + + Blog.drop_collection() + + blog1 = Blog(title="code", posts=[post1, post2]) + blog1 = Blog.objects.insert(blog1) + self.assertEqual(blog1.title, "code") + self.assertEqual(Blog.objects.count(), 1) + + Blog.drop_collection() + blog1 = Blog(title="code", posts=[post1, post2]) + obj_id = Blog.objects.insert(blog1, load_bulk=False) + self.assertEquals(obj_id.__class__.__name__, 'ObjectId') + + def test_slave_okay(self): + """Ensures that a query can take slave_okay syntax + """ + person1 = self.Person(name="User A", age=20) + person1.save() + person2 = self.Person(name="User B", age=30) + person2.save() + + # Retrieve the first person from the database + person = self.Person.objects.slave_okay(True).first() + self.assertTrue(isinstance(person, self.Person)) + self.assertEqual(person.name, "User A") + self.assertEqual(person.age, 20) + + def test_cursor_args(self): + """Ensures the cursor args can be set as expected + """ + p = self.Person.objects + # Check default + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': True}) + + p.snapshot(False).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(True) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': True}) + def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """ @@ -2115,8 +2235,27 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_ensure_index(self): + """Ensure that manual creation of indexes works. + """ + class Comment(Document): + message = StringField() + + Comment.objects.ensure_index('message') + + info = Comment.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info) + + class QTest(unittest.TestCase): + def setUp(self): + connect(db='mongoenginetest') + def test_empty_q(self): """Ensure that empty Q objects won't hurt. """ diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 00000000..fff2d398 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import * +from mongoengine import signals + +signal_output = [] + + +class SignalTests(unittest.TestCase): + """ + Testing signals before/after saving and deleting. + """ + + def get_signal_output(self, fn, *args, **kwargs): + # Flush any existing signal output + global signal_output + signal_output = [] + fn(*args, **kwargs) + return signal_output + + def setUp(self): + connect(db='mongoenginetest') + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, instance, **kwargs): + signal_output.append('pre_init signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, instance, **kwargs): + signal_output.append('post_init signal, %s' % instance) + + @classmethod + def pre_save(cls, instance, **kwargs): + signal_output.append('pre_save signal, %s' % instance) + + @classmethod + def post_save(cls, instance, **kwargs): + signal_output.append('post_save signal, %s' % instance) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, instance, **kwargs): + signal_output.append('pre_delete signal, %s' % instance) + + @classmethod + def post_delete(cls, instance, **kwargs): + signal_output.append('post_delete signal, %s' % instance) + + self.Author = Author + + # Save up the number of connected signals so that we can check at the end + # that all the signals we register get properly unregistered + self.pre_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + signals.pre_init.connect(Author.pre_init) + signals.post_init.connect(Author.post_init) + signals.pre_save.connect(Author.pre_save) + signals.post_save.connect(Author.post_save) + signals.pre_delete.connect(Author.pre_delete) + signals.post_delete.connect(Author.post_delete) + + def tearDown(self): + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save.disconnect(self.Author.pre_save) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + self.assertEqual(self.pre_signals, post_signals) + + def test_model_signals(self): + """ Model saves should throw some signals. """ + + def create_author(): + a1 = self.Author(name='Bill Shakespeare') + + self.assertEqual(self.get_signal_output(create_author), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", + ]) + + a1 = self.Author(name='Bill Shakespeare') + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, Bill Shakespeare", + "post_save signal, Bill Shakespeare", + "Is created" + ]) + + a1.reload() + a1.name='William Shakespeare' + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, William Shakespeare", + "post_save signal, William Shakespeare", + "Is updated" + ]) + + self.assertEqual(self.get_signal_output(a1.delete), [ + 'pre_delete signal, William Shakespeare', + 'post_delete signal, William Shakespeare', + ]) \ No newline at end of file