Merge branch 'v0.4' of git://github.com/hmarr/mongoengine into v0.4

Conflicts:
	docs/changelog.rst
	mongoengine/base.py
	mongoengine/queryset.py
This commit is contained in:
Steve Challis 2010-10-17 23:48:20 +01:00
commit 39e27735cc
13 changed files with 703 additions and 227 deletions

View File

@ -66,3 +66,5 @@ Fields
.. autoclass:: mongoengine.GenericReferenceField .. autoclass:: mongoengine.GenericReferenceField
.. autoclass:: mongoengine.FileField .. autoclass:: mongoengine.FileField
.. autoclass:: mongoengine.GeoPointField

View File

@ -6,16 +6,26 @@ Changes in v0.4
=============== ===============
- Added ``GridFSStorage`` Django storage backend - Added ``GridFSStorage`` Django storage backend
- Added ``FileField`` for GridFS support - Added ``FileField`` for GridFS support
- New Q-object implementation, which is no longer based on Javascript
- Added ``SortedListField`` - Added ``SortedListField``
- Added ``EmailField`` - Added ``EmailField``
- Added ``GeoPointField`` - Added ``GeoPointField``
- Added ``exact`` and ``iexact`` match operators to ``QuerySet`` - Added ``exact`` and ``iexact`` match operators to ``QuerySet``
- Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts - Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts
- Fixed bug in Q-objects - Added new query operators for Geo queries
- Added ``not`` query operator
- Added new update operators: ``pop`` and ``add_to_set``
- Added ``__raw__`` query parameter
- Added support for custom querysets
- Fixed document inheritance primary key issue - Fixed document inheritance primary key issue
- Added support for querying by array element position
- Base class can now be defined for ``DictField`` - Base class can now be defined for ``DictField``
- Fixed MRO error that occured on document inheritance - Fixed MRO error that occured on document inheritance
- Added ``QuerySet.distinct``, ``QuerySet.create``, ``QuerySet.snapshot``,
``QuerySet.timeout`` and ``QuerySet.all``
- Subsequent calls to ``connect()`` now work
- Introduced ``min_length`` for ``StringField`` - Introduced ``min_length`` for ``StringField``
- Fixed multi-process connection issue
- Other minor fixes - Other minor fixes
Changes in v0.3 Changes in v0.3

View File

@ -47,11 +47,11 @@ are as follows:
* :class:`~mongoengine.ReferenceField` * :class:`~mongoengine.ReferenceField`
* :class:`~mongoengine.GenericReferenceField` * :class:`~mongoengine.GenericReferenceField`
* :class:`~mongoengine.BooleanField` * :class:`~mongoengine.BooleanField`
* :class:`~mongoengine.GeoLocationField`
* :class:`~mongoengine.FileField` * :class:`~mongoengine.FileField`
* :class:`~mongoengine.EmailField` * :class:`~mongoengine.EmailField`
* :class:`~mongoengine.SortedListField` * :class:`~mongoengine.SortedListField`
* :class:`~mongoengine.BinaryField` * :class:`~mongoengine.BinaryField`
* :class:`~mongoengine.GeoPointField`
Field arguments Field arguments
--------------- ---------------
@ -72,6 +72,25 @@ arguments can be set on all fields:
:attr:`default` (Default: None) :attr:`default` (Default: None)
A value to use when no value is set for this field. A value to use when no value is set for this field.
The definion of default parameters follow `the general rules on Python
<http://docs.python.org/reference/compound_stmts.html#function-definitions>`__,
which means that some care should be taken when dealing with default mutable objects
(like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`)::
class ExampleFirst(Document):
# Default an empty list
values = ListField(IntField(), default=list)
class ExampleSecond(Document):
# Default a set of values
values = ListField(IntField(), default=lambda: [1,2,3])
class ExampleDangerous(Document):
# This can make an .append call to add values to the default (and all the following objects),
# instead to just an object
values = ListField(IntField(), default=[1,2,3])
:attr:`unique` (Default: False) :attr:`unique` (Default: False)
When True, no documents in the collection will have the same value for this When True, no documents in the collection will have the same value for this
field. field.
@ -279,6 +298,10 @@ or a **-** sign. Note that direction only matters on multi-field indexes. ::
meta = { meta = {
'indexes': ['title', ('title', '-rating')] 'indexes': ['title', ('title', '-rating')]
} }
.. note::
Geospatial indexes will be automatically created for all
:class:`~mongoengine.GeoPointField`\ s
Ordering Ordering
======== ========

View File

@ -53,6 +53,16 @@ lists that contain that item will be matched::
# 'tags' list # 'tags' list
Page.objects(tags='coding') Page.objects(tags='coding')
Raw queries
-----------
It is possible to provide a raw PyMongo query as a query parameter, which will
be integrated directly into the query. This is done using the ``__raw__``
keyword argument::
Page.objects(__raw__={'tags': 'coding'})
.. versionadded:: 0.4
Query operators Query operators
=============== ===============
Operators other than equality may also be used in queries; just attach the Operators other than equality may also be used in queries; just attach the
@ -68,6 +78,8 @@ Available operators are as follows:
* ``lte`` -- less than or equal to * ``lte`` -- less than or equal to
* ``gt`` -- greater than * ``gt`` -- greater than
* ``gte`` -- greater than or equal to * ``gte`` -- greater than or equal to
* ``not`` -- negate a standard check, may be used before other operators (e.g.
``Q(age__not__mod=5)``)
* ``in`` -- value is in list (a list of values should be provided) * ``in`` -- value is in list (a list of values should be provided)
* ``nin`` -- value is not in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided)
* ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values
@ -89,6 +101,27 @@ expressions:
.. versionadded:: 0.3 .. versionadded:: 0.3
There are a few special operators for performing geographical queries, that
may used with :class:`~mongoengine.GeoPointField`\ s:
* ``within_distance`` -- provide a list containing a point and a maximum
distance (e.g. [(41.342, -87.653), 5])
* ``within_box`` -- filter documents to those within a given bounding box (e.g.
[(35.0, -125.0), (40.0, -100.0)])
* ``near`` -- order the documents by how close they are to a given point
.. versionadded:: 0.4
Querying by position
====================
It is possible to query by position in a list by using a numerical value as a
query operator. So if you wanted to find all pages whose first tag was ``db``,
you could use the following query::
BlogPost.objects(tags__0='db')
.. versionadded:: 0.4
Limiting and skipping results Limiting and skipping results
============================= =============================
Just as with traditional ORMs, you may limit the number of results returned, or Just as with traditional ORMs, you may limit the number of results returned, or
@ -181,6 +214,22 @@ custom manager methods as you like::
assert len(BlogPost.objects) == 2 assert len(BlogPost.objects) == 2
assert len(BlogPost.live_posts) == 1 assert len(BlogPost.live_posts) == 1
Custom QuerySets
================
Should you want to add custom methods for interacting with or filtering
documents, extending the :class:`~mongoengine.queryset.QuerySet` class may be
the way to go. To use a custom :class:`~mongoengine.queryset.QuerySet` class on
a document, set ``queryset_class`` to the custom class in a
:class:`~mongoengine.Document`\ s ``meta`` dictionary::
class AwesomerQuerySet(QuerySet):
pass
class Page(Document):
meta = {'queryset_class': AwesomerQuerySet}
.. versionadded:: 0.4
Aggregation Aggregation
=========== ===========
MongoDB provides some aggregation methods out of the box, but there are not as MongoDB provides some aggregation methods out of the box, but there are not as
@ -402,8 +451,10 @@ that you may use with these methods:
* ``pop`` -- remove the last item from a list * ``pop`` -- remove the last item from a list
* ``push`` -- append a value to a list * ``push`` -- append a value to a list
* ``push_all`` -- append several values to a list * ``push_all`` -- append several values to a list
* ``pop`` -- remove the first or last element of a list
* ``pull`` -- remove a value from a list * ``pull`` -- remove a value from a list
* ``pull_all`` -- remove several values from a list * ``pull_all`` -- remove several values from a list
* ``add_to_set`` -- add value to a list only if its not in the list already
The syntax for atomic updates is similar to the querying syntax, but the The syntax for atomic updates is similar to the querying syntax, but the
modifier comes before the field, not after it:: modifier comes before the field, not after it::

View File

@ -7,7 +7,7 @@ MongoDB. To install it, simply run
.. code-block:: console .. code-block:: console
# easy_install -U mongoengine # pip install -U mongoengine
The source is available on `GitHub <http://github.com/hmarr/mongoengine>`_. The source is available on `GitHub <http://github.com/hmarr/mongoengine>`_.

View File

@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr' __author__ = 'Harry Marr'
VERSION = (0, 3, 0) VERSION = (0, 4, 0)
def get_version(): def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1]) version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@ -3,6 +3,7 @@ from queryset import DoesNotExist, MultipleObjectsReturned
import sys import sys
import pymongo import pymongo
import pymongo.objectid
_document_registry = {} _document_registry = {}
@ -203,6 +204,9 @@ class DocumentMetaclass(type):
exc = subclass_exception('MultipleObjectsReturned', base_excs, module) exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
new_class.add_to_class('MultipleObjectsReturned', exc) new_class.add_to_class('MultipleObjectsReturned', exc)
global _document_registry
_document_registry[name] = new_class
return new_class return new_class
def add_to_class(self, name, value): def add_to_class(self, name, value):
@ -215,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
global _document_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__ super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have # Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc. # their own metadata with DB collection, etc.
@ -321,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._fields['id'] = ObjectIdField(db_field='_id') new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id'] new_class.id = new_class._fields['id']
_document_registry[name] = new_class
return new_class return new_class

View File

@ -4,11 +4,12 @@ import multiprocessing
__all__ = ['ConnectionError', 'connect'] __all__ = ['ConnectionError', 'connect']
_connection_settings = { _connection_defaults = {
'host': 'localhost', 'host': 'localhost',
'port': 27017, 'port': 27017,
} }
_connection = {} _connection = {}
_connection_settings = _connection_defaults.copy()
_db_name = None _db_name = None
_db_username = None _db_username = None
@ -20,25 +21,25 @@ class ConnectionError(Exception):
pass pass
def _get_connection(): def _get_connection(reconnect=False):
global _connection global _connection
identity = get_identity() identity = get_identity()
# Connect to the database if not already connected # Connect to the database if not already connected
if _connection.get(identity) is None: if _connection.get(identity) is None or reconnect:
try: try:
_connection[identity] = Connection(**_connection_settings) _connection[identity] = Connection(**_connection_settings)
except: except:
raise ConnectionError('Cannot connect to the database') raise ConnectionError('Cannot connect to the database')
return _connection[identity] return _connection[identity]
def _get_db(): def _get_db(reconnect=False):
global _db, _connection global _db, _connection
identity = get_identity() identity = get_identity()
# Connect if not already connected # Connect if not already connected
if _connection.get(identity) is None: if _connection.get(identity) is None or reconnect:
_connection[identity] = _get_connection() _connection[identity] = _get_connection(reconnect=reconnect)
if _db.get(identity) is None: if _db.get(identity) is None or reconnect:
# _db_name will be None if the user hasn't called connect() # _db_name will be None if the user hasn't called connect()
if _db_name is None: if _db_name is None:
raise ConnectionError('Not connected to the database') raise ConnectionError('Not connected to the database')
@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs):
the default port on localhost. If authentication is needed, provide the default port on localhost. If authentication is needed, provide
username and password arguments as well. username and password arguments as well.
""" """
global _connection_settings, _db_name, _db_username, _db_password global _connection_settings, _db_name, _db_username, _db_password, _db
_connection_settings.update(kwargs) _connection_settings = dict(_connection_defaults, **kwargs)
_db_name = db _db_name = db
_db_username = username _db_username = username
_db_password = password _db_password = password
return _get_db() return _get_db(reconnect=True)

View File

@ -5,6 +5,9 @@ from operator import itemgetter
import re import re
import pymongo import pymongo
import pymongo.dbref
import pymongo.son
import pymongo.binary
import datetime import datetime
import decimal import decimal
import gridfs import gridfs
@ -106,8 +109,11 @@ class URLField(StringField):
message = 'This URL appears to be a broken link: %s' % e message = 'This URL appears to be a broken link: %s' % e
raise ValidationError(message) raise ValidationError(message)
class EmailField(StringField): class EmailField(StringField):
"""A field that validates input as an E-Mail-Address. """A field that validates input as an E-Mail-Address.
.. versionadded:: 0.4
""" """
EMAIL_REGEX = re.compile( EMAIL_REGEX = re.compile(
@ -115,11 +121,12 @@ class EmailField(StringField):
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
) )
def validate(self, value): def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value): if not EmailField.EMAIL_REGEX.match(value):
raise ValidationError('Invalid Mail-address: %s' % value) raise ValidationError('Invalid Mail-address: %s' % value)
class IntField(BaseField): class IntField(BaseField):
"""An integer field. """An integer field.
""" """
@ -143,6 +150,7 @@ class IntField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Integer value is too large') raise ValidationError('Integer value is too large')
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """An floating point number field.
""" """
@ -179,7 +187,7 @@ class DecimalField(BaseField):
if not isinstance(value, basestring): if not isinstance(value, basestring):
value = unicode(value) value = unicode(value)
return decimal.Decimal(value) return decimal.Decimal(value)
def to_mongo(self, value): def to_mongo(self, value):
return unicode(value) return unicode(value)
@ -198,6 +206,7 @@ class DecimalField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Decimal value is too large') raise ValidationError('Decimal value is too large')
class BooleanField(BaseField): class BooleanField(BaseField):
"""A boolean field type. """A boolean field type.
@ -210,6 +219,7 @@ class BooleanField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, bool) assert isinstance(value, bool)
class DateTimeField(BaseField): class DateTimeField(BaseField):
"""A datetime field. """A datetime field.
""" """
@ -217,38 +227,49 @@ class DateTimeField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, datetime.datetime) assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField): class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of """An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`. :class:`~mongoengine.EmbeddedDocument`.
""" """
def __init__(self, document, **kwargs): def __init__(self, document_type, **kwargs):
if not issubclass(document, EmbeddedDocument): if not isinstance(document_type, basestring):
raise ValidationError('Invalid embedded document class provided ' if not issubclass(document_type, EmbeddedDocument):
'to an EmbeddedDocumentField') raise ValidationError('Invalid embedded document class '
self.document = document 'provided to an EmbeddedDocumentField')
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs) super(EmbeddedDocumentField, self).__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def to_python(self, value): def to_python(self, value):
if not isinstance(value, self.document): if not isinstance(value, self.document_type):
return self.document._from_son(value) return self.document_type._from_son(value)
return value return value
def to_mongo(self, value): def to_mongo(self, value):
return self.document.to_mongo(value) return self.document_type.to_mongo(value)
def validate(self, value): def validate(self, value):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined. EmbeddedDocument subclass provided when the document was defined.
""" """
# Using isinstance also works for subclasses of self.document # Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document): if not isinstance(value, self.document_type):
raise ValidationError('Invalid embedded document instance ' raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField') 'provided to an EmbeddedDocumentField')
self.document.validate(value) self.document_type.validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document._fields.get(member_name) return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
@ -322,20 +343,32 @@ class ListField(BaseField):
try: try:
[self.field.validate(item) for item in value] [self.field.validate(item) for item in value]
except Exception, err: except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(err)) raise ValidationError('Invalid ListField item (%s)' % str(item))
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if op in ('set', 'unset'): if op in ('set', 'unset'):
return [self.field.to_mongo(v) for v in value] return [self.field.prepare_query_value(op, v) for v in value]
return self.field.to_mongo(value) return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.field.lookup_member(member_name) return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class SortedListField(ListField): class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to """A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always the database in order to ensure that a sorted list is always
retrieved. retrieved.
.. versionadded:: 0.4
""" """
_ordering = None _ordering = None
@ -351,6 +384,7 @@ class SortedListField(ListField):
key=itemgetter(self._ordering)) key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value]) return sorted([self.field.to_mongo(item) for item in value])
class DictField(BaseField): class DictField(BaseField):
"""A dictionary field that wraps a standard Python dictionary. This is """A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined. similar to an embedded document, but the structure is not defined.
@ -389,7 +423,6 @@ class ReferenceField(BaseField):
raise ValidationError('Argument to ReferenceField constructor ' raise ValidationError('Argument to ReferenceField constructor '
'must be a document class or a string') 'must be a document class or a string')
self.document_type_obj = document_type self.document_type_obj = document_type
self.document_obj = None
super(ReferenceField, self).__init__(**kwargs) super(ReferenceField, self).__init__(**kwargs)
@property @property
@ -444,6 +477,7 @@ class ReferenceField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
class GenericReferenceField(BaseField): class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
@ -488,7 +522,7 @@ class GenericReferenceField(BaseField):
return {'_cls': document.__class__.__name__, '_ref': ref} return {'_cls': document.__class__.__name__, '_ref': ref}
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value)['_ref'] return self.to_mongo(value)
class BinaryField(BaseField): class BinaryField(BaseField):
@ -655,6 +689,8 @@ class FileField(BaseField):
class GeoPointField(BaseField): class GeoPointField(BaseField):
"""A list storing a latitude and longitude. """A list storing a latitude and longitude.
.. versionadded:: 0.4
""" """
_geo_index = True _geo_index = True

View File

@ -2,8 +2,12 @@ from connection import _get_db
import pprint import pprint
import pymongo import pymongo
import pymongo.code
import pymongo.dbref
import pymongo.objectid
import re import re
import copy import copy
import itertools
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError'] 'InvalidCollectionError']
@ -15,6 +19,7 @@ REPR_OUTPUT_SIZE = 20
class DoesNotExist(Exception): class DoesNotExist(Exception):
pass pass
class MultipleObjectsReturned(Exception): class MultipleObjectsReturned(Exception):
pass pass
@ -26,50 +31,192 @@ class InvalidQueryError(Exception):
class OperationError(Exception): class OperationError(Exception):
pass pass
class InvalidCollectionError(Exception): class InvalidCollectionError(Exception):
pass pass
RE_TYPE = type(re.compile('')) RE_TYPE = type(re.compile(''))
class Q(object): class QNodeVisitor(object):
"""Base visitor class for visiting Q-object nodes in a query tree.
"""
OR = '||' def visit_combination(self, combination):
AND = '&&' """Called by QCombination objects.
OPERATORS = { """
'eq': ('((this.%(field)s instanceof Array) && ' return combination
' this.%(field)s.indexOf(%(value)s) != -1) ||'
' this.%(field)s == %(value)s'),
'ne': 'this.%(field)s != %(value)s',
'gt': 'this.%(field)s > %(value)s',
'gte': 'this.%(field)s >= %(value)s',
'lt': 'this.%(field)s < %(value)s',
'lte': 'this.%(field)s <= %(value)s',
'lte': 'this.%(field)s <= %(value)s',
'in': '%(value)s.indexOf(this.%(field)s) != -1',
'nin': '%(value)s.indexOf(this.%(field)s) == -1',
'mod': '%(field)s %% %(value)s',
'all': ('%(value)s.every(function(a){'
'return this.%(field)s.indexOf(a) != -1 })'),
'size': 'this.%(field)s.length == %(value)s',
'exists': 'this.%(field)s != null',
'regex_eq': '%(value)s.test(this.%(field)s)',
'regex_ne': '!%(value)s.test(this.%(field)s)',
}
def __init__(self, **query): def visit_query(self, query):
self.query = [query] """Called by (New)Q objects.
"""
return query
def _combine(self, other, op):
obj = Q() class SimplificationVisitor(QNodeVisitor):
if not other.query[0]: """Simplifies query trees by combinging unnecessary 'and' connection nodes
into a single Q-object.
"""
def visit_combination(self, combination):
if combination.operation == combination.AND:
# The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children):
queries = [node.query for node in combination.children]
return Q(**self._query_conjunction(queries))
return combination
def _query_conjunction(self, queries):
"""Merges query dicts - effectively &ing them together.
"""
query_ops = set()
combined_query = {}
for query in queries:
ops = set(query.keys())
# Make sure that the same operation isn't applied more than once
# to a single field
intersection = ops.intersection(query_ops)
if intersection:
msg = 'Duplicate query contitions: '
raise InvalidQueryError(msg + ', '.join(intersection))
query_ops.update(ops)
combined_query.update(copy.deepcopy(query))
return combined_query
class QueryTreeTransformerVisitor(QNodeVisitor):
"""Transforms the query tree in to a form that may be used with MongoDB.
"""
def visit_combination(self, combination):
if combination.operation == combination.AND:
# MongoDB doesn't allow us to have too many $or operations in our
# queries, so the aim is to move the ORs up the tree to one
# 'master' $or. Firstly, we must find all the necessary parts (part
# of an AND combination or just standard Q object), and store them
# separately from the OR parts.
or_groups = []
and_parts = []
for node in combination.children:
if isinstance(node, QCombination):
if node.operation == node.OR:
# Any of the children in an $or component may cause
# the query to succeed
or_groups.append(node.children)
elif node.operation == node.AND:
and_parts.append(node)
elif isinstance(node, Q):
and_parts.append(node)
# Now we combine the parts into a usable query. AND together all of
# the necessary parts. Then for each $or part, create a new query
# that ANDs the necessary part with the $or part.
clauses = []
for or_group in itertools.product(*or_groups):
q_object = reduce(lambda a, b: a & b, and_parts, Q())
q_object = reduce(lambda a, b: a & b, or_group, q_object)
clauses.append(q_object)
# Finally, $or the generated clauses in to one query. Each of the
# clauses is sufficient for the query to succeed.
return reduce(lambda a, b: a | b, clauses, Q())
if combination.operation == combination.OR:
children = []
# Crush any nested ORs in to this combination as MongoDB doesn't
# support nested $or operations
for node in combination.children:
if (isinstance(node, QCombination) and
node.operation == combination.OR):
children += node.children
else:
children.append(node)
combination.children = children
return combination
class QueryCompilerVisitor(QNodeVisitor):
"""Compiles the nodes in a query tree to a PyMongo-compatible query
dictionary.
"""
def __init__(self, document):
self.document = document
def visit_combination(self, combination):
if combination.operation == combination.OR:
return {'$or': combination.children}
elif combination.operation == combination.AND:
return self._mongo_query_conjunction(combination.children)
return combination
def visit_query(self, query):
return QuerySet._transform_query(self.document, **query.query)
def _mongo_query_conjunction(self, queries):
"""Merges Mongo query dicts - effectively &ing them together.
"""
combined_query = {}
for query in queries:
for field, ops in query.items():
if field not in combined_query:
combined_query[field] = ops
else:
# The field is already present in the query the only way
# we can merge is if both the existing value and the new
# value are operation dicts, reject anything else
if (not isinstance(combined_query[field], dict) or
not isinstance(ops, dict)):
message = 'Conflicting values for ' + field
raise InvalidQueryError(message)
current_ops = set(combined_query[field].keys())
new_ops = set(ops.keys())
# Make sure that the same operation isn't applied more than
# once to a single field
intersection = current_ops.intersection(new_ops)
if intersection:
msg = 'Duplicate query contitions: '
raise InvalidQueryError(msg + ', '.join(intersection))
# Right! We've got two non-overlapping dicts of operations!
combined_query[field].update(copy.deepcopy(ops))
return combined_query
class QNode(object):
"""Base class for nodes in query trees.
"""
AND = 0
OR = 1
def to_query(self, document):
query = self.accept(SimplificationVisitor())
query = query.accept(QueryTreeTransformerVisitor())
query = query.accept(QueryCompilerVisitor(document))
return query
def accept(self, visitor):
raise NotImplementedError
def _combine(self, other, operation):
"""Combine this node with another node into a QCombination object.
"""
if other.empty:
return self return self
if self.query[0]:
obj.query = (['('] + copy.deepcopy(self.query) + [op] + if self.empty:
copy.deepcopy(other.query) + [')']) return other
else:
obj.query = copy.deepcopy(other.query) return QCombination(operation, [self, other])
return obj
@property
def empty(self):
return False
def __or__(self, other): def __or__(self, other):
return self._combine(other, self.OR) return self._combine(other, self.OR)
@ -77,78 +224,49 @@ class Q(object):
def __and__(self, other): def __and__(self, other):
return self._combine(other, self.AND) return self._combine(other, self.AND)
def as_js(self, document):
js = [] class QCombination(QNode):
js_scope = {} """Represents the combination of several conditions by a given logical
for i, item in enumerate(self.query): operator.
if isinstance(item, dict): """
item_query = QuerySet._transform_query(document, **item)
# item_query will values will either be a value or a dict def __init__(self, operation, children):
js.append(self._item_query_as_js(item_query, js_scope, i)) self.operation = operation
self.children = []
for node in children:
# If the child is a combination of the same type, we can merge its
# children directly into this combinations children
if isinstance(node, QCombination) and node.operation == operation:
self.children += node.children
else: else:
js.append(item) self.children.append(node)
return pymongo.code.Code(' '.join(js), js_scope)
def _item_query_as_js(self, item_query, js_scope, item_num): def accept(self, visitor):
# item_query will be in one of the following forms for i in range(len(self.children)):
# {'age': 25, 'name': 'Test'} self.children[i] = self.children[i].accept(visitor)
# {'age': {'$lt': 25}, 'name': {'$in': ['Test', 'Example']}
# {'age': {'$lt': 25, '$gt': 18}}
js = []
for i, (key, value) in enumerate(item_query.items()):
op = 'eq'
# Construct a variable name for the value in the JS
value_name = 'i%sf%s' % (item_num, i)
if isinstance(value, dict):
# Multiple operators for this field
for j, (op, value) in enumerate(value.items()):
# Create a custom variable name for this operator
op_value_name = '%so%s' % (value_name, j)
# Construct the JS that uses this op
value, operation_js = self._build_op_js(op, key, value,
op_value_name)
# Update the js scope with the value for this op
js_scope[op_value_name] = value
js.append(operation_js)
else:
# Construct the JS for this field
value, field_js = self._build_op_js(op, key, value, value_name)
js_scope[value_name] = value
js.append(field_js)
return ' && '.join(js)
def _build_op_js(self, op, key, value, value_name): return visitor.visit_combination(self)
"""Substitute the values in to the correct chunk of Javascript.
"""
if isinstance(value, RE_TYPE):
# Regexes are handled specially
if op.strip('$') == 'ne':
op_js = Q.OPERATORS['regex_ne']
else:
op_js = Q.OPERATORS['regex_eq']
else:
op_js = Q.OPERATORS[op.strip('$')]
# Comparing two ObjectIds in Javascript doesn't work.. @property
if isinstance(value, pymongo.objectid.ObjectId): def empty(self):
value = unicode(value) return not bool(self.children)
# Handle DBRef
if isinstance(value, pymongo.dbref.DBRef):
op_js = '(this.%(field)s.$id == "%(id)s" &&'\
' this.%(field)s.$ref == "%(ref)s")' % {
'field': key,
'id': unicode(value.id),
'ref': unicode(value.collection)
}
value = None
# Perform the substitution class Q(QNode):
operation_js = op_js % { """A simple query object, used in a query tree to build up more complex
'field': key, query structures.
'value': value_name """
}
return value, operation_js def __init__(self, **query):
self.query = query
def accept(self, visitor):
return visitor.visit_query(self)
@property
def empty(self):
return not bool(self.query)
class QuerySet(object): class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor, """A set of results returned from a query. Wraps a MongoDB cursor,
@ -159,19 +277,30 @@ class QuerySet(object):
self._document = document self._document = document
self._collection_obj = collection self._collection_obj = collection
self._accessed_collection = False self._accessed_collection = False
self._query = {} self._mongo_query = None
self._query_obj = Q()
self._initial_query = {}
self._where_clause = None self._where_clause = None
self._loaded_fields = [] self._loaded_fields = []
self._ordering = [] self._ordering = []
self._snapshot = False
self._timeout = True
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
# subclasses of the class being used # subclasses of the class being used
if document._meta.get('allow_inheritance'): if document._meta.get('allow_inheritance'):
self._query = {'_types': self._document._class_name} self._initial_query = {'_types': self._document._class_name}
self._cursor_obj = None self._cursor_obj = None
self._limit = None self._limit = None
self._skip = None self._skip = None
@property
def _query(self):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query)
return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False, def ensure_index(self, key_or_list, drop_dups=False, background=False,
**kwargs): **kwargs):
"""Ensure that the given indexes are in place. """Ensure that the given indexes are in place.
@ -230,10 +359,14 @@ class QuerySet(object):
objects, only the last one will be used objects, only the last one will be used
:param query: Django-style query keyword arguments :param query: Django-style query keyword arguments
""" """
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query)
if q_obj: if q_obj:
self._where_clause = q_obj.as_js(self._document) query &= q_obj
query = QuerySet._transform_query(_doc_cls=self._document, **query) self._query_obj &= query
self._query.update(query) self._mongo_query = None
self._cursor_obj = None
return self return self
def filter(self, *q_objs, **query): def filter(self, *q_objs, **query):
@ -285,9 +418,12 @@ class QuerySet(object):
@property @property
def _cursor(self): def _cursor(self):
if self._cursor_obj is None: if self._cursor_obj is None:
cursor_args = {} cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout,
}
if self._loaded_fields: if self._loaded_fields:
cursor_args = {'fields': self._loaded_fields} cursor_args['fields'] = self._loaded_fields
self._cursor_obj = self._collection.find(self._query, self._cursor_obj = self._collection.find(self._query,
**cursor_args) **cursor_args)
# Apply where clauses to cursor # Apply where clauses to cursor
@ -335,7 +471,7 @@ class QuerySet(object):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists'] 'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near'] geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith', match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
@ -343,6 +479,10 @@ class QuerySet(object):
mongo_query = {} mongo_query = {}
for key, value in query.items(): for key, value in query.items():
if key == "__raw__":
mongo_query.update(value)
continue
parts = key.split('__') parts = key.split('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()] parts = [part for part in parts if not part.isdigit()]
@ -351,6 +491,11 @@ class QuerySet(object):
if parts[-1] in operators + match_operators + geo_operators: if parts[-1] in operators + match_operators + geo_operators:
op = parts.pop() op = parts.pop()
negate = False
if parts[-1] == 'not':
parts.pop()
negate = True
if _doc_cls: if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
@ -358,7 +503,7 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte'] singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += match_operators singular_ops += match_operators
if op in singular_ops: if op in singular_ops:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
@ -366,9 +511,6 @@ class QuerySet(object):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
if field.__class__.__name__ == 'GenericReferenceField':
parts.append('_ref')
# if op and op not in match_operators: # if op and op not in match_operators:
if op: if op:
if op in geo_operators: if op in geo_operators:
@ -383,7 +525,10 @@ class QuerySet(object):
"been implemented" % op) "been implemented" % op)
elif op not in match_operators: elif op not in match_operators:
value = {'$' + op: value} value = {'$' + op: value}
if negate:
value = {'$not': value}
for i, part in indices: for i, part in indices:
parts.insert(i, part) parts.insert(i, part)
key = '.'.join(parts) key = '.'.join(parts)
@ -445,6 +590,7 @@ class QuerySet(object):
def create(self, **kwargs): def create(self, **kwargs):
"""Create new object. Returns the saved object instance. """Create new object. Returns the saved object instance.
.. versionadded:: 0.4 .. versionadded:: 0.4
""" """
doc = self._document(**kwargs) doc = self._document(**kwargs)
@ -641,6 +787,7 @@ class QuerySet(object):
# Integer index provided # Integer index provided
elif isinstance(key, int): elif isinstance(key, int):
return self._document._from_son(self._cursor[key]) return self._document._from_son(self._cursor[key])
raise AttributeError
def distinct(self, field): def distinct(self, field):
"""Return a list of distinct values for a given field. """Return a list of distinct values for a given field.
@ -649,7 +796,7 @@ class QuerySet(object):
.. versionadded:: 0.4 .. versionadded:: 0.4
""" """
return self._collection.distinct(field) return self._cursor.distinct(field)
def only(self, *fields): def only(self, *fields):
"""Load only a subset of this document's fields. :: """Load only a subset of this document's fields. ::
@ -709,6 +856,20 @@ class QuerySet(object):
plan = pprint.pformat(plan) plan = pprint.pformat(plan)
return plan return plan
def snapshot(self, enabled):
"""Enable or disable snapshot mode when querying.
:param enabled: whether or not snapshot mode is enabled
"""
self._snapshot = enabled
def timeout(self, enabled):
"""Enable or disable the default mongod timeout when querying.
:param enabled: whether or not the timeout is used
"""
self._timeout = enabled
def delete(self, safe=False): def delete(self, safe=False):
"""Delete the documents matched by the query. """Delete the documents matched by the query.
@ -721,7 +882,7 @@ class QuerySet(object):
"""Transform an update spec from Django-style format to Mongo format. """Transform an update spec from Django-style format to Mongo format.
""" """
operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all',
'pull', 'pull_all'] 'pull', 'pull_all', 'add_to_set']
mongo_update = {} mongo_update = {}
for key, value in update.items(): for key, value in update.items():
@ -739,6 +900,8 @@ class QuerySet(object):
op = 'inc' op = 'inc'
if value > 0: if value > 0:
value = -value value = -value
elif op == 'add_to_set':
op = op.replace('_to_set', 'ToSet')
if _doc_cls: if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
@ -747,7 +910,8 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'set', 'unset', 'pop', 'push', 'pull'): if op in (None, 'set', 'unset', 'pop', 'push', 'pull',
'addToSet'):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
@ -913,20 +1077,27 @@ class QuerySet(object):
""" """
return self.exec_js(average_func, field) return self.exec_js(average_func, field)
def item_frequencies(self, list_field, normalize=False): def item_frequencies(self, field, normalize=False):
"""Returns a dictionary of all items present in a list field across """Returns a dictionary of all items present in a field across
the whole queried set of documents, and their corresponding frequency. the whole queried set of documents, and their corresponding frequency.
This is useful for generating tag clouds, or searching documents. This is useful for generating tag clouds, or searching documents.
:param list_field: the list field to use If the field is a :class:`~mongoengine.ListField`, the items within
each list will be counted individually.
:param field: the field to use
:param normalize: normalize the results so they add to 1.0 :param normalize: normalize the results so they add to 1.0
""" """
freq_func = """ freq_func = """
function(listField) { function(field) {
if (options.normalize) { if (options.normalize) {
var total = 0.0; var total = 0.0;
db[collection].find(query).forEach(function(doc) { db[collection].find(query).forEach(function(doc) {
total += doc[listField].length; if (doc[field].constructor == Array) {
total += doc[field].length;
} else {
total++;
}
}); });
} }
@ -936,14 +1107,19 @@ class QuerySet(object):
inc /= total; inc /= total;
} }
db[collection].find(query).forEach(function(doc) { db[collection].find(query).forEach(function(doc) {
doc[listField].forEach(function(item) { if (doc[field].constructor == Array) {
doc[field].forEach(function(item) {
frequencies[item] = inc + (frequencies[item] || 0);
});
} else {
var item = doc[field];
frequencies[item] = inc + (frequencies[item] || 0); frequencies[item] = inc + (frequencies[item] || 0);
}); }
}); });
return frequencies; return frequencies;
} }
""" """
return self.exec_js(freq_func, list_field, normalize=normalize) return self.exec_js(freq_func, field, normalize=normalize)
def __repr__(self): def __repr__(self):
limit = REPR_OUTPUT_SIZE + 1 limit = REPR_OUTPUT_SIZE + 1
@ -959,7 +1135,7 @@ class QuerySetManager(object):
def __init__(self, manager_func=None): def __init__(self, manager_func=None):
self._manager_func = manager_func self._manager_func = manager_func
self._collection = None self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when """Descriptor for instantiating a new QuerySet object when
@ -969,10 +1145,9 @@ class QuerySetManager(object):
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
if self._collection is None: db = _get_db()
db = _get_db() collection = owner._meta['collection']
collection = owner._meta['collection'] if (db, collection) not in self._collections:
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']: if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta # Get max document limit and max byte size from meta
@ -980,10 +1155,10 @@ class QuerySetManager(object):
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():
self._collection = db[collection] self._collections[(db, collection)] = db[collection]
# The collection already exists, check if its capped # The collection already exists, check if its capped
# options match the specified capped options # options match the specified capped options
options = self._collection.options() options = self._collections[(db, collection)].options()
if options.get('max') != max_documents or \ if options.get('max') != max_documents or \
options.get('size') != max_size: options.get('size') != max_size:
msg = ('Cannot create collection "%s" as a capped ' msg = ('Cannot create collection "%s" as a capped '
@ -994,13 +1169,15 @@ class QuerySetManager(object):
opts = {'capped': True, 'size': max_size} opts = {'capped': True, 'size': max_size}
if max_documents: if max_documents:
opts['max'] = max_documents opts['max'] = max_documents
self._collection = db.create_collection(collection, **opts) self._collections[(db, collection)] = db.create_collection(
collection, **opts
)
else: else:
self._collection = db[collection] self._collections[(db, collection)] = db[collection]
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collection) queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func: if self._manager_func:
if self._manager_func.func_code.co_argcount == 1: if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self._manager_func(queryset)

View File

@ -200,6 +200,37 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
self.assertFalse(collection in self.db.collection_names()) self.assertFalse(collection in self.db.collection_names())
def test_inherited_collections(self):
"""Ensure that subclassed documents don't override parents' collections.
"""
class Drink(Document):
name = StringField()
class AlcoholicDrink(Drink):
meta = {'collection': 'booze'}
class Drinker(Document):
drink = GenericReferenceField()
Drink.drop_collection()
AlcoholicDrink.drop_collection()
Drinker.drop_collection()
red_bull = Drink(name='Red Bull')
red_bull.save()
programmer = Drinker(drink=red_bull)
programmer.save()
beer = AlcoholicDrink(name='Beer')
beer.save()
real_person = Drinker(drink=beer)
real_person.save()
self.assertEqual(Drinker.objects[0].drink.name, red_bull.name)
self.assertEqual(Drinker.objects[1].drink.name, beer.name)
def test_capped_collection(self): def test_capped_collection(self):
"""Ensure that capped collections work properly. """Ensure that capped collections work properly.
""" """

View File

@ -189,6 +189,9 @@ class FieldTest(unittest.TestCase):
def test_list_validation(self): def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements. """Ensure that a list field only accepts lists with valid elements.
""" """
class User(Document):
pass
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@ -196,6 +199,7 @@ class FieldTest(unittest.TestCase):
content = StringField() content = StringField()
comments = ListField(EmbeddedDocumentField(Comment)) comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField()) tags = ListField(StringField())
authors = ListField(ReferenceField(User))
post = BlogPost(content='Went for a walk today...') post = BlogPost(content='Went for a walk today...')
post.validate() post.validate()
@ -210,15 +214,21 @@ class FieldTest(unittest.TestCase):
post.tags = ('fun', 'leisure') post.tags = ('fun', 'leisure')
post.validate() post.validate()
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
post.comments = comments
post.validate()
post.comments = ['a'] post.comments = ['a']
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.comments = 'yay' post.comments = 'yay'
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
post.comments = comments
post.validate()
post.authors = [Comment()]
self.assertRaises(ValidationError, post.validate)
post.authors = [User()]
post.validate()
def test_sorted_list_sorting(self): def test_sorted_list_sorting(self):
"""Ensure that a sorted list field properly sorts values. """Ensure that a sorted list field properly sorts values.
""" """
@ -395,14 +405,54 @@ class FieldTest(unittest.TestCase):
class Employee(Document): class Employee(Document):
name = StringField() name = StringField()
boss = ReferenceField('self') boss = ReferenceField('self')
friends = ListField(ReferenceField('self'))
bill = Employee(name='Bill Lumbergh') bill = Employee(name='Bill Lumbergh')
bill.save() bill.save()
peter = Employee(name='Peter Gibbons', boss=bill)
michael = Employee(name='Michael Bolton')
michael.save()
samir = Employee(name='Samir Nagheenanajar')
samir.save()
friends = [michael, samir]
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
peter.save() peter.save()
peter = Employee.objects.with_id(peter.id) peter = Employee.objects.with_id(peter.id)
self.assertEqual(peter.boss, bill) self.assertEqual(peter.boss, bill)
self.assertEqual(peter.friends, friends)
def test_recursive_embedding(self):
"""Ensure that EmbeddedDocumentFields can contain their own documents.
"""
class Tree(Document):
name = StringField()
children = ListField(EmbeddedDocumentField('TreeNode'))
class TreeNode(EmbeddedDocument):
name = StringField()
children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1")
tree.children.append(first_child)
second_child = TreeNode(name="Child 2")
first_child.children.append(second_child)
third_child = TreeNode(name="Child 3")
first_child.children.append(third_child)
tree.save()
tree_obj = Tree.objects.first()
self.assertEqual(len(tree.children), 1)
self.assertEqual(tree.children[0].name, first_child.name)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
def test_undefined_reference(self): def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents. """Ensure that ReferenceFields may reference undefined Documents.

View File

@ -53,9 +53,6 @@ class QuerySetTest(unittest.TestCase):
person2 = self.Person(name="User B", age=30) person2 = self.Person(name="User B", age=30)
person2.save() person2.save()
q1 = Q(name='test')
q2 = Q(age__gte=18)
# Find all people in the collection # Find all people in the collection
people = self.Person.objects people = self.Person.objects
self.assertEqual(len(people), 2) self.assertEqual(len(people), 2)
@ -156,7 +153,8 @@ class QuerySetTest(unittest.TestCase):
# Retrieve the first person from the database # Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) self.assertRaises(MultipleObjectsReturned, self.Person.objects.get)
self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get) self.assertRaises(self.Person.MultipleObjectsReturned,
self.Person.objects.get)
# Use a query to filter the people found to just person2 # Use a query to filter the people found to just person2
person = self.Person.objects.get(age=30) person = self.Person.objects.get(age=30)
@ -234,7 +232,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(created, False) self.assertEqual(created, False)
# Try retrieving when no objects exists - new doc should be created # Try retrieving when no objects exists - new doc should be created
person, created = self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'}) kwargs = dict(age=50, defaults={'name': 'User C'})
person, created = self.Person.objects.get_or_create(**kwargs)
self.assertEqual(created, True) self.assertEqual(created, True)
person = self.Person.objects.get(age=50) person = self.Person.objects.get(age=50)
@ -337,6 +336,18 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first() obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first()
self.assertEqual(obj, person) self.assertEqual(obj, person)
def test_not(self):
"""Ensure that the __not operator works as expected.
"""
alice = self.Person(name='Alice', age=25)
alice.save()
obj = self.Person.objects(name__iexact='alice').first()
self.assertEqual(obj, alice)
obj = self.Person.objects(name__not__iexact='alice').first()
self.assertEqual(obj, None)
def test_filter_chaining(self): def test_filter_chaining(self):
"""Ensure filters can be chained together. """Ensure filters can be chained together.
""" """
@ -546,9 +557,10 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first() obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first()
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__ne=re.compile('^bob'))).first() obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first()
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__ne=re.compile('^Gui'))).first()
obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
def test_q_lists(self): def test_q_lists(self):
@ -717,6 +729,11 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.tags, tags) self.assertEqual(post.tags, tags)
BlogPost.objects.update_one(add_to_set__tags='unique')
BlogPost.objects.update_one(add_to_set__tags='unique')
post.reload()
self.assertEqual(post.tags.count('unique'), 1)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_pull(self): def test_update_pull(self):
@ -968,7 +985,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost(hits=1, tags=['music', 'film', 'actors']).save() BlogPost(hits=1, tags=['music', 'film', 'actors']).save()
BlogPost(hits=2, tags=['music']).save() BlogPost(hits=2, tags=['music']).save()
BlogPost(hits=3, tags=['music', 'actors']).save() BlogPost(hits=2, tags=['music', 'actors']).save()
f = BlogPost.objects.item_frequencies('tags') f = BlogPost.objects.item_frequencies('tags')
f = dict((key, int(val)) for key, val in f.items()) f = dict((key, int(val)) for key, val in f.items())
@ -990,6 +1007,13 @@ class QuerySetTest(unittest.TestCase):
self.assertAlmostEqual(f['actors'], 2.0/6.0) self.assertAlmostEqual(f['actors'], 2.0/6.0)
self.assertAlmostEqual(f['film'], 1.0/6.0) self.assertAlmostEqual(f['film'], 1.0/6.0)
# Check item_frequencies works for non-list fields
f = BlogPost.objects.item_frequencies('hits')
f = dict((key, int(val)) for key, val in f.items())
self.assertEqual(set(['1', '2']), set(f.keys()))
self.assertEqual(f['1'], 1)
self.assertEqual(f['2'], 2)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_average(self): def test_average(self):
@ -1026,9 +1050,13 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='Mr Orange', age=20).save() self.Person(name='Mr Orange', age=20).save()
self.Person(name='Mr White', age=20).save() self.Person(name='Mr White', age=20).save()
self.Person(name='Mr Orange', age=30).save() self.Person(name='Mr Orange', age=30).save()
self.assertEqual(self.Person.objects.distinct('name'), self.Person(name='Mr Pink', age=30).save()
['Mr Orange', 'Mr White']) self.assertEqual(set(self.Person.objects.distinct('name')),
self.assertEqual(self.Person.objects.distinct('age'), [20, 30]) set(['Mr Orange', 'Mr White', 'Mr Pink']))
self.assertEqual(set(self.Person.objects.distinct('age')),
set([20, 30]))
self.assertEqual(set(self.Person.objects(age=30).distinct('name')),
set(['Mr Orange', 'Mr Pink']))
def test_custom_manager(self): def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected. """Ensure that custom QuerySetManager instances work as expected.
@ -1330,43 +1358,6 @@ class QuerySetTest(unittest.TestCase):
class QTest(unittest.TestCase): class QTest(unittest.TestCase):
def test_or_and(self):
"""Ensure that Q objects may be combined correctly.
"""
q1 = Q(name='test')
q2 = Q(age__gte=18)
query = ['(', {'name': 'test'}, '||', {'age__gte': 18}, ')']
self.assertEqual((q1 | q2).query, query)
query = ['(', {'name': 'test'}, '&&', {'age__gte': 18}, ')']
self.assertEqual((q1 & q2).query, query)
query = ['(', '(', {'name': 'test'}, '&&', {'age__gte': 18}, ')', '||',
{'name': 'example'}, ')']
self.assertEqual((q1 & q2 | Q(name='example')).query, query)
def test_item_query_as_js(self):
"""Ensure that the _item_query_as_js utilitiy method works properly.
"""
q = Q()
examples = [
({'name': 'test'}, ('((this.name instanceof Array) && '
'this.name.indexOf(i0f0) != -1) || this.name == i0f0'),
{'i0f0': 'test'}),
({'age': {'$gt': 18}}, 'this.age > i0f0o0', {'i0f0o0': 18}),
({'name': 'test', 'age': {'$gt': 18, '$lte': 65}},
('this.age <= i0f0o0 && this.age > i0f0o1 && '
'((this.name instanceof Array) && '
'this.name.indexOf(i0f1) != -1) || this.name == i0f1'),
{'i0f0o0': 65, 'i0f0o1': 18, 'i0f1': 'test'}),
]
for item, js, scope in examples:
test_scope = {}
self.assertEqual(q._item_query_as_js(item, test_scope, 0), js)
self.assertEqual(scope, test_scope)
def test_empty_q(self): def test_empty_q(self):
"""Ensure that empty Q objects won't hurt. """Ensure that empty Q objects won't hurt.
""" """
@ -1376,11 +1367,15 @@ class QTest(unittest.TestCase):
q4 = Q(name='test') q4 = Q(name='test')
q5 = Q() q5 = Q()
query = ['(', {'age__gte': 18}, '||', {'name': 'test'}, ')'] class Person(Document):
self.assertEqual((q1 | q2 | q3 | q4 | q5).query, query) name = StringField()
age = IntField()
query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')'] query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]}
self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query) self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query)
query = {'age': {'$gte': 18}, 'name': 'test'}
self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query)
def test_q_with_dbref(self): def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly""" """Ensure Q objects handle DBRefs correctly"""
@ -1398,6 +1393,105 @@ class QTest(unittest.TestCase):
self.assertEqual(Post.objects.filter(created_user=user).count(), 1) self.assertEqual(Post.objects.filter(created_user=user).count(), 1)
self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1)
def test_and_combination(self):
"""Ensure that Q-objects correctly AND together.
"""
class TestDoc(Document):
x = IntField()
y = StringField()
# Check than an error is raised when conflicting queries are anded
def invalid_combination():
query = Q(x__lt=7) & Q(x__lt=3)
query.to_query(TestDoc)
self.assertRaises(InvalidQueryError, invalid_combination)
# Check normal cases work without an error
query = Q(x__lt=7) & Q(x__gt=3)
q1 = Q(x__lt=7)
q2 = Q(x__gt=3)
query = (q1 & q2).to_query(TestDoc)
self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}})
# More complex nested example
query = Q(x__lt=100) & Q(y__ne='NotMyString')
query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100)
mongo_query = {
'x': {'$lt': 100, '$gt': -100},
'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']},
}
self.assertEqual(query.to_query(TestDoc), mongo_query)
def test_or_combination(self):
"""Ensure that Q-objects correctly OR together.
"""
class TestDoc(Document):
x = IntField()
q1 = Q(x__lt=3)
q2 = Q(x__gt=7)
query = (q1 | q2).to_query(TestDoc)
self.assertEqual(query, {
'$or': [
{'x': {'$lt': 3}},
{'x': {'$gt': 7}},
]
})
def test_and_or_combination(self):
"""Ensure that Q-objects handle ANDing ORed components.
"""
class TestDoc(Document):
x = IntField()
y = BooleanField()
query = (Q(x__gt=0) | Q(x__exists=False))
query &= Q(x__lt=100)
self.assertEqual(query.to_query(TestDoc), {
'$or': [
{'x': {'$lt': 100, '$gt': 0}},
{'x': {'$lt': 100, '$exists': False}},
]
})
q1 = (Q(x__gt=0) | Q(x__exists=False))
q2 = (Q(x__lt=100) | Q(y=True))
query = (q1 & q2).to_query(TestDoc)
self.assertEqual(['$or'], query.keys())
conditions = [
{'x': {'$lt': 100, '$gt': 0}},
{'x': {'$lt': 100, '$exists': False}},
{'x': {'$gt': 0}, 'y': True},
{'x': {'$exists': False}, 'y': True},
]
self.assertEqual(len(conditions), len(query['$or']))
for condition in conditions:
self.assertTrue(condition in query['$or'])
def test_or_and_or_combination(self):
"""Ensure that Q-objects handle ORing ANDed ORed components. :)
"""
class TestDoc(Document):
x = IntField()
y = BooleanField()
q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False)))
q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False)))
query = (q1 | q2).to_query(TestDoc)
self.assertEqual(['$or'], query.keys())
conditions = [
{'x': {'$gt': 0}, 'y': True},
{'x': {'$gt': 0}, 'y': {'$exists': False}},
{'x': {'$lt': 100}, 'y':False},
{'x': {'$lt': 100}, 'y': {'$exists': False}},
]
self.assertEqual(len(conditions), len(query['$or']))
for condition in conditions:
self.assertTrue(condition in query['$or'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()