Added ability to have scalar return values instead of partially-populated Document instances.

This commit is contained in:
Alice Bevan-McGregor 2012-01-25 17:06:58 -05:00 committed by Ross Lawley
parent 6bad4bd415
commit 9a190eb00d

View File

@ -411,6 +411,7 @@ class QuerySet(object):
self._timeout = True
self._class_check = True
self._slave_okay = False
self._scalar = []
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
@ -977,8 +978,13 @@ class QuerySet(object):
docs = self._collection.find({'_id': {'$in': object_ids}},
**self._cursor_args)
for doc in docs:
doc_map[doc['_id']] = self._document._from_son(doc)
if self._scalar:
for doc in docs:
doc_map[doc['_id']] = self._get_scalar(
self._document._from_son(doc))
else:
for doc in docs:
doc_map[doc['_id']] = self._document._from_son(doc)
return doc_map
@ -988,6 +994,9 @@ class QuerySet(object):
try:
if self._limit == 0:
raise StopIteration
if self._scalar:
return self._get_scalar(self._document._from_son(
self._cursor.next()))
return self._document._from_son(self._cursor.next())
except StopIteration, e:
self.rewind()
@ -1164,6 +1173,9 @@ class QuerySet(object):
return self
# Integer index provided
elif isinstance(key, int):
if self._scalar:
return self._get_scalar(self._document._from_son(
self._cursor[key]))
return self._document._from_son(self._cursor[key])
raise AttributeError
@ -1490,6 +1502,38 @@ class QuerySet(object):
self.rewind()
return self
def _get_scalar(self, doc):
def lookup(obj, name):
chunks = name.split('__')
for chunk in chunks:
obj = getattr(obj, chunk)
return obj
data = [lookup(doc, n) for n in self._scalar]
if len(data) == 1:
return data[0]
return tuple(data)
def scalar(self, *fields):
"""Instead of returning Document instances, return either a specific
value or a tuple of values in order.
This effects all results and can be unset by calling ``scalar``
without arguments. Calls ``only`` automatically.
:param fields: One or more fields to return instead of a Document.
"""
self._scalar = list(fields)
if fields:
self.only(*fields)
else:
self.all_fields()
return self
def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be