Merge branch 'dev' into pull_124

This commit is contained in:
Ross Lawley
2011-05-25 09:54:56 +01:00
10 changed files with 1184 additions and 240 deletions

View File

@@ -8,6 +8,7 @@ import pymongo.objectid
import re
import copy
import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True
EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]):
self.direction = direction
def __init__(self, fields=[], value=ONLY, always_include=[]):
self.value = value
self.fields = set(fields)
self.always_include = set(always_include)
def as_dict(self):
return dict((field, self.direction) for field in self.fields)
return dict((field, self.value) for field in self.fields)
def __add__(self, f):
if not self.fields:
self.fields = f.fields
self.direction = f.direction
elif self.direction is self.ONLY and f.direction is self.ONLY:
self.value = f.value
elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE:
elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE:
elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY:
self.direction = self.ONLY
elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.value = self.ONLY
self.fields = f.fields - self.fields
if self.always_include:
if self.direction is self.ONLY and self.fields:
if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include)
else:
self.fields -= self.always_include
@@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self):
self.fields = set([])
self.direction = self.ONLY
self.value = self.ONLY
def __nonzero__(self):
return bool(self.fields)
@@ -334,6 +335,7 @@ class QuerySet(object):
self._ordering = []
self._snapshot = False
self._timeout = True
self._class_check = True
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
@@ -344,11 +346,26 @@ class QuerySet(object):
self._limit = None
self._skip = None
def clone(self):
"""Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`"""
c = self.__class__(self._document, self._collection_obj)
copy_props = ('_initial_query', '_query_obj', '_where_clause',
'_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_limit', '_skip')
for prop in copy_props:
val = getattr(self, prop)
setattr(c, prop, copy.deepcopy(val))
return c
@property
def _query(self):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query)
if self._class_check:
self._mongo_query.update(self._initial_query)
return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False,
@@ -399,7 +416,7 @@ class QuerySet(object):
return index_list
def __call__(self, q_obj=None, **query):
def __call__(self, q_obj=None, class_check=True, **query):
"""Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query.
@@ -407,16 +424,17 @@ class QuerySet(object):
the query; the :class:`~mongoengine.queryset.QuerySet` is filtered
multiple times with different :class:`~mongoengine.queryset.Q`
objects, only the last one will be used
:param class_check: If set to False bypass class name check when
querying collection
:param query: Django-style query keyword arguments
"""
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query)
if q_obj:
query &= q_obj
self._query_obj &= query
self._mongo_query = None
self._cursor_obj = None
self._class_check = class_check
return self
def filter(self, *q_objs, **query):
@@ -440,17 +458,17 @@ class QuerySet(object):
drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {})
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# Ensure document-defined indexes are created
if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']:
self._collection.ensure_index(key_or_list,
background=background, **index_opts)
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# If _types is being used (for polymorphism), it needs an index
if '_types' in self._query:
self._collection.ensure_index('_types',
@@ -474,7 +492,7 @@ class QuerySet(object):
}
if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict()
self._cursor_obj = self._collection.find(self._query,
self._cursor_obj = self._collection.find(self._query,
**cursor_args)
# Apply where clauses to cursor
if self._where_clause:
@@ -504,6 +522,15 @@ class QuerySet(object):
fields = []
field = None
for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
fields.append(field_name)
continue
if field is None:
# Look up first field from the document
if field_name == 'pk':
@@ -528,14 +555,14 @@ class QuerySet(object):
return '.'.join(parts)
@classmethod
def _transform_query(cls, _doc_cls=None, **query):
def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact']
mongo_query = {}
@@ -577,8 +604,12 @@ class QuerySet(object):
if op in geo_operators:
if op == "within_distance":
value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
value = {'$within': {'$centerSphere': value}}
elif op == "near":
value = {'$near': value}
elif op == "near_sphere":
value = {'$nearSphere': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
else:
@@ -620,9 +651,9 @@ class QuerySet(object):
raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name)
def get_or_create(self, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object
def get_or_create(self, write_options=None, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` if multiple results are found.
@@ -630,6 +661,10 @@ class QuerySet(object):
dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`.
:param write_options: optional extra keyword arguments used if we
have to create a new document.
Passes any write_options onto :meth:`~mongoengine.document.Document.save`
.. versionadded:: 0.3
"""
defaults = query.get('defaults', {})
@@ -641,7 +676,7 @@ class QuerySet(object):
if count == 0:
query.update(defaults)
doc = self._document(**query)
doc.save()
doc.save(write_options=write_options)
return doc, True
elif count == 1:
return self.first(), False
@@ -725,7 +760,7 @@ class QuerySet(object):
def __len__(self):
return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None,
def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
scope=None, keep_temp=False):
"""Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
@@ -739,26 +774,26 @@ class QuerySet(object):
:param map_f: map function, as :class:`~pymongo.code.Code` or string
:param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string
:param output: output collection name
:param finalize_f: finalize function, an optional function that
performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide
to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo
.. note:: Map/Reduce changed in server version **>= 1.7.4**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**.
PyMongo version **>= 1.11**.
.. versionadded:: 0.3
"""
from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1")
raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {}
if isinstance(map_f, pymongo.code.Code):
@@ -789,8 +824,7 @@ class QuerySet(object):
if limit:
mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, **mr_args)
results = self._collection.map_reduce(map_f, reduce_f, output, **mr_args)
results = results.find()
if self._ordering:
@@ -835,7 +869,7 @@ class QuerySet(object):
self._skip, self._limit = key.start, key.stop
except IndexError, err:
# PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it.
# bin it, kill it.
start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop:
@@ -868,10 +902,8 @@ class QuerySet(object):
.. versionadded:: 0.3
"""
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY)
return self
fields = dict([(f, QueryFieldList.ONLY) for f in fields])
return self.fields(**fields)
def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. ::
@@ -880,8 +912,44 @@ class QuerySet(object):
:param fields: fields to exclude
"""
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE)
fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self
def all_fields(self):
@@ -917,6 +985,10 @@ class QuerySet(object):
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
try:
key = QuerySet._translate_field_name(self._document, key)
except:
pass
key_list.append((key, direction))
self._ordering = key_list
@@ -1007,10 +1079,17 @@ class QuerySet(object):
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields]
parts = []
for field in fields:
if isinstance(field, str):
parts.append(field)
else:
parts.append(field.db_field)
# Convert value to proper value
field = fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
@@ -1029,22 +1108,27 @@ class QuerySet(object):
return mongo_update
def update(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on the fields matched by the query. When
def update(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments
:param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
.. versionadded:: 0.2
"""
if pymongo.version < '1.1.1':
raise OperationError('update() method requires PyMongo 1.1.1+')
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update)
try:
ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update)
upsert=upsert, safe=safe_update,
**write_options)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, err:
@@ -1053,22 +1137,27 @@ class QuerySet(object):
raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on first field matched by the query. When
def update_one(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
:param update: Django-style update keyword arguments
.. versionadded:: 0.2
"""
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update)
try:
# Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True'
if pymongo.version >= '1.1.1':
ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update)
upsert=upsert, safe=safe_update,
**write_options)
else:
# Older versions of PyMongo don't support 'multi'
ret = self._collection.update(self._query, update,
@@ -1082,8 +1171,8 @@ class QuerySet(object):
return self
def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be
"""When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be
substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor).
"""
@@ -1106,9 +1195,9 @@ class QuerySet(object):
options specified as keyword arguments.
As fields in MongoEngine may use different names in the database (set
using the :attr:`db_field` keyword argument to a :class:`Field`
using the :attr:`db_field` keyword argument to a :class:`Field`
constructor), a mechanism exists for replacing MongoEngine field names
with the database field names in Javascript code. When accessing a
with the database field names in Javascript code. When accessing a
field, use square-bracket notation, and prefix the MongoEngine field
name with a tilde (~).
@@ -1241,8 +1330,11 @@ class QuerySet(object):
class QuerySetManager(object):
def __init__(self, manager_func=None):
self._manager_func = manager_func
get_queryset = None
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {}
def __get__(self, instance, owner):
@@ -1259,7 +1351,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents']
if collection in db.collection_names():
@@ -1286,11 +1378,11 @@ class QuerySetManager(object):
# owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func:
if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset)
if self.get_queryset:
if self.get_queryset.func_code.co_argcount == 1:
queryset = self.get_queryset(queryset)
else:
queryset = self._manager_func(owner, queryset)
queryset = self.get_queryset(owner, queryset)
return queryset