From 29c887f30b0f7db13d30c920d29d2b4f2f490047 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 20 Aug 2013 12:21:20 +0000 Subject: [PATCH] Updated field filter logic - can now exclude subclass fields (#443) --- docs/changelog.rst | 1 + mongoengine/queryset/base.py | 31 ++++++++++++++++++++++++++----- tests/queryset/field_list.py | 23 +++++++++++++++++++++++ 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 74e2e503..489f2ff5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.4 ================ +- Fixed can now exclude subclass fields (#443) - Fixed dereference issue with embedded listfield referencefields (#439) - Fixed slice when using inheritance causing fields to be excluded (#437) - Fixed ._get_db() attribute after a Document.switch_db() (#441) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index a6ba49b2..7af9daad 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -14,8 +14,9 @@ from pymongo.common import validate_read_preference from mongoengine import signals from mongoengine.common import _import_class +from mongoengine.base.common import get_document from mongoengine.errors import (OperationError, NotUniqueError, - InvalidQueryError) + InvalidQueryError, LookUpError) from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList @@ -1333,13 +1334,33 @@ class BaseQuerySet(object): return frequencies - def _fields_to_dbfields(self, fields): + def _fields_to_dbfields(self, fields, subdoc=False): """Translate fields paths to its db equivalents""" ret = [] + subclasses = [] + document = self._document + if document._meta['allow_inheritance']: + subclasses = [get_document(x) + for x in document._subclasses][1:] for field in fields: - field = ".".join(f.db_field for f in - self._document._lookup_field(field.split('.'))) - ret.append(field) + try: + field = ".".join(f.db_field for f in + document._lookup_field(field.split('.'))) + ret.append(field) + except LookUpError, e: + found = False + for subdoc in subclasses: + try: + subfield = ".".join(f.db_field for f in + subdoc._lookup_field(field.split('.'))) + ret.append(subfield) + found = True + break + except LookUpError, e: + pass + + if not found: + raise e return ret def _get_order_by(self, keys): diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index a18e167e..7d66d263 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -399,5 +399,28 @@ class OnlyExcludeAllTest(unittest.TestCase): numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + def test_exclude_from_subclasses_docs(self): + + class Base(Document): + username = StringField() + + meta = {'allow_inheritance': True} + + class Anon(Base): + anon = BooleanField() + + class User(Base): + password = StringField() + wibble = StringField() + + Base.drop_collection() + User(username="mongodb", password="secret").save() + + user = Base.objects().exclude("password", "wibble").first() + self.assertEqual(user.password, None) + + self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + if __name__ == '__main__': unittest.main()