From ca7b2371fbbba4c705b6aaa12025b54eab329d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Tue, 13 Dec 2011 11:54:19 -0200 Subject: [PATCH] added support for dereferences --- mongoengine/queryset.py | 45 +++++++++++++++++++++++++++++++++-------- tests/queryset.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1a2ff1c3..80bff6f5 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -318,24 +318,51 @@ class SelectResult(object): """ Used for .select method in QuerySet """ - def __init__(self, cursor, fields): + def __init__(self, document_type, cursor, fields, dbfields): + from base import BaseField + from fields import ReferenceField + # Caches for optimization + self.ReferenceField = ReferenceField + self._cursor = cursor - self._fields = [f.split('.') for f in fields] - def _get_value(self, keys, data): + f = [] + for field, dbfield in itertools.izip(fields, dbfields): + + p = document_type + for path in field.split('.'): + if p and isinstance(p, BaseField): + p = p.lookup_member(path) + elif p: + p = getattr(p, path) + else: + break + + f.append((dbfield.split('.'), p)) + + self._fields = f + + def _get_value(self, keys, field_type, data): for key in keys: if data: data = data.get(key) else: break - return data + if isinstance(field_type, self.ReferenceField): + doc_type = field_type.document_type + data = doc_type._get_db().dereference(data) + + if data: + return doc_type._from_son(data) + + return field_type.to_python(data) def next(self): try: data = self._cursor.next() - return [self._get_value(f, data) - for f in self._fields] + return [self._get_value(k, t, data) + for k, t in self._fields] except StopIteration, e: self.rewind() raise e @@ -843,11 +870,13 @@ class QuerySet(object): """ Select a field and make a tuple of element """ + dbfields = self._fields_to_dbfields(fields) + cursor_args = self._cursor_args - cursor_args['fields'] = self._fields_to_dbfields(fields) + cursor_args['fields'] = dbfields cursor = self._build_cursor(**cursor_args) - return SelectResult(cursor, fields) + return SelectResult(self._document, cursor, fields, dbfields) def first(self): """Retrieve the first object matching the query. diff --git a/tests/queryset.py b/tests/queryset.py index 991fb00a..02e931e4 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2994,7 +2994,38 @@ class QueryFieldListTest(unittest.TestCase): [u'Walter cruz', 30, u'Brasilia'], [u'Wilson Jr', 19, u'Corumba-GO'], [u'Gabriel Falcao', 23, u'New York']]) + + def test_select_decimal(self): + from decimal import Decimal + class Person(Document): + name = StringField() + rating = DecimalField() + Person.drop_collection() + Person(name="Wilson Jr", rating=Decimal('1.0')).save() + + ulist = list(Person.objects.select('name', 'rating')) + self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]]) + + + def test_select_reference_field(self): + class State(Document): + name = StringField() + + class Person(Document): + name = StringField() + state = ReferenceField(State) + + State.drop_collection() + Person.drop_collection() + + s1 = State(name="Goias") + s1.save() + + Person(name="Wilson JR", state=s1).save() + + plist = list(Person.objects.select('name', 'state')) + self.assertEqual(plist, [[u'Wilson JR', s1]]) if __name__ == '__main__': unittest.main()