Added ComplexBaseField
* Handles the efficient lazy dereferencing of DBrefs. * Handles complex nested values in ListFields and DictFields * Allows for both strictly declared ListFields and DictFields where the embedded value must be of a field type or no restrictions where the values can be a mix of field types / values. * Handles DBrefences of documents where allow_inheritance = False.
This commit is contained in:
		| @@ -132,15 +132,19 @@ class BaseField(object): | |||||||
|         self.validate(value) |         self.validate(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DereferenceBaseField(BaseField): | class ComplexBaseField(BaseField): | ||||||
|     """Handles the lazy dereferencing of a queryset.  Will dereference all |     """Handles complex fields, such as lists / dictionaries. | ||||||
|  |  | ||||||
|  |     Allows for nesting of embedded documents inside complex types. | ||||||
|  |     Handles the lazy dereferencing of a queryset by lazily dereferencing all | ||||||
|     items in a list / dict rather than one at a time. |     items in a list / dict rather than one at a time. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     field = None | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |     def __get__(self, instance, owner): | ||||||
|         """Descriptor to automatically dereference references. |         """Descriptor to automatically dereference references. | ||||||
|         """ |         """ | ||||||
|         from fields import ReferenceField, GenericReferenceField |  | ||||||
|         from connection import _get_db |         from connection import _get_db | ||||||
|  |  | ||||||
|         if instance is None: |         if instance is None: | ||||||
| @@ -149,68 +153,175 @@ class DereferenceBaseField(BaseField): | |||||||
|  |  | ||||||
|         # Get value from document instance if available |         # Get value from document instance if available | ||||||
|         value_list = instance._data.get(self.name) |         value_list = instance._data.get(self.name) | ||||||
|         if not value_list: |         if not value_list or isinstance(value_list, basestring): | ||||||
|             return super(DereferenceBaseField, self).__get__(instance, owner) |             return super(ComplexBaseField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
|         is_list = False |         is_list = False | ||||||
|         if not hasattr(value_list, 'items'): |         if not hasattr(value_list, 'items'): | ||||||
|             is_list = True |             is_list = True | ||||||
|             value_list = dict([(k,v) for k,v in enumerate(value_list)]) |             value_list = dict([(k,v) for k,v in enumerate(value_list)]) | ||||||
|  |  | ||||||
|         if isinstance(self.field, ReferenceField) and value_list: |         for k,v in value_list.items(): | ||||||
|             db = _get_db() |             if isinstance(v, dict) and '_cls' in v and '_ref' not in v: | ||||||
|             dbref = {} |                 value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v) | ||||||
|             collections = {} |  | ||||||
|  |  | ||||||
|             for k, v in value_list.items(): |         # Handle all dereferencing | ||||||
|                 dbref[k] = v |         db = _get_db() | ||||||
|                 # Save any DBRefs |         dbref = {} | ||||||
|  |         collections = {} | ||||||
|  |         for k, v in value_list.items(): | ||||||
|  |             dbref[k] = v | ||||||
|  |             # Save any DBRefs | ||||||
|  |             if isinstance(v, (pymongo.dbref.DBRef)): | ||||||
|  |                 # direct reference (DBRef) | ||||||
|  |                 collections.setdefault(v.collection, []).append((k, v)) | ||||||
|  |             elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: | ||||||
|  |                 # generic reference | ||||||
|  |                 collection =  get_document(v['_cls'])._meta['collection'] | ||||||
|  |                 collections.setdefault(collection, []).append((k, v)) | ||||||
|  |  | ||||||
|  |         # For each collection get the references | ||||||
|  |         for collection, dbrefs in collections.items(): | ||||||
|  |             id_map = {} | ||||||
|  |             for k, v in dbrefs: | ||||||
|                 if isinstance(v, (pymongo.dbref.DBRef)): |                 if isinstance(v, (pymongo.dbref.DBRef)): | ||||||
|                     collections.setdefault(v.collection, []).append((k, v)) |                     # direct reference (DBRef), has no _cls information | ||||||
|  |                     id_map[v.id] = (k, None) | ||||||
|  |                 elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: | ||||||
|  |                     # generic reference - includes _cls information | ||||||
|  |                     id_map[v['_ref'].id] = (k, get_document(v['_cls'])) | ||||||
|  |  | ||||||
|             # For each collection get the references |             references = db[collection].find({'_id': {'$in': id_map.keys()}}) | ||||||
|             for collection, dbrefs in collections.items(): |             for ref in references: | ||||||
|                 id_map = dict([(v.id, k) for k, v in dbrefs]) |                 key, doc_cls = id_map[ref['_id']] | ||||||
|                 references = db[collection].find({'_id': {'$in': id_map.keys()}}) |                 if not doc_cls:  # If no doc_cls get it from the referenced doc | ||||||
|                 for ref in references: |                     doc_cls = get_document(ref['_cls']) | ||||||
|                     key = id_map[ref['_id']] |                 dbref[key] = doc_cls._from_son(ref) | ||||||
|                     dbref[key] = get_document(ref['_cls'])._from_son(ref) |  | ||||||
|  |  | ||||||
|             if is_list: |         if is_list: | ||||||
|                 dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] |             dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] | ||||||
|             instance._data[self.name] = dbref |         instance._data[self.name] = dbref | ||||||
|  |         return super(ComplexBaseField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
|         # Get value from document instance if available |     def to_python(self, value): | ||||||
|         if isinstance(self.field, GenericReferenceField) and value_list: |         """Convert a MongoDB-compatible type to a Python type. | ||||||
|             db = _get_db() |         """ | ||||||
|             value_list = [(k,v) for k,v in value_list.items()] |         from mongoengine import Document | ||||||
|             dbref = {} |  | ||||||
|             classes = {} |  | ||||||
|  |  | ||||||
|             for k, v in value_list: |         if isinstance(value, basestring): | ||||||
|                 dbref[k] = v |             return value | ||||||
|                 # Save any DBRefs |  | ||||||
|                 if isinstance(v, (dict, pymongo.son.SON)): |  | ||||||
|                     classes.setdefault(v['_cls'], []).append((k, v)) |  | ||||||
|  |  | ||||||
|             # For each collection get the references |         if hasattr(value, 'to_python'): | ||||||
|             for doc_cls, dbrefs in classes.items(): |             return value.to_python() | ||||||
|                 id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) |  | ||||||
|                 doc_cls = get_document(doc_cls) |  | ||||||
|                 collection = doc_cls._meta['collection'] |  | ||||||
|                 references = db[collection].find({'_id': {'$in': id_map.keys()}}) |  | ||||||
|  |  | ||||||
|                 for ref in references: |         is_list = False | ||||||
|                     key = id_map[ref['_id']] |         if not hasattr(value, 'items'): | ||||||
|                     dbref[key] = doc_cls._from_son(ref) |             try: | ||||||
|  |                 is_list = True | ||||||
|  |                 value = dict([(k,v) for k,v in enumerate(value)]) | ||||||
|  |             except TypeError:  # Not iterable return the value | ||||||
|  |                 return value | ||||||
|  |  | ||||||
|             if is_list: |         if self.field: | ||||||
|                 dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] |             value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) | ||||||
|  |         else: | ||||||
|  |             value_dict = {} | ||||||
|  |             for k,v in value.items(): | ||||||
|  |                 if isinstance(v, Document): | ||||||
|  |                     # We need the id from the saved object to create the DBRef | ||||||
|  |                     if v.pk is None: | ||||||
|  |                         raise ValidationError('You can only reference documents once ' | ||||||
|  |                                       'they have been saved to the database') | ||||||
|  |                     collection = v._meta['collection'] | ||||||
|  |                     value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) | ||||||
|  |                 elif hasattr(v, 'to_python'): | ||||||
|  |                     value_dict[k] = v.to_python() | ||||||
|  |                 else: | ||||||
|  |                     value_dict[k] = self.to_python(v) | ||||||
|  |  | ||||||
|             instance._data[self.name] = dbref |         if is_list:  # Convert back to a list | ||||||
|  |             return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] | ||||||
|  |         return value_dict | ||||||
|  |  | ||||||
|         return super(DereferenceBaseField, self).__get__(instance, owner) |     def to_mongo(self, value): | ||||||
|  |         """Convert a Python type to a MongoDB-compatible type. | ||||||
|  |         """ | ||||||
|  |         from mongoengine import Document | ||||||
|  |  | ||||||
|  |         if isinstance(value, basestring): | ||||||
|  |             return value | ||||||
|  |  | ||||||
|  |         if hasattr(value, 'to_mongo'): | ||||||
|  |             return value.to_mongo() | ||||||
|  |  | ||||||
|  |         is_list = False | ||||||
|  |         if not hasattr(value, 'items'): | ||||||
|  |             try: | ||||||
|  |                 is_list = True | ||||||
|  |                 value = dict([(k,v) for k,v in enumerate(value)]) | ||||||
|  |             except TypeError:  # Not iterable return the value | ||||||
|  |                 return value | ||||||
|  |  | ||||||
|  |         if self.field: | ||||||
|  |             value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) | ||||||
|  |         else: | ||||||
|  |             value_dict = {} | ||||||
|  |             for k,v in value.items(): | ||||||
|  |                 if isinstance(v, Document): | ||||||
|  |                     # We need the id from the saved object to create the DBRef | ||||||
|  |                     if v.pk is None: | ||||||
|  |                         raise ValidationError('You can only reference documents once ' | ||||||
|  |                                       'they have been saved to the database') | ||||||
|  |  | ||||||
|  |                     # If its a document that is not inheritable it won't have | ||||||
|  |                     # _types / _cls data so make it a generic reference allows | ||||||
|  |                     # us to dereference | ||||||
|  |                     meta = getattr(v, 'meta', getattr(v, '_meta', {})) | ||||||
|  |                     if meta and not meta['allow_inheritance'] and not self.field: | ||||||
|  |                         from fields import GenericReferenceField | ||||||
|  |                         value_dict[k] = GenericReferenceField().to_mongo(v) | ||||||
|  |                     else: | ||||||
|  |                         collection = v._meta['collection'] | ||||||
|  |                         value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) | ||||||
|  |                 elif hasattr(v, 'to_mongo'): | ||||||
|  |                     value_dict[k] = v.to_mongo() | ||||||
|  |                 else: | ||||||
|  |                     value_dict[k] = self.to_mongo(v) | ||||||
|  |  | ||||||
|  |         if is_list:  # Convert back to a list | ||||||
|  |             return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] | ||||||
|  |         return value_dict | ||||||
|  |  | ||||||
|  |     def validate(self, value): | ||||||
|  |         """If field provided ensure the value is valid. | ||||||
|  |         """ | ||||||
|  |         if self.field: | ||||||
|  |             try: | ||||||
|  |                 if hasattr(value, 'iteritems'): | ||||||
|  |                     [self.field.validate(v) for k,v in value.iteritems()] | ||||||
|  |                 else: | ||||||
|  |                      [self.field.validate(v) for v in value] | ||||||
|  |             except Exception, err: | ||||||
|  |                 raise ValidationError('Invalid %s item (%s)' % ( | ||||||
|  |                         self.field.__class__.__name__, str(v))) | ||||||
|  |  | ||||||
|  |     def prepare_query_value(self, op, value): | ||||||
|  |         return self.to_mongo(value) | ||||||
|  |  | ||||||
|  |     def lookup_member(self, member_name): | ||||||
|  |         if self.field: | ||||||
|  |             return self.field.lookup_member(member_name) | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     def _set_owner_document(self, owner_document): | ||||||
|  |         if self.field: | ||||||
|  |             self.field.owner_document = owner_document | ||||||
|  |         self._owner_document = owner_document | ||||||
|  |  | ||||||
|  |     def _get_owner_document(self, owner_document): | ||||||
|  |         self._owner_document = owner_document | ||||||
|  |  | ||||||
|  |     owner_document = property(_get_owner_document, _set_owner_document) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ObjectIdField(BaseField): | class ObjectIdField(BaseField): | ||||||
| @@ -219,7 +330,6 @@ class ObjectIdField(BaseField): | |||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         return value |         return value | ||||||
|         # return unicode(value) |  | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         if not isinstance(value, pymongo.objectid.ObjectId): |         if not isinstance(value, pymongo.objectid.ObjectId): | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from base import (BaseField, DereferenceBaseField, ObjectIdField, | from base import (BaseField, ComplexBaseField, ObjectIdField, | ||||||
|                   ValidationError, get_document) |                   ValidationError, get_document) | ||||||
| from queryset import DO_NOTHING | from queryset import DO_NOTHING | ||||||
| from document import Document, EmbeddedDocument | from document import Document, EmbeddedDocument | ||||||
| @@ -301,6 +301,8 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|  |         if isinstance(value, basestring): | ||||||
|  |             return value | ||||||
|         return self.document_type.to_mongo(value) |         return self.document_type.to_mongo(value) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
| @@ -320,7 +322,7 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ListField(DereferenceBaseField): | class ListField(ComplexBaseField): | ||||||
|     """A list field that wraps a standard field, allowing multiple instances |     """A list field that wraps a standard field, allowing multiple instances | ||||||
|     of the field to be used as a list in the database. |     of the field to be used as a list in the database. | ||||||
|     """ |     """ | ||||||
| @@ -328,48 +330,25 @@ class ListField(DereferenceBaseField): | |||||||
|     # ListFields cannot be indexed with _types - MongoDB doesn't support this |     # ListFields cannot be indexed with _types - MongoDB doesn't support this | ||||||
|     _index_with_types = False |     _index_with_types = False | ||||||
|  |  | ||||||
|     def __init__(self, field, **kwargs): |     def __init__(self, field=None, **kwargs): | ||||||
|         if not isinstance(field, BaseField): |  | ||||||
|             raise ValidationError('Argument to ListField constructor must be ' |  | ||||||
|                                   'a valid field') |  | ||||||
|         self.field = field |         self.field = field | ||||||
|         kwargs.setdefault('default', lambda: []) |         kwargs.setdefault('default', lambda: []) | ||||||
|         super(ListField, self).__init__(**kwargs) |         super(ListField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     def to_python(self, value): |  | ||||||
|         return [self.field.to_python(item) for item in value] |  | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |  | ||||||
|         return [self.field.to_mongo(item) for item in value] |  | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """Make sure that a list of valid fields is being used. |         """Make sure that a list of valid fields is being used. | ||||||
|         """ |         """ | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
|             raise ValidationError('Only lists and tuples may be used in a ' |             raise ValidationError('Only lists and tuples may be used in a ' | ||||||
|                                   'list field') |                                   'list field') | ||||||
|  |         super(ListField, self).validate(value) | ||||||
|         try: |  | ||||||
|             [self.field.validate(item) for item in value] |  | ||||||
|         except Exception, err: |  | ||||||
|             raise ValidationError('Invalid ListField item (%s)' % str(item)) |  | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if op in ('set', 'unset'): |         if self.field: | ||||||
|             return [self.field.prepare_query_value(op, v) for v in value] |             if op in ('set', 'unset') and not isinstance(value, basestring): | ||||||
|         return self.field.prepare_query_value(op, value) |                 return [self.field.prepare_query_value(op, v) for v in value] | ||||||
|  |             return self.field.prepare_query_value(op, value) | ||||||
|     def lookup_member(self, member_name): |         return super(ListField, self).prepare_query_value(op, value) | ||||||
|         return self.field.lookup_member(member_name) |  | ||||||
|  |  | ||||||
|     def _set_owner_document(self, owner_document): |  | ||||||
|         self.field.owner_document = owner_document |  | ||||||
|         self._owner_document = owner_document |  | ||||||
|  |  | ||||||
|     def _get_owner_document(self, owner_document): |  | ||||||
|         self._owner_document = owner_document |  | ||||||
|  |  | ||||||
|     owner_document = property(_get_owner_document, _set_owner_document) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SortedListField(ListField): | class SortedListField(ListField): | ||||||
| @@ -388,20 +367,21 @@ class SortedListField(ListField): | |||||||
|         super(SortedListField, self).__init__(field, **kwargs) |         super(SortedListField, self).__init__(field, **kwargs) | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|  |         value = super(SortedListField, self).to_mongo(value) | ||||||
|         if self._ordering is not None: |         if self._ordering is not None: | ||||||
|             return sorted([self.field.to_mongo(item) for item in value], |             return sorted(value, key=itemgetter(self._ordering)) | ||||||
|                           key=itemgetter(self._ordering)) |         return sorted(value) | ||||||
|         return sorted([self.field.to_mongo(item) for item in value]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DictField(BaseField): | class DictField(ComplexBaseField): | ||||||
|     """A dictionary field that wraps a standard Python dictionary. This is |     """A dictionary field that wraps a standard Python dictionary. This is | ||||||
|     similar to an embedded document, but the structure is not defined. |     similar to an embedded document, but the structure is not defined. | ||||||
|  |  | ||||||
|     .. versionadded:: 0.3 |     .. versionadded:: 0.3 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, basecls=None, *args, **kwargs): |     def __init__(self, basecls=None, field=None, *args, **kwargs): | ||||||
|  |         self.field = field | ||||||
|         self.basecls = basecls or BaseField |         self.basecls = basecls or BaseField | ||||||
|         assert issubclass(self.basecls, BaseField) |         assert issubclass(self.basecls, BaseField) | ||||||
|         kwargs.setdefault('default', lambda: {}) |         kwargs.setdefault('default', lambda: {}) | ||||||
| @@ -417,6 +397,7 @@ class DictField(BaseField): | |||||||
|         if any(('.' in k or '$' in k) for k in value): |         if any(('.' in k or '$' in k) for k in value): | ||||||
|             raise ValidationError('Invalid dictionary key name - keys may not ' |             raise ValidationError('Invalid dictionary key name - keys may not ' | ||||||
|                                   'contain "." or "$" characters') |                                   'contain "." or "$" characters') | ||||||
|  |         super(DictField, self).validate(value) | ||||||
|  |  | ||||||
|     def lookup_member(self, member_name): |     def lookup_member(self, member_name): | ||||||
|         return DictField(basecls=self.basecls, db_field=member_name) |         return DictField(basecls=self.basecls, db_field=member_name) | ||||||
| @@ -432,7 +413,7 @@ class DictField(BaseField): | |||||||
|         return super(DictField, self).prepare_query_value(op, value) |         return super(DictField, self).prepare_query_value(op, value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class MapField(DereferenceBaseField): | class MapField(DictField): | ||||||
|     """A field that maps a name to a specified field type. Similar to |     """A field that maps a name to a specified field type. Similar to | ||||||
|     a DictField, except the 'value' of each item must match the specified |     a DictField, except the 'value' of each item must match the specified | ||||||
|     field type. |     field type. | ||||||
| @@ -444,50 +425,7 @@ class MapField(DereferenceBaseField): | |||||||
|         if not isinstance(field, BaseField): |         if not isinstance(field, BaseField): | ||||||
|             raise ValidationError('Argument to MapField constructor must be ' |             raise ValidationError('Argument to MapField constructor must be ' | ||||||
|                                   'a valid field') |                                   'a valid field') | ||||||
|         self.field = field |         super(MapField, self).__init__(field=field, *args, **kwargs) | ||||||
|         kwargs.setdefault('default', lambda: {}) |  | ||||||
|         super(MapField, self).__init__(*args, **kwargs) |  | ||||||
|  |  | ||||||
|     def validate(self, value): |  | ||||||
|         """Make sure that a list of valid fields is being used. |  | ||||||
|         """ |  | ||||||
|         if not isinstance(value, dict): |  | ||||||
|             raise ValidationError('Only dictionaries may be used in a ' |  | ||||||
|                                   'DictField') |  | ||||||
|  |  | ||||||
|         if any(('.' in k or '$' in k) for k in value): |  | ||||||
|             raise ValidationError('Invalid dictionary key name - keys may not ' |  | ||||||
|                                   'contain "." or "$" characters') |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             [self.field.validate(item) for item in value.values()] |  | ||||||
|         except Exception, err: |  | ||||||
|             raise ValidationError('Invalid MapField item (%s)' % str(item)) |  | ||||||
|  |  | ||||||
|     def to_python(self, value): |  | ||||||
|         return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) |  | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |  | ||||||
|         return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) |  | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |  | ||||||
|         if op not in ('set', 'unset'): |  | ||||||
|             return self.field.prepare_query_value(op, value) |  | ||||||
|         for key in value: |  | ||||||
|             value[key] = self.field.prepare_query_value(op, value[key]) |  | ||||||
|         return value |  | ||||||
|  |  | ||||||
|     def lookup_member(self, member_name): |  | ||||||
|         return self.field.lookup_member(member_name) |  | ||||||
|  |  | ||||||
|     def _set_owner_document(self, owner_document): |  | ||||||
|         self.field.owner_document = owner_document |  | ||||||
|         self._owner_document = owner_document |  | ||||||
|  |  | ||||||
|     def _get_owner_document(self, owner_document): |  | ||||||
|         self._owner_document = owner_document |  | ||||||
|  |  | ||||||
|     owner_document = property(_get_owner_document, _set_owner_document) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReferenceField(BaseField): | class ReferenceField(BaseField): | ||||||
|   | |||||||
| @@ -549,11 +549,12 @@ class QuerySet(object): | |||||||
|             parts = [parts] |             parts = [parts] | ||||||
|         fields = [] |         fields = [] | ||||||
|         field = None |         field = None | ||||||
|  |  | ||||||
|         for field_name in parts: |         for field_name in parts: | ||||||
|             # Handle ListField indexing: |             # Handle ListField indexing: | ||||||
|             if field_name.isdigit(): |             if field_name.isdigit(): | ||||||
|                 try: |                 try: | ||||||
|                     field = field.field |                     new_field = field.field | ||||||
|                 except AttributeError, err: |                 except AttributeError, err: | ||||||
|                     raise InvalidQueryError( |                     raise InvalidQueryError( | ||||||
|                         "Can't use index on unsubscriptable field (%s)" % err) |                         "Can't use index on unsubscriptable field (%s)" % err) | ||||||
| @@ -567,11 +568,17 @@ class QuerySet(object): | |||||||
|                 field = document._fields[field_name] |                 field = document._fields[field_name] | ||||||
|             else: |             else: | ||||||
|                 # Look up subfield on the previous field |                 # Look up subfield on the previous field | ||||||
|                 field = field.lookup_member(field_name) |                 new_field = field.lookup_member(field_name) | ||||||
|                 if field is None: |                 from base import ComplexBaseField | ||||||
|  |                 if not new_field and isinstance(field, ComplexBaseField): | ||||||
|  |                     fields.append(field_name) | ||||||
|  |                     continue | ||||||
|  |                 elif not new_field: | ||||||
|                     raise InvalidQueryError('Cannot resolve field "%s"' |                     raise InvalidQueryError('Cannot resolve field "%s"' | ||||||
|                                             % field_name) |                                                 % field_name) | ||||||
|  |                 field = new_field  # update field to the new field type | ||||||
|             fields.append(field) |             fields.append(field) | ||||||
|  |  | ||||||
|         return fields |         return fields | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -615,14 +622,33 @@ class QuerySet(object): | |||||||
|             if _doc_cls: |             if _doc_cls: | ||||||
|                 # Switch field names to proper names [set in Field(name='foo')] |                 # Switch field names to proper names [set in Field(name='foo')] | ||||||
|                 fields = QuerySet._lookup_field(_doc_cls, parts) |                 fields = QuerySet._lookup_field(_doc_cls, parts) | ||||||
|                 parts = [field.db_field for field in fields] |                 parts = [] | ||||||
|  |  | ||||||
|  |                 cleaned_fields = [] | ||||||
|  |                 append_field = True | ||||||
|  |                 for field in fields: | ||||||
|  |                     if isinstance(field, str): | ||||||
|  |                         parts.append(field) | ||||||
|  |                         append_field = False | ||||||
|  |                     else: | ||||||
|  |                         parts.append(field.db_field) | ||||||
|  |                     if append_field: | ||||||
|  |                         cleaned_fields.append(field) | ||||||
|  |  | ||||||
|                 # Convert value to proper value |                 # Convert value to proper value | ||||||
|                 field = fields[-1] |                 field = cleaned_fields[-1] | ||||||
|  |  | ||||||
|                 singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] |                 singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] | ||||||
|                 singular_ops += match_operators |                 singular_ops += match_operators | ||||||
|                 if op in singular_ops: |                 if op in singular_ops: | ||||||
|                     value = field.prepare_query_value(op, value) |                     if isinstance(field, basestring): | ||||||
|  |                         if op in match_operators and isinstance(value, basestring): | ||||||
|  |                             from mongoengine import StringField | ||||||
|  |                             value = StringField().prepare_query_value(op, value) | ||||||
|  |                         else: | ||||||
|  |                             value = field | ||||||
|  |                     else: | ||||||
|  |                         value = field.prepare_query_value(op, value) | ||||||
|                 elif op in ('in', 'nin', 'all', 'near'): |                 elif op in ('in', 'nin', 'all', 'near'): | ||||||
|                     # 'in', 'nin' and 'all' require a list of values |                     # 'in', 'nin' and 'all' require a list of values | ||||||
|                     value = [field.prepare_query_value(op, v) for v in value] |                     value = [field.prepare_query_value(op, v) for v in value] | ||||||
| @@ -1170,14 +1196,19 @@ class QuerySet(object): | |||||||
|                 fields = QuerySet._lookup_field(_doc_cls, parts) |                 fields = QuerySet._lookup_field(_doc_cls, parts) | ||||||
|                 parts = [] |                 parts = [] | ||||||
|  |  | ||||||
|  |                 cleaned_fields = [] | ||||||
|  |                 append_field = True | ||||||
|                 for field in fields: |                 for field in fields: | ||||||
|                     if isinstance(field, str): |                     if isinstance(field, str): | ||||||
|                         parts.append(field) |                         parts.append(field) | ||||||
|  |                         append_field = False | ||||||
|                     else: |                     else: | ||||||
|                         parts.append(field.db_field) |                         parts.append(field.db_field) | ||||||
|  |                     if append_field: | ||||||
|  |                         cleaned_fields.append(field) | ||||||
|  |  | ||||||
|                 # Convert value to proper value |                 # Convert value to proper value | ||||||
|                 field = fields[-1] |                 field = cleaned_fields[-1] | ||||||
|  |  | ||||||
|                 if op in (None, 'set', 'push', 'pull', 'addToSet'): |                 if op in (None, 'set', 'push', 'pull', 'addToSet'): | ||||||
|                     value = field.prepare_query_value(op, value) |                     value = field.prepare_query_value(op, value) | ||||||
|   | |||||||
| @@ -122,6 +122,64 @@ class FieldTest(unittest.TestCase): | |||||||
|             [m for m in group_obj.members] |             [m for m in group_obj.members] | ||||||
|             self.assertEqual(q, 4) |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             for m in group_obj.members: | ||||||
|  |                 self.assertTrue('User' in m.__class__.__name__) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_list_field_complex(self): | ||||||
|  |  | ||||||
|  |         class UserA(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserB(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserC(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = ListField() | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         members = [] | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             a = UserA(name='User A %s' % i) | ||||||
|  |             a.save() | ||||||
|  |  | ||||||
|  |             b = UserB(name='User B %s' % i) | ||||||
|  |             b.save() | ||||||
|  |  | ||||||
|  |             c = UserC(name='User C %s' % i) | ||||||
|  |             c.save() | ||||||
|  |  | ||||||
|  |             members += [a, b, c] | ||||||
|  |  | ||||||
|  |         group = Group(members=members) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             for m in group_obj.members: | ||||||
|  |                 self.assertTrue('User' in m.__class__.__name__) | ||||||
|  |  | ||||||
|         UserA.drop_collection() |         UserA.drop_collection() | ||||||
|         UserB.drop_collection() |         UserB.drop_collection() | ||||||
|         UserC.drop_collection() |         UserC.drop_collection() | ||||||
| @@ -156,10 +214,13 @@ class FieldTest(unittest.TestCase): | |||||||
|             [m for m in group_obj.members] |             [m for m in group_obj.members] | ||||||
|             self.assertEqual(q, 2) |             self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |             for k, m in group_obj.members.iteritems(): | ||||||
|  |                 self.assertTrue(isinstance(m, User)) | ||||||
|  |  | ||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|     def ztest_generic_reference_dict_field(self): |     def test_dict_field(self): | ||||||
|  |  | ||||||
|         class UserA(Document): |         class UserA(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
| @@ -206,6 +267,9 @@ class FieldTest(unittest.TestCase): | |||||||
|             [m for m in group_obj.members] |             [m for m in group_obj.members] | ||||||
|             self.assertEqual(q, 4) |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             for k, m in group_obj.members.iteritems(): | ||||||
|  |                 self.assertTrue('User' in m.__class__.__name__) | ||||||
|  |  | ||||||
|         group.members = {} |         group.members = {} | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
| @@ -218,11 +282,54 @@ class FieldTest(unittest.TestCase): | |||||||
|             [m for m in group_obj.members] |             [m for m in group_obj.members] | ||||||
|             self.assertEqual(q, 1) |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             for k, m in group_obj.members.iteritems(): | ||||||
|  |                 self.assertTrue('User' in m.__class__.__name__) | ||||||
|  |  | ||||||
|         UserA.drop_collection() |         UserA.drop_collection() | ||||||
|         UserB.drop_collection() |         UserB.drop_collection() | ||||||
|         UserC.drop_collection() |         UserC.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_dict_field_no_field_inheritance(self): | ||||||
|  |  | ||||||
|  |         class UserA(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = {'allow_inheritance': False} | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = DictField() | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         members = [] | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             a = UserA(name='User A %s' % i) | ||||||
|  |             a.save() | ||||||
|  |  | ||||||
|  |             members += [a] | ||||||
|  |  | ||||||
|  |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |             for k, m in group_obj.members.iteritems(): | ||||||
|  |                 self.assertTrue(isinstance(m, UserA)) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|     def test_generic_reference_map_field(self): |     def test_generic_reference_map_field(self): | ||||||
|  |  | ||||||
|         class UserA(Document): |         class UserA(Document): | ||||||
| @@ -270,6 +377,9 @@ class FieldTest(unittest.TestCase): | |||||||
|             [m for m in group_obj.members] |             [m for m in group_obj.members] | ||||||
|             self.assertEqual(q, 4) |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             for k, m in group_obj.members.iteritems(): | ||||||
|  |                 self.assertTrue('User' in m.__class__.__name__) | ||||||
|  |  | ||||||
|         group.members = {} |         group.members = {} | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										287
									
								
								tests/fields.py
									
									
									
									
									
								
							
							
						
						
									
										287
									
								
								tests/fields.py
									
									
									
									
									
								
							| @@ -322,6 +322,108 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_list_field(self): | ||||||
|  |         """Ensure that list types work as expected. | ||||||
|  |         """ | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             info = ListField() | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost() | ||||||
|  |         post.info = 'my post' | ||||||
|  |         self.assertRaises(ValidationError, post.validate) | ||||||
|  |  | ||||||
|  |         post.info = {'title': 'test'} | ||||||
|  |         self.assertRaises(ValidationError, post.validate) | ||||||
|  |  | ||||||
|  |         post.info = ['test'] | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         post = BlogPost() | ||||||
|  |         post.info = [{'test': 'test'}] | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         post = BlogPost() | ||||||
|  |         post.info = [{'test': 3}] | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         self.assertEquals(BlogPost.objects.count(), 3) | ||||||
|  |         self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1) | ||||||
|  |         self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1) | ||||||
|  |  | ||||||
|  |         # Confirm handles non strings or non existing keys | ||||||
|  |         self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) | ||||||
|  |         self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_list_field_Strict(self): | ||||||
|  |         """Ensure that list field handles validation if provided a strict field type.""" | ||||||
|  |  | ||||||
|  |         class Simple(Document): | ||||||
|  |             mapping = ListField(field=IntField()) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |         e = Simple() | ||||||
|  |         e.mapping = [1] | ||||||
|  |         e.save() | ||||||
|  |  | ||||||
|  |         def create_invalid_mapping(): | ||||||
|  |             e.mapping = ["abc"] | ||||||
|  |             e.save() | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, create_invalid_mapping) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_list_field_complex(self): | ||||||
|  |         """Ensure that the list fields can handle the complex types.""" | ||||||
|  |  | ||||||
|  |         class SettingBase(EmbeddedDocument): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         class StringSetting(SettingBase): | ||||||
|  |             value = StringField() | ||||||
|  |  | ||||||
|  |         class IntegerSetting(SettingBase): | ||||||
|  |             value = IntField() | ||||||
|  |  | ||||||
|  |         class Simple(Document): | ||||||
|  |             mapping = ListField() | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |         e = Simple() | ||||||
|  |         e.mapping.append(StringSetting(value='foo')) | ||||||
|  |         e.mapping.append(IntegerSetting(value=42)) | ||||||
|  |         e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, | ||||||
|  |                           'complex': IntegerSetting(value=42), 'list': | ||||||
|  |                           [IntegerSetting(value=42), StringSetting(value='foo')]}) | ||||||
|  |         e.save() | ||||||
|  |  | ||||||
|  |         e2 = Simple.objects.get(id=e.id) | ||||||
|  |         self.assertTrue(isinstance(e2.mapping[0], StringSetting)) | ||||||
|  |         self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) | ||||||
|  |  | ||||||
|  |         # Test querying | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) | ||||||
|  |  | ||||||
|  |         # Confirm can update | ||||||
|  |         Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1) | ||||||
|  |  | ||||||
|  |         Simple.objects().update( | ||||||
|  |             set__mapping__2__list__1=StringSetting(value='Boo')) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|     def test_dict_field(self): |     def test_dict_field(self): | ||||||
|         """Ensure that dict types work as expected. |         """Ensure that dict types work as expected. | ||||||
|         """ |         """ | ||||||
| @@ -363,6 +465,131 @@ class FieldTest(unittest.TestCase): | |||||||
|         self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) |         self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_dictfield_Strict(self): | ||||||
|  |         """Ensure that dict field handles validation if provided a strict field type.""" | ||||||
|  |  | ||||||
|  |         class Simple(Document): | ||||||
|  |             mapping = DictField(field=IntField()) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |         e = Simple() | ||||||
|  |         e.mapping['someint'] = 1 | ||||||
|  |         e.save() | ||||||
|  |  | ||||||
|  |         def create_invalid_mapping(): | ||||||
|  |             e.mapping['somestring'] = "abc" | ||||||
|  |             e.save() | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, create_invalid_mapping) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_dictfield_complex(self): | ||||||
|  |         """Ensure that the dict field can handle the complex types.""" | ||||||
|  |  | ||||||
|  |         class SettingBase(EmbeddedDocument): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         class StringSetting(SettingBase): | ||||||
|  |             value = StringField() | ||||||
|  |  | ||||||
|  |         class IntegerSetting(SettingBase): | ||||||
|  |             value = IntField() | ||||||
|  |  | ||||||
|  |         class Simple(Document): | ||||||
|  |             mapping = DictField() | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |         e = Simple() | ||||||
|  |         e.mapping['somestring'] = StringSetting(value='foo') | ||||||
|  |         e.mapping['someint'] = IntegerSetting(value=42) | ||||||
|  |         e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, | ||||||
|  |                                  'complex': IntegerSetting(value=42), 'list': | ||||||
|  |                                 [IntegerSetting(value=42), StringSetting(value='foo')]} | ||||||
|  |         e.save() | ||||||
|  |  | ||||||
|  |         e2 = Simple.objects.get(id=e.id) | ||||||
|  |         self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) | ||||||
|  |         self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) | ||||||
|  |  | ||||||
|  |         # Test querying | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) | ||||||
|  |  | ||||||
|  |         # Confirm can update | ||||||
|  |         Simple.objects().update( | ||||||
|  |             set__mapping={"someint": IntegerSetting(value=10)}) | ||||||
|  |         Simple.objects().update( | ||||||
|  |             set__mapping__nested_dict__list__1=StringSetting(value='Boo')) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) | ||||||
|  |         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_mapfield(self): | ||||||
|  |         """Ensure that the MapField handles the declared type.""" | ||||||
|  |  | ||||||
|  |         class Simple(Document): | ||||||
|  |             mapping = MapField(IntField()) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |         e = Simple() | ||||||
|  |         e.mapping['someint'] = 1 | ||||||
|  |         e.save() | ||||||
|  |  | ||||||
|  |         def create_invalid_mapping(): | ||||||
|  |             e.mapping['somestring'] = "abc" | ||||||
|  |             e.save() | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, create_invalid_mapping) | ||||||
|  |  | ||||||
|  |         def create_invalid_class(): | ||||||
|  |             class NoDeclaredType(Document): | ||||||
|  |                 mapping = MapField() | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, create_invalid_class) | ||||||
|  |  | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_complex_mapfield(self): | ||||||
|  |         """Ensure that the MapField can handle complex declared types.""" | ||||||
|  |  | ||||||
|  |         class SettingBase(EmbeddedDocument): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         class StringSetting(SettingBase): | ||||||
|  |             value = StringField() | ||||||
|  |  | ||||||
|  |         class IntegerSetting(SettingBase): | ||||||
|  |             value = IntField() | ||||||
|  |  | ||||||
|  |         class Extensible(Document): | ||||||
|  |             mapping = MapField(EmbeddedDocumentField(SettingBase)) | ||||||
|  |  | ||||||
|  |         Extensible.drop_collection() | ||||||
|  |  | ||||||
|  |         e = Extensible() | ||||||
|  |         e.mapping['somestring'] = StringSetting(value='foo') | ||||||
|  |         e.mapping['someint'] = IntegerSetting(value=42) | ||||||
|  |         e.save() | ||||||
|  |  | ||||||
|  |         e2 = Extensible.objects.get(id=e.id) | ||||||
|  |         self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) | ||||||
|  |         self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) | ||||||
|  |  | ||||||
|  |         def create_invalid_mapping(): | ||||||
|  |             e.mapping['someint'] = 123 | ||||||
|  |             e.save() | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, create_invalid_mapping) | ||||||
|  |  | ||||||
|  |         Extensible.drop_collection() | ||||||
|  |  | ||||||
|     def test_embedded_document_validation(self): |     def test_embedded_document_validation(self): | ||||||
|         """Ensure that invalid embedded documents cannot be assigned to |         """Ensure that invalid embedded documents cannot be assigned to | ||||||
|         embedded document fields. |         embedded document fields. | ||||||
| @@ -933,66 +1160,6 @@ class FieldTest(unittest.TestCase): | |||||||
|         self.assertEqual(d2.data, {}) |         self.assertEqual(d2.data, {}) | ||||||
|         self.assertEqual(d2.data2, {}) |         self.assertEqual(d2.data2, {}) | ||||||
|  |  | ||||||
|     def test_mapfield(self): |  | ||||||
|         """Ensure that the MapField handles the declared type.""" |  | ||||||
|  |  | ||||||
|         class Simple(Document): |  | ||||||
|             mapping = MapField(IntField()) |  | ||||||
|  |  | ||||||
|         Simple.drop_collection() |  | ||||||
|  |  | ||||||
|         e = Simple() |  | ||||||
|         e.mapping['someint'] = 1 |  | ||||||
|         e.save() |  | ||||||
|  |  | ||||||
|         def create_invalid_mapping(): |  | ||||||
|             e.mapping['somestring'] = "abc" |  | ||||||
|             e.save() |  | ||||||
|  |  | ||||||
|         self.assertRaises(ValidationError, create_invalid_mapping) |  | ||||||
|  |  | ||||||
|         def create_invalid_class(): |  | ||||||
|             class NoDeclaredType(Document): |  | ||||||
|                 mapping = MapField() |  | ||||||
|  |  | ||||||
|         self.assertRaises(ValidationError, create_invalid_class) |  | ||||||
|  |  | ||||||
|         Simple.drop_collection() |  | ||||||
|  |  | ||||||
|     def test_complex_mapfield(self): |  | ||||||
|         """Ensure that the MapField can handle complex declared types.""" |  | ||||||
|  |  | ||||||
|         class SettingBase(EmbeddedDocument): |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         class StringSetting(SettingBase): |  | ||||||
|             value = StringField() |  | ||||||
|  |  | ||||||
|         class IntegerSetting(SettingBase): |  | ||||||
|             value = IntField() |  | ||||||
|  |  | ||||||
|         class Extensible(Document): |  | ||||||
|             mapping = MapField(EmbeddedDocumentField(SettingBase)) |  | ||||||
|  |  | ||||||
|         Extensible.drop_collection() |  | ||||||
|  |  | ||||||
|         e = Extensible() |  | ||||||
|         e.mapping['somestring'] = StringSetting(value='foo') |  | ||||||
|         e.mapping['someint'] = IntegerSetting(value=42) |  | ||||||
|         e.save() |  | ||||||
|  |  | ||||||
|         e2 = Extensible.objects.get(id=e.id) |  | ||||||
|         self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) |  | ||||||
|         self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) |  | ||||||
|  |  | ||||||
|         def create_invalid_mapping(): |  | ||||||
|             e.mapping['someint'] = 123 |  | ||||||
|             e.save() |  | ||||||
|  |  | ||||||
|         self.assertRaises(ValidationError, create_invalid_mapping) |  | ||||||
|  |  | ||||||
|         Extensible.drop_collection() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user