diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index c8a30783..38389fbf 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1248,6 +1248,23 @@ class BaseQuerySet(object): else: return 0 + def aggregate_sum(self, field): + """Sum over the values of the specified field. + + :param field: the field to sum over; use dot-notation to refer to + embedded document fields + + This method is more performant than the regular `sum`, because it uses + the aggregation framework instead of map-reduce. + """ + result = self._document._get_collection().aggregate([ + { '$match': self._query }, + { '$group': { '_id': 'sum', 'total': { '$sum': '$' + field } } } + ]) + if result['result']: + return result['result'][0]['total'] + return 0 + def average(self, field): """Average over the values of the specified field. @@ -1303,6 +1320,23 @@ class BaseQuerySet(object): else: return 0 + def aggregate_average(self, field): + """Average over the values of the specified field. + + :param field: the field to average over; use dot-notation to refer to + embedded document fields + + This method is more performant than the regular `average`, because it + uses the aggregation framework instead of map-reduce. + """ + result = self._document._get_collection().aggregate([ + { '$match': self._query }, + { '$group': { '_id': 'avg', 'total': { '$avg': '$' + field } } } + ]) + if result['result']: + return result['result'][0]['total'] + return 0 + def item_frequencies(self, field, normalize=False, map_reduce=True): """Returns a dictionary of all items present in a field across the whole queried set of documents, and their corresponding frequency. diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index e7eb4901..d4348678 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2706,26 +2706,58 @@ class QuerySetTest(unittest.TestCase): avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) + self.assertAlmostEqual( + int(self.Person.objects.aggregate_average('age')), avg + ) self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.average('age')), avg) + self.assertEqual( + int(self.Person.objects.aggregate_average('age')), avg + ) # dot notation self.Person( name='person meta', person_meta=self.PersonMeta(weight=0)).save() self.assertAlmostEqual( int(self.Person.objects.average('person_meta.weight')), 0) + self.assertAlmostEqual( + int(self.Person.objects.aggregate_average('person_meta.weight')), + 0 + ) for i, weight in enumerate(ages): self.Person( name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() self.assertAlmostEqual( - int(self.Person.objects.average('person_meta.weight')), avg) + int(self.Person.objects.average('person_meta.weight')), avg + ) + self.assertAlmostEqual( + int(self.Person.objects.aggregate_average('person_meta.weight')), + avg + ) self.Person(name='test meta none').save() self.assertEqual( - int(self.Person.objects.average('person_meta.weight')), avg) + int(self.Person.objects.average('person_meta.weight')), avg + ) + self.assertEqual( + int(self.Person.objects.aggregate_average('person_meta.weight')), + avg + ) + + # test summing over a filtered queryset + over_50 = [a for a in ages if a >= 50] + avg = float(sum(over_50)) / len(over_50) + self.assertEqual( + self.Person.objects.filter(age__gte=50).average('age'), + avg + ) + self.assertEqual( + self.Person.objects.filter(age__gte=50).aggregate_average('age'), + avg + ) def test_sum(self): """Ensure that field can be summed over correctly. @@ -2734,20 +2766,44 @@ class QuerySetTest(unittest.TestCase): for i, age in enumerate(ages): self.Person(name='test%s' % i, age=age).save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual( + self.Person.objects.aggregate_sum('age'), sum(ages) + ) self.Person(name='ageless person').save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual( + self.Person.objects.aggregate_sum('age'), sum(ages) + ) for i, age in enumerate(ages): self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save() self.assertEqual( - int(self.Person.objects.sum('person_meta.weight')), sum(ages)) + self.Person.objects.sum('person_meta.weight'), sum(ages) + ) + self.assertEqual( + self.Person.objects.aggregate_sum('person_meta.weight'), + sum(ages) + ) self.Person(name='weightless person').save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual( + self.Person.objects.aggregate_sum('age'), sum(ages) + ) + + # test summing over a filtered queryset + self.assertEqual( + self.Person.objects.filter(age__gte=50).sum('age'), + sum([a for a in ages if a >= 50]) + ) + self.assertEqual( + self.Person.objects.filter(age__gte=50).aggregate_sum('age'), + sum([a for a in ages if a >= 50]) + ) def test_embedded_average(self): class Pay(EmbeddedDocument):