Added DereferenceBaseField class

Handles the lazy dereferencing of all items in a list / dict.
Improves query efficiency by an order of magnitude.
This commit is contained in:
Ross Lawley 2011-06-06 11:04:06 +01:00
parent 5d778648e6
commit ec7effa0ef
3 changed files with 104 additions and 207 deletions

View File

@ -5,6 +5,7 @@ from queryset import DO_NOTHING
import sys
import pymongo
import pymongo.objectid
from operator import itemgetter
class NotRegistered(Exception):
@ -127,6 +128,87 @@ class BaseField(object):
self.validate(value)
class DereferenceBaseField(BaseField):
"""Handles the lazy dereferencing of a queryset. Will dereference all
items in a list / dict rather than one at a time.
"""
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
from fields import ReferenceField, GenericReferenceField
from connection import _get_db
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value_list = instance._data.get(self.name)
if not value_list:
return super(DereferenceBaseField, self).__get__(instance, owner)
is_list = False
if not hasattr(value_list, 'items'):
is_list = True
value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = get_document(ref['_cls'])._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
dbref = {}
classes = {}
for k, v in value_list:
dbref[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
return super(DereferenceBaseField, self).__get__(instance, owner)
class ObjectIdField(BaseField):
"""An field wrapper around MongoDB's ObjectIds.
"""

View File

@ -1,4 +1,5 @@
from base import BaseField, ObjectIdField, ValidationError, get_document
from base import (BaseField, DereferenceBaseField, ObjectIdField,
ValidationError, get_document)
from queryset import DO_NOTHING
from document import Document, EmbeddedDocument
from connection import _get_db
@ -12,7 +13,6 @@ import pymongo.binary
import datetime, time
import decimal
import gridfs
import warnings
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
@ -153,6 +153,7 @@ class IntField(BaseField):
def prepare_query_value(self, op, value):
return int(value)
class FloatField(BaseField):
"""An floating point number field.
"""
@ -178,6 +179,7 @@ class FloatField(BaseField):
def prepare_query_value(self, op, value):
return float(value)
class DecimalField(BaseField):
"""A fixed-point decimal number field.
@ -255,7 +257,6 @@ class DateTimeField(BaseField):
try: # Seconds are optional, so try converting seconds first.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
**kwargs)
except ValueError:
try: # Try without seconds.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
@ -267,6 +268,7 @@ class DateTimeField(BaseField):
except ValueError:
return None
class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`.
@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField):
return self.to_mongo(value)
class ListField(BaseField):
class ListField(DereferenceBaseField):
"""A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database.
"""
@ -330,63 +332,6 @@ class ListField(BaseField):
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value_list = instance._data.get(self.name)
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in enumerate(value_list)]
deref_list = []
collections = {}
for k, v in value_list:
deref_list.append(v)
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
deref_list[key] = get_document(ref['_cls'])._from_son(ref)
instance._data[self.name] = deref_list
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in enumerate(value_list)]
deref_list = []
classes = {}
for k, v in value_list:
deref_list.append(v)
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
deref_list[key] = doc_cls._from_son(ref)
instance._data[self.name] = deref_list
return super(ListField, self).__get__(instance, owner)
def to_python(self, value):
return [self.field.to_python(item) for item in value]
@ -480,10 +425,10 @@ class DictField(BaseField):
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField,self).prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)
class MapField(BaseField):
class MapField(DereferenceBaseField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
@ -515,68 +460,11 @@ class MapField(BaseField):
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value_list = instance._data.get(self.name)
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
deref_dict = {}
collections = {}
for k, v in value_list.items():
deref_dict[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
deref_dict[key] = get_document(ref['_cls'])._from_son(ref)
instance._data[self.name] = deref_dict
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
deref_dict = {}
classes = {}
for k, v in value_list:
deref_dict[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
deref_dict[key] = doc_cls._from_son(ref)
instance._data[self.name] = deref_dict
return super(MapField, self).__get__(instance, owner)
def to_python(self, value):
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
def to_mongo(self, value):
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
def prepare_query_value(self, op, value):
if op not in ('set', 'unset'):
@ -794,11 +682,11 @@ class GridFSProxy(object):
self.newfile = self.fs.new_file(**kwargs)
self.grid_id = self.newfile._id
def put(self, file, **kwargs):
def put(self, file_obj, **kwargs):
if self.grid_id:
raise GridFSError('This document already has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file, **kwargs)
self.grid_id = self.fs.put(file_obj, **kwargs)
def write(self, string):
if self.grid_id:
@ -827,9 +715,9 @@ class GridFSProxy(object):
self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs):
def replace(self, file_obj, **kwargs):
self.delete()
self.put(file, **kwargs)
self.put(file_obj, **kwargs)
def close(self):
if self.newfile:
@ -911,76 +799,3 @@ class GeoPointField(BaseField):
if (not isinstance(value[0], (float, int)) and
not isinstance(value[1], (float, int))):
raise ValidationError('Both values in point must be float or int.')
class DereferenceMixin(object):
""" WORK IN PROGRESS"""
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value_list = instance._data.get(self.name)
if not value_list:
return super(MapField, self).__get__(instance, owner)
is_dict = True
if not hasattr(value_list, 'items'):
is_dict = False
value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
dbref = {}
if not is_dict:
dbref = []
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = get_document(ref['_cls'])._from_son(ref)
instance._data[self.name] = dbref
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
dbref = {}
classes = {}
for k, v in value_list:
dbref[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = doc_cls._from_son(ref)
instance._data[self.name] = dbref
return super(DereferenceField, self).__get__(instance, owner)

View File

@ -11,7 +11,7 @@ class FieldTest(unittest.TestCase):
connect(db='mongoenginetest')
self.db = _get_db()
def ztest_list_item_dereference(self):
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
class User(Document):
@ -42,7 +42,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
def ztest_recursive_reference(self):
def test_recursive_reference(self):
"""Ensure that ReferenceFields can reference their own documents.
"""
class Employee(Document):
@ -75,7 +75,7 @@ class FieldTest(unittest.TestCase):
peter.friends
self.assertEqual(q, 3)
def ztest_generic_reference(self):
def test_generic_reference(self):
class UserA(Document):
name = StringField()