Fix BaseQuerySet#sum and BaseQuerySet#average for fields that specify a db_field
This commit is contained in:
		| @@ -1271,9 +1271,10 @@ class BaseQuerySet(object): | |||||||
|         :param field: the field to sum over; use dot notation to refer to |         :param field: the field to sum over; use dot notation to refer to | ||||||
|             embedded document fields |             embedded document fields | ||||||
|         """ |         """ | ||||||
|  |         db_field = self._fields_to_dbfields([field]).pop() | ||||||
|         pipeline = [ |         pipeline = [ | ||||||
|             {'$match': self._query}, |             {'$match': self._query}, | ||||||
|             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}} |             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}} | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|         # if we're performing a sum over a list field, we sum up all the |         # if we're performing a sum over a list field, we sum up all the | ||||||
| @@ -1300,9 +1301,10 @@ class BaseQuerySet(object): | |||||||
|         :param field: the field to average over; use dot notation to refer to |         :param field: the field to average over; use dot notation to refer to | ||||||
|             embedded document fields |             embedded document fields | ||||||
|         """ |         """ | ||||||
|  |         db_field = self._fields_to_dbfields([field]).pop() | ||||||
|         pipeline = [ |         pipeline = [ | ||||||
|             {'$match': self._query}, |             {'$match': self._query}, | ||||||
|             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}} |             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}} | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|         # if we're performing an average over a list field, we average out |         # if we're performing an average over a list field, we average out | ||||||
|   | |||||||
| @@ -2838,6 +2838,34 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             sum([a for a in ages if a >= 50]) |             sum([a for a in ages if a >= 50]) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_sum_over_db_field(self): | ||||||
|  |         """Ensure that a field mapped to a db field with a different name | ||||||
|  |         can be summed over correctly. | ||||||
|  |         """ | ||||||
|  |         class UserVisit(Document): | ||||||
|  |             num_visits = IntField(db_field='visits') | ||||||
|  |  | ||||||
|  |         UserVisit.drop_collection() | ||||||
|  |  | ||||||
|  |         UserVisit.objects.create(num_visits=10) | ||||||
|  |         UserVisit.objects.create(num_visits=5) | ||||||
|  |  | ||||||
|  |         self.assertEqual(UserVisit.objects.sum('num_visits'), 15) | ||||||
|  |  | ||||||
|  |     def test_average_over_db_field(self): | ||||||
|  |         """Ensure that a field mapped to a db field with a different name | ||||||
|  |         can have its average computed correctly. | ||||||
|  |         """ | ||||||
|  |         class UserVisit(Document): | ||||||
|  |             num_visits = IntField(db_field='visits') | ||||||
|  |  | ||||||
|  |         UserVisit.drop_collection() | ||||||
|  |  | ||||||
|  |         UserVisit.objects.create(num_visits=20) | ||||||
|  |         UserVisit.objects.create(num_visits=10) | ||||||
|  |  | ||||||
|  |         self.assertEqual(UserVisit.objects.average('num_visits'), 15) | ||||||
|  |  | ||||||
|     def test_embedded_average(self): |     def test_embedded_average(self): | ||||||
|         class Pay(EmbeddedDocument): |         class Pay(EmbeddedDocument): | ||||||
|             value = DecimalField() |             value = DecimalField() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user