From 4b9bacf7316275d2c0c1efa7b5850b98374679cc Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 17:21:54 +0100 Subject: [PATCH] Added ComplexBaseField * Handles the efficient lazy dereferencing of DBrefs. * Handles complex nested values in ListFields and DictFields * Allows for both strictly declared ListFields and DictFields where the embedded value must be of a field type or no restrictions where the values can be a mix of field types / values. * Handles DBrefences of documents where allow_inheritance = False. --- mongoengine/base.py | 206 +++++++++++++++++++++------- mongoengine/fields.py | 102 +++----------- mongoengine/queryset.py | 47 +++++-- tests/dereference.py | 112 +++++++++++++++- tests/fields.py | 287 +++++++++++++++++++++++++++++++--------- 5 files changed, 555 insertions(+), 199 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 8a0a1f23..a22795c7 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -132,15 +132,19 @@ class BaseField(object): self.validate(value) -class DereferenceBaseField(BaseField): - """Handles the lazy dereferencing of a queryset. Will dereference all +class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. + + Allows for nesting of embedded documents inside complex types. + Handles the lazy dereferencing of a queryset by lazily dereferencing all items in a list / dict rather than one at a time. """ + field = None + def __get__(self, instance, owner): """Descriptor to automatically dereference references. """ - from fields import ReferenceField, GenericReferenceField from connection import _get_db if instance is None: @@ -149,68 +153,175 @@ class DereferenceBaseField(BaseField): # Get value from document instance if available value_list = instance._data.get(self.name) - if not value_list: - return super(DereferenceBaseField, self).__get__(instance, owner) + if not value_list or isinstance(value_list, basestring): + return super(ComplexBaseField, self).__get__(instance, owner) is_list = False if not hasattr(value_list, 'items'): is_list = True value_list = dict([(k,v) for k,v in enumerate(value_list)]) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - dbref = {} - collections = {} + for k,v in value_list.items(): + if isinstance(v, dict) and '_cls' in v and '_ref' not in v: + value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v) - for k, v in value_list.items(): - dbref[k] = v - # Save any DBRefs + # Handle all dereferencing + db = _get_db() + dbref = {} + collections = {} + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + # direct reference (DBRef) + collections.setdefault(v.collection, []).append((k, v)) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + # generic reference + collection = get_document(v['_cls'])._meta['collection'] + collections.setdefault(collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = {} + for k, v in dbrefs: if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) + # direct reference (DBRef), has no _cls information + id_map[v.id] = (k, None) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + # generic reference - includes _cls information + id_map[v['_ref'].id] = (k, get_document(v['_cls'])) - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - dbref[key] = get_document(ref['_cls'])._from_son(ref) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key, doc_cls = id_map[ref['_id']] + if not doc_cls: # If no doc_cls get it from the referenced doc + doc_cls = get_document(ref['_cls']) + dbref[key] = doc_cls._from_son(ref) - if is_list: - dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] - instance._data[self.name] = dbref + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + return super(ComplexBaseField, self).__get__(instance, owner) - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - dbref = {} - classes = {} + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + from mongoengine import Document - for k, v in value_list: - dbref[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) + if isinstance(value, basestring): + return value - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) + if hasattr(value, 'to_python'): + return value.to_python() - for ref in references: - key = id_map[ref['_id']] - dbref[key] = doc_cls._from_son(ref) + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k,v) for k,v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value - if is_list: - dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + if self.field: + value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) + else: + value_dict = {} + for k,v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + collection = v._meta['collection'] + value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + elif hasattr(v, 'to_python'): + value_dict[k] = v.to_python() + else: + value_dict[k] = self.to_python(v) - instance._data[self.name] = dbref + if is_list: # Convert back to a list + return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] + return value_dict - return super(DereferenceBaseField, self).__get__(instance, owner) + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + from mongoengine import Document + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_mongo'): + return value.to_mongo() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k,v) for k,v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) + else: + value_dict = {} + for k,v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + + # If its a document that is not inheritable it won't have + # _types / _cls data so make it a generic reference allows + # us to dereference + meta = getattr(v, 'meta', getattr(v, '_meta', {})) + if meta and not meta['allow_inheritance'] and not self.field: + from fields import GenericReferenceField + value_dict[k] = GenericReferenceField().to_mongo(v) + else: + collection = v._meta['collection'] + value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + elif hasattr(v, 'to_mongo'): + value_dict[k] = v.to_mongo() + else: + value_dict[k] = self.to_mongo(v) + + if is_list: # Convert back to a list + return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] + return value_dict + + def validate(self, value): + """If field provided ensure the value is valid. + """ + if self.field: + try: + if hasattr(value, 'iteritems'): + [self.field.validate(v) for k,v in value.iteritems()] + else: + [self.field.validate(v) for v in value] + except Exception, err: + raise ValidationError('Invalid %s item (%s)' % ( + self.field.__class__.__name__, str(v))) + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def lookup_member(self, member_name): + if self.field: + return self.field.lookup_member(member_name) + return None + + def _set_owner_document(self, owner_document): + if self.field: + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) class ObjectIdField(BaseField): @@ -219,7 +330,6 @@ class ObjectIdField(BaseField): def to_python(self, value): return value - # return unicode(value) def to_mongo(self, value): if not isinstance(value, pymongo.objectid.ObjectId): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1995d345..f9b2580b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,4 @@ -from base import (BaseField, DereferenceBaseField, ObjectIdField, +from base import (BaseField, ComplexBaseField, ObjectIdField, ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument @@ -301,6 +301,8 @@ class EmbeddedDocumentField(BaseField): return value def to_mongo(self, value): + if isinstance(value, basestring): + return value return self.document_type.to_mongo(value) def validate(self, value): @@ -320,7 +322,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(DereferenceBaseField): +class ListField(ComplexBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -328,48 +330,25 @@ class ListField(DereferenceBaseField): # ListFields cannot be indexed with _types - MongoDB doesn't support this _index_with_types = False - def __init__(self, field, **kwargs): - if not isinstance(field, BaseField): - raise ValidationError('Argument to ListField constructor must be ' - 'a valid field') + def __init__(self, field=None, **kwargs): self.field = field kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def to_python(self, value): - return [self.field.to_python(item) for item in value] - - def to_mongo(self, value): - return [self.field.to_mongo(item) for item in value] - def validate(self, value): """Make sure that a list of valid fields is being used. """ if not isinstance(value, (list, tuple)): raise ValidationError('Only lists and tuples may be used in a ' 'list field') - - try: - [self.field.validate(item) for item in value] - except Exception, err: - raise ValidationError('Invalid ListField item (%s)' % str(item)) + super(ListField, self).validate(value) def prepare_query_value(self, op, value): - if op in ('set', 'unset'): - return [self.field.prepare_query_value(op, v) for v in value] - return self.field.prepare_query_value(op, value) - - def lookup_member(self, member_name): - return self.field.lookup_member(member_name) - - def _set_owner_document(self, owner_document): - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) + if self.field: + if op in ('set', 'unset') and not isinstance(value, basestring): + return [self.field.prepare_query_value(op, v) for v in value] + return self.field.prepare_query_value(op, value) + return super(ListField, self).prepare_query_value(op, value) class SortedListField(ListField): @@ -388,20 +367,21 @@ class SortedListField(ListField): super(SortedListField, self).__init__(field, **kwargs) def to_mongo(self, value): + value = super(SortedListField, self).to_mongo(value) if self._ordering is not None: - return sorted([self.field.to_mongo(item) for item in value], - key=itemgetter(self._ordering)) - return sorted([self.field.to_mongo(item) for item in value]) + return sorted(value, key=itemgetter(self._ordering)) + return sorted(value) -class DictField(BaseField): +class DictField(ComplexBaseField): """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. .. versionadded:: 0.3 """ - def __init__(self, basecls=None, *args, **kwargs): + def __init__(self, basecls=None, field=None, *args, **kwargs): + self.field = field self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) kwargs.setdefault('default', lambda: {}) @@ -417,6 +397,7 @@ class DictField(BaseField): if any(('.' in k or '$' in k) for k in value): raise ValidationError('Invalid dictionary key name - keys may not ' 'contain "." or "$" characters') + super(DictField, self).validate(value) def lookup_member(self, member_name): return DictField(basecls=self.basecls, db_field=member_name) @@ -432,7 +413,7 @@ class DictField(BaseField): return super(DictField, self).prepare_query_value(op, value) -class MapField(DereferenceBaseField): +class MapField(DictField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -444,50 +425,7 @@ class MapField(DereferenceBaseField): if not isinstance(field, BaseField): raise ValidationError('Argument to MapField constructor must be ' 'a valid field') - self.field = field - kwargs.setdefault('default', lambda: {}) - super(MapField, self).__init__(*args, **kwargs) - - def validate(self, value): - """Make sure that a list of valid fields is being used. - """ - if not isinstance(value, dict): - raise ValidationError('Only dictionaries may be used in a ' - 'DictField') - - if any(('.' in k or '$' in k) for k in value): - raise ValidationError('Invalid dictionary key name - keys may not ' - 'contain "." or "$" characters') - - try: - [self.field.validate(item) for item in value.values()] - except Exception, err: - raise ValidationError('Invalid MapField item (%s)' % str(item)) - - def to_python(self, value): - return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) - - def to_mongo(self, value): - return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) - - def prepare_query_value(self, op, value): - if op not in ('set', 'unset'): - return self.field.prepare_query_value(op, value) - for key in value: - value[key] = self.field.prepare_query_value(op, value[key]) - return value - - def lookup_member(self, member_name): - return self.field.lookup_member(member_name) - - def _set_owner_document(self, owner_document): - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) + super(MapField, self).__init__(field=field, *args, **kwargs) class ReferenceField(BaseField): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1dfe55af..666567e2 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -549,11 +549,12 @@ class QuerySet(object): parts = [parts] fields = [] field = None + for field_name in parts: # Handle ListField indexing: if field_name.isdigit(): try: - field = field.field + new_field = field.field except AttributeError, err: raise InvalidQueryError( "Can't use index on unsubscriptable field (%s)" % err) @@ -567,11 +568,17 @@ class QuerySet(object): field = document._fields[field_name] else: # Look up subfield on the previous field - field = field.lookup_member(field_name) - if field is None: + new_field = field.lookup_member(field_name) + from base import ComplexBaseField + if not new_field and isinstance(field, ComplexBaseField): + fields.append(field_name) + continue + elif not new_field: raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) + % field_name) + field = new_field # update field to the new field type fields.append(field) + return fields @classmethod @@ -615,14 +622,33 @@ class QuerySet(object): if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [field.db_field for field in fields] + parts = [] + + cleaned_fields = [] + append_field = True + for field in fields: + if isinstance(field, str): + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) # Convert value to proper value - field = fields[-1] + field = cleaned_fields[-1] + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += match_operators if op in singular_ops: - value = field.prepare_query_value(op, value) + if isinstance(field, basestring): + if op in match_operators and isinstance(value, basestring): + from mongoengine import StringField + value = StringField().prepare_query_value(op, value) + else: + value = field + else: + value = field.prepare_query_value(op, value) elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] @@ -1170,14 +1196,19 @@ class QuerySet(object): fields = QuerySet._lookup_field(_doc_cls, parts) parts = [] + cleaned_fields = [] + append_field = True for field in fields: if isinstance(field, str): parts.append(field) + append_field = False else: parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) # Convert value to proper value - field = fields[-1] + field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull', 'addToSet'): value = field.prepare_query_value(op, value) diff --git a/tests/dereference.py b/tests/dereference.py index b6cee89e..68792721 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -122,6 +122,64 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_list_field_complex(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() @@ -156,10 +214,13 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + User.drop_collection() Group.drop_collection() - def ztest_generic_reference_dict_field(self): + def test_dict_field(self): class UserA(Document): name = StringField() @@ -206,6 +267,9 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + group.members = {} group.save() @@ -218,11 +282,54 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 1) + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() Group.drop_collection() + def test_dict_field_no_field_inheritance(self): + + class UserA(Document): + name = StringField() + meta = {'allow_inheritance': False} + + class Group(Document): + members = DictField() + + UserA.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + members += [a] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + + UserA.drop_collection() + Group.drop_collection() + def test_generic_reference_map_field(self): class UserA(Document): @@ -270,6 +377,9 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + group.members = {} group.save() diff --git a/tests/fields.py b/tests/fields.py index d8970043..4d51ed51 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -322,6 +322,108 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() + def test_list_field(self): + """Ensure that list types work as expected. + """ + class BlogPost(Document): + info = ListField() + + BlogPost.drop_collection() + + post = BlogPost() + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) + + post.info = {'title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = ['test'] + post.save() + + post = BlogPost() + post.info = [{'test': 'test'}] + post.save() + + post = BlogPost() + post.info = [{'test': 3}] + post.save() + + + self.assertEquals(BlogPost.objects.count(), 3) + self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1) + self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1) + + # Confirm handles non strings or non existing keys + self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) + self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) + BlogPost.drop_collection() + + def test_list_field_Strict(self): + """Ensure that list field handles validation if provided a strict field type.""" + + class Simple(Document): + mapping = ListField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping = [1] + e.save() + + def create_invalid_mapping(): + e.mapping = ["abc"] + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Simple.drop_collection() + + def test_list_field_complex(self): + """Ensure that the list fields can handle the complex types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = ListField() + + Simple.drop_collection() + e = Simple() + e.mapping.append(StringSetting(value='foo')) + e.mapping.append(IntegerSetting(value=42)) + e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, + 'complex': IntegerSetting(value=42), 'list': + [IntegerSetting(value=42), StringSetting(value='foo')]}) + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping[0], StringSetting)) + self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) + + # Test querying + self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) + self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1) + + Simple.objects().update( + set__mapping__2__list__1=StringSetting(value='Boo')) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) + + Simple.drop_collection() + def test_dict_field(self): """Ensure that dict types work as expected. """ @@ -363,6 +465,131 @@ class FieldTest(unittest.TestCase): self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) BlogPost.drop_collection() + def test_dictfield_Strict(self): + """Ensure that dict field handles validation if provided a strict field type.""" + + class Simple(Document): + mapping = DictField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Simple.drop_collection() + + def test_dictfield_complex(self): + """Ensure that the dict field can handle the complex types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = DictField() + + Simple.drop_collection() + e = Simple() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, + 'complex': IntegerSetting(value=42), 'list': + [IntegerSetting(value=42), StringSetting(value='foo')]} + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + # Test querying + self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update( + set__mapping={"someint": IntegerSetting(value=10)}) + Simple.objects().update( + set__mapping__nested_dict__list__1=StringSetting(value='Boo')) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + + Simple.drop_collection() + + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + def create_invalid_class(): + class NoDeclaredType(Document): + mapping = MapField() + + self.assertRaises(ValidationError, create_invalid_class) + + Simple.drop_collection() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + def create_invalid_mapping(): + e.mapping['someint'] = 123 + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Extensible.drop_collection() + def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to embedded document fields. @@ -933,66 +1160,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) - def test_mapfield(self): - """Ensure that the MapField handles the declared type.""" - - class Simple(Document): - mapping = MapField(IntField()) - - Simple.drop_collection() - - e = Simple() - e.mapping['someint'] = 1 - e.save() - - def create_invalid_mapping(): - e.mapping['somestring'] = "abc" - e.save() - - self.assertRaises(ValidationError, create_invalid_mapping) - - def create_invalid_class(): - class NoDeclaredType(Document): - mapping = MapField() - - self.assertRaises(ValidationError, create_invalid_class) - - Simple.drop_collection() - - def test_complex_mapfield(self): - """Ensure that the MapField can handle complex declared types.""" - - class SettingBase(EmbeddedDocument): - pass - - class StringSetting(SettingBase): - value = StringField() - - class IntegerSetting(SettingBase): - value = IntField() - - class Extensible(Document): - mapping = MapField(EmbeddedDocumentField(SettingBase)) - - Extensible.drop_collection() - - e = Extensible() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.save() - - e2 = Extensible.objects.get(id=e.id) - self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) - self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) - - def create_invalid_mapping(): - e.mapping['someint'] = 123 - e.save() - - self.assertRaises(ValidationError, create_invalid_mapping) - - Extensible.drop_collection() - if __name__ == '__main__': unittest.main()