Added ComplexBaseField
* Handles the efficient lazy dereferencing of DBrefs. * Handles complex nested values in ListFields and DictFields * Allows for both strictly declared ListFields and DictFields where the embedded value must be of a field type or no restrictions where the values can be a mix of field types / values. * Handles DBrefences of documents where allow_inheritance = False.
This commit is contained in:
		| @@ -132,15 +132,19 @@ class BaseField(object): | ||||
|         self.validate(value) | ||||
|  | ||||
|  | ||||
| class DereferenceBaseField(BaseField): | ||||
|     """Handles the lazy dereferencing of a queryset.  Will dereference all | ||||
| class ComplexBaseField(BaseField): | ||||
|     """Handles complex fields, such as lists / dictionaries. | ||||
|  | ||||
|     Allows for nesting of embedded documents inside complex types. | ||||
|     Handles the lazy dereferencing of a queryset by lazily dereferencing all | ||||
|     items in a list / dict rather than one at a time. | ||||
|     """ | ||||
|  | ||||
|     field = None | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor to automatically dereference references. | ||||
|         """ | ||||
|         from fields import ReferenceField, GenericReferenceField | ||||
|         from connection import _get_db | ||||
|  | ||||
|         if instance is None: | ||||
| @@ -149,68 +153,175 @@ class DereferenceBaseField(BaseField): | ||||
|  | ||||
|         # Get value from document instance if available | ||||
|         value_list = instance._data.get(self.name) | ||||
|         if not value_list: | ||||
|             return super(DereferenceBaseField, self).__get__(instance, owner) | ||||
|         if not value_list or isinstance(value_list, basestring): | ||||
|             return super(ComplexBaseField, self).__get__(instance, owner) | ||||
|  | ||||
|         is_list = False | ||||
|         if not hasattr(value_list, 'items'): | ||||
|             is_list = True | ||||
|             value_list = dict([(k,v) for k,v in enumerate(value_list)]) | ||||
|  | ||||
|         if isinstance(self.field, ReferenceField) and value_list: | ||||
|         for k,v in value_list.items(): | ||||
|             if isinstance(v, dict) and '_cls' in v and '_ref' not in v: | ||||
|                 value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v) | ||||
|  | ||||
|         # Handle all dereferencing | ||||
|         db = _get_db() | ||||
|         dbref = {} | ||||
|         collections = {} | ||||
|  | ||||
|         for k, v in value_list.items(): | ||||
|             dbref[k] = v | ||||
|             # Save any DBRefs | ||||
|             if isinstance(v, (pymongo.dbref.DBRef)): | ||||
|                 # direct reference (DBRef) | ||||
|                 collections.setdefault(v.collection, []).append((k, v)) | ||||
|             elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: | ||||
|                 # generic reference | ||||
|                 collection =  get_document(v['_cls'])._meta['collection'] | ||||
|                 collections.setdefault(collection, []).append((k, v)) | ||||
|  | ||||
|         # For each collection get the references | ||||
|         for collection, dbrefs in collections.items(): | ||||
|                 id_map = dict([(v.id, k) for k, v in dbrefs]) | ||||
|             id_map = {} | ||||
|             for k, v in dbrefs: | ||||
|                 if isinstance(v, (pymongo.dbref.DBRef)): | ||||
|                     # direct reference (DBRef), has no _cls information | ||||
|                     id_map[v.id] = (k, None) | ||||
|                 elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: | ||||
|                     # generic reference - includes _cls information | ||||
|                     id_map[v['_ref'].id] = (k, get_document(v['_cls'])) | ||||
|  | ||||
|             references = db[collection].find({'_id': {'$in': id_map.keys()}}) | ||||
|             for ref in references: | ||||
|                     key = id_map[ref['_id']] | ||||
|                     dbref[key] = get_document(ref['_cls'])._from_son(ref) | ||||
|  | ||||
|             if is_list: | ||||
|                 dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] | ||||
|             instance._data[self.name] = dbref | ||||
|  | ||||
|         # Get value from document instance if available | ||||
|         if isinstance(self.field, GenericReferenceField) and value_list: | ||||
|             db = _get_db() | ||||
|             value_list = [(k,v) for k,v in value_list.items()] | ||||
|             dbref = {} | ||||
|             classes = {} | ||||
|  | ||||
|             for k, v in value_list: | ||||
|                 dbref[k] = v | ||||
|                 # Save any DBRefs | ||||
|                 if isinstance(v, (dict, pymongo.son.SON)): | ||||
|                     classes.setdefault(v['_cls'], []).append((k, v)) | ||||
|  | ||||
|             # For each collection get the references | ||||
|             for doc_cls, dbrefs in classes.items(): | ||||
|                 id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) | ||||
|                 doc_cls = get_document(doc_cls) | ||||
|                 collection = doc_cls._meta['collection'] | ||||
|                 references = db[collection].find({'_id': {'$in': id_map.keys()}}) | ||||
|  | ||||
|                 for ref in references: | ||||
|                     key = id_map[ref['_id']] | ||||
|                 key, doc_cls = id_map[ref['_id']] | ||||
|                 if not doc_cls:  # If no doc_cls get it from the referenced doc | ||||
|                     doc_cls = get_document(ref['_cls']) | ||||
|                 dbref[key] = doc_cls._from_son(ref) | ||||
|  | ||||
|         if is_list: | ||||
|             dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] | ||||
|  | ||||
|         instance._data[self.name] = dbref | ||||
|         return super(ComplexBaseField, self).__get__(instance, owner) | ||||
|  | ||||
|         return super(DereferenceBaseField, self).__get__(instance, owner) | ||||
|     def to_python(self, value): | ||||
|         """Convert a MongoDB-compatible type to a Python type. | ||||
|         """ | ||||
|         from mongoengine import Document | ||||
|  | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|  | ||||
|         if hasattr(value, 'to_python'): | ||||
|             return value.to_python() | ||||
|  | ||||
|         is_list = False | ||||
|         if not hasattr(value, 'items'): | ||||
|             try: | ||||
|                 is_list = True | ||||
|                 value = dict([(k,v) for k,v in enumerate(value)]) | ||||
|             except TypeError:  # Not iterable return the value | ||||
|                 return value | ||||
|  | ||||
|         if self.field: | ||||
|             value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) | ||||
|         else: | ||||
|             value_dict = {} | ||||
|             for k,v in value.items(): | ||||
|                 if isinstance(v, Document): | ||||
|                     # We need the id from the saved object to create the DBRef | ||||
|                     if v.pk is None: | ||||
|                         raise ValidationError('You can only reference documents once ' | ||||
|                                       'they have been saved to the database') | ||||
|                     collection = v._meta['collection'] | ||||
|                     value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) | ||||
|                 elif hasattr(v, 'to_python'): | ||||
|                     value_dict[k] = v.to_python() | ||||
|                 else: | ||||
|                     value_dict[k] = self.to_python(v) | ||||
|  | ||||
|         if is_list:  # Convert back to a list | ||||
|             return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] | ||||
|         return value_dict | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         """Convert a Python type to a MongoDB-compatible type. | ||||
|         """ | ||||
|         from mongoengine import Document | ||||
|  | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|  | ||||
|         if hasattr(value, 'to_mongo'): | ||||
|             return value.to_mongo() | ||||
|  | ||||
|         is_list = False | ||||
|         if not hasattr(value, 'items'): | ||||
|             try: | ||||
|                 is_list = True | ||||
|                 value = dict([(k,v) for k,v in enumerate(value)]) | ||||
|             except TypeError:  # Not iterable return the value | ||||
|                 return value | ||||
|  | ||||
|         if self.field: | ||||
|             value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) | ||||
|         else: | ||||
|             value_dict = {} | ||||
|             for k,v in value.items(): | ||||
|                 if isinstance(v, Document): | ||||
|                     # We need the id from the saved object to create the DBRef | ||||
|                     if v.pk is None: | ||||
|                         raise ValidationError('You can only reference documents once ' | ||||
|                                       'they have been saved to the database') | ||||
|  | ||||
|                     # If its a document that is not inheritable it won't have | ||||
|                     # _types / _cls data so make it a generic reference allows | ||||
|                     # us to dereference | ||||
|                     meta = getattr(v, 'meta', getattr(v, '_meta', {})) | ||||
|                     if meta and not meta['allow_inheritance'] and not self.field: | ||||
|                         from fields import GenericReferenceField | ||||
|                         value_dict[k] = GenericReferenceField().to_mongo(v) | ||||
|                     else: | ||||
|                         collection = v._meta['collection'] | ||||
|                         value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) | ||||
|                 elif hasattr(v, 'to_mongo'): | ||||
|                     value_dict[k] = v.to_mongo() | ||||
|                 else: | ||||
|                     value_dict[k] = self.to_mongo(v) | ||||
|  | ||||
|         if is_list:  # Convert back to a list | ||||
|             return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] | ||||
|         return value_dict | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """If field provided ensure the value is valid. | ||||
|         """ | ||||
|         if self.field: | ||||
|             try: | ||||
|                 if hasattr(value, 'iteritems'): | ||||
|                     [self.field.validate(v) for k,v in value.iteritems()] | ||||
|                 else: | ||||
|                      [self.field.validate(v) for v in value] | ||||
|             except Exception, err: | ||||
|                 raise ValidationError('Invalid %s item (%s)' % ( | ||||
|                         self.field.__class__.__name__, str(v))) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         return self.to_mongo(value) | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         if self.field: | ||||
|             return self.field.lookup_member(member_name) | ||||
|         return None | ||||
|  | ||||
|     def _set_owner_document(self, owner_document): | ||||
|         if self.field: | ||||
|             self.field.owner_document = owner_document | ||||
|         self._owner_document = owner_document | ||||
|  | ||||
|     def _get_owner_document(self, owner_document): | ||||
|         self._owner_document = owner_document | ||||
|  | ||||
|     owner_document = property(_get_owner_document, _set_owner_document) | ||||
|  | ||||
|  | ||||
| class ObjectIdField(BaseField): | ||||
| @@ -219,7 +330,6 @@ class ObjectIdField(BaseField): | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         return value | ||||
|         # return unicode(value) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         if not isinstance(value, pymongo.objectid.ObjectId): | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from base import (BaseField, DereferenceBaseField, ObjectIdField, | ||||
| from base import (BaseField, ComplexBaseField, ObjectIdField, | ||||
|                   ValidationError, get_document) | ||||
| from queryset import DO_NOTHING | ||||
| from document import Document, EmbeddedDocument | ||||
| @@ -301,6 +301,8 @@ class EmbeddedDocumentField(BaseField): | ||||
|         return value | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|         return self.document_type.to_mongo(value) | ||||
|  | ||||
|     def validate(self, value): | ||||
| @@ -320,7 +322,7 @@ class EmbeddedDocumentField(BaseField): | ||||
|         return self.to_mongo(value) | ||||
|  | ||||
|  | ||||
| class ListField(DereferenceBaseField): | ||||
| class ListField(ComplexBaseField): | ||||
|     """A list field that wraps a standard field, allowing multiple instances | ||||
|     of the field to be used as a list in the database. | ||||
|     """ | ||||
| @@ -328,48 +330,25 @@ class ListField(DereferenceBaseField): | ||||
|     # ListFields cannot be indexed with _types - MongoDB doesn't support this | ||||
|     _index_with_types = False | ||||
|  | ||||
|     def __init__(self, field, **kwargs): | ||||
|         if not isinstance(field, BaseField): | ||||
|             raise ValidationError('Argument to ListField constructor must be ' | ||||
|                                   'a valid field') | ||||
|     def __init__(self, field=None, **kwargs): | ||||
|         self.field = field | ||||
|         kwargs.setdefault('default', lambda: []) | ||||
|         super(ListField, self).__init__(**kwargs) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         return [self.field.to_python(item) for item in value] | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         return [self.field.to_mongo(item) for item in value] | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Make sure that a list of valid fields is being used. | ||||
|         """ | ||||
|         if not isinstance(value, (list, tuple)): | ||||
|             raise ValidationError('Only lists and tuples may be used in a ' | ||||
|                                   'list field') | ||||
|  | ||||
|         try: | ||||
|             [self.field.validate(item) for item in value] | ||||
|         except Exception, err: | ||||
|             raise ValidationError('Invalid ListField item (%s)' % str(item)) | ||||
|         super(ListField, self).validate(value) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if op in ('set', 'unset'): | ||||
|         if self.field: | ||||
|             if op in ('set', 'unset') and not isinstance(value, basestring): | ||||
|                 return [self.field.prepare_query_value(op, v) for v in value] | ||||
|             return self.field.prepare_query_value(op, value) | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         return self.field.lookup_member(member_name) | ||||
|  | ||||
|     def _set_owner_document(self, owner_document): | ||||
|         self.field.owner_document = owner_document | ||||
|         self._owner_document = owner_document | ||||
|  | ||||
|     def _get_owner_document(self, owner_document): | ||||
|         self._owner_document = owner_document | ||||
|  | ||||
|     owner_document = property(_get_owner_document, _set_owner_document) | ||||
|         return super(ListField, self).prepare_query_value(op, value) | ||||
|  | ||||
|  | ||||
| class SortedListField(ListField): | ||||
| @@ -388,20 +367,21 @@ class SortedListField(ListField): | ||||
|         super(SortedListField, self).__init__(field, **kwargs) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         value = super(SortedListField, self).to_mongo(value) | ||||
|         if self._ordering is not None: | ||||
|             return sorted([self.field.to_mongo(item) for item in value], | ||||
|                           key=itemgetter(self._ordering)) | ||||
|         return sorted([self.field.to_mongo(item) for item in value]) | ||||
|             return sorted(value, key=itemgetter(self._ordering)) | ||||
|         return sorted(value) | ||||
|  | ||||
|  | ||||
| class DictField(BaseField): | ||||
| class DictField(ComplexBaseField): | ||||
|     """A dictionary field that wraps a standard Python dictionary. This is | ||||
|     similar to an embedded document, but the structure is not defined. | ||||
|  | ||||
|     .. versionadded:: 0.3 | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, basecls=None, *args, **kwargs): | ||||
|     def __init__(self, basecls=None, field=None, *args, **kwargs): | ||||
|         self.field = field | ||||
|         self.basecls = basecls or BaseField | ||||
|         assert issubclass(self.basecls, BaseField) | ||||
|         kwargs.setdefault('default', lambda: {}) | ||||
| @@ -417,6 +397,7 @@ class DictField(BaseField): | ||||
|         if any(('.' in k or '$' in k) for k in value): | ||||
|             raise ValidationError('Invalid dictionary key name - keys may not ' | ||||
|                                   'contain "." or "$" characters') | ||||
|         super(DictField, self).validate(value) | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         return DictField(basecls=self.basecls, db_field=member_name) | ||||
| @@ -432,7 +413,7 @@ class DictField(BaseField): | ||||
|         return super(DictField, self).prepare_query_value(op, value) | ||||
|  | ||||
|  | ||||
| class MapField(DereferenceBaseField): | ||||
| class MapField(DictField): | ||||
|     """A field that maps a name to a specified field type. Similar to | ||||
|     a DictField, except the 'value' of each item must match the specified | ||||
|     field type. | ||||
| @@ -444,50 +425,7 @@ class MapField(DereferenceBaseField): | ||||
|         if not isinstance(field, BaseField): | ||||
|             raise ValidationError('Argument to MapField constructor must be ' | ||||
|                                   'a valid field') | ||||
|         self.field = field | ||||
|         kwargs.setdefault('default', lambda: {}) | ||||
|         super(MapField, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Make sure that a list of valid fields is being used. | ||||
|         """ | ||||
|         if not isinstance(value, dict): | ||||
|             raise ValidationError('Only dictionaries may be used in a ' | ||||
|                                   'DictField') | ||||
|  | ||||
|         if any(('.' in k or '$' in k) for k in value): | ||||
|             raise ValidationError('Invalid dictionary key name - keys may not ' | ||||
|                                   'contain "." or "$" characters') | ||||
|  | ||||
|         try: | ||||
|             [self.field.validate(item) for item in value.values()] | ||||
|         except Exception, err: | ||||
|             raise ValidationError('Invalid MapField item (%s)' % str(item)) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if op not in ('set', 'unset'): | ||||
|             return self.field.prepare_query_value(op, value) | ||||
|         for key in value: | ||||
|             value[key] = self.field.prepare_query_value(op, value[key]) | ||||
|         return value | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         return self.field.lookup_member(member_name) | ||||
|  | ||||
|     def _set_owner_document(self, owner_document): | ||||
|         self.field.owner_document = owner_document | ||||
|         self._owner_document = owner_document | ||||
|  | ||||
|     def _get_owner_document(self, owner_document): | ||||
|         self._owner_document = owner_document | ||||
|  | ||||
|     owner_document = property(_get_owner_document, _set_owner_document) | ||||
|         super(MapField, self).__init__(field=field, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class ReferenceField(BaseField): | ||||
|   | ||||
| @@ -549,11 +549,12 @@ class QuerySet(object): | ||||
|             parts = [parts] | ||||
|         fields = [] | ||||
|         field = None | ||||
|  | ||||
|         for field_name in parts: | ||||
|             # Handle ListField indexing: | ||||
|             if field_name.isdigit(): | ||||
|                 try: | ||||
|                     field = field.field | ||||
|                     new_field = field.field | ||||
|                 except AttributeError, err: | ||||
|                     raise InvalidQueryError( | ||||
|                         "Can't use index on unsubscriptable field (%s)" % err) | ||||
| @@ -567,11 +568,17 @@ class QuerySet(object): | ||||
|                 field = document._fields[field_name] | ||||
|             else: | ||||
|                 # Look up subfield on the previous field | ||||
|                 field = field.lookup_member(field_name) | ||||
|                 if field is None: | ||||
|                 new_field = field.lookup_member(field_name) | ||||
|                 from base import ComplexBaseField | ||||
|                 if not new_field and isinstance(field, ComplexBaseField): | ||||
|                     fields.append(field_name) | ||||
|                     continue | ||||
|                 elif not new_field: | ||||
|                     raise InvalidQueryError('Cannot resolve field "%s"' | ||||
|                                                 % field_name) | ||||
|                 field = new_field  # update field to the new field type | ||||
|             fields.append(field) | ||||
|  | ||||
|         return fields | ||||
|  | ||||
|     @classmethod | ||||
| @@ -615,13 +622,32 @@ class QuerySet(object): | ||||
|             if _doc_cls: | ||||
|                 # Switch field names to proper names [set in Field(name='foo')] | ||||
|                 fields = QuerySet._lookup_field(_doc_cls, parts) | ||||
|                 parts = [field.db_field for field in fields] | ||||
|                 parts = [] | ||||
|  | ||||
|                 cleaned_fields = [] | ||||
|                 append_field = True | ||||
|                 for field in fields: | ||||
|                     if isinstance(field, str): | ||||
|                         parts.append(field) | ||||
|                         append_field = False | ||||
|                     else: | ||||
|                         parts.append(field.db_field) | ||||
|                     if append_field: | ||||
|                         cleaned_fields.append(field) | ||||
|  | ||||
|                 # Convert value to proper value | ||||
|                 field = fields[-1] | ||||
|                 field = cleaned_fields[-1] | ||||
|  | ||||
|                 singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] | ||||
|                 singular_ops += match_operators | ||||
|                 if op in singular_ops: | ||||
|                     if isinstance(field, basestring): | ||||
|                         if op in match_operators and isinstance(value, basestring): | ||||
|                             from mongoengine import StringField | ||||
|                             value = StringField().prepare_query_value(op, value) | ||||
|                         else: | ||||
|                             value = field | ||||
|                     else: | ||||
|                         value = field.prepare_query_value(op, value) | ||||
|                 elif op in ('in', 'nin', 'all', 'near'): | ||||
|                     # 'in', 'nin' and 'all' require a list of values | ||||
| @@ -1170,14 +1196,19 @@ class QuerySet(object): | ||||
|                 fields = QuerySet._lookup_field(_doc_cls, parts) | ||||
|                 parts = [] | ||||
|  | ||||
|                 cleaned_fields = [] | ||||
|                 append_field = True | ||||
|                 for field in fields: | ||||
|                     if isinstance(field, str): | ||||
|                         parts.append(field) | ||||
|                         append_field = False | ||||
|                     else: | ||||
|                         parts.append(field.db_field) | ||||
|                     if append_field: | ||||
|                         cleaned_fields.append(field) | ||||
|  | ||||
|                 # Convert value to proper value | ||||
|                 field = fields[-1] | ||||
|                 field = cleaned_fields[-1] | ||||
|  | ||||
|                 if op in (None, 'set', 'push', 'pull', 'addToSet'): | ||||
|                     value = field.prepare_query_value(op, value) | ||||
|   | ||||
| @@ -122,6 +122,64 @@ class FieldTest(unittest.TestCase): | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 4) | ||||
|  | ||||
|             for m in group_obj.members: | ||||
|                 self.assertTrue('User' in m.__class__.__name__) | ||||
|  | ||||
|         UserA.drop_collection() | ||||
|         UserB.drop_collection() | ||||
|         UserC.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|     def test_list_field_complex(self): | ||||
|  | ||||
|         class UserA(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class UserB(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class UserC(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class Group(Document): | ||||
|             members = ListField() | ||||
|  | ||||
|         UserA.drop_collection() | ||||
|         UserB.drop_collection() | ||||
|         UserC.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
|             b = UserB(name='User B %s' % i) | ||||
|             b.save() | ||||
|  | ||||
|             c = UserC(name='User C %s' % i) | ||||
|             c.save() | ||||
|  | ||||
|             members += [a, b, c] | ||||
|  | ||||
|         group = Group(members=members) | ||||
|         group.save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
|             self.assertEqual(q, 0) | ||||
|  | ||||
|             group_obj = Group.objects.first() | ||||
|             self.assertEqual(q, 1) | ||||
|  | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 4) | ||||
|  | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 4) | ||||
|  | ||||
|             for m in group_obj.members: | ||||
|                 self.assertTrue('User' in m.__class__.__name__) | ||||
|  | ||||
|         UserA.drop_collection() | ||||
|         UserB.drop_collection() | ||||
|         UserC.drop_collection() | ||||
| @@ -156,10 +214,13 @@ class FieldTest(unittest.TestCase): | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 2) | ||||
|  | ||||
|             for k, m in group_obj.members.iteritems(): | ||||
|                 self.assertTrue(isinstance(m, User)) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|     def ztest_generic_reference_dict_field(self): | ||||
|     def test_dict_field(self): | ||||
|  | ||||
|         class UserA(Document): | ||||
|             name = StringField() | ||||
| @@ -206,6 +267,9 @@ class FieldTest(unittest.TestCase): | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 4) | ||||
|  | ||||
|             for k, m in group_obj.members.iteritems(): | ||||
|                 self.assertTrue('User' in m.__class__.__name__) | ||||
|  | ||||
|         group.members = {} | ||||
|         group.save() | ||||
|  | ||||
| @@ -218,11 +282,54 @@ class FieldTest(unittest.TestCase): | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 1) | ||||
|  | ||||
|             for k, m in group_obj.members.iteritems(): | ||||
|                 self.assertTrue('User' in m.__class__.__name__) | ||||
|  | ||||
|         UserA.drop_collection() | ||||
|         UserB.drop_collection() | ||||
|         UserC.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|     def test_dict_field_no_field_inheritance(self): | ||||
|  | ||||
|         class UserA(Document): | ||||
|             name = StringField() | ||||
|             meta = {'allow_inheritance': False} | ||||
|  | ||||
|         class Group(Document): | ||||
|             members = DictField() | ||||
|  | ||||
|         UserA.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
|             members += [a] | ||||
|  | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
|             self.assertEqual(q, 0) | ||||
|  | ||||
|             group_obj = Group.objects.first() | ||||
|             self.assertEqual(q, 1) | ||||
|  | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 2) | ||||
|  | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 2) | ||||
|  | ||||
|             for k, m in group_obj.members.iteritems(): | ||||
|                 self.assertTrue(isinstance(m, UserA)) | ||||
|  | ||||
|         UserA.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|     def test_generic_reference_map_field(self): | ||||
|  | ||||
|         class UserA(Document): | ||||
| @@ -270,6 +377,9 @@ class FieldTest(unittest.TestCase): | ||||
|             [m for m in group_obj.members] | ||||
|             self.assertEqual(q, 4) | ||||
|  | ||||
|             for k, m in group_obj.members.iteritems(): | ||||
|                 self.assertTrue('User' in m.__class__.__name__) | ||||
|  | ||||
|         group.members = {} | ||||
|         group.save() | ||||
|  | ||||
|   | ||||
							
								
								
									
										287
									
								
								tests/fields.py
									
									
									
									
									
								
							
							
						
						
									
										287
									
								
								tests/fields.py
									
									
									
									
									
								
							| @@ -322,6 +322,108 @@ class FieldTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_field(self): | ||||
|         """Ensure that list types work as expected. | ||||
|         """ | ||||
|         class BlogPost(Document): | ||||
|             info = ListField() | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost() | ||||
|         post.info = 'my post' | ||||
|         self.assertRaises(ValidationError, post.validate) | ||||
|  | ||||
|         post.info = {'title': 'test'} | ||||
|         self.assertRaises(ValidationError, post.validate) | ||||
|  | ||||
|         post.info = ['test'] | ||||
|         post.save() | ||||
|  | ||||
|         post = BlogPost() | ||||
|         post.info = [{'test': 'test'}] | ||||
|         post.save() | ||||
|  | ||||
|         post = BlogPost() | ||||
|         post.info = [{'test': 3}] | ||||
|         post.save() | ||||
|  | ||||
|  | ||||
|         self.assertEquals(BlogPost.objects.count(), 3) | ||||
|         self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1) | ||||
|         self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1) | ||||
|  | ||||
|         # Confirm handles non strings or non existing keys | ||||
|         self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) | ||||
|         self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_field_Strict(self): | ||||
|         """Ensure that list field handles validation if provided a strict field type.""" | ||||
|  | ||||
|         class Simple(Document): | ||||
|             mapping = ListField(field=IntField()) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|         e = Simple() | ||||
|         e.mapping = [1] | ||||
|         e.save() | ||||
|  | ||||
|         def create_invalid_mapping(): | ||||
|             e.mapping = ["abc"] | ||||
|             e.save() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_mapping) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_list_field_complex(self): | ||||
|         """Ensure that the list fields can handle the complex types.""" | ||||
|  | ||||
|         class SettingBase(EmbeddedDocument): | ||||
|             pass | ||||
|  | ||||
|         class StringSetting(SettingBase): | ||||
|             value = StringField() | ||||
|  | ||||
|         class IntegerSetting(SettingBase): | ||||
|             value = IntField() | ||||
|  | ||||
|         class Simple(Document): | ||||
|             mapping = ListField() | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|         e = Simple() | ||||
|         e.mapping.append(StringSetting(value='foo')) | ||||
|         e.mapping.append(IntegerSetting(value=42)) | ||||
|         e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, | ||||
|                           'complex': IntegerSetting(value=42), 'list': | ||||
|                           [IntegerSetting(value=42), StringSetting(value='foo')]}) | ||||
|         e.save() | ||||
|  | ||||
|         e2 = Simple.objects.get(id=e.id) | ||||
|         self.assertTrue(isinstance(e2.mapping[0], StringSetting)) | ||||
|         self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) | ||||
|  | ||||
|         # Test querying | ||||
|         self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) | ||||
|  | ||||
|         # Confirm can update | ||||
|         Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1) | ||||
|  | ||||
|         Simple.objects().update( | ||||
|             set__mapping__2__list__1=StringSetting(value='Boo')) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_dict_field(self): | ||||
|         """Ensure that dict types work as expected. | ||||
|         """ | ||||
| @@ -363,6 +465,131 @@ class FieldTest(unittest.TestCase): | ||||
|         self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_dictfield_Strict(self): | ||||
|         """Ensure that dict field handles validation if provided a strict field type.""" | ||||
|  | ||||
|         class Simple(Document): | ||||
|             mapping = DictField(field=IntField()) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|         e = Simple() | ||||
|         e.mapping['someint'] = 1 | ||||
|         e.save() | ||||
|  | ||||
|         def create_invalid_mapping(): | ||||
|             e.mapping['somestring'] = "abc" | ||||
|             e.save() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_mapping) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_dictfield_complex(self): | ||||
|         """Ensure that the dict field can handle the complex types.""" | ||||
|  | ||||
|         class SettingBase(EmbeddedDocument): | ||||
|             pass | ||||
|  | ||||
|         class StringSetting(SettingBase): | ||||
|             value = StringField() | ||||
|  | ||||
|         class IntegerSetting(SettingBase): | ||||
|             value = IntField() | ||||
|  | ||||
|         class Simple(Document): | ||||
|             mapping = DictField() | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|         e = Simple() | ||||
|         e.mapping['somestring'] = StringSetting(value='foo') | ||||
|         e.mapping['someint'] = IntegerSetting(value=42) | ||||
|         e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, | ||||
|                                  'complex': IntegerSetting(value=42), 'list': | ||||
|                                 [IntegerSetting(value=42), StringSetting(value='foo')]} | ||||
|         e.save() | ||||
|  | ||||
|         e2 = Simple.objects.get(id=e.id) | ||||
|         self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) | ||||
|         self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) | ||||
|  | ||||
|         # Test querying | ||||
|         self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) | ||||
|  | ||||
|         # Confirm can update | ||||
|         Simple.objects().update( | ||||
|             set__mapping={"someint": IntegerSetting(value=10)}) | ||||
|         Simple.objects().update( | ||||
|             set__mapping__nested_dict__list__1=StringSetting(value='Boo')) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) | ||||
|         self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_mapfield(self): | ||||
|         """Ensure that the MapField handles the declared type.""" | ||||
|  | ||||
|         class Simple(Document): | ||||
|             mapping = MapField(IntField()) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|         e = Simple() | ||||
|         e.mapping['someint'] = 1 | ||||
|         e.save() | ||||
|  | ||||
|         def create_invalid_mapping(): | ||||
|             e.mapping['somestring'] = "abc" | ||||
|             e.save() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_mapping) | ||||
|  | ||||
|         def create_invalid_class(): | ||||
|             class NoDeclaredType(Document): | ||||
|                 mapping = MapField() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_class) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_complex_mapfield(self): | ||||
|         """Ensure that the MapField can handle complex declared types.""" | ||||
|  | ||||
|         class SettingBase(EmbeddedDocument): | ||||
|             pass | ||||
|  | ||||
|         class StringSetting(SettingBase): | ||||
|             value = StringField() | ||||
|  | ||||
|         class IntegerSetting(SettingBase): | ||||
|             value = IntField() | ||||
|  | ||||
|         class Extensible(Document): | ||||
|             mapping = MapField(EmbeddedDocumentField(SettingBase)) | ||||
|  | ||||
|         Extensible.drop_collection() | ||||
|  | ||||
|         e = Extensible() | ||||
|         e.mapping['somestring'] = StringSetting(value='foo') | ||||
|         e.mapping['someint'] = IntegerSetting(value=42) | ||||
|         e.save() | ||||
|  | ||||
|         e2 = Extensible.objects.get(id=e.id) | ||||
|         self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) | ||||
|         self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) | ||||
|  | ||||
|         def create_invalid_mapping(): | ||||
|             e.mapping['someint'] = 123 | ||||
|             e.save() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_mapping) | ||||
|  | ||||
|         Extensible.drop_collection() | ||||
|  | ||||
|     def test_embedded_document_validation(self): | ||||
|         """Ensure that invalid embedded documents cannot be assigned to | ||||
|         embedded document fields. | ||||
| @@ -933,66 +1160,6 @@ class FieldTest(unittest.TestCase): | ||||
|         self.assertEqual(d2.data, {}) | ||||
|         self.assertEqual(d2.data2, {}) | ||||
|  | ||||
|     def test_mapfield(self): | ||||
|         """Ensure that the MapField handles the declared type.""" | ||||
|  | ||||
|         class Simple(Document): | ||||
|             mapping = MapField(IntField()) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|         e = Simple() | ||||
|         e.mapping['someint'] = 1 | ||||
|         e.save() | ||||
|  | ||||
|         def create_invalid_mapping(): | ||||
|             e.mapping['somestring'] = "abc" | ||||
|             e.save() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_mapping) | ||||
|  | ||||
|         def create_invalid_class(): | ||||
|             class NoDeclaredType(Document): | ||||
|                 mapping = MapField() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_class) | ||||
|  | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_complex_mapfield(self): | ||||
|         """Ensure that the MapField can handle complex declared types.""" | ||||
|  | ||||
|         class SettingBase(EmbeddedDocument): | ||||
|             pass | ||||
|  | ||||
|         class StringSetting(SettingBase): | ||||
|             value = StringField() | ||||
|  | ||||
|         class IntegerSetting(SettingBase): | ||||
|             value = IntField() | ||||
|  | ||||
|         class Extensible(Document): | ||||
|             mapping = MapField(EmbeddedDocumentField(SettingBase)) | ||||
|  | ||||
|         Extensible.drop_collection() | ||||
|  | ||||
|         e = Extensible() | ||||
|         e.mapping['somestring'] = StringSetting(value='foo') | ||||
|         e.mapping['someint'] = IntegerSetting(value=42) | ||||
|         e.save() | ||||
|  | ||||
|         e2 = Extensible.objects.get(id=e.id) | ||||
|         self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) | ||||
|         self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) | ||||
|  | ||||
|         def create_invalid_mapping(): | ||||
|             e.mapping['someint'] = 123 | ||||
|             e.save() | ||||
|  | ||||
|         self.assertRaises(ValidationError, create_invalid_mapping) | ||||
|  | ||||
|         Extensible.drop_collection() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user