From 94cad89e321b92239171fd0a2f11095fa2f01b09 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 15 Jun 2011 11:22:27 +0100 Subject: [PATCH] Fixes to item_frequencies - now handles path lookups fixes #194 --- .gitignore | 1 + mongoengine/queryset.py | 39 ++++++++++++++++++------- tests/queryset.py | 63 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 8951a0ce..315674fe 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ env/ .settings .project .pydevproject +tests/bugfix.py diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 666567e2..4ffa5324 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1303,7 +1303,16 @@ class QuerySet(object): # Substitute the correct name for the field into the javascript return u'["%s"]' % fields[-1].db_field - return re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = QuerySet._lookup_field(self._document, field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) + + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, code) + return code def exec_js(self, code, *fields, **options): """Execute a Javascript function on the server. A list of fields may be @@ -1405,12 +1414,15 @@ class QuerySet(object): def _item_frequencies_map_reduce(self, field, normalize=False): map_func = """ function() { - if (this[~%(field)s].constructor == Array) { - this[~%(field)s].forEach(function(item) { + path = '{{~%(field)s}}'.split('.'); + field = this; + for (p in path) { field = field[path[p]]; } + if (field.constructor == Array) { + field.forEach(function(item) { emit(item, 1); }); } else { - emit(this[~%(field)s], 1); + emit(field, 1); } } """ % dict(field=field) @@ -1443,12 +1455,16 @@ class QuerySet(object): def _item_frequencies_exec_js(self, field, normalize=False): """Uses exec_js to execute""" freq_func = """ - function(field) { + function(path) { + path = path.split('.'); + if (options.normalize) { var total = 0.0; db[collection].find(query).forEach(function(doc) { - if (doc[field].constructor == Array) { - total += doc[field].length; + field = doc; + for (p in path) { field = field[path[p]]; } + if (field.constructor == Array) { + total += field.length; } else { total++; } @@ -1461,18 +1477,21 @@ class QuerySet(object): inc /= total; } db[collection].find(query).forEach(function(doc) { - if (doc[field].constructor == Array) { - doc[field].forEach(function(item) { + field = doc; + for (p in path) { field = field[path[p]]; } + if (field.constructor == Array) { + field.forEach(function(item) { frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); }); } else { - var item = doc[field]; + var item = field; frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); } }); return frequencies; } """ + return self.exec_js(freq_func, field, normalize=normalize) def __repr__(self): diff --git a/tests/queryset.py b/tests/queryset.py index 37140f4a..cc219fba 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1116,6 +1116,11 @@ class QuerySetTest(unittest.TestCase): ] self.assertEqual(results, expected_results) + # Test template style + code = "{{~comments.content}}" + sub_code = BlogPost.objects._sub_js_fields(code) + self.assertEquals("cmnts.body", sub_code) + BlogPost.drop_collection() def test_delete(self): @@ -1637,6 +1642,64 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_item_frequencies_on_embedded(self): + """Ensure that item frequencies are properly generated from lists. + """ + + class Phone(EmbeddedDocument): + number = StringField() + + class Person(Document): + name = StringField() + phone = EmbeddedDocumentField(Phone) + + Person.drop_collection() + + doc = Person(name="Guido") + doc.phone = Phone(number='62-3331-1656') + doc.save() + + doc = Person(name="Marr") + doc.phone = Phone(number='62-3331-1656') + doc.save() + + doc = Person(name="WP Junior") + doc.phone = Phone(number='62-3332-1656') + doc.save() + + + def test_assertions(f): + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['62-3331-1656', '62-3332-1656']), set(f.keys())) + self.assertEqual(f['62-3331-1656'], 2) + self.assertEqual(f['62-3332-1656'], 1) + + exec_js = Person.objects.item_frequencies('phone.number') + map_reduce = Person.objects.item_frequencies('phone.number', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) + + # Ensure query is taken into account + def test_assertions(f): + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['62-3331-1656']), set(f.keys())) + self.assertEqual(f['62-3331-1656'], 2) + + exec_js = Person.objects(phone__number='62-3331-1656').item_frequencies('phone.number') + map_reduce = Person.objects(phone__number='62-3331-1656').item_frequencies('phone.number', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) + + # Check that normalization works + def test_assertions(f): + self.assertEqual(f['62-3331-1656'], 2.0/3.0) + self.assertEqual(f['62-3332-1656'], 1.0/3.0) + + exec_js = Person.objects.item_frequencies('phone.number', normalize=True) + map_reduce = Person.objects.item_frequencies('phone.number', normalize=True, map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) + def test_average(self): """Ensure that field can be averaged correctly. """