diff --git a/docs/changelog.rst b/docs/changelog.rst index b3925e75..466fdcbf 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.6.15 ================= +- Added support for null / zero / false values in item_frequencies - Fixed cascade save edge case - Fixed geo index creation through reference fields - Added support for args / kwargs when using @queryset_manager diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 4f7443f7..6499c3e0 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1718,10 +1718,11 @@ class QuerySet(object): def _item_frequencies_map_reduce(self, field, normalize=False): map_func = """ function() { - path = '{{~%(field)s}}'.split('.'); - field = this; + var path = '{{~%(field)s}}'.split('.'); + var field = this; + for (p in path) { - if (field) + if (typeof field != 'undefined') field = field[path[p]]; else break; @@ -1730,7 +1731,7 @@ class QuerySet(object): field.forEach(function(item) { emit(item, 1); }); - } else if (field) { + } else if (typeof field != 'undefined') { emit(field, 1); } else { emit(null, 1); @@ -1754,12 +1755,12 @@ class QuerySet(object): if isinstance(key, float): if int(key) == key: key = int(key) - key = str(key) - frequencies[key] = f.value + frequencies[key] = int(f.value) if normalize: count = sum(frequencies.values()) - frequencies = dict([(k, v / count) for k, v in frequencies.items()]) + frequencies = dict([(k, float(v) / count) + for k, v in frequencies.items()]) return frequencies @@ -1767,31 +1768,28 @@ class QuerySet(object): """Uses exec_js to execute""" freq_func = """ function(path) { - path = path.split('.'); + var path = path.split('.'); - if (options.normalize) { - var total = 0.0; - db[collection].find(query).forEach(function(doc) { - field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - total += field.length; - } else { - total++; - } - }); - } + var total = 0.0; + db[collection].find(query).forEach(function(doc) { + var field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + total += field.length; + } else { + total++; + } + }); var frequencies = {}; + var types = {}; var inc = 1.0; - if (options.normalize) { - inc /= total; - } + db[collection].find(query).forEach(function(doc) { field = doc; for (p in path) { @@ -1806,17 +1804,28 @@ class QuerySet(object): }); } else { var item = field; + types[item] = item; frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); } }); - return frequencies; + return [total, frequencies, types]; } """ - data = self.exec_js(freq_func, field, normalize=normalize) - if 'undefined' in data: - data[None] = data['undefined'] - del(data['undefined']) - return data + total, data, types = self.exec_js(freq_func, field) + values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) + + if normalize: + values = dict([(k, float(v) / total) for k, v in values.items()]) + + frequencies = {} + for k, v in values.iteritems(): + if isinstance(k, float): + if int(k) == k: + k = int(k) + + frequencies[k] = v + + return frequencies def __repr__(self): """Provides the string representation of the QuerySet diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 939451c1..1bac6a97 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -1994,9 +1994,9 @@ class QuerySetTest(unittest.TestCase): # Check item_frequencies works for non-list fields def test_assertions(f): - self.assertEqual(set(['1', '2']), set(f.keys())) - self.assertEqual(f['1'], 1) - self.assertEqual(f['2'], 2) + self.assertEqual(set([1, 2]), set(f.keys())) + self.assertEqual(f[1], 1) + self.assertEqual(f[2], 2) exec_js = BlogPost.objects.item_frequencies('hits') map_reduce = BlogPost.objects.item_frequencies('hits', map_reduce=True) @@ -2096,7 +2096,6 @@ class QuerySetTest(unittest.TestCase): data = EmbeddedDocumentField(Data, required=True) extra = EmbeddedDocumentField(Extra) - Person.drop_collection() p = Person() @@ -2114,6 +2113,52 @@ class QuerySetTest(unittest.TestCase): ot = Person.objects.item_frequencies('extra.tag', map_reduce=True) self.assertEquals(ot, {None: 1.0, u'friend': 1.0}) + def test_item_frequencies_with_0_values(self): + class Test(Document): + val = IntField() + + Test.drop_collection() + t = Test() + t.val = 0 + t.save() + + ot = Test.objects.item_frequencies('val', map_reduce=True) + self.assertEquals(ot, {0: 1}) + ot = Test.objects.item_frequencies('val', map_reduce=False) + self.assertEquals(ot, {0: 1}) + + def test_item_frequencies_with_False_values(self): + class Test(Document): + val = BooleanField() + + Test.drop_collection() + t = Test() + t.val = False + t.save() + + ot = Test.objects.item_frequencies('val', map_reduce=True) + self.assertEquals(ot, {False: 1}) + ot = Test.objects.item_frequencies('val', map_reduce=False) + self.assertEquals(ot, {False: 1}) + + def test_item_frequencies_normalize(self): + class Test(Document): + val = IntField() + + Test.drop_collection() + + for i in xrange(50): + Test(val=1).save() + + for i in xrange(20): + Test(val=2).save() + + freqs = Test.objects.item_frequencies('val', map_reduce=False, normalize=True) + self.assertEquals(freqs, {1: 50.0/70, 2: 20.0/70}) + + freqs = Test.objects.item_frequencies('val', map_reduce=True, normalize=True) + self.assertEquals(freqs, {1: 50.0/70, 2: 20.0/70}) + def test_average(self): """Ensure that field can be averaged correctly. """