diff --git a/docs/changelog.rst b/docs/changelog.rst index 64047c38..3bca6a17 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,9 @@ Changelog Changes in 0.7.X ================= +- Embedded Documents dont care about inheritance + + - Use weakref proxies in base lists / dicts (MongoEngine/mongoengine#74) - Improved queryset filtering (hmarr/mongoengine#554) - Fixed Dynamic Documents and Embedded Documents (hmarr/mongoengine#561) diff --git a/mongoengine/base.py b/mongoengine/base.py index 6c86506d..a05403db 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -17,6 +17,12 @@ import pymongo from bson import ObjectId from bson.dbref import DBRef +ALLOW_INHERITANCE = True + +_document_registry = {} +_class_registry = {} + + class NotRegistered(Exception): pass @@ -111,16 +117,13 @@ class ValidationError(AssertionError): return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) -_document_registry = {} -_module_registry = {} - - def get_document(name): doc = _document_registry.get(name, None) if not doc: # Possible old style names end = ".%s" % name - possible_match = [k for k in _document_registry.keys() if k.endswith(end)] + possible_match = [k for k in _document_registry.keys() + if k.endswith(end)] if len(possible_match) == 1: doc = _document_registry.get(possible_match.pop(), None) if not doc: @@ -153,7 +156,8 @@ class BaseField(object): def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, - validation=None, choices=None, verbose_name=None, help_text=None): + validation=None, choices=None, verbose_name=None, + help_text=None): self.db_field = (db_field or name) if not primary_key else '_id' if name: msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" @@ -237,11 +241,15 @@ class BaseField(object): value_to_check = value.__class__ if is_cls else value err_msg = 'an instance' if is_cls else 'one' if isinstance(self.choices[0], (list, tuple)): - option_keys = [option_key for option_key, option_value in self.choices] + option_keys = [k for k, v in self.choices] if value_to_check not in option_keys: - self.error('Value must be %s of %s' % (err_msg, unicode(option_keys))) + msg = ('Value must be %s of %s' % + (err_msg, unicode(option_keys))) + self.error(msg) elif value_to_check not in self.choices: - self.error('Value must be %s of %s' % (err_msg, unicode(self.choices))) + msg = ('Value must be %s of %s' % + (err_msg, unicode(self.choices))) + self.error() # check validation argument if self.validation is not None: @@ -286,7 +294,8 @@ class ComplexBaseField(BaseField): value = super(ComplexBaseField, self).__get__(instance, owner) # Convert lists / values so we can watch for any changes on them - if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): value = BaseList(value, instance, self.name) instance._data[self.name] = value elif isinstance(value, dict) and not isinstance(value, BaseDict): @@ -328,7 +337,8 @@ class ComplexBaseField(BaseField): return value if self.field: - value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) + value_dict = dict([(key, self.field.to_python(item)) + for key, item in value.items()]) else: value_dict = {} for k, v in value.items(): @@ -345,7 +355,8 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_python(v) if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), key=operator.itemgetter(0))] + return [v for k, v in sorted(value_dict.items(), + key=operator.itemgetter(0))] return value_dict def to_mongo(self, value): @@ -368,7 +379,8 @@ class ComplexBaseField(BaseField): return value if self.field: - value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) + value_dict = dict([(key, self.field.to_mongo(item)) + for key, item in value.items()]) else: value_dict = {} for k, v in value.items(): @@ -381,8 +393,11 @@ class ComplexBaseField(BaseField): # If its a document that is not inheritable it won't have # _types / _cls data so make it a generic reference allows # us to dereference - meta = getattr(v, 'meta', getattr(v, '_meta', {})) - if meta and not meta.get('allow_inheritance', True) and not self.field: + meta = getattr(v, '_meta', {}) + allow_inheritance = ( + meta.get('allow_inheritance', ALLOW_INHERITANCE) + == False) + if allow_inheritance and not self.field: from fields import GenericReferenceField value_dict[k] = GenericReferenceField().to_mongo(v) else: @@ -394,7 +409,8 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_mongo(v) if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), key=operator.itemgetter(0))] + return [v for k, v in sorted(value_dict.items(), + key=operator.itemgetter(0))] return value_dict def validate(self, value): @@ -479,77 +495,25 @@ class DocumentMetaclass(type): """ def __new__(cls, name, bases, attrs): - - def _get_mixin_fields(base): - attrs = {} - attrs.update(dict([(k, v) for k, v in base.__dict__.items() - if issubclass(v.__class__, BaseField)])) - - # Handle simple mixin's with meta - if hasattr(base, 'meta') and not isinstance(base, DocumentMetaclass): - meta = attrs.get('meta', {}) - meta.update(base.meta) - attrs['meta'] = meta - - for p_base in base.__bases__: - #optimize :-) - if p_base in (object, BaseDocument): - continue - - attrs.update(_get_mixin_fields(p_base)) - return attrs - - metaclass = attrs.get('my_metaclass') + bases = cls._get_bases(bases) super_new = super(DocumentMetaclass, cls).__new__ + + # If a base class just call super + metaclass = attrs.get('my_metaclass') if metaclass and issubclass(metaclass, DocumentMetaclass): return super_new(cls, name, bases, attrs) + attrs['_is_document'] = attrs.get('_is_document', False) + + # Handle document Fields + + # Merge all fields from subclasses doc_fields = {} - class_name = [name] - superclasses = {} - simple_class = True - for base in bases: - # Include all fields present in superclasses + for base in bases[::-1]: if hasattr(base, '_fields'): doc_fields.update(base._fields) - # Get superclasses from superclass - superclasses[base._class_name] = base - superclasses.update(base._superclasses) - else: # Add any mixin fields - attrs.update(_get_mixin_fields(base)) - if hasattr(base, '_meta') and not base._meta.get('abstract'): - # Ensure that the Document class may be subclassed - - # inheritance may be disabled to remove dependency on - # additional fields _cls and _types - class_name.append(base._class_name) - if not base._meta.get('allow_inheritance_defined', True): - warnings.warn( - "%s uses inheritance, the default for allow_inheritance " - "is changing to off by default. Please add it to the " - "document meta." % name, - FutureWarning - ) - if base._meta.get('allow_inheritance', True) == False: - raise ValueError('Document %s may not be subclassed' % - base.__name__) - else: - simple_class = False - - doc_class_name = '.'.join(reversed(class_name)) - meta = attrs.get('_meta', {}) - meta.update(attrs.get('meta', {})) - - if 'allow_inheritance' not in meta: - meta['allow_inheritance'] = True - - # Only simple classes - direct subclasses of Document - may set - # allow_inheritance to False - if not simple_class and not meta['allow_inheritance'] and not meta['abstract']: - raise ValueError('Only direct subclasses of Document may set ' - '"allow_inheritance" to False') - - # Add the document's fields to the _fields attribute + # Discover any document fields field_names = {} for attr_name, attr_value in attrs.iteritems(): if not isinstance(attr_value, BaseField): @@ -559,69 +523,93 @@ class DocumentMetaclass(type): attr_value.db_field = attr_name doc_fields[attr_name] = attr_value - field_names[attr_value.db_field] = field_names.get(attr_value.db_field, 0) + 1 + # Count names to ensure no db_field redefinitions + field_names[attr_value.db_field] = field_names.get( + attr_value.db_field, 0) + 1 + # Ensure no duplicate db_fields duplicate_db_fields = [k for k, v in field_names.items() if v > 1] if duplicate_db_fields: - raise InvalidDocumentError("Multiple db_fields defined for: %s " % ", ".join(duplicate_db_fields)) + msg = ("Multiple db_fields defined for: %s " % + ", ".join(duplicate_db_fields)) + raise InvalidDocumentError(msg) + + # Set _fields and db_field maps attrs['_fields'] = doc_fields - attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items() if k != v.db_field]) - attrs['_reverse_db_field_map'] = dict([(v, k) for k, v in attrs['_db_field_map'].items()]) - attrs['_meta'] = meta - attrs['_class_name'] = doc_class_name + attrs['_db_field_map'] = dict( + ((k, v.db_field) for k, v in doc_fields.items() + if k != v.db_field)) + attrs['_reverse_db_field_map'] = dict( + (v, k) for k, v in attrs['_db_field_map'].iteritems()) + + # + # Set document hierarchy + # + superclasses = {} + class_name = [name] + for base in bases: + if (not getattr(base, '_is_base_cls', True) and + not getattr(base, '_meta', {}).get('abstract', True)): + # Collate heirarchy for _cls and _types + class_name.append(base.__name__) + + # Get superclasses from superclass + superclasses[base._class_name] = base + superclasses.update(base._superclasses) + + attrs['_class_name'] = '.'.join(reversed(class_name)) attrs['_superclasses'] = superclasses - if 'Document' not in _module_registry: - from mongoengine.document import Document, EmbeddedDocument - from mongoengine.fields import DictField - _module_registry['Document'] = Document - _module_registry['EmbeddedDocument'] = EmbeddedDocument - _module_registry['DictField'] = DictField - else: - Document = _module_registry.get('Document') - EmbeddedDocument = _module_registry.get('EmbeddedDocument') - DictField = _module_registry.get('DictField') - + # Create the new_class new_class = super_new(cls, name, bases, attrs) + # Handle delete rules + Document, EmbeddedDocument, DictField = cls._import_classes() for field in new_class._fields.itervalues(): - field.owner_document = new_class - - delete_rule = getattr(field, 'reverse_delete_rule', DO_NOTHING) f = field + f.owner_document = new_class + delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): - delete_rule = getattr(f.field, 'reverse_delete_rule', DO_NOTHING) + delete_rule = getattr(f.field, + 'reverse_delete_rule', + DO_NOTHING) if isinstance(f, DictField) and delete_rule != DO_NOTHING: - raise InvalidDocumentError("Reverse delete rules are not supported for %s (field: %s)" % (field.__class__.__name__, field.name)) + msg = ("Reverse delete rules are not supported " + "for %s (field: %s)" % + (field.__class__.__name__, field.name)) + raise InvalidDocumentError(msg) + f = field.field if delete_rule != DO_NOTHING: if issubclass(new_class, EmbeddedDocument): - raise InvalidDocumentError("Reverse delete rules are not supported for EmbeddedDocuments (field: %s)" % field.name) - f.document_type.register_delete_rule(new_class, field.name, delete_rule) + msg = ("Reverse delete rules are not supported for " + "EmbeddedDocuments (field: %s)" % field.name) + raise InvalidDocumentError(msg) + f.document_type.register_delete_rule(new_class, + field.name, delete_rule) - if (field.name and - hasattr(Document, field.name) and + if (field.name and hasattr(Document, field.name) and EmbeddedDocument not in new_class.mro()): - raise InvalidDocumentError("%s is a document method and not a valid field name" % field.name) + msg = ("%s is a document method and not a valid " + "field name" % field.name) + raise InvalidDocumentError(msg) + # Merge in exceptions with parent hierarchy + exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) module = attrs.get('__module__') + for exc in exceptions_to_merge: + name = exc.__name__ + parents = tuple(getattr(base, name) for base in bases + if hasattr(base, name)) or (exc,) + # Create new exception and set to new_class + exception = type(name, parents, {'__module__': module}) + setattr(new_class, name, exception) - base_excs = tuple(base.DoesNotExist for base in bases - if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) - exc = subclass_exception('DoesNotExist', base_excs, module) - new_class.add_to_class('DoesNotExist', exc) + # Add class to the _document_registry + _document_registry[new_class._class_name] = new_class - base_excs = tuple(base.MultipleObjectsReturned for base in bases - if hasattr(base, 'MultipleObjectsReturned')) - base_excs = base_excs or (MultipleObjectsReturned,) - exc = subclass_exception('MultipleObjectsReturned', base_excs, module) - new_class.add_to_class('MultipleObjectsReturned', exc) - - global _document_registry - _document_registry[doc_class_name] = new_class - - # in Python 2, User-defined methods objects have special read-only + # In Python 2, User-defined methods objects have special read-only # attributes 'im_func' and 'im_self' which contain the function obj # and class instance object respectively. With Python 3 these special # attributes have been replaced by __func__ and __self__. The Blinker @@ -633,15 +621,40 @@ class DocumentMetaclass(type): if isinstance(val, classmethod): f = val.__get__(new_class) if hasattr(f, '__func__') and not hasattr(f, 'im_func'): - f.__dict__.update({'im_func':getattr(f, '__func__')}) + f.__dict__.update({'im_func': getattr(f, '__func__')}) if hasattr(f, '__self__') and not hasattr(f, 'im_self'): - f.__dict__.update({'im_self':getattr(f, '__self__')}) + f.__dict__.update({'im_self': getattr(f, '__self__')}) return new_class def add_to_class(self, name, value): setattr(self, name, value) + @classmethod + def _get_bases(cls, bases): + if isinstance(bases, BasesTuple): + return bases + seen = [] + bases = cls.__get_bases(bases) + unique_bases = (b for b in bases if not (b in seen or seen.append(b))) + return BasesTuple(unique_bases) + + @classmethod + def __get_bases(cls, bases): + for base in bases: + if base is object: + continue + yield base + for child_base in cls.__get_bases(base.__bases__): + yield child_base + + @classmethod + def _import_classes(cls): + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + DictField = _import_class('DictField') + return (Document, EmbeddedDocument, DictField) + class TopLevelDocumentMetaclass(DocumentMetaclass): """Metaclass for top-level documents (i.e. documents that have their own @@ -649,125 +662,157 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): """ def __new__(cls, name, bases, attrs): + + bases = cls._get_bases(bases) super_new = super(TopLevelDocumentMetaclass, cls).__new__ - # Classes defined in this package are abstract and should not have - # their own metadata with DB collection, etc. - # __metaclass__ is only set on the class with the __metaclass__ - # attribute (i.e. it is not set on subclasses). This differentiates - # 'real' documents from the 'Document' class - # - # Also assume a class is abstract if it has abstract set to True in - # its meta dictionary. This allows custom Document superclasses. - if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass or - ('meta' in attrs and attrs['meta'].get('abstract', False))): - # Make sure no base class was non-abstract - non_abstract_bases = [b for b in bases - if hasattr(b, '_meta') and not b._meta.get('abstract', False)] - if non_abstract_bases: - raise ValueError("Abstract document cannot have non-abstract base") + + # Set default _meta data if base class, otherwise get user defined meta + if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass): + # defaults + attrs['_meta'] = { + 'abstract': True, + 'max_documents': None, + 'max_size': None, + 'ordering': [], # default ordering applied at runtime + 'indexes': [], # indexes to be ensured at runtime + 'id_field': None, + 'index_background': False, + 'index_drop_dups': False, + 'index_opts': None, + 'delete_rules': None, + 'allow_inheritance': None, + } + attrs['_is_base_cls'] = True + attrs['_meta'].update(attrs.get('meta', {})) + else: + attrs['_meta'] = attrs.get('meta', {}) + # Explictly set abstract to false unless set + attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False) + attrs['_is_base_cls'] = False + + # Set flag marking as document class - as opposed to an object mixin + attrs['_is_document'] = True + + # Ensure queryset_class is inherited + if 'objects' in attrs: + manager = attrs['objects'] + if hasattr(manager, 'queryset_class'): + attrs['_meta']['queryset_class'] = manager.queryset_class + + # Clean up top level meta + if 'meta' in attrs: + del(attrs['meta']) + + # Find the parent document class + parent_doc_cls = [b for b in bases + if b.__class__ == TopLevelDocumentMetaclass] + parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] + + # Prevent classes setting collection different to their parents + # If parent wasn't an abstract class + if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) + and not parent_doc_cls._meta.get('abstract', True)): + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['_meta']['collection']) + + # Ensure abstract documents have abstract bases + if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): + if (parent_doc_cls and + not parent_doc_cls._meta.get('abstract', False)): + msg = "Abstract document cannot have non-abstract base" + raise ValueError(msg) return super_new(cls, name, bases, attrs) - collection = ''.join('_%s' % c if c.isupper() else c for c in name).strip('_').lower() + # Merge base class metas. + # Uses a special MetaDict that handles various merging rules + meta = MetaDict() + for base in bases[::-1]: + # Add any mixin metadata from plain objects + if hasattr(base, 'meta'): + meta.merge(base.meta) + elif hasattr(base, '_meta'): + # Warn if allow_inheritance isn't set and prevent + # inheritance of classes where inheritance is set to False + allow_inheritance = base._meta.get('allow_inheritance', + ALLOW_INHERITANCE) + if not base._is_base_cls and allow_inheritance is None: + warnings.warn( + "%s uses inheritance, the default for " + "allow_inheritance is changing to off by default. " + "Please add it to the document meta." % name, + FutureWarning + ) + elif (allow_inheritance == False and + not base._meta.get('abstract')): + raise ValueError('Document %s may not be subclassed' % + base.__name__) + meta.merge(base._meta) - id_field = None - abstract_base_indexes = [] - base_indexes = [] - base_meta = {} + # Set collection in the meta if its callable + if (getattr(base, '_is_document', False) and + not base._meta.get('abstract')): + collection = meta.get('collection', None) + if callable(collection): + meta['collection'] = collection(base) - # Subclassed documents inherit collection from superclass - for base in bases: - if hasattr(base, '_meta'): - if 'collection' in attrs.get('meta', {}) and not base._meta.get('abstract', False): - import warnings - msg = "Trying to set a collection on a subclass (%s)" % name - warnings.warn(msg, SyntaxWarning) - del(attrs['meta']['collection']) - if base._get_collection_name(): - collection = base._get_collection_name() + # Standard object mixin - merge in any Fields + if not hasattr(base, '_meta'): + attrs.update(dict([(k, v) for k, v in base.__dict__.items() + if issubclass(v.__class__, BaseField)])) - # Propagate inherited values - keys_to_propogate = ( - 'index_background', 'index_drop_dups', 'index_opts', - 'allow_inheritance', 'queryset_class', 'db_alias', - 'shard_key' - ) - for key in keys_to_propogate: - if key in base._meta: - base_meta[key] = base._meta[key] + meta.merge(attrs.get('_meta', {})) # Top level meta - id_field = id_field or base._meta.get('id_field') - if base._meta.get('abstract', False): - abstract_base_indexes += base._meta.get('indexes', []) - else: - base_indexes += base._meta.get('indexes', []) - try: - base_meta['objects'] = base.__getattribute__(base, 'objects') - except TypeError: - pass - except AttributeError: - pass + # Only simple classes (direct subclasses of Document) + # may set allow_inheritance to False + simple_class = all([b._meta.get('abstract') + for b in bases if hasattr(b, '_meta')]) + if (not simple_class and meta['allow_inheritance'] == False and + not meta['abstract']): + raise ValueError('Only direct subclasses of Document may set ' + '"allow_inheritance" to False') - # defaults - meta = { - 'abstract': False, - 'collection': collection, - 'max_documents': None, - 'max_size': None, - 'ordering': [], # default ordering applied at runtime - 'indexes': [], # indexes to be ensured at runtime - 'id_field': id_field, - 'index_background': False, - 'index_drop_dups': False, - 'index_opts': {}, - 'queryset_class': QuerySet, - 'delete_rules': {}, - 'allow_inheritance': True - } - - allow_inheritance_defined = ('allow_inheritance' in base_meta or - 'allow_inheritance'in attrs.get('meta', {})) - meta['allow_inheritance_defined'] = allow_inheritance_defined - meta.update(base_meta) - - # Apply document-defined meta options - meta.update(attrs.get('meta', {})) + # Set default collection name + if 'collection' not in meta: + meta['collection'] = ''.join('_%s' % c if c.isupper() else c + for c in name).strip('_').lower() attrs['_meta'] = meta - # Set up collection manager, needs the class to have fields so use - # DocumentMetaclass before instantiating CollectionManager object + # Call super and get the new class new_class = super_new(cls, name, bases, attrs) - collection = attrs['_meta'].get('collection', None) - if callable(collection): - new_class._meta['collection'] = collection(new_class) - - # Provide a default queryset unless one has been manually provided - manager = attrs.get('objects', meta.get('objects', QuerySetManager())) - if hasattr(manager, 'queryset_class'): - meta['queryset_class'] = manager.queryset_class - new_class.objects = manager - - indicies = list(meta['indexes']) + abstract_base_indexes - user_indexes = [QuerySet._build_index_spec(new_class, spec) - for spec in indicies] + base_indexes - new_class._meta['indexes'] = user_indexes + meta = new_class._meta + # Set index specifications + meta['index_specs'] = [QuerySet._build_index_spec(new_class, spec) + for spec in meta['indexes']] unique_indexes = cls._unique_with_indexes(new_class) new_class._meta['unique_indexes'] = unique_indexes + # If collection is a callable - call it and set the value + collection = meta.get('collection') + if callable(collection): + new_class._meta['collection'] = collection(new_class) + + # Provide a default queryset unless one has been set + manager = attrs.get('objects', QuerySetManager()) + new_class.objects = manager + + # Validate the fields and set primary key if needed for field_name, field in new_class._fields.iteritems(): - # Check for custom primary key if field.primary_key: - current_pk = new_class._meta['id_field'] + # Ensure only one primary key is set + current_pk = new_class._meta.get('id_field') if current_pk and current_pk != field_name: raise ValueError('Cannot override primary key field') + # Set primary key if not current_pk: new_class._meta['id_field'] = field_name - # Make 'Document.id' an alias to the real primary key field new_class.id = field - if not new_class._meta['id_field']: + # Set primary key if not defined by the document + if not new_class._meta.get('id_field'): new_class._meta['id_field'] = 'id' new_class._fields['id'] = ObjectIdField(db_field='_id') new_class.id = new_class._fields['id'] @@ -776,6 +821,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): @classmethod def _unique_with_indexes(cls, new_class, namespace=""): + """ + Find and set unique indexes + """ unique_indexes = [] for field_name, field in new_class._fields.items(): # Generate a list of indexes needed by uniqueness constraints @@ -801,18 +849,34 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): unique_fields += unique_with # Add the new index to the list - index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields] + index = [("%s%s" % (namespace, f), pymongo.ASCENDING) + for f in unique_fields] unique_indexes.append(index) # Grab any embedded document field unique indexes - if field.__class__.__name__ == "EmbeddedDocumentField" and field.document_type != new_class: + if (field.__class__.__name__ == "EmbeddedDocumentField" and + field.document_type != new_class): field_namespace = "%s." % field_name unique_indexes += cls._unique_with_indexes(field.document_type, - field_namespace) + field_namespace) return unique_indexes +class MetaDict(dict): + """Custom dictionary for meta classes. + Handles the merging of set indexes + """ + _merge_options = ('indexes',) + + def merge(self, new_options): + for k, v in new_options.iteritems(): + if k in self._merge_options: + self[k] = self.get(k, []) + v + else: + self[k] = v + + class BaseDocument(object): _dynamic = False @@ -877,10 +941,11 @@ class BaseDocument(object): self._data[name] = value if hasattr(self, '_changed_fields'): self._mark_as_changed(name) - - if not self._created and name in self._meta.get('shard_key', tuple()): - from queryset import OperationError - raise OperationError("Shard Keys are immutable. Tried to update %s" % name) + if (self._is_document and not self._created and + name in self._meta.get('shard_key', tuple())): + OperationError = _import_class('OperationError') + msg = "Shard Keys are immutable. Tried to update %s" % name + raise OperationError(msg) super(BaseDocument, self).__setattr__(name, value) @@ -912,7 +977,8 @@ class BaseDocument(object): value = data # Convert lists / values so we can watch for any changes on them - if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): value = BaseList(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) @@ -953,7 +1019,7 @@ class BaseDocument(object): data[field.db_field] = field.to_mongo(value) # Only add _cls and _types if allow_inheritance is not False if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', True) == False): + self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] if '_id' in data and data['_id'] is None: @@ -1012,9 +1078,11 @@ class BaseDocument(object): changed_fields.append(field_name) if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) for k, v in errors_dict.items()]) - raise InvalidDocumentError(""" -Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, errors)) + errors = "\n".join(["%s - %s" % (k, v) + for k, v in errors_dict.items()]) + msg = ("Invalid data to create a `%s` instance.\n%s" + % (cls._class_name, errors)) + raise InvalidDocumentError(msg) if PY25: data = dict([(str(k), v) for k, v in data.items()]) @@ -1029,7 +1097,8 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error if not key: return key = self._db_field_map.get(key, key) - if hasattr(self, '_changed_fields') and key not in self._changed_fields: + if (hasattr(self, '_changed_fields') and + key not in self._changed_fields): self._changed_fields.append(key) def _get_changed_fields(self, key='', inspected=None): @@ -1059,9 +1128,14 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error continue inspected.add(field.id) - if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed - _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k] - elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents + if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) + and db_field_name not in _changed_fields): + # Find all embedded fields that have been changed + changed = field._get_changed_fields(key, inspected) + _changed_fields += ["%s%s" % (key, k) for k in changed if k] + elif (isinstance(field, (list, tuple, dict)) and + db_field_name not in _changed_fields): + # Loop list / dict fields as they contain documents # Determine the iterator to use if not hasattr(field, 'items'): iterator = enumerate(field) @@ -1071,7 +1145,9 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error if not hasattr(value, '_get_changed_fields'): continue list_key = "%s%s." % (key, index) - _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k] + changed = value._get_changed_fields(list_key, inspected) + _changed_fields += ["%s%s" % (list_key, k) + for k in changed if k] return _changed_fields def _delta(self): @@ -1113,7 +1189,8 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error # If we've set a value that ain't the default value dont unset it. default = None - if self._dynamic and len(parts) and parts[0] in self._dynamic_fields: + if (self._dynamic and len(parts) and + parts[0] in self._dynamic_fields): del(set_data[path]) unset_data[path] = 1 continue @@ -1126,7 +1203,8 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error for p in parts: if p.isdigit(): d = d[int(p)] - elif hasattr(d, '__getattribute__') and not isinstance(d, dict): + elif (hasattr(d, '__getattribute__') and + not isinstance(d, dict)): real_path = d._reverse_db_field_map.get(p, p) d = getattr(d, real_path) else: @@ -1172,7 +1250,8 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error return geo_indices def __getstate__(self): - removals = ["get_%s_display" % k for k, v in self._fields.items() if v.choices] + removals = ("get_%s_display" % k + for k, v in self._fields.items() if v.choices) for k in removals: if hasattr(self, k): delattr(self, k) @@ -1183,9 +1262,12 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error self.__set_field_display() def __set_field_display(self): + """Dynamically set the display value for a field with choices""" for attr_name, field in self._fields.items(): - if field.choices: # dynamically adds a way to get the display value for a field with choices - setattr(self, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) + if field.choices: + setattr(self, + 'get_%s_display' % attr_name, + partial(self.__get_field_display, field=field)) def __get_field_display(self, field): """Returns the display value for a choice field""" @@ -1257,6 +1339,11 @@ Invalid data to create a `%s` instance.\n%s""".strip() % (cls._class_name, error return hash(self.pk) +class BasesTuple(tuple): + """Special class to handle introspection of bases tuple in __new__""" + pass + + class BaseList(list): """A special list so we can watch any changes """ @@ -1378,5 +1465,18 @@ class BaseDict(dict): self._instance._mark_as_changed(self._name) -def subclass_exception(name, parents, module): - return type(name, parents, {'__module__': module}) +def _import_class(cls_name): + """Cached mechanism for imports""" + if cls_name in _class_registry: + return _class_registry.get(cls_name) + if cls_name == 'Document': + from mongoengine.document import Document as cls + elif cls_name == 'EmbeddedDocument': + from mongoengine.document import EmbeddedDocument as cls + elif cls_name == 'DictField': + from mongoengine.fields import DictField as cls + elif cls_name == 'OperationError': + from queryset import OperationError as cls + + _class_registry[cls_name] = cls + return cls diff --git a/mongoengine/document.py b/mongoengine/document.py index 46111be8..f6b0b511 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -360,7 +360,9 @@ class Document(BaseDocument): """This method registers the delete rules to apply when removing this object. """ - cls._meta['delete_rules'][(document_cls, field_name)] = rule + delete_rules = cls._meta.get('delete_rules') or {} + delete_rules[(document_cls, field_name)] = rule + cls._meta['delete_rules'] = delete_rules @classmethod def drop_collection(cls): @@ -392,6 +394,7 @@ class DynamicDocument(Document): __metaclass__ = TopLevelDocumentMetaclass _dynamic = True + meta = {'abstract': True} def __delattr__(self, *args, **kwargs): """Deletes the attribute by setting to None and allowing _delta to unset diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 448b0c6a..42025134 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -747,7 +747,6 @@ class ReferenceField(BaseField): def prepare_query_value(self, op, value): if value is None: return None - return self.to_mongo(value) def validate(self, value): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index dbd8ad46..446a42ab 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -352,7 +352,7 @@ class QuerySet(object): # If inheritance is allowed, only return instances and instances of # subclasses of the class being used - if document._meta.get('allow_inheritance'): + if document._meta.get('allow_inheritance') != False: self._initial_query = {'_types': self._document._class_name} self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None @@ -442,7 +442,7 @@ class QuerySet(object): """ background = self._document._meta.get('index_background', False) drop_dups = self._document._meta.get('index_drop_dups', False) - index_opts = self._document._meta.get('index_opts', {}) + index_opts = self._document._meta.get('index_opts') or {} index_types = self._document._meta.get('index_types', True) # determine if an index which we are creating includes @@ -450,6 +450,7 @@ class QuerySet(object): # an extra index on _type, as mongodb will use the existing # index to service queries against _type types_indexed = False + def includes_types(fields): first_field = None if len(fields): @@ -466,8 +467,8 @@ class QuerySet(object): background=background, drop_dups=drop_dups, **index_opts) # Ensure document-defined indexes are created - if self._document._meta['indexes']: - for spec in self._document._meta['indexes']: + if self._document._meta['index_specs']: + for spec in self._document._meta['index_specs']: types_indexed = types_indexed or includes_types(spec['fields']) opts = index_opts.copy() opts['unique'] = spec.get('unique', False) @@ -498,7 +499,10 @@ class QuerySet(object): index_list = [] direction = None - use_types = doc_cls._meta.get('allow_inheritance', True) + + allow_inheritance = doc_cls._meta.get('allow_inheritance') != False + use_types = allow_inheritance + for key in spec['fields']: # Get ASCENDING direction from +, DESCENDING from -, and GEO2D from * direction = pymongo.ASCENDING @@ -516,7 +520,8 @@ class QuerySet(object): key = '_id' else: fields = QuerySet._lookup_field(doc_cls, parts) - parts = [field if field == '_id' else field.db_field for field in fields] + parts = [field if field == '_id' else field.db_field + for field in fields] key = '.'.join(parts) index_list.append((key, direction)) @@ -530,8 +535,9 @@ class QuerySet(object): # If _types is being used, prepend it to every specified index index_types = doc_cls._meta.get('index_types', True) - allow_inheritance = doc_cls._meta.get('allow_inheritance') - if spec.get('types', index_types) and allow_inheritance and use_types and direction is not pymongo.GEO2D: + + if (spec.get('types', index_types) and allow_inheritance and use_types + and direction is not pymongo.GEO2D): index_list.insert(0, ('_types', 1)) spec['fields'] = index_list @@ -1329,9 +1335,10 @@ class QuerySet(object): """ doc = self._document + delete_rules = doc._meta.get('delete_rules') or {} # Check for DENY rules before actually deleting/nullifying any other # references - for rule_entry in doc._meta['delete_rules']: + for rule_entry in delete_rules: document_cls, field_name = rule_entry rule = doc._meta['delete_rules'][rule_entry] if rule == DENY and document_cls.objects(**{field_name + '__in': self}).count() > 0: @@ -1339,12 +1346,14 @@ class QuerySet(object): (document_cls.__name__, field_name) raise OperationError(msg) - for rule_entry in doc._meta['delete_rules']: + for rule_entry in delete_rules: document_cls, field_name = rule_entry rule = doc._meta['delete_rules'][rule_entry] if rule == CASCADE: ref_q = document_cls.objects(**{field_name + '__in': self}) - if doc != document_cls or (doc == document_cls and ref_q.count() > 0): + ref_q_count = ref_q.count() + if (doc != document_cls and ref_q_count > 0 + or (doc == document_cls and ref_q_count > 0)): ref_q.delete(safe=safe) elif rule == NULLIFY: document_cls.objects(**{field_name + '__in': self}).update( @@ -1915,7 +1924,7 @@ class QuerySetManager(object): return self # owner is the document that contains the QuerySetManager - queryset_class = owner._meta['queryset_class'] or QuerySet + queryset_class = owner._meta.get('queryset_class') or QuerySet queryset = queryset_class(owner, owner._get_collection()) if self.get_queryset: arg_count = self.get_queryset.func_code.co_argcount diff --git a/setup.cfg b/setup.cfg index f28e6687..d95a9176 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [nosetests] -verbosity = 2 +verbosity = 3 detailed-errors = 1 #with-coverage = 1 #cover-erase = 1 @@ -8,4 +8,4 @@ detailed-errors = 1 #cover-package = mongoengine py3where = build where = tests -#tests = test_bugfix.py \ No newline at end of file +#tests = test_bugfix.py \ No newline at end of file diff --git a/tests/test_django.py b/tests/test_django.py index 678d7cfe..398fd3e0 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -18,8 +18,8 @@ try: from mongoengine.django.sessions import SessionStore, MongoSession except Exception, err: if PY3: - SessionTestsMixin = type #dummy value so no error - SessionStore = None #dummy value so no error + SessionTestsMixin = type # dummy value so no error + SessionStore = None # dummy value so no error else: raise err diff --git a/tests/test_document.py b/tests/test_document.py index 8654e3f8..da329ca5 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -387,19 +387,6 @@ class DocumentTest(unittest.TestCase): meta = {'allow_inheritance': False} self.assertRaises(ValueError, create_employee_class) - # Test the same for embedded documents - class Comment(EmbeddedDocument): - content = StringField() - meta = {'allow_inheritance': False} - - def create_special_comment(): - class SpecialComment(Comment): - pass - self.assertRaises(ValueError, create_special_comment) - - comment = Comment(content='test') - self.assertFalse('_cls' in comment.to_mongo()) - self.assertFalse('_types' in comment.to_mongo()) def test_allow_inheritance_abstract_document(self): """Ensure that abstract documents can set inheritance rules and that @@ -491,9 +478,20 @@ class DocumentTest(unittest.TestCase): """Ensure that a document superclass can be marked as abstract thereby not using it as the name for the collection.""" + defaults = {'index_background': True, + 'index_drop_dups': True, + 'index_opts': {'hello': 'world'}, + 'allow_inheritance': True, + 'queryset_class': 'QuerySet', + 'db_alias': 'myDB', + 'shard_key': ('hello', 'world')} + + meta_settings = {'abstract': True} + meta_settings.update(defaults) + class Animal(Document): name = StringField() - meta = {'abstract': True} + meta = meta_settings class Fish(Animal): pass class Guppy(Fish): pass @@ -502,6 +500,10 @@ class DocumentTest(unittest.TestCase): meta = {'abstract': True} class Human(Mammal): pass + for k, v in defaults.iteritems(): + for cls in [Animal, Fish, Guppy]: + self.assertEqual(cls._meta[k], v) + self.assertFalse('collection' in Animal._meta) self.assertFalse('collection' in Mammal._meta) @@ -564,6 +566,7 @@ class DocumentTest(unittest.TestCase): class Drink(Document): name = StringField() + meta = {'allow_inheritance': True} class Drinker(Document): drink = GenericReferenceField() @@ -799,7 +802,6 @@ class DocumentTest(unittest.TestCase): user_guid = StringField(required=True) - class Person(UserBase): meta = { 'indexes': ['name'], @@ -1325,7 +1327,6 @@ class DocumentTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) - self.assertFalse('collection' in Comment._meta) def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. @@ -2504,32 +2505,24 @@ class DocumentTest(unittest.TestCase): def test_mixins_dont_add_to_types(self): - class Bob(Document): name = StringField() - - Bob.drop_collection() - - p = Bob(name="Rozza") - p.save() - Bob.drop_collection() + class Mixin(object): + name = StringField() class Person(Document, Mixin): pass Person.drop_collection() - p = Person(name="Rozza") - p.save() - self.assertEqual(p._fields.keys(), ['name', 'id']) + self.assertEqual(Person._fields.keys(), ['name', 'id']) + + Person(name="Rozza").save() collection = self.db[Person._get_collection_name()] obj = collection.find_one() self.assertEqual(obj['_cls'], 'Person') self.assertEqual(obj['_types'], ['Person']) - - self.assertEqual(Person.objects.count(), 1) - rozza = Person.objects.get(name="Rozza") Person.drop_collection() @@ -2668,16 +2661,18 @@ class DocumentTest(unittest.TestCase): self.assertEqual(len(BlogPost.objects), 0) def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): - """Ensure that a referenced document is also deleted upon deletion. + """Ensure that a referenced document is also deleted upon deletion for + complex fields. """ - class BlogPost(Document): + class BlogPost2(Document): content = StringField() authors = ListField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) reviewers = ListField(ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) self.Person.drop_collection() - BlogPost.drop_collection() + + BlogPost2.drop_collection() author = self.Person(name='Test User') author.save() @@ -2685,18 +2680,19 @@ class DocumentTest(unittest.TestCase): reviewer = self.Person(name='Re Viewer') reviewer.save() - post = BlogPost(content= 'Watched some TV') + post = BlogPost2(content='Watched some TV') post.authors = [author] post.reviewers = [reviewer] post.save() + # Deleting the reviewer should have no effect on the BlogPost2 reviewer.delete() - self.assertEqual(len(BlogPost.objects), 1) # No effect on the BlogPost - self.assertEqual(BlogPost.objects.get().reviewers, []) + self.assertEqual(len(BlogPost2.objects), 1) + self.assertEqual(BlogPost2.objects.get().reviewers, []) # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() - self.assertEqual(len(BlogPost.objects), 0) + self.assertEqual(len(BlogPost2.objects), 0) def test_two_way_reverse_delete_rule(self): """Ensure that Bi-Directional relationships work with @@ -3074,7 +3070,7 @@ class DocumentTest(unittest.TestCase): self.assertEqual('testdb-1', B._meta.get('db_alias')) def test_db_ref_usage(self): - """ DB Ref usage in __raw__ queries """ + """ DB Ref usage in dict_fields""" class User(Document): name = StringField() @@ -3216,7 +3212,6 @@ class ValidatorErrorTest(unittest.TestCase): one = Doc.objects.filter(**{'hello world': 1}).count() self.assertEqual(1, one) - def test_fields_rewrite(self): class BasePerson(Document): name = StringField() @@ -3226,7 +3221,6 @@ class ValidatorErrorTest(unittest.TestCase): class Person(BasePerson): name = StringField(required=True) - p = Person(age=15) self.assertRaises(ValidationError, p.validate) diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 979dc6f1..591a82ac 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -24,6 +24,8 @@ class QuerySetTest(unittest.TestCase): name = StringField() age = IntField() meta = {'allow_inheritance': True} + + Person.drop_collection() self.Person = Person def test_initialisation(self):