Updated for pymongo

This commit is contained in:
Ross Lawley 2012-07-20 16:17:35 +01:00
parent 69989365c7
commit 3da37fbf6e
7 changed files with 158 additions and 138 deletions

View File

@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr' __author__ = 'Harry Marr'
VERSION = (0, 4, 0) VERSION = (0, 4, 1)
def get_version(): def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1]) version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@ -2,8 +2,8 @@ from queryset import QuerySet, QuerySetManager
from queryset import DoesNotExist, MultipleObjectsReturned from queryset import DoesNotExist, MultipleObjectsReturned
import sys import sys
import bson
import pymongo import pymongo
import pymongo.objectid
_document_registry = {} _document_registry = {}
@ -21,11 +21,11 @@ class BaseField(object):
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
""" """
# Fields may have _types inserted into indexes by default # Fields may have _types inserted into indexes by default
_index_with_types = True _index_with_types = True
_geo_index = False _geo_index = False
def __init__(self, db_field=None, name=None, required=False, default=None, def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False, unique=False, unique_with=None, primary_key=False,
validation=None, choices=None): validation=None, choices=None):
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
@ -43,7 +43,7 @@ class BaseField(object):
self.choices = choices self.choices = choices
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document. Do """Descriptor for retrieving a value from a field in a document. Do
any necessary conversion between Python and MongoDB types. any necessary conversion between Python and MongoDB types.
""" """
if instance is None: if instance is None:
@ -111,9 +111,9 @@ class ObjectIdField(BaseField):
# return unicode(value) # return unicode(value)
def to_mongo(self, value): def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId): if not isinstance(value, bson.objectid.ObjectId):
try: try:
return pymongo.objectid.ObjectId(unicode(value)) return bson.objectid.ObjectId(unicode(value))
except Exception, e: except Exception, e:
#e.message attribute has been deprecated since Python 2.6 #e.message attribute has been deprecated since Python 2.6
raise ValidationError(unicode(e)) raise ValidationError(unicode(e))
@ -124,7 +124,7 @@ class ObjectIdField(BaseField):
def validate(self, value): def validate(self, value):
try: try:
pymongo.objectid.ObjectId(unicode(value)) bson.objectid.ObjectId(unicode(value))
except: except:
raise ValidationError('Invalid Object ID') raise ValidationError('Invalid Object ID')
@ -153,8 +153,8 @@ class DocumentMetaclass(type):
superclasses.update(base._superclasses) superclasses.update(base._superclasses)
if hasattr(base, '_meta'): if hasattr(base, '_meta'):
# Ensure that the Document class may be subclassed - # Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on # inheritance may be disabled to remove dependency on
# additional fields _cls and _types # additional fields _cls and _types
if base._meta.get('allow_inheritance', True) == False: if base._meta.get('allow_inheritance', True) == False:
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
@ -193,12 +193,12 @@ class DocumentMetaclass(type):
module = attrs.get('__module__') module = attrs.get('__module__')
base_excs = tuple(base.DoesNotExist for base in bases base_excs = tuple(base.DoesNotExist for base in bases
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
exc = subclass_exception('DoesNotExist', base_excs, module) exc = subclass_exception('DoesNotExist', base_excs, module)
new_class.add_to_class('DoesNotExist', exc) new_class.add_to_class('DoesNotExist', exc)
base_excs = tuple(base.MultipleObjectsReturned for base in bases base_excs = tuple(base.MultipleObjectsReturned for base in bases
if hasattr(base, 'MultipleObjectsReturned')) if hasattr(base, 'MultipleObjectsReturned'))
base_excs = base_excs or (MultipleObjectsReturned,) base_excs = base_excs or (MultipleObjectsReturned,)
exc = subclass_exception('MultipleObjectsReturned', base_excs, module) exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
@ -220,9 +220,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
super_new = super(TopLevelDocumentMetaclass, cls).__new__ super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have # Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc. # their own metadata with DB collection, etc.
# __metaclass__ is only set on the class with the __metaclass__ # __metaclass__ is only set on the class with the __metaclass__
# attribute (i.e. it is not set on subclasses). This differentiates # attribute (i.e. it is not set on subclasses). This differentiates
# 'real' documents from the 'Document' class # 'real' documents from the 'Document' class
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: if attrs.get('__metaclass__') == TopLevelDocumentMetaclass:
@ -347,7 +347,7 @@ class BaseDocument(object):
are present. are present.
""" """
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name)) fields = [(field, getattr(self, name))
for name, field in self._fields.items()] for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value # Ensure that each field is matched to a valid value

View File

@ -40,16 +40,16 @@ class Document(BaseDocument):
presence of `_cls` and `_types`, set :attr:`allow_inheritance` to presence of `_cls` and `_types`, set :attr:`allow_inheritance` to
``False`` in the :attr:`meta` dictionary. ``False`` in the :attr:`meta` dictionary.
A :class:`~mongoengine.Document` may use a **Capped Collection** by A :class:`~mongoengine.Document` may use a **Capped Collection** by
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
dictionary. :attr:`max_documents` is the maximum number of documents that dictionary. :attr:`max_documents` is the maximum number of documents that
is allowed to be stored in the collection, and :attr:`max_size` is the is allowed to be stored in the collection, and :attr:`max_size` is the
maximum size of the collection in bytes. If :attr:`max_size` is not maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB). 10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign. a **+** or **-** sign.
""" """
@ -61,11 +61,11 @@ class Document(BaseDocument):
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
If ``safe=True`` and the operation is unsuccessful, an If ``safe=True`` and the operation is unsuccessful, an
:class:`~mongoengine.OperationError` will be raised. :class:`~mongoengine.OperationError` will be raised.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents updates of existing documents
:param validate: validates the document; set to ``False`` for skiping :param validate: validates the document; set to ``False`` for skiping
""" """
@ -123,9 +123,9 @@ class MapReduceDocument(object):
"""A document returned from a map/reduce query. """A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection` :param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of :param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as :class:`~bson.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``, an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property. the object can be accessed via the ``object`` property.
:param value: The result(s) for this key. :param value: The result(s) for this key.
@ -140,7 +140,7 @@ class MapReduceDocument(object):
@property @property
def object(self): def object(self):
"""Lazy-load the object referenced by ``self.key``. ``self.key`` """Lazy-load the object referenced by ``self.key``. ``self.key``
should be the ``primary_key``. should be the ``primary_key``.
""" """
id_field = self._document()._meta['id_field'] id_field = self._document()._meta['id_field']

View File

@ -5,9 +5,9 @@ from operator import itemgetter
import re import re
import pymongo import pymongo
import pymongo.dbref import bson.dbref
import pymongo.son import bson.son
import pymongo.binary import bson.binary
import datetime import datetime
import decimal import decimal
import gridfs import gridfs
@ -300,13 +300,13 @@ class ListField(BaseField):
if isinstance(self.field, ReferenceField): if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type referenced_type = self.field.document_type
# Get value from document instance if available # Get value from document instance if available
value_list = instance._data.get(self.name) value_list = instance._data.get(self.name)
if value_list: if value_list:
deref_list = [] deref_list = []
for value in value_list: for value in value_list:
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)): if isinstance(value, (bson.dbref.DBRef)):
value = _get_db().dereference(value) value = _get_db().dereference(value)
deref_list.append(referenced_type._from_son(value)) deref_list.append(referenced_type._from_son(value))
else: else:
@ -319,7 +319,7 @@ class ListField(BaseField):
deref_list = [] deref_list = []
for value in value_list: for value in value_list:
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)): if isinstance(value, (dict, bson.son.SON)):
deref_list.append(self.field.dereference(value)) deref_list.append(self.field.dereference(value))
else: else:
deref_list.append(value) deref_list.append(value)
@ -444,7 +444,7 @@ class ReferenceField(BaseField):
# Get value from document instance if available # Get value from document instance if available
value = instance._data.get(self.name) value = instance._data.get(self.name)
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)): if isinstance(value, (bson.dbref.DBRef)):
value = _get_db().dereference(value) value = _get_db().dereference(value)
if value is not None: if value is not None:
instance._data[self.name] = self.document_type._from_son(value) instance._data[self.name] = self.document_type._from_son(value)
@ -466,13 +466,13 @@ class ReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = self.document_type._meta['collection'] collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_) return bson.dbref.DBRef(collection, id_)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) assert isinstance(value, (self.document_type, bson.dbref.DBRef))
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
@ -490,7 +490,7 @@ class GenericReferenceField(BaseField):
return self return self
value = instance._data.get(self.name) value = instance._data.get(self.name)
if isinstance(value, (dict, pymongo.son.SON)): if isinstance(value, (dict, bson.son.SON)):
instance._data[self.name] = self.dereference(value) instance._data[self.name] = self.dereference(value)
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)
@ -518,7 +518,7 @@ class GenericReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = document._meta['collection'] collection = document._meta['collection']
ref = pymongo.dbref.DBRef(collection, id_) ref = bson.dbref.DBRef(collection, id_)
return {'_cls': document.__class__.__name__, '_ref': ref} return {'_cls': document.__class__.__name__, '_ref': ref}
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -534,7 +534,7 @@ class BinaryField(BaseField):
super(BinaryField, self).__init__(**kwargs) super(BinaryField, self).__init__(**kwargs)
def to_mongo(self, value): def to_mongo(self, value):
return pymongo.binary.Binary(value) return bson.binary.Binary(value)
def to_python(self, value): def to_python(self, value):
# Returns str not unicode as this is binary data # Returns str not unicode as this is binary data
@ -603,7 +603,7 @@ class GridFSProxy(object):
if not self.newfile: if not self.newfile:
self.new_file() self.new_file()
self.grid_id = self.newfile._id self.grid_id = self.newfile._id
self.newfile.writelines(lines) self.newfile.writelines(lines)
def read(self): def read(self):
try: try:
@ -680,7 +680,7 @@ class FileField(BaseField):
def validate(self, value): def validate(self, value):
if value.grid_id is not None: if value.grid_id is not None:
assert isinstance(value, GridFSProxy) assert isinstance(value, GridFSProxy)
assert isinstance(value.grid_id, pymongo.objectid.ObjectId) assert isinstance(value.grid_id, bson.objectid.ObjectId)
class GeoPointField(BaseField): class GeoPointField(BaseField):

View File

@ -2,9 +2,9 @@ from connection import _get_db
import pprint import pprint
import pymongo import pymongo
import pymongo.code import bson.code
import pymongo.dbref import bson.dbref
import pymongo.objectid import bson.objectid
import re import re
import copy import copy
import itertools import itertools
@ -424,7 +424,7 @@ class QuerySet(object):
} }
if self._loaded_fields: if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields cursor_args['fields'] = self._loaded_fields
self._cursor_obj = self._collection.find(self._query, self._cursor_obj = self._collection.find(self._query,
**cursor_args) **cursor_args)
# Apply where clauses to cursor # Apply where clauses to cursor
if self._where_clause: if self._where_clause:
@ -476,8 +476,8 @@ class QuerySet(object):
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not'] 'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near'] geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith', match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact'] 'exact', 'iexact']
mongo_query = {} mongo_query = {}
@ -563,8 +563,8 @@ class QuerySet(object):
% self._document._class_name) % self._document._class_name)
def get_or_create(self, *q_objs, **query): def get_or_create(self, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of """Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object ``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises and ``created`` is a boolean specifying whether a new object was created. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` or :class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` if multiple results are found. `DocumentName.MultipleObjectsReturned` if multiple results are found.
@ -667,8 +667,8 @@ class QuerySet(object):
def __len__(self): def __len__(self):
return self.count() return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
scope=None, keep_temp=False): scope=None):
"""Perform a map/reduce query using the current query spec """Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
it must be the last call made, as it does not return a maleable it must be the last call made, as it does not return a maleable
@ -678,52 +678,61 @@ class QuerySet(object):
and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced`
tests in ``tests.queryset.QuerySetTest`` for usage examples. tests in ``tests.queryset.QuerySetTest`` for usage examples.
:param map_f: map function, as :class:`~pymongo.code.Code` or string :param map_f: map function, as :class:`~bson.code.Code` or string
:param reduce_f: reduce function, as :param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string :class:`~bson.code.Code` or string
:param output: output collection name, if set to 'inline' will try to
use :class:`~pymongo.collection.Collection.inline_map_reduce`
This can also be a dictionary containing output options
see: http://docs.mongodb.org/manual/reference/commands/#mapReduce
:param finalize_f: finalize function, an optional function that :param finalize_f: finalize function, an optional function that
performs any post-reduction processing. performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional. :param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide :param limit: number of objects from current query to provide
to map/reduce method to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`. :class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo .. note::
:meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**. Map/Reduce changed in server version **>= 1.7.4**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.11**.
.. versionchanged:: 0.5
- removed ``keep_temp`` keyword argument, which was only relevant
for MongoDB server versions older than 1.7.4
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
from document import MapReduceDocument from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"): if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1") raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, pymongo.code.Code): if isinstance(map_f, bson.code.Code):
map_f_scope = map_f.scope map_f_scope = map_f.scope
map_f = unicode(map_f) map_f = unicode(map_f)
map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) map_f = bson.code.Code(self._sub_js_fields(map_f), map_f_scope)
reduce_f_scope = {} reduce_f_scope = {}
if isinstance(reduce_f, pymongo.code.Code): if isinstance(reduce_f, bson.code.Code):
reduce_f_scope = reduce_f.scope reduce_f_scope = reduce_f.scope
reduce_f = unicode(reduce_f) reduce_f = unicode(reduce_f)
reduce_f_code = self._sub_js_fields(reduce_f) reduce_f_code = self._sub_js_fields(reduce_f)
reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) reduce_f = bson.code.Code(reduce_f_code, reduce_f_scope)
mr_args = {'query': self._query, 'keeptemp': keep_temp} mr_args = {'query': self._query}
if finalize_f: if finalize_f:
finalize_f_scope = {} finalize_f_scope = {}
if isinstance(finalize_f, pymongo.code.Code): if isinstance(finalize_f, bson.code.Code):
finalize_f_scope = finalize_f.scope finalize_f_scope = finalize_f.scope
finalize_f = unicode(finalize_f) finalize_f = unicode(finalize_f)
finalize_f_code = self._sub_js_fields(finalize_f) finalize_f_code = self._sub_js_fields(finalize_f)
finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) finalize_f = bson.code.Code(finalize_f_code, finalize_f_scope)
mr_args['finalize'] = finalize_f mr_args['finalize'] = finalize_f
if scope: if scope:
@ -732,8 +741,16 @@ class QuerySet(object):
if limit: if limit:
mr_args['limit'] = limit mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, **mr_args) if output == 'inline' and not self._ordering:
results = results.find() map_reduce_function = 'inline_map_reduce'
else:
map_reduce_function = 'map_reduce'
mr_args['out'] = output
results = getattr(self._collection, map_reduce_function)(map_f, reduce_f, **mr_args)
if map_reduce_function == 'map_reduce':
results = results.find()
if self._ordering: if self._ordering:
results = results.sort(self._ordering) results = results.sort(self._ordering)
@ -777,7 +794,7 @@ class QuerySet(object):
self._skip, self._limit = key.start, key.stop self._skip, self._limit = key.start, key.stop
except IndexError, err: except IndexError, err:
# PyMongo raises an error if key.start == key.stop, catch it, # PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it. # bin it, kill it.
start = key.start or 0 start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None: if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop: if start == key.stop:
@ -933,7 +950,7 @@ class QuerySet(object):
return mongo_update return mongo_update
def update(self, safe_update=True, upsert=False, **update): def update(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on the fields matched by the query. When """Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
@ -957,7 +974,7 @@ class QuerySet(object):
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update): def update_one(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on first field matched by the query. When """Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
@ -985,8 +1002,8 @@ class QuerySet(object):
return self return self
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be *fieldname* is the Python name of a field, *fieldname* will be
substituted for the MongoDB name of the field (specified using the substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor). :attr:`name` keyword argument in a field's constructor).
""" """
@ -1009,9 +1026,9 @@ class QuerySet(object):
options specified as keyword arguments. options specified as keyword arguments.
As fields in MongoEngine may use different names in the database (set As fields in MongoEngine may use different names in the database (set
using the :attr:`db_field` keyword argument to a :class:`Field` using the :attr:`db_field` keyword argument to a :class:`Field`
constructor), a mechanism exists for replacing MongoEngine field names constructor), a mechanism exists for replacing MongoEngine field names
with the database field names in Javascript code. When accessing a with the database field names in Javascript code. When accessing a
field, use square-bracket notation, and prefix the MongoEngine field field, use square-bracket notation, and prefix the MongoEngine field
name with a tilde (~). name with a tilde (~).
@ -1037,7 +1054,7 @@ class QuerySet(object):
query['$where'] = self._where_clause query['$where'] = self._where_clause
scope['query'] = query scope['query'] = query
code = pymongo.code.Code(code, scope=scope) code = bson.code.Code(code, scope=scope)
db = _get_db() db = _get_db()
return db.eval(code, *fields) return db.eval(code, *fields)

View File

@ -1,5 +1,6 @@
import unittest import unittest
from datetime import datetime from datetime import datetime
import bson
import pymongo import pymongo
from mongoengine import * from mongoengine import *
@ -7,7 +8,7 @@ from mongoengine.connection import _get_db
class DocumentTest(unittest.TestCase): class DocumentTest(unittest.TestCase):
def setUp(self): def setUp(self):
connect(db='mongoenginetest') connect(db='mongoenginetest')
self.db = _get_db() self.db = _get_db()
@ -38,7 +39,7 @@ class DocumentTest(unittest.TestCase):
name = name_field name = name_field
age = age_field age = age_field
non_field = True non_field = True
self.assertEqual(Person._fields['name'], name_field) self.assertEqual(Person._fields['name'], name_field)
self.assertEqual(Person._fields['age'], age_field) self.assertEqual(Person._fields['age'], age_field)
self.assertFalse('non_field' in Person._fields) self.assertFalse('non_field' in Person._fields)
@ -60,7 +61,7 @@ class DocumentTest(unittest.TestCase):
mammal_superclasses = {'Animal': Animal} mammal_superclasses = {'Animal': Animal}
self.assertEqual(Mammal._superclasses, mammal_superclasses) self.assertEqual(Mammal._superclasses, mammal_superclasses)
dog_superclasses = { dog_superclasses = {
'Animal': Animal, 'Animal': Animal,
'Animal.Mammal': Mammal, 'Animal.Mammal': Mammal,
@ -68,7 +69,7 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(Dog._superclasses, dog_superclasses) self.assertEqual(Dog._superclasses, dog_superclasses)
def test_get_subclasses(self): def test_get_subclasses(self):
"""Ensure that the correct list of subclasses is retrieved by the """Ensure that the correct list of subclasses is retrieved by the
_get_subclasses method. _get_subclasses method.
""" """
class Animal(Document): pass class Animal(Document): pass
@ -78,15 +79,15 @@ class DocumentTest(unittest.TestCase):
class Dog(Mammal): pass class Dog(Mammal): pass
mammal_subclasses = { mammal_subclasses = {
'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Dog': Dog,
'Animal.Mammal.Human': Human 'Animal.Mammal.Human': Human
} }
self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) self.assertEqual(Mammal._get_subclasses(), mammal_subclasses)
animal_subclasses = { animal_subclasses = {
'Animal.Fish': Fish, 'Animal.Fish': Fish,
'Animal.Mammal': Mammal, 'Animal.Mammal': Mammal,
'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Dog': Dog,
'Animal.Mammal.Human': Human 'Animal.Mammal.Human': Human
} }
self.assertEqual(Animal._get_subclasses(), animal_subclasses) self.assertEqual(Animal._get_subclasses(), animal_subclasses)
@ -124,7 +125,7 @@ class DocumentTest(unittest.TestCase):
self.assertTrue('name' in Employee._fields) self.assertTrue('name' in Employee._fields)
self.assertTrue('salary' in Employee._fields) self.assertTrue('salary' in Employee._fields)
self.assertEqual(Employee._meta['collection'], self.assertEqual(Employee._meta['collection'],
self.Person._meta['collection']) self.Person._meta['collection'])
# Ensure that MRO error is not raised # Ensure that MRO error is not raised
@ -146,7 +147,7 @@ class DocumentTest(unittest.TestCase):
class Dog(Animal): class Dog(Animal):
pass pass
self.assertRaises(ValueError, create_dog_class) self.assertRaises(ValueError, create_dog_class)
# Check that _cls etc aren't present on simple documents # Check that _cls etc aren't present on simple documents
dog = Animal(name='dog') dog = Animal(name='dog')
dog.save() dog.save()
@ -161,7 +162,7 @@ class DocumentTest(unittest.TestCase):
class Employee(self.Person): class Employee(self.Person):
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
self.assertRaises(ValueError, create_employee_class) self.assertRaises(ValueError, create_employee_class)
# Test the same for embedded documents # Test the same for embedded documents
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@ -186,7 +187,7 @@ class DocumentTest(unittest.TestCase):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
meta = {'collection': collection} meta = {'collection': collection}
user = Person(name="Test User") user = Person(name="Test User")
user.save() user.save()
self.assertTrue(collection in self.db.collection_names()) self.assertTrue(collection in self.db.collection_names())
@ -280,7 +281,7 @@ class DocumentTest(unittest.TestCase):
tags = ListField(StringField()) tags = ListField(StringField())
meta = { meta = {
'indexes': [ 'indexes': [
'-date', '-date',
'tags', 'tags',
('category', '-date') ('category', '-date')
], ],
@ -296,12 +297,12 @@ class DocumentTest(unittest.TestCase):
list(BlogPost.objects) list(BlogPost.objects)
info = BlogPost.objects._collection.index_information() info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()] info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info) in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info)
# tags is a list field so it shouldn't have _types in the index # tags is a list field so it shouldn't have _types in the index
self.assertTrue([('tags', 1)] in info) self.assertTrue([('tags', 1)] in info)
class ExtendedBlogPost(BlogPost): class ExtendedBlogPost(BlogPost):
title = StringField() title = StringField()
meta = {'indexes': ['title']} meta = {'indexes': ['title']}
@ -311,7 +312,7 @@ class DocumentTest(unittest.TestCase):
list(ExtendedBlogPost.objects) list(ExtendedBlogPost.objects)
info = ExtendedBlogPost.objects._collection.index_information() info = ExtendedBlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()] info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info) in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info)
self.assertTrue([('_types', 1), ('title', 1)] in info) self.assertTrue([('_types', 1), ('title', 1)] in info)
@ -380,7 +381,7 @@ class DocumentTest(unittest.TestCase):
class EmailUser(User): class EmailUser(User):
email = StringField() email = StringField()
user = User(username='test', name='test user') user = User(username='test', name='test user')
user.save() user.save()
@ -391,20 +392,20 @@ class DocumentTest(unittest.TestCase):
user_son = User.objects._collection.find_one() user_son = User.objects._collection.find_one()
self.assertEqual(user_son['_id'], 'test') self.assertEqual(user_son['_id'], 'test')
self.assertTrue('username' not in user_son['_id']) self.assertTrue('username' not in user_son['_id'])
User.drop_collection() User.drop_collection()
user = User(pk='mongo', name='mongo user') user = User(pk='mongo', name='mongo user')
user.save() user.save()
user_obj = User.objects.first() user_obj = User.objects.first()
self.assertEqual(user_obj.id, 'mongo') self.assertEqual(user_obj.id, 'mongo')
self.assertEqual(user_obj.pk, 'mongo') self.assertEqual(user_obj.pk, 'mongo')
user_son = User.objects._collection.find_one() user_son = User.objects._collection.find_one()
self.assertEqual(user_son['_id'], 'mongo') self.assertEqual(user_son['_id'], 'mongo')
self.assertTrue('username' not in user_son['_id']) self.assertTrue('username' not in user_son['_id'])
User.drop_collection() User.drop_collection()
def test_creation(self): def test_creation(self):
@ -457,18 +458,18 @@ class DocumentTest(unittest.TestCase):
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
self.assertTrue('content' in Comment._fields) self.assertTrue('content' in Comment._fields)
self.assertFalse('id' in Comment._fields) self.assertFalse('id' in Comment._fields)
self.assertFalse('collection' in Comment._meta) self.assertFalse('collection' in Comment._meta)
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that embedded documents may be validated. """Ensure that embedded documents may be validated.
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
date = DateTimeField() date = DateTimeField()
content = StringField(required=True) content = StringField(required=True)
comment = Comment() comment = Comment()
self.assertRaises(ValidationError, comment.validate) self.assertRaises(ValidationError, comment.validate)
@ -496,7 +497,7 @@ class DocumentTest(unittest.TestCase):
# Test skipping validation on save # Test skipping validation on save
class Recipient(Document): class Recipient(Document):
email = EmailField(required=True) email = EmailField(required=True)
recipient = Recipient(email='root@localhost') recipient = Recipient(email='root@localhost')
self.assertRaises(ValidationError, recipient.save) self.assertRaises(ValidationError, recipient.save)
try: try:
@ -517,19 +518,19 @@ class DocumentTest(unittest.TestCase):
"""Ensure that a document may be saved with a custom _id. """Ensure that a document may be saved with a custom _id.
""" """
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
id='497ce96f395f2f052a494fd4') id='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._meta['collection']] collection = self.db[self.Person._meta['collection']]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_custom_pk(self): def test_save_custom_pk(self):
"""Ensure that a document may be saved with a custom _id using pk alias. """Ensure that a document may be saved with a custom _id using pk alias.
""" """
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
pk='497ce96f395f2f052a494fd4') pk='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
@ -565,7 +566,7 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_save_embedded_document(self): def test_save_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may be
saved in the database. saved in the database.
""" """
class EmployeeDetails(EmbeddedDocument): class EmployeeDetails(EmbeddedDocument):
@ -591,7 +592,7 @@ class DocumentTest(unittest.TestCase):
def test_save_reference(self): def test_save_reference(self):
"""Ensure that a document reference field may be saved in the database. """Ensure that a document reference field may be saved in the database.
""" """
class BlogPost(Document): class BlogPost(Document):
meta = {'collection': 'blogpost_1'} meta = {'collection': 'blogpost_1'}
content = StringField() content = StringField()
@ -610,8 +611,8 @@ class DocumentTest(unittest.TestCase):
post_obj = BlogPost.objects.first() post_obj = BlogPost.objects.first()
# Test laziness # Test laziness
self.assertTrue(isinstance(post_obj._data['author'], self.assertTrue(isinstance(post_obj._data['author'],
pymongo.dbref.DBRef)) bson.dbref.DBRef))
self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertTrue(isinstance(post_obj.author, self.Person))
self.assertEqual(post_obj.author.name, 'Test User') self.assertEqual(post_obj.author.name, 'Test User')

View File

@ -3,6 +3,7 @@
import unittest import unittest
import pymongo import pymongo
import bson
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
@ -58,7 +59,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 2) self.assertEqual(len(people), 2)
results = list(people) results = list(people)
self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (pymongo.objectid.ObjectId, self.assertTrue(isinstance(results[0].id, (bson.objectid.ObjectId,
str, unicode))) str, unicode)))
self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0].age, 20) self.assertEqual(results[0].age, 20)
@ -162,7 +163,7 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age__lt=30) person = self.Person.objects.get(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
def test_find_array_position(self): def test_find_array_position(self):
"""Ensure that query by array position works. """Ensure that query by array position works.
""" """
@ -177,7 +178,7 @@ class QuerySetTest(unittest.TestCase):
posts = ListField(EmbeddedDocumentField(Post)) posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection() Blog.drop_collection()
Blog.objects.create(tags=['a', 'b']) Blog.objects.create(tags=['a', 'b'])
self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='a')), 1)
self.assertEqual(len(Blog.objects(tags__0='b')), 0) self.assertEqual(len(Blog.objects(tags__0='b')), 0)
@ -226,16 +227,16 @@ class QuerySetTest(unittest.TestCase):
person, created = self.Person.objects.get_or_create(age=30) person, created = self.Person.objects.get_or_create(age=30)
self.assertEqual(person.name, "User B") self.assertEqual(person.name, "User B")
self.assertEqual(created, False) self.assertEqual(created, False)
person, created = self.Person.objects.get_or_create(age__lt=30) person, created = self.Person.objects.get_or_create(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertEqual(created, False) self.assertEqual(created, False)
# Try retrieving when no objects exists - new doc should be created # Try retrieving when no objects exists - new doc should be created
kwargs = dict(age=50, defaults={'name': 'User C'}) kwargs = dict(age=50, defaults={'name': 'User C'})
person, created = self.Person.objects.get_or_create(**kwargs) person, created = self.Person.objects.get_or_create(**kwargs)
self.assertEqual(created, True) self.assertEqual(created, True)
person = self.Person.objects.get(age=50) person = self.Person.objects.get(age=50)
self.assertEqual(person.name, "User C") self.assertEqual(person.name, "User C")
@ -328,7 +329,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
# Test unsafe expressions # Test unsafe expressions
person = self.Person(name='Guido van Rossum [.\'Geek\']') person = self.Person(name='Guido van Rossum [.\'Geek\']')
person.save() person.save()
@ -559,7 +560,7 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first()
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
@ -631,7 +632,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
name = StringField(db_field='doc-name') name = StringField(db_field='doc-name')
comments = ListField(EmbeddedDocumentField(Comment), comments = ListField(EmbeddedDocumentField(Comment),
db_field='cmnts') db_field='cmnts')
BlogPost.drop_collection() BlogPost.drop_collection()
@ -733,7 +734,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.update_one(add_to_set__tags='unique') BlogPost.objects.update_one(add_to_set__tags='unique')
post.reload() post.reload()
self.assertEqual(post.tags.count('unique'), 1) self.assertEqual(post.tags.count('unique'), 1)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_pull(self): def test_update_pull(self):
@ -802,7 +803,7 @@ class QuerySetTest(unittest.TestCase):
""" """
# run a map/reduce operation spanning all posts # run a map/reduce operation spanning all posts
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(len(results), 4) self.assertEqual(len(results), 4)
@ -813,7 +814,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(film.value, 3) self.assertEqual(film.value, 3)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_map_reduce_with_custom_object_ids(self): def test_map_reduce_with_custom_object_ids(self):
"""Ensure that QuerySet.map_reduce works properly with custom """Ensure that QuerySet.map_reduce works properly with custom
primary keys. primary keys.
@ -822,24 +823,24 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
title = StringField(primary_key=True) title = StringField(primary_key=True)
tags = ListField(StringField()) tags = ListField(StringField())
post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"])
post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"])
post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"])
post1.save() post1.save()
post2.save() post2.save()
post3.save() post3.save()
self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._fields['title'].db_field, '_id')
self.assertEqual(BlogPost._meta['id_field'], 'title') self.assertEqual(BlogPost._meta['id_field'], 'title')
map_f = """ map_f = """
function() { function() {
emit(this._id, 1); emit(this._id, 1);
} }
""" """
# reduce to a list of tag ids and counts # reduce to a list of tag ids and counts
reduce_f = """ reduce_f = """
function(key, values) { function(key, values) {
@ -850,10 +851,10 @@ class QuerySetTest(unittest.TestCase):
return total; return total;
} }
""" """
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(results[0].object, post1) self.assertEqual(results[0].object, post1)
self.assertEqual(results[1].object, post2) self.assertEqual(results[1].object, post2)
self.assertEqual(results[2].object, post3) self.assertEqual(results[2].object, post3)
@ -943,7 +944,7 @@ class QuerySetTest(unittest.TestCase):
finalize_f = """ finalize_f = """
function(key, value) { function(key, value) {
// f(sec_since_epoch,y,z) = // f(sec_since_epoch,y,z) =
// log10(z) + ((y*sec_since_epoch) / 45000) // log10(z) + ((y*sec_since_epoch) / 45000)
z_10 = Math.log(value.z) / Math.log(10); z_10 = Math.log(value.z) / Math.log(10);
weight = z_10 + ((value.y * value.t_s) / 45000); weight = z_10 + ((value.y * value.t_s) / 45000);
@ -962,6 +963,7 @@ class QuerySetTest(unittest.TestCase):
results = Link.objects.order_by("-value") results = Link.objects.order_by("-value")
results = results.map_reduce(map_f, results = results.map_reduce(map_f,
reduce_f, reduce_f,
"myresults",
finalize_f=finalize_f, finalize_f=finalize_f,
scope=scope) scope=scope)
results = list(results) results = list(results)
@ -1289,12 +1291,12 @@ class QuerySetTest(unittest.TestCase):
title = StringField() title = StringField()
date = DateTimeField() date = DateTimeField()
location = GeoPointField() location = GeoPointField()
def __unicode__(self): def __unicode__(self):
return self.title return self.title
Event.drop_collection() Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door", event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1), date=datetime.now() - timedelta(days=1),
location=[41.909889, -87.677137]) location=[41.909889, -87.677137])
@ -1304,7 +1306,7 @@ class QuerySetTest(unittest.TestCase):
event3 = Event(title="Coltrane Motion @ Empty Bottle", event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(), date=datetime.now(),
location=[41.900474, -87.686638]) location=[41.900474, -87.686638])
event1.save() event1.save()
event2.save() event2.save()
event3.save() event3.save()
@ -1324,24 +1326,24 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(event2 not in events) self.assertTrue(event2 not in events)
self.assertTrue(event1 in events) self.assertTrue(event1 in events)
self.assertTrue(event3 in events) self.assertTrue(event3 in events)
# ensure ordering is respected by "near" # ensure ordering is respected by "near"
events = Event.objects(location__near=[41.9120459, -87.67892]) events = Event.objects(location__near=[41.9120459, -87.67892])
events = events.order_by("-date") events = events.order_by("-date")
self.assertEqual(events.count(), 3) self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2]) self.assertEqual(list(events), [event3, event1, event2])
# find events around san francisco # find events around san francisco
point_and_distance = [[37.7566023, -122.415579], 10] point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2) self.assertEqual(events[0], event2)
# find events within 1 mile of greenpoint, broolyn, nyc, ny # find events within 1 mile of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1] point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0) self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance" # ensure ordering is respected by "within_distance"
point_and_distance = [[41.9120459, -87.67892], 10] point_and_distance = [[41.9120459, -87.67892], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
@ -1354,7 +1356,7 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_box=box) events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id) self.assertEqual(events[0].id, event2.id)
Event.drop_collection() Event.drop_collection()
def test_custom_querysets(self): def test_custom_querysets(self):
@ -1398,7 +1400,7 @@ class QTest(unittest.TestCase):
query = {'age': {'$gte': 18}, 'name': 'test'} query = {'age': {'$gte': 18}, 'name': 'test'}
self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query)
def test_q_with_dbref(self): def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly""" """Ensure Q objects handle DBRefs correctly"""
connect(db='mongoenginetest') connect(db='mongoenginetest')
@ -1440,7 +1442,7 @@ class QTest(unittest.TestCase):
query = Q(x__lt=100) & Q(y__ne='NotMyString') query = Q(x__lt=100) & Q(y__ne='NotMyString')
query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100)
mongo_query = { mongo_query = {
'x': {'$lt': 100, '$gt': -100}, 'x': {'$lt': 100, '$gt': -100},
'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']},
} }
self.assertEqual(query.to_query(TestDoc), mongo_query) self.assertEqual(query.to_query(TestDoc), mongo_query)