From 692f00864d982a8c54f08bcda3712b43d9708751 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 22 Jan 2013 15:16:58 +0000 Subject: [PATCH] Fixed inheritance and unique index creation (#140) --- docs/changelog.rst | 1 + mongoengine/base/document.py | 79 +++++++++++++++++++++++---------- mongoengine/base/metaclasses.py | 5 +-- mongoengine/document.py | 17 ++----- tests/document/indexes.py | 9 ++-- 5 files changed, 65 insertions(+), 46 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1905a9d3..cb0ac6c6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,6 +29,7 @@ Changes in 0.8.X - Querysets now return clones and are no longer edit in place (#56) - Added support for $maxDistance (#179) - Uses getlasterror to test created on updated saves (#163) +- Fixed inheritance and unique index creation (#140) Changes in 0.7.9 ================ diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 93bde8ec..9f400618 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -509,6 +509,34 @@ class BaseDocument(object): obj._created = False return obj + @classmethod + def _build_index_specs(cls, meta_indexes): + """Generate and merge the full index specs + """ + + geo_indices = cls._geo_indices() + unique_indices = cls._unique_with_indexes() + index_specs = [cls._build_index_spec(spec) + for spec in meta_indexes] + + def merge_index_specs(index_specs, indices): + if not indices: + return index_specs + + spec_fields = [v['fields'] + for k, v in enumerate(index_specs)] + # Merge unqiue_indexes with existing specs + for k, v in enumerate(indices): + if v['fields'] in spec_fields: + index_specs[spec_fields.index(v['fields'])].update(v) + else: + index_specs.append(v) + return index_specs + + index_specs = merge_index_specs(index_specs, geo_indices) + index_specs = merge_index_specs(index_specs, unique_indices) + return index_specs + @classmethod def _build_index_spec(cls, spec): """Build a PyMongo index spec from a MongoEngine index spec. @@ -576,6 +604,7 @@ class BaseDocument(object): """ unique_indexes = [] for field_name, field in cls._fields.items(): + sparse = False # Generate a list of indexes needed by uniqueness constraints if field.unique: field.required = True @@ -596,11 +625,14 @@ class BaseDocument(object): unique_with.append('.'.join(name_parts)) # Unique field should be required parts[-1].required = True + sparse = (not sparse and + parts[-1].name not in cls.__dict__) unique_fields += unique_with # Add the new index to the list - index = [("%s%s" % (namespace, f), pymongo.ASCENDING) + fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields] + index = {'fields': fields, 'unique': True, 'sparse': sparse} unique_indexes.append(index) # Grab any embedded document field unique indexes @@ -612,6 +644,29 @@ class BaseDocument(object): return unique_indexes + @classmethod + def _geo_indices(cls, inspected=None): + inspected = inspected or [] + geo_indices = [] + inspected.append(cls) + + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GeoPointField = _import_class("GeoPointField") + + for field in cls._fields.values(): + if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): + continue + if hasattr(field, 'document_type'): + field_cls = field.document_type + if field_cls in inspected: + continue + if hasattr(field_cls, '_geo_indices'): + geo_indices += field_cls._geo_indices(inspected) + elif field._geo_index: + geo_indices.append({'fields': + [(field.db_field, pymongo.GEO2D)]}) + return geo_indices + @classmethod def _lookup_field(cls, parts): """Lookup a field based on its attribute and return a list containing @@ -671,28 +726,6 @@ class BaseDocument(object): parts = [f.db_field for f in cls._lookup_field(parts)] return '.'.join(parts) - @classmethod - def _geo_indices(cls, inspected=None): - inspected = inspected or [] - geo_indices = [] - inspected.append(cls) - - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - GeoPointField = _import_class("GeoPointField") - - for field in cls._fields.values(): - if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): - continue - if hasattr(field, 'document_type'): - field_cls = field.document_type - if field_cls in inspected: - continue - if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected) - elif field._geo_index: - geo_indices.append(field) - return geo_indices - def __set_field_display(self): """Dynamically set the display value for a field with choices""" for attr_name, field in self._fields.items(): diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index af39e144..2b63bfa8 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -329,10 +329,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): meta = new_class._meta # Set index specifications - meta['index_specs'] = [new_class._build_index_spec(spec) - for spec in meta['indexes']] - unique_indexes = new_class._unique_with_indexes() - new_class._meta['unique_indexes'] = unique_indexes + meta['index_specs'] = new_class._build_index_specs(meta['indexes']) # If collection is a callable - call it and set the value collection = meta.get('collection') diff --git a/mongoengine/document.py b/mongoengine/document.py index 69d4d406..fff7efad 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -105,7 +105,7 @@ class Document(BaseDocument): By default, _cls will be added to the start of every index (that doesn't contain a list) if allow_inheritance is True. This can be - disabled by either setting types to False on the specific index or + disabled by either setting cls to False on the specific index or by setting index_cls to False on the meta dictionary for the document. """ @@ -481,12 +481,6 @@ class Document(BaseDocument): first_field = fields[0][0] return first_field == '_cls' - # Ensure indexes created by uniqueness constraints - for index in cls._meta['unique_indexes']: - cls_indexed = cls_indexed or includes_cls(index) - collection.ensure_index(index, unique=True, background=background, - drop_dups=drop_dups, **index_opts) - # Ensure document-defined indexes are created if cls._meta['index_specs']: index_spec = cls._meta['index_specs'] @@ -496,7 +490,8 @@ class Document(BaseDocument): cls_indexed = cls_indexed or includes_cls(fields) opts = index_opts.copy() opts.update(spec) - collection.ensure_index(fields, background=background, **opts) + collection.ensure_index(fields, background=background, + drop_dups=drop_dups, **opts) # If _cls is being used (for polymorphism), it needs an index, # only if another index doesn't begin with _cls @@ -505,12 +500,6 @@ class Document(BaseDocument): collection.ensure_index('_cls', background=background, **index_opts) - # Add geo indicies - for field in cls._geo_indices(): - index_spec = [(field.db_field, pymongo.GEO2D)] - collection.ensure_index(index_spec, background=background, - **index_opts) - class DynamicDocument(Document): """A Dynamic Document class allowing flexible, expandable and uncontrolled diff --git a/tests/document/indexes.py b/tests/document/indexes.py index cf25f61e..fb278aa7 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -259,13 +259,12 @@ class IndexesTest(unittest.TestCase): tags = ListField(StringField()) meta = { 'indexes': [ - {'fields': ['-date'], 'unique': True, - 'sparse': True, 'types': False}, + {'fields': ['-date'], 'unique': True, 'sparse': True}, ], } self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, - 'sparse': True, 'types': False}], + 'sparse': True}], BlogPost._meta['index_specs']) BlogPost.drop_collection() @@ -674,7 +673,7 @@ class IndexesTest(unittest.TestCase): User.drop_collection() - def test_types_index_with_pk(self): + def test_index_with_pk(self): """Ensure you can use `pk` as part of a query""" class Comment(EmbeddedDocument): @@ -687,7 +686,7 @@ class IndexesTest(unittest.TestCase): {'fields': ['pk', 'comments.comment_id'], 'unique': True}]} except UnboundLocalError: - self.fail('Unbound local error at types index + pk definition') + self.fail('Unbound local error at index + pk definition') info = BlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()]