Merge branch 'v0.4' of git://github.com/hmarr/mongoengine into v0.4

Conflicts:
	docs/changelog.rst
	mongoengine/base.py
	mongoengine/queryset.py
This commit is contained in:
Steve Challis
2010-10-17 23:48:20 +01:00
13 changed files with 703 additions and 227 deletions

View File

@@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr'
VERSION = (0, 3, 0)
VERSION = (0, 4, 0)
def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@@ -3,6 +3,7 @@ from queryset import DoesNotExist, MultipleObjectsReturned
import sys
import pymongo
import pymongo.objectid
_document_registry = {}
@@ -203,6 +204,9 @@ class DocumentMetaclass(type):
exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
new_class.add_to_class('MultipleObjectsReturned', exc)
global _document_registry
_document_registry[name] = new_class
return new_class
def add_to_class(self, name, value):
@@ -215,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
"""
def __new__(cls, name, bases, attrs):
global _document_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc.
@@ -321,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
_document_registry[name] = new_class
return new_class

View File

@@ -4,11 +4,12 @@ import multiprocessing
__all__ = ['ConnectionError', 'connect']
_connection_settings = {
_connection_defaults = {
'host': 'localhost',
'port': 27017,
}
_connection = {}
_connection_settings = _connection_defaults.copy()
_db_name = None
_db_username = None
@@ -20,25 +21,25 @@ class ConnectionError(Exception):
pass
def _get_connection():
def _get_connection(reconnect=False):
global _connection
identity = get_identity()
# Connect to the database if not already connected
if _connection.get(identity) is None:
if _connection.get(identity) is None or reconnect:
try:
_connection[identity] = Connection(**_connection_settings)
except:
raise ConnectionError('Cannot connect to the database')
return _connection[identity]
def _get_db():
def _get_db(reconnect=False):
global _db, _connection
identity = get_identity()
# Connect if not already connected
if _connection.get(identity) is None:
_connection[identity] = _get_connection()
if _connection.get(identity) is None or reconnect:
_connection[identity] = _get_connection(reconnect=reconnect)
if _db.get(identity) is None:
if _db.get(identity) is None or reconnect:
# _db_name will be None if the user hasn't called connect()
if _db_name is None:
raise ConnectionError('Not connected to the database')
@@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs):
the default port on localhost. If authentication is needed, provide
username and password arguments as well.
"""
global _connection_settings, _db_name, _db_username, _db_password
_connection_settings.update(kwargs)
global _connection_settings, _db_name, _db_username, _db_password, _db
_connection_settings = dict(_connection_defaults, **kwargs)
_db_name = db
_db_username = username
_db_password = password
return _get_db()
return _get_db(reconnect=True)

View File

@@ -5,6 +5,9 @@ from operator import itemgetter
import re
import pymongo
import pymongo.dbref
import pymongo.son
import pymongo.binary
import datetime
import decimal
import gridfs
@@ -106,8 +109,11 @@ class URLField(StringField):
message = 'This URL appears to be a broken link: %s' % e
raise ValidationError(message)
class EmailField(StringField):
"""A field that validates input as an E-Mail-Address.
.. versionadded:: 0.4
"""
EMAIL_REGEX = re.compile(
@@ -115,11 +121,12 @@ class EmailField(StringField):
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
)
def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value):
raise ValidationError('Invalid Mail-address: %s' % value)
class IntField(BaseField):
"""An integer field.
"""
@@ -143,6 +150,7 @@ class IntField(BaseField):
if self.max_value is not None and value > self.max_value:
raise ValidationError('Integer value is too large')
class FloatField(BaseField):
"""An floating point number field.
"""
@@ -179,7 +187,7 @@ class DecimalField(BaseField):
if not isinstance(value, basestring):
value = unicode(value)
return decimal.Decimal(value)
def to_mongo(self, value):
return unicode(value)
@@ -198,6 +206,7 @@ class DecimalField(BaseField):
if self.max_value is not None and value > self.max_value:
raise ValidationError('Decimal value is too large')
class BooleanField(BaseField):
"""A boolean field type.
@@ -210,6 +219,7 @@ class BooleanField(BaseField):
def validate(self, value):
assert isinstance(value, bool)
class DateTimeField(BaseField):
"""A datetime field.
"""
@@ -217,38 +227,49 @@ class DateTimeField(BaseField):
def validate(self, value):
assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`.
"""
def __init__(self, document, **kwargs):
if not issubclass(document, EmbeddedDocument):
raise ValidationError('Invalid embedded document class provided '
'to an EmbeddedDocumentField')
self.document = document
def __init__(self, document_type, **kwargs):
if not isinstance(document_type, basestring):
if not issubclass(document_type, EmbeddedDocument):
raise ValidationError('Invalid embedded document class '
'provided to an EmbeddedDocumentField')
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def to_python(self, value):
if not isinstance(value, self.document):
return self.document._from_son(value)
if not isinstance(value, self.document_type):
return self.document_type._from_son(value)
return value
def to_mongo(self, value):
return self.document.to_mongo(value)
return self.document_type.to_mongo(value)
def validate(self, value):
"""Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined.
"""
# Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document):
if not isinstance(value, self.document_type):
raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField')
self.document.validate(value)
self.document_type.validate(value)
def lookup_member(self, member_name):
return self.document._fields.get(member_name)
return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value):
return self.to_mongo(value)
@@ -322,20 +343,32 @@ class ListField(BaseField):
try:
[self.field.validate(item) for item in value]
except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(err))
raise ValidationError('Invalid ListField item (%s)' % str(item))
def prepare_query_value(self, op, value):
if op in ('set', 'unset'):
return [self.field.to_mongo(v) for v in value]
return self.field.to_mongo(value)
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always
retrieved.
.. versionadded:: 0.4
"""
_ordering = None
@@ -351,6 +384,7 @@ class SortedListField(ListField):
key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value])
class DictField(BaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
@@ -389,7 +423,6 @@ class ReferenceField(BaseField):
raise ValidationError('Argument to ReferenceField constructor '
'must be a document class or a string')
self.document_type_obj = document_type
self.document_obj = None
super(ReferenceField, self).__init__(**kwargs)
@property
@@ -444,6 +477,7 @@ class ReferenceField(BaseField):
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).
@@ -488,7 +522,7 @@ class GenericReferenceField(BaseField):
return {'_cls': document.__class__.__name__, '_ref': ref}
def prepare_query_value(self, op, value):
return self.to_mongo(value)['_ref']
return self.to_mongo(value)
class BinaryField(BaseField):
@@ -655,6 +689,8 @@ class FileField(BaseField):
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
.. versionadded:: 0.4
"""
_geo_index = True

View File

@@ -2,8 +2,12 @@ from connection import _get_db
import pprint
import pymongo
import pymongo.code
import pymongo.dbref
import pymongo.objectid
import re
import copy
import itertools
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError']
@@ -15,6 +19,7 @@ REPR_OUTPUT_SIZE = 20
class DoesNotExist(Exception):
pass
class MultipleObjectsReturned(Exception):
pass
@@ -26,50 +31,192 @@ class InvalidQueryError(Exception):
class OperationError(Exception):
pass
class InvalidCollectionError(Exception):
pass
RE_TYPE = type(re.compile(''))
class Q(object):
class QNodeVisitor(object):
"""Base visitor class for visiting Q-object nodes in a query tree.
"""
OR = '||'
AND = '&&'
OPERATORS = {
'eq': ('((this.%(field)s instanceof Array) && '
' this.%(field)s.indexOf(%(value)s) != -1) ||'
' this.%(field)s == %(value)s'),
'ne': 'this.%(field)s != %(value)s',
'gt': 'this.%(field)s > %(value)s',
'gte': 'this.%(field)s >= %(value)s',
'lt': 'this.%(field)s < %(value)s',
'lte': 'this.%(field)s <= %(value)s',
'lte': 'this.%(field)s <= %(value)s',
'in': '%(value)s.indexOf(this.%(field)s) != -1',
'nin': '%(value)s.indexOf(this.%(field)s) == -1',
'mod': '%(field)s %% %(value)s',
'all': ('%(value)s.every(function(a){'
'return this.%(field)s.indexOf(a) != -1 })'),
'size': 'this.%(field)s.length == %(value)s',
'exists': 'this.%(field)s != null',
'regex_eq': '%(value)s.test(this.%(field)s)',
'regex_ne': '!%(value)s.test(this.%(field)s)',
}
def visit_combination(self, combination):
"""Called by QCombination objects.
"""
return combination
def __init__(self, **query):
self.query = [query]
def visit_query(self, query):
"""Called by (New)Q objects.
"""
return query
def _combine(self, other, op):
obj = Q()
if not other.query[0]:
class SimplificationVisitor(QNodeVisitor):
"""Simplifies query trees by combinging unnecessary 'and' connection nodes
into a single Q-object.
"""
def visit_combination(self, combination):
if combination.operation == combination.AND:
# The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children):
queries = [node.query for node in combination.children]
return Q(**self._query_conjunction(queries))
return combination
def _query_conjunction(self, queries):
"""Merges query dicts - effectively &ing them together.
"""
query_ops = set()
combined_query = {}
for query in queries:
ops = set(query.keys())
# Make sure that the same operation isn't applied more than once
# to a single field
intersection = ops.intersection(query_ops)
if intersection:
msg = 'Duplicate query contitions: '
raise InvalidQueryError(msg + ', '.join(intersection))
query_ops.update(ops)
combined_query.update(copy.deepcopy(query))
return combined_query
class QueryTreeTransformerVisitor(QNodeVisitor):
"""Transforms the query tree in to a form that may be used with MongoDB.
"""
def visit_combination(self, combination):
if combination.operation == combination.AND:
# MongoDB doesn't allow us to have too many $or operations in our
# queries, so the aim is to move the ORs up the tree to one
# 'master' $or. Firstly, we must find all the necessary parts (part
# of an AND combination or just standard Q object), and store them
# separately from the OR parts.
or_groups = []
and_parts = []
for node in combination.children:
if isinstance(node, QCombination):
if node.operation == node.OR:
# Any of the children in an $or component may cause
# the query to succeed
or_groups.append(node.children)
elif node.operation == node.AND:
and_parts.append(node)
elif isinstance(node, Q):
and_parts.append(node)
# Now we combine the parts into a usable query. AND together all of
# the necessary parts. Then for each $or part, create a new query
# that ANDs the necessary part with the $or part.
clauses = []
for or_group in itertools.product(*or_groups):
q_object = reduce(lambda a, b: a & b, and_parts, Q())
q_object = reduce(lambda a, b: a & b, or_group, q_object)
clauses.append(q_object)
# Finally, $or the generated clauses in to one query. Each of the
# clauses is sufficient for the query to succeed.
return reduce(lambda a, b: a | b, clauses, Q())
if combination.operation == combination.OR:
children = []
# Crush any nested ORs in to this combination as MongoDB doesn't
# support nested $or operations
for node in combination.children:
if (isinstance(node, QCombination) and
node.operation == combination.OR):
children += node.children
else:
children.append(node)
combination.children = children
return combination
class QueryCompilerVisitor(QNodeVisitor):
"""Compiles the nodes in a query tree to a PyMongo-compatible query
dictionary.
"""
def __init__(self, document):
self.document = document
def visit_combination(self, combination):
if combination.operation == combination.OR:
return {'$or': combination.children}
elif combination.operation == combination.AND:
return self._mongo_query_conjunction(combination.children)
return combination
def visit_query(self, query):
return QuerySet._transform_query(self.document, **query.query)
def _mongo_query_conjunction(self, queries):
"""Merges Mongo query dicts - effectively &ing them together.
"""
combined_query = {}
for query in queries:
for field, ops in query.items():
if field not in combined_query:
combined_query[field] = ops
else:
# The field is already present in the query the only way
# we can merge is if both the existing value and the new
# value are operation dicts, reject anything else
if (not isinstance(combined_query[field], dict) or
not isinstance(ops, dict)):
message = 'Conflicting values for ' + field
raise InvalidQueryError(message)
current_ops = set(combined_query[field].keys())
new_ops = set(ops.keys())
# Make sure that the same operation isn't applied more than
# once to a single field
intersection = current_ops.intersection(new_ops)
if intersection:
msg = 'Duplicate query contitions: '
raise InvalidQueryError(msg + ', '.join(intersection))
# Right! We've got two non-overlapping dicts of operations!
combined_query[field].update(copy.deepcopy(ops))
return combined_query
class QNode(object):
"""Base class for nodes in query trees.
"""
AND = 0
OR = 1
def to_query(self, document):
query = self.accept(SimplificationVisitor())
query = query.accept(QueryTreeTransformerVisitor())
query = query.accept(QueryCompilerVisitor(document))
return query
def accept(self, visitor):
raise NotImplementedError
def _combine(self, other, operation):
"""Combine this node with another node into a QCombination object.
"""
if other.empty:
return self
if self.query[0]:
obj.query = (['('] + copy.deepcopy(self.query) + [op] +
copy.deepcopy(other.query) + [')'])
else:
obj.query = copy.deepcopy(other.query)
return obj
if self.empty:
return other
return QCombination(operation, [self, other])
@property
def empty(self):
return False
def __or__(self, other):
return self._combine(other, self.OR)
@@ -77,78 +224,49 @@ class Q(object):
def __and__(self, other):
return self._combine(other, self.AND)
def as_js(self, document):
js = []
js_scope = {}
for i, item in enumerate(self.query):
if isinstance(item, dict):
item_query = QuerySet._transform_query(document, **item)
# item_query will values will either be a value or a dict
js.append(self._item_query_as_js(item_query, js_scope, i))
class QCombination(QNode):
"""Represents the combination of several conditions by a given logical
operator.
"""
def __init__(self, operation, children):
self.operation = operation
self.children = []
for node in children:
# If the child is a combination of the same type, we can merge its
# children directly into this combinations children
if isinstance(node, QCombination) and node.operation == operation:
self.children += node.children
else:
js.append(item)
return pymongo.code.Code(' '.join(js), js_scope)
self.children.append(node)
def _item_query_as_js(self, item_query, js_scope, item_num):
# item_query will be in one of the following forms
# {'age': 25, 'name': 'Test'}
# {'age': {'$lt': 25}, 'name': {'$in': ['Test', 'Example']}
# {'age': {'$lt': 25, '$gt': 18}}
js = []
for i, (key, value) in enumerate(item_query.items()):
op = 'eq'
# Construct a variable name for the value in the JS
value_name = 'i%sf%s' % (item_num, i)
if isinstance(value, dict):
# Multiple operators for this field
for j, (op, value) in enumerate(value.items()):
# Create a custom variable name for this operator
op_value_name = '%so%s' % (value_name, j)
# Construct the JS that uses this op
value, operation_js = self._build_op_js(op, key, value,
op_value_name)
# Update the js scope with the value for this op
js_scope[op_value_name] = value
js.append(operation_js)
else:
# Construct the JS for this field
value, field_js = self._build_op_js(op, key, value, value_name)
js_scope[value_name] = value
js.append(field_js)
return ' && '.join(js)
def accept(self, visitor):
for i in range(len(self.children)):
self.children[i] = self.children[i].accept(visitor)
def _build_op_js(self, op, key, value, value_name):
"""Substitute the values in to the correct chunk of Javascript.
"""
if isinstance(value, RE_TYPE):
# Regexes are handled specially
if op.strip('$') == 'ne':
op_js = Q.OPERATORS['regex_ne']
else:
op_js = Q.OPERATORS['regex_eq']
else:
op_js = Q.OPERATORS[op.strip('$')]
return visitor.visit_combination(self)
# Comparing two ObjectIds in Javascript doesn't work..
if isinstance(value, pymongo.objectid.ObjectId):
value = unicode(value)
@property
def empty(self):
return not bool(self.children)
# Handle DBRef
if isinstance(value, pymongo.dbref.DBRef):
op_js = '(this.%(field)s.$id == "%(id)s" &&'\
' this.%(field)s.$ref == "%(ref)s")' % {
'field': key,
'id': unicode(value.id),
'ref': unicode(value.collection)
}
value = None
# Perform the substitution
operation_js = op_js % {
'field': key,
'value': value_name
}
return value, operation_js
class Q(QNode):
"""A simple query object, used in a query tree to build up more complex
query structures.
"""
def __init__(self, **query):
self.query = query
def accept(self, visitor):
return visitor.visit_query(self)
@property
def empty(self):
return not bool(self.query)
class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor,
@@ -159,19 +277,30 @@ class QuerySet(object):
self._document = document
self._collection_obj = collection
self._accessed_collection = False
self._query = {}
self._mongo_query = None
self._query_obj = Q()
self._initial_query = {}
self._where_clause = None
self._loaded_fields = []
self._ordering = []
self._snapshot = False
self._timeout = True
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get('allow_inheritance'):
self._query = {'_types': self._document._class_name}
self._initial_query = {'_types': self._document._class_name}
self._cursor_obj = None
self._limit = None
self._skip = None
@property
def _query(self):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query)
return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False,
**kwargs):
"""Ensure that the given indexes are in place.
@@ -230,10 +359,14 @@ class QuerySet(object):
objects, only the last one will be used
:param query: Django-style query keyword arguments
"""
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query)
if q_obj:
self._where_clause = q_obj.as_js(self._document)
query = QuerySet._transform_query(_doc_cls=self._document, **query)
self._query.update(query)
query &= q_obj
self._query_obj &= query
self._mongo_query = None
self._cursor_obj = None
return self
def filter(self, *q_objs, **query):
@@ -285,9 +418,12 @@ class QuerySet(object):
@property
def _cursor(self):
if self._cursor_obj is None:
cursor_args = {}
cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout,
}
if self._loaded_fields:
cursor_args = {'fields': self._loaded_fields}
cursor_args['fields'] = self._loaded_fields
self._cursor_obj = self._collection.find(self._query,
**cursor_args)
# Apply where clauses to cursor
@@ -335,7 +471,7 @@ class QuerySet(object):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists']
'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
@@ -343,6 +479,10 @@ class QuerySet(object):
mongo_query = {}
for key, value in query.items():
if key == "__raw__":
mongo_query.update(value)
continue
parts = key.split('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
@@ -351,6 +491,11 @@ class QuerySet(object):
if parts[-1] in operators + match_operators + geo_operators:
op = parts.pop()
negate = False
if parts[-1] == 'not':
parts.pop()
negate = True
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
@@ -358,7 +503,7 @@ class QuerySet(object):
# Convert value to proper value
field = fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte']
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += match_operators
if op in singular_ops:
value = field.prepare_query_value(op, value)
@@ -366,9 +511,6 @@ class QuerySet(object):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
if field.__class__.__name__ == 'GenericReferenceField':
parts.append('_ref')
# if op and op not in match_operators:
if op:
if op in geo_operators:
@@ -383,7 +525,10 @@ class QuerySet(object):
"been implemented" % op)
elif op not in match_operators:
value = {'$' + op: value}
if negate:
value = {'$not': value}
for i, part in indices:
parts.insert(i, part)
key = '.'.join(parts)
@@ -445,6 +590,7 @@ class QuerySet(object):
def create(self, **kwargs):
"""Create new object. Returns the saved object instance.
.. versionadded:: 0.4
"""
doc = self._document(**kwargs)
@@ -641,6 +787,7 @@ class QuerySet(object):
# Integer index provided
elif isinstance(key, int):
return self._document._from_son(self._cursor[key])
raise AttributeError
def distinct(self, field):
"""Return a list of distinct values for a given field.
@@ -649,7 +796,7 @@ class QuerySet(object):
.. versionadded:: 0.4
"""
return self._collection.distinct(field)
return self._cursor.distinct(field)
def only(self, *fields):
"""Load only a subset of this document's fields. ::
@@ -709,6 +856,20 @@ class QuerySet(object):
plan = pprint.pformat(plan)
return plan
def snapshot(self, enabled):
"""Enable or disable snapshot mode when querying.
:param enabled: whether or not snapshot mode is enabled
"""
self._snapshot = enabled
def timeout(self, enabled):
"""Enable or disable the default mongod timeout when querying.
:param enabled: whether or not the timeout is used
"""
self._timeout = enabled
def delete(self, safe=False):
"""Delete the documents matched by the query.
@@ -721,7 +882,7 @@ class QuerySet(object):
"""Transform an update spec from Django-style format to Mongo format.
"""
operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all',
'pull', 'pull_all']
'pull', 'pull_all', 'add_to_set']
mongo_update = {}
for key, value in update.items():
@@ -739,6 +900,8 @@ class QuerySet(object):
op = 'inc'
if value > 0:
value = -value
elif op == 'add_to_set':
op = op.replace('_to_set', 'ToSet')
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
@@ -747,7 +910,8 @@ class QuerySet(object):
# Convert value to proper value
field = fields[-1]
if op in (None, 'set', 'unset', 'pop', 'push', 'pull'):
if op in (None, 'set', 'unset', 'pop', 'push', 'pull',
'addToSet'):
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value]
@@ -913,20 +1077,27 @@ class QuerySet(object):
"""
return self.exec_js(average_func, field)
def item_frequencies(self, list_field, normalize=False):
"""Returns a dictionary of all items present in a list field across
def item_frequencies(self, field, normalize=False):
"""Returns a dictionary of all items present in a field across
the whole queried set of documents, and their corresponding frequency.
This is useful for generating tag clouds, or searching documents.
:param list_field: the list field to use
If the field is a :class:`~mongoengine.ListField`, the items within
each list will be counted individually.
:param field: the field to use
:param normalize: normalize the results so they add to 1.0
"""
freq_func = """
function(listField) {
function(field) {
if (options.normalize) {
var total = 0.0;
db[collection].find(query).forEach(function(doc) {
total += doc[listField].length;
if (doc[field].constructor == Array) {
total += doc[field].length;
} else {
total++;
}
});
}
@@ -936,14 +1107,19 @@ class QuerySet(object):
inc /= total;
}
db[collection].find(query).forEach(function(doc) {
doc[listField].forEach(function(item) {
if (doc[field].constructor == Array) {
doc[field].forEach(function(item) {
frequencies[item] = inc + (frequencies[item] || 0);
});
} else {
var item = doc[field];
frequencies[item] = inc + (frequencies[item] || 0);
});
}
});
return frequencies;
}
"""
return self.exec_js(freq_func, list_field, normalize=normalize)
return self.exec_js(freq_func, field, normalize=normalize)
def __repr__(self):
limit = REPR_OUTPUT_SIZE + 1
@@ -959,7 +1135,7 @@ class QuerySetManager(object):
def __init__(self, manager_func=None):
self._manager_func = manager_func
self._collection = None
self._collections = {}
def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when
@@ -969,10 +1145,9 @@ class QuerySetManager(object):
# Document class being used rather than a document object
return self
if self._collection is None:
db = _get_db()
collection = owner._meta['collection']
db = _get_db()
collection = owner._meta['collection']
if (db, collection) not in self._collections:
# Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta
@@ -980,10 +1155,10 @@ class QuerySetManager(object):
max_documents = owner._meta['max_documents']
if collection in db.collection_names():
self._collection = db[collection]
self._collections[(db, collection)] = db[collection]
# The collection already exists, check if its capped
# options match the specified capped options
options = self._collection.options()
options = self._collections[(db, collection)].options()
if options.get('max') != max_documents or \
options.get('size') != max_size:
msg = ('Cannot create collection "%s" as a capped '
@@ -994,13 +1169,15 @@ class QuerySetManager(object):
opts = {'capped': True, 'size': max_size}
if max_documents:
opts['max'] = max_documents
self._collection = db.create_collection(collection, **opts)
self._collections[(db, collection)] = db.create_collection(
collection, **opts
)
else:
self._collection = db[collection]
self._collections[(db, collection)] = db[collection]
# owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collection)
queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func:
if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset)