diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index ad173191..c163e6e7 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -11,10 +11,12 @@ from mongoengine.errors import ValidationError from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList -__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") +__all__ = ("BaseField", "ComplexBaseField", + "ObjectIdField", "GeoJsonBaseField") class BaseField(object): + """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. @@ -60,6 +62,7 @@ class BaseField(object): used when generating model forms from the document model. """ self.db_field = (db_field or name) if not primary_key else '_id' + if name: msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" warnings.warn(msg, DeprecationWarning) @@ -105,7 +108,7 @@ class BaseField(object): if instance._initialised: try: if (self.name not in instance._data or - instance._data[self.name] != value): + instance._data[self.name] != value): instance._mark_as_changed(self.name) except: # Values cant be compared eg: naive and tz datetimes @@ -175,6 +178,7 @@ class BaseField(object): class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. Allows for nesting of embedded documents inside complex types. @@ -197,7 +201,7 @@ class ComplexBaseField(BaseField): GenericReferenceField = _import_class('GenericReferenceField') dereference = (self._auto_dereference and (self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)))) + (GenericReferenceField, ReferenceField)))) _dereference = _import_class("DeReference")() @@ -212,7 +216,7 @@ class ComplexBaseField(BaseField): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + not isinstance(value, BaseList)): value = BaseList(value, instance, self.name) instance._data[self.name] = value elif isinstance(value, dict) and not isinstance(value, BaseDict): @@ -220,8 +224,8 @@ class ComplexBaseField(BaseField): instance._data[self.name] = value if (self._auto_dereference and instance._initialised and - isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): + isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced): value = _dereference( value, max_depth=1, instance=instance, name=self.name ) @@ -384,6 +388,7 @@ class ComplexBaseField(BaseField): class ObjectIdField(BaseField): + """A field wrapper around MongoDB's ObjectIds. """ @@ -412,6 +417,7 @@ class ObjectIdField(BaseField): class GeoJsonBaseField(BaseField): + """A geo json field storing a geojson style object. .. versionadded:: 0.8 """ @@ -435,7 +441,8 @@ class GeoJsonBaseField(BaseField): if isinstance(value, dict): if set(value.keys()) == set(['type', 'coordinates']): if value['type'] != self._type: - self.error('%s type must be "%s"' % (self._name, self._type)) + self.error('%s type must be "%s"' % + (self._name, self._type)) return self.validate(value['coordinates']) else: self.error('%s can only accept a valid GeoJson dictionary' diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 887c9abc..b7157a35 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -30,7 +30,8 @@ class DocumentMetaclass(type): return super_new(cls, name, bases, attrs) attrs['_is_document'] = attrs.get('_is_document', False) - + attrs['_cached_reference_fields'] = [] + # EmbeddedDocuments could have meta data for inheritance if 'meta' in attrs: attrs['_meta'] = attrs.pop('meta') @@ -172,10 +173,17 @@ class DocumentMetaclass(type): f = field f.owner_document = new_class delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) - if isinstance(f, CachedReferenceField) and issubclass( - new_class, EmbeddedDocument): - raise InvalidDocumentError( - "CachedReferenceFields is not allowed in EmbeddedDocuments") + if isinstance(f, CachedReferenceField): + + if issubclass(new_class, EmbeddedDocument): + raise InvalidDocumentError( + "CachedReferenceFields is not allowed in EmbeddedDocuments") + if not f.document_type: + raise InvalidDocumentError( + "Document is not avaiable to sync") + + f.document_type._cached_reference_fields.append(f) + if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): delete_rule = getattr(f.field, 'reverse_delete_rule', diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 58271435..abe2a491 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1047,12 +1047,13 @@ class CachedReferenceField(BaseField): doc_tipe = self.document_type if isinstance(document, Document): - # We need the id from the saved object to create the DBRef + # Wen need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') else: + raise SystemError(document) self.error('Only accept a document object') value = { @@ -1065,7 +1066,14 @@ class CachedReferenceField(BaseField): def prepare_query_value(self, op, value): if value is None: return None - return self.to_mongo(value) + + if isinstance(value, Document): + if value.pk is None: + self.error('You can only reference documents once they have' + ' been saved to the database') + return {'_id': value.pk} + + raise NotImplementedError def validate(self, value): @@ -1079,6 +1087,22 @@ class CachedReferenceField(BaseField): def lookup_member(self, member_name): return self.document_type._fields.get(member_name) + def sync_all(self): + update_key = 'set__%s' % self.name + errors = [] + + for doc in self.document_type.objects: + filter_kwargs = {} + filter_kwargs[self.name] = doc + + update_kwargs = {} + update_kwargs[update_key] = doc + + errors.append((filter_kwargs, update_kwargs)) + + self.owner_document.objects( + **filter_kwargs).update(**update_kwargs) + class GenericReferenceField(BaseField): diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 3345ae64..e575d9d6 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -12,21 +12,21 @@ __all__ = ('query', 'update') COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not', 'elemMatch') -GEO_OPERATORS = ('within_distance', 'within_spherical_distance', - 'within_box', 'within_polygon', 'near', 'near_sphere', - 'max_distance', 'geo_within', 'geo_within_box', - 'geo_within_polygon', 'geo_within_center', - 'geo_within_sphere', 'geo_intersects') -STRING_OPERATORS = ('contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', - 'exact', 'iexact') -CUSTOM_OPERATORS = ('match',) -MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + - STRING_OPERATORS + CUSTOM_OPERATORS) +GEO_OPERATORS = ('within_distance', 'within_spherical_distance', + 'within_box', 'within_polygon', 'near', 'near_sphere', + 'max_distance', 'geo_within', 'geo_within_box', + 'geo_within_polygon', 'geo_within_center', + 'geo_within_sphere', 'geo_intersects') +STRING_OPERATORS = ('contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', + 'exact', 'iexact') +CUSTOM_OPERATORS = ('match',) +MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + + STRING_OPERATORS + CUSTOM_OPERATORS) -UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', - 'push_all', 'pull', 'pull_all', 'add_to_set', - 'set_on_insert') +UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', + 'push_all', 'pull', 'pull_all', 'add_to_set', + 'set_on_insert') def query(_doc_cls=None, _field_operation=False, **query): @@ -60,14 +60,20 @@ def query(_doc_cls=None, _field_operation=False, **query): raise InvalidQueryError(e) parts = [] + CachedReferenceField = _import_class('CachedReferenceField') + cleaned_fields = [] for field in fields: append_field = True if isinstance(field, basestring): parts.append(field) append_field = False + # is last and CachedReferenceField + elif isinstance(field, CachedReferenceField) and fields[-1] == field: + parts.append('%s._id' % field.db_field) else: parts.append(field.db_field) + if append_field: cleaned_fields.append(field) @@ -79,13 +85,17 @@ def query(_doc_cls=None, _field_operation=False, **query): if op in singular_ops: if isinstance(field, basestring): if (op in STRING_OPERATORS and - isinstance(value, basestring)): + isinstance(value, basestring)): StringField = _import_class('StringField') value = StringField.prepare_query_value(op, value) else: value = field else: value = field.prepare_query_value(op, value) + + if isinstance(field, CachedReferenceField) and value: + value = value['_id'] + elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] @@ -125,10 +135,12 @@ def query(_doc_cls=None, _field_operation=False, **query): continue value_son[k] = v if (get_connection().max_wire_version <= 1): - value_son['$maxDistance'] = value_dict['$maxDistance'] + value_son['$maxDistance'] = value_dict[ + '$maxDistance'] else: value_son['$near'] = SON(value_son['$near']) - value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] + value_son['$near'][ + '$maxDistance'] = value_dict['$maxDistance'] else: for k, v in value_dict.iteritems(): if k == '$maxDistance': @@ -264,7 +276,8 @@ def update(_doc_cls=None, **update): if ListField in field_classes: # Join all fields via dot notation to the last ListField # Then process as normal - last_listField = len(cleaned_fields) - field_classes.index(ListField) + last_listField = len( + cleaned_fields) - field_classes.index(ListField) key = ".".join(parts[:last_listField]) parts = parts[last_listField:] parts.insert(0, key) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index c82c936b..d5ae3329 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1518,11 +1518,18 @@ class FieldTest(unittest.TestCase): Animal.drop_collection() Ocorrence.drop_collection() - a = Animal(nam="Leopard", tag="heavy") + a = Animal(name="Leopard", tag="heavy") a.save() + self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) o = Ocorrence(person="teste", animal=a) o.save() + + p = Ocorrence(person="Wilson") + p.save() + + self.assertEqual(Ocorrence.objects(animal=None).count(), 1) + self.assertEqual( a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) @@ -1539,6 +1546,56 @@ class FieldTest(unittest.TestCase): self.assertEqual(ocorrence.person, "teste") self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_cached_reference_field_update_all(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pj" + } + }) + + self.assertEqual(Person.objects(father=a1)._query, { + 'father._id': a1.pk + }) + self.assertEqual(Person.objects(father=a1).count(), 1) + + Person.objects.update(set__tp="pf") + Person.father.sync_all() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pf" + } + }) + def test_cached_reference_fields_on_embedded_documents(self): def build(): class Test(Document):