diff --git a/docs/apireference.rst b/docs/apireference.rst index 4fff317a..34d4536d 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -66,3 +66,5 @@ Fields .. autoclass:: mongoengine.GenericReferenceField .. autoclass:: mongoengine.FileField + +.. autoclass:: mongoengine.GeoPointField diff --git a/docs/changelog.rst b/docs/changelog.rst index 29f49cf1..d7c6fe85 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,16 +6,26 @@ Changes in v0.4 =============== - Added ``GridFSStorage`` Django storage backend - Added ``FileField`` for GridFS support +- New Q-object implementation, which is no longer based on Javascript - Added ``SortedListField`` - Added ``EmailField`` - Added ``GeoPointField`` - Added ``exact`` and ``iexact`` match operators to ``QuerySet`` - Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts -- Fixed bug in Q-objects +- Added new query operators for Geo queries +- Added ``not`` query operator +- Added new update operators: ``pop`` and ``add_to_set`` +- Added ``__raw__`` query parameter +- Added support for custom querysets - Fixed document inheritance primary key issue +- Added support for querying by array element position - Base class can now be defined for ``DictField`` - Fixed MRO error that occured on document inheritance +- Added ``QuerySet.distinct``, ``QuerySet.create``, ``QuerySet.snapshot``, + ``QuerySet.timeout`` and ``QuerySet.all`` +- Subsequent calls to ``connect()`` now work - Introduced ``min_length`` for ``StringField`` +- Fixed multi-process connection issue - Other minor fixes Changes in v0.3 diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 0caa4227..106d4ec8 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -47,11 +47,11 @@ are as follows: * :class:`~mongoengine.ReferenceField` * :class:`~mongoengine.GenericReferenceField` * :class:`~mongoengine.BooleanField` -* :class:`~mongoengine.GeoLocationField` * :class:`~mongoengine.FileField` * :class:`~mongoengine.EmailField` * :class:`~mongoengine.SortedListField` * :class:`~mongoengine.BinaryField` +* :class:`~mongoengine.GeoPointField` Field arguments --------------- @@ -72,6 +72,25 @@ arguments can be set on all fields: :attr:`default` (Default: None) A value to use when no value is set for this field. + The definion of default parameters follow `the general rules on Python + `__, + which means that some care should be taken when dealing with default mutable objects + (like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`):: + + class ExampleFirst(Document): + # Default an empty list + values = ListField(IntField(), default=list) + + class ExampleSecond(Document): + # Default a set of values + values = ListField(IntField(), default=lambda: [1,2,3]) + + class ExampleDangerous(Document): + # This can make an .append call to add values to the default (and all the following objects), + # instead to just an object + values = ListField(IntField(), default=[1,2,3]) + + :attr:`unique` (Default: False) When True, no documents in the collection will have the same value for this field. @@ -279,6 +298,10 @@ or a **-** sign. Note that direction only matters on multi-field indexes. :: meta = { 'indexes': ['title', ('title', '-rating')] } + +.. note:: + Geospatial indexes will be automatically created for all + :class:`~mongoengine.GeoPointField`\ s Ordering ======== diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index bef19bc5..58bf9f63 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -53,6 +53,16 @@ lists that contain that item will be matched:: # 'tags' list Page.objects(tags='coding') +Raw queries +----------- +It is possible to provide a raw PyMongo query as a query parameter, which will +be integrated directly into the query. This is done using the ``__raw__`` +keyword argument:: + + Page.objects(__raw__={'tags': 'coding'}) + +.. versionadded:: 0.4 + Query operators =============== Operators other than equality may also be used in queries; just attach the @@ -68,6 +78,8 @@ Available operators are as follows: * ``lte`` -- less than or equal to * ``gt`` -- greater than * ``gte`` -- greater than or equal to +* ``not`` -- negate a standard check, may be used before other operators (e.g. + ``Q(age__not__mod=5)``) * ``in`` -- value is in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values @@ -89,6 +101,27 @@ expressions: .. versionadded:: 0.3 +There are a few special operators for performing geographical queries, that +may used with :class:`~mongoengine.GeoPointField`\ s: + +* ``within_distance`` -- provide a list containing a point and a maximum + distance (e.g. [(41.342, -87.653), 5]) +* ``within_box`` -- filter documents to those within a given bounding box (e.g. + [(35.0, -125.0), (40.0, -100.0)]) +* ``near`` -- order the documents by how close they are to a given point + +.. versionadded:: 0.4 + +Querying by position +==================== +It is possible to query by position in a list by using a numerical value as a +query operator. So if you wanted to find all pages whose first tag was ``db``, +you could use the following query:: + + BlogPost.objects(tags__0='db') + +.. versionadded:: 0.4 + Limiting and skipping results ============================= Just as with traditional ORMs, you may limit the number of results returned, or @@ -181,6 +214,22 @@ custom manager methods as you like:: assert len(BlogPost.objects) == 2 assert len(BlogPost.live_posts) == 1 +Custom QuerySets +================ +Should you want to add custom methods for interacting with or filtering +documents, extending the :class:`~mongoengine.queryset.QuerySet` class may be +the way to go. To use a custom :class:`~mongoengine.queryset.QuerySet` class on +a document, set ``queryset_class`` to the custom class in a +:class:`~mongoengine.Document`\ s ``meta`` dictionary:: + + class AwesomerQuerySet(QuerySet): + pass + + class Page(Document): + meta = {'queryset_class': AwesomerQuerySet} + +.. versionadded:: 0.4 + Aggregation =========== MongoDB provides some aggregation methods out of the box, but there are not as @@ -402,8 +451,10 @@ that you may use with these methods: * ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list +* ``pop`` -- remove the first or last element of a list * ``pull`` -- remove a value from a list * ``pull_all`` -- remove several values from a list +* ``add_to_set`` -- add value to a list only if its not in the list already The syntax for atomic updates is similar to the querying syntax, but the modifier comes before the field, not after it:: diff --git a/docs/index.rst b/docs/index.rst index a28b344c..ccb7fbe2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ MongoDB. To install it, simply run .. code-block:: console - # easy_install -U mongoengine + # pip install -U mongoengine The source is available on `GitHub `_. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index e01d31ae..6d18ffe7 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ + __author__ = 'Harry Marr' -VERSION = (0, 3, 0) +VERSION = (0, 4, 0) def get_version(): version = '%s.%s' % (VERSION[0], VERSION[1]) diff --git a/mongoengine/base.py b/mongoengine/base.py index af21db0e..31377f07 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -3,6 +3,7 @@ from queryset import DoesNotExist, MultipleObjectsReturned import sys import pymongo +import pymongo.objectid _document_registry = {} @@ -203,6 +204,9 @@ class DocumentMetaclass(type): exc = subclass_exception('MultipleObjectsReturned', base_excs, module) new_class.add_to_class('MultipleObjectsReturned', exc) + global _document_registry + _document_registry[name] = new_class + return new_class def add_to_class(self, name, value): @@ -215,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): """ def __new__(cls, name, bases, attrs): - global _document_registry - super_new = super(TopLevelDocumentMetaclass, cls).__new__ # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. @@ -321,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class._fields['id'] = ObjectIdField(db_field='_id') new_class.id = new_class._fields['id'] - _document_registry[name] = new_class - return new_class diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 94cc6ea1..814fde13 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -4,11 +4,12 @@ import multiprocessing __all__ = ['ConnectionError', 'connect'] -_connection_settings = { +_connection_defaults = { 'host': 'localhost', 'port': 27017, } _connection = {} +_connection_settings = _connection_defaults.copy() _db_name = None _db_username = None @@ -20,25 +21,25 @@ class ConnectionError(Exception): pass -def _get_connection(): +def _get_connection(reconnect=False): global _connection identity = get_identity() # Connect to the database if not already connected - if _connection.get(identity) is None: + if _connection.get(identity) is None or reconnect: try: _connection[identity] = Connection(**_connection_settings) except: raise ConnectionError('Cannot connect to the database') return _connection[identity] -def _get_db(): +def _get_db(reconnect=False): global _db, _connection identity = get_identity() # Connect if not already connected - if _connection.get(identity) is None: - _connection[identity] = _get_connection() + if _connection.get(identity) is None or reconnect: + _connection[identity] = _get_connection(reconnect=reconnect) - if _db.get(identity) is None: + if _db.get(identity) is None or reconnect: # _db_name will be None if the user hasn't called connect() if _db_name is None: raise ConnectionError('Not connected to the database') @@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs): the default port on localhost. If authentication is needed, provide username and password arguments as well. """ - global _connection_settings, _db_name, _db_username, _db_password - _connection_settings.update(kwargs) + global _connection_settings, _db_name, _db_username, _db_password, _db + _connection_settings = dict(_connection_defaults, **kwargs) _db_name = db _db_username = username _db_password = password - return _get_db() \ No newline at end of file + return _get_db(reconnect=True) + diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 23a88dee..30a508d1 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -5,6 +5,9 @@ from operator import itemgetter import re import pymongo +import pymongo.dbref +import pymongo.son +import pymongo.binary import datetime import decimal import gridfs @@ -106,8 +109,11 @@ class URLField(StringField): message = 'This URL appears to be a broken link: %s' % e raise ValidationError(message) + class EmailField(StringField): """A field that validates input as an E-Mail-Address. + + .. versionadded:: 0.4 """ EMAIL_REGEX = re.compile( @@ -115,11 +121,12 @@ class EmailField(StringField): r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) - + def validate(self, value): if not EmailField.EMAIL_REGEX.match(value): raise ValidationError('Invalid Mail-address: %s' % value) + class IntField(BaseField): """An integer field. """ @@ -143,6 +150,7 @@ class IntField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Integer value is too large') + class FloatField(BaseField): """An floating point number field. """ @@ -179,7 +187,7 @@ class DecimalField(BaseField): if not isinstance(value, basestring): value = unicode(value) return decimal.Decimal(value) - + def to_mongo(self, value): return unicode(value) @@ -198,6 +206,7 @@ class DecimalField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Decimal value is too large') + class BooleanField(BaseField): """A boolean field type. @@ -210,6 +219,7 @@ class BooleanField(BaseField): def validate(self, value): assert isinstance(value, bool) + class DateTimeField(BaseField): """A datetime field. """ @@ -217,38 +227,49 @@ class DateTimeField(BaseField): def validate(self, value): assert isinstance(value, datetime.datetime) + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. """ - def __init__(self, document, **kwargs): - if not issubclass(document, EmbeddedDocument): - raise ValidationError('Invalid embedded document class provided ' - 'to an EmbeddedDocumentField') - self.document = document + def __init__(self, document_type, **kwargs): + if not isinstance(document_type, basestring): + if not issubclass(document_type, EmbeddedDocument): + raise ValidationError('Invalid embedded document class ' + 'provided to an EmbeddedDocumentField') + self.document_type_obj = document_type super(EmbeddedDocumentField, self).__init__(**kwargs) + @property + def document_type(self): + if isinstance(self.document_type_obj, basestring): + if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: + self.document_type_obj = self.owner_document + else: + self.document_type_obj = get_document(self.document_type_obj) + return self.document_type_obj + def to_python(self, value): - if not isinstance(value, self.document): - return self.document._from_son(value) + if not isinstance(value, self.document_type): + return self.document_type._from_son(value) return value def to_mongo(self, value): - return self.document.to_mongo(value) + return self.document_type.to_mongo(value) def validate(self, value): """Make sure that the document instance is an instance of the EmbeddedDocument subclass provided when the document was defined. """ # Using isinstance also works for subclasses of self.document - if not isinstance(value, self.document): + if not isinstance(value, self.document_type): raise ValidationError('Invalid embedded document instance ' 'provided to an EmbeddedDocumentField') - self.document.validate(value) + self.document_type.validate(value) def lookup_member(self, member_name): - return self.document._fields.get(member_name) + return self.document_type._fields.get(member_name) def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -322,20 +343,32 @@ class ListField(BaseField): try: [self.field.validate(item) for item in value] except Exception, err: - raise ValidationError('Invalid ListField item (%s)' % str(err)) + raise ValidationError('Invalid ListField item (%s)' % str(item)) def prepare_query_value(self, op, value): if op in ('set', 'unset'): - return [self.field.to_mongo(v) for v in value] - return self.field.to_mongo(value) + return [self.field.prepare_query_value(op, v) for v in value] + return self.field.prepare_query_value(op, value) def lookup_member(self, member_name): return self.field.lookup_member(member_name) + def _set_owner_document(self, owner_document): + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + + class SortedListField(ListField): """A ListField that sorts the contents of its list before writing to the database in order to ensure that a sorted list is always retrieved. + + .. versionadded:: 0.4 """ _ordering = None @@ -351,6 +384,7 @@ class SortedListField(ListField): key=itemgetter(self._ordering)) return sorted([self.field.to_mongo(item) for item in value]) + class DictField(BaseField): """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. @@ -389,7 +423,6 @@ class ReferenceField(BaseField): raise ValidationError('Argument to ReferenceField constructor ' 'must be a document class or a string') self.document_type_obj = document_type - self.document_obj = None super(ReferenceField, self).__init__(**kwargs) @property @@ -444,6 +477,7 @@ class ReferenceField(BaseField): def lookup_member(self, member_name): return self.document_type._fields.get(member_name) + class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -488,7 +522,7 @@ class GenericReferenceField(BaseField): return {'_cls': document.__class__.__name__, '_ref': ref} def prepare_query_value(self, op, value): - return self.to_mongo(value)['_ref'] + return self.to_mongo(value) class BinaryField(BaseField): @@ -655,6 +689,8 @@ class FileField(BaseField): class GeoPointField(BaseField): """A list storing a latitude and longitude. + + .. versionadded:: 0.4 """ _geo_index = True diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1795395c..95479bd2 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -2,8 +2,12 @@ from connection import _get_db import pprint import pymongo +import pymongo.code +import pymongo.dbref +import pymongo.objectid import re import copy +import itertools __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', 'InvalidCollectionError'] @@ -15,6 +19,7 @@ REPR_OUTPUT_SIZE = 20 class DoesNotExist(Exception): pass + class MultipleObjectsReturned(Exception): pass @@ -26,50 +31,192 @@ class InvalidQueryError(Exception): class OperationError(Exception): pass + class InvalidCollectionError(Exception): pass + RE_TYPE = type(re.compile('')) -class Q(object): +class QNodeVisitor(object): + """Base visitor class for visiting Q-object nodes in a query tree. + """ - OR = '||' - AND = '&&' - OPERATORS = { - 'eq': ('((this.%(field)s instanceof Array) && ' - ' this.%(field)s.indexOf(%(value)s) != -1) ||' - ' this.%(field)s == %(value)s'), - 'ne': 'this.%(field)s != %(value)s', - 'gt': 'this.%(field)s > %(value)s', - 'gte': 'this.%(field)s >= %(value)s', - 'lt': 'this.%(field)s < %(value)s', - 'lte': 'this.%(field)s <= %(value)s', - 'lte': 'this.%(field)s <= %(value)s', - 'in': '%(value)s.indexOf(this.%(field)s) != -1', - 'nin': '%(value)s.indexOf(this.%(field)s) == -1', - 'mod': '%(field)s %% %(value)s', - 'all': ('%(value)s.every(function(a){' - 'return this.%(field)s.indexOf(a) != -1 })'), - 'size': 'this.%(field)s.length == %(value)s', - 'exists': 'this.%(field)s != null', - 'regex_eq': '%(value)s.test(this.%(field)s)', - 'regex_ne': '!%(value)s.test(this.%(field)s)', - } + def visit_combination(self, combination): + """Called by QCombination objects. + """ + return combination - def __init__(self, **query): - self.query = [query] + def visit_query(self, query): + """Called by (New)Q objects. + """ + return query - def _combine(self, other, op): - obj = Q() - if not other.query[0]: + +class SimplificationVisitor(QNodeVisitor): + """Simplifies query trees by combinging unnecessary 'and' connection nodes + into a single Q-object. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # The simplification only applies to 'simple' queries + if all(isinstance(node, Q) for node in combination.children): + queries = [node.query for node in combination.children] + return Q(**self._query_conjunction(queries)) + return combination + + def _query_conjunction(self, queries): + """Merges query dicts - effectively &ing them together. + """ + query_ops = set() + combined_query = {} + for query in queries: + ops = set(query.keys()) + # Make sure that the same operation isn't applied more than once + # to a single field + intersection = ops.intersection(query_ops) + if intersection: + msg = 'Duplicate query contitions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + query_ops.update(ops) + combined_query.update(copy.deepcopy(query)) + return combined_query + + +class QueryTreeTransformerVisitor(QNodeVisitor): + """Transforms the query tree in to a form that may be used with MongoDB. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # MongoDB doesn't allow us to have too many $or operations in our + # queries, so the aim is to move the ORs up the tree to one + # 'master' $or. Firstly, we must find all the necessary parts (part + # of an AND combination or just standard Q object), and store them + # separately from the OR parts. + or_groups = [] + and_parts = [] + for node in combination.children: + if isinstance(node, QCombination): + if node.operation == node.OR: + # Any of the children in an $or component may cause + # the query to succeed + or_groups.append(node.children) + elif node.operation == node.AND: + and_parts.append(node) + elif isinstance(node, Q): + and_parts.append(node) + + # Now we combine the parts into a usable query. AND together all of + # the necessary parts. Then for each $or part, create a new query + # that ANDs the necessary part with the $or part. + clauses = [] + for or_group in itertools.product(*or_groups): + q_object = reduce(lambda a, b: a & b, and_parts, Q()) + q_object = reduce(lambda a, b: a & b, or_group, q_object) + clauses.append(q_object) + + # Finally, $or the generated clauses in to one query. Each of the + # clauses is sufficient for the query to succeed. + return reduce(lambda a, b: a | b, clauses, Q()) + + if combination.operation == combination.OR: + children = [] + # Crush any nested ORs in to this combination as MongoDB doesn't + # support nested $or operations + for node in combination.children: + if (isinstance(node, QCombination) and + node.operation == combination.OR): + children += node.children + else: + children.append(node) + combination.children = children + + return combination + + +class QueryCompilerVisitor(QNodeVisitor): + """Compiles the nodes in a query tree to a PyMongo-compatible query + dictionary. + """ + + def __init__(self, document): + self.document = document + + def visit_combination(self, combination): + if combination.operation == combination.OR: + return {'$or': combination.children} + elif combination.operation == combination.AND: + return self._mongo_query_conjunction(combination.children) + return combination + + def visit_query(self, query): + return QuerySet._transform_query(self.document, **query.query) + + def _mongo_query_conjunction(self, queries): + """Merges Mongo query dicts - effectively &ing them together. + """ + combined_query = {} + for query in queries: + for field, ops in query.items(): + if field not in combined_query: + combined_query[field] = ops + else: + # The field is already present in the query the only way + # we can merge is if both the existing value and the new + # value are operation dicts, reject anything else + if (not isinstance(combined_query[field], dict) or + not isinstance(ops, dict)): + message = 'Conflicting values for ' + field + raise InvalidQueryError(message) + + current_ops = set(combined_query[field].keys()) + new_ops = set(ops.keys()) + # Make sure that the same operation isn't applied more than + # once to a single field + intersection = current_ops.intersection(new_ops) + if intersection: + msg = 'Duplicate query contitions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + # Right! We've got two non-overlapping dicts of operations! + combined_query[field].update(copy.deepcopy(ops)) + return combined_query + + +class QNode(object): + """Base class for nodes in query trees. + """ + + AND = 0 + OR = 1 + + def to_query(self, document): + query = self.accept(SimplificationVisitor()) + query = query.accept(QueryTreeTransformerVisitor()) + query = query.accept(QueryCompilerVisitor(document)) + return query + + def accept(self, visitor): + raise NotImplementedError + + def _combine(self, other, operation): + """Combine this node with another node into a QCombination object. + """ + if other.empty: return self - if self.query[0]: - obj.query = (['('] + copy.deepcopy(self.query) + [op] + - copy.deepcopy(other.query) + [')']) - else: - obj.query = copy.deepcopy(other.query) - return obj + + if self.empty: + return other + + return QCombination(operation, [self, other]) + + @property + def empty(self): + return False def __or__(self, other): return self._combine(other, self.OR) @@ -77,78 +224,49 @@ class Q(object): def __and__(self, other): return self._combine(other, self.AND) - def as_js(self, document): - js = [] - js_scope = {} - for i, item in enumerate(self.query): - if isinstance(item, dict): - item_query = QuerySet._transform_query(document, **item) - # item_query will values will either be a value or a dict - js.append(self._item_query_as_js(item_query, js_scope, i)) + +class QCombination(QNode): + """Represents the combination of several conditions by a given logical + operator. + """ + + def __init__(self, operation, children): + self.operation = operation + self.children = [] + for node in children: + # If the child is a combination of the same type, we can merge its + # children directly into this combinations children + if isinstance(node, QCombination) and node.operation == operation: + self.children += node.children else: - js.append(item) - return pymongo.code.Code(' '.join(js), js_scope) + self.children.append(node) - def _item_query_as_js(self, item_query, js_scope, item_num): - # item_query will be in one of the following forms - # {'age': 25, 'name': 'Test'} - # {'age': {'$lt': 25}, 'name': {'$in': ['Test', 'Example']} - # {'age': {'$lt': 25, '$gt': 18}} - js = [] - for i, (key, value) in enumerate(item_query.items()): - op = 'eq' - # Construct a variable name for the value in the JS - value_name = 'i%sf%s' % (item_num, i) - if isinstance(value, dict): - # Multiple operators for this field - for j, (op, value) in enumerate(value.items()): - # Create a custom variable name for this operator - op_value_name = '%so%s' % (value_name, j) - # Construct the JS that uses this op - value, operation_js = self._build_op_js(op, key, value, - op_value_name) - # Update the js scope with the value for this op - js_scope[op_value_name] = value - js.append(operation_js) - else: - # Construct the JS for this field - value, field_js = self._build_op_js(op, key, value, value_name) - js_scope[value_name] = value - js.append(field_js) - return ' && '.join(js) + def accept(self, visitor): + for i in range(len(self.children)): + self.children[i] = self.children[i].accept(visitor) - def _build_op_js(self, op, key, value, value_name): - """Substitute the values in to the correct chunk of Javascript. - """ - if isinstance(value, RE_TYPE): - # Regexes are handled specially - if op.strip('$') == 'ne': - op_js = Q.OPERATORS['regex_ne'] - else: - op_js = Q.OPERATORS['regex_eq'] - else: - op_js = Q.OPERATORS[op.strip('$')] + return visitor.visit_combination(self) - # Comparing two ObjectIds in Javascript doesn't work.. - if isinstance(value, pymongo.objectid.ObjectId): - value = unicode(value) + @property + def empty(self): + return not bool(self.children) - # Handle DBRef - if isinstance(value, pymongo.dbref.DBRef): - op_js = '(this.%(field)s.$id == "%(id)s" &&'\ - ' this.%(field)s.$ref == "%(ref)s")' % { - 'field': key, - 'id': unicode(value.id), - 'ref': unicode(value.collection) - } - value = None - # Perform the substitution - operation_js = op_js % { - 'field': key, - 'value': value_name - } - return value, operation_js +class Q(QNode): + """A simple query object, used in a query tree to build up more complex + query structures. + """ + + def __init__(self, **query): + self.query = query + + def accept(self, visitor): + return visitor.visit_query(self) + + @property + def empty(self): + return not bool(self.query) + class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -159,19 +277,30 @@ class QuerySet(object): self._document = document self._collection_obj = collection self._accessed_collection = False - self._query = {} + self._mongo_query = None + self._query_obj = Q() + self._initial_query = {} self._where_clause = None self._loaded_fields = [] self._ordering = [] + self._snapshot = False + self._timeout = True # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get('allow_inheritance'): - self._query = {'_types': self._document._class_name} + self._initial_query = {'_types': self._document._class_name} self._cursor_obj = None self._limit = None self._skip = None + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + self._mongo_query.update(self._initial_query) + return self._mongo_query + def ensure_index(self, key_or_list, drop_dups=False, background=False, **kwargs): """Ensure that the given indexes are in place. @@ -230,10 +359,14 @@ class QuerySet(object): objects, only the last one will be used :param query: Django-style query keyword arguments """ + #if q_obj: + #self._where_clause = q_obj.as_js(self._document) + query = Q(**query) if q_obj: - self._where_clause = q_obj.as_js(self._document) - query = QuerySet._transform_query(_doc_cls=self._document, **query) - self._query.update(query) + query &= q_obj + self._query_obj &= query + self._mongo_query = None + self._cursor_obj = None return self def filter(self, *q_objs, **query): @@ -285,9 +418,12 @@ class QuerySet(object): @property def _cursor(self): if self._cursor_obj is None: - cursor_args = {} + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + } if self._loaded_fields: - cursor_args = {'fields': self._loaded_fields} + cursor_args['fields'] = self._loaded_fields self._cursor_obj = self._collection.find(self._query, **cursor_args) # Apply where clauses to cursor @@ -335,7 +471,7 @@ class QuerySet(object): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists'] + 'all', 'size', 'exists', 'not'] geo_operators = ['within_distance', 'within_box', 'near'] match_operators = ['contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', @@ -343,6 +479,10 @@ class QuerySet(object): mongo_query = {} for key, value in query.items(): + if key == "__raw__": + mongo_query.update(value) + continue + parts = key.split('__') indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] parts = [part for part in parts if not part.isdigit()] @@ -351,6 +491,11 @@ class QuerySet(object): if parts[-1] in operators + match_operators + geo_operators: op = parts.pop() + negate = False + if parts[-1] == 'not': + parts.pop() + negate = True + if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) @@ -358,7 +503,7 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte'] + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += match_operators if op in singular_ops: value = field.prepare_query_value(op, value) @@ -366,9 +511,6 @@ class QuerySet(object): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] - if field.__class__.__name__ == 'GenericReferenceField': - parts.append('_ref') - # if op and op not in match_operators: if op: if op in geo_operators: @@ -383,7 +525,10 @@ class QuerySet(object): "been implemented" % op) elif op not in match_operators: value = {'$' + op: value} - + + if negate: + value = {'$not': value} + for i, part in indices: parts.insert(i, part) key = '.'.join(parts) @@ -445,6 +590,7 @@ class QuerySet(object): def create(self, **kwargs): """Create new object. Returns the saved object instance. + .. versionadded:: 0.4 """ doc = self._document(**kwargs) @@ -641,6 +787,7 @@ class QuerySet(object): # Integer index provided elif isinstance(key, int): return self._document._from_son(self._cursor[key]) + raise AttributeError def distinct(self, field): """Return a list of distinct values for a given field. @@ -649,7 +796,7 @@ class QuerySet(object): .. versionadded:: 0.4 """ - return self._collection.distinct(field) + return self._cursor.distinct(field) def only(self, *fields): """Load only a subset of this document's fields. :: @@ -709,6 +856,20 @@ class QuerySet(object): plan = pprint.pformat(plan) return plan + def snapshot(self, enabled): + """Enable or disable snapshot mode when querying. + + :param enabled: whether or not snapshot mode is enabled + """ + self._snapshot = enabled + + def timeout(self, enabled): + """Enable or disable the default mongod timeout when querying. + + :param enabled: whether or not the timeout is used + """ + self._timeout = enabled + def delete(self, safe=False): """Delete the documents matched by the query. @@ -721,7 +882,7 @@ class QuerySet(object): """Transform an update spec from Django-style format to Mongo format. """ operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', - 'pull', 'pull_all'] + 'pull', 'pull_all', 'add_to_set'] mongo_update = {} for key, value in update.items(): @@ -739,6 +900,8 @@ class QuerySet(object): op = 'inc' if value > 0: value = -value + elif op == 'add_to_set': + op = op.replace('_to_set', 'ToSet') if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] @@ -747,7 +910,8 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - if op in (None, 'set', 'unset', 'pop', 'push', 'pull'): + if op in (None, 'set', 'unset', 'pop', 'push', 'pull', + 'addToSet'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] @@ -913,20 +1077,27 @@ class QuerySet(object): """ return self.exec_js(average_func, field) - def item_frequencies(self, list_field, normalize=False): - """Returns a dictionary of all items present in a list field across + def item_frequencies(self, field, normalize=False): + """Returns a dictionary of all items present in a field across the whole queried set of documents, and their corresponding frequency. This is useful for generating tag clouds, or searching documents. - :param list_field: the list field to use + If the field is a :class:`~mongoengine.ListField`, the items within + each list will be counted individually. + + :param field: the field to use :param normalize: normalize the results so they add to 1.0 """ freq_func = """ - function(listField) { + function(field) { if (options.normalize) { var total = 0.0; db[collection].find(query).forEach(function(doc) { - total += doc[listField].length; + if (doc[field].constructor == Array) { + total += doc[field].length; + } else { + total++; + } }); } @@ -936,14 +1107,19 @@ class QuerySet(object): inc /= total; } db[collection].find(query).forEach(function(doc) { - doc[listField].forEach(function(item) { + if (doc[field].constructor == Array) { + doc[field].forEach(function(item) { + frequencies[item] = inc + (frequencies[item] || 0); + }); + } else { + var item = doc[field]; frequencies[item] = inc + (frequencies[item] || 0); - }); + } }); return frequencies; } """ - return self.exec_js(freq_func, list_field, normalize=normalize) + return self.exec_js(freq_func, field, normalize=normalize) def __repr__(self): limit = REPR_OUTPUT_SIZE + 1 @@ -959,7 +1135,7 @@ class QuerySetManager(object): def __init__(self, manager_func=None): self._manager_func = manager_func - self._collection = None + self._collections = {} def __get__(self, instance, owner): """Descriptor for instantiating a new QuerySet object when @@ -969,10 +1145,9 @@ class QuerySetManager(object): # Document class being used rather than a document object return self - if self._collection is None: - db = _get_db() - collection = owner._meta['collection'] - + db = _get_db() + collection = owner._meta['collection'] + if (db, collection) not in self._collections: # Create collection as a capped collection if specified if owner._meta['max_size'] or owner._meta['max_documents']: # Get max document limit and max byte size from meta @@ -980,10 +1155,10 @@ class QuerySetManager(object): max_documents = owner._meta['max_documents'] if collection in db.collection_names(): - self._collection = db[collection] + self._collections[(db, collection)] = db[collection] # The collection already exists, check if its capped # options match the specified capped options - options = self._collection.options() + options = self._collections[(db, collection)].options() if options.get('max') != max_documents or \ options.get('size') != max_size: msg = ('Cannot create collection "%s" as a capped ' @@ -994,13 +1169,15 @@ class QuerySetManager(object): opts = {'capped': True, 'size': max_size} if max_documents: opts['max'] = max_documents - self._collection = db.create_collection(collection, **opts) + self._collections[(db, collection)] = db.create_collection( + collection, **opts + ) else: - self._collection = db[collection] + self._collections[(db, collection)] = db[collection] # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet - queryset = queryset_class(owner, self._collection) + queryset = queryset_class(owner, self._collections[(db, collection)]) if self._manager_func: if self._manager_func.func_code.co_argcount == 1: queryset = self._manager_func(queryset) diff --git a/tests/document.py b/tests/document.py index c2e0d6f9..aa901813 100644 --- a/tests/document.py +++ b/tests/document.py @@ -200,6 +200,37 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_inherited_collections(self): + """Ensure that subclassed documents don't override parents' collections. + """ + class Drink(Document): + name = StringField() + + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + class Drinker(Document): + drink = GenericReferenceField() + + Drink.drop_collection() + AlcoholicDrink.drop_collection() + Drinker.drop_collection() + + red_bull = Drink(name='Red Bull') + red_bull.save() + + programmer = Drinker(drink=red_bull) + programmer.save() + + beer = AlcoholicDrink(name='Beer') + beer.save() + + real_person = Drinker(drink=beer) + real_person.save() + + self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) + self.assertEqual(Drinker.objects[1].drink.name, beer.name) + def test_capped_collection(self): """Ensure that capped collections work properly. """ diff --git a/tests/fields.py b/tests/fields.py index f5f38fc2..5602cdec 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -189,6 +189,9 @@ class FieldTest(unittest.TestCase): def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. """ + class User(Document): + pass + class Comment(EmbeddedDocument): content = StringField() @@ -196,6 +199,7 @@ class FieldTest(unittest.TestCase): content = StringField() comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) + authors = ListField(ReferenceField(User)) post = BlogPost(content='Went for a walk today...') post.validate() @@ -210,15 +214,21 @@ class FieldTest(unittest.TestCase): post.tags = ('fun', 'leisure') post.validate() - comments = [Comment(content='Good for you'), Comment(content='Yay.')] - post.comments = comments - post.validate() - post.comments = ['a'] self.assertRaises(ValidationError, post.validate) post.comments = 'yay' self.assertRaises(ValidationError, post.validate) + comments = [Comment(content='Good for you'), Comment(content='Yay.')] + post.comments = comments + post.validate() + + post.authors = [Comment()] + self.assertRaises(ValidationError, post.validate) + + post.authors = [User()] + post.validate() + def test_sorted_list_sorting(self): """Ensure that a sorted list field properly sorts values. """ @@ -395,14 +405,54 @@ class FieldTest(unittest.TestCase): class Employee(Document): name = StringField() boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) bill = Employee(name='Bill Lumbergh') bill.save() - peter = Employee(name='Peter Gibbons', boss=bill) + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) peter.save() peter = Employee.objects.with_id(peter.id) self.assertEqual(peter.boss, bill) + self.assertEqual(peter.friends, friends) + + def test_recursive_embedding(self): + """Ensure that EmbeddedDocumentFields can contain their own documents. + """ + class Tree(Document): + name = StringField() + children = ListField(EmbeddedDocumentField('TreeNode')) + + class TreeNode(EmbeddedDocument): + name = StringField() + children = ListField(EmbeddedDocumentField('self')) + + tree = Tree(name="Tree") + + first_child = TreeNode(name="Child 1") + tree.children.append(first_child) + + second_child = TreeNode(name="Child 2") + first_child.children.append(second_child) + + third_child = TreeNode(name="Child 3") + first_child.children.append(third_child) + + tree.save() + + tree_obj = Tree.objects.first() + self.assertEqual(len(tree.children), 1) + self.assertEqual(tree.children[0].name, first_child.name) + self.assertEqual(tree.children[0].children[0].name, second_child.name) + self.assertEqual(tree.children[0].children[1].name, third_child.name) def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. diff --git a/tests/queryset.py b/tests/queryset.py index 59a8216e..38e1cb33 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -53,9 +53,6 @@ class QuerySetTest(unittest.TestCase): person2 = self.Person(name="User B", age=30) person2.save() - q1 = Q(name='test') - q2 = Q(age__gte=18) - # Find all people in the collection people = self.Person.objects self.assertEqual(len(people), 2) @@ -156,7 +153,8 @@ class QuerySetTest(unittest.TestCase): # Retrieve the first person from the database self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) - self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get) + self.assertRaises(self.Person.MultipleObjectsReturned, + self.Person.objects.get) # Use a query to filter the people found to just person2 person = self.Person.objects.get(age=30) @@ -234,7 +232,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(created, False) # Try retrieving when no objects exists - new doc should be created - person, created = self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'}) + kwargs = dict(age=50, defaults={'name': 'User C'}) + person, created = self.Person.objects.get_or_create(**kwargs) self.assertEqual(created, True) person = self.Person.objects.get(age=50) @@ -337,6 +336,18 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first() self.assertEqual(obj, person) + def test_not(self): + """Ensure that the __not operator works as expected. + """ + alice = self.Person(name='Alice', age=25) + alice.save() + + obj = self.Person.objects(name__iexact='alice').first() + self.assertEqual(obj, alice) + + obj = self.Person.objects(name__not__iexact='alice').first() + self.assertEqual(obj, None) + def test_filter_chaining(self): """Ensure filters can be chained together. """ @@ -546,9 +557,10 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__ne=re.compile('^bob'))).first() + obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__ne=re.compile('^Gui'))).first() + + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() self.assertEqual(obj, None) def test_q_lists(self): @@ -717,6 +729,11 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.tags, tags) + BlogPost.objects.update_one(add_to_set__tags='unique') + BlogPost.objects.update_one(add_to_set__tags='unique') + post.reload() + self.assertEqual(post.tags.count('unique'), 1) + BlogPost.drop_collection() def test_update_pull(self): @@ -968,7 +985,7 @@ class QuerySetTest(unittest.TestCase): BlogPost(hits=1, tags=['music', 'film', 'actors']).save() BlogPost(hits=2, tags=['music']).save() - BlogPost(hits=3, tags=['music', 'actors']).save() + BlogPost(hits=2, tags=['music', 'actors']).save() f = BlogPost.objects.item_frequencies('tags') f = dict((key, int(val)) for key, val in f.items()) @@ -990,6 +1007,13 @@ class QuerySetTest(unittest.TestCase): self.assertAlmostEqual(f['actors'], 2.0/6.0) self.assertAlmostEqual(f['film'], 1.0/6.0) + # Check item_frequencies works for non-list fields + f = BlogPost.objects.item_frequencies('hits') + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['1', '2']), set(f.keys())) + self.assertEqual(f['1'], 1) + self.assertEqual(f['2'], 2) + BlogPost.drop_collection() def test_average(self): @@ -1026,9 +1050,13 @@ class QuerySetTest(unittest.TestCase): self.Person(name='Mr Orange', age=20).save() self.Person(name='Mr White', age=20).save() self.Person(name='Mr Orange', age=30).save() - self.assertEqual(self.Person.objects.distinct('name'), - ['Mr Orange', 'Mr White']) - self.assertEqual(self.Person.objects.distinct('age'), [20, 30]) + self.Person(name='Mr Pink', age=30).save() + self.assertEqual(set(self.Person.objects.distinct('name')), + set(['Mr Orange', 'Mr White', 'Mr Pink'])) + self.assertEqual(set(self.Person.objects.distinct('age')), + set([20, 30])) + self.assertEqual(set(self.Person.objects(age=30).distinct('name')), + set(['Mr Orange', 'Mr Pink'])) def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. @@ -1330,43 +1358,6 @@ class QuerySetTest(unittest.TestCase): class QTest(unittest.TestCase): - def test_or_and(self): - """Ensure that Q objects may be combined correctly. - """ - q1 = Q(name='test') - q2 = Q(age__gte=18) - - query = ['(', {'name': 'test'}, '||', {'age__gte': 18}, ')'] - self.assertEqual((q1 | q2).query, query) - - query = ['(', {'name': 'test'}, '&&', {'age__gte': 18}, ')'] - self.assertEqual((q1 & q2).query, query) - - query = ['(', '(', {'name': 'test'}, '&&', {'age__gte': 18}, ')', '||', - {'name': 'example'}, ')'] - self.assertEqual((q1 & q2 | Q(name='example')).query, query) - - def test_item_query_as_js(self): - """Ensure that the _item_query_as_js utilitiy method works properly. - """ - q = Q() - examples = [ - - ({'name': 'test'}, ('((this.name instanceof Array) && ' - 'this.name.indexOf(i0f0) != -1) || this.name == i0f0'), - {'i0f0': 'test'}), - ({'age': {'$gt': 18}}, 'this.age > i0f0o0', {'i0f0o0': 18}), - ({'name': 'test', 'age': {'$gt': 18, '$lte': 65}}, - ('this.age <= i0f0o0 && this.age > i0f0o1 && ' - '((this.name instanceof Array) && ' - 'this.name.indexOf(i0f1) != -1) || this.name == i0f1'), - {'i0f0o0': 65, 'i0f0o1': 18, 'i0f1': 'test'}), - ] - for item, js, scope in examples: - test_scope = {} - self.assertEqual(q._item_query_as_js(item, test_scope, 0), js) - self.assertEqual(scope, test_scope) - def test_empty_q(self): """Ensure that empty Q objects won't hurt. """ @@ -1376,11 +1367,15 @@ class QTest(unittest.TestCase): q4 = Q(name='test') q5 = Q() - query = ['(', {'age__gte': 18}, '||', {'name': 'test'}, ')'] - self.assertEqual((q1 | q2 | q3 | q4 | q5).query, query) + class Person(Document): + name = StringField() + age = IntField() - query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')'] - self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query) + query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]} + self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) + + query = {'age': {'$gte': 18}, 'name': 'test'} + self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" @@ -1398,6 +1393,105 @@ class QTest(unittest.TestCase): self.assertEqual(Post.objects.filter(created_user=user).count(), 1) self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + def test_and_combination(self): + """Ensure that Q-objects correctly AND together. + """ + class TestDoc(Document): + x = IntField() + y = StringField() + + # Check than an error is raised when conflicting queries are anded + def invalid_combination(): + query = Q(x__lt=7) & Q(x__lt=3) + query.to_query(TestDoc) + self.assertRaises(InvalidQueryError, invalid_combination) + + # Check normal cases work without an error + query = Q(x__lt=7) & Q(x__gt=3) + + q1 = Q(x__lt=7) + q2 = Q(x__gt=3) + query = (q1 & q2).to_query(TestDoc) + self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}}) + + # More complex nested example + query = Q(x__lt=100) & Q(y__ne='NotMyString') + query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) + mongo_query = { + 'x': {'$lt': 100, '$gt': -100}, + 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, + } + self.assertEqual(query.to_query(TestDoc), mongo_query) + + def test_or_combination(self): + """Ensure that Q-objects correctly OR together. + """ + class TestDoc(Document): + x = IntField() + + q1 = Q(x__lt=3) + q2 = Q(x__gt=7) + query = (q1 | q2).to_query(TestDoc) + self.assertEqual(query, { + '$or': [ + {'x': {'$lt': 3}}, + {'x': {'$gt': 7}}, + ] + }) + + def test_and_or_combination(self): + """Ensure that Q-objects handle ANDing ORed components. + """ + class TestDoc(Document): + x = IntField() + y = BooleanField() + + query = (Q(x__gt=0) | Q(x__exists=False)) + query &= Q(x__lt=100) + self.assertEqual(query.to_query(TestDoc), { + '$or': [ + {'x': {'$lt': 100, '$gt': 0}}, + {'x': {'$lt': 100, '$exists': False}}, + ] + }) + + q1 = (Q(x__gt=0) | Q(x__exists=False)) + q2 = (Q(x__lt=100) | Q(y=True)) + query = (q1 & q2).to_query(TestDoc) + + self.assertEqual(['$or'], query.keys()) + conditions = [ + {'x': {'$lt': 100, '$gt': 0}}, + {'x': {'$lt': 100, '$exists': False}}, + {'x': {'$gt': 0}, 'y': True}, + {'x': {'$exists': False}, 'y': True}, + ] + self.assertEqual(len(conditions), len(query['$or'])) + for condition in conditions: + self.assertTrue(condition in query['$or']) + + def test_or_and_or_combination(self): + """Ensure that Q-objects handle ORing ANDed ORed components. :) + """ + class TestDoc(Document): + x = IntField() + y = BooleanField() + + q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False))) + q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) + query = (q1 | q2).to_query(TestDoc) + + self.assertEqual(['$or'], query.keys()) + conditions = [ + {'x': {'$gt': 0}, 'y': True}, + {'x': {'$gt': 0}, 'y': {'$exists': False}}, + {'x': {'$lt': 100}, 'y':False}, + {'x': {'$lt': 100}, 'y': {'$exists': False}}, + ] + self.assertEqual(len(conditions), len(query['$or'])) + for condition in conditions: + self.assertTrue(condition in query['$or']) + if __name__ == '__main__': unittest.main()