Fixed inheritance and unique index creation (#140)

This commit is contained in:
Ross Lawley 2013-01-22 15:16:58 +00:00
parent 344dc64df8
commit 692f00864d
5 changed files with 65 additions and 46 deletions

View File

@ -29,6 +29,7 @@ Changes in 0.8.X
- Querysets now return clones and are no longer edit in place (#56)
- Added support for $maxDistance (#179)
- Uses getlasterror to test created on updated saves (#163)
- Fixed inheritance and unique index creation (#140)
Changes in 0.7.9
================

View File

@ -509,6 +509,34 @@ class BaseDocument(object):
obj._created = False
return obj
@classmethod
def _build_index_specs(cls, meta_indexes):
"""Generate and merge the full index specs
"""
geo_indices = cls._geo_indices()
unique_indices = cls._unique_with_indexes()
index_specs = [cls._build_index_spec(spec)
for spec in meta_indexes]
def merge_index_specs(index_specs, indices):
if not indices:
return index_specs
spec_fields = [v['fields']
for k, v in enumerate(index_specs)]
# Merge unqiue_indexes with existing specs
for k, v in enumerate(indices):
if v['fields'] in spec_fields:
index_specs[spec_fields.index(v['fields'])].update(v)
else:
index_specs.append(v)
return index_specs
index_specs = merge_index_specs(index_specs, geo_indices)
index_specs = merge_index_specs(index_specs, unique_indices)
return index_specs
@classmethod
def _build_index_spec(cls, spec):
"""Build a PyMongo index spec from a MongoEngine index spec.
@ -576,6 +604,7 @@ class BaseDocument(object):
"""
unique_indexes = []
for field_name, field in cls._fields.items():
sparse = False
# Generate a list of indexes needed by uniqueness constraints
if field.unique:
field.required = True
@ -596,11 +625,14 @@ class BaseDocument(object):
unique_with.append('.'.join(name_parts))
# Unique field should be required
parts[-1].required = True
sparse = (not sparse and
parts[-1].name not in cls.__dict__)
unique_fields += unique_with
# Add the new index to the list
index = [("%s%s" % (namespace, f), pymongo.ASCENDING)
fields = [("%s%s" % (namespace, f), pymongo.ASCENDING)
for f in unique_fields]
index = {'fields': fields, 'unique': True, 'sparse': sparse}
unique_indexes.append(index)
# Grab any embedded document field unique indexes
@ -612,6 +644,29 @@ class BaseDocument(object):
return unique_indexes
@classmethod
def _geo_indices(cls, inspected=None):
inspected = inspected or []
geo_indices = []
inspected.append(cls)
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GeoPointField = _import_class("GeoPointField")
for field in cls._fields.values():
if not isinstance(field, (EmbeddedDocumentField, GeoPointField)):
continue
if hasattr(field, 'document_type'):
field_cls = field.document_type
if field_cls in inspected:
continue
if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected)
elif field._geo_index:
geo_indices.append({'fields':
[(field.db_field, pymongo.GEO2D)]})
return geo_indices
@classmethod
def _lookup_field(cls, parts):
"""Lookup a field based on its attribute and return a list containing
@ -671,28 +726,6 @@ class BaseDocument(object):
parts = [f.db_field for f in cls._lookup_field(parts)]
return '.'.join(parts)
@classmethod
def _geo_indices(cls, inspected=None):
inspected = inspected or []
geo_indices = []
inspected.append(cls)
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GeoPointField = _import_class("GeoPointField")
for field in cls._fields.values():
if not isinstance(field, (EmbeddedDocumentField, GeoPointField)):
continue
if hasattr(field, 'document_type'):
field_cls = field.document_type
if field_cls in inspected:
continue
if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected)
elif field._geo_index:
geo_indices.append(field)
return geo_indices
def __set_field_display(self):
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():

View File

@ -329,10 +329,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
meta = new_class._meta
# Set index specifications
meta['index_specs'] = [new_class._build_index_spec(spec)
for spec in meta['indexes']]
unique_indexes = new_class._unique_with_indexes()
new_class._meta['unique_indexes'] = unique_indexes
meta['index_specs'] = new_class._build_index_specs(meta['indexes'])
# If collection is a callable - call it and set the value
collection = meta.get('collection')

View File

@ -105,7 +105,7 @@ class Document(BaseDocument):
By default, _cls will be added to the start of every index (that
doesn't contain a list) if allow_inheritance is True. This can be
disabled by either setting types to False on the specific index or
disabled by either setting cls to False on the specific index or
by setting index_cls to False on the meta dictionary for the document.
"""
@ -481,12 +481,6 @@ class Document(BaseDocument):
first_field = fields[0][0]
return first_field == '_cls'
# Ensure indexes created by uniqueness constraints
for index in cls._meta['unique_indexes']:
cls_indexed = cls_indexed or includes_cls(index)
collection.ensure_index(index, unique=True, background=background,
drop_dups=drop_dups, **index_opts)
# Ensure document-defined indexes are created
if cls._meta['index_specs']:
index_spec = cls._meta['index_specs']
@ -496,7 +490,8 @@ class Document(BaseDocument):
cls_indexed = cls_indexed or includes_cls(fields)
opts = index_opts.copy()
opts.update(spec)
collection.ensure_index(fields, background=background, **opts)
collection.ensure_index(fields, background=background,
drop_dups=drop_dups, **opts)
# If _cls is being used (for polymorphism), it needs an index,
# only if another index doesn't begin with _cls
@ -505,12 +500,6 @@ class Document(BaseDocument):
collection.ensure_index('_cls', background=background,
**index_opts)
# Add geo indicies
for field in cls._geo_indices():
index_spec = [(field.db_field, pymongo.GEO2D)]
collection.ensure_index(index_spec, background=background,
**index_opts)
class DynamicDocument(Document):
"""A Dynamic Document class allowing flexible, expandable and uncontrolled

View File

@ -259,13 +259,12 @@ class IndexesTest(unittest.TestCase):
tags = ListField(StringField())
meta = {
'indexes': [
{'fields': ['-date'], 'unique': True,
'sparse': True, 'types': False},
{'fields': ['-date'], 'unique': True, 'sparse': True},
],
}
self.assertEqual([{'fields': [('addDate', -1)], 'unique': True,
'sparse': True, 'types': False}],
'sparse': True}],
BlogPost._meta['index_specs'])
BlogPost.drop_collection()
@ -674,7 +673,7 @@ class IndexesTest(unittest.TestCase):
User.drop_collection()
def test_types_index_with_pk(self):
def test_index_with_pk(self):
"""Ensure you can use `pk` as part of a query"""
class Comment(EmbeddedDocument):
@ -687,7 +686,7 @@ class IndexesTest(unittest.TestCase):
{'fields': ['pk', 'comments.comment_id'],
'unique': True}]}
except UnboundLocalError:
self.fail('Unbound local error at types index + pk definition')
self.fail('Unbound local error at index + pk definition')
info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]