refs #709, added CachedReferenceField.sync_all to sync all documents on demand

This commit is contained in:
Wilson Júnior 2014-07-25 08:44:59 -03:00
parent 6c4aee1479
commit 87c97efce0
5 changed files with 142 additions and 33 deletions

View File

@ -11,10 +11,12 @@ from mongoengine.errors import ValidationError
from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.common import ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.datastructures import BaseDict, BaseList
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") __all__ = ("BaseField", "ComplexBaseField",
"ObjectIdField", "GeoJsonBaseField")
class BaseField(object): class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class """A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
@ -60,6 +62,7 @@ class BaseField(object):
used when generating model forms from the document model. used when generating model forms from the document model.
""" """
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
if name: if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
@ -175,6 +178,7 @@ class BaseField(object):
class ComplexBaseField(BaseField): class ComplexBaseField(BaseField):
"""Handles complex fields, such as lists / dictionaries. """Handles complex fields, such as lists / dictionaries.
Allows for nesting of embedded documents inside complex types. Allows for nesting of embedded documents inside complex types.
@ -384,6 +388,7 @@ class ComplexBaseField(BaseField):
class ObjectIdField(BaseField): class ObjectIdField(BaseField):
"""A field wrapper around MongoDB's ObjectIds. """A field wrapper around MongoDB's ObjectIds.
""" """
@ -412,6 +417,7 @@ class ObjectIdField(BaseField):
class GeoJsonBaseField(BaseField): class GeoJsonBaseField(BaseField):
"""A geo json field storing a geojson style object. """A geo json field storing a geojson style object.
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
@ -435,7 +441,8 @@ class GeoJsonBaseField(BaseField):
if isinstance(value, dict): if isinstance(value, dict):
if set(value.keys()) == set(['type', 'coordinates']): if set(value.keys()) == set(['type', 'coordinates']):
if value['type'] != self._type: if value['type'] != self._type:
self.error('%s type must be "%s"' % (self._name, self._type)) self.error('%s type must be "%s"' %
(self._name, self._type))
return self.validate(value['coordinates']) return self.validate(value['coordinates'])
else: else:
self.error('%s can only accept a valid GeoJson dictionary' self.error('%s can only accept a valid GeoJson dictionary'

View File

@ -30,6 +30,7 @@ class DocumentMetaclass(type):
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
attrs['_is_document'] = attrs.get('_is_document', False) attrs['_is_document'] = attrs.get('_is_document', False)
attrs['_cached_reference_fields'] = []
# EmbeddedDocuments could have meta data for inheritance # EmbeddedDocuments could have meta data for inheritance
if 'meta' in attrs: if 'meta' in attrs:
@ -172,10 +173,17 @@ class DocumentMetaclass(type):
f = field f = field
f.owner_document = new_class f.owner_document = new_class
delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
if isinstance(f, CachedReferenceField) and issubclass( if isinstance(f, CachedReferenceField):
new_class, EmbeddedDocument):
if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError( raise InvalidDocumentError(
"CachedReferenceFields is not allowed in EmbeddedDocuments") "CachedReferenceFields is not allowed in EmbeddedDocuments")
if not f.document_type:
raise InvalidDocumentError(
"Document is not avaiable to sync")
f.document_type._cached_reference_fields.append(f)
if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
delete_rule = getattr(f.field, delete_rule = getattr(f.field,
'reverse_delete_rule', 'reverse_delete_rule',

View File

@ -1047,12 +1047,13 @@ class CachedReferenceField(BaseField):
doc_tipe = self.document_type doc_tipe = self.document_type
if isinstance(document, Document): if isinstance(document, Document):
# We need the id from the saved object to create the DBRef # Wen need the id from the saved object to create the DBRef
id_ = document.pk id_ = document.pk
if id_ is None: if id_ is None:
self.error('You can only reference documents once they have' self.error('You can only reference documents once they have'
' been saved to the database') ' been saved to the database')
else: else:
raise SystemError(document)
self.error('Only accept a document object') self.error('Only accept a document object')
value = { value = {
@ -1065,7 +1066,14 @@ class CachedReferenceField(BaseField):
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if value is None: if value is None:
return None return None
return self.to_mongo(value)
if isinstance(value, Document):
if value.pk is None:
self.error('You can only reference documents once they have'
' been saved to the database')
return {'_id': value.pk}
raise NotImplementedError
def validate(self, value): def validate(self, value):
@ -1079,6 +1087,22 @@ class CachedReferenceField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
def sync_all(self):
update_key = 'set__%s' % self.name
errors = []
for doc in self.document_type.objects:
filter_kwargs = {}
filter_kwargs[self.name] = doc
update_kwargs = {}
update_kwargs[update_key] = doc
errors.append((filter_kwargs, update_kwargs))
self.owner_document.objects(
**filter_kwargs).update(**update_kwargs)
class GenericReferenceField(BaseField): class GenericReferenceField(BaseField):

View File

@ -60,14 +60,20 @@ def query(_doc_cls=None, _field_operation=False, **query):
raise InvalidQueryError(e) raise InvalidQueryError(e)
parts = [] parts = []
CachedReferenceField = _import_class('CachedReferenceField')
cleaned_fields = [] cleaned_fields = []
for field in fields: for field in fields:
append_field = True append_field = True
if isinstance(field, basestring): if isinstance(field, basestring):
parts.append(field) parts.append(field)
append_field = False append_field = False
# is last and CachedReferenceField
elif isinstance(field, CachedReferenceField) and fields[-1] == field:
parts.append('%s._id' % field.db_field)
else: else:
parts.append(field.db_field) parts.append(field.db_field)
if append_field: if append_field:
cleaned_fields.append(field) cleaned_fields.append(field)
@ -86,6 +92,10 @@ def query(_doc_cls=None, _field_operation=False, **query):
value = field value = field
else: else:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
if isinstance(field, CachedReferenceField) and value:
value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
@ -125,10 +135,12 @@ def query(_doc_cls=None, _field_operation=False, **query):
continue continue
value_son[k] = v value_son[k] = v
if (get_connection().max_wire_version <= 1): if (get_connection().max_wire_version <= 1):
value_son['$maxDistance'] = value_dict['$maxDistance'] value_son['$maxDistance'] = value_dict[
'$maxDistance']
else: else:
value_son['$near'] = SON(value_son['$near']) value_son['$near'] = SON(value_son['$near'])
value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] value_son['$near'][
'$maxDistance'] = value_dict['$maxDistance']
else: else:
for k, v in value_dict.iteritems(): for k, v in value_dict.iteritems():
if k == '$maxDistance': if k == '$maxDistance':
@ -264,7 +276,8 @@ def update(_doc_cls=None, **update):
if ListField in field_classes: if ListField in field_classes:
# Join all fields via dot notation to the last ListField # Join all fields via dot notation to the last ListField
# Then process as normal # Then process as normal
last_listField = len(cleaned_fields) - field_classes.index(ListField) last_listField = len(
cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField]) key = ".".join(parts[:last_listField])
parts = parts[last_listField:] parts = parts[last_listField:]
parts.insert(0, key) parts.insert(0, key)

View File

@ -1518,11 +1518,18 @@ class FieldTest(unittest.TestCase):
Animal.drop_collection() Animal.drop_collection()
Ocorrence.drop_collection() Ocorrence.drop_collection()
a = Animal(nam="Leopard", tag="heavy") a = Animal(name="Leopard", tag="heavy")
a.save() a.save()
self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal])
o = Ocorrence(person="teste", animal=a) o = Ocorrence(person="teste", animal=a)
o.save() o.save()
p = Ocorrence(person="Wilson")
p.save()
self.assertEqual(Ocorrence.objects(animal=None).count(), 1)
self.assertEqual( self.assertEqual(
a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
@ -1539,6 +1546,56 @@ class FieldTest(unittest.TestCase):
self.assertEqual(ocorrence.person, "teste") self.assertEqual(ocorrence.person, "teste")
self.assertTrue(isinstance(ocorrence.animal, Animal)) self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_cached_reference_field_update_all(self):
class Person(Document):
TYPES = (
('pf', "PF"),
('pj', "PJ")
)
name = StringField()
tp = StringField(
choices=TYPES
)
father = CachedReferenceField('self', fields=('tp',))
Person.drop_collection()
a1 = Person(name="Wilson Father", tp="pj")
a1.save()
a2 = Person(name='Wilson Junior', tp='pf', father=a1)
a2.save()
self.assertEqual(dict(a2.to_mongo()), {
"_id": a2.pk,
"name": u"Wilson Junior",
"tp": u"pf",
"father": {
"_id": a1.pk,
"tp": u"pj"
}
})
self.assertEqual(Person.objects(father=a1)._query, {
'father._id': a1.pk
})
self.assertEqual(Person.objects(father=a1).count(), 1)
Person.objects.update(set__tp="pf")
Person.father.sync_all()
a2.reload()
self.assertEqual(dict(a2.to_mongo()), {
"_id": a2.pk,
"name": u"Wilson Junior",
"tp": u"pf",
"father": {
"_id": a1.pk,
"tp": u"pf"
}
})
def test_cached_reference_fields_on_embedded_documents(self): def test_cached_reference_fields_on_embedded_documents(self):
def build(): def build():
class Test(Document): class Test(Document):