refs #709, added CachedReferenceField.sync_all to sync all documents on demand
This commit is contained in:
parent
6c4aee1479
commit
87c97efce0
@ -11,10 +11,12 @@ from mongoengine.errors import ValidationError
|
|||||||
from mongoengine.base.common import ALLOW_INHERITANCE
|
from mongoengine.base.common import ALLOW_INHERITANCE
|
||||||
from mongoengine.base.datastructures import BaseDict, BaseList
|
from mongoengine.base.datastructures import BaseDict, BaseList
|
||||||
|
|
||||||
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
|
__all__ = ("BaseField", "ComplexBaseField",
|
||||||
|
"ObjectIdField", "GeoJsonBaseField")
|
||||||
|
|
||||||
|
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
|
|
||||||
"""A base class for fields in a MongoDB document. Instances of this class
|
"""A base class for fields in a MongoDB document. Instances of this class
|
||||||
may be added to subclasses of `Document` to define a document's schema.
|
may be added to subclasses of `Document` to define a document's schema.
|
||||||
|
|
||||||
@ -60,6 +62,7 @@ class BaseField(object):
|
|||||||
used when generating model forms from the document model.
|
used when generating model forms from the document model.
|
||||||
"""
|
"""
|
||||||
self.db_field = (db_field or name) if not primary_key else '_id'
|
self.db_field = (db_field or name) if not primary_key else '_id'
|
||||||
|
|
||||||
if name:
|
if name:
|
||||||
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
|
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
|
||||||
warnings.warn(msg, DeprecationWarning)
|
warnings.warn(msg, DeprecationWarning)
|
||||||
@ -105,7 +108,7 @@ class BaseField(object):
|
|||||||
if instance._initialised:
|
if instance._initialised:
|
||||||
try:
|
try:
|
||||||
if (self.name not in instance._data or
|
if (self.name not in instance._data or
|
||||||
instance._data[self.name] != value):
|
instance._data[self.name] != value):
|
||||||
instance._mark_as_changed(self.name)
|
instance._mark_as_changed(self.name)
|
||||||
except:
|
except:
|
||||||
# Values cant be compared eg: naive and tz datetimes
|
# Values cant be compared eg: naive and tz datetimes
|
||||||
@ -175,6 +178,7 @@ class BaseField(object):
|
|||||||
|
|
||||||
|
|
||||||
class ComplexBaseField(BaseField):
|
class ComplexBaseField(BaseField):
|
||||||
|
|
||||||
"""Handles complex fields, such as lists / dictionaries.
|
"""Handles complex fields, such as lists / dictionaries.
|
||||||
|
|
||||||
Allows for nesting of embedded documents inside complex types.
|
Allows for nesting of embedded documents inside complex types.
|
||||||
@ -197,7 +201,7 @@ class ComplexBaseField(BaseField):
|
|||||||
GenericReferenceField = _import_class('GenericReferenceField')
|
GenericReferenceField = _import_class('GenericReferenceField')
|
||||||
dereference = (self._auto_dereference and
|
dereference = (self._auto_dereference and
|
||||||
(self.field is None or isinstance(self.field,
|
(self.field is None or isinstance(self.field,
|
||||||
(GenericReferenceField, ReferenceField))))
|
(GenericReferenceField, ReferenceField))))
|
||||||
|
|
||||||
_dereference = _import_class("DeReference")()
|
_dereference = _import_class("DeReference")()
|
||||||
|
|
||||||
@ -212,7 +216,7 @@ class ComplexBaseField(BaseField):
|
|||||||
|
|
||||||
# Convert lists / values so we can watch for any changes on them
|
# Convert lists / values so we can watch for any changes on them
|
||||||
if (isinstance(value, (list, tuple)) and
|
if (isinstance(value, (list, tuple)) and
|
||||||
not isinstance(value, BaseList)):
|
not isinstance(value, BaseList)):
|
||||||
value = BaseList(value, instance, self.name)
|
value = BaseList(value, instance, self.name)
|
||||||
instance._data[self.name] = value
|
instance._data[self.name] = value
|
||||||
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
||||||
@ -220,8 +224,8 @@ class ComplexBaseField(BaseField):
|
|||||||
instance._data[self.name] = value
|
instance._data[self.name] = value
|
||||||
|
|
||||||
if (self._auto_dereference and instance._initialised and
|
if (self._auto_dereference and instance._initialised and
|
||||||
isinstance(value, (BaseList, BaseDict))
|
isinstance(value, (BaseList, BaseDict))
|
||||||
and not value._dereferenced):
|
and not value._dereferenced):
|
||||||
value = _dereference(
|
value = _dereference(
|
||||||
value, max_depth=1, instance=instance, name=self.name
|
value, max_depth=1, instance=instance, name=self.name
|
||||||
)
|
)
|
||||||
@ -384,6 +388,7 @@ class ComplexBaseField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class ObjectIdField(BaseField):
|
class ObjectIdField(BaseField):
|
||||||
|
|
||||||
"""A field wrapper around MongoDB's ObjectIds.
|
"""A field wrapper around MongoDB's ObjectIds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -412,6 +417,7 @@ class ObjectIdField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class GeoJsonBaseField(BaseField):
|
class GeoJsonBaseField(BaseField):
|
||||||
|
|
||||||
"""A geo json field storing a geojson style object.
|
"""A geo json field storing a geojson style object.
|
||||||
.. versionadded:: 0.8
|
.. versionadded:: 0.8
|
||||||
"""
|
"""
|
||||||
@ -435,7 +441,8 @@ class GeoJsonBaseField(BaseField):
|
|||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if set(value.keys()) == set(['type', 'coordinates']):
|
if set(value.keys()) == set(['type', 'coordinates']):
|
||||||
if value['type'] != self._type:
|
if value['type'] != self._type:
|
||||||
self.error('%s type must be "%s"' % (self._name, self._type))
|
self.error('%s type must be "%s"' %
|
||||||
|
(self._name, self._type))
|
||||||
return self.validate(value['coordinates'])
|
return self.validate(value['coordinates'])
|
||||||
else:
|
else:
|
||||||
self.error('%s can only accept a valid GeoJson dictionary'
|
self.error('%s can only accept a valid GeoJson dictionary'
|
||||||
|
@ -30,7 +30,8 @@ class DocumentMetaclass(type):
|
|||||||
return super_new(cls, name, bases, attrs)
|
return super_new(cls, name, bases, attrs)
|
||||||
|
|
||||||
attrs['_is_document'] = attrs.get('_is_document', False)
|
attrs['_is_document'] = attrs.get('_is_document', False)
|
||||||
|
attrs['_cached_reference_fields'] = []
|
||||||
|
|
||||||
# EmbeddedDocuments could have meta data for inheritance
|
# EmbeddedDocuments could have meta data for inheritance
|
||||||
if 'meta' in attrs:
|
if 'meta' in attrs:
|
||||||
attrs['_meta'] = attrs.pop('meta')
|
attrs['_meta'] = attrs.pop('meta')
|
||||||
@ -172,10 +173,17 @@ class DocumentMetaclass(type):
|
|||||||
f = field
|
f = field
|
||||||
f.owner_document = new_class
|
f.owner_document = new_class
|
||||||
delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
|
delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
|
||||||
if isinstance(f, CachedReferenceField) and issubclass(
|
if isinstance(f, CachedReferenceField):
|
||||||
new_class, EmbeddedDocument):
|
|
||||||
raise InvalidDocumentError(
|
if issubclass(new_class, EmbeddedDocument):
|
||||||
"CachedReferenceFields is not allowed in EmbeddedDocuments")
|
raise InvalidDocumentError(
|
||||||
|
"CachedReferenceFields is not allowed in EmbeddedDocuments")
|
||||||
|
if not f.document_type:
|
||||||
|
raise InvalidDocumentError(
|
||||||
|
"Document is not avaiable to sync")
|
||||||
|
|
||||||
|
f.document_type._cached_reference_fields.append(f)
|
||||||
|
|
||||||
if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
|
if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
|
||||||
delete_rule = getattr(f.field,
|
delete_rule = getattr(f.field,
|
||||||
'reverse_delete_rule',
|
'reverse_delete_rule',
|
||||||
|
@ -1047,12 +1047,13 @@ class CachedReferenceField(BaseField):
|
|||||||
doc_tipe = self.document_type
|
doc_tipe = self.document_type
|
||||||
|
|
||||||
if isinstance(document, Document):
|
if isinstance(document, Document):
|
||||||
# We need the id from the saved object to create the DBRef
|
# Wen need the id from the saved object to create the DBRef
|
||||||
id_ = document.pk
|
id_ = document.pk
|
||||||
if id_ is None:
|
if id_ is None:
|
||||||
self.error('You can only reference documents once they have'
|
self.error('You can only reference documents once they have'
|
||||||
' been saved to the database')
|
' been saved to the database')
|
||||||
else:
|
else:
|
||||||
|
raise SystemError(document)
|
||||||
self.error('Only accept a document object')
|
self.error('Only accept a document object')
|
||||||
|
|
||||||
value = {
|
value = {
|
||||||
@ -1065,7 +1066,14 @@ class CachedReferenceField(BaseField):
|
|||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return self.to_mongo(value)
|
|
||||||
|
if isinstance(value, Document):
|
||||||
|
if value.pk is None:
|
||||||
|
self.error('You can only reference documents once they have'
|
||||||
|
' been saved to the database')
|
||||||
|
return {'_id': value.pk}
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
|
||||||
@ -1079,6 +1087,22 @@ class CachedReferenceField(BaseField):
|
|||||||
def lookup_member(self, member_name):
|
def lookup_member(self, member_name):
|
||||||
return self.document_type._fields.get(member_name)
|
return self.document_type._fields.get(member_name)
|
||||||
|
|
||||||
|
def sync_all(self):
|
||||||
|
update_key = 'set__%s' % self.name
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for doc in self.document_type.objects:
|
||||||
|
filter_kwargs = {}
|
||||||
|
filter_kwargs[self.name] = doc
|
||||||
|
|
||||||
|
update_kwargs = {}
|
||||||
|
update_kwargs[update_key] = doc
|
||||||
|
|
||||||
|
errors.append((filter_kwargs, update_kwargs))
|
||||||
|
|
||||||
|
self.owner_document.objects(
|
||||||
|
**filter_kwargs).update(**update_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class GenericReferenceField(BaseField):
|
class GenericReferenceField(BaseField):
|
||||||
|
|
||||||
|
@ -12,21 +12,21 @@ __all__ = ('query', 'update')
|
|||||||
|
|
||||||
COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
|
COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
|
||||||
'all', 'size', 'exists', 'not', 'elemMatch')
|
'all', 'size', 'exists', 'not', 'elemMatch')
|
||||||
GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
|
GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
|
||||||
'within_box', 'within_polygon', 'near', 'near_sphere',
|
'within_box', 'within_polygon', 'near', 'near_sphere',
|
||||||
'max_distance', 'geo_within', 'geo_within_box',
|
'max_distance', 'geo_within', 'geo_within_box',
|
||||||
'geo_within_polygon', 'geo_within_center',
|
'geo_within_polygon', 'geo_within_center',
|
||||||
'geo_within_sphere', 'geo_intersects')
|
'geo_within_sphere', 'geo_intersects')
|
||||||
STRING_OPERATORS = ('contains', 'icontains', 'startswith',
|
STRING_OPERATORS = ('contains', 'icontains', 'startswith',
|
||||||
'istartswith', 'endswith', 'iendswith',
|
'istartswith', 'endswith', 'iendswith',
|
||||||
'exact', 'iexact')
|
'exact', 'iexact')
|
||||||
CUSTOM_OPERATORS = ('match',)
|
CUSTOM_OPERATORS = ('match',)
|
||||||
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
||||||
STRING_OPERATORS + CUSTOM_OPERATORS)
|
STRING_OPERATORS + CUSTOM_OPERATORS)
|
||||||
|
|
||||||
UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
|
UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
|
||||||
'push_all', 'pull', 'pull_all', 'add_to_set',
|
'push_all', 'pull', 'pull_all', 'add_to_set',
|
||||||
'set_on_insert')
|
'set_on_insert')
|
||||||
|
|
||||||
|
|
||||||
def query(_doc_cls=None, _field_operation=False, **query):
|
def query(_doc_cls=None, _field_operation=False, **query):
|
||||||
@ -60,14 +60,20 @@ def query(_doc_cls=None, _field_operation=False, **query):
|
|||||||
raise InvalidQueryError(e)
|
raise InvalidQueryError(e)
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
|
CachedReferenceField = _import_class('CachedReferenceField')
|
||||||
|
|
||||||
cleaned_fields = []
|
cleaned_fields = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
append_field = True
|
append_field = True
|
||||||
if isinstance(field, basestring):
|
if isinstance(field, basestring):
|
||||||
parts.append(field)
|
parts.append(field)
|
||||||
append_field = False
|
append_field = False
|
||||||
|
# is last and CachedReferenceField
|
||||||
|
elif isinstance(field, CachedReferenceField) and fields[-1] == field:
|
||||||
|
parts.append('%s._id' % field.db_field)
|
||||||
else:
|
else:
|
||||||
parts.append(field.db_field)
|
parts.append(field.db_field)
|
||||||
|
|
||||||
if append_field:
|
if append_field:
|
||||||
cleaned_fields.append(field)
|
cleaned_fields.append(field)
|
||||||
|
|
||||||
@ -79,13 +85,17 @@ def query(_doc_cls=None, _field_operation=False, **query):
|
|||||||
if op in singular_ops:
|
if op in singular_ops:
|
||||||
if isinstance(field, basestring):
|
if isinstance(field, basestring):
|
||||||
if (op in STRING_OPERATORS and
|
if (op in STRING_OPERATORS and
|
||||||
isinstance(value, basestring)):
|
isinstance(value, basestring)):
|
||||||
StringField = _import_class('StringField')
|
StringField = _import_class('StringField')
|
||||||
value = StringField.prepare_query_value(op, value)
|
value = StringField.prepare_query_value(op, value)
|
||||||
else:
|
else:
|
||||||
value = field
|
value = field
|
||||||
else:
|
else:
|
||||||
value = field.prepare_query_value(op, value)
|
value = field.prepare_query_value(op, value)
|
||||||
|
|
||||||
|
if isinstance(field, CachedReferenceField) and value:
|
||||||
|
value = value['_id']
|
||||||
|
|
||||||
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
|
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
|
||||||
# 'in', 'nin' and 'all' require a list of values
|
# 'in', 'nin' and 'all' require a list of values
|
||||||
value = [field.prepare_query_value(op, v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
@ -125,10 +135,12 @@ def query(_doc_cls=None, _field_operation=False, **query):
|
|||||||
continue
|
continue
|
||||||
value_son[k] = v
|
value_son[k] = v
|
||||||
if (get_connection().max_wire_version <= 1):
|
if (get_connection().max_wire_version <= 1):
|
||||||
value_son['$maxDistance'] = value_dict['$maxDistance']
|
value_son['$maxDistance'] = value_dict[
|
||||||
|
'$maxDistance']
|
||||||
else:
|
else:
|
||||||
value_son['$near'] = SON(value_son['$near'])
|
value_son['$near'] = SON(value_son['$near'])
|
||||||
value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
|
value_son['$near'][
|
||||||
|
'$maxDistance'] = value_dict['$maxDistance']
|
||||||
else:
|
else:
|
||||||
for k, v in value_dict.iteritems():
|
for k, v in value_dict.iteritems():
|
||||||
if k == '$maxDistance':
|
if k == '$maxDistance':
|
||||||
@ -264,7 +276,8 @@ def update(_doc_cls=None, **update):
|
|||||||
if ListField in field_classes:
|
if ListField in field_classes:
|
||||||
# Join all fields via dot notation to the last ListField
|
# Join all fields via dot notation to the last ListField
|
||||||
# Then process as normal
|
# Then process as normal
|
||||||
last_listField = len(cleaned_fields) - field_classes.index(ListField)
|
last_listField = len(
|
||||||
|
cleaned_fields) - field_classes.index(ListField)
|
||||||
key = ".".join(parts[:last_listField])
|
key = ".".join(parts[:last_listField])
|
||||||
parts = parts[last_listField:]
|
parts = parts[last_listField:]
|
||||||
parts.insert(0, key)
|
parts.insert(0, key)
|
||||||
|
@ -1518,11 +1518,18 @@ class FieldTest(unittest.TestCase):
|
|||||||
Animal.drop_collection()
|
Animal.drop_collection()
|
||||||
Ocorrence.drop_collection()
|
Ocorrence.drop_collection()
|
||||||
|
|
||||||
a = Animal(nam="Leopard", tag="heavy")
|
a = Animal(name="Leopard", tag="heavy")
|
||||||
a.save()
|
a.save()
|
||||||
|
|
||||||
|
self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal])
|
||||||
o = Ocorrence(person="teste", animal=a)
|
o = Ocorrence(person="teste", animal=a)
|
||||||
o.save()
|
o.save()
|
||||||
|
|
||||||
|
p = Ocorrence(person="Wilson")
|
||||||
|
p.save()
|
||||||
|
|
||||||
|
self.assertEqual(Ocorrence.objects(animal=None).count(), 1)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
|
a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
|
||||||
|
|
||||||
@ -1539,6 +1546,56 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertEqual(ocorrence.person, "teste")
|
self.assertEqual(ocorrence.person, "teste")
|
||||||
self.assertTrue(isinstance(ocorrence.animal, Animal))
|
self.assertTrue(isinstance(ocorrence.animal, Animal))
|
||||||
|
|
||||||
|
def test_cached_reference_field_update_all(self):
|
||||||
|
class Person(Document):
|
||||||
|
TYPES = (
|
||||||
|
('pf', "PF"),
|
||||||
|
('pj', "PJ")
|
||||||
|
)
|
||||||
|
name = StringField()
|
||||||
|
tp = StringField(
|
||||||
|
choices=TYPES
|
||||||
|
)
|
||||||
|
|
||||||
|
father = CachedReferenceField('self', fields=('tp',))
|
||||||
|
|
||||||
|
Person.drop_collection()
|
||||||
|
|
||||||
|
a1 = Person(name="Wilson Father", tp="pj")
|
||||||
|
a1.save()
|
||||||
|
|
||||||
|
a2 = Person(name='Wilson Junior', tp='pf', father=a1)
|
||||||
|
a2.save()
|
||||||
|
|
||||||
|
self.assertEqual(dict(a2.to_mongo()), {
|
||||||
|
"_id": a2.pk,
|
||||||
|
"name": u"Wilson Junior",
|
||||||
|
"tp": u"pf",
|
||||||
|
"father": {
|
||||||
|
"_id": a1.pk,
|
||||||
|
"tp": u"pj"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
self.assertEqual(Person.objects(father=a1)._query, {
|
||||||
|
'father._id': a1.pk
|
||||||
|
})
|
||||||
|
self.assertEqual(Person.objects(father=a1).count(), 1)
|
||||||
|
|
||||||
|
Person.objects.update(set__tp="pf")
|
||||||
|
Person.father.sync_all()
|
||||||
|
|
||||||
|
a2.reload()
|
||||||
|
self.assertEqual(dict(a2.to_mongo()), {
|
||||||
|
"_id": a2.pk,
|
||||||
|
"name": u"Wilson Junior",
|
||||||
|
"tp": u"pf",
|
||||||
|
"father": {
|
||||||
|
"_id": a1.pk,
|
||||||
|
"tp": u"pf"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
def test_cached_reference_fields_on_embedded_documents(self):
|
def test_cached_reference_fields_on_embedded_documents(self):
|
||||||
def build():
|
def build():
|
||||||
class Test(Document):
|
class Test(Document):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user