merged conflicts

This commit is contained in:
blackbrrr 2010-03-09 15:19:14 -06:00
commit 22a6ec7794
4 changed files with 336 additions and 70 deletions

View File

@ -22,7 +22,7 @@ sys.path.append(os.path.abspath('..'))
# Add any Sphinx extension module names here, as strings. They can be extensions # Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = ['sphinx.ext.autodoc'] extensions = ['sphinx.ext.autodoc', 'sphinx.ext.todo']
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']

View File

@ -115,3 +115,41 @@ class Document(BaseDocument):
""" """
db = _get_db() db = _get_db()
db.drop_collection(cls._meta['collection']) db.drop_collection(cls._meta['collection'])
class MapReduceDocument(object):
"""A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property.
:param value: The result(s) for this key.
.. versionadded:: 0.2.2
"""
def __init__(self, document, collection, key, value):
self._document = document
self._collection = collection
self.key = key
self.value = value
@property
def object(self):
"""Lazy-load the object referenced by ``self.key``. If ``self.key``
is not an ``ObjectId``, simply return ``self.key``.
"""
if not isinstance(self.key, (pymongo.objectid.ObjectId)):
try:
self.key = pymongo.objectid.ObjectId(self.key)
except:
return self.key
if not hasattr(self, "_key_object"):
self._key_object = self._document.objects.with_id(self.key)
return self._key_object
return self._key_object

View File

@ -5,7 +5,7 @@ import re
import copy import copy
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError'] 'InvalidCollectionError']
# The maximum number of items to display in a QuerySet.__repr__ # The maximum number of items to display in a QuerySet.__repr__
@ -13,7 +13,7 @@ REPR_OUTPUT_SIZE = 20
class DoesNotExist(Exception): class DoesNotExist(Exception):
pass pass
class MultipleObjectsReturned(Exception): class MultipleObjectsReturned(Exception):
@ -30,8 +30,9 @@ class OperationError(Exception):
RE_TYPE = type(re.compile('')) RE_TYPE = type(re.compile(''))
class Q(object): class Q(object):
OR = '||' OR = '||'
AND = '&&' AND = '&&'
OPERATORS = { OPERATORS = {
@ -52,7 +53,7 @@ class Q(object):
'regex_eq': '%(value)s.test(this.%(field)s)', 'regex_eq': '%(value)s.test(this.%(field)s)',
'regex_ne': '!%(value)s.test(this.%(field)s)', 'regex_ne': '!%(value)s.test(this.%(field)s)',
} }
def __init__(self, **query): def __init__(self, **query):
self.query = [query] self.query = [query]
@ -103,7 +104,7 @@ class Q(object):
js.append(operation_js) js.append(operation_js)
else: else:
# Construct the JS for this field # Construct the JS for this field
value, field_js = self._build_op_js(op, key, value, value_name) (op, key, value, value_name)
js_scope[value_name] = value js_scope[value_name] = value
js.append(field_js) js.append(field_js)
return ' && '.join(js) return ' && '.join(js)
@ -132,10 +133,10 @@ class Q(object):
return value, operation_js return value, operation_js
class QuerySet(object): class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor, """A set of results returned from a query. Wraps a MongoDB cursor,
providing :class:`~mongoengine.Document` objects as the results. providing :class:`~mongoengine.Document` objects as the results.
""" """
def __init__(self, document, collection): def __init__(self, document, collection):
self._document = document self._document = document
self._collection_obj = collection self._collection_obj = collection
@ -143,7 +144,7 @@ class QuerySet(object):
self._query = {} self._query = {}
self._where_clause = None self._where_clause = None
self._loaded_fields = [] self._loaded_fields = []
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
# subclasses of the class being used # subclasses of the class being used
if document._meta.get('allow_inheritance'): if document._meta.get('allow_inheritance'):
@ -151,7 +152,7 @@ class QuerySet(object):
self._cursor_obj = None self._cursor_obj = None
self._limit = None self._limit = None
self._skip = None self._skip = None
def ensure_index(self, key_or_list): def ensure_index(self, key_or_list):
"""Ensure that the given indexes are in place. """Ensure that the given indexes are in place.
@ -199,7 +200,7 @@ class QuerySet(object):
return index_list return index_list
def __call__(self, q_obj=None, **query): def __call__(self, q_obj=None, **query):
"""Filter the selected documents by calling the """Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query. :class:`~mongoengine.queryset.QuerySet` with a query.
:param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in
@ -213,7 +214,7 @@ class QuerySet(object):
query = QuerySet._transform_query(_doc_cls=self._document, **query) query = QuerySet._transform_query(_doc_cls=self._document, **query)
self._query.update(query) self._query.update(query)
return self return self
def filter(self, *q_objs, **query): def filter(self, *q_objs, **query):
"""An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__`
""" """
@ -253,11 +254,11 @@ class QuerySet(object):
# Apply where clauses to cursor # Apply where clauses to cursor
if self._where_clause: if self._where_clause:
self._cursor_obj.where(self._where_clause) self._cursor_obj.where(self._where_clause)
# apply default ordering # apply default ordering
if self._document._meta['ordering']: if self._document._meta['ordering']:
self.order_by(*self._document._meta['ordering']) self.order_by(*self._document._meta['ordering'])
return self._cursor_obj return self._cursor_obj
@classmethod @classmethod
@ -337,8 +338,8 @@ class QuerySet(object):
return mongo_query return mongo_query
def get(self, *q_objs, **query): def get(self, *q_objs, **query):
"""Retrieve the the matching object raising """Retrieve the the matching object raising
:class:`~mongoengine.queryset.MultipleObjectsReturned` or :class:`~mongoengine.queryset.MultipleObjectsReturned` or
:class:`~mongoengine.queryset.DoesNotExist` exceptions if multiple or :class:`~mongoengine.queryset.DoesNotExist` exceptions if multiple or
no results are found. no results are found.
""" """
@ -354,15 +355,15 @@ class QuerySet(object):
def get_or_create(self, *q_objs, **query): def get_or_create(self, *q_objs, **query):
"""Retreive unique object or create, if it doesn't exist. Raises """Retreive unique object or create, if it doesn't exist. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` if multiple :class:`~mongoengine.queryset.MultipleObjectsReturned` if multiple
results are found. A new document will be created if the document results are found. A new document will be created if the document
doesn't exists; a dictionary of default values for the new document doesn't exists; a dictionary of default values for the new document
may be provided as a keyword argument called :attr:`defaults`. may be provided as a keyword argument called :attr:`defaults`.
""" """
defaults = query.get('defaults', {}) defaults = query.get('defaults', {})
if query.has_key('defaults'): if 'defaults' in query:
del query['defaults'] del query['defaults']
self.__call__(*q_objs, **query) self.__call__(*q_objs, **query)
count = self.count() count = self.count()
if count == 0: if count == 0:
@ -430,6 +431,70 @@ class QuerySet(object):
def __len__(self): def __len__(self):
return self.count() return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None,
scope=None, keep_temp=False):
"""Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
it must be the last call made, as it does not return a maleable
``QuerySet``.
See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce`
and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced`
tests in ``tests.queryset.QuerySetTest`` for usage examples.
:param map_f: map function, as :class:`~pymongo.code.Code` or string
:param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string
:param finalize_f: finalize function, an optional function that
performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide
to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**.
.. versionadded:: 0.2.2
"""
from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1")
if not isinstance(map_f, pymongo.code.Code):
map_f = pymongo.code.Code(map_f)
if not isinstance(reduce_f, pymongo.code.Code):
reduce_f = pymongo.code.Code(reduce_f)
mr_args = {'query': self._query, 'keeptemp': keep_temp}
if finalize_f:
if not isinstance(finalize_f, pymongo.code.Code):
finalize_f = pymongo.code.Code(finalize_f)
mr_args['finalize'] = finalize_f
if scope:
mr_args['scope'] = scope
if limit:
mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, **mr_args)
results = results.find()
if self._ordering:
results = results.sort(self._ordering)
for doc in results:
yield MapReduceDocument(self._document, self._collection,
doc['_id'], doc['value'])
def limit(self, n): def limit(self, n):
"""Limit the number of returned documents to `n`. This may also be """Limit the number of returned documents to `n`. This may also be
achieved using array-slicing syntax (e.g. ``User.objects[:5]``). achieved using array-slicing syntax (e.g. ``User.objects[:5]``).
@ -441,6 +506,7 @@ class QuerySet(object):
else: else:
self._cursor.limit(n) self._cursor.limit(n)
self._limit = n self._limit = n
# Return self to allow chaining # Return self to allow chaining
return self return self
@ -514,13 +580,14 @@ class QuerySet(object):
direction = pymongo.DESCENDING direction = pymongo.DESCENDING
if key[0] in ('-', '+'): if key[0] in ('-', '+'):
key = key[1:] key = key[1:]
key_list.append((key, direction)) key_list.append((key, direction))
self._ordering = key_list
self._cursor.sort(key_list) self._cursor.sort(key_list)
return self return self
def explain(self, format=False): def explain(self, format=False):
"""Return an explain plan record for the """Return an explain plan record for the
:class:`~mongoengine.queryset.QuerySet`\ 's cursor. :class:`~mongoengine.queryset.QuerySet`\ 's cursor.
:param format: format the plan before returning it :param format: format the plan before returning it
@ -531,7 +598,7 @@ class QuerySet(object):
import pprint import pprint
plan = pprint.pformat(plan) plan = pprint.pformat(plan)
return plan return plan
def delete(self, safe=False): def delete(self, safe=False):
"""Delete the documents matched by the query. """Delete the documents matched by the query.
@ -543,7 +610,7 @@ class QuerySet(object):
def _transform_update(cls, _doc_cls=None, **update): def _transform_update(cls, _doc_cls=None, **update):
"""Transform an update spec from Django-style format to Mongo format. """Transform an update spec from Django-style format to Mongo format.
""" """
operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull',
'pull_all'] 'pull_all']
mongo_update = {} mongo_update = {}
@ -601,7 +668,7 @@ class QuerySet(object):
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
self._collection.update(self._query, update, safe=safe_update, self._collection.update(self._query, update, safe=safe_update,
multi=True) multi=True)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
if unicode(err) == u'multi not coded yet': if unicode(err) == u'multi not coded yet':
@ -622,7 +689,7 @@ class QuerySet(object):
# Explicitly provide 'multi=False' to newer versions of PyMongo # Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True' # as the default may change to 'True'
if pymongo.version >= '1.1.1': if pymongo.version >= '1.1.1':
self._collection.update(self._query, update, safe=safe_update, self._collection.update(self._query, update, safe=safe_update,
multi=False) multi=False)
else: else:
# Older versions of PyMongo don't support 'multi' # Older versions of PyMongo don't support 'multi'
@ -652,8 +719,8 @@ class QuerySet(object):
"""Execute a Javascript function on the server. A list of fields may be """Execute a Javascript function on the server. A list of fields may be
provided, which will be translated to their correct names and supplied provided, which will be translated to their correct names and supplied
as the arguments to the function. A few extra variables are added to as the arguments to the function. A few extra variables are added to
the function's scope: ``collection``, which is the name of the the function's scope: ``collection``, which is the name of the
collection in use; ``query``, which is an object representing the collection in use; ``query``, which is an object representing the
current query; and ``options``, which is an object containing any current query; and ``options``, which is an object containing any
options specified as keyword arguments. options specified as keyword arguments.
@ -667,7 +734,7 @@ class QuerySet(object):
:param code: a string of Javascript code to execute :param code: a string of Javascript code to execute
:param fields: fields that you will be using in your function, which :param fields: fields that you will be using in your function, which
will be passed in to your function as arguments will be passed in to your function as arguments
:param options: options that you want available to the function :param options: options that you want available to the function
(accessed in Javascript through the ``options`` object) (accessed in Javascript through the ``options`` object)
""" """
code = self._sub_js_fields(code) code = self._sub_js_fields(code)
@ -684,7 +751,7 @@ class QuerySet(object):
query = self._query query = self._query
if self._where_clause: if self._where_clause:
query['$where'] = self._where_clause query['$where'] = self._where_clause
scope['query'] = query scope['query'] = query
code = pymongo.code.Code(code, scope=scope) code = pymongo.code.Code(code, scope=scope)
@ -732,7 +799,7 @@ class QuerySet(object):
def item_frequencies(self, list_field, normalize=False): def item_frequencies(self, list_field, normalize=False):
"""Returns a dictionary of all items present in a list field across """Returns a dictionary of all items present in a list field across
the whole queried set of documents, and their corresponding frequency. the whole queried set of documents, and their corresponding frequency.
This is useful for generating tag clouds, or searching documents. This is useful for generating tag clouds, or searching documents.
:param list_field: the list field to use :param list_field: the list field to use
:param normalize: normalize the results so they add to 1.0 :param normalize: normalize the results so they add to 1.0
@ -782,7 +849,7 @@ class QuerySetManager(object):
self._collection = None self._collection = None
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when """Descriptor for instantiating a new QuerySet object when
Document.objects is accessed. Document.objects is accessed.
""" """
if instance is not None: if instance is not None:
@ -801,7 +868,7 @@ class QuerySetManager(object):
if collection in db.collection_names(): if collection in db.collection_names():
self._collection = db[collection] self._collection = db[collection]
# The collection already exists, check if its capped # The collection already exists, check if its capped
# options match the specified capped options # options match the specified capped options
options = self._collection.options() options = self._collection.options()
if options.get('max') != max_documents or \ if options.get('max') != max_documents or \
@ -817,7 +884,7 @@ class QuerySetManager(object):
self._collection = db.create_collection(collection, opts) self._collection = db.create_collection(collection, opts)
else: else:
self._collection = db[collection] self._collection = db[collection]
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset = QuerySet(owner, self._collection) queryset = QuerySet(owner, self._collection)
if self._manager_func: if self._manager_func:
@ -827,6 +894,7 @@ class QuerySetManager(object):
queryset = self._manager_func(owner, queryset) queryset = self._manager_func(owner, queryset)
return queryset return queryset
def queryset_manager(func): def queryset_manager(func):
"""Decorator that allows you to define custom QuerySet managers on """Decorator that allows you to define custom QuerySet managers on
:class:`~mongoengine.Document` classes. The manager must be a function that :class:`~mongoengine.Document` classes. The manager must be a function that

View File

@ -1,14 +1,17 @@
# -*- coding: utf-8 -*-
import unittest import unittest
import pymongo import pymongo
from datetime import datetime from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
DoesNotExist) DoesNotExist)
from mongoengine import * from mongoengine import *
class QuerySetTest(unittest.TestCase): class QuerySetTest(unittest.TestCase):
def setUp(self): def setUp(self):
connect(db='mongoenginetest') connect(db='mongoenginetest')
@ -16,12 +19,12 @@ class QuerySetTest(unittest.TestCase):
name = StringField() name = StringField()
age = IntField() age = IntField()
self.Person = Person self.Person = Person
def test_initialisation(self): def test_initialisation(self):
"""Ensure that a QuerySet is correctly initialised by QuerySetManager. """Ensure that a QuerySet is correctly initialised by QuerySetManager.
""" """
self.assertTrue(isinstance(self.Person.objects, QuerySet)) self.assertTrue(isinstance(self.Person.objects, QuerySet))
self.assertEqual(self.Person.objects._collection.name, self.assertEqual(self.Person.objects._collection.name,
self.Person._meta['collection']) self.Person._meta['collection'])
self.assertTrue(isinstance(self.Person.objects._collection, self.assertTrue(isinstance(self.Person.objects._collection,
pymongo.collection.Collection)) pymongo.collection.Collection))
@ -31,15 +34,15 @@ class QuerySetTest(unittest.TestCase):
""" """
self.assertEqual(QuerySet._transform_query(name='test', age=30), self.assertEqual(QuerySet._transform_query(name='test', age=30),
{'name': 'test', 'age': 30}) {'name': 'test', 'age': 30})
self.assertEqual(QuerySet._transform_query(age__lt=30), self.assertEqual(QuerySet._transform_query(age__lt=30),
{'age': {'$lt': 30}}) {'age': {'$lt': 30}})
self.assertEqual(QuerySet._transform_query(age__gt=20, age__lt=50), self.assertEqual(QuerySet._transform_query(age__gt=20, age__lt=50),
{'age': {'$gt': 20, '$lt': 50}}) {'age': {'$gt': 20, '$lt': 50}})
self.assertEqual(QuerySet._transform_query(age=20, age__gt=50), self.assertEqual(QuerySet._transform_query(age=20, age__gt=50),
{'age': 20}) {'age': 20})
self.assertEqual(QuerySet._transform_query(friend__age__gte=30), self.assertEqual(QuerySet._transform_query(friend__age__gte=30),
{'friend.age': {'$gte': 30}}) {'friend.age': {'$gte': 30}})
self.assertEqual(QuerySet._transform_query(name__exists=True), self.assertEqual(QuerySet._transform_query(name__exists=True),
{'name': {'$exists': True}}) {'name': {'$exists': True}})
def test_find(self): def test_find(self):
@ -134,7 +137,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(person.name, "User B") self.assertEqual(person.name, "User B")
self.assertRaises(IndexError, self.Person.objects.__getitem__, 2) self.assertRaises(IndexError, self.Person.objects.__getitem__, 2)
# Find a document using just the object id # Find a document using just the object id
person = self.Person.objects.with_id(person1.id) person = self.Person.objects.with_id(person1.id)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
@ -170,7 +173,7 @@ class QuerySetTest(unittest.TestCase):
person2.save() person2.save()
# Retrieve the first person from the database # Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned, self.assertRaises(MultipleObjectsReturned,
self.Person.objects.get_or_create) self.Person.objects.get_or_create)
# Use a query to filter the people found to just person2 # Use a query to filter the people found to just person2
@ -244,36 +247,36 @@ class QuerySetTest(unittest.TestCase):
"""Ensure filters can be chained together. """Ensure filters can be chained together.
""" """
from datetime import datetime from datetime import datetime
class BlogPost(Document): class BlogPost(Document):
title = StringField() title = StringField()
is_published = BooleanField() is_published = BooleanField()
published_date = DateTimeField() published_date = DateTimeField()
@queryset_manager @queryset_manager
def published(doc_cls, queryset): def published(doc_cls, queryset):
return queryset(is_published=True) return queryset(is_published=True)
blog_post_1 = BlogPost(title="Blog Post #1", blog_post_1 = BlogPost(title="Blog Post #1",
is_published = True, is_published = True,
published_date=datetime(2010, 1, 5, 0, 0 ,0)) published_date=datetime(2010, 1, 5, 0, 0 ,0))
blog_post_2 = BlogPost(title="Blog Post #2", blog_post_2 = BlogPost(title="Blog Post #2",
is_published = True, is_published = True,
published_date=datetime(2010, 1, 6, 0, 0 ,0)) published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_3 = BlogPost(title="Blog Post #3", blog_post_3 = BlogPost(title="Blog Post #3",
is_published = True, is_published = True,
published_date=datetime(2010, 1, 7, 0, 0 ,0)) published_date=datetime(2010, 1, 7, 0, 0 ,0))
blog_post_1.save() blog_post_1.save()
blog_post_2.save() blog_post_2.save()
blog_post_3.save() blog_post_3.save()
# find all published blog posts before 2010-01-07 # find all published blog posts before 2010-01-07
published_posts = BlogPost.published() published_posts = BlogPost.published()
published_posts = published_posts.filter( published_posts = published_posts.filter(
published_date__lt=datetime(2010, 1, 7, 0, 0 ,0)) published_date__lt=datetime(2010, 1, 7, 0, 0 ,0))
self.assertEqual(published_posts.count(), 2) self.assertEqual(published_posts.count(), 2)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_ordering(self): def test_ordering(self):
@ -289,22 +292,22 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
blog_post_1 = BlogPost(title="Blog Post #1", blog_post_1 = BlogPost(title="Blog Post #1",
published_date=datetime(2010, 1, 5, 0, 0 ,0)) published_date=datetime(2010, 1, 5, 0, 0 ,0))
blog_post_2 = BlogPost(title="Blog Post #2", blog_post_2 = BlogPost(title="Blog Post #2",
published_date=datetime(2010, 1, 6, 0, 0 ,0)) published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_3 = BlogPost(title="Blog Post #3", blog_post_3 = BlogPost(title="Blog Post #3",
published_date=datetime(2010, 1, 7, 0, 0 ,0)) published_date=datetime(2010, 1, 7, 0, 0 ,0))
blog_post_1.save() blog_post_1.save()
blog_post_2.save() blog_post_2.save()
blog_post_3.save() blog_post_3.save()
# get the "first" BlogPost using default ordering # get the "first" BlogPost using default ordering
# from BlogPost.meta.ordering # from BlogPost.meta.ordering
latest_post = BlogPost.objects.first() latest_post = BlogPost.objects.first()
self.assertEqual(latest_post.title, "Blog Post #3") self.assertEqual(latest_post.title, "Blog Post #3")
# override default ordering, order BlogPosts by "published_date" # override default ordering, order BlogPosts by "published_date"
first_post = BlogPost.objects.order_by("+published_date").first() first_post = BlogPost.objects.order_by("+published_date").first()
self.assertEqual(first_post.title, "Blog Post #1") self.assertEqual(first_post.title, "Blog Post #1")
@ -363,7 +366,7 @@ class QuerySetTest(unittest.TestCase):
result = BlogPost.objects.first() result = BlogPost.objects.first()
self.assertTrue(isinstance(result.author, User)) self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User') self.assertEqual(result.author.name, 'Test User')
BlogPost.drop_collection() BlogPost.drop_collection()
def test_find_dict_item(self): def test_find_dict_item(self):
@ -430,7 +433,7 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='user2', age=20).save() self.Person(name='user2', age=20).save()
self.Person(name='user3', age=30).save() self.Person(name='user3', age=30).save()
self.Person(name='user4', age=40).save() self.Person(name='user4', age=40).save()
self.assertEqual(len(self.Person.objects(Q(age__in=[20]))), 2) self.assertEqual(len(self.Person.objects(Q(age__in=[20]))), 2)
self.assertEqual(len(self.Person.objects(Q(age__in=[20, 30]))), 3) self.assertEqual(len(self.Person.objects(Q(age__in=[20, 30]))), 3)
@ -615,10 +618,167 @@ class QuerySetTest(unittest.TestCase):
names = [p.name for p in self.Person.objects.order_by('age')] names = [p.name for p in self.Person.objects.order_by('age')]
self.assertEqual(names, ['User A', 'User C', 'User B']) self.assertEqual(names, ['User A', 'User C', 'User B'])
ages = [p.age for p in self.Person.objects.order_by('-name')] ages = [p.age for p in self.Person.objects.order_by('-name')]
self.assertEqual(ages, [30, 40, 20]) self.assertEqual(ages, [30, 40, 20])
def test_map_reduce(self):
"""Ensure map/reduce is both mapping and reducing.
"""
class Song(Document):
artists = ListField(StringField())
title = StringField()
is_cover = BooleanField()
Song.drop_collection()
Song(title="Gloria", is_cover=True, artists=['Patti Smith']).save()
Song(title="Redondo beach", is_cover=False,
artists=['Patti Smith']).save()
Song(title="My Generation", is_cover=True,
artists=['Patti Smith', 'John Cale']).save()
map_f = """
function() {
this.artists.forEach(function(artist) {
emit(artist, 1);
});
}
"""
reduce_f = """
function(key, values) {
var total = 0;
for(var i=0; i<values.length; i++) {
total += values[i];
}
return total;
}
"""
# ensure both artists are found
results = Song.objects.map_reduce(map_f, reduce_f)
results = list(results)
self.assertEqual(len(results), 2)
# query for a count of Songs per artist, ordered by -count.
# Patti Smith has 3 song credits, and should therefore be first.
results = Song.objects.order_by("-value").map_reduce(map_f, reduce_f)
results = list(results)
self.assertEqual(results[0].key, "Patti Smith")
self.assertEqual(results[0].value, 3.0)
Song.drop_collection()
def test_map_reduce_finalize(self):
"""Ensure that map, reduce, and finalize run and introduce "scope"
by simulating "hotness" ranking with Reddit algorithm.
"""
from time import mktime
class Link(Document):
title = StringField()
up_votes = IntField()
down_votes = IntField()
submitted = DateTimeField()
Link.drop_collection()
now = datetime.utcnow()
# Note: Test data taken from a custom Reddit homepage on
# Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should
# reflect order of insertion below.
Link(title = "Google Buzz auto-followed a woman's abusive ex ...",
up_votes = 1079,
down_votes = 553,
submitted = now-timedelta(hours=4)).save()
Link(title = "We did it! Barbie is a computer engineer.",
up_votes = 481,
down_votes = 124,
submitted = now-timedelta(hours=2)).save()
Link(title = "This Is A Mosquito Getting Killed By A Laser",
up_votes = 1446,
down_votes = 530,
submitted=now-timedelta(hours=13)).save()
Link(title = "Arabic flashcards land physics student in jail.",
up_votes = 215,
down_votes = 105,
submitted = now-timedelta(hours=6)).save()
Link(title = "The Burger Lab: Presenting, the Flood Burger",
up_votes = 48,
down_votes = 17,
submitted = now-timedelta(hours=5)).save()
Link(title="How to see polarization with the naked eye",
up_votes = 74,
down_votes = 13,
submitted = now-timedelta(hours=10)).save()
map_f = """
function() {
emit(this._id, {up_delta: this.up_votes - this.down_votes,
sub_date: this.submitted.getTime() / 1000})
}
"""
reduce_f = """
function(key, values) {
data = values[0];
x = data.up_delta;
// calculate time diff between reddit epoch and submission
sec_since_epoch = data.sub_date - reddit_epoch;
// calculate 'Y'
if(x > 0) {
y = 1;
} else if (x = 0) {
y = 0;
} else {
y = -1;
}
// calculate 'Z', the maximal value
if(Math.abs(x) >= 1) {
z = Math.abs(x);
} else {
z = 1;
}
return {x: x, y: y, z: z, t_s: sec_since_epoch};
}
"""
finalize_f = """
function(key, value) {
// f(sec_since_epoch,y,z) = log10(z) + ((y*sec_since_epoch) / 45000)
z_10 = Math.log(value.z) / Math.log(10);
weight = z_10 + ((value.y * value.t_s) / 45000);
return weight;
}
"""
reddit_epoch = mktime(datetime(2005, 12, 8, 7, 46, 43).timetuple())
scope = {'reddit_epoch': reddit_epoch}
# ensure both artists are found
results = Link.objects.order_by("-value")
results = results.map_reduce(map_f,
reduce_f,
finalize_f=finalize_f,
scope=scope)
results = list(results)
# assert troublesome Buzz article is ranked 1st
self.assertTrue(results[0].object.title.startswith("Google Buzz"))
# assert laser vision is ranked last
self.assertTrue(results[-1].object.title.startswith("How to see"))
Link.drop_collection()
def test_item_frequencies(self): def test_item_frequencies(self):
"""Ensure that item frequencies are properly generated from lists. """Ensure that item frequencies are properly generated from lists.
""" """
@ -715,20 +875,20 @@ class QuerySetTest(unittest.TestCase):
title = StringField(name='postTitle') title = StringField(name='postTitle')
comments = ListField(EmbeddedDocumentField(Comment), comments = ListField(EmbeddedDocumentField(Comment),
name='postComments') name='postComments')
BlogPost.drop_collection() BlogPost.drop_collection()
data = {'title': 'Post 1', 'comments': [Comment(content='test')]} data = {'title': 'Post 1', 'comments': [Comment(content='test')]}
BlogPost(**data).save() BlogPost(**data).save()
self.assertTrue('postTitle' in self.assertTrue('postTitle' in
BlogPost.objects(title=data['title'])._query) BlogPost.objects(title=data['title'])._query)
self.assertFalse('title' in self.assertFalse('title' in
BlogPost.objects(title=data['title'])._query) BlogPost.objects(title=data['title'])._query)
self.assertEqual(len(BlogPost.objects(title=data['title'])), 1) self.assertEqual(len(BlogPost.objects(title=data['title'])), 1)
self.assertTrue('postComments.commentContent' in self.assertTrue('postComments.commentContent' in
BlogPost.objects(comments__content='test')._query) BlogPost.objects(comments__content='test')._query)
self.assertEqual(len(BlogPost.objects(comments__content='test')), 1) self.assertEqual(len(BlogPost.objects(comments__content='test')), 1)
@ -749,7 +909,7 @@ class QuerySetTest(unittest.TestCase):
post.save() post.save()
# Test that query may be performed by providing a document as a value # Test that query may be performed by providing a document as a value
# while using a ReferenceField's name - the document should be # while using a ReferenceField's name - the document should be
# converted to an DBRef, which is legal, unlike a Document object # converted to an DBRef, which is legal, unlike a Document object
post_obj = BlogPost.objects(author=person).first() post_obj = BlogPost.objects(author=person).first()
self.assertEqual(post.id, post_obj.id) self.assertEqual(post.id, post_obj.id)
@ -852,7 +1012,7 @@ class QuerySetTest(unittest.TestCase):
class QTest(unittest.TestCase): class QTest(unittest.TestCase):
def test_or_and(self): def test_or_and(self):
"""Ensure that Q objects may be combined correctly. """Ensure that Q objects may be combined correctly.
""" """
@ -876,8 +1036,8 @@ class QTest(unittest.TestCase):
examples = [ examples = [
({'name': 'test'}, 'this.name == i0f0', {'i0f0': 'test'}), ({'name': 'test'}, 'this.name == i0f0', {'i0f0': 'test'}),
({'age': {'$gt': 18}}, 'this.age > i0f0o0', {'i0f0o0': 18}), ({'age': {'$gt': 18}}, 'this.age > i0f0o0', {'i0f0o0': 18}),
({'name': 'test', 'age': {'$gt': 18, '$lte': 65}}, ({'name': 'test', 'age': {'$gt': 18, '$lte': 65}},
'this.age <= i0f0o0 && this.age > i0f0o1 && this.name == i0f1', 'this.age <= i0f0o0 && this.age > i0f0o1 && this.name == i0f1',
{'i0f0o0': 65, 'i0f0o1': 18, 'i0f1': 'test'}), {'i0f0o0': 65, 'i0f0o1': 18, 'i0f1': 'test'}),
] ]
for item, js, scope in examples: for item, js, scope in examples: