diff --git a/docs/changelog.rst b/docs/changelog.rst index 57f49867..2ada58e7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.6.15 ================= +- Fixed geo index creation through reference fields - Added support for args / kwargs when using @queryset_manager - Deref list custom id fix diff --git a/mongoengine/base.py b/mongoengine/base.py index 6eb0b0f9..0d216d97 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -1113,7 +1113,11 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error inspected = inspected or [] geo_indices = [] inspected.append(cls) + + from fields import EmbeddedDocumentField, GeoPointField for field in cls._fields.values(): + if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): + continue if hasattr(field, 'document_type'): field_cls = field.document_type if field_cls in inspected: diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index f94b7092..4f7443f7 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -483,7 +483,6 @@ class QuerySet(object): self._collection.ensure_index(index_spec, background=background, **index_opts) - @classmethod def _build_index_spec(cls, doc_cls, spec): """Build a PyMongo index spec from a MongoEngine index spec. diff --git a/tests/test_document.py b/tests/test_document.py index 491a6856..30d92447 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -872,15 +872,26 @@ class DocumentTest(unittest.TestCase): def test_geo_indexes_recursion(self): - class User(Document): - channel = ReferenceField('Channel') + class Location(Document): + name = StringField() location = GeoPointField() - class Channel(Document): - user = ReferenceField('User') - location = GeoPointField() + class Parent(Document): + name = StringField() + location = ReferenceField(Location) - self.assertEquals(len(User._geo_indices()), 2) + Location.drop_collection() + Parent.drop_collection() + + list(Parent.objects) + + collection = Parent._get_collection() + info = collection.index_information() + + self.assertFalse('location_2d' in info) + + self.assertEquals(len(Parent._geo_indices()), 0) + self.assertEquals(len(Location._geo_indices()), 1) def test_covered_index(self): """Ensure that covered indexes can be used @@ -3170,7 +3181,7 @@ name: Field is required ("name")""" class Person(BasePerson): name = StringField(required=True) - + p = Person(age=15) self.assertRaises(ValidationError, p.validate)