diff --git a/mongoengine/base.py b/mongoengine/base.py index 62bf80b5..83fd34ee 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -102,6 +102,7 @@ class DocumentMetaclass(type): doc_fields = {} class_name = [name] superclasses = {} + simple_class = True for base in bases: # Include all fields present in superclasses if hasattr(base, '_fields'): @@ -110,6 +111,29 @@ class DocumentMetaclass(type): # Get superclasses from superclass superclasses[base._class_name] = base superclasses.update(base._superclasses) + + if hasattr(base, '_meta'): + # Ensure that the Document class may be subclassed - + # inheritance may be disabled to remove dependency on + # additional fields _cls and _types + if base._meta.get('allow_inheritance', True) == False: + raise ValueError('Document %s may not be subclassed' % + base.__name__) + else: + simple_class = False + + meta = attrs.get('_meta', attrs.get('meta', {})) + + if 'allow_inheritance' not in meta: + meta['allow_inheritance'] = True + + # Only simple classes - direct subclasses of Document - may set + # allow_inheritance to False + if not simple_class and not meta['allow_inheritance']: + raise ValueError('Only direct subclasses of Document may set ' + '"allow_inheritance" to False') + attrs['_meta'] = meta + attrs['_class_name'] = '.'.join(reversed(class_name)) attrs['_superclasses'] = superclasses @@ -142,21 +166,12 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): collection = name.lower() - simple_class = True id_field = None base_indexes = [] # Subclassed documents inherit collection from superclass for base in bases: if hasattr(base, '_meta') and 'collection' in base._meta: - # Ensure that the Document class may be subclassed - - # inheritance may be disabled to remove dependency on - # additional fields _cls and _types - if base._meta.get('allow_inheritance', True) == False: - raise ValueError('Document %s may not be subclassed' % - base.__name__) - else: - simple_class = False collection = base._meta['collection'] id_field = id_field or base._meta.get('id_field') @@ -164,7 +179,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): meta = { 'collection': collection, - 'allow_inheritance': True, 'max_documents': None, 'max_size': None, 'ordering': [], # default ordering applied at runtime @@ -174,12 +188,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Apply document-defined meta options meta.update(attrs.get('meta', {})) - - # Only simple classes - direct subclasses of Document - may set - # allow_inheritance to False - if not simple_class and not meta['allow_inheritance']: - raise ValueError('Only direct subclasses of Document may set ' - '"allow_inheritance" to False') attrs['_meta'] = meta # Set up collection manager, needs the class to have fields so use @@ -337,8 +345,12 @@ class BaseDocument(object): if value is not None: data[field.name] = field.to_mongo(value) # Only add _cls and _types if allow_inheritance is not False - if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', True) == False): + #if not (hasattr(self, '_meta') and + # self._meta.get('allow_inheritance', True) == False): + ah = True + if hasattr(self, '_meta'): + ah = self._meta.get('allow_inheritance', True) + if ah: data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] return data diff --git a/tests/document.py b/tests/document.py index 0c0b220b..1b58781c 100644 --- a/tests/document.py +++ b/tests/document.py @@ -156,6 +156,20 @@ class DocumentTest(unittest.TestCase): class Employee(self.Person): meta = {'allow_inheritance': False} self.assertRaises(ValueError, create_employee_class) + + # Test the same for embedded documents + class Comment(EmbeddedDocument): + content = StringField() + meta = {'allow_inheritance': False} + + def create_special_comment(): + class SpecialComment(Comment): + pass + self.assertRaises(ValueError, create_special_comment) + + comment = Comment(content='test') + self.assertFalse('_cls' in comment.to_mongo()) + self.assertFalse('_types' in comment.to_mongo()) def test_collection_name(self): """Ensure that a collection with a specified name may be used. @@ -391,7 +405,7 @@ class DocumentTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) - self.assertFalse(hasattr(Comment, '_meta')) + self.assertFalse('collection' in Comment._meta) def test_embedded_document_validation(self): """Ensure that embedded documents may be validated.