From 30fdd3e184165b722cd470fdbd470634374bff77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Mon, 7 Jul 2014 12:54:41 -0300 Subject: [PATCH] Added initial CachedReferenceField --- mongoengine/base/document.py | 83 +++++++---- mongoengine/fields.py | 206 +++++++++++++++++++++++----- tests/fields/fields.py | 257 +++++++++++++++++++++++++++++------ 3 files changed, 445 insertions(+), 101 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 961f5be1..593c237d 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -23,9 +23,10 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' + class BaseDocument(object): __slots__ = ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') + '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') _dynamic = False _dynamic_lock = True @@ -50,7 +51,8 @@ class BaseDocument(object): for value in args: name = next(field) if name in values: - raise TypeError("Multiple values for keyword argument '" + name + "'") + raise TypeError( + "Multiple values for keyword argument '" + name + "'") values[name] = value __auto_convert = values.pop("__auto_convert", True) signals.pre_init.send(self.__class__, document=self, values=values) @@ -58,7 +60,8 @@ class BaseDocument(object): if self.STRICT and not self._dynamic: self._data = StrictDict.create(allowed_keys=self._fields_ordered)() else: - self._data = SemiStrictDict.create(allowed_keys=self._fields_ordered)() + self._data = SemiStrictDict.create( + allowed_keys=self._fields_ordered)() _created = values.pop("_created", True) self._data = {} @@ -144,8 +147,8 @@ class BaseDocument(object): self__created = True if (self._is_document and not self__created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) @@ -156,8 +159,8 @@ class BaseDocument(object): self__initialised = False # Check if the user has created a new instance of a class if (self._is_document and self__initialised - and self__created and name == self._meta['id_field']): - super(BaseDocument, self).__setattr__('_created', False) + and self__created and name == self._meta['id_field']): + super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__(name, value) @@ -174,7 +177,7 @@ class BaseDocument(object): if isinstance(data["_data"], SON): data["_data"] = self.__class__._from_son(data["_data"])._data for k in ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields'): + '_dynamic_fields'): if k in data: setattr(self, k, data[k]) if '_fields_ordered' in data: @@ -257,26 +260,45 @@ class BaseDocument(object): """ pass - def to_mongo(self, use_db_field=True): - """Return as SON data ready for use with MongoDB. + def to_mongo(self, use_db_field=True, fields=[]): + """ + Return as SON data ready for use with MongoDB. """ data = SON() data["_id"] = None data['_cls'] = self._class_name + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + + # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] + root_fields = set([f.split('.')[0] for f in fields]) for field_name in self: + if root_fields and field_name not in root_fields: + continue + value = self._data.get(field_name, None) field = self._fields.get(field_name) + if field is None and self._dynamic: field = self._dynamic_fields.get(field_name) if value is not None: EmbeddedDocument = _import_class("EmbeddedDocument") - if isinstance(value, (EmbeddedDocument)) and use_db_field==False: + if isinstance(value, (EmbeddedDocument)) and \ + not use_db_field: value = field.to_mongo(value, use_db_field) else: value = field.to_mongo(value) + if isinstance(field, EmbeddedDocumentField) and fields: + key = '%s.' % field_name + + value = field.to_mongo(value, fields=[ + i.replace(key, '') for i in fields if i.startswith(key)]) + + elif value is not None: + value = field.to_mongo(value) + # Handle self generating fields if value is None and field._auto_gen: value = field.generate() @@ -299,7 +321,7 @@ class BaseDocument(object): # Only add _cls if allow_inheritance is True if (not hasattr(self, '_meta') or - not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): + not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): data.pop('_cls') return data @@ -321,7 +343,8 @@ class BaseDocument(object): self._data.get(name)) for name in self._fields_ordered] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") + GenericEmbeddedDocumentField = _import_class( + "GenericEmbeddedDocumentField") for field, value in fields: if value is not None: @@ -352,7 +375,8 @@ class BaseDocument(object): """Converts a document to JSON. :param use_db_field: Set to True by default but enables the output of the json structure with the field names and not the mongodb store db_names in case of set to False """ - use_db_field = kwargs.pop('use_db_field') if kwargs.has_key('use_db_field') else True + use_db_field = kwargs.pop('use_db_field') if kwargs.has_key( + 'use_db_field') else True return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) @classmethod @@ -387,7 +411,7 @@ class BaseDocument(object): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + not isinstance(value, BaseList)): value = BaseList(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) @@ -452,9 +476,10 @@ class BaseDocument(object): if hasattr(value, '_get_changed_fields'): changed = value._get_changed_fields(inspected) changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] + for k in changed if k] elif isinstance(value, (list, tuple, dict)): - self._nestable_types_changed_fields(changed_fields, list_key, value, inspected) + self._nestable_types_changed_fields( + changed_fields, list_key, value, inspected) def _get_changed_fields(self, inspected=None): """Returns a list of all fields that have explicitly been changed. @@ -484,16 +509,17 @@ class BaseDocument(object): if isinstance(field, ReferenceField): continue elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in changed_fields): + and db_field_name not in changed_fields): # Find all embedded fields that have been changed changed = data._get_changed_fields(inspected) changed_fields += ["%s%s" % (key, k) for k in changed if k] elif (isinstance(data, (list, tuple, dict)) and db_field_name not in changed_fields): if (hasattr(field, 'field') and - isinstance(field.field, ReferenceField)): + isinstance(field.field, ReferenceField)): continue - self._nestable_types_changed_fields(changed_fields, key, data, inspected) + self._nestable_types_changed_fields( + changed_fields, key, data, inspected) return changed_fields def _delta(self): @@ -539,7 +565,7 @@ class BaseDocument(object): # If we've set a value that ain't the default value dont unset it. default = None if (self._dynamic and len(parts) and parts[0] in - self._dynamic_fields): + self._dynamic_fields): del(set_data[path]) unset_data[path] = 1 continue @@ -625,13 +651,14 @@ class BaseDocument(object): if errors_dict: errors = "\n".join(["%s - %s" % (k, v) - for k, v in errors_dict.items()]) + for k, v in errors_dict.items()]) msg = ("Invalid data to create a `%s` instance.\n%s" % (cls._class_name, errors)) raise InvalidDocumentError(msg) if cls.STRICT: - data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) + data = dict((k, v) + for k, v in data.iteritems() if k in cls._fields) obj = cls(__auto_convert=False, _created=False, **data) obj._changed_fields = changed_fields if not _auto_dereference: @@ -778,7 +805,7 @@ class BaseDocument(object): # Grab any embedded document field unique indexes if (field.__class__.__name__ == "EmbeddedDocumentField" and - field.document_type != cls): + field.document_type != cls): field_namespace = "%s." % field_name doc_cls = field.document_type unique_indexes += doc_cls._unique_with_indexes(field_namespace) @@ -794,7 +821,8 @@ class BaseDocument(object): geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", "PointField", "LineStringField", "PolygonField"] - geo_field_types = tuple([_import_class(field) for field in geo_field_type_names]) + geo_field_types = tuple([_import_class(field) + for field in geo_field_type_names]) for field in cls._fields.values(): if not isinstance(field, geo_field_types): @@ -804,13 +832,14 @@ class BaseDocument(object): if field_cls in inspected: continue if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected, parent_field=field.db_field) + geo_indices += field_cls._geo_indices( + inspected, parent_field=field.db_field) elif field._geo_index: field_name = field.db_field if parent_field: field_name = "%s.%s" % (parent_field, field_name) geo_indices.append({'fields': - [(field_name, field._geo_index)]}) + [(field_name, field._geo_index)]}) return geo_indices @classmethod diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0320898b..22bb7423 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -34,22 +34,24 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', - 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', - 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', - 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', - 'SortedListField', 'DictField', 'MapField', 'ReferenceField', - 'GenericReferenceField', 'BinaryField', 'GridFSError', - 'GridFSProxy', 'FileField', 'ImageGridFsProxy', - 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', - 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', - 'GeoJsonBaseField'] +__all__ = [ + 'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', + 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', + 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', + 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', + 'SortedListField', 'DictField', 'MapField', 'ReferenceField', + 'CachedReferenceField', 'GenericReferenceField', 'BinaryField', + 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy', + 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', + 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', + 'GeoJsonBaseField'] RECURSIVE_REFERENCE_CONSTANT = 'self' class StringField(BaseField): + """A unicode string field. """ @@ -109,6 +111,7 @@ class StringField(BaseField): class URLField(StringField): + """A field that validates input as an URL. .. versionadded:: 0.3 @@ -116,7 +119,8 @@ class URLField(StringField): _URL_REGEX = re.compile( r'^(?:http|ftp)s?://' # http:// or https:// - r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... + # domain... + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' r'localhost|' # localhost... r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'(?::\d+)?' # optional port @@ -145,15 +149,19 @@ class URLField(StringField): class EmailField(StringField): + """A field that validates input as an E-Mail-Address. .. versionadded:: 0.4 """ EMAIL_REGEX = re.compile( - r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain + # dot-atom + r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" + # quoted-string + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' + # domain + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE ) def validate(self, value): @@ -163,6 +171,7 @@ class EmailField(StringField): class IntField(BaseField): + """An 32-bit integer field. """ @@ -197,6 +206,7 @@ class IntField(BaseField): class LongField(BaseField): + """An 64-bit integer field. """ @@ -231,6 +241,7 @@ class LongField(BaseField): class FloatField(BaseField): + """An floating point number field. """ @@ -265,6 +276,7 @@ class FloatField(BaseField): class DecimalField(BaseField): + """A fixed-point decimal number field. .. versionchanged:: 0.8 @@ -338,6 +350,7 @@ class DecimalField(BaseField): class BooleanField(BaseField): + """A boolean field type. .. versionadded:: 0.1.2 @@ -356,6 +369,7 @@ class BooleanField(BaseField): class DateTimeField(BaseField): + """A datetime field. Uses the python-dateutil library if available alternatively use time.strptime @@ -406,15 +420,15 @@ class DateTimeField(BaseField): kwargs = {'microsecond': usecs} try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M:%S')[:6], **kwargs) + '%Y-%m-%d %H:%M:%S')[:6], **kwargs) except ValueError: try: # Try without seconds. return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M')[:5], **kwargs) + '%Y-%m-%d %H:%M')[:5], **kwargs) except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, - '%Y-%m-%d')[:3], **kwargs) + '%Y-%m-%d')[:3], **kwargs) except ValueError: return None @@ -423,6 +437,7 @@ class DateTimeField(BaseField): class ComplexDateTimeField(StringField): + """ ComplexDateTimeField handles microseconds exactly instead of rounding like DateTimeField does. @@ -525,6 +540,7 @@ class ComplexDateTimeField(StringField): class EmbeddedDocumentField(BaseField): + """An embedded document field - with a declared document_type. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. """ @@ -551,7 +567,7 @@ class EmbeddedDocumentField(BaseField): return self.document_type._from_son(value) return value - def to_mongo(self, value, use_db_field=True): + def to_mongo(self, value, use_db_field=True, fields=[]): if not isinstance(value, self.document_type): return value return self.document_type.to_mongo(value, use_db_field) @@ -574,6 +590,7 @@ class EmbeddedDocumentField(BaseField): class GenericEmbeddedDocumentField(BaseField): + """A generic embedded document field - allows any :class:`~mongoengine.EmbeddedDocument` to be stored. @@ -612,6 +629,7 @@ class GenericEmbeddedDocumentField(BaseField): class DynamicField(BaseField): + """A truly dynamic field type capable of handling different and varying types of data. @@ -675,6 +693,7 @@ class DynamicField(BaseField): class ListField(ComplexBaseField): + """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. @@ -693,21 +712,22 @@ class ListField(ComplexBaseField): """Make sure that a list of valid fields is being used. """ if (not isinstance(value, (list, tuple, QuerySet)) or - isinstance(value, basestring)): + isinstance(value, basestring)): self.error('Only lists and tuples may be used in a list field') super(ListField, self).validate(value) def prepare_query_value(self, op, value): if self.field: if op in ('set', 'unset') and (not isinstance(value, basestring) - and not isinstance(value, BaseDocument) - and hasattr(value, '__iter__')): + and not isinstance(value, BaseDocument) + and hasattr(value, '__iter__')): return [self.field.prepare_query_value(op, v) for v in value] return self.field.prepare_query_value(op, value) return super(ListField, self).prepare_query_value(op, value) class SortedListField(ListField): + """A ListField that sorts the contents of its list before writing to the database in order to ensure that a sorted list is always retrieved. @@ -739,6 +759,7 @@ class SortedListField(ListField): reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse) + def key_not_string(d): """ Helper function to recursively determine if any key in a dictionary is not a string. @@ -747,6 +768,7 @@ def key_not_string(d): if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): return True + def key_has_dot_or_dollar(d): """ Helper function to recursively determine if any key in a dictionary contains a dot or a dollar sign. @@ -755,7 +777,9 @@ def key_has_dot_or_dollar(d): if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): return True + class DictField(ComplexBaseField): + """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. @@ -807,6 +831,7 @@ class DictField(ComplexBaseField): class MapField(DictField): + """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -822,6 +847,7 @@ class MapField(DictField): class ReferenceField(BaseField): + """A reference to a document that will be automatically dereferenced on access (lazily). @@ -932,7 +958,7 @@ class ReferenceField(BaseField): """Convert a MongoDB-compatible type to a Python type. """ if (not self.dbref and - not isinstance(value, (DBRef, Document, EmbeddedDocument))): + not isinstance(value, (DBRef, Document, EmbeddedDocument))): collection = self.document_type._get_collection_name() value = DBRef(collection, self.document_type.id.to_python(value)) return value @@ -955,7 +981,106 @@ class ReferenceField(BaseField): return self.document_type._fields.get(member_name) +class CachedReferenceField(BaseField): + + """ + A referencefield with cache fields support + .. versionadded:: 0.9 + """ + + def __init__(self, document_type, fields=[], **kwargs): + """Initialises the Cached Reference Field. + + :param fields: A list of fields to be cached in document + """ + if not isinstance(document_type, basestring) and \ + not issubclass(document_type, (Document, basestring)): + + self.error('Argument to CachedReferenceField constructor must be a' + ' document class or a string') + + self.document_type_obj = document_type + self.fields = fields + super(CachedReferenceField, self).__init__(**kwargs) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + if isinstance(value, dict): + collection = self.document_type._get_collection_name() + value = DBRef( + collection, self.document_type.id.to_python(value['_id'])) + + return value + + @property + def document_type(self): + if isinstance(self.document_type_obj, basestring): + if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: + self.document_type_obj = self.owner_document + else: + self.document_type_obj = get_document(self.document_type_obj) + return self.document_type_obj + + def __get__(self, instance, owner): + """Descriptor to allow lazy dereferencing. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value = instance._data.get(self.name) + self._auto_dereference = instance._fields[self.name]._auto_dereference + # Dereference DBRefs + if self._auto_dereference and isinstance(value, DBRef): + value = self.document_type._get_db().dereference(value) + if value is not None: + instance._data[self.name] = self.document_type._from_son(value) + + return super(CachedReferenceField, self).__get__(instance, owner) + + def to_mongo(self, document): + id_field_name = self.document_type._meta['id_field'] + id_field = self.document_type._fields[id_field_name] + doc_tipe = self.document_type + + if isinstance(document, Document): + # We need the id from the saved object to create the DBRef + id_ = document.pk + if id_ is None: + self.error('You can only reference documents once they have' + ' been saved to the database') + else: + self.error('Only accept a document object') + + value = { + "_id": id_field.to_mongo(id_) + } + + value.update(dict(document.to_mongo(fields=self.fields))) + return value + + def prepare_query_value(self, op, value): + if value is None: + return None + return self.to_mongo(value) + + def validate(self, value): + + if not isinstance(value, (self.document_type)): + self.error("A CachedReferenceField only accepts documents") + + if isinstance(value, Document) and value.id is None: + self.error('You can only reference documents once they have been ' + 'saved to the database') + + def lookup_member(self, member_name): + return self.document_type._fields.get(member_name) + + class GenericReferenceField(BaseField): + """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -974,6 +1099,7 @@ class GenericReferenceField(BaseField): return self value = instance._data.get(self.name) + self._auto_dereference = instance._fields[self.name]._auto_dereference if self._auto_dereference and isinstance(value, (dict, SON)): instance._data[self.name] = self.dereference(value) @@ -1036,6 +1162,7 @@ class GenericReferenceField(BaseField): class BinaryField(BaseField): + """A binary data field. """ @@ -1056,7 +1183,7 @@ class BinaryField(BaseField): if not isinstance(value, (bin_type, txt_type, Binary)): self.error("BinaryField only accepts instances of " "(%s, %s, Binary)" % ( - bin_type.__name__, txt_type.__name__)) + bin_type.__name__, txt_type.__name__)) if self.max_bytes is not None and len(value) > self.max_bytes: self.error('Binary value is too long') @@ -1067,6 +1194,7 @@ class GridFSError(Exception): class GridFSProxy(object): + """Proxy object to handle writing and reading of files to and from GridFS .. versionadded:: 0.4 @@ -1121,7 +1249,8 @@ class GridFSProxy(object): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) def __str__(self): - name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)' + name = getattr( + self.get(), 'filename', self.grid_id) if self.get() else '(no file)' return '<%s: %s>' % (self.__class__.__name__, name) def __eq__(self, other): @@ -1135,7 +1264,8 @@ class GridFSProxy(object): @property def fs(self): if not self._fs: - self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name) + self._fs = gridfs.GridFS( + get_db(self.db_alias), self.collection_name) return self._fs def get(self, id=None): @@ -1209,6 +1339,7 @@ class GridFSProxy(object): class FileField(BaseField): + """A GridFS storage field. .. versionadded:: 0.4 @@ -1253,7 +1384,8 @@ class FileField(BaseField): pass # Create a new proxy object as we don't already have one - instance._data[key] = self.get_proxy_obj(key=key, instance=instance) + instance._data[key] = self.get_proxy_obj( + key=key, instance=instance) instance._data[key].put(value) else: instance._data[key] = value @@ -1291,11 +1423,13 @@ class FileField(BaseField): class ImageGridFsProxy(GridFSProxy): + """ Proxy for ImageField versionadded: 0.6 """ + def put(self, file_obj, **kwargs): """ Insert a image in database @@ -1341,7 +1475,8 @@ class ImageGridFsProxy(GridFSProxy): size = field.thumbnail_size if size['force']: - thumbnail = ImageOps.fit(img, (size['width'], size['height']), Image.ANTIALIAS) + thumbnail = ImageOps.fit( + img, (size['width'], size['height']), Image.ANTIALIAS) else: thumbnail = img.copy() thumbnail.thumbnail((size['width'], @@ -1367,7 +1502,7 @@ class ImageGridFsProxy(GridFSProxy): **kwargs) def delete(self, *args, **kwargs): - #deletes thumbnail + # deletes thumbnail out = self.get() if out and out.thumbnail_id: self.fs.delete(out.thumbnail_id) @@ -1427,6 +1562,7 @@ class ImproperlyConfigured(Exception): class ImageField(FileField): + """ A Image File storage field. @@ -1465,6 +1601,7 @@ class ImageField(FileField): class SequenceField(BaseField): + """Provides a sequental counter see: http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers @@ -1534,7 +1671,7 @@ class SequenceField(BaseField): data = collection.find_one({"_id": sequence_id}) if data: - return self.value_decorator(data['next']+1) + return self.value_decorator(data['next'] + 1) return self.value_decorator(1) @@ -1579,6 +1716,7 @@ class SequenceField(BaseField): class UUIDField(BaseField): + """A UUID field. .. versionadded:: 0.6 @@ -1631,6 +1769,7 @@ class UUIDField(BaseField): class GeoPointField(BaseField): + """A list storing a longitude and latitude coordinate. .. note:: this represents a generic point in a 2D plane and a legacy way of @@ -1651,13 +1790,16 @@ class GeoPointField(BaseField): 'of (x, y)') if not len(value) == 2: - self.error("Value (%s) must be a two-dimensional point" % repr(value)) + self.error("Value (%s) must be a two-dimensional point" % + repr(value)) elif (not isinstance(value[0], (float, int)) or not isinstance(value[1], (float, int))): - self.error("Both values (%s) in point must be float or int" % repr(value)) + self.error( + "Both values (%s) in point must be float or int" % repr(value)) class PointField(GeoJsonBaseField): + """A GeoJSON field storing a longitude and latitude coordinate. The data is represented as: @@ -1677,6 +1819,7 @@ class PointField(GeoJsonBaseField): class LineStringField(GeoJsonBaseField): + """A GeoJSON field storing a line of longitude and latitude coordinates. The data is represented as: @@ -1695,6 +1838,7 @@ class LineStringField(GeoJsonBaseField): class PolygonField(GeoJsonBaseField): + """A GeoJSON field storing a polygon of longitude and latitude coordinates. The data is represented as: diff --git a/tests/fields/fields.py b/tests/fields/fields.py index c108f37e..5da06981 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -47,7 +47,8 @@ class FieldTest(unittest.TestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) + self.assertEqual( + data_to_be_saved, ['age', 'created', 'name', 'userid']) self.assertTrue(person.validate() is None) @@ -63,7 +64,8 @@ class FieldTest(unittest.TestCase): # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) + self.assertEqual( + data_to_be_saved, ['age', 'created', 'name', 'userid']) def test_default_values_set_to_None(self): """Ensure that default field values are used when creating a document. @@ -587,7 +589,8 @@ class FieldTest(unittest.TestCase): LogEntry.drop_collection() - # Post UTC - microseconds are rounded (down) nearest millisecond and dropped + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) log = LogEntry() @@ -688,7 +691,8 @@ class FieldTest(unittest.TestCase): LogEntry.drop_collection() - # Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped - with default datetimefields d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) log = LogEntry() log.date = d1 @@ -696,14 +700,16 @@ class FieldTest(unittest.TestCase): log.reload() self.assertEqual(log.date, d1) - # Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields + # Post UTC - microseconds are rounded (down) nearest millisecond - with + # default datetimefields d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) log.date = d1 log.save() log.reload() self.assertEqual(log.date, d1) - # Pre UTC dates microseconds below 1000 are dropped - with default datetimefields + # Pre UTC dates microseconds below 1000 are dropped - with default + # datetimefields d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) log.date = d1 log.save() @@ -929,12 +935,16 @@ class FieldTest(unittest.TestCase): post.save() self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual(BlogPost.objects.filter(info__exact='test').count(), 1) - self.assertEqual(BlogPost.objects.filter(info__0__test='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__0__test='test').count(), 1) # Confirm handles non strings or non existing keys - self.assertEqual(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) - self.assertEqual(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__0__test__exact='5').count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__100__test__exact='test').count(), 0) BlogPost.drop_collection() def test_list_field_passed_in_value(self): @@ -951,7 +961,6 @@ class FieldTest(unittest.TestCase): foo.bars.append(bar) self.assertEqual(repr(foo.bars), '[]') - def test_list_field_strict(self): """Ensure that list field handles validation if provided a strict field type.""" @@ -1082,20 +1091,28 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) # Test querying - self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__1__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__number=1).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__complex__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) # Confirm can update Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) - self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__1__value=10).count(), 1) Simple.objects().update( set__mapping__2__list__1=StringSetting(value='Boo')) - self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) - self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) + self.assertEqual( + Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) Simple.drop_collection() @@ -1141,12 +1158,16 @@ class FieldTest(unittest.TestCase): post.save() self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual(BlogPost.objects.filter(info__title__exact='test').count(), 1) - self.assertEqual(BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__title__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact='test').count(), 1) # Confirm handles non strings or non existing keys - self.assertEqual(BlogPost.objects.filter(info__details__test__exact=5).count(), 0) - self.assertEqual(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) post = BlogPost.objects.create(info={'title': 'original'}) post.info.update({'title': 'updated'}) @@ -1207,19 +1228,26 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) # Test querying - self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) # Confirm can update Simple.objects().update( set__mapping={"someint": IntegerSetting(value=10)}) Simple.objects().update( set__mapping__nested_dict__list__1=StringSetting(value='Boo')) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) Simple.drop_collection() @@ -1290,7 +1318,7 @@ class FieldTest(unittest.TestCase): class Test(Document): my_map = MapField(field=EmbeddedDocumentField(Embedded), - db_field='x') + db_field='x') Test.drop_collection() @@ -1334,7 +1362,7 @@ class FieldTest(unittest.TestCase): Log(name="wilson", visited={'friends': datetime.datetime.now()}).save() self.assertEqual(1, Log.objects( - visited__friends__exists=True).count()) + visited__friends__exists=True).count()) def test_embedded_db_field(self): @@ -1477,6 +1505,151 @@ class FieldTest(unittest.TestCase): mongoed = p1.to_mongo() self.assertTrue(isinstance(mongoed['parent'], ObjectId)) + def test_cached_reference_fields(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(nam="Leopard", tag="heavy") + a.save() + + o = Ocorrence(person="teste", animal=a) + o.save() + self.assertEqual( + a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects(animal__tag='heavy').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects(animal__tag='heavy').first() + self.assertEqual(ocorrence.person, "teste") + self.assertTrue(isinstance(ocorrence.animal, Animal)) + + def test_cached_reference_embedded_fields(self): + class Owner(EmbeddedDocument): + TPS = ( + ('n', "Normal"), + ('u', "Urgent") + ) + name = StringField() + tp = StringField( + verbose_name="Type", + db_field="t", + choices=TPS) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tp']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(nam="Leopard", tag="heavy", + owner=Owner(tp='u', name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 't': 'u' + } + }) + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects( + animal__tag='heavy', animal__owner__tp='u').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tp='u').first() + self.assertEqual(ocorrence.person, "teste") + self.assertTrue(isinstance(ocorrence.animal, Animal)) + + def test_cached_reference_embedded_list_fields(self): + class Owner(EmbeddedDocument): + name = StringField() + tags = ListField(StringField()) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tags']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(nam="Leopard", tag="heavy", + owner=Owner(tags=['cool', 'funny'], + name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste 2", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 'tags': ['cool', 'funny'] + } + }) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['tags'], + ['cool', 'funny']) + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + query = Ocorrence.objects( + animal__tag='heavy', animal__owner__tags='cool')._query + self.assertEqual( + query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tags='cool').first() + self.assertEqual(ocorrence.person, "teste 2") + self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_objectid_reference_fields(self): class Person(Document): @@ -1834,8 +2007,7 @@ class FieldTest(unittest.TestCase): Person(name="Wilson Jr").save() self.assertEqual(repr(Person.objects(city=None)), - "[]") - + "[]") def test_generic_reference_choices(self): """Ensure that a GenericReferenceField can handle choices @@ -1982,7 +2154,8 @@ class FieldTest(unittest.TestCase): attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07')) attachment_required.validate() - attachment_size_limit = AttachmentSizeLimit(blob=b('\xe6\x00\xc4\xff\x07')) + attachment_size_limit = AttachmentSizeLimit( + blob=b('\xe6\x00\xc4\xff\x07')) self.assertRaises(ValidationError, attachment_size_limit.validate) attachment_size_limit.blob = b('\xe6\x00\xc4\xff') attachment_size_limit.validate() @@ -2030,8 +2203,8 @@ class FieldTest(unittest.TestCase): """ class Shirt(Document): size = StringField(max_length=3, choices=( - ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), - ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) style = StringField(max_length=3, choices=( ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') @@ -2061,7 +2234,7 @@ class FieldTest(unittest.TestCase): """ class Shirt(Document): size = StringField(max_length=3, - choices=('S', 'M', 'L', 'XL', 'XXL')) + choices=('S', 'M', 'L', 'XL', 'XXL')) Shirt.drop_collection() @@ -2179,7 +2352,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 1000) - def test_sequence_field_get_next_value(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2368,7 +2540,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, post.comments[0].id) self.assertEqual(2, post.comments[1].id) - def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField() @@ -2478,7 +2649,7 @@ class FieldTest(unittest.TestCase): self.assertTrue('comments' in error.errors) self.assertTrue(1 in error.errors['comments']) self.assertTrue(isinstance(error.errors['comments'][1]['content'], - ValidationError)) + ValidationError)) # ValidationError.schema property error_dict = error.to_dict() @@ -2604,11 +2775,11 @@ class FieldTest(unittest.TestCase): DictFieldTest.drop_collection() test = DictFieldTest(dictionary=None) - test.dictionary # Just access to test getter + test.dictionary # Just access to test getter self.assertRaises(ValidationError, test.validate) test = DictFieldTest(dictionary=False) - test.dictionary # Just access to test getter + test.dictionary # Just access to test getter self.assertRaises(ValidationError, test.validate)