Merge pull request #1336 from closeio/aggregate-sum-and-avg
Replace map-reduce based QuerySet.sum/average with aggregation-based implementations
This commit is contained in:
commit
9d6f9b1f26
@ -15,6 +15,7 @@ Changes in 0.10.7 - DEV
|
|||||||
- ListField now handles negative indicies correctly. #1270
|
- ListField now handles negative indicies correctly. #1270
|
||||||
- Fixed AttributeError when initializing EmbeddedDocument with positional args. #681
|
- Fixed AttributeError when initializing EmbeddedDocument with positional args. #681
|
||||||
- Fixed no_cursor_timeout error with pymongo 3.0+ #1304
|
- Fixed no_cursor_timeout error with pymongo 3.0+ #1304
|
||||||
|
- Replaced map-reduce based QuerySet.sum/average with aggregation-based implementations #1336
|
||||||
|
|
||||||
Changes in 0.10.6
|
Changes in 0.10.6
|
||||||
=================
|
=================
|
||||||
|
@ -1237,66 +1237,28 @@ class BaseQuerySet(object):
|
|||||||
def sum(self, field):
|
def sum(self, field):
|
||||||
"""Sum over the values of the specified field.
|
"""Sum over the values of the specified field.
|
||||||
|
|
||||||
:param field: the field to sum over; use dot-notation to refer to
|
:param field: the field to sum over; use dot notation to refer to
|
||||||
embedded document fields
|
embedded document fields
|
||||||
|
|
||||||
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
|
|
||||||
with sharding.
|
|
||||||
"""
|
"""
|
||||||
map_func = """
|
pipeline = [
|
||||||
function() {
|
|
||||||
var path = '{{~%(field)s}}'.split('.'),
|
|
||||||
field = this;
|
|
||||||
|
|
||||||
for (p in path) {
|
|
||||||
if (typeof field != 'undefined')
|
|
||||||
field = field[path[p]];
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field && field.constructor == Array) {
|
|
||||||
field.forEach(function(item) {
|
|
||||||
emit(1, item||0);
|
|
||||||
});
|
|
||||||
} else if (typeof field != 'undefined') {
|
|
||||||
emit(1, field||0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
""" % dict(field=field)
|
|
||||||
|
|
||||||
reduce_func = Code("""
|
|
||||||
function(key, values) {
|
|
||||||
var sum = 0;
|
|
||||||
for (var i in values) {
|
|
||||||
sum += values[i];
|
|
||||||
}
|
|
||||||
return sum;
|
|
||||||
}
|
|
||||||
""")
|
|
||||||
|
|
||||||
for result in self.map_reduce(map_func, reduce_func, output='inline'):
|
|
||||||
return result.value
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def aggregate_sum(self, field):
|
|
||||||
"""Sum over the values of the specified field.
|
|
||||||
|
|
||||||
:param field: the field to sum over; use dot-notation to refer to
|
|
||||||
embedded document fields
|
|
||||||
|
|
||||||
This method is more performant than the regular `sum`, because it uses
|
|
||||||
the aggregation framework instead of map-reduce.
|
|
||||||
"""
|
|
||||||
result = self._document._get_collection().aggregate([
|
|
||||||
{'$match': self._query},
|
{'$match': self._query},
|
||||||
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
|
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
|
||||||
])
|
]
|
||||||
|
|
||||||
|
# if we're performing a sum over a list field, we sum up all the
|
||||||
|
# elements in the list, hence we need to $unwind the arrays first
|
||||||
|
ListField = _import_class('ListField')
|
||||||
|
field_parts = field.split('.')
|
||||||
|
field_instances = self._document._lookup_field(field_parts)
|
||||||
|
if isinstance(field_instances[-1], ListField):
|
||||||
|
pipeline.insert(1, {'$unwind': '$' + field})
|
||||||
|
|
||||||
|
result = self._document._get_collection().aggregate(pipeline)
|
||||||
if IS_PYMONGO_3:
|
if IS_PYMONGO_3:
|
||||||
result = list(result)
|
result = tuple(result)
|
||||||
else:
|
else:
|
||||||
result = result.get('result')
|
result = result.get('result')
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return result[0]['total']
|
return result[0]['total']
|
||||||
return 0
|
return 0
|
||||||
@ -1304,73 +1266,26 @@ class BaseQuerySet(object):
|
|||||||
def average(self, field):
|
def average(self, field):
|
||||||
"""Average over the values of the specified field.
|
"""Average over the values of the specified field.
|
||||||
|
|
||||||
:param field: the field to average over; use dot-notation to refer to
|
:param field: the field to average over; use dot notation to refer to
|
||||||
embedded document fields
|
embedded document fields
|
||||||
|
|
||||||
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
|
|
||||||
with sharding.
|
|
||||||
"""
|
"""
|
||||||
map_func = """
|
pipeline = [
|
||||||
function() {
|
|
||||||
var path = '{{~%(field)s}}'.split('.'),
|
|
||||||
field = this;
|
|
||||||
|
|
||||||
for (p in path) {
|
|
||||||
if (typeof field != 'undefined')
|
|
||||||
field = field[path[p]];
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field && field.constructor == Array) {
|
|
||||||
field.forEach(function(item) {
|
|
||||||
emit(1, {t: item||0, c: 1});
|
|
||||||
});
|
|
||||||
} else if (typeof field != 'undefined') {
|
|
||||||
emit(1, {t: field||0, c: 1});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
""" % dict(field=field)
|
|
||||||
|
|
||||||
reduce_func = Code("""
|
|
||||||
function(key, values) {
|
|
||||||
var out = {t: 0, c: 0};
|
|
||||||
for (var i in values) {
|
|
||||||
var value = values[i];
|
|
||||||
out.t += value.t;
|
|
||||||
out.c += value.c;
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
""")
|
|
||||||
|
|
||||||
finalize_func = Code("""
|
|
||||||
function(key, value) {
|
|
||||||
return value.t / value.c;
|
|
||||||
}
|
|
||||||
""")
|
|
||||||
|
|
||||||
for result in self.map_reduce(map_func, reduce_func,
|
|
||||||
finalize_f=finalize_func, output='inline'):
|
|
||||||
return result.value
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def aggregate_average(self, field):
|
|
||||||
"""Average over the values of the specified field.
|
|
||||||
|
|
||||||
:param field: the field to average over; use dot-notation to refer to
|
|
||||||
embedded document fields
|
|
||||||
|
|
||||||
This method is more performant than the regular `average`, because it
|
|
||||||
uses the aggregation framework instead of map-reduce.
|
|
||||||
"""
|
|
||||||
result = self._document._get_collection().aggregate([
|
|
||||||
{'$match': self._query},
|
{'$match': self._query},
|
||||||
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
|
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
|
||||||
])
|
]
|
||||||
|
|
||||||
|
# if we're performing an average over a list field, we average out
|
||||||
|
# all the elements in the list, hence we need to $unwind the arrays
|
||||||
|
# first
|
||||||
|
ListField = _import_class('ListField')
|
||||||
|
field_parts = field.split('.')
|
||||||
|
field_instances = self._document._lookup_field(field_parts)
|
||||||
|
if isinstance(field_instances[-1], ListField):
|
||||||
|
pipeline.insert(1, {'$unwind': '$' + field})
|
||||||
|
|
||||||
|
result = self._document._get_collection().aggregate(pipeline)
|
||||||
if IS_PYMONGO_3:
|
if IS_PYMONGO_3:
|
||||||
result = list(result)
|
result = tuple(result)
|
||||||
else:
|
else:
|
||||||
result = result.get('result')
|
result = result.get('result')
|
||||||
if result:
|
if result:
|
||||||
|
@ -2766,25 +2766,15 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
|
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
|
||||||
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
|
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
|
||||||
self.assertAlmostEqual(
|
|
||||||
int(self.Person.objects.aggregate_average('age')), avg
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Person(name='ageless person').save()
|
self.Person(name='ageless person').save()
|
||||||
self.assertEqual(int(self.Person.objects.average('age')), avg)
|
self.assertEqual(int(self.Person.objects.average('age')), avg)
|
||||||
self.assertEqual(
|
|
||||||
int(self.Person.objects.aggregate_average('age')), avg
|
|
||||||
)
|
|
||||||
|
|
||||||
# dot notation
|
# dot notation
|
||||||
self.Person(
|
self.Person(
|
||||||
name='person meta', person_meta=self.PersonMeta(weight=0)).save()
|
name='person meta', person_meta=self.PersonMeta(weight=0)).save()
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
int(self.Person.objects.average('person_meta.weight')), 0)
|
int(self.Person.objects.average('person_meta.weight')), 0)
|
||||||
self.assertAlmostEqual(
|
|
||||||
int(self.Person.objects.aggregate_average('person_meta.weight')),
|
|
||||||
0
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, weight in enumerate(ages):
|
for i, weight in enumerate(ages):
|
||||||
self.Person(
|
self.Person(
|
||||||
@ -2793,19 +2783,11 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
int(self.Person.objects.average('person_meta.weight')), avg
|
int(self.Person.objects.average('person_meta.weight')), avg
|
||||||
)
|
)
|
||||||
self.assertAlmostEqual(
|
|
||||||
int(self.Person.objects.aggregate_average('person_meta.weight')),
|
|
||||||
avg
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Person(name='test meta none').save()
|
self.Person(name='test meta none').save()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
int(self.Person.objects.average('person_meta.weight')), avg
|
int(self.Person.objects.average('person_meta.weight')), avg
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
|
||||||
int(self.Person.objects.aggregate_average('person_meta.weight')),
|
|
||||||
avg
|
|
||||||
)
|
|
||||||
|
|
||||||
# test summing over a filtered queryset
|
# test summing over a filtered queryset
|
||||||
over_50 = [a for a in ages if a >= 50]
|
over_50 = [a for a in ages if a >= 50]
|
||||||
@ -2814,10 +2796,6 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.Person.objects.filter(age__gte=50).average('age'),
|
self.Person.objects.filter(age__gte=50).average('age'),
|
||||||
avg
|
avg
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
|
||||||
self.Person.objects.filter(age__gte=50).aggregate_average('age'),
|
|
||||||
avg
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_sum(self):
|
def test_sum(self):
|
||||||
"""Ensure that field can be summed over correctly.
|
"""Ensure that field can be summed over correctly.
|
||||||
@ -2827,15 +2805,9 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.Person(name='test%s' % i, age=age).save()
|
self.Person(name='test%s' % i, age=age).save()
|
||||||
|
|
||||||
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
||||||
self.assertEqual(
|
|
||||||
self.Person.objects.aggregate_sum('age'), sum(ages)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Person(name='ageless person').save()
|
self.Person(name='ageless person').save()
|
||||||
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
||||||
self.assertEqual(
|
|
||||||
self.Person.objects.aggregate_sum('age'), sum(ages)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, age in enumerate(ages):
|
for i, age in enumerate(ages):
|
||||||
self.Person(name='test meta%s' %
|
self.Person(name='test meta%s' %
|
||||||
@ -2844,26 +2816,15 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.Person.objects.sum('person_meta.weight'), sum(ages)
|
self.Person.objects.sum('person_meta.weight'), sum(ages)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
|
||||||
self.Person.objects.aggregate_sum('person_meta.weight'),
|
|
||||||
sum(ages)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.Person(name='weightless person').save()
|
self.Person(name='weightless person').save()
|
||||||
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
||||||
self.assertEqual(
|
|
||||||
self.Person.objects.aggregate_sum('age'), sum(ages)
|
|
||||||
)
|
|
||||||
|
|
||||||
# test summing over a filtered queryset
|
# test summing over a filtered queryset
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.Person.objects.filter(age__gte=50).sum('age'),
|
self.Person.objects.filter(age__gte=50).sum('age'),
|
||||||
sum([a for a in ages if a >= 50])
|
sum([a for a in ages if a >= 50])
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
|
||||||
self.Person.objects.filter(age__gte=50).aggregate_sum('age'),
|
|
||||||
sum([a for a in ages if a >= 50])
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_embedded_average(self):
|
def test_embedded_average(self):
|
||||||
class Pay(EmbeddedDocument):
|
class Pay(EmbeddedDocument):
|
||||||
@ -2876,21 +2837,12 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
|
|
||||||
Doc(name=u"Wilson Junior",
|
Doc(name='Wilson Junior', pay=Pay(value=150)).save()
|
||||||
pay=Pay(value=150)).save()
|
Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
|
||||||
|
Doc(name='Tayza mariana', pay=Pay(value=165)).save()
|
||||||
|
Doc(name='Eliana Costa', pay=Pay(value=115)).save()
|
||||||
|
|
||||||
Doc(name=u"Isabella Luanna",
|
self.assertEqual(Doc.objects.average('pay.value'), 240)
|
||||||
pay=Pay(value=530)).save()
|
|
||||||
|
|
||||||
Doc(name=u"Tayza mariana",
|
|
||||||
pay=Pay(value=165)).save()
|
|
||||||
|
|
||||||
Doc(name=u"Eliana Costa",
|
|
||||||
pay=Pay(value=115)).save()
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Doc.objects.average('pay.value'),
|
|
||||||
240)
|
|
||||||
|
|
||||||
def test_embedded_array_average(self):
|
def test_embedded_array_average(self):
|
||||||
class Pay(EmbeddedDocument):
|
class Pay(EmbeddedDocument):
|
||||||
@ -2898,26 +2850,16 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
pay = EmbeddedDocumentField(
|
pay = EmbeddedDocumentField(Pay)
|
||||||
Pay)
|
|
||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
|
|
||||||
Doc(name=u"Wilson Junior",
|
Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
|
||||||
pay=Pay(values=[150, 100])).save()
|
Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
|
||||||
|
Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
|
||||||
|
Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
|
||||||
|
|
||||||
Doc(name=u"Isabella Luanna",
|
self.assertEqual(Doc.objects.average('pay.values'), 170)
|
||||||
pay=Pay(values=[530, 100])).save()
|
|
||||||
|
|
||||||
Doc(name=u"Tayza mariana",
|
|
||||||
pay=Pay(values=[165, 100])).save()
|
|
||||||
|
|
||||||
Doc(name=u"Eliana Costa",
|
|
||||||
pay=Pay(values=[115, 100])).save()
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Doc.objects.average('pay.values'),
|
|
||||||
170)
|
|
||||||
|
|
||||||
def test_array_average(self):
|
def test_array_average(self):
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
@ -2930,9 +2872,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
Doc(values=[165, 100]).save()
|
Doc(values=[165, 100]).save()
|
||||||
Doc(values=[115, 100]).save()
|
Doc(values=[115, 100]).save()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(Doc.objects.average('values'), 170)
|
||||||
Doc.objects.average('values'),
|
|
||||||
170)
|
|
||||||
|
|
||||||
def test_embedded_sum(self):
|
def test_embedded_sum(self):
|
||||||
class Pay(EmbeddedDocument):
|
class Pay(EmbeddedDocument):
|
||||||
@ -2940,26 +2880,16 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
pay = EmbeddedDocumentField(
|
pay = EmbeddedDocumentField(Pay)
|
||||||
Pay)
|
|
||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
|
|
||||||
Doc(name=u"Wilson Junior",
|
Doc(name='Wilson Junior', pay=Pay(value=150)).save()
|
||||||
pay=Pay(value=150)).save()
|
Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
|
||||||
|
Doc(name='Tayza mariana', pay=Pay(value=165)).save()
|
||||||
|
Doc(name='Eliana Costa', pay=Pay(value=115)).save()
|
||||||
|
|
||||||
Doc(name=u"Isabella Luanna",
|
self.assertEqual(Doc.objects.sum('pay.value'), 960)
|
||||||
pay=Pay(value=530)).save()
|
|
||||||
|
|
||||||
Doc(name=u"Tayza mariana",
|
|
||||||
pay=Pay(value=165)).save()
|
|
||||||
|
|
||||||
Doc(name=u"Eliana Costa",
|
|
||||||
pay=Pay(value=115)).save()
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Doc.objects.sum('pay.value'),
|
|
||||||
960)
|
|
||||||
|
|
||||||
def test_embedded_array_sum(self):
|
def test_embedded_array_sum(self):
|
||||||
class Pay(EmbeddedDocument):
|
class Pay(EmbeddedDocument):
|
||||||
@ -2967,26 +2897,16 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
pay = EmbeddedDocumentField(
|
pay = EmbeddedDocumentField(Pay)
|
||||||
Pay)
|
|
||||||
|
|
||||||
Doc.drop_collection()
|
Doc.drop_collection()
|
||||||
|
|
||||||
Doc(name=u"Wilson Junior",
|
Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
|
||||||
pay=Pay(values=[150, 100])).save()
|
Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
|
||||||
|
Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
|
||||||
|
Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
|
||||||
|
|
||||||
Doc(name=u"Isabella Luanna",
|
self.assertEqual(Doc.objects.sum('pay.values'), 1360)
|
||||||
pay=Pay(values=[530, 100])).save()
|
|
||||||
|
|
||||||
Doc(name=u"Tayza mariana",
|
|
||||||
pay=Pay(values=[165, 100])).save()
|
|
||||||
|
|
||||||
Doc(name=u"Eliana Costa",
|
|
||||||
pay=Pay(values=[115, 100])).save()
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Doc.objects.sum('pay.values'),
|
|
||||||
1360)
|
|
||||||
|
|
||||||
def test_array_sum(self):
|
def test_array_sum(self):
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
@ -2999,9 +2919,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
Doc(values=[165, 100]).save()
|
Doc(values=[165, 100]).save()
|
||||||
Doc(values=[115, 100]).save()
|
Doc(values=[115, 100]).save()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(Doc.objects.sum('values'), 1360)
|
||||||
Doc.objects.sum('values'),
|
|
||||||
1360)
|
|
||||||
|
|
||||||
def test_distinct(self):
|
def test_distinct(self):
|
||||||
"""Ensure that the QuerySet.distinct method works.
|
"""Ensure that the QuerySet.distinct method works.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user