added support for dereferences

This commit is contained in:
Wilson Júnior 2011-12-13 11:54:19 -02:00
parent ed5fba6b0f
commit ca7b2371fb
2 changed files with 68 additions and 8 deletions

View File

@ -318,24 +318,51 @@ class SelectResult(object):
"""
Used for .select method in QuerySet
"""
def __init__(self, cursor, fields):
def __init__(self, document_type, cursor, fields, dbfields):
from base import BaseField
from fields import ReferenceField
# Caches for optimization
self.ReferenceField = ReferenceField
self._cursor = cursor
self._fields = [f.split('.') for f in fields]
def _get_value(self, keys, data):
f = []
for field, dbfield in itertools.izip(fields, dbfields):
p = document_type
for path in field.split('.'):
if p and isinstance(p, BaseField):
p = p.lookup_member(path)
elif p:
p = getattr(p, path)
else:
break
f.append((dbfield.split('.'), p))
self._fields = f
def _get_value(self, keys, field_type, data):
for key in keys:
if data:
data = data.get(key)
else:
break
return data
if isinstance(field_type, self.ReferenceField):
doc_type = field_type.document_type
data = doc_type._get_db().dereference(data)
if data:
return doc_type._from_son(data)
return field_type.to_python(data)
def next(self):
try:
data = self._cursor.next()
return [self._get_value(f, data)
for f in self._fields]
return [self._get_value(k, t, data)
for k, t in self._fields]
except StopIteration, e:
self.rewind()
raise e
@ -843,11 +870,13 @@ class QuerySet(object):
"""
Select a field and make a tuple of element
"""
dbfields = self._fields_to_dbfields(fields)
cursor_args = self._cursor_args
cursor_args['fields'] = self._fields_to_dbfields(fields)
cursor_args['fields'] = dbfields
cursor = self._build_cursor(**cursor_args)
return SelectResult(cursor, fields)
return SelectResult(self._document, cursor, fields, dbfields)
def first(self):
"""Retrieve the first object matching the query.

View File

@ -2994,7 +2994,38 @@ class QueryFieldListTest(unittest.TestCase):
[u'Walter cruz', 30, u'Brasilia'],
[u'Wilson Jr', 19, u'Corumba-GO'],
[u'Gabriel Falcao', 23, u'New York']])
def test_select_decimal(self):
from decimal import Decimal
class Person(Document):
name = StringField()
rating = DecimalField()
Person.drop_collection()
Person(name="Wilson Jr", rating=Decimal('1.0')).save()
ulist = list(Person.objects.select('name', 'rating'))
self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]])
def test_select_reference_field(self):
class State(Document):
name = StringField()
class Person(Document):
name = StringField()
state = ReferenceField(State)
State.drop_collection()
Person.drop_collection()
s1 = State(name="Goias")
s1.save()
Person(name="Wilson JR", state=s1).save()
plist = list(Person.objects.select('name', 'state'))
self.assertEqual(plist, [[u'Wilson JR', s1]])
if __name__ == '__main__':
unittest.main()