From 32bab13a8acc0e091094468be0a6d890719bfec5 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 24 May 2011 12:50:48 +0100 Subject: [PATCH] Added MapField, similar to DictField Similar to DictField except the value of each entry is always of a certain (declared) field type. Thanks again to @theojulienne for the code #108 --- mongoengine/fields.py | 93 ++++++++++++++++++++++++++++++++++++++++++- tests/fields.py | 61 ++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0cc8219b..d1f9b665 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -17,7 +17,7 @@ import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', - 'ObjectIdField', 'ReferenceField', 'ValidationError', + 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] @@ -451,6 +451,97 @@ class DictField(BaseField): def lookup_member(self, member_name): return self.basecls(db_field=member_name) + +class MapField(BaseField): + """A field that maps a name to a specified field type. Similar to + a DictField, except the 'value' of each item must match the specified + field type. + + .. versionadded:: 0.5 + """ + + def __init__(self, field=None, *args, **kwargs): + if not isinstance(field, BaseField): + raise ValidationError('Argument to MapField constructor must be ' + 'a valid field') + self.field = field + kwargs.setdefault('default', lambda: {}) + super(MapField, self).__init__(*args, **kwargs) + + def validate(self, value): + """Make sure that a list of valid fields is being used. + """ + if not isinstance(value, dict): + raise ValidationError('Only dictionaries may be used in a ' + 'DictField') + + if any(('.' in k or '$' in k) for k in value): + raise ValidationError('Invalid dictionary key name - keys may not ' + 'contain "." or "$" characters') + + try: + [self.field.validate(item) for item in value.values()] + except Exception, err: + raise ValidationError('Invalid MapField item (%s)' % str(item)) + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + if isinstance(self.field, ReferenceField): + referenced_type = self.field.document_type + # Get value from document instance if available + value_dict = instance._data.get(self.name) + if value_dict: + deref_dict = [] + for key,value in value_dict.iteritems(): + # Dereference DBRefs + if isinstance(value, (pymongo.dbref.DBRef)): + value = _get_db().dereference(value) + deref_dict[key] = referenced_type._from_son(value) + else: + deref_dict[key] = value + instance._data[self.name] = deref_dict + + if isinstance(self.field, GenericReferenceField): + value_dict = instance._data.get(self.name) + if value_dict: + deref_dict = [] + for key,value in value_dict.iteritems(): + # Dereference DBRefs + if isinstance(value, (dict, pymongo.son.SON)): + deref_dict[key] = self.field.dereference(value) + else: + deref_dict[key] = value + instance._data[self.name] = deref_dict + + return super(MapField, self).__get__(instance, owner) + + def to_python(self, value): + return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + + def to_mongo(self, value): + return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + + def prepare_query_value(self, op, value): + return self.field.prepare_query_value(op, value) + + def lookup_member(self, member_name): + return self.field.lookup_member(member_name) + + def _set_owner_document(self, owner_document): + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + + class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). diff --git a/tests/fields.py b/tests/fields.py index 38409b6a..62bd3a1f 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -825,5 +825,66 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + def create_invalid_class(): + class NoDeclaredType(Document): + mapping = MapField() + + self.assertRaises(ValidationError, create_invalid_class) + + Simple.drop_collection() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + def create_invalid_mapping(): + e.mapping['someint'] = 123 + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Extensible.drop_collection() + + if __name__ == '__main__': unittest.main()