added support for dereferences
This commit is contained in:
parent
ed5fba6b0f
commit
ca7b2371fb
@ -318,24 +318,51 @@ class SelectResult(object):
|
|||||||
"""
|
"""
|
||||||
Used for .select method in QuerySet
|
Used for .select method in QuerySet
|
||||||
"""
|
"""
|
||||||
def __init__(self, cursor, fields):
|
def __init__(self, document_type, cursor, fields, dbfields):
|
||||||
self._cursor = cursor
|
from base import BaseField
|
||||||
self._fields = [f.split('.') for f in fields]
|
from fields import ReferenceField
|
||||||
|
# Caches for optimization
|
||||||
|
self.ReferenceField = ReferenceField
|
||||||
|
|
||||||
def _get_value(self, keys, data):
|
self._cursor = cursor
|
||||||
|
|
||||||
|
f = []
|
||||||
|
for field, dbfield in itertools.izip(fields, dbfields):
|
||||||
|
|
||||||
|
p = document_type
|
||||||
|
for path in field.split('.'):
|
||||||
|
if p and isinstance(p, BaseField):
|
||||||
|
p = p.lookup_member(path)
|
||||||
|
elif p:
|
||||||
|
p = getattr(p, path)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
f.append((dbfield.split('.'), p))
|
||||||
|
|
||||||
|
self._fields = f
|
||||||
|
|
||||||
|
def _get_value(self, keys, field_type, data):
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if data:
|
if data:
|
||||||
data = data.get(key)
|
data = data.get(key)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
return data
|
if isinstance(field_type, self.ReferenceField):
|
||||||
|
doc_type = field_type.document_type
|
||||||
|
data = doc_type._get_db().dereference(data)
|
||||||
|
|
||||||
|
if data:
|
||||||
|
return doc_type._from_son(data)
|
||||||
|
|
||||||
|
return field_type.to_python(data)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
try:
|
try:
|
||||||
data = self._cursor.next()
|
data = self._cursor.next()
|
||||||
return [self._get_value(f, data)
|
return [self._get_value(k, t, data)
|
||||||
for f in self._fields]
|
for k, t in self._fields]
|
||||||
except StopIteration, e:
|
except StopIteration, e:
|
||||||
self.rewind()
|
self.rewind()
|
||||||
raise e
|
raise e
|
||||||
@ -843,11 +870,13 @@ class QuerySet(object):
|
|||||||
"""
|
"""
|
||||||
Select a field and make a tuple of element
|
Select a field and make a tuple of element
|
||||||
"""
|
"""
|
||||||
|
dbfields = self._fields_to_dbfields(fields)
|
||||||
|
|
||||||
cursor_args = self._cursor_args
|
cursor_args = self._cursor_args
|
||||||
cursor_args['fields'] = self._fields_to_dbfields(fields)
|
cursor_args['fields'] = dbfields
|
||||||
cursor = self._build_cursor(**cursor_args)
|
cursor = self._build_cursor(**cursor_args)
|
||||||
|
|
||||||
return SelectResult(cursor, fields)
|
return SelectResult(self._document, cursor, fields, dbfields)
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
"""Retrieve the first object matching the query.
|
"""Retrieve the first object matching the query.
|
||||||
|
@ -2995,6 +2995,37 @@ class QueryFieldListTest(unittest.TestCase):
|
|||||||
[u'Wilson Jr', 19, u'Corumba-GO'],
|
[u'Wilson Jr', 19, u'Corumba-GO'],
|
||||||
[u'Gabriel Falcao', 23, u'New York']])
|
[u'Gabriel Falcao', 23, u'New York']])
|
||||||
|
|
||||||
|
def test_select_decimal(self):
|
||||||
|
from decimal import Decimal
|
||||||
|
class Person(Document):
|
||||||
|
name = StringField()
|
||||||
|
rating = DecimalField()
|
||||||
|
|
||||||
|
Person.drop_collection()
|
||||||
|
Person(name="Wilson Jr", rating=Decimal('1.0')).save()
|
||||||
|
|
||||||
|
ulist = list(Person.objects.select('name', 'rating'))
|
||||||
|
self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_reference_field(self):
|
||||||
|
class State(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Person(Document):
|
||||||
|
name = StringField()
|
||||||
|
state = ReferenceField(State)
|
||||||
|
|
||||||
|
State.drop_collection()
|
||||||
|
Person.drop_collection()
|
||||||
|
|
||||||
|
s1 = State(name="Goias")
|
||||||
|
s1.save()
|
||||||
|
|
||||||
|
Person(name="Wilson JR", state=s1).save()
|
||||||
|
|
||||||
|
plist = list(Person.objects.select('name', 'state'))
|
||||||
|
self.assertEqual(plist, [[u'Wilson JR', s1]])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user