Merge pull request #1021 from elasticsales/aggregate-sum-and-avg
aggregate_sum/average + unit tests
This commit is contained in:
		| @@ -25,6 +25,7 @@ Changes in 0.9.X - DEV | |||||||
| - Updated URL and Email Field regex validators, added schemes argument to URLField validation. #652 | - Updated URL and Email Field regex validators, added schemes argument to URLField validation. #652 | ||||||
| - Removed get_or_create() deprecated since 0.8.0. #300 | - Removed get_or_create() deprecated since 0.8.0. #300 | ||||||
| - Capped collection multiple of 256. #1011 | - Capped collection multiple of 256. #1011 | ||||||
|  | - Added `BaseQuerySet.aggregate_sum` and `BaseQuerySet.aggregate_average` methods. | ||||||
|  |  | ||||||
| Changes in 0.9.0 | Changes in 0.9.0 | ||||||
| ================ | ================ | ||||||
|   | |||||||
| @@ -1248,6 +1248,27 @@ class BaseQuerySet(object): | |||||||
|         else: |         else: | ||||||
|             return 0 |             return 0 | ||||||
|  |  | ||||||
|  |     def aggregate_sum(self, field): | ||||||
|  |         """Sum over the values of the specified field. | ||||||
|  |  | ||||||
|  |         :param field: the field to sum over; use dot-notation to refer to | ||||||
|  |             embedded document fields | ||||||
|  |  | ||||||
|  |         This method is more performant than the regular `sum`, because it uses | ||||||
|  |         the aggregation framework instead of map-reduce. | ||||||
|  |         """ | ||||||
|  |         result = self._document._get_collection().aggregate([ | ||||||
|  |             { '$match': self._query }, | ||||||
|  |             { '$group': { '_id': 'sum', 'total': { '$sum': '$' + field } } } | ||||||
|  |         ]) | ||||||
|  |         if IS_PYMONGO_3: | ||||||
|  |             result = list(result) | ||||||
|  |         else: | ||||||
|  |             result = result.get('result') | ||||||
|  |         if result: | ||||||
|  |             return result[0]['total'] | ||||||
|  |         return 0 | ||||||
|  |  | ||||||
|     def average(self, field): |     def average(self, field): | ||||||
|         """Average over the values of the specified field. |         """Average over the values of the specified field. | ||||||
|  |  | ||||||
| @@ -1303,6 +1324,27 @@ class BaseQuerySet(object): | |||||||
|         else: |         else: | ||||||
|             return 0 |             return 0 | ||||||
|  |  | ||||||
|  |     def aggregate_average(self, field): | ||||||
|  |         """Average over the values of the specified field. | ||||||
|  |  | ||||||
|  |         :param field: the field to average over; use dot-notation to refer to | ||||||
|  |             embedded document fields | ||||||
|  |  | ||||||
|  |         This method is more performant than the regular `average`, because it | ||||||
|  |         uses the aggregation framework instead of map-reduce. | ||||||
|  |         """ | ||||||
|  |         result = self._document._get_collection().aggregate([ | ||||||
|  |             { '$match': self._query }, | ||||||
|  |             { '$group': { '_id': 'avg', 'total': { '$avg': '$' + field } } } | ||||||
|  |         ]) | ||||||
|  |         if IS_PYMONGO_3: | ||||||
|  |             result = list(result) | ||||||
|  |         else: | ||||||
|  |             result = result.get('result') | ||||||
|  |         if result: | ||||||
|  |             return result[0]['total'] | ||||||
|  |         return 0 | ||||||
|  |  | ||||||
|     def item_frequencies(self, field, normalize=False, map_reduce=True): |     def item_frequencies(self, field, normalize=False, map_reduce=True): | ||||||
|         """Returns a dictionary of all items present in a field across |         """Returns a dictionary of all items present in a field across | ||||||
|         the whole queried set of documents, and their corresponding frequency. |         the whole queried set of documents, and their corresponding frequency. | ||||||
|   | |||||||
| @@ -2706,26 +2706,58 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         avg = float(sum(ages)) / (len(ages) + 1)  # take into account the 0 |         avg = float(sum(ages)) / (len(ages) + 1)  # take into account the 0 | ||||||
|         self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) |         self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) | ||||||
|  |         self.assertAlmostEqual( | ||||||
|  |             int(self.Person.objects.aggregate_average('age')), avg | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self.Person(name='ageless person').save() |         self.Person(name='ageless person').save() | ||||||
|         self.assertEqual(int(self.Person.objects.average('age')), avg) |         self.assertEqual(int(self.Person.objects.average('age')), avg) | ||||||
|  |         self.assertEqual( | ||||||
|  |             int(self.Person.objects.aggregate_average('age')), avg | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # dot notation |         # dot notation | ||||||
|         self.Person( |         self.Person( | ||||||
|             name='person meta', person_meta=self.PersonMeta(weight=0)).save() |             name='person meta', person_meta=self.PersonMeta(weight=0)).save() | ||||||
|         self.assertAlmostEqual( |         self.assertAlmostEqual( | ||||||
|             int(self.Person.objects.average('person_meta.weight')), 0) |             int(self.Person.objects.average('person_meta.weight')), 0) | ||||||
|  |         self.assertAlmostEqual( | ||||||
|  |             int(self.Person.objects.aggregate_average('person_meta.weight')), | ||||||
|  |             0 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         for i, weight in enumerate(ages): |         for i, weight in enumerate(ages): | ||||||
|             self.Person( |             self.Person( | ||||||
|                 name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() |                 name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() | ||||||
|  |  | ||||||
|         self.assertAlmostEqual( |         self.assertAlmostEqual( | ||||||
|             int(self.Person.objects.average('person_meta.weight')), avg) |             int(self.Person.objects.average('person_meta.weight')), avg | ||||||
|  |         ) | ||||||
|  |         self.assertAlmostEqual( | ||||||
|  |             int(self.Person.objects.aggregate_average('person_meta.weight')), | ||||||
|  |             avg | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self.Person(name='test meta none').save() |         self.Person(name='test meta none').save() | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             int(self.Person.objects.average('person_meta.weight')), avg) |             int(self.Person.objects.average('person_meta.weight')), avg | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             int(self.Person.objects.aggregate_average('person_meta.weight')), | ||||||
|  |             avg | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # test summing over a filtered queryset | ||||||
|  |         over_50 = [a for a in ages if a >= 50] | ||||||
|  |         avg = float(sum(over_50)) / len(over_50) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.filter(age__gte=50).average('age'), | ||||||
|  |             avg | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.filter(age__gte=50).aggregate_average('age'), | ||||||
|  |             avg | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_sum(self): |     def test_sum(self): | ||||||
|         """Ensure that field can be summed over correctly. |         """Ensure that field can be summed over correctly. | ||||||
| @@ -2734,20 +2766,44 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         for i, age in enumerate(ages): |         for i, age in enumerate(ages): | ||||||
|             self.Person(name='test%s' % i, age=age).save() |             self.Person(name='test%s' % i, age=age).save() | ||||||
|  |  | ||||||
|         self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) |         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.aggregate_sum('age'), sum(ages) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self.Person(name='ageless person').save() |         self.Person(name='ageless person').save() | ||||||
|         self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) |         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.aggregate_sum('age'), sum(ages) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         for i, age in enumerate(ages): |         for i, age in enumerate(ages): | ||||||
|             self.Person(name='test meta%s' % |             self.Person(name='test meta%s' % | ||||||
|                         i, person_meta=self.PersonMeta(weight=age)).save() |                         i, person_meta=self.PersonMeta(weight=age)).save() | ||||||
|  |  | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             int(self.Person.objects.sum('person_meta.weight')), sum(ages)) |             self.Person.objects.sum('person_meta.weight'), sum(ages) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.aggregate_sum('person_meta.weight'), | ||||||
|  |             sum(ages) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self.Person(name='weightless person').save() |         self.Person(name='weightless person').save() | ||||||
|         self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) |         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.aggregate_sum('age'), sum(ages) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # test summing over a filtered queryset | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.filter(age__gte=50).sum('age'), | ||||||
|  |             sum([a for a in ages if a >= 50]) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual( | ||||||
|  |             self.Person.objects.filter(age__gte=50).aggregate_sum('age'), | ||||||
|  |             sum([a for a in ages if a >= 50]) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_embedded_average(self): |     def test_embedded_average(self): | ||||||
|         class Pay(EmbeddedDocument): |         class Pay(EmbeddedDocument): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user