Added .scalar to Queryset

More efficient than the previous .values_list implementation Ref #393
Reverted some of the .values_list code thats no longer needed.

Closes #415
This commit is contained in:
Ross Lawley 2012-01-27 11:41:42 +00:00
parent 9a190eb00d
commit f60a49d6f6
6 changed files with 123 additions and 187 deletions

View File

@ -1,11 +1,11 @@
The PRIMARY AUTHORS are (and/or have been): The PRIMARY AUTHORS are (and/or have been):
Ross Lawley <ross.lawley@gmail.com>
Harry Marr <harry@hmarr.com> Harry Marr <harry@hmarr.com>
Matt Dennewitz <mattdennewitz@gmail.com> Matt Dennewitz <mattdennewitz@gmail.com>
Deepak Thukral <iapain@yahoo.com> Deepak Thukral <iapain@yahoo.com>
Florian Schlachter <flori@n-schlachter.de> Florian Schlachter <flori@n-schlachter.de>
Steve Challis <steve@stevechallis.com> Steve Challis <steve@stevechallis.com>
Ross Lawley <ross.lawley@gmail.com>
Wilson Júnior <wilsonpjunior@gmail.com> Wilson Júnior <wilsonpjunior@gmail.com>
Dan Crosta https://github.com/dcrosta Dan Crosta https://github.com/dcrosta

View File

@ -5,8 +5,8 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Added scalar for efficiently returning partial data values (aliased to values_list)
- Fixed limit skip bug - Fixed limit skip bug
- Added values_list for returning a list of data
- Improved Inheritance / Mixin - Improved Inheritance / Mixin
- Added sharding support - Added sharding support
- Added pymongo 2.1 support - Added pymongo 2.1 support

View File

@ -468,14 +468,14 @@ class ObjectIdField(BaseField):
class DocumentMetaclass(type): class DocumentMetaclass(type):
"""Metaclass for all documents. """Metaclass for all documents.
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
def _get_mixin_fields(base): def _get_mixin_fields(base):
attrs = {} attrs = {}
attrs.update(dict([(k, v) for k, v in base.__dict__.items() attrs.update(dict([(k, v) for k, v in base.__dict__.items()
if issubclass(v.__class__, BaseField)])) if issubclass(v.__class__, BaseField)]))
for p_base in base.__bases__: for p_base in base.__bases__:
#optimize :-) #optimize :-)
if p_base in (object, BaseDocument): if p_base in (object, BaseDocument):

View File

@ -314,82 +314,6 @@ class QueryFieldList(object):
def __nonzero__(self): def __nonzero__(self):
return bool(self.fields) return bool(self.fields)
class ListResult(object):
"""
Used for .values_list method in QuerySet
"""
def __init__(self, document_type, cursor, fields, dbfields):
from base import BaseField
from fields import ReferenceField, GenericReferenceField
# Caches for optimization
self.ReferenceField = ReferenceField
self.GenericReferenceField = GenericReferenceField
self._cursor = cursor
f = []
for field, dbfield in itertools.izip(fields, dbfields):
p = document_type
for path in field.split('.'):
if p and isinstance(p, BaseField):
p = p.lookup_member(path)
elif p:
p = getattr(p, path)
else:
break
f.append((dbfield.split('.'), p))
self._fields = f
def _get_value(self, keys, field_type, data):
for key in keys:
if data:
data = data.get(key)
else:
break
if isinstance(field_type, self.ReferenceField):
doc_type = field_type.document_type
data = doc_type._get_db().dereference(data)
if data:
return doc_type._from_son(data)
elif isinstance(field_type, self.GenericReferenceField):
if data and isinstance(data, (dict, pymongo.dbref.DBRef)):
return field_type.dereference(data)
if data is None:
return
return field_type.to_python(data)
def next(self):
try:
data = self._cursor.next()
return [self._get_value(k, t, data)
for k, t in self._fields]
except StopIteration, e:
self.rewind()
raise e
def rewind(self):
self._cursor.rewind()
def count(self):
"""
Count the selected elements in the query.
"""
return self._cursor.count(with_limit_and_skip=True)
def __len__(self):
return self.count()
def __iter__(self):
return self
class QuerySet(object): class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor, """A set of results returned from a query. Wraps a MongoDB cursor,
@ -625,38 +549,33 @@ class QuerySet(object):
cursor_args['fields'] = self._loaded_fields.as_dict() cursor_args['fields'] = self._loaded_fields.as_dict()
return cursor_args return cursor_args
def _build_cursor(self, **cursor_args):
obj = self._collection.find(self._query,
**cursor_args)
# Apply where clauses to cursor
if self._where_clause:
obj.where(self._where_clause)
# apply default ordering
if self._ordering:
obj.sort(self._ordering)
elif self._document._meta['ordering']:
self._ordering = self._get_order_key_list(
*self._document._meta['ordering'])
obj.sort(self._ordering)
if self._limit is not None:
obj.limit(self._limit - (self._skip or 0))
if self._skip is not None:
obj.skip(self._skip)
if self._hint != -1:
obj.hint(self._hint)
return obj
@property @property
def _cursor(self): def _cursor(self):
if self._cursor_obj is None: if self._cursor_obj is None:
self._cursor_obj = self._build_cursor(**self._cursor_args)
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
self._cursor_obj.where(self._where_clause)
# apply default ordering
if self._ordering:
self._cursor_obj.sort(self._ordering)
elif self._document._meta['ordering']:
self.order_by(*self._document._meta['ordering'])
if self._limit is not None:
self._cursor_obj.limit(self._limit - (self._skip or 0))
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
return self._cursor_obj return self._cursor_obj
@classmethod @classmethod
def _lookup_field(cls, document, parts): def _lookup_field(cls, document, parts):
"""Lookup a field based on its attribute and return a list containing """Lookup a field based on its attribute and return a list containing
@ -885,19 +804,6 @@ class QuerySet(object):
doc.save() doc.save()
return doc return doc
def values_list(self, *fields):
"""
make a list of elements
.. versionadded:: 0.6
"""
dbfields = self._fields_to_dbfields(fields)
cursor_args = self._cursor_args
cursor_args['fields'] = dbfields
cursor = self._build_cursor(**cursor_args)
return ListResult(self._document, cursor, fields, dbfields)
def first(self): def first(self):
"""Retrieve the first object matching the query. """Retrieve the first object matching the query.
""" """
@ -1269,9 +1175,13 @@ class QuerySet(object):
ret.append(field) ret.append(field)
return ret return ret
def _get_order_key_list(self, *keys): def order_by(self, *keys):
""" """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
Build order list for query order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed.
:param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction
""" """
key_list = [] key_list = []
for key in keys: for key in keys:
@ -1288,18 +1198,6 @@ class QuerySet(object):
pass pass
key_list.append((key, direction)) key_list.append((key, direction))
return key_list
def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed.
:param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction
"""
key_list = self._get_order_key_list(*keys)
self._ordering = key_list self._ordering = key_list
self._cursor.sort(key_list) self._cursor.sort(key_list)
return self return self
@ -1503,37 +1401,43 @@ class QuerySet(object):
return self return self
def _get_scalar(self, doc): def _get_scalar(self, doc):
def lookup(obj, name): def lookup(obj, name):
chunks = name.split('__') chunks = name.split('__')
for chunk in chunks: for chunk in chunks:
if hasattr(obj, '_db_field_map'):
chunk = obj._db_field_map.get(chunk, chunk)
obj = getattr(obj, chunk) obj = getattr(obj, chunk)
return obj return obj
data = [lookup(doc, n) for n in self._scalar] data = [lookup(doc, n) for n in self._scalar]
if len(data) == 1: if len(data) == 1:
return data[0] return data[0]
return tuple(data) return tuple(data)
def scalar(self, *fields): def scalar(self, *fields):
"""Instead of returning Document instances, return either a specific """Instead of returning Document instances, return either a specific
value or a tuple of values in order. value or a tuple of values in order.
This effects all results and can be unset by calling ``scalar`` This effects all results and can be unset by calling ``scalar``
without arguments. Calls ``only`` automatically. without arguments. Calls ``only`` automatically.
:param fields: One or more fields to return instead of a Document. :param fields: One or more fields to return instead of a Document.
""" """
self._scalar = list(fields) self._scalar = list(fields)
if fields: if fields:
self.only(*fields) self.only(*fields)
else: else:
self.all_fields() self.all_fields()
return self return self
def values_list(self, *fields):
"""An alias for scalar"""
return self.scalar(*fields)
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be *fieldname* is the Python name of a field, *fieldname* will be

View File

@ -38,7 +38,9 @@ setup(name='mongoengine',
packages=find_packages(), packages=find_packages(),
author='Harry Marr', author='Harry Marr',
author_email='harry.marr@{nospam}gmail.com', author_email='harry.marr@{nospam}gmail.com',
url='http://hmarr.com/mongoengine/', maintainer="Ross Lawley",
maintainer_email="ross.lawley@gmail.com",
url='http://mongoengine.org/',
license='MIT', license='MIT',
include_package_data=True, include_package_data=True,
description=DESCRIPTION, description=DESCRIPTION,

View File

@ -2909,8 +2909,38 @@ class QueryFieldListTest(unittest.TestCase):
ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"}))
self.assertEqual([b1], ak) self.assertEqual([b1], ak)
def test_values_list(self): def test_scalar(self):
class Organization(Document):
id = ObjectIdField('_id')
name = StringField()
class User(Document):
id = ObjectIdField('_id')
name = StringField()
organization = ObjectIdField()
User.drop_collection()
Organization.drop_collection()
whitehouse = Organization(name="White House")
whitehouse.save()
User(name="Bob Dole", organization=whitehouse.id).save()
# Efficient way to get all unique organization names for a given
# set of users (Pretend this has additional filtering.)
user_orgs = set(User.objects.scalar('organization'))
orgs = Organization.objects(id__in=user_orgs).scalar('name')
self.assertEqual(list(orgs), ['White House'])
# Efficient for generating listings, too.
orgs = Organization.objects.scalar('name').in_bulk(list(user_orgs))
user_map = User.objects.scalar('name', 'organization')
user_listing = [(user, orgs[org]) for user, org in user_map]
self.assertEqual([("Bob Dole", "White House")], user_listing)
def test_scalar_simple(self):
class TestDoc(Document): class TestDoc(Document):
x = IntField() x = IntField()
y = BooleanField() y = BooleanField()
@ -2921,12 +2951,12 @@ class QueryFieldListTest(unittest.TestCase):
TestDoc(x=20, y=False).save() TestDoc(x=20, y=False).save()
TestDoc(x=30, y=True).save() TestDoc(x=30, y=True).save()
plist = list(TestDoc.objects.values_list('x', 'y')) plist = list(TestDoc.objects.scalar('x', 'y'))
self.assertEqual(len(plist), 3) self.assertEqual(len(plist), 3)
self.assertEqual(plist[0], [10, True]) self.assertEqual(plist[0], (10, True))
self.assertEqual(plist[1], [20, False]) self.assertEqual(plist[1], (20, False))
self.assertEqual(plist[2], [30, True]) self.assertEqual(plist[2], (30, True))
class UserDoc(Document): class UserDoc(Document):
name = StringField() name = StringField()
@ -2939,23 +2969,23 @@ class QueryFieldListTest(unittest.TestCase):
UserDoc(name="Eliana", age=37).save() UserDoc(name="Eliana", age=37).save()
UserDoc(name="Tayza", age=15).save() UserDoc(name="Tayza", age=15).save()
ulist = list(UserDoc.objects.values_list('name', 'age')) ulist = list(UserDoc.objects.scalar('name', 'age'))
self.assertEqual(ulist, [ self.assertEqual(ulist, [
[u'Wilson Jr', 19], (u'Wilson Jr', 19),
[u'Wilson', 43], (u'Wilson', 43),
[u'Eliana', 37], (u'Eliana', 37),
[u'Tayza', 15]]) (u'Tayza', 15)])
ulist = list(UserDoc.objects.order_by('age').values_list('name')) ulist = list(UserDoc.objects.scalar('name').order_by('age'))
self.assertEqual(ulist, [ self.assertEqual(ulist, [
[u'Tayza'], (u'Tayza'),
[u'Wilson Jr'], (u'Wilson Jr'),
[u'Eliana'], (u'Eliana'),
[u'Wilson']]) (u'Wilson')])
def test_values_list_embedded(self): def test_scalar_embedded(self):
class Profile(EmbeddedDocument): class Profile(EmbeddedDocument):
name = StringField() name = StringField()
age = IntField() age = IntField()
@ -2983,32 +3013,31 @@ class QueryFieldListTest(unittest.TestCase):
locale=Locale(city="Brasilia", country="Brazil")).save() locale=Locale(city="Brasilia", country="Brazil")).save()
self.assertEqual( self.assertEqual(
list(Person.objects.order_by('profile.age').values_list('profile.name')), list(Person.objects.order_by('profile__age').scalar('profile__name')),
[[u'Wilson Jr'], [u'Gabriel Falcao'], [u'Wilson Jr', u'Gabriel Falcao', u'Lincoln de souza', u'Walter cruz'])
[u'Lincoln de souza'], [u'Walter cruz']])
ulist = list(Person.objects.order_by('locale.city') ulist = list(Person.objects.order_by('locale.city')
.values_list('profile.name', 'profile.age', 'locale.city')) .scalar('profile__name', 'profile__age', 'locale__city'))
self.assertEqual(ulist, self.assertEqual(ulist,
[[u'Lincoln de souza', 28, u'Belo Horizonte'], [(u'Lincoln de souza', 28, u'Belo Horizonte'),
[u'Walter cruz', 30, u'Brasilia'], (u'Walter cruz', 30, u'Brasilia'),
[u'Wilson Jr', 19, u'Corumba-GO'], (u'Wilson Jr', 19, u'Corumba-GO'),
[u'Gabriel Falcao', 23, u'New York']]) (u'Gabriel Falcao', 23, u'New York')])
def test_values_list_decimal(self): def test_scalar_decimal(self):
from decimal import Decimal from decimal import Decimal
class Person(Document): class Person(Document):
name = StringField() name = StringField()
rating = DecimalField() rating = DecimalField()
Person.drop_collection() Person.drop_collection()
Person(name="Wilson Jr", rating=Decimal('1.0')).save() Person(name="Wilson Jr", rating=Decimal('1.0')).save()
ulist = list(Person.objects.values_list('name', 'rating')) ulist = list(Person.objects.scalar('name', 'rating'))
self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]]) self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))])
def test_values_list_reference_field(self): def test_scalar_reference_field(self):
class State(Document): class State(Document):
name = StringField() name = StringField()
@ -3024,10 +3053,10 @@ class QueryFieldListTest(unittest.TestCase):
Person(name="Wilson JR", state=s1).save() Person(name="Wilson JR", state=s1).save()
plist = list(Person.objects.values_list('name', 'state')) plist = list(Person.objects.scalar('name', 'state'))
self.assertEqual(plist, [[u'Wilson JR', s1]]) self.assertEqual(plist, [(u'Wilson JR', s1)])
def test_values_list_generic_reference_field(self): def test_scalar_generic_reference_field(self):
class State(Document): class State(Document):
name = StringField() name = StringField()
@ -3043,13 +3072,14 @@ class QueryFieldListTest(unittest.TestCase):
Person(name="Wilson JR", state=s1).save() Person(name="Wilson JR", state=s1).save()
plist = list(Person.objects.values_list('name', 'state')) plist = list(Person.objects.scalar('name', 'state'))
self.assertEqual(plist, [[u'Wilson JR', s1]]) self.assertEqual(plist, [(u'Wilson JR', s1)])
def test_scalar_db_field(self):
def test_values_list_db_field(self):
class TestDoc(Document): class TestDoc(Document):
x = IntField(db_field="y") x = IntField()
y = BooleanField(db_field="x") y = BooleanField()
TestDoc.drop_collection() TestDoc.drop_collection()
@ -3057,12 +3087,12 @@ class QueryFieldListTest(unittest.TestCase):
TestDoc(x=20, y=False).save() TestDoc(x=20, y=False).save()
TestDoc(x=30, y=True).save() TestDoc(x=30, y=True).save()
plist = list(TestDoc.objects.values_list('x', 'y')) plist = list(TestDoc.objects.scalar('x', 'y'))
self.assertEqual(len(plist), 3) self.assertEqual(len(plist), 3)
self.assertEqual(plist[0], [10, True]) self.assertEqual(plist[0], (10, True))
self.assertEqual(plist[1], [20, False]) self.assertEqual(plist[1], (20, False))
self.assertEqual(plist[2], [30, True]) self.assertEqual(plist[2], (30, True))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()