refs #709, added CachedReferenceField.sync_all to sync all documents on demand
This commit is contained in:
		@@ -11,10 +11,12 @@ from mongoengine.errors import ValidationError
 | 
				
			|||||||
from mongoengine.base.common import ALLOW_INHERITANCE
 | 
					from mongoengine.base.common import ALLOW_INHERITANCE
 | 
				
			||||||
from mongoengine.base.datastructures import BaseDict, BaseList
 | 
					from mongoengine.base.datastructures import BaseDict, BaseList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
 | 
					__all__ = ("BaseField", "ComplexBaseField",
 | 
				
			||||||
 | 
					           "ObjectIdField", "GeoJsonBaseField")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseField(object):
 | 
					class BaseField(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """A base class for fields in a MongoDB document. Instances of this class
 | 
					    """A base class for fields in a MongoDB document. Instances of this class
 | 
				
			||||||
    may be added to subclasses of `Document` to define a document's schema.
 | 
					    may be added to subclasses of `Document` to define a document's schema.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -60,6 +62,7 @@ class BaseField(object):
 | 
				
			|||||||
            used when generating model forms from the document model.
 | 
					            used when generating model forms from the document model.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.db_field = (db_field or name) if not primary_key else '_id'
 | 
					        self.db_field = (db_field or name) if not primary_key else '_id'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if name:
 | 
					        if name:
 | 
				
			||||||
            msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
 | 
					            msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
 | 
				
			||||||
            warnings.warn(msg, DeprecationWarning)
 | 
					            warnings.warn(msg, DeprecationWarning)
 | 
				
			||||||
@@ -105,7 +108,7 @@ class BaseField(object):
 | 
				
			|||||||
        if instance._initialised:
 | 
					        if instance._initialised:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                if (self.name not in instance._data or
 | 
					                if (self.name not in instance._data or
 | 
				
			||||||
                   instance._data[self.name] != value):
 | 
					                        instance._data[self.name] != value):
 | 
				
			||||||
                    instance._mark_as_changed(self.name)
 | 
					                    instance._mark_as_changed(self.name)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                # Values cant be compared eg: naive and tz datetimes
 | 
					                # Values cant be compared eg: naive and tz datetimes
 | 
				
			||||||
@@ -175,6 +178,7 @@ class BaseField(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ComplexBaseField(BaseField):
 | 
					class ComplexBaseField(BaseField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """Handles complex fields, such as lists / dictionaries.
 | 
					    """Handles complex fields, such as lists / dictionaries.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Allows for nesting of embedded documents inside complex types.
 | 
					    Allows for nesting of embedded documents inside complex types.
 | 
				
			||||||
@@ -197,7 +201,7 @@ class ComplexBaseField(BaseField):
 | 
				
			|||||||
        GenericReferenceField = _import_class('GenericReferenceField')
 | 
					        GenericReferenceField = _import_class('GenericReferenceField')
 | 
				
			||||||
        dereference = (self._auto_dereference and
 | 
					        dereference = (self._auto_dereference and
 | 
				
			||||||
                       (self.field is None or isinstance(self.field,
 | 
					                       (self.field is None or isinstance(self.field,
 | 
				
			||||||
                        (GenericReferenceField, ReferenceField))))
 | 
					                                                         (GenericReferenceField, ReferenceField))))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        _dereference = _import_class("DeReference")()
 | 
					        _dereference = _import_class("DeReference")()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -212,7 +216,7 @@ class ComplexBaseField(BaseField):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Convert lists / values so we can watch for any changes on them
 | 
					        # Convert lists / values so we can watch for any changes on them
 | 
				
			||||||
        if (isinstance(value, (list, tuple)) and
 | 
					        if (isinstance(value, (list, tuple)) and
 | 
				
			||||||
           not isinstance(value, BaseList)):
 | 
					                not isinstance(value, BaseList)):
 | 
				
			||||||
            value = BaseList(value, instance, self.name)
 | 
					            value = BaseList(value, instance, self.name)
 | 
				
			||||||
            instance._data[self.name] = value
 | 
					            instance._data[self.name] = value
 | 
				
			||||||
        elif isinstance(value, dict) and not isinstance(value, BaseDict):
 | 
					        elif isinstance(value, dict) and not isinstance(value, BaseDict):
 | 
				
			||||||
@@ -220,8 +224,8 @@ class ComplexBaseField(BaseField):
 | 
				
			|||||||
            instance._data[self.name] = value
 | 
					            instance._data[self.name] = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (self._auto_dereference and instance._initialised and
 | 
					        if (self._auto_dereference and instance._initialised and
 | 
				
			||||||
           isinstance(value, (BaseList, BaseDict))
 | 
					                isinstance(value, (BaseList, BaseDict))
 | 
				
			||||||
           and not value._dereferenced):
 | 
					                and not value._dereferenced):
 | 
				
			||||||
            value = _dereference(
 | 
					            value = _dereference(
 | 
				
			||||||
                value, max_depth=1, instance=instance, name=self.name
 | 
					                value, max_depth=1, instance=instance, name=self.name
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
@@ -384,6 +388,7 @@ class ComplexBaseField(BaseField):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ObjectIdField(BaseField):
 | 
					class ObjectIdField(BaseField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """A field wrapper around MongoDB's ObjectIds.
 | 
					    """A field wrapper around MongoDB's ObjectIds.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -412,6 +417,7 @@ class ObjectIdField(BaseField):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GeoJsonBaseField(BaseField):
 | 
					class GeoJsonBaseField(BaseField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """A geo json field storing a geojson style object.
 | 
					    """A geo json field storing a geojson style object.
 | 
				
			||||||
    .. versionadded:: 0.8
 | 
					    .. versionadded:: 0.8
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@@ -435,7 +441,8 @@ class GeoJsonBaseField(BaseField):
 | 
				
			|||||||
        if isinstance(value, dict):
 | 
					        if isinstance(value, dict):
 | 
				
			||||||
            if set(value.keys()) == set(['type', 'coordinates']):
 | 
					            if set(value.keys()) == set(['type', 'coordinates']):
 | 
				
			||||||
                if value['type'] != self._type:
 | 
					                if value['type'] != self._type:
 | 
				
			||||||
                    self.error('%s type must be "%s"' % (self._name, self._type))
 | 
					                    self.error('%s type must be "%s"' %
 | 
				
			||||||
 | 
					                               (self._name, self._type))
 | 
				
			||||||
                return self.validate(value['coordinates'])
 | 
					                return self.validate(value['coordinates'])
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.error('%s can only accept a valid GeoJson dictionary'
 | 
					                self.error('%s can only accept a valid GeoJson dictionary'
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -30,7 +30,8 @@ class DocumentMetaclass(type):
 | 
				
			|||||||
            return super_new(cls, name, bases, attrs)
 | 
					            return super_new(cls, name, bases, attrs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        attrs['_is_document'] = attrs.get('_is_document', False)
 | 
					        attrs['_is_document'] = attrs.get('_is_document', False)
 | 
				
			||||||
 | 
					        attrs['_cached_reference_fields'] = []
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        # EmbeddedDocuments could have meta data for inheritance
 | 
					        # EmbeddedDocuments could have meta data for inheritance
 | 
				
			||||||
        if 'meta' in attrs:
 | 
					        if 'meta' in attrs:
 | 
				
			||||||
            attrs['_meta'] = attrs.pop('meta')
 | 
					            attrs['_meta'] = attrs.pop('meta')
 | 
				
			||||||
@@ -172,10 +173,17 @@ class DocumentMetaclass(type):
 | 
				
			|||||||
            f = field
 | 
					            f = field
 | 
				
			||||||
            f.owner_document = new_class
 | 
					            f.owner_document = new_class
 | 
				
			||||||
            delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
 | 
					            delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
 | 
				
			||||||
            if isinstance(f, CachedReferenceField) and issubclass(
 | 
					            if isinstance(f, CachedReferenceField):
 | 
				
			||||||
                    new_class, EmbeddedDocument):
 | 
					
 | 
				
			||||||
                raise InvalidDocumentError(
 | 
					                if issubclass(new_class, EmbeddedDocument):
 | 
				
			||||||
                    "CachedReferenceFields is not allowed in EmbeddedDocuments")
 | 
					                    raise InvalidDocumentError(
 | 
				
			||||||
 | 
					                        "CachedReferenceFields is not allowed in EmbeddedDocuments")
 | 
				
			||||||
 | 
					                if not f.document_type:
 | 
				
			||||||
 | 
					                    raise InvalidDocumentError(
 | 
				
			||||||
 | 
					                        "Document is not avaiable to sync")
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                f.document_type._cached_reference_fields.append(f)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
            if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
 | 
					            if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
 | 
				
			||||||
                delete_rule = getattr(f.field,
 | 
					                delete_rule = getattr(f.field,
 | 
				
			||||||
                                      'reverse_delete_rule',
 | 
					                                      'reverse_delete_rule',
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1047,12 +1047,13 @@ class CachedReferenceField(BaseField):
 | 
				
			|||||||
        doc_tipe = self.document_type
 | 
					        doc_tipe = self.document_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(document, Document):
 | 
					        if isinstance(document, Document):
 | 
				
			||||||
            # We need the id from the saved object to create the DBRef
 | 
					            # Wen need the id from the saved object to create the DBRef
 | 
				
			||||||
            id_ = document.pk
 | 
					            id_ = document.pk
 | 
				
			||||||
            if id_ is None:
 | 
					            if id_ is None:
 | 
				
			||||||
                self.error('You can only reference documents once they have'
 | 
					                self.error('You can only reference documents once they have'
 | 
				
			||||||
                           ' been saved to the database')
 | 
					                           ' been saved to the database')
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 | 
					            raise SystemError(document)
 | 
				
			||||||
            self.error('Only accept a document object')
 | 
					            self.error('Only accept a document object')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        value = {
 | 
					        value = {
 | 
				
			||||||
@@ -1065,7 +1066,14 @@ class CachedReferenceField(BaseField):
 | 
				
			|||||||
    def prepare_query_value(self, op, value):
 | 
					    def prepare_query_value(self, op, value):
 | 
				
			||||||
        if value is None:
 | 
					        if value is None:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
        return self.to_mongo(value)
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(value, Document):
 | 
				
			||||||
 | 
					            if value.pk is None:
 | 
				
			||||||
 | 
					                self.error('You can only reference documents once they have'
 | 
				
			||||||
 | 
					                           ' been saved to the database')
 | 
				
			||||||
 | 
					            return {'_id': value.pk}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate(self, value):
 | 
					    def validate(self, value):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1079,6 +1087,22 @@ class CachedReferenceField(BaseField):
 | 
				
			|||||||
    def lookup_member(self, member_name):
 | 
					    def lookup_member(self, member_name):
 | 
				
			||||||
        return self.document_type._fields.get(member_name)
 | 
					        return self.document_type._fields.get(member_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sync_all(self):
 | 
				
			||||||
 | 
					        update_key = 'set__%s' % self.name
 | 
				
			||||||
 | 
					        errors = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for doc in self.document_type.objects:
 | 
				
			||||||
 | 
					            filter_kwargs = {}
 | 
				
			||||||
 | 
					            filter_kwargs[self.name] = doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            update_kwargs = {}
 | 
				
			||||||
 | 
					            update_kwargs[update_key] = doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            errors.append((filter_kwargs, update_kwargs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.owner_document.objects(
 | 
				
			||||||
 | 
					                **filter_kwargs).update(**update_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GenericReferenceField(BaseField):
 | 
					class GenericReferenceField(BaseField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,21 +12,21 @@ __all__ = ('query', 'update')
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
 | 
					COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
 | 
				
			||||||
                        'all', 'size', 'exists', 'not', 'elemMatch')
 | 
					                        'all', 'size', 'exists', 'not', 'elemMatch')
 | 
				
			||||||
GEO_OPERATORS        = ('within_distance', 'within_spherical_distance',
 | 
					GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
 | 
				
			||||||
                        'within_box', 'within_polygon', 'near', 'near_sphere',
 | 
					                 'within_box', 'within_polygon', 'near', 'near_sphere',
 | 
				
			||||||
                        'max_distance', 'geo_within', 'geo_within_box',
 | 
					                 'max_distance', 'geo_within', 'geo_within_box',
 | 
				
			||||||
                        'geo_within_polygon', 'geo_within_center',
 | 
					                 'geo_within_polygon', 'geo_within_center',
 | 
				
			||||||
                        'geo_within_sphere', 'geo_intersects')
 | 
					                 'geo_within_sphere', 'geo_intersects')
 | 
				
			||||||
STRING_OPERATORS     = ('contains', 'icontains', 'startswith',
 | 
					STRING_OPERATORS = ('contains', 'icontains', 'startswith',
 | 
				
			||||||
                        'istartswith', 'endswith', 'iendswith',
 | 
					                    'istartswith', 'endswith', 'iendswith',
 | 
				
			||||||
                        'exact', 'iexact')
 | 
					                    'exact', 'iexact')
 | 
				
			||||||
CUSTOM_OPERATORS     = ('match',)
 | 
					CUSTOM_OPERATORS = ('match',)
 | 
				
			||||||
MATCH_OPERATORS      = (COMPARISON_OPERATORS + GEO_OPERATORS +
 | 
					MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
 | 
				
			||||||
                        STRING_OPERATORS + CUSTOM_OPERATORS)
 | 
					                   STRING_OPERATORS + CUSTOM_OPERATORS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
UPDATE_OPERATORS     = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
 | 
					UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
 | 
				
			||||||
                        'push_all', 'pull', 'pull_all', 'add_to_set',
 | 
					                    'push_all', 'pull', 'pull_all', 'add_to_set',
 | 
				
			||||||
                        'set_on_insert')
 | 
					                    'set_on_insert')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def query(_doc_cls=None, _field_operation=False, **query):
 | 
					def query(_doc_cls=None, _field_operation=False, **query):
 | 
				
			||||||
@@ -60,14 +60,20 @@ def query(_doc_cls=None, _field_operation=False, **query):
 | 
				
			|||||||
                raise InvalidQueryError(e)
 | 
					                raise InvalidQueryError(e)
 | 
				
			||||||
            parts = []
 | 
					            parts = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            CachedReferenceField = _import_class('CachedReferenceField')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cleaned_fields = []
 | 
					            cleaned_fields = []
 | 
				
			||||||
            for field in fields:
 | 
					            for field in fields:
 | 
				
			||||||
                append_field = True
 | 
					                append_field = True
 | 
				
			||||||
                if isinstance(field, basestring):
 | 
					                if isinstance(field, basestring):
 | 
				
			||||||
                    parts.append(field)
 | 
					                    parts.append(field)
 | 
				
			||||||
                    append_field = False
 | 
					                    append_field = False
 | 
				
			||||||
 | 
					                # is last and CachedReferenceField
 | 
				
			||||||
 | 
					                elif isinstance(field, CachedReferenceField) and fields[-1] == field:
 | 
				
			||||||
 | 
					                    parts.append('%s._id' % field.db_field)
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    parts.append(field.db_field)
 | 
					                    parts.append(field.db_field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if append_field:
 | 
					                if append_field:
 | 
				
			||||||
                    cleaned_fields.append(field)
 | 
					                    cleaned_fields.append(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -79,13 +85,17 @@ def query(_doc_cls=None, _field_operation=False, **query):
 | 
				
			|||||||
            if op in singular_ops:
 | 
					            if op in singular_ops:
 | 
				
			||||||
                if isinstance(field, basestring):
 | 
					                if isinstance(field, basestring):
 | 
				
			||||||
                    if (op in STRING_OPERATORS and
 | 
					                    if (op in STRING_OPERATORS and
 | 
				
			||||||
                       isinstance(value, basestring)):
 | 
					                            isinstance(value, basestring)):
 | 
				
			||||||
                        StringField = _import_class('StringField')
 | 
					                        StringField = _import_class('StringField')
 | 
				
			||||||
                        value = StringField.prepare_query_value(op, value)
 | 
					                        value = StringField.prepare_query_value(op, value)
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        value = field
 | 
					                        value = field
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    value = field.prepare_query_value(op, value)
 | 
					                    value = field.prepare_query_value(op, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if isinstance(field, CachedReferenceField) and value:
 | 
				
			||||||
 | 
					                        value = value['_id']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
 | 
					            elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
 | 
				
			||||||
                # 'in', 'nin' and 'all' require a list of values
 | 
					                # 'in', 'nin' and 'all' require a list of values
 | 
				
			||||||
                value = [field.prepare_query_value(op, v) for v in value]
 | 
					                value = [field.prepare_query_value(op, v) for v in value]
 | 
				
			||||||
@@ -125,10 +135,12 @@ def query(_doc_cls=None, _field_operation=False, **query):
 | 
				
			|||||||
                                continue
 | 
					                                continue
 | 
				
			||||||
                            value_son[k] = v
 | 
					                            value_son[k] = v
 | 
				
			||||||
                        if (get_connection().max_wire_version <= 1):
 | 
					                        if (get_connection().max_wire_version <= 1):
 | 
				
			||||||
                            value_son['$maxDistance'] = value_dict['$maxDistance']
 | 
					                            value_son['$maxDistance'] = value_dict[
 | 
				
			||||||
 | 
					                                '$maxDistance']
 | 
				
			||||||
                        else:
 | 
					                        else:
 | 
				
			||||||
                            value_son['$near'] = SON(value_son['$near'])
 | 
					                            value_son['$near'] = SON(value_son['$near'])
 | 
				
			||||||
                            value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
 | 
					                            value_son['$near'][
 | 
				
			||||||
 | 
					                                '$maxDistance'] = value_dict['$maxDistance']
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        for k, v in value_dict.iteritems():
 | 
					                        for k, v in value_dict.iteritems():
 | 
				
			||||||
                            if k == '$maxDistance':
 | 
					                            if k == '$maxDistance':
 | 
				
			||||||
@@ -264,7 +276,8 @@ def update(_doc_cls=None, **update):
 | 
				
			|||||||
            if ListField in field_classes:
 | 
					            if ListField in field_classes:
 | 
				
			||||||
                # Join all fields via dot notation to the last ListField
 | 
					                # Join all fields via dot notation to the last ListField
 | 
				
			||||||
                # Then process as normal
 | 
					                # Then process as normal
 | 
				
			||||||
                last_listField = len(cleaned_fields) - field_classes.index(ListField)
 | 
					                last_listField = len(
 | 
				
			||||||
 | 
					                    cleaned_fields) - field_classes.index(ListField)
 | 
				
			||||||
                key = ".".join(parts[:last_listField])
 | 
					                key = ".".join(parts[:last_listField])
 | 
				
			||||||
                parts = parts[last_listField:]
 | 
					                parts = parts[last_listField:]
 | 
				
			||||||
                parts.insert(0, key)
 | 
					                parts.insert(0, key)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1518,11 +1518,18 @@ class FieldTest(unittest.TestCase):
 | 
				
			|||||||
        Animal.drop_collection()
 | 
					        Animal.drop_collection()
 | 
				
			||||||
        Ocorrence.drop_collection()
 | 
					        Ocorrence.drop_collection()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        a = Animal(nam="Leopard", tag="heavy")
 | 
					        a = Animal(name="Leopard", tag="heavy")
 | 
				
			||||||
        a.save()
 | 
					        a.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal])
 | 
				
			||||||
        o = Ocorrence(person="teste", animal=a)
 | 
					        o = Ocorrence(person="teste", animal=a)
 | 
				
			||||||
        o.save()
 | 
					        o.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        p = Ocorrence(person="Wilson")
 | 
				
			||||||
 | 
					        p.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(Ocorrence.objects(animal=None).count(), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
 | 
					            a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1539,6 +1546,56 @@ class FieldTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertEqual(ocorrence.person, "teste")
 | 
					        self.assertEqual(ocorrence.person, "teste")
 | 
				
			||||||
        self.assertTrue(isinstance(ocorrence.animal, Animal))
 | 
					        self.assertTrue(isinstance(ocorrence.animal, Animal))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_cached_reference_field_update_all(self):
 | 
				
			||||||
 | 
					        class Person(Document):
 | 
				
			||||||
 | 
					            TYPES = (
 | 
				
			||||||
 | 
					                ('pf', "PF"),
 | 
				
			||||||
 | 
					                ('pj', "PJ")
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            name = StringField()
 | 
				
			||||||
 | 
					            tp = StringField(
 | 
				
			||||||
 | 
					                choices=TYPES
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            father = CachedReferenceField('self', fields=('tp',))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Person.drop_collection()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        a1 = Person(name="Wilson Father", tp="pj")
 | 
				
			||||||
 | 
					        a1.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        a2 = Person(name='Wilson Junior', tp='pf', father=a1)
 | 
				
			||||||
 | 
					        a2.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(dict(a2.to_mongo()), {
 | 
				
			||||||
 | 
					            "_id": a2.pk,
 | 
				
			||||||
 | 
					            "name": u"Wilson Junior",
 | 
				
			||||||
 | 
					            "tp": u"pf",
 | 
				
			||||||
 | 
					            "father": {
 | 
				
			||||||
 | 
					                "_id": a1.pk,
 | 
				
			||||||
 | 
					                "tp": u"pj"
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(Person.objects(father=a1)._query, {
 | 
				
			||||||
 | 
					            'father._id': a1.pk
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					        self.assertEqual(Person.objects(father=a1).count(), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Person.objects.update(set__tp="pf")
 | 
				
			||||||
 | 
					        Person.father.sync_all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        a2.reload()
 | 
				
			||||||
 | 
					        self.assertEqual(dict(a2.to_mongo()), {
 | 
				
			||||||
 | 
					            "_id": a2.pk,
 | 
				
			||||||
 | 
					            "name": u"Wilson Junior",
 | 
				
			||||||
 | 
					            "tp": u"pf",
 | 
				
			||||||
 | 
					            "father": {
 | 
				
			||||||
 | 
					                "_id": a1.pk,
 | 
				
			||||||
 | 
					                "tp": u"pf"
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_cached_reference_fields_on_embedded_documents(self):
 | 
					    def test_cached_reference_fields_on_embedded_documents(self):
 | 
				
			||||||
        def build():
 | 
					        def build():
 | 
				
			||||||
            class Test(Document):
 | 
					            class Test(Document):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user