From f60a49d6f6fc4ad213597667c957a69607fbcd44 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 27 Jan 2012 11:41:42 +0000 Subject: [PATCH] Added .scalar to Queryset More efficient than the previous .values_list implementation Ref #393 Reverted some of the .values_list code thats no longer needed. Closes #415 --- AUTHORS | 2 +- docs/changelog.rst | 2 +- mongoengine/base.py | 6 +- mongoengine/queryset.py | 180 ++++++++++------------------------------ setup.py | 4 +- tests/queryset.py | 116 ++++++++++++++++---------- 6 files changed, 123 insertions(+), 187 deletions(-) diff --git a/AUTHORS b/AUTHORS index d7e10030..2b52ecb5 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,11 +1,11 @@ The PRIMARY AUTHORS are (and/or have been): +Ross Lawley Harry Marr Matt Dennewitz Deepak Thukral Florian Schlachter Steve Challis -Ross Lawley Wilson JĂșnior Dan Crosta https://github.com/dcrosta diff --git a/docs/changelog.rst b/docs/changelog.rst index 0215335f..75230a28 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,8 +5,8 @@ Changelog Changes in dev ============== +- Added scalar for efficiently returning partial data values (aliased to values_list) - Fixed limit skip bug -- Added values_list for returning a list of data - Improved Inheritance / Mixin - Added sharding support - Added pymongo 2.1 support diff --git a/mongoengine/base.py b/mongoengine/base.py index 163a198f..369c10f3 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -468,14 +468,14 @@ class ObjectIdField(BaseField): class DocumentMetaclass(type): """Metaclass for all documents. """ - - + + def __new__(cls, name, bases, attrs): def _get_mixin_fields(base): attrs = {} attrs.update(dict([(k, v) for k, v in base.__dict__.items() if issubclass(v.__class__, BaseField)])) - + for p_base in base.__bases__: #optimize :-) if p_base in (object, BaseDocument): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index c36390b9..c4ac6375 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -314,82 +314,6 @@ class QueryFieldList(object): def __nonzero__(self): return bool(self.fields) -class ListResult(object): - """ - Used for .values_list method in QuerySet - """ - def __init__(self, document_type, cursor, fields, dbfields): - from base import BaseField - from fields import ReferenceField, GenericReferenceField - # Caches for optimization - - self.ReferenceField = ReferenceField - self.GenericReferenceField = GenericReferenceField - - self._cursor = cursor - - f = [] - for field, dbfield in itertools.izip(fields, dbfields): - - p = document_type - for path in field.split('.'): - if p and isinstance(p, BaseField): - p = p.lookup_member(path) - elif p: - p = getattr(p, path) - else: - break - - f.append((dbfield.split('.'), p)) - - self._fields = f - - def _get_value(self, keys, field_type, data): - for key in keys: - if data: - data = data.get(key) - else: - break - - if isinstance(field_type, self.ReferenceField): - doc_type = field_type.document_type - data = doc_type._get_db().dereference(data) - - if data: - return doc_type._from_son(data) - - elif isinstance(field_type, self.GenericReferenceField): - if data and isinstance(data, (dict, pymongo.dbref.DBRef)): - return field_type.dereference(data) - - if data is None: - return - - return field_type.to_python(data) - - def next(self): - try: - data = self._cursor.next() - return [self._get_value(k, t, data) - for k, t in self._fields] - except StopIteration, e: - self.rewind() - raise e - - def rewind(self): - self._cursor.rewind() - - def count(self): - """ - Count the selected elements in the query. - """ - return self._cursor.count(with_limit_and_skip=True) - - def __len__(self): - return self.count() - - def __iter__(self): - return self class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -625,38 +549,33 @@ class QuerySet(object): cursor_args['fields'] = self._loaded_fields.as_dict() return cursor_args - def _build_cursor(self, **cursor_args): - obj = self._collection.find(self._query, - **cursor_args) - # Apply where clauses to cursor - if self._where_clause: - obj.where(self._where_clause) - - # apply default ordering - if self._ordering: - obj.sort(self._ordering) - elif self._document._meta['ordering']: - self._ordering = self._get_order_key_list( - *self._document._meta['ordering']) - obj.sort(self._ordering) - - if self._limit is not None: - obj.limit(self._limit - (self._skip or 0)) - - if self._skip is not None: - obj.skip(self._skip) - - if self._hint != -1: - obj.hint(self._hint) - - return obj - @property def _cursor(self): if self._cursor_obj is None: - self._cursor_obj = self._build_cursor(**self._cursor_args) + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply where clauses to cursor + if self._where_clause: + self._cursor_obj.where(self._where_clause) + + # apply default ordering + if self._ordering: + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: + self.order_by(*self._document._meta['ordering']) + + if self._limit is not None: + self._cursor_obj.limit(self._limit - (self._skip or 0)) + + if self._skip is not None: + self._cursor_obj.skip(self._skip) + + if self._hint != -1: + self._cursor_obj.hint(self._hint) return self._cursor_obj + @classmethod def _lookup_field(cls, document, parts): """Lookup a field based on its attribute and return a list containing @@ -885,19 +804,6 @@ class QuerySet(object): doc.save() return doc - def values_list(self, *fields): - """ - make a list of elements - .. versionadded:: 0.6 - """ - dbfields = self._fields_to_dbfields(fields) - - cursor_args = self._cursor_args - cursor_args['fields'] = dbfields - cursor = self._build_cursor(**cursor_args) - - return ListResult(self._document, cursor, fields, dbfields) - def first(self): """Retrieve the first object matching the query. """ @@ -1269,9 +1175,13 @@ class QuerySet(object): ret.append(field) return ret - def _get_order_key_list(self, *keys): - """ - Build order list for query + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction """ key_list = [] for key in keys: @@ -1288,18 +1198,6 @@ class QuerySet(object): pass key_list.append((key, direction)) - return key_list - - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction - """ - - key_list = self._get_order_key_list(*keys) self._ordering = key_list self._cursor.sort(key_list) return self @@ -1503,37 +1401,43 @@ class QuerySet(object): return self def _get_scalar(self, doc): + def lookup(obj, name): chunks = name.split('__') for chunk in chunks: + if hasattr(obj, '_db_field_map'): + chunk = obj._db_field_map.get(chunk, chunk) obj = getattr(obj, chunk) return obj - + data = [lookup(doc, n) for n in self._scalar] - if len(data) == 1: return data[0] - + return tuple(data) def scalar(self, *fields): """Instead of returning Document instances, return either a specific value or a tuple of values in order. - + This effects all results and can be unset by calling ``scalar`` without arguments. Calls ``only`` automatically. - + :param fields: One or more fields to return instead of a Document. """ self._scalar = list(fields) - + if fields: self.only(*fields) else: self.all_fields() - + return self + def values_list(self, *fields): + """An alias for scalar""" + return self.scalar(*fields) + def _sub_js_fields(self, code): """When fields are specified with [~fieldname] syntax, where *fieldname* is the Python name of a field, *fieldname* will be diff --git a/setup.py b/setup.py index b0c29bf0..7a69d83a 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,9 @@ setup(name='mongoengine', packages=find_packages(), author='Harry Marr', author_email='harry.marr@{nospam}gmail.com', - url='http://hmarr.com/mongoengine/', + maintainer="Ross Lawley", + maintainer_email="ross.lawley@gmail.com", + url='http://mongoengine.org/', license='MIT', include_package_data=True, description=DESCRIPTION, diff --git a/tests/queryset.py b/tests/queryset.py index b749340c..a84fd5b0 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2909,8 +2909,38 @@ class QueryFieldListTest(unittest.TestCase): ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) - - def test_values_list(self): + + def test_scalar(self): + + class Organization(Document): + id = ObjectIdField('_id') + name = StringField() + + class User(Document): + id = ObjectIdField('_id') + name = StringField() + organization = ObjectIdField() + + User.drop_collection() + Organization.drop_collection() + + whitehouse = Organization(name="White House") + whitehouse.save() + User(name="Bob Dole", organization=whitehouse.id).save() + + # Efficient way to get all unique organization names for a given + # set of users (Pretend this has additional filtering.) + user_orgs = set(User.objects.scalar('organization')) + orgs = Organization.objects(id__in=user_orgs).scalar('name') + self.assertEqual(list(orgs), ['White House']) + + # Efficient for generating listings, too. + orgs = Organization.objects.scalar('name').in_bulk(list(user_orgs)) + user_map = User.objects.scalar('name', 'organization') + user_listing = [(user, orgs[org]) for user, org in user_map] + self.assertEqual([("Bob Dole", "White House")], user_listing) + + def test_scalar_simple(self): class TestDoc(Document): x = IntField() y = BooleanField() @@ -2921,12 +2951,12 @@ class QueryFieldListTest(unittest.TestCase): TestDoc(x=20, y=False).save() TestDoc(x=30, y=True).save() - plist = list(TestDoc.objects.values_list('x', 'y')) + plist = list(TestDoc.objects.scalar('x', 'y')) self.assertEqual(len(plist), 3) - self.assertEqual(plist[0], [10, True]) - self.assertEqual(plist[1], [20, False]) - self.assertEqual(plist[2], [30, True]) + self.assertEqual(plist[0], (10, True)) + self.assertEqual(plist[1], (20, False)) + self.assertEqual(plist[2], (30, True)) class UserDoc(Document): name = StringField() @@ -2939,23 +2969,23 @@ class QueryFieldListTest(unittest.TestCase): UserDoc(name="Eliana", age=37).save() UserDoc(name="Tayza", age=15).save() - ulist = list(UserDoc.objects.values_list('name', 'age')) + ulist = list(UserDoc.objects.scalar('name', 'age')) self.assertEqual(ulist, [ - [u'Wilson Jr', 19], - [u'Wilson', 43], - [u'Eliana', 37], - [u'Tayza', 15]]) + (u'Wilson Jr', 19), + (u'Wilson', 43), + (u'Eliana', 37), + (u'Tayza', 15)]) - ulist = list(UserDoc.objects.order_by('age').values_list('name')) + ulist = list(UserDoc.objects.scalar('name').order_by('age')) self.assertEqual(ulist, [ - [u'Tayza'], - [u'Wilson Jr'], - [u'Eliana'], - [u'Wilson']]) + (u'Tayza'), + (u'Wilson Jr'), + (u'Eliana'), + (u'Wilson')]) - def test_values_list_embedded(self): + def test_scalar_embedded(self): class Profile(EmbeddedDocument): name = StringField() age = IntField() @@ -2983,32 +3013,31 @@ class QueryFieldListTest(unittest.TestCase): locale=Locale(city="Brasilia", country="Brazil")).save() self.assertEqual( - list(Person.objects.order_by('profile.age').values_list('profile.name')), - [[u'Wilson Jr'], [u'Gabriel Falcao'], - [u'Lincoln de souza'], [u'Walter cruz']]) + list(Person.objects.order_by('profile__age').scalar('profile__name')), + [u'Wilson Jr', u'Gabriel Falcao', u'Lincoln de souza', u'Walter cruz']) ulist = list(Person.objects.order_by('locale.city') - .values_list('profile.name', 'profile.age', 'locale.city')) + .scalar('profile__name', 'profile__age', 'locale__city')) self.assertEqual(ulist, - [[u'Lincoln de souza', 28, u'Belo Horizonte'], - [u'Walter cruz', 30, u'Brasilia'], - [u'Wilson Jr', 19, u'Corumba-GO'], - [u'Gabriel Falcao', 23, u'New York']]) + [(u'Lincoln de souza', 28, u'Belo Horizonte'), + (u'Walter cruz', 30, u'Brasilia'), + (u'Wilson Jr', 19, u'Corumba-GO'), + (u'Gabriel Falcao', 23, u'New York')]) - def test_values_list_decimal(self): + def test_scalar_decimal(self): from decimal import Decimal class Person(Document): name = StringField() rating = DecimalField() - + Person.drop_collection() Person(name="Wilson Jr", rating=Decimal('1.0')).save() - ulist = list(Person.objects.values_list('name', 'rating')) - self.assertEqual(ulist, [[u'Wilson Jr', Decimal('1.0')]]) + ulist = list(Person.objects.scalar('name', 'rating')) + self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))]) - def test_values_list_reference_field(self): + def test_scalar_reference_field(self): class State(Document): name = StringField() @@ -3024,10 +3053,10 @@ class QueryFieldListTest(unittest.TestCase): Person(name="Wilson JR", state=s1).save() - plist = list(Person.objects.values_list('name', 'state')) - self.assertEqual(plist, [[u'Wilson JR', s1]]) + plist = list(Person.objects.scalar('name', 'state')) + self.assertEqual(plist, [(u'Wilson JR', s1)]) - def test_values_list_generic_reference_field(self): + def test_scalar_generic_reference_field(self): class State(Document): name = StringField() @@ -3043,13 +3072,14 @@ class QueryFieldListTest(unittest.TestCase): Person(name="Wilson JR", state=s1).save() - plist = list(Person.objects.values_list('name', 'state')) - self.assertEqual(plist, [[u'Wilson JR', s1]]) + plist = list(Person.objects.scalar('name', 'state')) + self.assertEqual(plist, [(u'Wilson JR', s1)]) + + def test_scalar_db_field(self): - def test_values_list_db_field(self): class TestDoc(Document): - x = IntField(db_field="y") - y = BooleanField(db_field="x") + x = IntField() + y = BooleanField() TestDoc.drop_collection() @@ -3057,12 +3087,12 @@ class QueryFieldListTest(unittest.TestCase): TestDoc(x=20, y=False).save() TestDoc(x=30, y=True).save() - plist = list(TestDoc.objects.values_list('x', 'y')) - + plist = list(TestDoc.objects.scalar('x', 'y')) self.assertEqual(len(plist), 3) - self.assertEqual(plist[0], [10, True]) - self.assertEqual(plist[1], [20, False]) - self.assertEqual(plist[2], [30, True]) + self.assertEqual(plist[0], (10, True)) + self.assertEqual(plist[1], (20, False)) + self.assertEqual(plist[2], (30, True)) + if __name__ == '__main__': unittest.main()