Fix BaseQuerySet#sum and BaseQuerySet#average for fields that specify a db_field
This commit is contained in:
parent
2b7417c728
commit
999cdfd997
@ -1271,9 +1271,10 @@ class BaseQuerySet(object):
|
||||
:param field: the field to sum over; use dot notation to refer to
|
||||
embedded document fields
|
||||
"""
|
||||
db_field = self._fields_to_dbfields([field]).pop()
|
||||
pipeline = [
|
||||
{'$match': self._query},
|
||||
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
|
||||
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}}
|
||||
]
|
||||
|
||||
# if we're performing a sum over a list field, we sum up all the
|
||||
@ -1300,9 +1301,10 @@ class BaseQuerySet(object):
|
||||
:param field: the field to average over; use dot notation to refer to
|
||||
embedded document fields
|
||||
"""
|
||||
db_field = self._fields_to_dbfields([field]).pop()
|
||||
pipeline = [
|
||||
{'$match': self._query},
|
||||
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
|
||||
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}}
|
||||
]
|
||||
|
||||
# if we're performing an average over a list field, we average out
|
||||
|
@ -2838,6 +2838,34 @@ class QuerySetTest(unittest.TestCase):
|
||||
sum([a for a in ages if a >= 50])
|
||||
)
|
||||
|
||||
def test_sum_over_db_field(self):
|
||||
"""Ensure that a field mapped to a db field with a different name
|
||||
can be summed over correctly.
|
||||
"""
|
||||
class UserVisit(Document):
|
||||
num_visits = IntField(db_field='visits')
|
||||
|
||||
UserVisit.drop_collection()
|
||||
|
||||
UserVisit.objects.create(num_visits=10)
|
||||
UserVisit.objects.create(num_visits=5)
|
||||
|
||||
self.assertEqual(UserVisit.objects.sum('num_visits'), 15)
|
||||
|
||||
def test_average_over_db_field(self):
|
||||
"""Ensure that a field mapped to a db field with a different name
|
||||
can have its average computed correctly.
|
||||
"""
|
||||
class UserVisit(Document):
|
||||
num_visits = IntField(db_field='visits')
|
||||
|
||||
UserVisit.drop_collection()
|
||||
|
||||
UserVisit.objects.create(num_visits=20)
|
||||
UserVisit.objects.create(num_visits=10)
|
||||
|
||||
self.assertEqual(UserVisit.objects.average('num_visits'), 15)
|
||||
|
||||
def test_embedded_average(self):
|
||||
class Pay(EmbeddedDocument):
|
||||
value = DecimalField()
|
||||
|
Loading…
x
Reference in New Issue
Block a user