Added ComplexBaseField

* Handles the efficient lazy dereferencing of DBrefs.
* Handles complex nested values in ListFields and DictFields
* Allows for both strictly declared ListFields and DictFields where the embedded
value must be of a field type or no restrictions where the values can be a mix
of field types / values.
* Handles DBrefences of documents where allow_inheritance = False.
This commit is contained in:
Ross Lawley
2011-06-06 17:21:54 +01:00
parent 602d7dad00
commit 4b9bacf731
5 changed files with 555 additions and 199 deletions

View File

@@ -132,15 +132,19 @@ class BaseField(object):
self.validate(value)
class DereferenceBaseField(BaseField):
"""Handles the lazy dereferencing of a queryset. Will dereference all
class ComplexBaseField(BaseField):
"""Handles complex fields, such as lists / dictionaries.
Allows for nesting of embedded documents inside complex types.
Handles the lazy dereferencing of a queryset by lazily dereferencing all
items in a list / dict rather than one at a time.
"""
field = None
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
from fields import ReferenceField, GenericReferenceField
from connection import _get_db
if instance is None:
@@ -149,68 +153,175 @@ class DereferenceBaseField(BaseField):
# Get value from document instance if available
value_list = instance._data.get(self.name)
if not value_list:
return super(DereferenceBaseField, self).__get__(instance, owner)
if not value_list or isinstance(value_list, basestring):
return super(ComplexBaseField, self).__get__(instance, owner)
is_list = False
if not hasattr(value_list, 'items'):
is_list = True
value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
dbref = {}
collections = {}
for k,v in value_list.items():
if isinstance(v, dict) and '_cls' in v and '_ref' not in v:
value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v)
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
# Handle all dereferencing
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef)
collections.setdefault(v.collection, []).append((k, v))
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference
collection = get_document(v['_cls'])._meta['collection']
collections.setdefault(collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = {}
for k, v in dbrefs:
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# direct reference (DBRef), has no _cls information
id_map[v.id] = (k, None)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference - includes _cls information
id_map[v['_ref'].id] = (k, get_document(v['_cls']))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = get_document(ref['_cls'])._from_son(ref)
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key, doc_cls = id_map[ref['_id']]
if not doc_cls: # If no doc_cls get it from the referenced doc
doc_cls = get_document(ref['_cls'])
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
return super(ComplexBaseField, self).__get__(instance, owner)
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
dbref = {}
classes = {}
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
from mongoengine import Document
for k, v in value_list:
dbref[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
if isinstance(value, basestring):
return value
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
if hasattr(value, 'to_python'):
return value.to_python()
for ref in references:
key = id_map[ref['_id']]
dbref[key] = doc_cls._from_son(ref)
is_list = False
if not hasattr(value, 'items'):
try:
is_list = True
value = dict([(k,v) for k,v in enumerate(value)])
except TypeError: # Not iterable return the value
return value
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
if self.field:
value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()])
else:
value_dict = {}
for k,v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
collection = v._meta['collection']
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_python'):
value_dict[k] = v.to_python()
else:
value_dict[k] = self.to_python(v)
instance._data[self.name] = dbref
if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))]
return value_dict
return super(DereferenceBaseField, self).__get__(instance, owner)
def to_mongo(self, value):
"""Convert a Python type to a MongoDB-compatible type.
"""
from mongoengine import Document
if isinstance(value, basestring):
return value
if hasattr(value, 'to_mongo'):
return value.to_mongo()
is_list = False
if not hasattr(value, 'items'):
try:
is_list = True
value = dict([(k,v) for k,v in enumerate(value)])
except TypeError: # Not iterable return the value
return value
if self.field:
value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()])
else:
value_dict = {}
for k,v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
# If its a document that is not inheritable it won't have
# _types / _cls data so make it a generic reference allows
# us to dereference
meta = getattr(v, 'meta', getattr(v, '_meta', {}))
if meta and not meta['allow_inheritance'] and not self.field:
from fields import GenericReferenceField
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._meta['collection']
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo()
else:
value_dict[k] = self.to_mongo(v)
if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))]
return value_dict
def validate(self, value):
"""If field provided ensure the value is valid.
"""
if self.field:
try:
if hasattr(value, 'iteritems'):
[self.field.validate(v) for k,v in value.iteritems()]
else:
[self.field.validate(v) for v in value]
except Exception, err:
raise ValidationError('Invalid %s item (%s)' % (
self.field.__class__.__name__, str(v)))
def prepare_query_value(self, op, value):
return self.to_mongo(value)
def lookup_member(self, member_name):
if self.field:
return self.field.lookup_member(member_name)
return None
def _set_owner_document(self, owner_document):
if self.field:
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ObjectIdField(BaseField):
@@ -219,7 +330,6 @@ class ObjectIdField(BaseField):
def to_python(self, value):
return value
# return unicode(value)
def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId):

View File

@@ -1,4 +1,4 @@
from base import (BaseField, DereferenceBaseField, ObjectIdField,
from base import (BaseField, ComplexBaseField, ObjectIdField,
ValidationError, get_document)
from queryset import DO_NOTHING
from document import Document, EmbeddedDocument
@@ -301,6 +301,8 @@ class EmbeddedDocumentField(BaseField):
return value
def to_mongo(self, value):
if isinstance(value, basestring):
return value
return self.document_type.to_mongo(value)
def validate(self, value):
@@ -320,7 +322,7 @@ class EmbeddedDocumentField(BaseField):
return self.to_mongo(value)
class ListField(DereferenceBaseField):
class ListField(ComplexBaseField):
"""A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database.
"""
@@ -328,48 +330,25 @@ class ListField(DereferenceBaseField):
# ListFields cannot be indexed with _types - MongoDB doesn't support this
_index_with_types = False
def __init__(self, field, **kwargs):
if not isinstance(field, BaseField):
raise ValidationError('Argument to ListField constructor must be '
'a valid field')
def __init__(self, field=None, **kwargs):
self.field = field
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def to_python(self, value):
return [self.field.to_python(item) for item in value]
def to_mongo(self, value):
return [self.field.to_mongo(item) for item in value]
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, (list, tuple)):
raise ValidationError('Only lists and tuples may be used in a '
'list field')
try:
[self.field.validate(item) for item in value]
except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(item))
super(ListField, self).validate(value)
def prepare_query_value(self, op, value):
if op in ('set', 'unset'):
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
if self.field:
if op in ('set', 'unset') and not isinstance(value, basestring):
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
return super(ListField, self).prepare_query_value(op, value)
class SortedListField(ListField):
@@ -388,20 +367,21 @@ class SortedListField(ListField):
super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value)
if self._ordering is not None:
return sorted([self.field.to_mongo(item) for item in value],
key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value])
return sorted(value, key=itemgetter(self._ordering))
return sorted(value)
class DictField(BaseField):
class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
.. versionadded:: 0.3
"""
def __init__(self, basecls=None, *args, **kwargs):
def __init__(self, basecls=None, field=None, *args, **kwargs):
self.field = field
self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField)
kwargs.setdefault('default', lambda: {})
@@ -417,6 +397,7 @@ class DictField(BaseField):
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
super(DictField, self).validate(value)
def lookup_member(self, member_name):
return DictField(basecls=self.basecls, db_field=member_name)
@@ -432,7 +413,7 @@ class DictField(BaseField):
return super(DictField, self).prepare_query_value(op, value)
class MapField(DereferenceBaseField):
class MapField(DictField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
@@ -444,50 +425,7 @@ class MapField(DereferenceBaseField):
if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be '
'a valid field')
self.field = field
kwargs.setdefault('default', lambda: {})
super(MapField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a '
'DictField')
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
try:
[self.field.validate(item) for item in value.values()]
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def to_python(self, value):
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
def to_mongo(self, value):
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
def prepare_query_value(self, op, value):
if op not in ('set', 'unset'):
return self.field.prepare_query_value(op, value)
for key in value:
value[key] = self.field.prepare_query_value(op, value[key])
return value
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
super(MapField, self).__init__(field=field, *args, **kwargs)
class ReferenceField(BaseField):

View File

@@ -549,11 +549,12 @@ class QuerySet(object):
parts = [parts]
fields = []
field = None
for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
new_field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
@@ -567,11 +568,17 @@ class QuerySet(object):
field = document._fields[field_name]
else:
# Look up subfield on the previous field
field = field.lookup_member(field_name)
if field is None:
new_field = field.lookup_member(field_name)
from base import ComplexBaseField
if not new_field and isinstance(field, ComplexBaseField):
fields.append(field_name)
continue
elif not new_field:
raise InvalidQueryError('Cannot resolve field "%s"'
% field_name)
% field_name)
field = new_field # update field to the new field type
fields.append(field)
return fields
@classmethod
@@ -615,14 +622,33 @@ class QuerySet(object):
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields]
parts = []
cleaned_fields = []
append_field = True
for field in fields:
if isinstance(field, str):
parts.append(field)
append_field = False
else:
parts.append(field.db_field)
if append_field:
cleaned_fields.append(field)
# Convert value to proper value
field = fields[-1]
field = cleaned_fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += match_operators
if op in singular_ops:
value = field.prepare_query_value(op, value)
if isinstance(field, basestring):
if op in match_operators and isinstance(value, basestring):
from mongoengine import StringField
value = StringField().prepare_query_value(op, value)
else:
value = field
else:
value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all', 'near'):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
@@ -1170,14 +1196,19 @@ class QuerySet(object):
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = []
cleaned_fields = []
append_field = True
for field in fields:
if isinstance(field, str):
parts.append(field)
append_field = False
else:
parts.append(field.db_field)
if append_field:
cleaned_fields.append(field)
# Convert value to proper value
field = fields[-1]
field = cleaned_fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value)