diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 5077f895..dc5fab4c 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -71,7 +71,10 @@ class QuerySet(object): # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get('allow_inheritance') is True: - self._initial_query = {"_cls": {"$in": self._document._subclasses}} + if len(self._document._subclasses) == 1: + self._initial_query = {"_cls": self._document._subclasses[0]} + else: + self._initial_query = {"_cls": {"$in": self._document._subclasses}} self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None self._limit = None diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 507408d7..07ddf2d9 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -3392,6 +3392,34 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(B.objects.get(a=a).a, a) self.assertEqual(B.objects.get(a=a.id).a, a) + def test_cls_query_in_subclassed_docs(self): + + class Animal(Document): + name = StringField() + + meta = { + 'allow_inheritance': True + } + + class Dog(Animal): + pass + + class Cat(Animal): + pass + + self.assertEqual(Animal.objects(name='Charlie')._query, { + 'name': 'Charlie', + '_cls': { '$in': ('Animal', 'Animal.Dog', 'Animal.Cat') } + }) + self.assertEqual(Dog.objects(name='Charlie')._query, { + 'name': 'Charlie', + '_cls': 'Animal.Dog' + }) + self.assertEqual(Cat.objects(name='Charlie')._query, { + 'name': 'Charlie', + '_cls': 'Animal.Cat' + }) + if __name__ == '__main__': unittest.main()