diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0c39253b..1b3257ac 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -314,6 +314,27 @@ class QueryFieldList(object): def __nonzero__(self): return bool(self.fields) +class SelectResult(object): + """ + Used for .select method in QuerySet + """ + def __init__(self, cursor, fields): + self._cursor = cursor + self._fields = fields + + def next(self): + try: + data = self._cursor.next() + return [data.get(f) for f in self._fields] + except StopIteration, e: + self.rewind() + raise e + + def rewind(self): + self._cursor.rewind() + + def __iter__(self): + return self class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -548,33 +569,38 @@ class QuerySet(object): cursor_args['fields'] = self._loaded_fields.as_dict() return cursor_args + def _build_cursor(self, **cursor_args): + obj = self._collection.find(self._query, + **cursor_args) + # Apply where clauses to cursor + if self._where_clause: + obj.where(self._where_clause) + + # apply default ordering + if self._ordering: + obj.sort(self._ordering) + elif self._document._meta['ordering']: + self._ordering = self._get_order_key_list( + *self._document._meta['ordering']) + obj.sort(self._ordering) + + if self._limit is not None: + obj.limit(self._limit) + + if self._skip is not None: + obj.skip(self._skip) + + if self._hint != -1: + obj.hint(self._hint) + + return obj + @property def _cursor(self): if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - self._cursor_obj.where(self._where_clause) - - # apply default ordering - if self._ordering: - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - self.order_by(*self._document._meta['ordering']) - - if self._limit is not None: - self._cursor_obj.limit(self._limit) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - + self._cursor_obj = self._build_cursor(**self._cursor_args) + return self._cursor_obj - @classmethod def _lookup_field(cls, document, parts): """Lookup a field based on its attribute and return a list containing @@ -803,6 +829,16 @@ class QuerySet(object): doc.save() return doc + def select(self, *fields): + """ + Select a field and make a tuple of element + """ + cursor_args = self._cursor_args + cursor_args['fields'] = self._fields_to_dbfields(fields) + cursor = self._build_cursor(**cursor_args) + + return SelectResult(cursor, fields) + def first(self): """Retrieve the first object matching the query. """ @@ -1163,13 +1199,9 @@ class QuerySet(object): ret.append(field) return ret - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction + def _get_order_key_list(self, *keys): + """ + Build order list for query """ key_list = [] for key in keys: @@ -1186,6 +1218,18 @@ class QuerySet(object): pass key_list.append((key, direction)) + return key_list + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + + key_list = self._get_order_key_list(*keys) self._ordering = key_list self._cursor.sort(key_list) return self diff --git a/tests/queryset.py b/tests/queryset.py index 37fa5247..0f5d84cf 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2909,6 +2909,51 @@ class QueryFieldListTest(unittest.TestCase): ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) + + def test_select(self): + class TestDoc(Document): + x = IntField() + y = BooleanField() + + TestDoc.drop_collection() + + TestDoc(x=10, y=True).save() + TestDoc(x=20, y=False).save() + TestDoc(x=30, y=True).save() + + plist = list(TestDoc.objects.select('x', 'y')) + + self.assertEqual(len(plist), 3) + self.assertEqual(plist[0], [10, True]) + self.assertEqual(plist[1], [20, False]) + self.assertEqual(plist[2], [30, True]) + + class UserDoc(Document): + name = StringField() + age = IntField() + + UserDoc.drop_collection() + + UserDoc(name="Wilson Jr", age=19).save() + UserDoc(name="Wilson", age=43).save() + UserDoc(name="Eliana", age=37).save() + UserDoc(name="Tayza", age=15).save() + + ulist = list(UserDoc.objects.select('name', 'age')) + + self.assertEqual(ulist, [ + [u'Wilson Jr', 19], + [u'Wilson', 43], + [u'Eliana', 37], + [u'Tayza', 15]]) + + ulist = list(UserDoc.objects.order_by('age').select('name')) + + self.assertEqual(ulist, [ + [u'Tayza'], + [u'Wilson Jr'], + [u'Eliana'], + [u'Wilson']]) if __name__ == '__main__': unittest.main()