Merge remote branch 'upstream/dev' into dev

This commit is contained in:
Colin Howe 2011-06-08 12:24:06 +01:00
commit 6081fc6faf
19 changed files with 1060 additions and 121 deletions

View File

@ -3,3 +3,4 @@ Matt Dennewitz <mattdennewitz@gmail.com>
Deepak Thukral <iapain@yahoo.com>
Florian Schlachter <flori@n-schlachter.de>
Steve Challis <steve@stevechallis.com>
Ross Lawley <ross.lawley@gmail.com>

View File

@ -5,6 +5,11 @@ Changelog
Changes in dev
==============
- Added slave_okay kwarg to queryset
- Added insert method for bulk inserts
- Added blinker signal support
- Added query_counter context manager for tests
- Added DereferenceBaseField - for improved performance in field dereferencing
- Added optional map_reduce method item_frequencies
- Added inline_map_reduce option to map_reduce
- Updated connection exception so it provides more info on the cause.

View File

@ -49,10 +49,11 @@ Storage
=======
With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`,
it is useful to have a Django file storage backend that wraps this. The new
storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it
is very similar to using the default FileSystemStorage.::
fs = mongoengine.django.GridFSStorage()
storage module is called :class:`~mongoengine.django.storage.GridFSStorage`.
Using it is very similar to using the default FileSystemStorage.::
from mongoengine.django.storage import GridFSStorage
fs = GridFSStorage()
filename = fs.save('hello.txt', 'Hello, World!')

View File

@ -341,9 +341,10 @@ Indexes
You can specify indexes on collections to make querying faster. This is done
by creating a list of index specifications called :attr:`indexes` in the
:attr:`~mongoengine.Document.meta` dictionary, where an index specification may
either be a single field name, or a tuple containing multiple field names. A
direction may be specified on fields by prefixing the field name with a **+**
or a **-** sign. Note that direction only matters on multi-field indexes. ::
either be a single field name, a tuple containing multiple field names, or a
dictionary containing a full index definition. A direction may be specified on
fields by prefixing the field name with a **+** or a **-** sign. Note that
direction only matters on multi-field indexes. ::
class Page(Document):
title = StringField()
@ -352,6 +353,21 @@ or a **-** sign. Note that direction only matters on multi-field indexes. ::
'indexes': ['title', ('title', '-rating')]
}
If a dictionary is passed then the following options are available:
:attr:`fields` (Default: None)
The fields to index. Specified in the same format as described above.
:attr:`types` (Default: True)
Whether the index should have the :attr:`_types` field added automatically
to the start of the index.
:attr:`sparse` (Default: False)
Whether the index should be sparse.
:attr:`unique` (Default: False)
Whether the index should be sparse.
.. note::
Geospatial indexes will be automatically created for all
:class:`~mongoengine.GeoPointField`\ s

View File

@ -11,3 +11,4 @@ User Guide
document-instances
querying
gridfs
signals

49
docs/guide/signals.rst Normal file
View File

@ -0,0 +1,49 @@
.. _signals:
Signals
=======
.. versionadded:: 0.5
Signal support is provided by the excellent `blinker`_ library and
will gracefully fall back if it is not available.
The following document signals exist in MongoEngine and are pretty self explaintary:
* `mongoengine.signals.pre_init`
* `mongoengine.signals.post_init`
* `mongoengine.signals.pre_save`
* `mongoengine.signals.post_save`
* `mongoengine.signals.pre_delete`
* `mongoengine.signals.post_delete`
Example usage::
from mongoengine import *
from mongoengine import signals
class Author(Document):
name = StringField()
def __unicode__(self):
return self.name
@classmethod
def pre_save(cls, instance, **kwargs):
logging.debug("Pre Save: %s" % instance.name)
@classmethod
def post_save(cls, instance, **kwargs):
logging.debug("Post Save: %s" % instance.name)
if 'created' in kwargs:
if kwargs['created']:
logging.debug("Created")
else:
logging.debug("Updated")
signals.pre_save.connect(Author.pre_save)
signals.post_save.connect(Author.post_save)
.. _blinker: http://pypi.python.org/pypi/blinker

View File

@ -6,9 +6,11 @@ import connection
from connection import *
import queryset
from queryset import *
import signals
from signals import *
__all__ = (document.__all__ + fields.__all__ + connection.__all__ +
queryset.__all__)
queryset.__all__ + signals.__all__)
__author__ = 'Harry Marr'

View File

@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager
from queryset import DoesNotExist, MultipleObjectsReturned
from queryset import DO_NOTHING
from mongoengine import signals
import sys
import pymongo
import pymongo.objectid
from operator import itemgetter
class NotRegistered(Exception):
@ -126,6 +129,88 @@ class BaseField(object):
self.validate(value)
class DereferenceBaseField(BaseField):
"""Handles the lazy dereferencing of a queryset. Will dereference all
items in a list / dict rather than one at a time.
"""
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
from fields import ReferenceField, GenericReferenceField
from connection import _get_db
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value_list = instance._data.get(self.name)
if not value_list:
return super(DereferenceBaseField, self).__get__(instance, owner)
is_list = False
if not hasattr(value_list, 'items'):
is_list = True
value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = get_document(ref['_cls'])._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
dbref = {}
classes = {}
for k, v in value_list:
dbref[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
return super(DereferenceBaseField, self).__get__(instance, owner)
class ObjectIdField(BaseField):
"""An field wrapper around MongoDB's ObjectIds.
"""
@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
class BaseDocument(object):
def __init__(self, **values):
signals.pre_init.send(self, values=values)
self._data = {}
# Assign default values to instance
for attr_name in self._fields.keys():
@ -395,6 +482,8 @@ class BaseDocument(object):
except AttributeError:
pass
signals.post_init.send(self)
def validate(self):
"""Ensure that all fields' values are valid and that required fields
are present.

View File

@ -1,3 +1,4 @@
from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
ValidationError)
from queryset import OperationError
@ -75,6 +76,8 @@ class Document(BaseDocument):
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
"""
signals.pre_save.send(self)
if validate:
self.validate()
@ -82,6 +85,7 @@ class Document(BaseDocument):
write_options = {}
doc = self.to_mongo()
created = '_id' not in doc
try:
collection = self.__class__.objects._collection
if force_insert:
@ -96,12 +100,16 @@ class Document(BaseDocument):
id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self, created=created)
def delete(self, safe=False):
"""Delete the :class:`~mongoengine.Document` from the database. This
will only take effect if the document has been previously saved.
:param safe: check if the operation succeeded before returning
"""
signals.pre_delete.send(self)
id_field = self._meta['id_field']
object_id = self._fields[id_field].to_mongo(self[id_field])
try:
@ -110,6 +118,8 @@ class Document(BaseDocument):
message = u'Could not delete document (%s)' % err.message
raise OperationError(message)
signals.post_delete.send(self)
@classmethod
def register_delete_rule(cls, document_cls, field_name, rule):
"""This method registers the delete rules to apply when removing this

View File

@ -1,4 +1,5 @@
from base import BaseField, ObjectIdField, ValidationError, get_document
from base import (BaseField, DereferenceBaseField, ObjectIdField,
ValidationError, get_document)
from queryset import DO_NOTHING
from document import Document, EmbeddedDocument
from connection import _get_db
@ -12,7 +13,6 @@ import pymongo.binary
import datetime, time
import decimal
import gridfs
import warnings
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
@ -118,8 +118,8 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
)
def validate(self, value):
@ -153,6 +153,7 @@ class IntField(BaseField):
def prepare_query_value(self, op, value):
return int(value)
class FloatField(BaseField):
"""An floating point number field.
"""
@ -178,6 +179,7 @@ class FloatField(BaseField):
def prepare_query_value(self, op, value):
return float(value)
class DecimalField(BaseField):
"""A fixed-point decimal number field.
@ -227,6 +229,10 @@ class BooleanField(BaseField):
class DateTimeField(BaseField):
"""A datetime field.
Note: Microseconds are rounded to the nearest millisecond.
Pre UTC microsecond support is effecively broken see
`tests.field.test_datetime` for more information.
"""
def validate(self, value):
@ -252,21 +258,21 @@ class DateTimeField(BaseField):
else:
usecs = 0
kwargs = {'microsecond': usecs}
try: # Seconds are optional, so try converting seconds first.
try: # Seconds are optional, so try converting seconds first.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
**kwargs)
except ValueError:
try: # Try without seconds.
try: # Try without seconds.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
**kwargs)
except ValueError: # Try without hour/minutes/seconds.
except ValueError: # Try without hour/minutes/seconds.
try:
return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3],
**kwargs)
except ValueError:
return None
class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`.
@ -314,7 +320,7 @@ class EmbeddedDocumentField(BaseField):
return self.to_mongo(value)
class ListField(BaseField):
class ListField(DereferenceBaseField):
"""A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database.
"""
@ -330,42 +336,6 @@ class ListField(BaseField):
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_list = instance._data.get(self.name)
if value_list:
deref_list = []
for value in value_list:
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_list.append(referenced_type._from_son(value))
else:
deref_list.append(value)
instance._data[self.name] = deref_list
if isinstance(self.field, GenericReferenceField):
value_list = instance._data.get(self.name)
if value_list:
deref_list = []
for value in value_list:
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_list.append(self.field.dereference(value))
else:
deref_list.append(value)
instance._data[self.name] = deref_list
return super(ListField, self).__get__(instance, owner)
def to_python(self, value):
return [self.field.to_python(item) for item in value]
@ -459,10 +429,10 @@ class DictField(BaseField):
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField,self).prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)
class MapField(BaseField):
class MapField(DereferenceBaseField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
@ -494,47 +464,11 @@ class MapField(BaseField):
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_dict[key] = referenced_type._from_son(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
if isinstance(self.field, GenericReferenceField):
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_dict[key] = self.field.dereference(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
return super(MapField, self).__get__(instance, owner)
def to_python(self, value):
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
def to_mongo(self, value):
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
def prepare_query_value(self, op, value):
if op not in ('set', 'unset'):
@ -752,11 +686,11 @@ class GridFSProxy(object):
self.newfile = self.fs.new_file(**kwargs)
self.grid_id = self.newfile._id
def put(self, file, **kwargs):
def put(self, file_obj, **kwargs):
if self.grid_id:
raise GridFSError('This document already has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file, **kwargs)
self.grid_id = self.fs.put(file_obj, **kwargs)
def write(self, string):
if self.grid_id:
@ -785,9 +719,9 @@ class GridFSProxy(object):
self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs):
def replace(self, file_obj, **kwargs):
self.delete()
self.put(file, **kwargs)
self.put(file_obj, **kwargs)
def close(self):
if self.newfile:

View File

@ -336,6 +336,7 @@ class QuerySet(object):
self._snapshot = False
self._timeout = True
self._class_check = True
self._slave_okay = False
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
@ -352,7 +353,7 @@ class QuerySet(object):
copy_props = ('_initial_query', '_query_obj', '_where_clause',
'_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_limit', '_skip')
'_timeout', '_limit', '_skip', '_slave_okay')
for prop in copy_props:
val = getattr(self, prop)
@ -376,21 +377,27 @@ class QuerySet(object):
construct a multi-field index); keys may be prefixed with a **+**
or a **-** to determine the index ordering
"""
index_list = QuerySet._build_index_spec(self._document, key_or_list)
self._collection.ensure_index(index_list, drop_dups=drop_dups,
background=background)
index_spec = QuerySet._build_index_spec(self._document, key_or_list)
self._collection.ensure_index(
index_spec['fields'],
drop_dups=drop_dups,
background=background,
sparse=index_spec.get('sparse', False),
unique=index_spec.get('unique', False))
return self
@classmethod
def _build_index_spec(cls, doc_cls, key_or_list):
def _build_index_spec(cls, doc_cls, spec):
"""Build a PyMongo index spec from a MongoEngine index spec.
"""
if isinstance(key_or_list, basestring):
key_or_list = [key_or_list]
if isinstance(spec, basestring):
spec = {'fields': [spec]}
if isinstance(spec, (list, tuple)):
spec = {'fields': spec}
index_list = []
use_types = doc_cls._meta.get('allow_inheritance', True)
for key in key_or_list:
for key in spec['fields']:
# Get direction from + or -
direction = pymongo.ASCENDING
if key.startswith("-"):
@ -411,12 +418,20 @@ class QuerySet(object):
use_types = False
# If _types is being used, prepend it to every specified index
if doc_cls._meta.get('allow_inheritance') and use_types:
if (spec.get('types', True) and doc_cls._meta.get('allow_inheritance')
and use_types):
index_list.insert(0, ('_types', 1))
return index_list
spec['fields'] = index_list
def __call__(self, q_obj=None, class_check=True, **query):
if spec.get('sparse', False) and len(spec['fields']) > 1:
raise ValueError(
'Sparse indexes can only have one field in them. '
'See https://jira.mongodb.org/browse/SERVER-2193')
return spec
def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query):
"""Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query.
@ -426,6 +441,8 @@ class QuerySet(object):
objects, only the last one will be used
:param class_check: If set to False bypass class name check when
querying collection
:param slave_okay: if True, allows this query to be run against a
replica secondary.
:param query: Django-style query keyword arguments
"""
query = Q(**query)
@ -465,9 +482,12 @@ class QuerySet(object):
# Ensure document-defined indexes are created
if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']:
self._collection.ensure_index(key_or_list,
background=background, **index_opts)
for spec in self._document._meta['indexes']:
opts = index_opts.copy()
opts['unique'] = spec.get('unique', False)
opts['sparse'] = spec.get('sparse', False)
self._collection.ensure_index(spec['fields'],
background=background, **opts)
# If _types is being used (for polymorphism), it needs an index
if '_types' in self._query:
@ -483,17 +503,23 @@ class QuerySet(object):
return self._collection_obj
@property
def _cursor_args(self):
cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout,
'slave_okay': self._slave_okay
}
if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict()
return cursor_args
@property
def _cursor(self):
if self._cursor_obj is None:
cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout,
}
if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict()
self._cursor_obj = self._collection.find(self._query,
**cursor_args)
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
self._cursor_obj.where(self._where_clause)
@ -702,6 +728,46 @@ class QuerySet(object):
result = None
return result
def insert(self, doc_or_docs, load_bulk=True):
"""bulk insert documents
:param docs_or_doc: a document or list of documents to be inserted
:param load_bulk (optional): If True returns the list of document instances
By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds``
.. versionadded:: 0.5
"""
from document import Document
docs = doc_or_docs
return_one = False
if isinstance(docs, Document) or issubclass(docs.__class__, Document):
return_one = True
docs = [docs]
raw = []
for doc in docs:
if not isinstance(doc, self._document):
msg = "Some documents inserted aren't instances of %s" % str(self._document)
raise OperationError(msg)
if doc.pk:
msg = "Some documents have ObjectIds use doc.update() instead"
raise OperationError(msg)
raw.append(doc.to_mongo())
ids = self._collection.insert(raw)
if not load_bulk:
return return_one and ids[0] or ids
documents = self.in_bulk(ids)
results = []
for obj_id in ids:
results.append(documents.get(obj_id))
return return_one and results[0] or results
def with_id(self, object_id):
"""Retrieve the object matching the id provided.
@ -710,7 +776,7 @@ class QuerySet(object):
id_field = self._document._meta['id_field']
object_id = self._document._fields[id_field].to_mongo(object_id)
result = self._collection.find_one({'_id': object_id})
result = self._collection.find_one({'_id': object_id}, **self._cursor_args)
if result is not None:
result = self._document._from_son(result)
return result
@ -726,7 +792,8 @@ class QuerySet(object):
"""
doc_map = {}
docs = self._collection.find({'_id': {'$in': object_ids}})
docs = self._collection.find({'_id': {'$in': object_ids}},
**self._cursor_args)
for doc in docs:
doc_map[doc['_id']] = self._document._from_son(doc)
@ -1023,6 +1090,7 @@ class QuerySet(object):
:param enabled: whether or not snapshot mode is enabled
"""
self._snapshot = enabled
return self
def timeout(self, enabled):
"""Enable or disable the default mongod timeout when querying.
@ -1030,6 +1098,15 @@ class QuerySet(object):
:param enabled: whether or not the timeout is used
"""
self._timeout = enabled
return self
def slave_okay(self, enabled):
"""Enable or disable the slave_okay when querying.
:param enabled: whether or not the slave_okay is enabled
"""
self._slave_okay = enabled
return self
def delete(self, safe=False):
"""Delete the documents matched by the query.

44
mongoengine/signals.py Normal file
View File

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save',
'pre_delete', 'post_delete']
signals_available = False
try:
from blinker import Namespace
signals_available = True
except ImportError:
class Namespace(object):
def signal(self, name, doc=None):
return _FakeSignal(name, doc)
class _FakeSignal(object):
"""If blinker is unavailable, create a fake class with the same
interface that allows sending of signals but will fail with an
error on anything else. Instead of doing anything on send, it
will just ignore the arguments and do nothing instead.
"""
def __init__(self, name, doc=None):
self.name = name
self.__doc__ = doc
def _fail(self, *args, **kwargs):
raise RuntimeError('signalling support is unavailable '
'because the blinker library is '
'not installed.')
send = lambda *a, **kw: None
connect = disconnect = has_receivers_for = receivers_for = \
temporarily_connected_to = _fail
del _fail
# the namespace for code signals. If you are not mongoengine code, do
# not put signals in here. Create your own namespace instead.
_signals = Namespace()
pre_init = _signals.signal('pre_init')
post_init = _signals.signal('post_init')
pre_save = _signals.signal('pre_save')
post_save = _signals.signal('post_save')
pre_delete = _signals.signal('pre_delete')
post_delete = _signals.signal('post_delete')

59
mongoengine/tests.py Normal file
View File

@ -0,0 +1,59 @@
from mongoengine.connection import _get_db
class query_counter(object):
""" Query_counter contextmanager to get the number of queries. """
def __init__(self):
""" Construct the query_counter. """
self.counter = 0
self.db = _get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
return value == self._get_count()
def __ne__(self, value):
""" != Compare querycounter. """
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
return self._get_count() >= value
def __int__(self):
""" int representation. """
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter
self.counter += 1
return count

View File

@ -45,6 +45,6 @@ setup(name='mongoengine',
long_description=LONG_DESCRIPTION,
platforms=['any'],
classifiers=CLASSIFIERS,
install_requires=['pymongo'],
install_requires=['pymongo', 'blinker'],
test_suite='tests',
)

288
tests/dereference.py Normal file
View File

@ -0,0 +1,288 @@
import unittest
from mongoengine import *
from mongoengine.connection import _get_db
from mongoengine.tests import query_counter
class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = _get_db()
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
members = ListField(ReferenceField(User))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
user = User(name='user %s' % i)
user.save()
group = Group(members=User.objects)
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 2)
User.drop_collection()
Group.drop_collection()
def test_recursive_reference(self):
"""Ensure that ReferenceFields can reference their own documents.
"""
class Employee(Document):
name = StringField()
boss = ReferenceField('self')
friends = ListField(ReferenceField('self'))
bill = Employee(name='Bill Lumbergh')
bill.save()
michael = Employee(name='Michael Bolton')
michael.save()
samir = Employee(name='Samir Nagheenanajar')
samir.save()
friends = [michael, samir]
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
peter.save()
with query_counter() as q:
self.assertEqual(q, 0)
peter = Employee.objects.with_id(peter.id)
self.assertEqual(q, 1)
peter.boss
self.assertEqual(q, 2)
peter.friends
self.assertEqual(q, 3)
def test_generic_reference(self):
class UserA(Document):
name = StringField()
class UserB(Document):
name = StringField()
class UserC(Document):
name = StringField()
class Group(Document):
members = ListField(GenericReferenceField())
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
b = UserB(name='User B %s' % i)
b.save()
c = UserC(name='User C %s' % i)
c.save()
members += [a, b, c]
group = Group(members=members)
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 4)
[m for m in group_obj.members]
self.assertEqual(q, 4)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
def test_map_field_reference(self):
class User(Document):
name = StringField()
class Group(Document):
members = MapField(ReferenceField(User))
User.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
user = User(name='user %s' % i)
user.save()
members.append(user)
group = Group(members=dict([(str(u.id), u) for u in members]))
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 2)
User.drop_collection()
Group.drop_collection()
def ztest_generic_reference_dict_field(self):
class UserA(Document):
name = StringField()
class UserB(Document):
name = StringField()
class UserC(Document):
name = StringField()
class Group(Document):
members = DictField()
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
b = UserB(name='User B %s' % i)
b.save()
c = UserC(name='User C %s' % i)
c.save()
members += [a, b, c]
group = Group(members=dict([(str(u.id), u) for u in members]))
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 4)
[m for m in group_obj.members]
self.assertEqual(q, 4)
group.members = {}
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 1)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
def test_generic_reference_map_field(self):
class UserA(Document):
name = StringField()
class UserB(Document):
name = StringField()
class UserC(Document):
name = StringField()
class Group(Document):
members = MapField(GenericReferenceField())
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
b = UserB(name='User B %s' % i)
b.save()
c = UserC(name='User C %s' % i)
c.save()
members += [a, b, c]
group = Group(members=dict([(str(u.id), u) for u in members]))
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 4)
[m for m in group_obj.members]
self.assertEqual(q, 4)
group.members = {}
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 1)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()

View File

@ -377,6 +377,40 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection()
def test_dictionary_indexes(self):
"""Ensure that indexes are used when meta[indexes] contains dictionaries
instead of lists.
"""
class BlogPost(Document):
date = DateTimeField(db_field='addDate', default=datetime.now)
category = StringField()
tags = ListField(StringField())
meta = {
'indexes': [
{ 'fields': ['-date'], 'unique': True,
'sparse': True, 'types': False },
],
}
BlogPost.drop_collection()
info = BlogPost.objects._collection.index_information()
# _id, '-date'
self.assertEqual(len(info), 3)
# Indexes are lazy so use list() to perform query
list(BlogPost.objects)
info = BlogPost.objects._collection.index_information()
info = [(value['key'],
value.get('unique', False),
value.get('sparse', False))
for key, value in info.iteritems()]
self.assertTrue(([('addDate', -1)], True, True) in info)
BlogPost.drop_collection()
def test_unique(self):
"""Ensure that uniqueness constraints are applied to fields.
"""

View File

@ -187,6 +187,66 @@ class FieldTest(unittest.TestCase):
log.time = '1pm'
self.assertRaises(ValidationError, log.validate)
def test_datetime(self):
"""Tests showing pymongo datetime fields handling of microseconds.
Microseconds are rounded to the nearest millisecond and pre UTC
handling is wonky.
See: http://api.mongodb.org/python/current/api/bson/son.html#dt
"""
class LogEntry(Document):
date = DateTimeField()
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
log = LogEntry()
log.date = d1
log.save()
log.reload()
self.assertNotEquals(log.date, d1)
self.assertEquals(log.date, d2)
# Post UTC - microseconds are rounded (down) nearest millisecond
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000)
log.date = d1
log.save()
log.reload()
self.assertNotEquals(log.date, d1)
self.assertEquals(log.date, d2)
# Pre UTC dates microseconds below 1000 are dropped
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
d2 = datetime.datetime(1969, 12, 31, 23, 59, 59)
log.date = d1
log.save()
log.reload()
self.assertNotEquals(log.date, d1)
self.assertEquals(log.date, d2)
# Pre UTC microseconds above 1000 is wonky.
# log.date has an invalid microsecond value so I can't construct
# a date to compare.
#
# However, the timedelta is predicable with pre UTC timestamps
# It always adds 16 seconds and [777216-776217] microseconds
for i in xrange(1001, 3113, 33):
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
log.date = d1
log.save()
log.reload()
self.assertNotEquals(log.date, d1)
delta = log.date - d1
self.assertEquals(delta.seconds, 16)
microseconds = 777216 - (i % 1000)
self.assertEquals(delta.microseconds, microseconds)
LogEntry.drop_collection()
def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements.
"""

View File

@ -9,6 +9,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager,
MultipleObjectsReturned, DoesNotExist,
QueryFieldList)
from mongoengine import *
from mongoengine.tests import query_counter
class QuerySetTest(unittest.TestCase):
@ -331,6 +332,125 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age=50)
self.assertEqual(person.name, "User C")
def test_bulk_insert(self):
"""Ensure that query by array position works.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
title = StringField()
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
with query_counter() as q:
self.assertEqual(q, 0)
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blogs = []
for i in xrange(1, 100):
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
Blog.objects.insert(blogs, load_bulk=False)
self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert
Blog.objects.insert(blogs)
self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog(title="code", posts=[post1, post2])
blog2 = Blog(title="mongodb", posts=[post2, post1])
blog1, blog2 = Blog.objects.insert([blog1, blog2])
self.assertEqual(blog1.title, "code")
self.assertEqual(blog2.title, "mongodb")
self.assertEqual(Blog.objects.count(), 2)
# test handles people trying to upsert
def throw_operation_error():
blogs = Blog.objects
Blog.objects.insert(blogs)
self.assertRaises(OperationError, throw_operation_error)
# test handles other classes being inserted
def throw_operation_error_wrong_doc():
class Author(Document):
pass
Blog.objects.insert(Author())
self.assertRaises(OperationError, throw_operation_error_wrong_doc)
def throw_operation_error_not_a_document():
Blog.objects.insert("HELLO WORLD")
self.assertRaises(OperationError, throw_operation_error_not_a_document)
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
blog1 = Blog.objects.insert(blog1)
self.assertEqual(blog1.title, "code")
self.assertEqual(Blog.objects.count(), 1)
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
obj_id = Blog.objects.insert(blog1, load_bulk=False)
self.assertEquals(obj_id.__class__.__name__, 'ObjectId')
def test_slave_okay(self):
"""Ensures that a query can take slave_okay syntax
"""
person1 = self.Person(name="User A", age=20)
person1.save()
person2 = self.Person(name="User B", age=30)
person2.save()
# Retrieve the first person from the database
person = self.Person.objects.slave_okay(True).first()
self.assertTrue(isinstance(person, self.Person))
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
def test_cursor_args(self):
"""Ensures the cursor args can be set as expected
"""
p = self.Person.objects
# Check default
self.assertEqual(p._cursor_args,
{'snapshot': False, 'slave_okay': False, 'timeout': True})
p.snapshot(False).slave_okay(False).timeout(False)
self.assertEqual(p._cursor_args,
{'snapshot': False, 'slave_okay': False, 'timeout': False})
p.snapshot(True).slave_okay(False).timeout(False)
self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': False, 'timeout': False})
p.snapshot(True).slave_okay(True).timeout(False)
self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': True, 'timeout': False})
p.snapshot(True).slave_okay(True).timeout(True)
self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': True, 'timeout': True})
def test_repeated_iteration(self):
"""Ensure that QuerySet rewinds itself one iteration finishes.
"""
@ -2099,8 +2219,27 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection()
def test_ensure_index(self):
"""Ensure that manual creation of indexes works.
"""
class Comment(Document):
message = StringField()
Comment.objects.ensure_index('message')
info = Comment.objects._collection.index_information()
info = [(value['key'],
value.get('unique', False),
value.get('sparse', False))
for key, value in info.iteritems()]
self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info)
class QTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
def test_empty_q(self):
"""Ensure that empty Q objects won't hurt.
"""

130
tests/signals.py Normal file
View File

@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
import unittest
from mongoengine import *
from mongoengine import signals
signal_output = []
class SignalTests(unittest.TestCase):
"""
Testing signals before/after saving and deleting.
"""
def get_signal_output(self, fn, *args, **kwargs):
# Flush any existing signal output
global signal_output
signal_output = []
fn(*args, **kwargs)
return signal_output
def setUp(self):
connect(db='mongoenginetest')
class Author(Document):
name = StringField()
def __unicode__(self):
return self.name
@classmethod
def pre_init(cls, instance, **kwargs):
signal_output.append('pre_init signal, %s' % cls.__name__)
signal_output.append(str(kwargs['values']))
@classmethod
def post_init(cls, instance, **kwargs):
signal_output.append('post_init signal, %s' % instance)
@classmethod
def pre_save(cls, instance, **kwargs):
signal_output.append('pre_save signal, %s' % instance)
@classmethod
def post_save(cls, instance, **kwargs):
signal_output.append('post_save signal, %s' % instance)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
@classmethod
def pre_delete(cls, instance, **kwargs):
signal_output.append('pre_delete signal, %s' % instance)
@classmethod
def post_delete(cls, instance, **kwargs):
signal_output.append('post_delete signal, %s' % instance)
self.Author = Author
# Save up the number of connected signals so that we can check at the end
# that all the signals we register get properly unregistered
self.pre_signals = (
len(signals.pre_init.receivers),
len(signals.post_init.receivers),
len(signals.pre_save.receivers),
len(signals.post_save.receivers),
len(signals.pre_delete.receivers),
len(signals.post_delete.receivers)
)
signals.pre_init.connect(Author.pre_init)
signals.post_init.connect(Author.post_init)
signals.pre_save.connect(Author.pre_save)
signals.post_save.connect(Author.post_save)
signals.pre_delete.connect(Author.pre_delete)
signals.post_delete.connect(Author.post_delete)
def tearDown(self):
signals.pre_init.disconnect(self.Author.pre_init)
signals.post_init.disconnect(self.Author.post_init)
signals.post_delete.disconnect(self.Author.post_delete)
signals.pre_delete.disconnect(self.Author.pre_delete)
signals.post_save.disconnect(self.Author.post_save)
signals.pre_save.disconnect(self.Author.pre_save)
# Check that all our signals got disconnected properly.
post_signals = (
len(signals.pre_init.receivers),
len(signals.post_init.receivers),
len(signals.pre_save.receivers),
len(signals.post_save.receivers),
len(signals.pre_delete.receivers),
len(signals.post_delete.receivers)
)
self.assertEqual(self.pre_signals, post_signals)
def test_model_signals(self):
""" Model saves should throw some signals. """
def create_author():
a1 = self.Author(name='Bill Shakespeare')
self.assertEqual(self.get_signal_output(create_author), [
"pre_init signal, Author",
"{'name': 'Bill Shakespeare'}",
"post_init signal, Bill Shakespeare",
])
a1 = self.Author(name='Bill Shakespeare')
self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, Bill Shakespeare",
"post_save signal, Bill Shakespeare",
"Is created"
])
a1.reload()
a1.name='William Shakespeare'
self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, William Shakespeare",
"post_save signal, William Shakespeare",
"Is updated"
])
self.assertEqual(self.get_signal_output(a1.delete), [
'pre_delete signal, William Shakespeare',
'post_delete signal, William Shakespeare',
])