diff --git a/mongoengine/document.py b/mongoengine/document.py index 3b6df4f8..6345e6da 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -571,39 +571,55 @@ class Document(BaseDocument): if cls._meta.get('abstract'): return [] - indexes = [] - index_cls = cls._meta.get('index_cls', True) + # get all the base classes, subclasses and sieblings + classes = [] + def get_classes(cls): - # Ensure document-defined indexes are created - if cls._meta['index_specs']: - index_spec = cls._meta['index_specs'] - for spec in index_spec: - spec = spec.copy() - fields = spec.pop('fields') - indexes.append(fields) + if (cls not in classes and + isinstance(cls, TopLevelDocumentMetaclass)): + classes.append(cls) - # add all of the indexes from the base classes - if go_up: for base_cls in cls.__bases__: - if isinstance(base_cls, TopLevelDocumentMetaclass): - for index in base_cls.list_indexes(go_up=True, go_down=False): - if index not in indexes: - indexes.append(index) - - # add all of the indexes from subclasses - if go_down: + if (isinstance(base_cls, TopLevelDocumentMetaclass) and + base_cls != Document and + not base_cls._meta.get('abstract') and + base_cls._get_collection().full_name == cls._get_collection().full_name and + base_cls not in classes): + classes.append(base_cls) + get_classes(base_cls) for subclass in cls.__subclasses__(): - for index in subclass.list_indexes(go_up=False, go_down=True): - if index not in indexes: - indexes.append(index) + if (isinstance(base_cls, TopLevelDocumentMetaclass) and + subclass._get_collection().full_name == cls._get_collection().full_name and + subclass not in classes): + classes.append(subclass) + get_classes(subclass) + + get_classes(cls) + + # get the indexes spec for all of the gathered classes + def get_indexes_spec(cls): + indexes = [] + + if cls._meta['index_specs']: + index_spec = cls._meta['index_specs'] + for spec in index_spec: + spec = spec.copy() + fields = spec.pop('fields') + indexes.append(fields) + return indexes + + indexes = [] + for cls in classes: + for index in get_indexes_spec(cls): + if index not in indexes: + indexes.append(index) # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed - if go_up and go_down: - if [(u'_id', 1)] not in indexes: - indexes.append([(u'_id', 1)]) - if (index_cls and - cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): - indexes.append([(u'_cls', 1)]) + if [(u'_id', 1)] not in indexes: + indexes.append([(u'_id', 1)]) + if (cls._meta.get('index_cls', True) and + cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): + indexes.append([(u'_cls', 1)]) return indexes diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 6bd2e3c0..52e3794c 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -152,6 +152,43 @@ class ClassMethodsTest(unittest.TestCase): BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1') self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': [] }) + def test_compare_indexes_multiple_subclasses(self): + """ Ensure that compare_indexes behaves correctly if called from a + class, which base class has multiple subclasses + """ + + class BlogPost(Document): + author = StringField() + title = StringField() + description = StringField() + + meta = { + 'allow_inheritance': True + } + + class BlogPostWithTags(BlogPost): + tags = StringField() + tag_list = ListField(StringField()) + + meta = { + 'indexes': [('author', 'tags')] + } + + class BlogPostWithCustomField(BlogPost): + custom = DictField() + + meta = { + 'indexes': [('author', 'custom')] + } + + BlogPost.ensure_indexes() + BlogPostWithTags.ensure_indexes() + BlogPostWithCustomField.ensure_indexes() + + self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPostWithTags.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPostWithCustomField.compare_indexes(), { 'missing': [], 'extra': [] }) + def test_list_indexes_inheritance(self): """ ensure that all of the indexes are listed regardless of the super- or sub-class that we call it from @@ -190,7 +227,6 @@ class ClassMethodsTest(unittest.TestCase): BlogPostWithTags.list_indexes()) self.assertEqual(BlogPost.list_indexes(), BlogPostWithTagsAndExtraText.list_indexes()) - print BlogPost.list_indexes() self.assertEqual(BlogPost.list_indexes(), [[('_cls', 1), ('author', 1), ('tags', 1)], [('_cls', 1), ('author', 1), ('tags', 1), ('extra_text', 1)], diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index f0116311..28490c95 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -189,6 +189,41 @@ class InheritanceTest(unittest.TestCase): self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) + def test_indexes_and_multiple_inheritance(self): + """ Ensure that all of the indexes are created for a document with + multiple inheritance. + """ + + class A(Document): + a = StringField() + + meta = { + 'allow_inheritance': True, + 'indexes': ['a'] + } + + class B(Document): + b = StringField() + + meta = { + 'allow_inheritance': True, + 'indexes': ['b'] + } + + class C(A, B): + pass + + A.drop_collection() + B.drop_collection() + C.drop_collection() + + C.ensure_indexes() + + self.assertEqual( + [idx['key'] for idx in C._get_collection().index_information().values()], + [[(u'_cls', 1), (u'b', 1)], [(u'_id', 1)], [(u'_cls', 1), (u'a', 1)]] + ) + def test_polymorphic_queries(self): """Ensure that the correct subclasses are returned from a query """