diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 526fe861..8e115a99 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -195,12 +195,37 @@ class QuerySet(object): def __iter__(self): return self + def exec_js(self, code, fields): + """Execute a Javascript function on the server. Two arguments will be + provided by default - the collection name, and the query object. A list + of fields may be provided, which will be translated to their correct + names and supplied as the remaining arguments to the function. + """ + fields = [QuerySet._translate_field_name(self._document, field) + for field in fields] + db = _get_db() + collection = self._document._meta['collection'] + return db.eval(code, collection, self._query, *fields) + + def sum(self, field): + """Sum over the values of the specified field. + """ + sum_func = """ + function(collection, query, sumField) { + var total = 0.0; + db[collection].find(query).forEach(function(doc) { + total += doc[sumField] || 0.0; + }); + return total; + } + """ + return self.exec_js(sum_func, [field]) + def item_frequencies(self, list_field): """Returns a dictionary of all items present in a list field across the whole queried set of documents, and their corresponding frequency. This is useful for generating tag clouds, or searching documents. """ - list_field = QuerySet._translate_field_name(self._document, list_field) freq_func = """ function(collection, query, listField) { var frequencies = {}; @@ -212,9 +237,7 @@ class QuerySet(object): return frequencies; } """ - db = _get_db() - collection = self._document._meta['collection'] - return db.eval(freq_func, collection, self._query, list_field) + return self.exec_js(freq_func, [list_field]) class QuerySetManager(object): diff --git a/tests/queryset.py b/tests/queryset.py index 8d097ca6..b35b8d48 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -214,6 +214,15 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_sum(self): + """Ensure that field can be summed over correctly. + """ + ages = [23, 54, 12, 94, 27] + for i, age in enumerate(ages): + self.Person(name='test%s' % i, age=age).save() + + self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. """