From 30fdd3e184165b722cd470fdbd470634374bff77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Mon, 7 Jul 2014 12:54:41 -0300 Subject: [PATCH 1/9] Added initial CachedReferenceField --- mongoengine/base/document.py | 83 +++++++---- mongoengine/fields.py | 206 +++++++++++++++++++++++----- tests/fields/fields.py | 257 +++++++++++++++++++++++++++++------ 3 files changed, 445 insertions(+), 101 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 961f5be1..593c237d 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -23,9 +23,10 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' + class BaseDocument(object): __slots__ = ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') + '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') _dynamic = False _dynamic_lock = True @@ -50,7 +51,8 @@ class BaseDocument(object): for value in args: name = next(field) if name in values: - raise TypeError("Multiple values for keyword argument '" + name + "'") + raise TypeError( + "Multiple values for keyword argument '" + name + "'") values[name] = value __auto_convert = values.pop("__auto_convert", True) signals.pre_init.send(self.__class__, document=self, values=values) @@ -58,7 +60,8 @@ class BaseDocument(object): if self.STRICT and not self._dynamic: self._data = StrictDict.create(allowed_keys=self._fields_ordered)() else: - self._data = SemiStrictDict.create(allowed_keys=self._fields_ordered)() + self._data = SemiStrictDict.create( + allowed_keys=self._fields_ordered)() _created = values.pop("_created", True) self._data = {} @@ -144,8 +147,8 @@ class BaseDocument(object): self__created = True if (self._is_document and not self__created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) @@ -156,8 +159,8 @@ class BaseDocument(object): self__initialised = False # Check if the user has created a new instance of a class if (self._is_document and self__initialised - and self__created and name == self._meta['id_field']): - super(BaseDocument, self).__setattr__('_created', False) + and self__created and name == self._meta['id_field']): + super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__(name, value) @@ -174,7 +177,7 @@ class BaseDocument(object): if isinstance(data["_data"], SON): data["_data"] = self.__class__._from_son(data["_data"])._data for k in ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields'): + '_dynamic_fields'): if k in data: setattr(self, k, data[k]) if '_fields_ordered' in data: @@ -257,26 +260,45 @@ class BaseDocument(object): """ pass - def to_mongo(self, use_db_field=True): - """Return as SON data ready for use with MongoDB. + def to_mongo(self, use_db_field=True, fields=[]): + """ + Return as SON data ready for use with MongoDB. """ data = SON() data["_id"] = None data['_cls'] = self._class_name + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + + # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] + root_fields = set([f.split('.')[0] for f in fields]) for field_name in self: + if root_fields and field_name not in root_fields: + continue + value = self._data.get(field_name, None) field = self._fields.get(field_name) + if field is None and self._dynamic: field = self._dynamic_fields.get(field_name) if value is not None: EmbeddedDocument = _import_class("EmbeddedDocument") - if isinstance(value, (EmbeddedDocument)) and use_db_field==False: + if isinstance(value, (EmbeddedDocument)) and \ + not use_db_field: value = field.to_mongo(value, use_db_field) else: value = field.to_mongo(value) + if isinstance(field, EmbeddedDocumentField) and fields: + key = '%s.' % field_name + + value = field.to_mongo(value, fields=[ + i.replace(key, '') for i in fields if i.startswith(key)]) + + elif value is not None: + value = field.to_mongo(value) + # Handle self generating fields if value is None and field._auto_gen: value = field.generate() @@ -299,7 +321,7 @@ class BaseDocument(object): # Only add _cls if allow_inheritance is True if (not hasattr(self, '_meta') or - not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): + not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): data.pop('_cls') return data @@ -321,7 +343,8 @@ class BaseDocument(object): self._data.get(name)) for name in self._fields_ordered] EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") + GenericEmbeddedDocumentField = _import_class( + "GenericEmbeddedDocumentField") for field, value in fields: if value is not None: @@ -352,7 +375,8 @@ class BaseDocument(object): """Converts a document to JSON. :param use_db_field: Set to True by default but enables the output of the json structure with the field names and not the mongodb store db_names in case of set to False """ - use_db_field = kwargs.pop('use_db_field') if kwargs.has_key('use_db_field') else True + use_db_field = kwargs.pop('use_db_field') if kwargs.has_key( + 'use_db_field') else True return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) @classmethod @@ -387,7 +411,7 @@ class BaseDocument(object): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + not isinstance(value, BaseList)): value = BaseList(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) @@ -452,9 +476,10 @@ class BaseDocument(object): if hasattr(value, '_get_changed_fields'): changed = value._get_changed_fields(inspected) changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] + for k in changed if k] elif isinstance(value, (list, tuple, dict)): - self._nestable_types_changed_fields(changed_fields, list_key, value, inspected) + self._nestable_types_changed_fields( + changed_fields, list_key, value, inspected) def _get_changed_fields(self, inspected=None): """Returns a list of all fields that have explicitly been changed. @@ -484,16 +509,17 @@ class BaseDocument(object): if isinstance(field, ReferenceField): continue elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in changed_fields): + and db_field_name not in changed_fields): # Find all embedded fields that have been changed changed = data._get_changed_fields(inspected) changed_fields += ["%s%s" % (key, k) for k in changed if k] elif (isinstance(data, (list, tuple, dict)) and db_field_name not in changed_fields): if (hasattr(field, 'field') and - isinstance(field.field, ReferenceField)): + isinstance(field.field, ReferenceField)): continue - self._nestable_types_changed_fields(changed_fields, key, data, inspected) + self._nestable_types_changed_fields( + changed_fields, key, data, inspected) return changed_fields def _delta(self): @@ -539,7 +565,7 @@ class BaseDocument(object): # If we've set a value that ain't the default value dont unset it. default = None if (self._dynamic and len(parts) and parts[0] in - self._dynamic_fields): + self._dynamic_fields): del(set_data[path]) unset_data[path] = 1 continue @@ -625,13 +651,14 @@ class BaseDocument(object): if errors_dict: errors = "\n".join(["%s - %s" % (k, v) - for k, v in errors_dict.items()]) + for k, v in errors_dict.items()]) msg = ("Invalid data to create a `%s` instance.\n%s" % (cls._class_name, errors)) raise InvalidDocumentError(msg) if cls.STRICT: - data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) + data = dict((k, v) + for k, v in data.iteritems() if k in cls._fields) obj = cls(__auto_convert=False, _created=False, **data) obj._changed_fields = changed_fields if not _auto_dereference: @@ -778,7 +805,7 @@ class BaseDocument(object): # Grab any embedded document field unique indexes if (field.__class__.__name__ == "EmbeddedDocumentField" and - field.document_type != cls): + field.document_type != cls): field_namespace = "%s." % field_name doc_cls = field.document_type unique_indexes += doc_cls._unique_with_indexes(field_namespace) @@ -794,7 +821,8 @@ class BaseDocument(object): geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", "PointField", "LineStringField", "PolygonField"] - geo_field_types = tuple([_import_class(field) for field in geo_field_type_names]) + geo_field_types = tuple([_import_class(field) + for field in geo_field_type_names]) for field in cls._fields.values(): if not isinstance(field, geo_field_types): @@ -804,13 +832,14 @@ class BaseDocument(object): if field_cls in inspected: continue if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected, parent_field=field.db_field) + geo_indices += field_cls._geo_indices( + inspected, parent_field=field.db_field) elif field._geo_index: field_name = field.db_field if parent_field: field_name = "%s.%s" % (parent_field, field_name) geo_indices.append({'fields': - [(field_name, field._geo_index)]}) + [(field_name, field._geo_index)]}) return geo_indices @classmethod diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0320898b..22bb7423 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -34,22 +34,24 @@ except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', - 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', - 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', - 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', - 'SortedListField', 'DictField', 'MapField', 'ReferenceField', - 'GenericReferenceField', 'BinaryField', 'GridFSError', - 'GridFSProxy', 'FileField', 'ImageGridFsProxy', - 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', - 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', - 'GeoJsonBaseField'] +__all__ = [ + 'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', + 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', + 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', + 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', + 'SortedListField', 'DictField', 'MapField', 'ReferenceField', + 'CachedReferenceField', 'GenericReferenceField', 'BinaryField', + 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy', + 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', + 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', + 'GeoJsonBaseField'] RECURSIVE_REFERENCE_CONSTANT = 'self' class StringField(BaseField): + """A unicode string field. """ @@ -109,6 +111,7 @@ class StringField(BaseField): class URLField(StringField): + """A field that validates input as an URL. .. versionadded:: 0.3 @@ -116,7 +119,8 @@ class URLField(StringField): _URL_REGEX = re.compile( r'^(?:http|ftp)s?://' # http:// or https:// - r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... + # domain... + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' r'localhost|' # localhost... r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'(?::\d+)?' # optional port @@ -145,15 +149,19 @@ class URLField(StringField): class EmailField(StringField): + """A field that validates input as an E-Mail-Address. .. versionadded:: 0.4 """ EMAIL_REGEX = re.compile( - r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain + # dot-atom + r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" + # quoted-string + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' + # domain + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE ) def validate(self, value): @@ -163,6 +171,7 @@ class EmailField(StringField): class IntField(BaseField): + """An 32-bit integer field. """ @@ -197,6 +206,7 @@ class IntField(BaseField): class LongField(BaseField): + """An 64-bit integer field. """ @@ -231,6 +241,7 @@ class LongField(BaseField): class FloatField(BaseField): + """An floating point number field. """ @@ -265,6 +276,7 @@ class FloatField(BaseField): class DecimalField(BaseField): + """A fixed-point decimal number field. .. versionchanged:: 0.8 @@ -338,6 +350,7 @@ class DecimalField(BaseField): class BooleanField(BaseField): + """A boolean field type. .. versionadded:: 0.1.2 @@ -356,6 +369,7 @@ class BooleanField(BaseField): class DateTimeField(BaseField): + """A datetime field. Uses the python-dateutil library if available alternatively use time.strptime @@ -406,15 +420,15 @@ class DateTimeField(BaseField): kwargs = {'microsecond': usecs} try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M:%S')[:6], **kwargs) + '%Y-%m-%d %H:%M:%S')[:6], **kwargs) except ValueError: try: # Try without seconds. return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M')[:5], **kwargs) + '%Y-%m-%d %H:%M')[:5], **kwargs) except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, - '%Y-%m-%d')[:3], **kwargs) + '%Y-%m-%d')[:3], **kwargs) except ValueError: return None @@ -423,6 +437,7 @@ class DateTimeField(BaseField): class ComplexDateTimeField(StringField): + """ ComplexDateTimeField handles microseconds exactly instead of rounding like DateTimeField does. @@ -525,6 +540,7 @@ class ComplexDateTimeField(StringField): class EmbeddedDocumentField(BaseField): + """An embedded document field - with a declared document_type. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. """ @@ -551,7 +567,7 @@ class EmbeddedDocumentField(BaseField): return self.document_type._from_son(value) return value - def to_mongo(self, value, use_db_field=True): + def to_mongo(self, value, use_db_field=True, fields=[]): if not isinstance(value, self.document_type): return value return self.document_type.to_mongo(value, use_db_field) @@ -574,6 +590,7 @@ class EmbeddedDocumentField(BaseField): class GenericEmbeddedDocumentField(BaseField): + """A generic embedded document field - allows any :class:`~mongoengine.EmbeddedDocument` to be stored. @@ -612,6 +629,7 @@ class GenericEmbeddedDocumentField(BaseField): class DynamicField(BaseField): + """A truly dynamic field type capable of handling different and varying types of data. @@ -675,6 +693,7 @@ class DynamicField(BaseField): class ListField(ComplexBaseField): + """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. @@ -693,21 +712,22 @@ class ListField(ComplexBaseField): """Make sure that a list of valid fields is being used. """ if (not isinstance(value, (list, tuple, QuerySet)) or - isinstance(value, basestring)): + isinstance(value, basestring)): self.error('Only lists and tuples may be used in a list field') super(ListField, self).validate(value) def prepare_query_value(self, op, value): if self.field: if op in ('set', 'unset') and (not isinstance(value, basestring) - and not isinstance(value, BaseDocument) - and hasattr(value, '__iter__')): + and not isinstance(value, BaseDocument) + and hasattr(value, '__iter__')): return [self.field.prepare_query_value(op, v) for v in value] return self.field.prepare_query_value(op, value) return super(ListField, self).prepare_query_value(op, value) class SortedListField(ListField): + """A ListField that sorts the contents of its list before writing to the database in order to ensure that a sorted list is always retrieved. @@ -739,6 +759,7 @@ class SortedListField(ListField): reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse) + def key_not_string(d): """ Helper function to recursively determine if any key in a dictionary is not a string. @@ -747,6 +768,7 @@ def key_not_string(d): if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): return True + def key_has_dot_or_dollar(d): """ Helper function to recursively determine if any key in a dictionary contains a dot or a dollar sign. @@ -755,7 +777,9 @@ def key_has_dot_or_dollar(d): if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): return True + class DictField(ComplexBaseField): + """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. @@ -807,6 +831,7 @@ class DictField(ComplexBaseField): class MapField(DictField): + """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -822,6 +847,7 @@ class MapField(DictField): class ReferenceField(BaseField): + """A reference to a document that will be automatically dereferenced on access (lazily). @@ -932,7 +958,7 @@ class ReferenceField(BaseField): """Convert a MongoDB-compatible type to a Python type. """ if (not self.dbref and - not isinstance(value, (DBRef, Document, EmbeddedDocument))): + not isinstance(value, (DBRef, Document, EmbeddedDocument))): collection = self.document_type._get_collection_name() value = DBRef(collection, self.document_type.id.to_python(value)) return value @@ -955,7 +981,106 @@ class ReferenceField(BaseField): return self.document_type._fields.get(member_name) +class CachedReferenceField(BaseField): + + """ + A referencefield with cache fields support + .. versionadded:: 0.9 + """ + + def __init__(self, document_type, fields=[], **kwargs): + """Initialises the Cached Reference Field. + + :param fields: A list of fields to be cached in document + """ + if not isinstance(document_type, basestring) and \ + not issubclass(document_type, (Document, basestring)): + + self.error('Argument to CachedReferenceField constructor must be a' + ' document class or a string') + + self.document_type_obj = document_type + self.fields = fields + super(CachedReferenceField, self).__init__(**kwargs) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + if isinstance(value, dict): + collection = self.document_type._get_collection_name() + value = DBRef( + collection, self.document_type.id.to_python(value['_id'])) + + return value + + @property + def document_type(self): + if isinstance(self.document_type_obj, basestring): + if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: + self.document_type_obj = self.owner_document + else: + self.document_type_obj = get_document(self.document_type_obj) + return self.document_type_obj + + def __get__(self, instance, owner): + """Descriptor to allow lazy dereferencing. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value = instance._data.get(self.name) + self._auto_dereference = instance._fields[self.name]._auto_dereference + # Dereference DBRefs + if self._auto_dereference and isinstance(value, DBRef): + value = self.document_type._get_db().dereference(value) + if value is not None: + instance._data[self.name] = self.document_type._from_son(value) + + return super(CachedReferenceField, self).__get__(instance, owner) + + def to_mongo(self, document): + id_field_name = self.document_type._meta['id_field'] + id_field = self.document_type._fields[id_field_name] + doc_tipe = self.document_type + + if isinstance(document, Document): + # We need the id from the saved object to create the DBRef + id_ = document.pk + if id_ is None: + self.error('You can only reference documents once they have' + ' been saved to the database') + else: + self.error('Only accept a document object') + + value = { + "_id": id_field.to_mongo(id_) + } + + value.update(dict(document.to_mongo(fields=self.fields))) + return value + + def prepare_query_value(self, op, value): + if value is None: + return None + return self.to_mongo(value) + + def validate(self, value): + + if not isinstance(value, (self.document_type)): + self.error("A CachedReferenceField only accepts documents") + + if isinstance(value, Document) and value.id is None: + self.error('You can only reference documents once they have been ' + 'saved to the database') + + def lookup_member(self, member_name): + return self.document_type._fields.get(member_name) + + class GenericReferenceField(BaseField): + """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -974,6 +1099,7 @@ class GenericReferenceField(BaseField): return self value = instance._data.get(self.name) + self._auto_dereference = instance._fields[self.name]._auto_dereference if self._auto_dereference and isinstance(value, (dict, SON)): instance._data[self.name] = self.dereference(value) @@ -1036,6 +1162,7 @@ class GenericReferenceField(BaseField): class BinaryField(BaseField): + """A binary data field. """ @@ -1056,7 +1183,7 @@ class BinaryField(BaseField): if not isinstance(value, (bin_type, txt_type, Binary)): self.error("BinaryField only accepts instances of " "(%s, %s, Binary)" % ( - bin_type.__name__, txt_type.__name__)) + bin_type.__name__, txt_type.__name__)) if self.max_bytes is not None and len(value) > self.max_bytes: self.error('Binary value is too long') @@ -1067,6 +1194,7 @@ class GridFSError(Exception): class GridFSProxy(object): + """Proxy object to handle writing and reading of files to and from GridFS .. versionadded:: 0.4 @@ -1121,7 +1249,8 @@ class GridFSProxy(object): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) def __str__(self): - name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)' + name = getattr( + self.get(), 'filename', self.grid_id) if self.get() else '(no file)' return '<%s: %s>' % (self.__class__.__name__, name) def __eq__(self, other): @@ -1135,7 +1264,8 @@ class GridFSProxy(object): @property def fs(self): if not self._fs: - self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name) + self._fs = gridfs.GridFS( + get_db(self.db_alias), self.collection_name) return self._fs def get(self, id=None): @@ -1209,6 +1339,7 @@ class GridFSProxy(object): class FileField(BaseField): + """A GridFS storage field. .. versionadded:: 0.4 @@ -1253,7 +1384,8 @@ class FileField(BaseField): pass # Create a new proxy object as we don't already have one - instance._data[key] = self.get_proxy_obj(key=key, instance=instance) + instance._data[key] = self.get_proxy_obj( + key=key, instance=instance) instance._data[key].put(value) else: instance._data[key] = value @@ -1291,11 +1423,13 @@ class FileField(BaseField): class ImageGridFsProxy(GridFSProxy): + """ Proxy for ImageField versionadded: 0.6 """ + def put(self, file_obj, **kwargs): """ Insert a image in database @@ -1341,7 +1475,8 @@ class ImageGridFsProxy(GridFSProxy): size = field.thumbnail_size if size['force']: - thumbnail = ImageOps.fit(img, (size['width'], size['height']), Image.ANTIALIAS) + thumbnail = ImageOps.fit( + img, (size['width'], size['height']), Image.ANTIALIAS) else: thumbnail = img.copy() thumbnail.thumbnail((size['width'], @@ -1367,7 +1502,7 @@ class ImageGridFsProxy(GridFSProxy): **kwargs) def delete(self, *args, **kwargs): - #deletes thumbnail + # deletes thumbnail out = self.get() if out and out.thumbnail_id: self.fs.delete(out.thumbnail_id) @@ -1427,6 +1562,7 @@ class ImproperlyConfigured(Exception): class ImageField(FileField): + """ A Image File storage field. @@ -1465,6 +1601,7 @@ class ImageField(FileField): class SequenceField(BaseField): + """Provides a sequental counter see: http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers @@ -1534,7 +1671,7 @@ class SequenceField(BaseField): data = collection.find_one({"_id": sequence_id}) if data: - return self.value_decorator(data['next']+1) + return self.value_decorator(data['next'] + 1) return self.value_decorator(1) @@ -1579,6 +1716,7 @@ class SequenceField(BaseField): class UUIDField(BaseField): + """A UUID field. .. versionadded:: 0.6 @@ -1631,6 +1769,7 @@ class UUIDField(BaseField): class GeoPointField(BaseField): + """A list storing a longitude and latitude coordinate. .. note:: this represents a generic point in a 2D plane and a legacy way of @@ -1651,13 +1790,16 @@ class GeoPointField(BaseField): 'of (x, y)') if not len(value) == 2: - self.error("Value (%s) must be a two-dimensional point" % repr(value)) + self.error("Value (%s) must be a two-dimensional point" % + repr(value)) elif (not isinstance(value[0], (float, int)) or not isinstance(value[1], (float, int))): - self.error("Both values (%s) in point must be float or int" % repr(value)) + self.error( + "Both values (%s) in point must be float or int" % repr(value)) class PointField(GeoJsonBaseField): + """A GeoJSON field storing a longitude and latitude coordinate. The data is represented as: @@ -1677,6 +1819,7 @@ class PointField(GeoJsonBaseField): class LineStringField(GeoJsonBaseField): + """A GeoJSON field storing a line of longitude and latitude coordinates. The data is represented as: @@ -1695,6 +1838,7 @@ class LineStringField(GeoJsonBaseField): class PolygonField(GeoJsonBaseField): + """A GeoJSON field storing a polygon of longitude and latitude coordinates. The data is represented as: diff --git a/tests/fields/fields.py b/tests/fields/fields.py index c108f37e..5da06981 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -47,7 +47,8 @@ class FieldTest(unittest.TestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) + self.assertEqual( + data_to_be_saved, ['age', 'created', 'name', 'userid']) self.assertTrue(person.validate() is None) @@ -63,7 +64,8 @@ class FieldTest(unittest.TestCase): # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) + self.assertEqual( + data_to_be_saved, ['age', 'created', 'name', 'userid']) def test_default_values_set_to_None(self): """Ensure that default field values are used when creating a document. @@ -587,7 +589,8 @@ class FieldTest(unittest.TestCase): LogEntry.drop_collection() - # Post UTC - microseconds are rounded (down) nearest millisecond and dropped + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) log = LogEntry() @@ -688,7 +691,8 @@ class FieldTest(unittest.TestCase): LogEntry.drop_collection() - # Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped - with default datetimefields d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) log = LogEntry() log.date = d1 @@ -696,14 +700,16 @@ class FieldTest(unittest.TestCase): log.reload() self.assertEqual(log.date, d1) - # Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields + # Post UTC - microseconds are rounded (down) nearest millisecond - with + # default datetimefields d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) log.date = d1 log.save() log.reload() self.assertEqual(log.date, d1) - # Pre UTC dates microseconds below 1000 are dropped - with default datetimefields + # Pre UTC dates microseconds below 1000 are dropped - with default + # datetimefields d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) log.date = d1 log.save() @@ -929,12 +935,16 @@ class FieldTest(unittest.TestCase): post.save() self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual(BlogPost.objects.filter(info__exact='test').count(), 1) - self.assertEqual(BlogPost.objects.filter(info__0__test='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__0__test='test').count(), 1) # Confirm handles non strings or non existing keys - self.assertEqual(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) - self.assertEqual(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__0__test__exact='5').count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__100__test__exact='test').count(), 0) BlogPost.drop_collection() def test_list_field_passed_in_value(self): @@ -951,7 +961,6 @@ class FieldTest(unittest.TestCase): foo.bars.append(bar) self.assertEqual(repr(foo.bars), '[]') - def test_list_field_strict(self): """Ensure that list field handles validation if provided a strict field type.""" @@ -1082,20 +1091,28 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) # Test querying - self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__1__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__number=1).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__complex__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) # Confirm can update Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) - self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__1__value=10).count(), 1) Simple.objects().update( set__mapping__2__list__1=StringSetting(value='Boo')) - self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) - self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) + self.assertEqual( + Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) Simple.drop_collection() @@ -1141,12 +1158,16 @@ class FieldTest(unittest.TestCase): post.save() self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual(BlogPost.objects.filter(info__title__exact='test').count(), 1) - self.assertEqual(BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__title__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact='test').count(), 1) # Confirm handles non strings or non existing keys - self.assertEqual(BlogPost.objects.filter(info__details__test__exact=5).count(), 0) - self.assertEqual(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) post = BlogPost.objects.create(info={'title': 'original'}) post.info.update({'title': 'updated'}) @@ -1207,19 +1228,26 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) # Test querying - self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) # Confirm can update Simple.objects().update( set__mapping={"someint": IntegerSetting(value=10)}) Simple.objects().update( set__mapping__nested_dict__list__1=StringSetting(value='Boo')) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) - self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) Simple.drop_collection() @@ -1290,7 +1318,7 @@ class FieldTest(unittest.TestCase): class Test(Document): my_map = MapField(field=EmbeddedDocumentField(Embedded), - db_field='x') + db_field='x') Test.drop_collection() @@ -1334,7 +1362,7 @@ class FieldTest(unittest.TestCase): Log(name="wilson", visited={'friends': datetime.datetime.now()}).save() self.assertEqual(1, Log.objects( - visited__friends__exists=True).count()) + visited__friends__exists=True).count()) def test_embedded_db_field(self): @@ -1477,6 +1505,151 @@ class FieldTest(unittest.TestCase): mongoed = p1.to_mongo() self.assertTrue(isinstance(mongoed['parent'], ObjectId)) + def test_cached_reference_fields(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(nam="Leopard", tag="heavy") + a.save() + + o = Ocorrence(person="teste", animal=a) + o.save() + self.assertEqual( + a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects(animal__tag='heavy').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects(animal__tag='heavy').first() + self.assertEqual(ocorrence.person, "teste") + self.assertTrue(isinstance(ocorrence.animal, Animal)) + + def test_cached_reference_embedded_fields(self): + class Owner(EmbeddedDocument): + TPS = ( + ('n', "Normal"), + ('u', "Urgent") + ) + name = StringField() + tp = StringField( + verbose_name="Type", + db_field="t", + choices=TPS) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tp']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(nam="Leopard", tag="heavy", + owner=Owner(tp='u', name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 't': 'u' + } + }) + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects( + animal__tag='heavy', animal__owner__tp='u').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tp='u').first() + self.assertEqual(ocorrence.person, "teste") + self.assertTrue(isinstance(ocorrence.animal, Animal)) + + def test_cached_reference_embedded_list_fields(self): + class Owner(EmbeddedDocument): + name = StringField() + tags = ListField(StringField()) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tags']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(nam="Leopard", tag="heavy", + owner=Owner(tags=['cool', 'funny'], + name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste 2", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 'tags': ['cool', 'funny'] + } + }) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['tags'], + ['cool', 'funny']) + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + query = Ocorrence.objects( + animal__tag='heavy', animal__owner__tags='cool')._query + self.assertEqual( + query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tags='cool').first() + self.assertEqual(ocorrence.person, "teste 2") + self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_objectid_reference_fields(self): class Person(Document): @@ -1834,8 +2007,7 @@ class FieldTest(unittest.TestCase): Person(name="Wilson Jr").save() self.assertEqual(repr(Person.objects(city=None)), - "[]") - + "[]") def test_generic_reference_choices(self): """Ensure that a GenericReferenceField can handle choices @@ -1982,7 +2154,8 @@ class FieldTest(unittest.TestCase): attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07')) attachment_required.validate() - attachment_size_limit = AttachmentSizeLimit(blob=b('\xe6\x00\xc4\xff\x07')) + attachment_size_limit = AttachmentSizeLimit( + blob=b('\xe6\x00\xc4\xff\x07')) self.assertRaises(ValidationError, attachment_size_limit.validate) attachment_size_limit.blob = b('\xe6\x00\xc4\xff') attachment_size_limit.validate() @@ -2030,8 +2203,8 @@ class FieldTest(unittest.TestCase): """ class Shirt(Document): size = StringField(max_length=3, choices=( - ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), - ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) style = StringField(max_length=3, choices=( ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') @@ -2061,7 +2234,7 @@ class FieldTest(unittest.TestCase): """ class Shirt(Document): size = StringField(max_length=3, - choices=('S', 'M', 'L', 'XL', 'XXL')) + choices=('S', 'M', 'L', 'XL', 'XXL')) Shirt.drop_collection() @@ -2179,7 +2352,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 1000) - def test_sequence_field_get_next_value(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2368,7 +2540,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, post.comments[0].id) self.assertEqual(2, post.comments[1].id) - def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField() @@ -2478,7 +2649,7 @@ class FieldTest(unittest.TestCase): self.assertTrue('comments' in error.errors) self.assertTrue(1 in error.errors['comments']) self.assertTrue(isinstance(error.errors['comments'][1]['content'], - ValidationError)) + ValidationError)) # ValidationError.schema property error_dict = error.to_dict() @@ -2604,11 +2775,11 @@ class FieldTest(unittest.TestCase): DictFieldTest.drop_collection() test = DictFieldTest(dictionary=None) - test.dictionary # Just access to test getter + test.dictionary # Just access to test getter self.assertRaises(ValidationError, test.validate) test = DictFieldTest(dictionary=False) - test.dictionary # Just access to test getter + test.dictionary # Just access to test getter self.assertRaises(ValidationError, test.validate) From 73549a904484c1f549a2079d7aa05f4d218de279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Thu, 17 Jul 2014 09:41:06 -0300 Subject: [PATCH 2/9] fixes for rebase branch --- mongoengine/base/document.py | 27 +++++++++++++-------------- mongoengine/fields.py | 3 ++- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 593c237d..7509fffe 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -268,7 +268,6 @@ class BaseDocument(object): data["_id"] = None data['_cls'] = self._class_name EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] root_fields = set([f.split('.')[0] for f in fields]) @@ -283,22 +282,22 @@ class BaseDocument(object): field = self._dynamic_fields.get(field_name) if value is not None: - EmbeddedDocument = _import_class("EmbeddedDocument") - if isinstance(value, (EmbeddedDocument)) and \ - not use_db_field: - value = field.to_mongo(value, use_db_field) + + if isinstance(field, (EmbeddedDocumentField)): + if fields: + key = '%s.' % field_name + embedded_fields = [ + i.replace(key, '') for i in fields + if i.startswith(key)] + + else: + embedded_fields = [] + + value = field.to_mongo(value, use_db_field=use_db_field, + fields=embedded_fields) else: value = field.to_mongo(value) - if isinstance(field, EmbeddedDocumentField) and fields: - key = '%s.' % field_name - - value = field.to_mongo(value, fields=[ - i.replace(key, '') for i in fields if i.startswith(key)]) - - elif value is not None: - value = field.to_mongo(value) - # Handle self generating fields if value is None and field._auto_gen: value = field.generate() diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 22bb7423..58271435 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -570,7 +570,8 @@ class EmbeddedDocumentField(BaseField): def to_mongo(self, value, use_db_field=True, fields=[]): if not isinstance(value, self.document_type): return value - return self.document_type.to_mongo(value, use_db_field) + return self.document_type.to_mongo(value, use_db_field, + fields=fields) def validate(self, value, clean=True): """Make sure that the document instance is an instance of the From 6c4aee147933dab020cced417bd556411becb3da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Thu, 17 Jul 2014 13:42:34 -0300 Subject: [PATCH 3/9] added CachedReferenceField restriction to use in EmbeddedDocument --- mongoengine/base/metaclasses.py | 40 ++++++++++++++++++++------------- mongoengine/common.py | 1 + tests/fields/fields.py | 12 ++++++++++ 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 4b2e8b9b..887c9abc 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -16,6 +16,7 @@ __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') class DocumentMetaclass(type): + """Metaclass for all documents. """ @@ -90,7 +91,7 @@ class DocumentMetaclass(type): # Set _fields and db_field maps attrs['_fields'] = doc_fields attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) - for k, v in doc_fields.iteritems()]) + for k, v in doc_fields.iteritems()]) attrs['_reverse_db_field_map'] = dict( (v, k) for k, v in attrs['_db_field_map'].iteritems()) @@ -105,7 +106,7 @@ class DocumentMetaclass(type): class_name = [name] for base in flattened_bases: if (not getattr(base, '_is_base_cls', True) and - not getattr(base, '_meta', {}).get('abstract', True)): + not getattr(base, '_meta', {}).get('abstract', True)): # Collate heirarchy for _cls and _subclasses class_name.append(base.__name__) @@ -115,7 +116,7 @@ class DocumentMetaclass(type): allow_inheritance = base._meta.get('allow_inheritance', ALLOW_INHERITANCE) if (allow_inheritance is not True and - not base._meta.get('abstract')): + not base._meta.get('abstract')): raise ValueError('Document %s may not be subclassed' % base.__name__) @@ -141,7 +142,8 @@ class DocumentMetaclass(type): base._subclasses += (_cls,) base._types = base._subclasses # TODO depreciate _types - Document, EmbeddedDocument, DictField = cls._import_classes() + (Document, EmbeddedDocument, DictField, + CachedReferenceField) = cls._import_classes() if issubclass(new_class, Document): new_class._collection = None @@ -170,6 +172,10 @@ class DocumentMetaclass(type): f = field f.owner_document = new_class delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) + if isinstance(f, CachedReferenceField) and issubclass( + new_class, EmbeddedDocument): + raise InvalidDocumentError( + "CachedReferenceFields is not allowed in EmbeddedDocuments") if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): delete_rule = getattr(f.field, 'reverse_delete_rule', @@ -191,7 +197,7 @@ class DocumentMetaclass(type): field.name, delete_rule) if (field.name and hasattr(Document, field.name) and - EmbeddedDocument not in new_class.mro()): + EmbeddedDocument not in new_class.mro()): msg = ("%s is a document method and not a valid " "field name" % field.name) raise InvalidDocumentError(msg) @@ -224,10 +230,12 @@ class DocumentMetaclass(type): Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') DictField = _import_class('DictField') - return (Document, EmbeddedDocument, DictField) + CachedReferenceField = _import_class('CachedReferenceField') + return (Document, EmbeddedDocument, DictField, CachedReferenceField) class TopLevelDocumentMetaclass(DocumentMetaclass): + """Metaclass for top-level documents (i.e. documents that have their own collection in the database. """ @@ -275,21 +283,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Find the parent document class parent_doc_cls = [b for b in flattened_bases - if b.__class__ == TopLevelDocumentMetaclass] + if b.__class__ == TopLevelDocumentMetaclass] parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] # Prevent classes setting collection different to their parents # If parent wasn't an abstract class if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) - and not parent_doc_cls._meta.get('abstract', True)): - msg = "Trying to set a collection on a subclass (%s)" % name - warnings.warn(msg, SyntaxWarning) - del(attrs['_meta']['collection']) + and not parent_doc_cls._meta.get('abstract', True)): + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['_meta']['collection']) # Ensure abstract documents have abstract bases if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): if (parent_doc_cls and - not parent_doc_cls._meta.get('abstract', False)): + not parent_doc_cls._meta.get('abstract', False)): msg = "Abstract document cannot have non-abstract base" raise ValueError(msg) return super_new(cls, name, bases, attrs) @@ -306,7 +314,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Set collection in the meta if its callable if (getattr(base, '_is_document', False) and - not base._meta.get('abstract')): + not base._meta.get('abstract')): collection = meta.get('collection', None) if callable(collection): meta['collection'] = collection(base) @@ -318,7 +326,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): simple_class = all([b._meta.get('abstract') for b in flattened_bases if hasattr(b, '_meta')]) if (not simple_class and meta['allow_inheritance'] is False and - not meta['abstract']): + not meta['abstract']): raise ValueError('Only direct subclasses of Document may set ' '"allow_inheritance" to False') @@ -378,7 +386,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): for exc in exceptions_to_merge: name = exc.__name__ parents = tuple(getattr(base, name) for base in flattened_bases - if hasattr(base, name)) or (exc,) + if hasattr(base, name)) or (exc,) # Create new exception and set to new_class exception = type(name, parents, {'__module__': module}) setattr(new_class, name, exception) @@ -387,6 +395,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class MetaDict(dict): + """Custom dictionary for meta classes. Handles the merging of set indexes """ @@ -401,5 +410,6 @@ class MetaDict(dict): class BasesTuple(tuple): + """Special class to handle introspection of bases tuple in __new__""" pass diff --git a/mongoengine/common.py b/mongoengine/common.py index daa194b9..7c0c18d2 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -25,6 +25,7 @@ def _import_class(cls_name): 'GenericEmbeddedDocumentField', 'GeoPointField', 'PointField', 'LineStringField', 'ListField', 'PolygonField', 'ReferenceField', 'StringField', + 'CachedReferenceField', 'ComplexBaseField', 'GeoJsonBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 5da06981..c82c936b 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1539,6 +1539,18 @@ class FieldTest(unittest.TestCase): self.assertEqual(ocorrence.person, "teste") self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_cached_reference_fields_on_embedded_documents(self): + def build(): + class Test(Document): + name = StringField() + + type('WrongEmbeddedDocument', ( + EmbeddedDocument,), { + 'test': CachedReferenceField(Test) + }) + + self.assertRaises(InvalidDocumentError, build) + def test_cached_reference_embedded_fields(self): class Owner(EmbeddedDocument): TPS = ( From 87c97efce08ca727c36b02dbf0e111c6400520ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Fri, 25 Jul 2014 08:44:59 -0300 Subject: [PATCH 4/9] refs #709, added CachedReferenceField.sync_all to sync all documents on demand --- mongoengine/base/fields.py | 21 +++++++---- mongoengine/base/metaclasses.py | 18 +++++++--- mongoengine/fields.py | 28 +++++++++++++-- mongoengine/queryset/transform.py | 49 +++++++++++++++---------- tests/fields/fields.py | 59 ++++++++++++++++++++++++++++++- 5 files changed, 142 insertions(+), 33 deletions(-) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index ad173191..c163e6e7 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -11,10 +11,12 @@ from mongoengine.errors import ValidationError from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList -__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") +__all__ = ("BaseField", "ComplexBaseField", + "ObjectIdField", "GeoJsonBaseField") class BaseField(object): + """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. @@ -60,6 +62,7 @@ class BaseField(object): used when generating model forms from the document model. """ self.db_field = (db_field or name) if not primary_key else '_id' + if name: msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" warnings.warn(msg, DeprecationWarning) @@ -105,7 +108,7 @@ class BaseField(object): if instance._initialised: try: if (self.name not in instance._data or - instance._data[self.name] != value): + instance._data[self.name] != value): instance._mark_as_changed(self.name) except: # Values cant be compared eg: naive and tz datetimes @@ -175,6 +178,7 @@ class BaseField(object): class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. Allows for nesting of embedded documents inside complex types. @@ -197,7 +201,7 @@ class ComplexBaseField(BaseField): GenericReferenceField = _import_class('GenericReferenceField') dereference = (self._auto_dereference and (self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)))) + (GenericReferenceField, ReferenceField)))) _dereference = _import_class("DeReference")() @@ -212,7 +216,7 @@ class ComplexBaseField(BaseField): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + not isinstance(value, BaseList)): value = BaseList(value, instance, self.name) instance._data[self.name] = value elif isinstance(value, dict) and not isinstance(value, BaseDict): @@ -220,8 +224,8 @@ class ComplexBaseField(BaseField): instance._data[self.name] = value if (self._auto_dereference and instance._initialised and - isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): + isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced): value = _dereference( value, max_depth=1, instance=instance, name=self.name ) @@ -384,6 +388,7 @@ class ComplexBaseField(BaseField): class ObjectIdField(BaseField): + """A field wrapper around MongoDB's ObjectIds. """ @@ -412,6 +417,7 @@ class ObjectIdField(BaseField): class GeoJsonBaseField(BaseField): + """A geo json field storing a geojson style object. .. versionadded:: 0.8 """ @@ -435,7 +441,8 @@ class GeoJsonBaseField(BaseField): if isinstance(value, dict): if set(value.keys()) == set(['type', 'coordinates']): if value['type'] != self._type: - self.error('%s type must be "%s"' % (self._name, self._type)) + self.error('%s type must be "%s"' % + (self._name, self._type)) return self.validate(value['coordinates']) else: self.error('%s can only accept a valid GeoJson dictionary' diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 887c9abc..b7157a35 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -30,7 +30,8 @@ class DocumentMetaclass(type): return super_new(cls, name, bases, attrs) attrs['_is_document'] = attrs.get('_is_document', False) - + attrs['_cached_reference_fields'] = [] + # EmbeddedDocuments could have meta data for inheritance if 'meta' in attrs: attrs['_meta'] = attrs.pop('meta') @@ -172,10 +173,17 @@ class DocumentMetaclass(type): f = field f.owner_document = new_class delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) - if isinstance(f, CachedReferenceField) and issubclass( - new_class, EmbeddedDocument): - raise InvalidDocumentError( - "CachedReferenceFields is not allowed in EmbeddedDocuments") + if isinstance(f, CachedReferenceField): + + if issubclass(new_class, EmbeddedDocument): + raise InvalidDocumentError( + "CachedReferenceFields is not allowed in EmbeddedDocuments") + if not f.document_type: + raise InvalidDocumentError( + "Document is not avaiable to sync") + + f.document_type._cached_reference_fields.append(f) + if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): delete_rule = getattr(f.field, 'reverse_delete_rule', diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 58271435..abe2a491 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1047,12 +1047,13 @@ class CachedReferenceField(BaseField): doc_tipe = self.document_type if isinstance(document, Document): - # We need the id from the saved object to create the DBRef + # Wen need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') else: + raise SystemError(document) self.error('Only accept a document object') value = { @@ -1065,7 +1066,14 @@ class CachedReferenceField(BaseField): def prepare_query_value(self, op, value): if value is None: return None - return self.to_mongo(value) + + if isinstance(value, Document): + if value.pk is None: + self.error('You can only reference documents once they have' + ' been saved to the database') + return {'_id': value.pk} + + raise NotImplementedError def validate(self, value): @@ -1079,6 +1087,22 @@ class CachedReferenceField(BaseField): def lookup_member(self, member_name): return self.document_type._fields.get(member_name) + def sync_all(self): + update_key = 'set__%s' % self.name + errors = [] + + for doc in self.document_type.objects: + filter_kwargs = {} + filter_kwargs[self.name] = doc + + update_kwargs = {} + update_kwargs[update_key] = doc + + errors.append((filter_kwargs, update_kwargs)) + + self.owner_document.objects( + **filter_kwargs).update(**update_kwargs) + class GenericReferenceField(BaseField): diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 3345ae64..e575d9d6 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -12,21 +12,21 @@ __all__ = ('query', 'update') COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not', 'elemMatch') -GEO_OPERATORS = ('within_distance', 'within_spherical_distance', - 'within_box', 'within_polygon', 'near', 'near_sphere', - 'max_distance', 'geo_within', 'geo_within_box', - 'geo_within_polygon', 'geo_within_center', - 'geo_within_sphere', 'geo_intersects') -STRING_OPERATORS = ('contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', - 'exact', 'iexact') -CUSTOM_OPERATORS = ('match',) -MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + - STRING_OPERATORS + CUSTOM_OPERATORS) +GEO_OPERATORS = ('within_distance', 'within_spherical_distance', + 'within_box', 'within_polygon', 'near', 'near_sphere', + 'max_distance', 'geo_within', 'geo_within_box', + 'geo_within_polygon', 'geo_within_center', + 'geo_within_sphere', 'geo_intersects') +STRING_OPERATORS = ('contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', + 'exact', 'iexact') +CUSTOM_OPERATORS = ('match',) +MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + + STRING_OPERATORS + CUSTOM_OPERATORS) -UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', - 'push_all', 'pull', 'pull_all', 'add_to_set', - 'set_on_insert') +UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', + 'push_all', 'pull', 'pull_all', 'add_to_set', + 'set_on_insert') def query(_doc_cls=None, _field_operation=False, **query): @@ -60,14 +60,20 @@ def query(_doc_cls=None, _field_operation=False, **query): raise InvalidQueryError(e) parts = [] + CachedReferenceField = _import_class('CachedReferenceField') + cleaned_fields = [] for field in fields: append_field = True if isinstance(field, basestring): parts.append(field) append_field = False + # is last and CachedReferenceField + elif isinstance(field, CachedReferenceField) and fields[-1] == field: + parts.append('%s._id' % field.db_field) else: parts.append(field.db_field) + if append_field: cleaned_fields.append(field) @@ -79,13 +85,17 @@ def query(_doc_cls=None, _field_operation=False, **query): if op in singular_ops: if isinstance(field, basestring): if (op in STRING_OPERATORS and - isinstance(value, basestring)): + isinstance(value, basestring)): StringField = _import_class('StringField') value = StringField.prepare_query_value(op, value) else: value = field else: value = field.prepare_query_value(op, value) + + if isinstance(field, CachedReferenceField) and value: + value = value['_id'] + elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] @@ -125,10 +135,12 @@ def query(_doc_cls=None, _field_operation=False, **query): continue value_son[k] = v if (get_connection().max_wire_version <= 1): - value_son['$maxDistance'] = value_dict['$maxDistance'] + value_son['$maxDistance'] = value_dict[ + '$maxDistance'] else: value_son['$near'] = SON(value_son['$near']) - value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] + value_son['$near'][ + '$maxDistance'] = value_dict['$maxDistance'] else: for k, v in value_dict.iteritems(): if k == '$maxDistance': @@ -264,7 +276,8 @@ def update(_doc_cls=None, **update): if ListField in field_classes: # Join all fields via dot notation to the last ListField # Then process as normal - last_listField = len(cleaned_fields) - field_classes.index(ListField) + last_listField = len( + cleaned_fields) - field_classes.index(ListField) key = ".".join(parts[:last_listField]) parts = parts[last_listField:] parts.insert(0, key) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index c82c936b..d5ae3329 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1518,11 +1518,18 @@ class FieldTest(unittest.TestCase): Animal.drop_collection() Ocorrence.drop_collection() - a = Animal(nam="Leopard", tag="heavy") + a = Animal(name="Leopard", tag="heavy") a.save() + self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) o = Ocorrence(person="teste", animal=a) o.save() + + p = Ocorrence(person="Wilson") + p.save() + + self.assertEqual(Ocorrence.objects(animal=None).count(), 1) + self.assertEqual( a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) @@ -1539,6 +1546,56 @@ class FieldTest(unittest.TestCase): self.assertEqual(ocorrence.person, "teste") self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_cached_reference_field_update_all(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pj" + } + }) + + self.assertEqual(Person.objects(father=a1)._query, { + 'father._id': a1.pk + }) + self.assertEqual(Person.objects(father=a1).count(), 1) + + Person.objects.update(set__tp="pf") + Person.father.sync_all() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pf" + } + }) + def test_cached_reference_fields_on_embedded_documents(self): def build(): class Test(Document): From 15bbf26b93197240fb2800a4f69e14d000038e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Fri, 25 Jul 2014 08:48:24 -0300 Subject: [PATCH 5/9] refs #709, fix typos --- mongoengine/fields.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index abe2a491..9b19f25d 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1047,13 +1047,12 @@ class CachedReferenceField(BaseField): doc_tipe = self.document_type if isinstance(document, Document): - # Wen need the id from the saved object to create the DBRef + # We need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') else: - raise SystemError(document) self.error('Only accept a document object') value = { From 6c0112c2be3da38e50471ffa764a64f2008ab541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Fri, 25 Jul 2014 18:12:26 -0300 Subject: [PATCH 6/9] refs #709, added support to disable auto_sync --- mongoengine/base/metaclasses.py | 9 +++-- mongoengine/fields.py | 29 ++++++++++++-- tests/fields/fields.py | 70 +++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index b7157a35..a4bd0144 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -31,7 +31,7 @@ class DocumentMetaclass(type): attrs['_is_document'] = attrs.get('_is_document', False) attrs['_cached_reference_fields'] = [] - + # EmbeddedDocuments could have meta data for inheritance if 'meta' in attrs: attrs['_meta'] = attrs.pop('meta') @@ -181,9 +181,12 @@ class DocumentMetaclass(type): if not f.document_type: raise InvalidDocumentError( "Document is not avaiable to sync") - + + if f.auto_sync: + f.start_listener() + f.document_type._cached_reference_fields.append(f) - + if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): delete_rule = getattr(f.field, 'reverse_delete_rule', diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9b19f25d..14fcde68 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -989,10 +989,11 @@ class CachedReferenceField(BaseField): .. versionadded:: 0.9 """ - def __init__(self, document_type, fields=[], **kwargs): + def __init__(self, document_type, fields=[], auto_sync=True, **kwargs): """Initialises the Cached Reference Field. :param fields: A list of fields to be cached in document + :param auto_sync: if True documents are auto updated. """ if not isinstance(document_type, basestring) and \ not issubclass(document_type, (Document, basestring)): @@ -1000,10 +1001,33 @@ class CachedReferenceField(BaseField): self.error('Argument to CachedReferenceField constructor must be a' ' document class or a string') + self.auto_sync = auto_sync self.document_type_obj = document_type self.fields = fields super(CachedReferenceField, self).__init__(**kwargs) + def start_listener(self): + """ + Start listener for document alterations, and update relacted docs + """ + from mongoengine import signals + signals.post_save.connect(self.on_document_pre_save, + sender=self.document_type) + + def on_document_pre_save(self, sender, document, created, **kwargs): + if not created: + update_kwargs = { + 'set__%s__%s' % (self.name, k): v + for k, v in document._delta()[0].items() + if k in self.fields} + + if update_kwargs: + filter_kwargs = {} + filter_kwargs[self.name] = document + + self.owner_document.objects( + **filter_kwargs).update(**update_kwargs) + def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ @@ -1088,7 +1112,6 @@ class CachedReferenceField(BaseField): def sync_all(self): update_key = 'set__%s' % self.name - errors = [] for doc in self.document_type.objects: filter_kwargs = {} @@ -1097,8 +1120,6 @@ class CachedReferenceField(BaseField): update_kwargs = {} update_kwargs[update_key] = doc - errors.append((filter_kwargs, update_kwargs)) - self.owner_document.objects( **filter_kwargs).update(**update_kwargs) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index d5ae3329..77490ddb 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1608,6 +1608,76 @@ class FieldTest(unittest.TestCase): self.assertRaises(InvalidDocumentError, build) + def test_cached_reference_auto_sync(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a1.tp = 'pf' + a1.save() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + '_id': a2.pk, + 'name': 'Wilson Junior', + 'tp': 'pf', + 'father': { + '_id': a1.pk, + 'tp': 'pf' + } + }) + + def test_cached_reference_auto_sync_disabled(self): + class Persone(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField( + 'self', fields=('tp',), auto_sync=False) + + Persone.drop_collection() + + a1 = Persone(name="Wilson Father", tp="pj") + a1.save() + + a2 = Persone(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a1.tp = 'pf' + a1.save() + + self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), { + '_id': a2.pk, + 'name': 'Wilson Junior', + 'tp': 'pf', + 'father': { + '_id': a1.pk, + 'tp': 'pj' + } + }) + def test_cached_reference_embedded_fields(self): class Owner(EmbeddedDocument): TPS = ( From e33a5bbef5d4f74bf8a63aeaddd390ef26b70eb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Sat, 26 Jul 2014 07:24:04 -0300 Subject: [PATCH 7/9] fixes for python2.6 --- mongoengine/fields.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 14fcde68..a6ffb94f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1016,10 +1016,10 @@ class CachedReferenceField(BaseField): def on_document_pre_save(self, sender, document, created, **kwargs): if not created: - update_kwargs = { - 'set__%s__%s' % (self.name, k): v + update_kwargs = dict( + ('set__%s__%s' % (self.name, k), v) for k, v in document._delta()[0].items() - if k in self.fields} + if k in self.fields) if update_kwargs: filter_kwargs = {} From b4d6f6b9470c832338338be6312420dea398380d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Wed, 30 Jul 2014 09:32:33 -0300 Subject: [PATCH 8/9] added documentation about CachedReferenceField --- docs/apireference.rst | 1 + mongoengine/fields.py | 19 ++++------ tests/fields/fields.py | 85 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/docs/apireference.rst b/docs/apireference.rst index 6c42d40a..9b4f2c66 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -84,6 +84,7 @@ Fields .. autoclass:: mongoengine.fields.MapField .. autoclass:: mongoengine.fields.ReferenceField .. autoclass:: mongoengine.fields.GenericReferenceField +.. autoclass:: mongoengine.fields.CachedReferenceField .. autoclass:: mongoengine.fields.BinaryField .. autoclass:: mongoengine.fields.FileField .. autoclass:: mongoengine.fields.ImageField diff --git a/mongoengine/fields.py b/mongoengine/fields.py index a6ffb94f..7bbc221a 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -985,7 +985,7 @@ class ReferenceField(BaseField): class CachedReferenceField(BaseField): """ - A referencefield with cache fields support + A referencefield with cache fields to porpuse pseudo-joins .. versionadded:: 0.9 """ @@ -1007,9 +1007,6 @@ class CachedReferenceField(BaseField): super(CachedReferenceField, self).__init__(**kwargs) def start_listener(self): - """ - Start listener for document alterations, and update relacted docs - """ from mongoengine import signals signals.post_save.connect(self.on_document_pre_save, sender=self.document_type) @@ -1029,8 +1026,6 @@ class CachedReferenceField(BaseField): **filter_kwargs).update(**update_kwargs) def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ if isinstance(value, dict): collection = self.document_type._get_collection_name() value = DBRef( @@ -1048,8 +1043,6 @@ class CachedReferenceField(BaseField): return self.document_type_obj def __get__(self, instance, owner): - """Descriptor to allow lazy dereferencing. - """ if instance is None: # Document class being used rather than a document object return self @@ -1079,9 +1072,9 @@ class CachedReferenceField(BaseField): else: self.error('Only accept a document object') - value = { - "_id": id_field.to_mongo(id_) - } + value = SON(( + ("_id", id_field.to_mongo(id_)), + )) value.update(dict(document.to_mongo(fields=self.fields))) return value @@ -1111,6 +1104,10 @@ class CachedReferenceField(BaseField): return self.document_type._fields.get(member_name) def sync_all(self): + """ + Sync all cached fields on demand. + Caution: this operation may be slower. + """ update_key = 'set__%s' % self.name for doc in self.document_type.objects: diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 77490ddb..342a13b3 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1546,6 +1546,91 @@ class FieldTest(unittest.TestCase): self.assertEqual(ocorrence.person, "teste") self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_cached_reference_field_decimal(self): + class PersonAuto(Document): + name = StringField() + salary = DecimalField() + + class SocialTest(Document): + group = StringField() + person = CachedReferenceField( + PersonAuto, + fields=('salary',)) + + PersonAuto.drop_collection() + SocialTest.drop_collection() + + p = PersonAuto(name="Alberto", salary=Decimal('7000.00')) + p.save() + + s = SocialTest(group="dev", person=p) + s.save() + + self.assertEqual( + SocialTest.objects._collection.find_one({'person.salary': 7000.00}), { + '_id': s.pk, + 'group': s.group, + 'person': { + '_id': p.pk, + 'salary': p.salary + } + }) + + def test_cached_reference_field_reference(self): + class Group(Document): + name = StringField() + + class Person(Document): + name = StringField() + group = ReferenceField(Group) + + class SocialData(Document): + obs = StringField() + tags = ListField( + StringField()) + person = CachedReferenceField( + Person, + fields=('group',)) + + Group.drop_collection() + Person.drop_collection() + SocialData.drop_collection() + + g1 = Group(name='dev') + g1.save() + + g2 = Group(name="designers") + g2.save() + + p1 = Person(name="Alberto", group=g1) + p1.save() + + p2 = Person(name="Andre", group=g1) + p2.save() + + p3 = Person(name="Afro design", group=g2) + p3.save() + + s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2']) + s1.save() + + s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4']) + s2.save() + + self.assertEqual(SocialData.objects._collection.find_one( + {'tags': 'tag2'}), { + '_id': s1.pk, + 'obs': 'testing 123', + 'tags': ['tag1', 'tag2'], + 'person': { + '_id': p1.pk, + 'group': g1.pk + } + }) + + self.assertEqual(SocialData.objects(person__group=g2).count(), 1) + self.assertEqual(SocialData.objects(person__group=g2).first(), s2) + def test_cached_reference_field_update_all(self): class Person(Document): TYPES = ( From f17f8b48c21adfe85953aedc8ed7cfc81bbf150c Mon Sep 17 00:00:00 2001 From: Wilson Junior Date: Sun, 3 Aug 2014 18:59:50 -0400 Subject: [PATCH 9/9] small fixes for python2.6 --- tests/fields/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 342a13b3..0af22a34 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1572,7 +1572,7 @@ class FieldTest(unittest.TestCase): 'group': s.group, 'person': { '_id': p.pk, - 'salary': p.salary + 'salary': 7000.00 } })