Inheritance is off by default (MongoEngine/mongoengine#122)

This commit is contained in:
Ross Lawley
2012-10-17 11:36:18 +00:00
parent 6f29d12386
commit 3d5b6ae332
20 changed files with 245 additions and 177 deletions

View File

@@ -2,7 +2,7 @@ from mongoengine.errors import NotRegistered
__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry')
ALLOW_INHERITANCE = True
ALLOW_INHERITANCE = False
_document_registry = {}

View File

@@ -50,7 +50,6 @@ class BaseDocument(object):
for key, value in values.iteritems():
key = self._reverse_db_field_map.get(key, key)
setattr(self, key, value)
# Set any get_fieldname_display methods
self.__set_field_display()
@@ -83,6 +82,11 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields'):
self._mark_as_changed(name)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
@@ -171,14 +175,24 @@ class BaseDocument(object):
"""Return data dictionary ready for use with MongoDB.
"""
data = {}
for field_name, field in self._fields.items():
value = getattr(self, field_name, None)
for field_name, field in self._fields.iteritems():
value = self._data.get(field_name, None)
if value is not None:
data[field.db_field] = field.to_mongo(value)
# Only add _cls if allow_inheritance is not False
if not (hasattr(self, '_meta') and
self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False):
value = field.to_mongo(value)
# Handle self generating fields
if value is None and field._auto_gen:
value = field.generate()
self._data[field_name] = value
if value is not None:
data[field.db_field] = value
# Only add _cls if allow_inheritance is True
if (hasattr(self, '_meta') and
self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == True):
data['_cls'] = self._class_name
if '_id' in data and data['_id'] is None:
del data['_id']
@@ -194,7 +208,7 @@ class BaseDocument(object):
are present.
"""
# Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name))
fields = [(field, self._data.get(name))
for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value
@@ -207,7 +221,7 @@ class BaseDocument(object):
errors[field.name] = error.errors or error
except (ValueError, AttributeError, AssertionError), error:
errors[field.name] = error
elif field.required:
elif field.required and not getattr(field, '_auto_gen', False):
errors[field.name] = ValidationError('Field is required',
field_name=field.name)
if errors:
@@ -313,6 +327,7 @@ class BaseDocument(object):
"""
# Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
@@ -370,7 +385,6 @@ class BaseDocument(object):
if hasattr(d, '_fields'):
field_name = d._reverse_db_field_map.get(db_field_name,
db_field_name)
if field_name in d._fields:
default = d._fields.get(field_name).default
else:
@@ -379,6 +393,7 @@ class BaseDocument(object):
if default is not None:
if callable(default):
default = default()
if default != value:
continue
@@ -399,15 +414,12 @@ class BaseDocument(object):
# get the class name from the document, falling back to the given
# class if unavailable
class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.items())
data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data)
to_str_keys_recursive(data)
if '_cls' in data:
del data['_cls']
# Return correct subclass for document type
if class_name != cls._class_name:
cls = get_document(class_name)
@@ -415,7 +427,7 @@ class BaseDocument(object):
changed_fields = []
errors_dict = {}
for field_name, field in cls._fields.items():
for field_name, field in cls._fields.iteritems():
if field.db_field in data:
value = data[field.db_field]
try:

View File

@@ -21,6 +21,7 @@ class BaseField(object):
name = None
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
@@ -36,7 +37,6 @@ class BaseField(object):
if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
warnings.warn(msg, DeprecationWarning)
self.name = None
self.required = required or primary_key
self.default = default
self.unique = bool(unique or unique_with)
@@ -62,7 +62,6 @@ class BaseField(object):
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available, if not use default
value = instance._data.get(self.name)
@@ -241,12 +240,21 @@ class ComplexBaseField(BaseField):
"""Convert a Python type to a MongoDB-compatible type.
"""
Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, basestring):
return value
if hasattr(value, 'to_mongo'):
return value.to_mongo()
if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, EmbeddedDocument)):
val['_cls'] = cls.__name__
return val
is_list = False
if not hasattr(value, 'items'):
@@ -258,10 +266,10 @@ class ComplexBaseField(BaseField):
if self.field:
value_dict = dict([(key, self.field.to_mongo(item))
for key, item in value.items()])
for key, item in value.iteritems()])
else:
value_dict = {}
for k, v in value.items():
for k, v in value.iteritems():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
@@ -274,16 +282,19 @@ class ComplexBaseField(BaseField):
meta = getattr(v, '_meta', {})
allow_inheritance = (
meta.get('allow_inheritance', ALLOW_INHERITANCE)
== False)
if allow_inheritance and not self.field:
GenericReferenceField = _import_class(
"GenericReferenceField")
== True)
if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo()
cls = v.__class__
val = v.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(v, (Document, EmbeddedDocument))):
val['_cls'] = cls.__name__
value_dict[k] = val
else:
value_dict[k] = self.to_mongo(v)

View File

@@ -34,6 +34,17 @@ class DocumentMetaclass(type):
if 'meta' in attrs:
attrs['_meta'] = attrs.pop('meta')
# EmbeddedDocuments should inherit meta data
if '_meta' not in attrs:
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
if hasattr(base, 'meta'):
meta.merge(base.meta)
elif hasattr(base, '_meta'):
meta.merge(base._meta)
attrs['_meta'] = meta
# Handle document Fields
# Merge all fields from subclasses
@@ -52,6 +63,7 @@ class DocumentMetaclass(type):
if not attr_value.db_field:
attr_value.db_field = attr_name
base_fields[attr_name] = attr_value
doc_fields.update(base_fields)
# Discover any document fields
@@ -98,15 +110,7 @@ class DocumentMetaclass(type):
# inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE)
if (not getattr(base, '_is_base_cls', True)
and allow_inheritance is None):
warnings.warn(
"%s uses inheritance, the default for "
"allow_inheritance is changing to off by default. "
"Please add it to the document meta." % name,
FutureWarning
)
elif (allow_inheritance == False and
if (allow_inheritance != True and
not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' %
base.__name__)
@@ -353,6 +357,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
if not new_class._meta.get('id_field'):
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id']
# Merge in exceptions with parent hierarchy