From 999cdfd9970c2d1c160cef04d13113c8ad3fe85f Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Fri, 2 Dec 2016 11:32:38 -0500 Subject: [PATCH] Fix BaseQuerySet#sum and BaseQuerySet#average for fields that specify a db_field --- mongoengine/queryset/base.py | 6 ++++-- tests/queryset/queryset.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 8b99d69c..bc70b44a 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1271,9 +1271,10 @@ class BaseQuerySet(object): :param field: the field to sum over; use dot notation to refer to embedded document fields """ + db_field = self._fields_to_dbfields([field]).pop() pipeline = [ {'$match': self._query}, - {'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}} + {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}} ] # if we're performing a sum over a list field, we sum up all the @@ -1300,9 +1301,10 @@ class BaseQuerySet(object): :param field: the field to average over; use dot notation to refer to embedded document fields """ + db_field = self._fields_to_dbfields([field]).pop() pipeline = [ {'$match': self._query}, - {'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}} + {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}} ] # if we're performing an average over a list field, we average out diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 8430dcef..6e2e9e5d 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2838,6 +2838,34 @@ class QuerySetTest(unittest.TestCase): sum([a for a in ages if a >= 50]) ) + def test_sum_over_db_field(self): + """Ensure that a field mapped to a db field with a different name + can be summed over correctly. + """ + class UserVisit(Document): + num_visits = IntField(db_field='visits') + + UserVisit.drop_collection() + + UserVisit.objects.create(num_visits=10) + UserVisit.objects.create(num_visits=5) + + self.assertEqual(UserVisit.objects.sum('num_visits'), 15) + + def test_average_over_db_field(self): + """Ensure that a field mapped to a db field with a different name + can have its average computed correctly. + """ + class UserVisit(Document): + num_visits = IntField(db_field='visits') + + UserVisit.drop_collection() + + UserVisit.objects.create(num_visits=20) + UserVisit.objects.create(num_visits=10) + + self.assertEqual(UserVisit.objects.average('num_visits'), 15) + def test_embedded_average(self): class Pay(EmbeddedDocument): value = DecimalField()