diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0c39253b..2f902273 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -314,6 +314,79 @@ class QueryFieldList(object): def __nonzero__(self): return bool(self.fields) +class ListResult(object): + """ + Used for .values_list method in QuerySet + """ + def __init__(self, document_type, cursor, fields, dbfields): + from base import BaseField + from fields import ReferenceField, GenericReferenceField + # Caches for optimization + + self.ReferenceField = ReferenceField + self.GenericReferenceField = GenericReferenceField + + self._cursor = cursor + + f = [] + for field, dbfield in itertools.izip(fields, dbfields): + + p = document_type + for path in field.split('.'): + if p and isinstance(p, BaseField): + p = p.lookup_member(path) + elif p: + p = getattr(p, path) + else: + break + + f.append((dbfield.split('.'), p)) + + self._fields = f + + def _get_value(self, keys, field_type, data): + for key in keys: + if data: + data = data.get(key) + else: + break + + if isinstance(field_type, self.ReferenceField): + doc_type = field_type.document_type + data = doc_type._get_db().dereference(data) + + if data: + return doc_type._from_son(data) + + elif isinstance(field_type, self.GenericReferenceField): + if data and isinstance(data, (dict, pymongo.dbref.DBRef)): + return field_type.dereference(data) + + return field_type.to_python(data) + + def next(self): + try: + data = self._cursor.next() + return [self._get_value(k, t, data) + for k, t in self._fields] + except StopIteration, e: + self.rewind() + raise e + + def rewind(self): + self._cursor.rewind() + + def count(self): + """ + Count the selected elements in the query. + """ + return self._cursor.count(with_limit_and_skip=True) + + def __len__(self): + return self.count() + + def __iter__(self): + return self class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -548,33 +621,38 @@ class QuerySet(object): cursor_args['fields'] = self._loaded_fields.as_dict() return cursor_args + def _build_cursor(self, **cursor_args): + obj = self._collection.find(self._query, + **cursor_args) + # Apply where clauses to cursor + if self._where_clause: + obj.where(self._where_clause) + + # apply default ordering + if self._ordering: + obj.sort(self._ordering) + elif self._document._meta['ordering']: + self._ordering = self._get_order_key_list( + *self._document._meta['ordering']) + obj.sort(self._ordering) + + if self._limit is not None: + obj.limit(self._limit) + + if self._skip is not None: + obj.skip(self._skip) + + if self._hint != -1: + obj.hint(self._hint) + + return obj + @property def _cursor(self): if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - self._cursor_obj.where(self._where_clause) - - # apply default ordering - if self._ordering: - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - self.order_by(*self._document._meta['ordering']) - - if self._limit is not None: - self._cursor_obj.limit(self._limit) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - + self._cursor_obj = self._build_cursor(**self._cursor_args) + return self._cursor_obj - @classmethod def _lookup_field(cls, document, parts): """Lookup a field based on its attribute and return a list containing @@ -803,6 +881,19 @@ class QuerySet(object): doc.save() return doc + def values_list(self, *fields): + """ + make a list of elements + .. versionadded:: 0.6 + """ + dbfields = self._fields_to_dbfields(fields) + + cursor_args = self._cursor_args + cursor_args['fields'] = dbfields + cursor = self._build_cursor(**cursor_args) + + return ListResult(self._document, cursor, fields, dbfields) + def first(self): """Retrieve the first object matching the query. """ @@ -1163,13 +1254,9 @@ class QuerySet(object): ret.append(field) return ret - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction + def _get_order_key_list(self, *keys): + """ + Build order list for query """ key_list = [] for key in keys: @@ -1186,6 +1273,18 @@ class QuerySet(object): pass key_list.append((key, direction)) + return key_list + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + + key_list = self._get_order_key_list(*keys) self._ordering = key_list self._cursor.sort(key_list) return self diff --git a/tests/queryset.py b/tests/queryset.py index 37fa5247..b749340c 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2909,6 +2909,160 @@ class QueryFieldListTest(unittest.TestCase): ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) + + def test_values_list(self): + class TestDoc(Document): + x = IntField() + y = BooleanField() + + TestDoc.drop_collection() + + TestDoc(x=10, y=True).save() + TestDoc(x=20, y=False).save() + TestDoc(x=30, y=True).save() + + plist = list(TestDoc.objects.values_list('x', 'y')) + + self.assertEqual(len(plist), 3) + self.assertEqual(plist[0], [10, True]) + self.assertEqual(plist[1], [20, False]) + self.assertEqual(plist[2], [30, True]) + + class UserDoc(Document): + name = StringField() + age = IntField() + + UserDoc.drop_collection() + + UserDoc(name="Wilson Jr", age=19).save() + UserDoc(name="Wilson", age=43).save() + UserDoc(name="Eliana", age=37).save() + UserDoc(name="Tayza", age=15).save() + + ulist = list(UserDoc.objects.values_list('name', 'age')) + + self.assertEqual(ulist, [ + [u'Wilson Jr', 19], + [u'Wilson', 43], + [u'Eliana', 37], + [u'Tayza', 15]]) + + ulist = list(UserDoc.objects.order_by('age').values_list('name')) + + self.assertEqual(ulist, [ + [u'Tayza'], + [u'Wilson Jr'], + [u'Eliana'], + [u'Wilson']]) + + def test_values_list_embedded(self): + class Profile(EmbeddedDocument): + name = StringField() + age = IntField() + + class Locale(EmbeddedDocument): + city = StringField() + country = StringField() + + class Person(Document): + profile = EmbeddedDocumentField(Profile) + locale = EmbeddedDocumentField(Locale) + + Person.drop_collection() + + Person(profile=Profile(name="Wilson Jr", age=19), + locale=Locale(city="Corumba-GO", country="Brazil")).save() + + Person(profile=Profile(name="Gabriel Falcao", age=23), + locale=Locale(city="New York", country="USA")).save() + + Person(profile=Profile(name="Lincoln de souza", age=28), + locale=Locale(city="Belo Horizonte", country="Brazil")).save() + + Person(profile=Profile(name="Walter cruz", age=30), + locale=Locale(city="Brasilia", country="Brazil")).save() + + self.assertEqual( + list(Person.objects.order_by('profile.age').values_list('profile.name')), + [[u'Wilson Jr'], [u'Gabriel Falcao'], + [u'Lincoln de souza'], [u'Walter cruz']]) + + ulist = list(Person.objects.order_by('locale.city') + .values_list('profile.name', 'profile.age', 'locale.city')) + self.assertEqual(ulist, + [[u'Lincoln de souza', 28, u'Belo Horizonte'], + [u'Walter cruz', 30, u'Brasilia'], + [u'Wilson Jr', 19, u'Corumba-GO'], + [u'Gabriel Falcao', 23, u'New York']]) + + def test_values_list_decimal(self): + from decimal import Decimal + class Person(Document): + name = StringField() + rating = DecimalField() + + Person.drop_collection() + Person(name="Wilson Jr", rating=Decimal('1.0')).save() + + ulist = list(Person.objects.values_list('name', 'rating')) + self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]]) + + + def test_values_list_reference_field(self): + class State(Document): + name = StringField() + + class Person(Document): + name = StringField() + state = ReferenceField(State) + + State.drop_collection() + Person.drop_collection() + + s1 = State(name="Goias") + s1.save() + + Person(name="Wilson JR", state=s1).save() + + plist = list(Person.objects.values_list('name', 'state')) + self.assertEqual(plist, [[u'Wilson JR', s1]]) + + def test_values_list_generic_reference_field(self): + class State(Document): + name = StringField() + + class Person(Document): + name = StringField() + state = GenericReferenceField() + + State.drop_collection() + Person.drop_collection() + + s1 = State(name="Goias") + s1.save() + + Person(name="Wilson JR", state=s1).save() + + plist = list(Person.objects.values_list('name', 'state')) + self.assertEqual(plist, [[u'Wilson JR', s1]]) + + def test_values_list_db_field(self): + class TestDoc(Document): + x = IntField(db_field="y") + y = BooleanField(db_field="x") + + TestDoc.drop_collection() + + TestDoc(x=10, y=True).save() + TestDoc(x=20, y=False).save() + TestDoc(x=30, y=True).save() + + plist = list(TestDoc.objects.values_list('x', 'y')) + + self.assertEqual(len(plist), 3) + self.assertEqual(plist[0], [10, True]) + self.assertEqual(plist[1], [20, False]) + self.assertEqual(plist[2], [30, True]) if __name__ == '__main__': unittest.main()