From dd095279c86d55f0497f751f8ca4cb56aeb36d55 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Sun, 7 Jun 2015 14:15:23 -0700 Subject: [PATCH 1/3] aggregate_sum/average + unit tests --- mongoengine/queryset/base.py | 34 ++++++++++++++++++ tests/queryset/queryset.py | 68 ++++++++++++++++++++++++++++++++---- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index c8a30783..38389fbf 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1248,6 +1248,23 @@ class BaseQuerySet(object): else: return 0 + def aggregate_sum(self, field): + """Sum over the values of the specified field. + + :param field: the field to sum over; use dot-notation to refer to + embedded document fields + + This method is more performant than the regular `sum`, because it uses + the aggregation framework instead of map-reduce. + """ + result = self._document._get_collection().aggregate([ + { '$match': self._query }, + { '$group': { '_id': 'sum', 'total': { '$sum': '$' + field } } } + ]) + if result['result']: + return result['result'][0]['total'] + return 0 + def average(self, field): """Average over the values of the specified field. @@ -1303,6 +1320,23 @@ class BaseQuerySet(object): else: return 0 + def aggregate_average(self, field): + """Average over the values of the specified field. + + :param field: the field to average over; use dot-notation to refer to + embedded document fields + + This method is more performant than the regular `average`, because it + uses the aggregation framework instead of map-reduce. + """ + result = self._document._get_collection().aggregate([ + { '$match': self._query }, + { '$group': { '_id': 'avg', 'total': { '$avg': '$' + field } } } + ]) + if result['result']: + return result['result'][0]['total'] + return 0 + def item_frequencies(self, field, normalize=False, map_reduce=True): """Returns a dictionary of all items present in a field across the whole queried set of documents, and their corresponding frequency. diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index e7eb4901..d4348678 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2706,26 +2706,58 @@ class QuerySetTest(unittest.TestCase): avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) + self.assertAlmostEqual( + int(self.Person.objects.aggregate_average('age')), avg + ) self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.average('age')), avg) + self.assertEqual( + int(self.Person.objects.aggregate_average('age')), avg + ) # dot notation self.Person( name='person meta', person_meta=self.PersonMeta(weight=0)).save() self.assertAlmostEqual( int(self.Person.objects.average('person_meta.weight')), 0) + self.assertAlmostEqual( + int(self.Person.objects.aggregate_average('person_meta.weight')), + 0 + ) for i, weight in enumerate(ages): self.Person( name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() self.assertAlmostEqual( - int(self.Person.objects.average('person_meta.weight')), avg) + int(self.Person.objects.average('person_meta.weight')), avg + ) + self.assertAlmostEqual( + int(self.Person.objects.aggregate_average('person_meta.weight')), + avg + ) self.Person(name='test meta none').save() self.assertEqual( - int(self.Person.objects.average('person_meta.weight')), avg) + int(self.Person.objects.average('person_meta.weight')), avg + ) + self.assertEqual( + int(self.Person.objects.aggregate_average('person_meta.weight')), + avg + ) + + # test summing over a filtered queryset + over_50 = [a for a in ages if a >= 50] + avg = float(sum(over_50)) / len(over_50) + self.assertEqual( + self.Person.objects.filter(age__gte=50).average('age'), + avg + ) + self.assertEqual( + self.Person.objects.filter(age__gte=50).aggregate_average('age'), + avg + ) def test_sum(self): """Ensure that field can be summed over correctly. @@ -2734,20 +2766,44 @@ class QuerySetTest(unittest.TestCase): for i, age in enumerate(ages): self.Person(name='test%s' % i, age=age).save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual( + self.Person.objects.aggregate_sum('age'), sum(ages) + ) self.Person(name='ageless person').save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual( + self.Person.objects.aggregate_sum('age'), sum(ages) + ) for i, age in enumerate(ages): self.Person(name='test meta%s' % i, person_meta=self.PersonMeta(weight=age)).save() self.assertEqual( - int(self.Person.objects.sum('person_meta.weight')), sum(ages)) + self.Person.objects.sum('person_meta.weight'), sum(ages) + ) + self.assertEqual( + self.Person.objects.aggregate_sum('person_meta.weight'), + sum(ages) + ) self.Person(name='weightless person').save() - self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual( + self.Person.objects.aggregate_sum('age'), sum(ages) + ) + + # test summing over a filtered queryset + self.assertEqual( + self.Person.objects.filter(age__gte=50).sum('age'), + sum([a for a in ages if a >= 50]) + ) + self.assertEqual( + self.Person.objects.filter(age__gte=50).aggregate_sum('age'), + sum([a for a in ages if a >= 50]) + ) def test_embedded_average(self): class Pay(EmbeddedDocument): From 12337802657d7e6d03a1fd19e7f412d54b3ef05d Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Mon, 8 Jun 2015 13:46:19 -0700 Subject: [PATCH 2/3] make aggregate_sum/average compatible with pymongo 3.x --- mongoengine/queryset/base.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 38389fbf..b949e121 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1261,8 +1261,13 @@ class BaseQuerySet(object): { '$match': self._query }, { '$group': { '_id': 'sum', 'total': { '$sum': '$' + field } } } ]) - if result['result']: - return result['result'][0]['total'] + if IS_PYMONGO_3: + result = list(result) + if result: + return result[0]['total'] + else: + if result['result']: + return result['result'][0]['total'] return 0 def average(self, field): @@ -1333,8 +1338,13 @@ class BaseQuerySet(object): { '$match': self._query }, { '$group': { '_id': 'avg', 'total': { '$avg': '$' + field } } } ]) - if result['result']: - return result['result'][0]['total'] + if IS_PYMONGO_3: + result = list(result) + if result: + return result[0]['total'] + else: + if result['result']: + return result['result'][0]['total'] return 0 def item_frequencies(self, field, normalize=False, map_reduce=True): From b7ef82cb67d11787f7b305028690c542ad048301 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Wed, 17 Jun 2015 17:05:10 -0700 Subject: [PATCH 3/3] style tweaks + changelog entry --- docs/changelog.rst | 1 + mongoengine/queryset/base.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 48e8b9aa..b9ad5b0e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,6 +25,7 @@ Changes in 0.9.X - DEV - Updated URL and Email Field regex validators, added schemes argument to URLField validation. #652 - Removed get_or_create() deprecated since 0.8.0. #300 - Capped collection multiple of 256. #1011 +- Added `BaseQuerySet.aggregate_sum` and `BaseQuerySet.aggregate_average` methods. Changes in 0.9.0 ================ diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index b949e121..c3abd46a 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1263,11 +1263,10 @@ class BaseQuerySet(object): ]) if IS_PYMONGO_3: result = list(result) - if result: - return result[0]['total'] else: - if result['result']: - return result['result'][0]['total'] + result = result.get('result') + if result: + return result[0]['total'] return 0 def average(self, field): @@ -1340,11 +1339,10 @@ class BaseQuerySet(object): ]) if IS_PYMONGO_3: result = list(result) - if result: - return result[0]['total'] else: - if result['result']: - return result['result'][0]['total'] + result = result.get('result') + if result: + return result[0]['total'] return 0 def item_frequencies(self, field, normalize=False, map_reduce=True):