diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 0b2898f2..d3bb4c4b 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1002,26 +1002,27 @@ class BaseQuerySet(object): .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work with sharding. """ - map_func = Code(""" + map_func = """ function() { - function deepFind(obj, path) { - var paths = path.split('.') - , current = obj - , i; + var path = '{{~%(field)s}}'.split('.'), + field = this; - for (i = 0; i < paths.length; ++i) { - if (current[paths[i]] == undefined) { - return undefined; - } else { - current = current[paths[i]]; - } - } - return current; + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; } - emit(1, deepFind(this, field) || 0); + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(1, item||0); + }); + } else if (typeof field != 'undefined') { + emit(1, field||0); + } } - """, scope={'field': field}) + """ % dict(field=field) reduce_func = Code(""" function(key, values) { @@ -1047,28 +1048,27 @@ class BaseQuerySet(object): .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work with sharding. """ - map_func = Code(""" + map_func = """ function() { - function deepFind(obj, path) { - var paths = path.split('.') - , current = obj - , i; + var path = '{{~%(field)s}}'.split('.'), + field = this; - for (i = 0; i < paths.length; ++i) { - if (current[paths[i]] == undefined) { - return undefined; - } else { - current = current[paths[i]]; - } - } - return current; + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; } - val = deepFind(this, field) - if (val !== undefined) - emit(1, {t: val || 0, c: 1}); + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(1, {t: item||0, c: 1}); + }); + } else if (typeof field != 'undefined') { + emit(1, {t: field||0, c: 1}); + } } - """, scope={'field': field}) + """ % dict(field=field) reduce_func = Code(""" function(key, values) { diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 6e3eb9bf..c56b31eb 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2246,6 +2246,145 @@ class QuerySetTest(unittest.TestCase): self.Person(name='weightless person').save() self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + def test_embedded_average(self): + class Pay(EmbeddedDocument): + value = DecimalField() + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(value=150)).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(value=530)).save() + + Doc(name=u"Tayza mariana", + pay=Pay(value=165)).save() + + Doc(name=u"Eliana Costa", + pay=Pay(value=115)).save() + + self.assertEqual( + Doc.objects.average('pay.value'), + 240) + + def test_embedded_array_average(self): + class Pay(EmbeddedDocument): + values = ListField(DecimalField()) + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(values=[150, 100])).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(values=[530, 100])).save() + + Doc(name=u"Tayza mariana", + pay=Pay(values=[165, 100])).save() + + Doc(name=u"Eliana Costa", + pay=Pay(values=[115, 100])).save() + + self.assertEqual( + Doc.objects.average('pay.values'), + 170) + + def test_array_average(self): + class Doc(Document): + values = ListField(DecimalField()) + + Doc.drop_collection() + + Doc(values=[150, 100]).save() + Doc(values=[530, 100]).save() + Doc(values=[165, 100]).save() + Doc(values=[115, 100]).save() + + self.assertEqual( + Doc.objects.average('values'), + 170) + + def test_embedded_sum(self): + class Pay(EmbeddedDocument): + value = DecimalField() + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(value=150)).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(value=530)).save() + + Doc(name=u"Tayza mariana", + pay=Pay(value=165)).save() + + Doc(name=u"Eliana Costa", + pay=Pay(value=115)).save() + + self.assertEqual( + Doc.objects.sum('pay.value'), + 960) + + + def test_embedded_array_sum(self): + class Pay(EmbeddedDocument): + values = ListField(DecimalField()) + + class Doc(Document): + name = StringField() + pay = EmbeddedDocumentField( + Pay) + + Doc.drop_collection() + + Doc(name=u"Wilson Junior", + pay=Pay(values=[150, 100])).save() + + Doc(name=u"Isabella Luanna", + pay=Pay(values=[530, 100])).save() + + Doc(name=u"Tayza mariana", + pay=Pay(values=[165, 100])).save() + + Doc(name=u"Eliana Costa", + pay=Pay(values=[115, 100])).save() + + self.assertEqual( + Doc.objects.sum('pay.values'), + 1360) + + def test_array_sum(self): + class Doc(Document): + values = ListField(DecimalField()) + + Doc.drop_collection() + + Doc(values=[150, 100]).save() + Doc(values=[530, 100]).save() + Doc(values=[165, 100]).save() + Doc(values=[115, 100]).save() + + self.assertEqual( + Doc.objects.sum('values'), + 1360) + def test_distinct(self): """Ensure that the QuerySet.distinct method works. """