From 0526f577ffd39ead42936f029aed225f8dc1d48e Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 22 Aug 2012 09:27:18 +0100 Subject: [PATCH] Embedded Documents still can inherit fields MongoEngine/mongoengine#84 --- mongoengine/base.py | 22 ++++++++++++++-------- tests/test_document.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 3c3dcdc9..3d78aaa0 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -514,6 +514,18 @@ class DocumentMetaclass(type): if hasattr(base, '_fields'): doc_fields.update(base._fields) + # Standard object mixin - merge in any Fields + if not hasattr(base, '_meta'): + base_fields = {} + for attr_name, attr_value in base.__dict__.iteritems(): + if not isinstance(attr_value, BaseField): + continue + attr_value.name = attr_name + if not attr_value.db_field: + attr_value.db_field = attr_name + base_fields[attr_name] = attr_value + doc_fields.update(base_fields) + # Discover any document fields field_names = {} for attr_name, attr_value in attrs.iteritems(): @@ -537,9 +549,8 @@ class DocumentMetaclass(type): # Set _fields and db_field maps attrs['_fields'] = doc_fields - attrs['_db_field_map'] = dict( - ((k, v.db_field) for k, v in doc_fields.items() - if k != v.db_field)) + attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) + for k, v in doc_fields.iteritems()]) attrs['_reverse_db_field_map'] = dict( (v, k) for k, v in attrs['_db_field_map'].iteritems()) @@ -757,11 +768,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): if callable(collection): meta['collection'] = collection(base) - # Standard object mixin - merge in any Fields - if not hasattr(base, '_meta'): - attrs.update(dict([(k, v) for k, v in base.__dict__.items() - if issubclass(v.__class__, BaseField)])) - meta.merge(attrs.get('_meta', {})) # Top level meta # Only simple classes (direct subclasses of Document) diff --git a/tests/test_document.py b/tests/test_document.py index fdca119f..b8e3a775 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -2585,6 +2585,21 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() + def test_object_mixins(self): + + class NameMixin(object): + name = StringField() + + class Foo(EmbeddedDocument, NameMixin): + quantity = IntField() + + self.assertEqual(['name', 'quantity'], sorted(Foo._fields.keys())) + + class Bar(Document, NameMixin): + widgets = StringField() + + self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys())) + def test_mixin_inheritance(self): class BaseMixIn(object): count = IntField()