Added support for text search and text_score.

This commit is contained in:
Wilson Júnior 2014-07-07 20:24:37 -03:00
parent c6e846e0ae
commit f7ebf8dedd
5 changed files with 449 additions and 197 deletions

View File

@ -12,3 +12,4 @@ User Guide
querying querying
gridfs gridfs
signals signals
text-indexes

View File

@ -0,0 +1,47 @@
===========
Text Search
===========
After MongoDB 2.4 version, supports search documents by text indexes.
Defining a Document with text index
===================================
Use the *$* prefix to set a text index, Look the declaration::
class News(Document):
title = StringField()
content = StringField()
is_active = BooleanField()
meta = {'indexes': [
{'fields': ['$title', "$content"],
'default_language': 'english',
'weight': {'title': 10, 'content': 2}
}
]}
Querying
========
Saving a document::
News(title="Using mongodb text search",
content="Testing text search").save()
News(title="MongoEngine 0.9 released",
content="Various improvements").save()
Next, start a text search using :attr:`QuerySet.search_text` method::
document = News.objects.search_text('testing').first()
document.title # may be: "Using mongodb text search"
document = News.objects.search_text('released').first()
document.title # may be: "MongoEngine 0.9 released"

View File

@ -41,6 +41,7 @@ class InvalidCollectionError(Exception):
class EmbeddedDocument(BaseDocument): class EmbeddedDocument(BaseDocument):
"""A :class:`~mongoengine.Document` that isn't stored in its own """A :class:`~mongoengine.Document` that isn't stored in its own
collection. :class:`~mongoengine.EmbeddedDocument`\ s should be used as collection. :class:`~mongoengine.EmbeddedDocument`\ s should be used as
fields on :class:`~mongoengine.Document`\ s through the fields on :class:`~mongoengine.Document`\ s through the
@ -77,6 +78,7 @@ class EmbeddedDocument(BaseDocument):
class Document(BaseDocument): class Document(BaseDocument):
"""The base class used for defining the structure and properties of """The base class used for defining the structure and properties of
collections of documents stored in MongoDB. Inherit from this class, and collections of documents stored in MongoDB. Inherit from this class, and
add fields as class attributes to define a document's structure. add fields as class attributes to define a document's structure.
@ -132,6 +134,7 @@ class Document(BaseDocument):
def pk(): def pk():
"""Primary key alias """Primary key alias
""" """
def fget(self): def fget(self):
return getattr(self, self._meta['id_field']) return getattr(self, self._meta['id_field'])
@ -140,6 +143,13 @@ class Document(BaseDocument):
return property(fget, fset) return property(fget, fset)
pk = pk() pk = pk()
@property
def text_score(self):
"""
Used for text searchs
"""
return self._data.get('text_score')
@classmethod @classmethod
def _get_db(cls): def _get_db(cls):
"""Some Model using other db_alias""" """Some Model using other db_alias"""
@ -282,9 +292,9 @@ class Document(BaseDocument):
upsert=upsert, **write_concern) upsert=upsert, **write_concern)
created = is_new_object(last_error) created = is_new_object(last_error)
if cascade is None: if cascade is None:
cascade = self._meta.get('cascade', False) or cascade_kwargs is not None cascade = self._meta.get(
'cascade', False) or cascade_kwargs is not None
if cascade: if cascade:
kwargs = { kwargs = {
@ -377,7 +387,8 @@ class Document(BaseDocument):
del(query["_cls"]) del(query["_cls"])
return self._qs.filter(**query).update_one(**kwargs) return self._qs.filter(**query).update_one(**kwargs)
else: else:
raise OperationError('attempt to update a document not yet saved') raise OperationError(
'attempt to update a document not yet saved')
# Need to add shard key to query, or you get an error # Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs) return self._qs.filter(**self._object_key).update_one(**kwargs)
@ -396,7 +407,8 @@ class Document(BaseDocument):
signals.pre_delete.send(self.__class__, document=self) signals.pre_delete.send(self.__class__, document=self)
try: try:
self._qs.filter(**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) self._qs.filter(
**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = u'Could not delete document (%s)' % err.message message = u'Could not delete document (%s)' % err.message
raise OperationError(message) raise OperationError(message)
@ -621,6 +633,7 @@ class Document(BaseDocument):
# get all the base classes, subclasses and sieblings # get all the base classes, subclasses and sieblings
classes = [] classes = []
def get_classes(cls): def get_classes(cls):
if (cls not in classes and if (cls not in classes and
@ -678,7 +691,8 @@ class Document(BaseDocument):
""" """
required = cls.list_indexes() required = cls.list_indexes()
existing = [info['key'] for info in cls._get_collection().index_information().values()] existing = [info['key']
for info in cls._get_collection().index_information().values()]
missing = [index for index in required if index not in existing] missing = [index for index in required if index not in existing]
extra = [index for index in existing if index not in required] extra = [index for index in existing if index not in required]
@ -696,6 +710,7 @@ class Document(BaseDocument):
class DynamicDocument(Document): class DynamicDocument(Document):
"""A Dynamic Document class allowing flexible, expandable and uncontrolled """A Dynamic Document class allowing flexible, expandable and uncontrolled
schemas. As a :class:`~mongoengine.Document` subclass, acts in the same schemas. As a :class:`~mongoengine.Document` subclass, acts in the same
way as an ordinary document but has expando style properties. Any data way as an ordinary document but has expando style properties. Any data
@ -727,6 +742,7 @@ class DynamicDocument(Document):
class DynamicEmbeddedDocument(EmbeddedDocument): class DynamicEmbeddedDocument(EmbeddedDocument):
"""A Dynamic Embedded Document class allowing flexible, expandable and """A Dynamic Embedded Document class allowing flexible, expandable and
uncontrolled schemas. See :class:`~mongoengine.DynamicDocument` for more uncontrolled schemas. See :class:`~mongoengine.DynamicDocument` for more
information about dynamic documents. information about dynamic documents.
@ -753,6 +769,7 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
class MapReduceDocument(object): class MapReduceDocument(object):
"""A document returned from a map/reduce query. """A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection` :param collection: An instance of :class:`~pymongo.Collection`
@ -783,7 +800,7 @@ class MapReduceDocument(object):
try: try:
self.key = id_field_type(self.key) self.key = id_field_type(self.key)
except: except:
raise Exception("Could not cast key as %s" % \ raise Exception("Could not cast key as %s" %
id_field_type.__name__) id_field_type.__name__)
if not hasattr(self, "_key_object"): if not hasattr(self, "_key_object"):

View File

@ -39,6 +39,7 @@ RE_TYPE = type(re.compile(''))
class BaseQuerySet(object): class BaseQuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor, """A set of results returned from a query. Wraps a MongoDB cursor,
providing :class:`~mongoengine.Document` objects as the results. providing :class:`~mongoengine.Document` objects as the results.
""" """
@ -64,6 +65,8 @@ class BaseQuerySet(object):
self._none = False self._none = False
self._as_pymongo = False self._as_pymongo = False
self._as_pymongo_coerce = False self._as_pymongo_coerce = False
self._search_text = None
self._include_text_scores = False
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
# subclasses of the class being used # subclasses of the class being used
@ -71,7 +74,8 @@ class BaseQuerySet(object):
if len(self._document._subclasses) == 1: if len(self._document._subclasses) == 1:
self._initial_query = {"_cls": self._document._subclasses[0]} self._initial_query = {"_cls": self._document._subclasses[0]}
else: else:
self._initial_query = {"_cls": {"$in": self._document._subclasses}} self._initial_query = {
"_cls": {"$in": self._document._subclasses}}
self._loaded_fields = QueryFieldList(always_include=['_cls']) self._loaded_fields = QueryFieldList(always_include=['_cls'])
self._cursor_obj = None self._cursor_obj = None
self._limit = None self._limit = None
@ -148,6 +152,7 @@ class BaseQuerySet(object):
return queryset._get_scalar( return queryset._get_scalar(
queryset._document._from_son(queryset._cursor[key], queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference)) _auto_dereference=self._auto_dereference))
if queryset._as_pymongo: if queryset._as_pymongo:
return queryset._get_as_pymongo(queryset._cursor[key]) return queryset._get_as_pymongo(queryset._cursor[key])
return queryset._document._from_son(queryset._cursor[key], return queryset._document._from_son(queryset._cursor[key],
@ -184,6 +189,35 @@ class BaseQuerySet(object):
""" """
return self.__call__(*q_objs, **query) return self.__call__(*q_objs, **query)
def search_text(self, text, language=None, include_text_scores=False):
"""
Start a text search, using text indexes.
:param language: The language that determines the list of stop words
for the search and the rules for the stemmer and tokenizer.
If not specified, the search uses the default language of the index.
For supported languages, see `Text Search Languages <http://docs.mongodb.org/manual/reference/text-search-languages/#text-search-languages>`.
:param include_text_scores: If True, automaticaly add a text_score attribute to Document.
"""
queryset = self.clone()
if queryset._search_text:
raise OperationError(
"Is not possible to use search_text two times.")
query_kwargs = {'$search': text}
if language:
query_kwargs['$language'] = language
queryset._query_obj &= Q(__raw__={'$text': query_kwargs})
queryset._mongo_query = None
queryset._cursor_obj = None
queryset._search_text = text
queryset._include_text_scores = include_text_scores
return queryset
def get(self, *q_objs, **query): def get(self, *q_objs, **query):
"""Retrieve the the matching object raising """Retrieve the the matching object raising
:class:`~mongoengine.queryset.MultipleObjectsReturned` or :class:`~mongoengine.queryset.MultipleObjectsReturned` or
@ -322,10 +356,10 @@ class BaseQuerySet(object):
try: try:
ids = self._collection.insert(raw, **write_concern) ids = self._collection.insert(raw, **write_concern)
except pymongo.errors.DuplicateKeyError, err: except pymongo.errors.DuplicateKeyError, err:
message = 'Could not save document (%s)'; message = 'Could not save document (%s)'
raise NotUniqueError(message % unicode(err)) raise NotUniqueError(message % unicode(err))
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'; message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', unicode(err)): if re.match('^E1100[01] duplicate key', unicode(err)):
# E11000 - duplicate key error index # E11000 - duplicate key error index
# E11001 - duplicate key on update # E11001 - duplicate key on update
@ -418,7 +452,8 @@ class BaseQuerySet(object):
write_concern=write_concern, write_concern=write_concern,
**{'pull_all__%s' % field_name: self}) **{'pull_all__%s' % field_name: self})
queryset._collection.remove(queryset._query, write_concern=write_concern) queryset._collection.remove(
queryset._query, write_concern=write_concern)
def update(self, upsert=False, multi=True, write_concern=None, def update(self, upsert=False, multi=True, write_concern=None,
full_result=False, **update): full_result=False, **update):
@ -515,7 +550,8 @@ class BaseQuerySet(object):
raise OperationError("Conflicting parameters: remove and new") raise OperationError("Conflicting parameters: remove and new")
if not update and not upsert and not remove: if not update and not upsert and not remove:
raise OperationError("No update parameters, must either update or remove") raise OperationError(
"No update parameters, must either update or remove")
queryset = self.clone() queryset = self.clone()
query = queryset._query query = queryset._query
@ -622,13 +658,15 @@ class BaseQuerySet(object):
:class:`~mongoengine.queryset.base.BaseQuerySet` into another child class :class:`~mongoengine.queryset.base.BaseQuerySet` into another child class
""" """
if not isinstance(cls, BaseQuerySet): if not isinstance(cls, BaseQuerySet):
raise OperationError('%s is not a subclass of BaseQuerySet' % cls.__name__) raise OperationError(
'%s is not a subclass of BaseQuerySet' % cls.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot', '_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_class_check', '_slave_okay', '_read_preference', '_timeout', '_class_check', '_slave_okay', '_read_preference',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference') '_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', '_include_text_scores')
for prop in copy_props: for prop in copy_props:
val = getattr(self, prop) val = getattr(self, prop)
@ -714,11 +752,14 @@ class BaseQuerySet(object):
distinct = self._dereference(queryset._cursor.distinct(field), 1, distinct = self._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=self._document) name=field, instance=self._document)
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) # We may need to cast to the correct type eg.
doc_field = getattr(self._document._fields.get(field), "field", None) # ListField(EmbeddedDocumentField)
doc_field = getattr(
self._document._fields.get(field), "field", None)
instance = getattr(doc_field, "document_type", False) instance = getattr(doc_field, "document_type", False)
EmbeddedDocumentField = _import_class('EmbeddedDocumentField') EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') GenericEmbeddedDocumentField = _import_class(
'GenericEmbeddedDocumentField')
if instance and isinstance(doc_field, (EmbeddedDocumentField, if instance and isinstance(doc_field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)): GenericEmbeddedDocumentField)):
distinct = [instance(**doc) for doc in distinct] distinct = [instance(**doc) for doc in distinct]
@ -799,7 +840,8 @@ class BaseQuerySet(object):
for value, group in itertools.groupby(fields, lambda x: x[1]): for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group] fields = [field for field, value in group]
fields = queryset._fields_to_dbfields(fields) fields = queryset._fields_to_dbfields(fields)
queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) queryset._loaded_fields += QueryFieldList(
fields, value=value, _only_called=_only_called)
return queryset return queryset
@ -1036,7 +1078,6 @@ class BaseQuerySet(object):
ordered_output.append(('db', get_db(db_alias).name)) ordered_output.append(('db', get_db(db_alias).name))
del remaing_args[0] del remaing_args[0]
for part in remaing_args: for part in remaing_args:
value = output.get(part) value = output.get(part)
if value: if value:
@ -1292,6 +1333,13 @@ class BaseQuerySet(object):
cursor_args['slave_okay'] = self._slave_okay cursor_args['slave_okay'] = self._slave_okay
if self._loaded_fields: if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict() cursor_args['fields'] = self._loaded_fields.as_dict()
if self._include_text_scores:
if 'fields' not in cursor_args:
cursor_args['fields'] = {}
cursor_args['fields']['text_score'] = {'$meta': "textScore"}
return cursor_args return cursor_args
@property @property

View File

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -26,6 +28,7 @@ __all__ = ("QuerySetTest",)
class db_ops_tracker(query_counter): class db_ops_tracker(query_counter):
def get_ops(self): def get_ops(self):
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
return list(self.db.system.profile.find(ignore_query)) return list(self.db.system.profile.find(ignore_query))
@ -150,8 +153,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(self.Person.objects.count(), 55) self.assertEqual(self.Person.objects.count(), 55)
self.assertEqual("Person object", "%s" % self.Person.objects[0]) self.assertEqual("Person object", "%s" % self.Person.objects[0])
self.assertEqual("[<Person: Person object>, <Person: Person object>]", "%s" % self.Person.objects[1:3]) self.assertEqual(
self.assertEqual("[<Person: Person object>, <Person: Person object>]", "%s" % self.Person.objects[51:53]) "[<Person: Person object>, <Person: Person object>]", "%s" % self.Person.objects[1:3])
self.assertEqual(
"[<Person: Person object>, <Person: Person object>]", "%s" % self.Person.objects[51:53])
def test_find_one(self): def test_find_one(self):
"""Ensure that a query using find_one returns a valid result. """Ensure that a query using find_one returns a valid result.
@ -187,7 +192,8 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.with_id(person1.id) person = self.Person.objects.with_id(person1.id)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertRaises(InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id) self.assertRaises(
InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id)
def test_find_only_one(self): def test_find_only_one(self):
"""Ensure that a query using ``get`` returns at most one result. """Ensure that a query using ``get`` returns at most one result.
@ -480,7 +486,8 @@ class QuerySetTest(unittest.TestCase):
BlogPost(title="ABC", comments=[c1, c2]).save() BlogPost(title="ABC", comments=[c1, c2]).save()
BlogPost.objects(comments__by="joe").update(set__comments__S__votes=Vote(score=4)) BlogPost.objects(comments__by="joe").update(
set__comments__S__votes=Vote(score=4))
post = BlogPost.objects.first() post = BlogPost.objects.first()
self.assertEqual(post.comments[0].by, 'joe') self.assertEqual(post.comments[0].by, 'joe')
@ -551,7 +558,8 @@ class QuerySetTest(unittest.TestCase):
def test_update_results(self): def test_update_results(self):
self.Person.drop_collection() self.Person.drop_collection()
result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True) result = self.Person(name="Bob", age=25).update(
upsert=True, full_result=True)
self.assertTrue(isinstance(result, dict)) self.assertTrue(isinstance(result, dict))
self.assertTrue("upserted" in result) self.assertTrue("upserted" in result)
self.assertFalse(result["updatedExisting"]) self.assertFalse(result["updatedExisting"])
@ -562,13 +570,15 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(result["updatedExisting"]) self.assertTrue(result["updatedExisting"])
self.Person(name="Bob", age=20).save() self.Person(name="Bob", age=20).save()
result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True) result = self.Person.objects(name="Bob").update(
set__name="bobby", multi=True)
self.assertEqual(result, 2) self.assertEqual(result, 2)
def test_upsert(self): def test_upsert(self):
self.Person.drop_collection() self.Person.drop_collection()
self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True) self.Person.objects(
pk=ObjectId(), name="Bob", age=30).update(upsert=True)
bob = self.Person.objects.first() bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name) self.assertEqual("Bob", bob.name)
@ -586,7 +596,8 @@ class QuerySetTest(unittest.TestCase):
def test_set_on_insert(self): def test_set_on_insert(self):
self.Person.drop_collection() self.Person.drop_collection()
self.Person.objects(pk=ObjectId()).update(set__name='Bob', set_on_insert__age=30, upsert=True) self.Person.objects(pk=ObjectId()).update(
set__name='Bob', set_on_insert__age=30, upsert=True)
bob = self.Person.objects.first() bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name) self.assertEqual("Bob", bob.name)
@ -660,7 +671,8 @@ class QuerySetTest(unittest.TestCase):
if (get_connection().max_wire_version <= 1): if (get_connection().max_wire_version <= 1):
self.assertEqual(q, 1) self.assertEqual(q, 1)
else: else:
self.assertEqual(q, 99) # profiling logs each doc now in the bulk op # profiling logs each doc now in the bulk op
self.assertEqual(q, 99)
Blog.drop_collection() Blog.drop_collection()
Blog.ensure_indexes() Blog.ensure_indexes()
@ -672,7 +684,8 @@ class QuerySetTest(unittest.TestCase):
if (get_connection().max_wire_version <= 1): if (get_connection().max_wire_version <= 1):
self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch
else: else:
self.assertEqual(q, 100) # 99 for insert, and 1 for in bulk fetch # 99 for insert, and 1 for in bulk fetch
self.assertEqual(q, 100)
Blog.drop_collection() Blog.drop_collection()
@ -1069,7 +1082,8 @@ class QuerySetTest(unittest.TestCase):
with db_ops_tracker() as q: with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').first() BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1) self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) self.assertEqual(
q.get_ops()[0]['query']['$orderby'], {u'published_date': -1})
with db_ops_tracker() as q: with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first() BlogPost.objects.filter(title='whatever').order_by().first()
@ -1088,7 +1102,8 @@ class QuerySetTest(unittest.TestCase):
'ordering': ['-published_date'] 'ordering': ['-published_date']
} }
BlogPost.objects.create(title='whatever', published_date=datetime.utcnow()) BlogPost.objects.create(
title='whatever', published_date=datetime.utcnow())
with db_ops_tracker() as q: with db_ops_tracker() as q:
BlogPost.objects.get(title='whatever') BlogPost.objects.get(title='whatever')
@ -1139,7 +1154,6 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_exec_js_query(self): def test_exec_js_query(self):
"""Ensure that queries are properly formed for use in exec_js. """Ensure that queries are properly formed for use in exec_js.
""" """
@ -1413,7 +1427,6 @@ class QuerySetTest(unittest.TestCase):
self.Person.objects()[:1].delete() self.Person.objects()[:1].delete()
self.assertEqual(1, BlogPost.objects.count()) self.assertEqual(1, BlogPost.objects.count())
def test_reference_field_find(self): def test_reference_field_find(self):
"""Ensure cascading deletion of referring documents from the database. """Ensure cascading deletion of referring documents from the database.
""" """
@ -1433,7 +1446,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) self.assertEqual(
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_reference_field_find_dbref(self): def test_reference_field_find_dbref(self):
"""Ensure cascading deletion of referring documents from the database. """Ensure cascading deletion of referring documents from the database.
@ -1454,7 +1468,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) self.assertEqual(
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_update(self): def test_update(self):
"""Ensure that atomic updates work properly. """Ensure that atomic updates work properly.
@ -1522,7 +1537,8 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.tags, ["code"]) self.assertEqual(post.tags, ["code"])
BlogPost.objects.filter(id=post.id).update(push_all__tags=["mongodb", "code"]) BlogPost.objects.filter(id=post.id).update(
push_all__tags=["mongodb", "code"])
post.reload() post.reload()
self.assertEqual(post.tags, ["code", "mongodb", "code"]) self.assertEqual(post.tags, ["code", "mongodb", "code"])
@ -1530,12 +1546,13 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.tags, ["mongodb"]) self.assertEqual(post.tags, ["mongodb"])
BlogPost.objects(slug="test").update(
BlogPost.objects(slug="test").update(pull_all__tags=["mongodb", "code"]) pull_all__tags=["mongodb", "code"])
post.reload() post.reload()
self.assertEqual(post.tags, []) self.assertEqual(post.tags, [])
BlogPost.objects(slug="test").update(__raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}}) BlogPost.objects(slug="test").update(
__raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}})
post.reload() post.reload()
self.assertEqual(post.tags, ["code", "mongodb"]) self.assertEqual(post.tags, ["code", "mongodb"])
@ -1568,7 +1585,6 @@ class QuerySetTest(unittest.TestCase):
name = StringField(max_length=75, unique=True, required=True) name = StringField(max_length=75, unique=True, required=True)
collaborators = ListField(EmbeddedDocumentField(Collaborator)) collaborators = ListField(EmbeddedDocumentField(Collaborator))
Site.drop_collection() Site.drop_collection()
c = Collaborator(user='Esteban') c = Collaborator(user='Esteban')
@ -1578,7 +1594,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Site.objects.first().collaborators, []) self.assertEqual(Site.objects.first().collaborators, [])
def pull_all(): def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__user=['Ross']) Site.objects(id=s.id).update_one(
pull_all__collaborators__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all) self.assertRaises(InvalidQueryError, pull_all)
@ -1598,21 +1615,23 @@ class QuerySetTest(unittest.TestCase):
name = StringField(max_length=75, unique=True, required=True) name = StringField(max_length=75, unique=True, required=True)
collaborators = EmbeddedDocumentField(Collaborator) collaborators = EmbeddedDocumentField(Collaborator)
Site.drop_collection() Site.drop_collection()
c = User(name='Esteban') c = User(name='Esteban')
f = User(name='Frank') f = User(name='Frank')
s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save() s = Site(name="test", collaborators=Collaborator(
helpful=[c], unhelpful=[f])).save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) Site.objects(id=s.id).update_one(pull__collaborators__helpful=c)
self.assertEqual(Site.objects.first().collaborators['helpful'], []) self.assertEqual(Site.objects.first().collaborators['helpful'], [])
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'}) Site.objects(id=s.id).update_one(
pull__collaborators__unhelpful={'name': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all(): def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross']) Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__name=['Ross'])
self.assertRaises(InvalidQueryError, pull_all) self.assertRaises(InvalidQueryError, pull_all)
@ -1626,8 +1645,8 @@ class QuerySetTest(unittest.TestCase):
class Site(Document): class Site(Document):
name = StringField(max_length=75, unique=True, required=True) name = StringField(max_length=75, unique=True, required=True)
collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator))) collaborators = MapField(
ListField(EmbeddedDocumentField(Collaborator)))
Site.drop_collection() Site.drop_collection()
@ -1636,14 +1655,17 @@ class QuerySetTest(unittest.TestCase):
s = Site(name="test", collaborators={'helpful': [c], 'unhelpful': [f]}) s = Site(name="test", collaborators={'helpful': [c], 'unhelpful': [f]})
s.save() s.save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban') Site.objects(id=s.id).update_one(
pull__collaborators__helpful__user='Esteban')
self.assertEqual(Site.objects.first().collaborators['helpful'], []) self.assertEqual(Site.objects.first().collaborators['helpful'], [])
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'}) Site.objects(id=s.id).update_one(
pull__collaborators__unhelpful={'user': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all(): def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross']) Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all) self.assertRaises(InvalidQueryError, pull_all)
@ -1893,7 +1915,8 @@ class QuerySetTest(unittest.TestCase):
Author(author=person_b).save() Author(author=person_b).save()
Author(author=person_c).save() Author(author=person_c).save()
names = [a.author.name for a in Author.objects.order_by('-author__age')] names = [
a.author.name for a in Author.objects.order_by('-author__age')]
self.assertEqual(names, ['User A', 'User B', 'User C']) self.assertEqual(names, ['User A', 'User B', 'User C'])
def test_map_reduce(self): def test_map_reduce(self):
@ -2250,7 +2273,8 @@ class QuerySetTest(unittest.TestCase):
def test_assertions(f): def test_assertions(f):
f = dict((key, int(val)) for key, val in f.items()) f = dict((key, int(val)) for key, val in f.items())
self.assertEqual(set(['music', 'film', 'actors', 'watch']), set(f.keys())) self.assertEqual(
set(['music', 'film', 'actors', 'watch']), set(f.keys()))
self.assertEqual(f['music'], 3) self.assertEqual(f['music'], 3)
self.assertEqual(f['actors'], 2) self.assertEqual(f['actors'], 2)
self.assertEqual(f['watch'], 2) self.assertEqual(f['watch'], 2)
@ -2270,7 +2294,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(f['watch'], 1) self.assertEqual(f['watch'], 1)
exec_js = BlogPost.objects(hits__gt=1).item_frequencies('tags') exec_js = BlogPost.objects(hits__gt=1).item_frequencies('tags')
map_reduce = BlogPost.objects(hits__gt=1).item_frequencies('tags', map_reduce=True) map_reduce = BlogPost.objects(
hits__gt=1).item_frequencies('tags', map_reduce=True)
test_assertions(exec_js) test_assertions(exec_js)
test_assertions(map_reduce) test_assertions(map_reduce)
@ -2282,7 +2307,8 @@ class QuerySetTest(unittest.TestCase):
self.assertAlmostEqual(f['film'], 1.0 / 8.0) self.assertAlmostEqual(f['film'], 1.0 / 8.0)
exec_js = BlogPost.objects.item_frequencies('tags', normalize=True) exec_js = BlogPost.objects.item_frequencies('tags', normalize=True)
map_reduce = BlogPost.objects.item_frequencies('tags', normalize=True, map_reduce=True) map_reduce = BlogPost.objects.item_frequencies(
'tags', normalize=True, map_reduce=True)
test_assertions(exec_js) test_assertions(exec_js)
test_assertions(map_reduce) test_assertions(map_reduce)
@ -2324,15 +2350,16 @@ class QuerySetTest(unittest.TestCase):
doc.phone = Phone(number='62-3332-1656') doc.phone = Phone(number='62-3332-1656')
doc.save() doc.save()
def test_assertions(f): def test_assertions(f):
f = dict((key, int(val)) for key, val in f.items()) f = dict((key, int(val)) for key, val in f.items())
self.assertEqual(set(['62-3331-1656', '62-3332-1656']), set(f.keys())) self.assertEqual(
set(['62-3331-1656', '62-3332-1656']), set(f.keys()))
self.assertEqual(f['62-3331-1656'], 2) self.assertEqual(f['62-3331-1656'], 2)
self.assertEqual(f['62-3332-1656'], 1) self.assertEqual(f['62-3332-1656'], 1)
exec_js = Person.objects.item_frequencies('phone.number') exec_js = Person.objects.item_frequencies('phone.number')
map_reduce = Person.objects.item_frequencies('phone.number', map_reduce=True) map_reduce = Person.objects.item_frequencies(
'phone.number', map_reduce=True)
test_assertions(exec_js) test_assertions(exec_js)
test_assertions(map_reduce) test_assertions(map_reduce)
@ -2342,8 +2369,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(set(['62-3331-1656']), set(f.keys())) self.assertEqual(set(['62-3331-1656']), set(f.keys()))
self.assertEqual(f['62-3331-1656'], 2) self.assertEqual(f['62-3331-1656'], 2)
exec_js = Person.objects(phone__number='62-3331-1656').item_frequencies('phone.number') exec_js = Person.objects(
map_reduce = Person.objects(phone__number='62-3331-1656').item_frequencies('phone.number', map_reduce=True) phone__number='62-3331-1656').item_frequencies('phone.number')
map_reduce = Person.objects(
phone__number='62-3331-1656').item_frequencies('phone.number', map_reduce=True)
test_assertions(exec_js) test_assertions(exec_js)
test_assertions(map_reduce) test_assertions(map_reduce)
@ -2352,8 +2381,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(f['62-3331-1656'], 2.0 / 3.0) self.assertEqual(f['62-3331-1656'], 2.0 / 3.0)
self.assertEqual(f['62-3332-1656'], 1.0 / 3.0) self.assertEqual(f['62-3332-1656'], 1.0 / 3.0)
exec_js = Person.objects.item_frequencies('phone.number', normalize=True) exec_js = Person.objects.item_frequencies(
map_reduce = Person.objects.item_frequencies('phone.number', normalize=True, map_reduce=True) 'phone.number', normalize=True)
map_reduce = Person.objects.item_frequencies(
'phone.number', normalize=True, map_reduce=True)
test_assertions(exec_js) test_assertions(exec_js)
test_assertions(map_reduce) test_assertions(map_reduce)
@ -2373,10 +2404,10 @@ class QuerySetTest(unittest.TestCase):
freq = Person.objects.item_frequencies('city', normalize=True) freq = Person.objects.item_frequencies('city', normalize=True)
self.assertEqual(freq, {'CRB': 0.5, None: 0.5}) self.assertEqual(freq, {'CRB': 0.5, None: 0.5})
freq = Person.objects.item_frequencies('city', map_reduce=True) freq = Person.objects.item_frequencies('city', map_reduce=True)
self.assertEqual(freq, {'CRB': 1.0, None: 1.0}) self.assertEqual(freq, {'CRB': 1.0, None: 1.0})
freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True) freq = Person.objects.item_frequencies(
'city', normalize=True, map_reduce=True)
self.assertEqual(freq, {'CRB': 0.5, None: 0.5}) self.assertEqual(freq, {'CRB': 0.5, None: 0.5})
def test_item_frequencies_with_null_embedded(self): def test_item_frequencies_with_null_embedded(self):
@ -2447,10 +2478,12 @@ class QuerySetTest(unittest.TestCase):
for i in xrange(20): for i in xrange(20):
Test(val=2).save() Test(val=2).save()
freqs = Test.objects.item_frequencies('val', map_reduce=False, normalize=True) freqs = Test.objects.item_frequencies(
'val', map_reduce=False, normalize=True)
self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70})
freqs = Test.objects.item_frequencies('val', map_reduce=True, normalize=True) freqs = Test.objects.item_frequencies(
'val', map_reduce=True, normalize=True)
self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70})
def test_average(self): def test_average(self):
@ -2470,17 +2503,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(int(self.Person.objects.average('age')), avg) self.assertEqual(int(self.Person.objects.average('age')), avg)
# dot notation # dot notation
self.Person(name='person meta', person_meta=self.PersonMeta(weight=0)).save() self.Person(
self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), 0) name='person meta', person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), 0)
for i, weight in enumerate(ages): for i, weight in enumerate(ages):
self.Person(name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() self.Person(
name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save()
self.assertAlmostEqual(int(self.Person.objects.average('person_meta.weight')), avg) self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), avg)
self.Person(name='test meta none').save() self.Person(name='test meta none').save()
self.assertEqual(int(self.Person.objects.average('person_meta.weight')), avg) self.assertEqual(
int(self.Person.objects.average('person_meta.weight')), avg)
def test_sum(self): def test_sum(self):
"""Ensure that field can be summed over correctly. """Ensure that field can be summed over correctly.
@ -2495,9 +2532,11 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
for i, age in enumerate(ages): for i, age in enumerate(ages):
self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save() self.Person(name='test meta%s' %
i, person_meta=self.PersonMeta(weight=age)).save()
self.assertEqual(int(self.Person.objects.sum('person_meta.weight')), sum(ages)) self.assertEqual(
int(self.Person.objects.sum('person_meta.weight')), sum(ages))
self.Person(name='weightless person').save() self.Person(name='weightless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
@ -2598,7 +2637,6 @@ class QuerySetTest(unittest.TestCase):
Doc.objects.sum('pay.value'), Doc.objects.sum('pay.value'),
960) 960)
def test_embedded_array_sum(self): def test_embedded_array_sum(self):
class Pay(EmbeddedDocument): class Pay(EmbeddedDocument):
values = ListField(DecimalField()) values = ListField(DecimalField())
@ -2673,6 +2711,68 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Foo.objects.distinct("bar"), [bar]) self.assertEqual(Foo.objects.distinct("bar"), [bar])
def test_text_indexes(self):
class News(Document):
title = StringField()
content = StringField()
is_active = BooleanField(default=True)
meta = {'indexes': [
{'fields': ['$title', "$content"],
'default_language': 'portuguese',
'weight': {'title': 10, 'content': 2}
}
]}
News.drop_collection()
info = News.objects._collection.index_information()
self.assertTrue('title_text_content_text' in info)
self.assertTrue('textIndexVersion' in info['title_text_content_text'])
News(title="Neymar quebrou a vertebra",
content="O Brasil sofre com a perda de Neymar").save()
News(title="Brasil passa para as quartas de finais",
content="Com o brasil nas quartas de finais teremos um "
"jogo complicado com a alemanha").save()
count = News.objects.search_text(
"neymar", language="portuguese").count()
self.assertEqual(count, 1)
count = News.objects.search_text(
"brasil -neymar").count()
self.assertEqual(count, 1)
News(title=u"As eleições no Brasil já estão em planejamento",
content=u"A candidata dilma roussef já começa o teu planejamento",
is_active=False).save()
new = News.objects(is_active=False).search_text(
"dilma", language="pt").first()
query = News.objects(is_active=False).search_text(
"dilma", language="pt")._query
self.assertEqual(
query, {'$text': {
'$search': 'dilma', '$language': 'pt'},
'is_active': False})
self.assertEqual(new.is_active, False)
self.assertTrue('dilma' in new.content)
self.assertTrue('planejamento' in new.title)
query = News.objects.search_text(
"candidata", include_text_scores=True)
self.assertTrue(query._include_text_scores)
new = query.first()
self.assertTrue(isinstance(new.text_score, float))
def test_distinct_handles_references_to_alias(self): def test_distinct_handles_references_to_alias(self):
register_connection('testdb', 'mongoenginetest2') register_connection('testdb', 'mongoenginetest2')
@ -2729,8 +2829,10 @@ class QuerySetTest(unittest.TestCase):
john_tolkien = Author(name="John Ronald Reuel Tolkien") john_tolkien = Author(name="John Ronald Reuel Tolkien")
book = Book(title="Tom Sawyer", authors=[mark_twain]).save() book = Book(title="Tom Sawyer", authors=[mark_twain]).save()
book = Book(title="The Lord of the Rings", authors=[john_tolkien]).save() book = Book(
book = Book(title="The Stories", authors=[mark_twain, john_tolkien]).save() title="The Lord of the Rings", authors=[john_tolkien]).save()
book = Book(
title="The Stories", authors=[mark_twain, john_tolkien]).save()
authors = Book.objects.distinct("authors") authors = Book.objects.distinct("authors")
self.assertEqual(authors, [mark_twain, john_tolkien]) self.assertEqual(authors, [mark_twain, john_tolkien])
@ -2845,6 +2947,7 @@ class QuerySetTest(unittest.TestCase):
return queryset(active=True) return queryset(active=True)
class Bar(Foo): class Bar(Foo):
@queryset_manager @queryset_manager
def objects(klass, queryset): def objects(klass, queryset):
return queryset(active=False) return queryset(active=False)
@ -2916,7 +3019,8 @@ class QuerySetTest(unittest.TestCase):
t = Test(testdict={'f': 'Value'}) t = Test(testdict={'f': 'Value'})
t.save() t.save()
self.assertEqual(Test.objects(testdict__f__startswith='Val').count(), 1) self.assertEqual(
Test.objects(testdict__f__startswith='Val').count(), 1)
self.assertEqual(Test.objects(testdict__f='Value').count(), 1) self.assertEqual(Test.objects(testdict__f='Value').count(), 1)
Test.drop_collection() Test.drop_collection()
@ -2927,7 +3031,8 @@ class QuerySetTest(unittest.TestCase):
t.save() t.save()
self.assertEqual(Test.objects(testdict__f='Value').count(), 1) self.assertEqual(Test.objects(testdict__f='Value').count(), 1)
self.assertEqual(Test.objects(testdict__f__startswith='Val').count(), 1) self.assertEqual(
Test.objects(testdict__f__startswith='Val').count(), 1)
Test.drop_collection() Test.drop_collection()
def test_bulk(self): def test_bulk(self):
@ -2972,6 +3077,7 @@ class QuerySetTest(unittest.TestCase):
"""Ensure that custom QuerySet classes may be used. """Ensure that custom QuerySet classes may be used.
""" """
class CustomQuerySet(QuerySet): class CustomQuerySet(QuerySet):
def not_empty(self): def not_empty(self):
return self.count() > 0 return self.count() > 0
@ -2993,6 +3099,7 @@ class QuerySetTest(unittest.TestCase):
""" """
class CustomQuerySet(QuerySet): class CustomQuerySet(QuerySet):
def not_empty(self): def not_empty(self):
return self.count() > 0 return self.count() > 0
@ -3040,6 +3147,7 @@ class QuerySetTest(unittest.TestCase):
""" """
class CustomQuerySet(QuerySet): class CustomQuerySet(QuerySet):
def not_empty(self): def not_empty(self):
return self.count() > 0 return self.count() > 0
@ -3063,6 +3171,7 @@ class QuerySetTest(unittest.TestCase):
""" """
class CustomQuerySet(QuerySet): class CustomQuerySet(QuerySet):
def not_empty(self): def not_empty(self):
return self.count() > 0 return self.count() > 0
@ -3096,7 +3205,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(5, Post.objects.limit(5).skip(5).count()) self.assertEqual(5, Post.objects.limit(5).skip(5).count())
self.assertEqual(10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False)) self.assertEqual(
10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False))
def test_count_and_none(self): def test_count_and_none(self):
"""Test count works with None()""" """Test count works with None()"""
@ -3278,7 +3388,8 @@ class QuerySetTest(unittest.TestCase):
c.save() c.save()
query = IntPair.objects.where('this[~fielda] >= this[~fieldb]') query = IntPair.objects.where('this[~fielda] >= this[~fieldb]')
self.assertEqual('this["fielda"] >= this["fieldb"]', query._where_clause) self.assertEqual(
'this["fielda"] >= this["fieldb"]', query._where_clause)
results = list(query) results = list(query)
self.assertEqual(2, len(results)) self.assertEqual(2, len(results))
self.assertTrue(a in results) self.assertTrue(a in results)
@ -3289,8 +3400,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
self.assertTrue(a in results) self.assertTrue(a in results)
query = IntPair.objects.where('function() { return this[~fielda] >= this[~fieldb] }') query = IntPair.objects.where(
self.assertEqual('function() { return this["fielda"] >= this["fieldb"] }', query._where_clause) 'function() { return this[~fielda] >= this[~fieldb] }')
self.assertEqual(
'function() { return this["fielda"] >= this["fieldb"] }', query._where_clause)
results = list(query) results = list(query)
self.assertEqual(2, len(results)) self.assertEqual(2, len(results))
self.assertTrue(a in results) self.assertTrue(a in results)
@ -3404,7 +3517,8 @@ class QuerySetTest(unittest.TestCase):
locale=Locale(city="Brasilia", country="Brazil")).save() locale=Locale(city="Brasilia", country="Brazil")).save()
self.assertEqual( self.assertEqual(
list(Person.objects.order_by('profile__age').scalar('profile__name')), list(Person.objects.order_by(
'profile__age').scalar('profile__name')),
[u'Wilson Jr', u'Gabriel Falcao', u'Lincoln de souza', u'Walter cruz']) [u'Wilson Jr', u'Gabriel Falcao', u'Lincoln de souza', u'Walter cruz'])
ulist = list(Person.objects.order_by('locale.city') ulist = list(Person.objects.order_by('locale.city')
@ -3417,6 +3531,7 @@ class QuerySetTest(unittest.TestCase):
def test_scalar_decimal(self): def test_scalar_decimal(self):
from decimal import Decimal from decimal import Decimal
class Person(Document): class Person(Document):
name = StringField() name = StringField()
rating = DecimalField() rating = DecimalField()
@ -3427,7 +3542,6 @@ class QuerySetTest(unittest.TestCase):
ulist = list(Person.objects.scalar('name', 'rating')) ulist = list(Person.objects.scalar('name', 'rating'))
self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))]) self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))])
def test_scalar_reference_field(self): def test_scalar_reference_field(self):
class State(Document): class State(Document):
name = StringField() name = StringField()
@ -3561,24 +3675,33 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='A%s' % i, age=i).save() self.Person(name='A%s' % i, age=i).save()
self.assertEqual(self.Person.objects.scalar('name').count(), 55) self.assertEqual(self.Person.objects.scalar('name').count(), 55)
self.assertEqual("A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) self.assertEqual(
self.assertEqual("A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first())
self.assertEqual(
"A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0])
if PY3: if PY3:
self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by('age').scalar('name')[1:3]) self.assertEqual(
self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by('age').scalar('name')[51:53]) "['A1', 'A2']", "%s" % self.Person.objects.order_by('age').scalar('name')[1:3])
self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by(
'age').scalar('name')[51:53])
else: else:
self.assertEqual("[u'A1', u'A2']", "%s" % self.Person.objects.order_by('age').scalar('name')[1:3]) self.assertEqual("[u'A1', u'A2']", "%s" % self.Person.objects.order_by(
self.assertEqual("[u'A51', u'A52']", "%s" % self.Person.objects.order_by('age').scalar('name')[51:53]) 'age').scalar('name')[1:3])
self.assertEqual("[u'A51', u'A52']", "%s" % self.Person.objects.order_by(
'age').scalar('name')[51:53])
# with_id and in_bulk # with_id and in_bulk
person = self.Person.objects.order_by('name').first() person = self.Person.objects.order_by('name').first()
self.assertEqual("A0", "%s" % self.Person.objects.scalar('name').with_id(person.id)) self.assertEqual("A0", "%s" %
self.Person.objects.scalar('name').with_id(person.id))
pks = self.Person.objects.order_by('age').scalar('pk')[1:3] pks = self.Person.objects.order_by('age').scalar('pk')[1:3]
if PY3: if PY3:
self.assertEqual("['A1', 'A2']", "%s" % sorted(self.Person.objects.scalar('name').in_bulk(list(pks)).values())) self.assertEqual("['A1', 'A2']", "%s" % sorted(
self.Person.objects.scalar('name').in_bulk(list(pks)).values()))
else: else:
self.assertEqual("[u'A1', u'A2']", "%s" % sorted(self.Person.objects.scalar('name').in_bulk(list(pks)).values())) self.assertEqual("[u'A1', u'A2']", "%s" % sorted(
self.Person.objects.scalar('name').in_bulk(list(pks)).values()))
def test_elem_match(self): def test_elem_match(self):
class Foo(EmbeddedDocument): class Foo(EmbeddedDocument):
@ -3601,10 +3724,12 @@ class QuerySetTest(unittest.TestCase):
Foo(shape="circle", color="purple", thick=False)]) Foo(shape="circle", color="purple", thick=False)])
b2.save() b2.save()
ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) ak = list(
Bar.objects(foo__match={'shape': "square", "color": "purple"}))
self.assertEqual([b1], ak) self.assertEqual([b1], ak)
ak = list(Bar.objects(foo__elemMatch={'shape': "square", "color": "purple"})) ak = list(
Bar.objects(foo__elemMatch={'shape': "square", "color": "purple"}))
self.assertEqual([b1], ak) self.assertEqual([b1], ak)
ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple")))
@ -3660,7 +3785,8 @@ class QuerySetTest(unittest.TestCase):
read_preference='Primary') read_preference='Primary')
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
def test_json_simple(self): def test_json_simple(self):
@ -3702,13 +3828,15 @@ class QuerySetTest(unittest.TestCase):
list_field = ListField(default=lambda: [1, 2, 3]) list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"}) dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=ObjectId) objectid_field = ObjectIdField(default=ObjectId)
reference_field = ReferenceField(Simple, default=lambda: Simple().save()) reference_field = ReferenceField(
Simple, default=lambda: Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1}) map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0) decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now) complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org") url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1) dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(default=lambda: Simple().save()) generic_reference_field = GenericReferenceField(
default=lambda: Simple().save())
sorted_list_field = SortedListField(IntField(), sorted_list_field = SortedListField(IntField(),
default=lambda: [1, 2, 3]) default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com") email_field = EmailField(default="ross@example.com")
@ -3754,7 +3882,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(results[1]['price'], 2.22) self.assertEqual(results[1]['price'], 2.22)
# Test coerce_types # Test coerce_types
users = User.objects.only('name', 'price').as_pymongo(coerce_types=True) users = User.objects.only(
'name', 'price').as_pymongo(coerce_types=True)
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
self.assertTrue(isinstance(results[1], dict)) self.assertTrue(isinstance(results[1], dict))
@ -3767,22 +3896,29 @@ class QuerySetTest(unittest.TestCase):
class User(Document): class User(Document):
email = EmailField(unique=True, required=True) email = EmailField(unique=True, required=True)
password_hash = StringField(db_field='password_hash', required=True) password_hash = StringField(
password_salt = StringField(db_field='password_salt', required=True) db_field='password_hash', required=True)
password_salt = StringField(
db_field='password_salt', required=True)
User.drop_collection() User.drop_collection()
User(email="ross@example.com", password_salt="SomeSalt", password_hash="SomeHash").save() User(email="ross@example.com", password_salt="SomeSalt",
password_hash="SomeHash").save()
serialized_user = User.objects.exclude('password_salt', 'password_hash').as_pymongo()[0] serialized_user = User.objects.exclude(
'password_salt', 'password_hash').as_pymongo()[0]
self.assertEqual(set(['_id', 'email']), set(serialized_user.keys())) self.assertEqual(set(['_id', 'email']), set(serialized_user.keys()))
serialized_user = User.objects.exclude('id', 'password_salt', 'password_hash').to_json() serialized_user = User.objects.exclude(
'id', 'password_salt', 'password_hash').to_json()
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
serialized_user = User.objects.exclude('password_salt').only('email').as_pymongo()[0] serialized_user = User.objects.exclude(
'password_salt').only('email').as_pymongo()[0]
self.assertEqual(set(['email']), set(serialized_user.keys())) self.assertEqual(set(['email']), set(serialized_user.keys()))
serialized_user = User.objects.exclude('password_salt').only('email').to_json() serialized_user = User.objects.exclude(
'password_salt').only('email').to_json()
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
def test_no_dereference(self): def test_no_dereference(self):
@ -3827,7 +3963,8 @@ class QuerySetTest(unittest.TestCase):
if platform.python_implementation() != "PyPy": if platform.python_implementation() != "PyPy":
# PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not.
# This may be a bug in PyPy (PyPy/#1802) but it does not affect the behavior of MongoEngine. # This may be a bug in PyPy (PyPy/#1802) but it does not affect
# the behavior of MongoEngine.
self.assertEqual(None, people._len) self.assertEqual(None, people._len)
self.assertEqual(q, 1) self.assertEqual(q, 1)
@ -3946,10 +4083,13 @@ class QuerySetTest(unittest.TestCase):
inner_count += 1 inner_count += 1
inner_total_count += 1 inner_total_count += 1
self.assertEqual(inner_count, 7) # inner loop should always be executed seven times # inner loop should always be executed seven times
self.assertEqual(inner_count, 7)
self.assertEqual(outer_count, 7) # outer loop should be executed seven times total # outer loop should be executed seven times total
self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total self.assertEqual(outer_count, 7)
# inner loop should be executed fourtynine times total
self.assertEqual(inner_total_count, 7 * 7)
self.assertEqual(q, 2) self.assertEqual(q, 2)
@ -4099,7 +4239,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(op['nreturned'], 1) self.assertEqual(op['nreturned'], 1)
def test_bool_with_ordering(self): def test_bool_with_ordering(self):
class Person(Document): class Person(Document):