From 4a269eb2c4bafc808cfc8d290ff5d34ed8228316 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Mon, 12 Dec 2011 13:39:37 -0200 Subject: [PATCH 1/6] added .select method --- mongoengine/queryset.py | 104 ++++++++++++++++++++++++++++------------ tests/queryset.py | 45 +++++++++++++++++ 2 files changed, 119 insertions(+), 30 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0c39253b..1b3257ac 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -314,6 +314,27 @@ class QueryFieldList(object): def __nonzero__(self): return bool(self.fields) +class SelectResult(object): + """ + Used for .select method in QuerySet + """ + def __init__(self, cursor, fields): + self._cursor = cursor + self._fields = fields + + def next(self): + try: + data = self._cursor.next() + return [data.get(f) for f in self._fields] + except StopIteration, e: + self.rewind() + raise e + + def rewind(self): + self._cursor.rewind() + + def __iter__(self): + return self class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -548,33 +569,38 @@ class QuerySet(object): cursor_args['fields'] = self._loaded_fields.as_dict() return cursor_args + def _build_cursor(self, **cursor_args): + obj = self._collection.find(self._query, + **cursor_args) + # Apply where clauses to cursor + if self._where_clause: + obj.where(self._where_clause) + + # apply default ordering + if self._ordering: + obj.sort(self._ordering) + elif self._document._meta['ordering']: + self._ordering = self._get_order_key_list( + *self._document._meta['ordering']) + obj.sort(self._ordering) + + if self._limit is not None: + obj.limit(self._limit) + + if self._skip is not None: + obj.skip(self._skip) + + if self._hint != -1: + obj.hint(self._hint) + + return obj + @property def _cursor(self): if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - self._cursor_obj.where(self._where_clause) - - # apply default ordering - if self._ordering: - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - self.order_by(*self._document._meta['ordering']) - - if self._limit is not None: - self._cursor_obj.limit(self._limit) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - + self._cursor_obj = self._build_cursor(**self._cursor_args) + return self._cursor_obj - @classmethod def _lookup_field(cls, document, parts): """Lookup a field based on its attribute and return a list containing @@ -803,6 +829,16 @@ class QuerySet(object): doc.save() return doc + def select(self, *fields): + """ + Select a field and make a tuple of element + """ + cursor_args = self._cursor_args + cursor_args['fields'] = self._fields_to_dbfields(fields) + cursor = self._build_cursor(**cursor_args) + + return SelectResult(cursor, fields) + def first(self): """Retrieve the first object matching the query. """ @@ -1163,13 +1199,9 @@ class QuerySet(object): ret.append(field) return ret - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction + def _get_order_key_list(self, *keys): + """ + Build order list for query """ key_list = [] for key in keys: @@ -1186,6 +1218,18 @@ class QuerySet(object): pass key_list.append((key, direction)) + return key_list + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + + key_list = self._get_order_key_list(*keys) self._ordering = key_list self._cursor.sort(key_list) return self diff --git a/tests/queryset.py b/tests/queryset.py index 37fa5247..0f5d84cf 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2909,6 +2909,51 @@ class QueryFieldListTest(unittest.TestCase): ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) + + def test_select(self): + class TestDoc(Document): + x = IntField() + y = BooleanField() + + TestDoc.drop_collection() + + TestDoc(x=10, y=True).save() + TestDoc(x=20, y=False).save() + TestDoc(x=30, y=True).save() + + plist = list(TestDoc.objects.select('x', 'y')) + + self.assertEqual(len(plist), 3) + self.assertEqual(plist[0], [10, True]) + self.assertEqual(plist[1], [20, False]) + self.assertEqual(plist[2], [30, True]) + + class UserDoc(Document): + name = StringField() + age = IntField() + + UserDoc.drop_collection() + + UserDoc(name="Wilson Jr", age=19).save() + UserDoc(name="Wilson", age=43).save() + UserDoc(name="Eliana", age=37).save() + UserDoc(name="Tayza", age=15).save() + + ulist = list(UserDoc.objects.select('name', 'age')) + + self.assertEqual(ulist, [ + [u'Wilson Jr', 19], + [u'Wilson', 43], + [u'Eliana', 37], + [u'Tayza', 15]]) + + ulist = list(UserDoc.objects.order_by('age').select('name')) + + self.assertEqual(ulist, [ + [u'Tayza'], + [u'Wilson Jr'], + [u'Eliana'], + [u'Wilson']]) if __name__ == '__main__': unittest.main() From ed5fba6b0f8794c5d0c2b66a4ede2fdd6346ec75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Tue, 13 Dec 2011 07:46:49 -0200 Subject: [PATCH 2/6] support for embedded fields --- mongoengine/queryset.py | 14 ++++++++++++-- tests/queryset.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1b3257ac..1a2ff1c3 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -320,12 +320,22 @@ class SelectResult(object): """ def __init__(self, cursor, fields): self._cursor = cursor - self._fields = fields + self._fields = [f.split('.') for f in fields] + + def _get_value(self, keys, data): + for key in keys: + if data: + data = data.get(key) + else: + break + + return data def next(self): try: data = self._cursor.next() - return [data.get(f) for f in self._fields] + return [self._get_value(f, data) + for f in self._fields] except StopIteration, e: self.rewind() raise e diff --git a/tests/queryset.py b/tests/queryset.py index 0f5d84cf..991fb00a 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2955,5 +2955,46 @@ class QueryFieldListTest(unittest.TestCase): [u'Eliana'], [u'Wilson']]) + def test_select_embedded(self): + class Profile(EmbeddedDocument): + name = StringField() + age = IntField() + + class Locale(EmbeddedDocument): + city = StringField() + country = StringField() + + class Person(Document): + profile = EmbeddedDocumentField(Profile) + locale = EmbeddedDocumentField(Locale) + + Person.drop_collection() + + Person(profile=Profile(name="Wilson Jr", age=19), + locale=Locale(city="Corumba-GO", country="Brazil")).save() + + Person(profile=Profile(name="Gabriel Falcao", age=23), + locale=Locale(city="New York", country="USA")).save() + + Person(profile=Profile(name="Lincoln de souza", age=28), + locale=Locale(city="Belo Horizonte", country="Brazil")).save() + + Person(profile=Profile(name="Walter cruz", age=30), + locale=Locale(city="Brasilia", country="Brazil")).save() + + self.assertEqual( + list(Person.objects.order_by('profile.age').select('profile.name')), + [[u'Wilson Jr'], [u'Gabriel Falcao'], + [u'Lincoln de souza'], [u'Walter cruz']]) + + ulist = list(Person.objects.order_by('locale.city') + .select('profile.name', 'profile.age', 'locale.city')) + self.assertEqual(ulist, + [[u'Lincoln de souza', 28, u'Belo Horizonte'], + [u'Walter cruz', 30, u'Brasilia'], + [u'Wilson Jr', 19, u'Corumba-GO'], + [u'Gabriel Falcao', 23, u'New York']]) + + if __name__ == '__main__': unittest.main() From ca7b2371fbbba4c705b6aaa12025b54eab329d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Tue, 13 Dec 2011 11:54:19 -0200 Subject: [PATCH 3/6] added support for dereferences --- mongoengine/queryset.py | 45 +++++++++++++++++++++++++++++++++-------- tests/queryset.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1a2ff1c3..80bff6f5 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -318,24 +318,51 @@ class SelectResult(object): """ Used for .select method in QuerySet """ - def __init__(self, cursor, fields): + def __init__(self, document_type, cursor, fields, dbfields): + from base import BaseField + from fields import ReferenceField + # Caches for optimization + self.ReferenceField = ReferenceField + self._cursor = cursor - self._fields = [f.split('.') for f in fields] - def _get_value(self, keys, data): + f = [] + for field, dbfield in itertools.izip(fields, dbfields): + + p = document_type + for path in field.split('.'): + if p and isinstance(p, BaseField): + p = p.lookup_member(path) + elif p: + p = getattr(p, path) + else: + break + + f.append((dbfield.split('.'), p)) + + self._fields = f + + def _get_value(self, keys, field_type, data): for key in keys: if data: data = data.get(key) else: break - return data + if isinstance(field_type, self.ReferenceField): + doc_type = field_type.document_type + data = doc_type._get_db().dereference(data) + + if data: + return doc_type._from_son(data) + + return field_type.to_python(data) def next(self): try: data = self._cursor.next() - return [self._get_value(f, data) - for f in self._fields] + return [self._get_value(k, t, data) + for k, t in self._fields] except StopIteration, e: self.rewind() raise e @@ -843,11 +870,13 @@ class QuerySet(object): """ Select a field and make a tuple of element """ + dbfields = self._fields_to_dbfields(fields) + cursor_args = self._cursor_args - cursor_args['fields'] = self._fields_to_dbfields(fields) + cursor_args['fields'] = dbfields cursor = self._build_cursor(**cursor_args) - return SelectResult(cursor, fields) + return SelectResult(self._document, cursor, fields, dbfields) def first(self): """Retrieve the first object matching the query. diff --git a/tests/queryset.py b/tests/queryset.py index 991fb00a..02e931e4 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2994,7 +2994,38 @@ class QueryFieldListTest(unittest.TestCase): [u'Walter cruz', 30, u'Brasilia'], [u'Wilson Jr', 19, u'Corumba-GO'], [u'Gabriel Falcao', 23, u'New York']]) + + def test_select_decimal(self): + from decimal import Decimal + class Person(Document): + name = StringField() + rating = DecimalField() + Person.drop_collection() + Person(name="Wilson Jr", rating=Decimal('1.0')).save() + + ulist = list(Person.objects.select('name', 'rating')) + self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]]) + + + def test_select_reference_field(self): + class State(Document): + name = StringField() + + class Person(Document): + name = StringField() + state = ReferenceField(State) + + State.drop_collection() + Person.drop_collection() + + s1 = State(name="Goias") + s1.save() + + Person(name="Wilson JR", state=s1).save() + + plist = list(Person.objects.select('name', 'state')) + self.assertEqual(plist, [[u'Wilson JR', s1]]) if __name__ == '__main__': unittest.main() From 7c1afd00313c5e0eaad64079efa8be127a0a4f1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Tue, 13 Dec 2011 11:56:35 -0200 Subject: [PATCH 4/6] tests for db_field --- tests/queryset.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/queryset.py b/tests/queryset.py index 02e931e4..4b619560 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -3027,5 +3027,24 @@ class QueryFieldListTest(unittest.TestCase): plist = list(Person.objects.select('name', 'state')) self.assertEqual(plist, [[u'Wilson JR', s1]]) + + def test_select_db_field(self): + class TestDoc(Document): + x = IntField(db_field="y") + y = BooleanField(db_field="x") + + TestDoc.drop_collection() + + TestDoc(x=10, y=True).save() + TestDoc(x=20, y=False).save() + TestDoc(x=30, y=True).save() + + plist = list(TestDoc.objects.select('x', 'y')) + + self.assertEqual(len(plist), 3) + self.assertEqual(plist[0], [10, True]) + self.assertEqual(plist[1], [20, False]) + self.assertEqual(plist[2], [30, True]) + if __name__ == '__main__': unittest.main() From 62219d96480dd54004c5563ec926ad07d127bd9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Fri, 16 Dec 2011 11:07:38 -0200 Subject: [PATCH 5/6] changed name --- mongoengine/queryset.py | 19 ++++++++++++------ tests/queryset.py | 44 +++++++++++++++++++++++++++++------------ 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 80bff6f5..23e581b8 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -314,15 +314,17 @@ class QueryFieldList(object): def __nonzero__(self): return bool(self.fields) -class SelectResult(object): +class ListResult(object): """ - Used for .select method in QuerySet + Used for .values_list method in QuerySet """ def __init__(self, document_type, cursor, fields, dbfields): from base import BaseField - from fields import ReferenceField + from fields import ReferenceField, GenericReferenceField # Caches for optimization + self.ReferenceField = ReferenceField + self.GenericReferenceField = GenericReferenceField self._cursor = cursor @@ -356,6 +358,10 @@ class SelectResult(object): if data: return doc_type._from_son(data) + elif isinstance(field_type, self.GenericReferenceField): + if data and isinstance(data, (dict, pymongo.dbref.DBRef)): + return field_type.dereference(data) + return field_type.to_python(data) def next(self): @@ -866,9 +872,10 @@ class QuerySet(object): doc.save() return doc - def select(self, *fields): + def values_list(self, *fields): """ - Select a field and make a tuple of element + make a list of elements + .. versionadded:: 0.6 """ dbfields = self._fields_to_dbfields(fields) @@ -876,7 +883,7 @@ class QuerySet(object): cursor_args['fields'] = dbfields cursor = self._build_cursor(**cursor_args) - return SelectResult(self._document, cursor, fields, dbfields) + return ListResult(self._document, cursor, fields, dbfields) def first(self): """Retrieve the first object matching the query. diff --git a/tests/queryset.py b/tests/queryset.py index 4b619560..b749340c 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2910,7 +2910,7 @@ class QueryFieldListTest(unittest.TestCase): ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) - def test_select(self): + def test_values_list(self): class TestDoc(Document): x = IntField() y = BooleanField() @@ -2921,7 +2921,7 @@ class QueryFieldListTest(unittest.TestCase): TestDoc(x=20, y=False).save() TestDoc(x=30, y=True).save() - plist = list(TestDoc.objects.select('x', 'y')) + plist = list(TestDoc.objects.values_list('x', 'y')) self.assertEqual(len(plist), 3) self.assertEqual(plist[0], [10, True]) @@ -2939,7 +2939,7 @@ class QueryFieldListTest(unittest.TestCase): UserDoc(name="Eliana", age=37).save() UserDoc(name="Tayza", age=15).save() - ulist = list(UserDoc.objects.select('name', 'age')) + ulist = list(UserDoc.objects.values_list('name', 'age')) self.assertEqual(ulist, [ [u'Wilson Jr', 19], @@ -2947,7 +2947,7 @@ class QueryFieldListTest(unittest.TestCase): [u'Eliana', 37], [u'Tayza', 15]]) - ulist = list(UserDoc.objects.order_by('age').select('name')) + ulist = list(UserDoc.objects.order_by('age').values_list('name')) self.assertEqual(ulist, [ [u'Tayza'], @@ -2955,7 +2955,7 @@ class QueryFieldListTest(unittest.TestCase): [u'Eliana'], [u'Wilson']]) - def test_select_embedded(self): + def test_values_list_embedded(self): class Profile(EmbeddedDocument): name = StringField() age = IntField() @@ -2983,19 +2983,19 @@ class QueryFieldListTest(unittest.TestCase): locale=Locale(city="Brasilia", country="Brazil")).save() self.assertEqual( - list(Person.objects.order_by('profile.age').select('profile.name')), + list(Person.objects.order_by('profile.age').values_list('profile.name')), [[u'Wilson Jr'], [u'Gabriel Falcao'], [u'Lincoln de souza'], [u'Walter cruz']]) ulist = list(Person.objects.order_by('locale.city') - .select('profile.name', 'profile.age', 'locale.city')) + .values_list('profile.name', 'profile.age', 'locale.city')) self.assertEqual(ulist, [[u'Lincoln de souza', 28, u'Belo Horizonte'], [u'Walter cruz', 30, u'Brasilia'], [u'Wilson Jr', 19, u'Corumba-GO'], [u'Gabriel Falcao', 23, u'New York']]) - def test_select_decimal(self): + def test_values_list_decimal(self): from decimal import Decimal class Person(Document): name = StringField() @@ -3004,11 +3004,11 @@ class QueryFieldListTest(unittest.TestCase): Person.drop_collection() Person(name="Wilson Jr", rating=Decimal('1.0')).save() - ulist = list(Person.objects.select('name', 'rating')) + ulist = list(Person.objects.values_list('name', 'rating')) self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]]) - def test_select_reference_field(self): + def test_values_list_reference_field(self): class State(Document): name = StringField() @@ -3024,11 +3024,29 @@ class QueryFieldListTest(unittest.TestCase): Person(name="Wilson JR", state=s1).save() - plist = list(Person.objects.select('name', 'state')) + plist = list(Person.objects.values_list('name', 'state')) self.assertEqual(plist, [[u'Wilson JR', s1]]) + def test_values_list_generic_reference_field(self): + class State(Document): + name = StringField() - def test_select_db_field(self): + class Person(Document): + name = StringField() + state = GenericReferenceField() + + State.drop_collection() + Person.drop_collection() + + s1 = State(name="Goias") + s1.save() + + Person(name="Wilson JR", state=s1).save() + + plist = list(Person.objects.values_list('name', 'state')) + self.assertEqual(plist, [[u'Wilson JR', s1]]) + + def test_values_list_db_field(self): class TestDoc(Document): x = IntField(db_field="y") y = BooleanField(db_field="x") @@ -3039,7 +3057,7 @@ class QueryFieldListTest(unittest.TestCase): TestDoc(x=20, y=False).save() TestDoc(x=30, y=True).save() - plist = list(TestDoc.objects.select('x', 'y')) + plist = list(TestDoc.objects.values_list('x', 'y')) self.assertEqual(len(plist), 3) self.assertEqual(plist[0], [10, True]) From 5ee4b4a5ac4193dc8add4168c0262f0786844a4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Fri, 16 Dec 2011 11:49:20 -0200 Subject: [PATCH 6/6] added count/len for ListResult --- mongoengine/queryset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 23e581b8..2f902273 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -376,6 +376,15 @@ class ListResult(object): def rewind(self): self._cursor.rewind() + def count(self): + """ + Count the selected elements in the query. + """ + return self._cursor.count(with_limit_and_skip=True) + + def __len__(self): + return self.count() + def __iter__(self): return self