diff --git a/mongoengine/base.py b/mongoengine/base.py index 4e3154fd..ce61547e 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -5,6 +5,7 @@ from queryset import DO_NOTHING import sys import pymongo import pymongo.objectid +from operator import itemgetter class NotRegistered(Exception): @@ -127,6 +128,87 @@ class BaseField(object): self.validate(value) +class DereferenceBaseField(BaseField): + """Handles the lazy dereferencing of a queryset. Will dereference all + items in a list / dict rather than one at a time. + """ + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + from fields import ReferenceField, GenericReferenceField + from connection import _get_db + + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value_list = instance._data.get(self.name) + if not value_list: + return super(DereferenceBaseField, self).__get__(instance, owner) + + is_list = False + if not hasattr(value_list, 'items'): + is_list = True + value_list = dict([(k,v) for k,v in enumerate(value_list)]) + + if isinstance(self.field, ReferenceField) and value_list: + db = _get_db() + dbref = {} + collections = {} + + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + collections.setdefault(v.collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = dict([(v.id, k) for k, v in dbrefs]) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key = id_map[ref['_id']] + dbref[key] = get_document(ref['_cls'])._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + + # Get value from document instance if available + if isinstance(self.field, GenericReferenceField) and value_list: + db = _get_db() + value_list = [(k,v) for k,v in value_list.items()] + dbref = {} + classes = {} + + for k, v in value_list: + dbref[k] = v + # Save any DBRefs + if isinstance(v, (dict, pymongo.son.SON)): + classes.setdefault(v['_cls'], []).append((k, v)) + + # For each collection get the references + for doc_cls, dbrefs in classes.items(): + id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) + doc_cls = get_document(doc_cls) + collection = doc_cls._meta['collection'] + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + + for ref in references: + key = id_map[ref['_id']] + dbref[key] = doc_cls._from_son(ref) + + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + + instance._data[self.name] = dbref + + return super(DereferenceBaseField, self).__get__(instance, owner) + + + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c21829c9..dc03fc05 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,5 @@ -from base import BaseField, ObjectIdField, ValidationError, get_document +from base import (BaseField, DereferenceBaseField, ObjectIdField, + ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument from connection import _get_db @@ -12,7 +13,6 @@ import pymongo.binary import datetime, time import decimal import gridfs -import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', @@ -118,8 +118,8 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,6 +153,7 @@ class IntField(BaseField): def prepare_query_value(self, op, value): return int(value) + class FloatField(BaseField): """An floating point number field. """ @@ -178,6 +179,7 @@ class FloatField(BaseField): def prepare_query_value(self, op, value): return float(value) + class DecimalField(BaseField): """A fixed-point decimal number field. @@ -252,21 +254,21 @@ class DateTimeField(BaseField): else: usecs = 0 kwargs = {'microsecond': usecs} - try: # Seconds are optional, so try converting seconds first. + try: # Seconds are optional, so try converting seconds first. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], **kwargs) - except ValueError: - try: # Try without seconds. + try: # Try without seconds. return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], **kwargs) - except ValueError: # Try without hour/minutes/seconds. + except ValueError: # Try without hour/minutes/seconds. try: return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], **kwargs) except ValueError: return None + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. @@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(BaseField): +class ListField(DereferenceBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -330,63 +332,6 @@ class ListField(BaseField): kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - value_list = [(k,v) for k,v in enumerate(value_list)] - deref_list = [] - collections = {} - - for k, v in value_list: - deref_list.append(v) - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - deref_list[key] = get_document(ref['_cls'])._from_son(ref) - instance._data[self.name] = deref_list - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in enumerate(value_list)] - deref_list = [] - classes = {} - - for k, v in value_list: - deref_list.append(v) - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - deref_list[key] = doc_cls._from_son(ref) - instance._data[self.name] = deref_list - return super(ListField, self).__get__(instance, owner) - def to_python(self, value): return [self.field.to_python(item) for item in value] @@ -480,10 +425,10 @@ class DictField(BaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField,self).prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) -class MapField(BaseField): +class MapField(DereferenceBaseField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -515,68 +460,11 @@ class MapField(BaseField): except Exception, err: raise ValidationError('Invalid MapField item (%s)' % str(item)) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - deref_dict = {} - collections = {} - - for k, v in value_list.items(): - deref_dict[k] = v - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - deref_dict[key] = get_document(ref['_cls'])._from_son(ref) - instance._data[self.name] = deref_dict - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - deref_dict = {} - classes = {} - - for k, v in value_list: - deref_dict[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - deref_dict[key] = doc_cls._from_son(ref) - instance._data[self.name] = deref_dict - - return super(MapField, self).__get__(instance, owner) - def to_python(self, value): - return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) def to_mongo(self, value): - return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) def prepare_query_value(self, op, value): if op not in ('set', 'unset'): @@ -794,11 +682,11 @@ class GridFSProxy(object): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id - def put(self, file, **kwargs): + def put(self, file_obj, **kwargs): if self.grid_id: raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') - self.grid_id = self.fs.put(file, **kwargs) + self.grid_id = self.fs.put(file_obj, **kwargs) def write(self, string): if self.grid_id: @@ -827,9 +715,9 @@ class GridFSProxy(object): self.grid_id = None self.gridout = None - def replace(self, file, **kwargs): + def replace(self, file_obj, **kwargs): self.delete() - self.put(file, **kwargs) + self.put(file_obj, **kwargs) def close(self): if self.newfile: @@ -911,76 +799,3 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') - - - -class DereferenceMixin(object): - """ WORK IN PROGRESS""" - - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available - value_list = instance._data.get(self.name) - if not value_list: - return super(MapField, self).__get__(instance, owner) - - is_dict = True - if not hasattr(value_list, 'items'): - is_dict = False - value_list = dict([(k,v) for k,v in enumerate(value_list)]) - - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - dbref = {} - if not is_dict: - dbref = [] - collections = {} - - for k, v in value_list.items(): - dbref[k] = v - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - dbref[key] = get_document(ref['_cls'])._from_son(ref) - - instance._data[self.name] = dbref - - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - dbref = {} - classes = {} - - for k, v in value_list: - dbref[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) - - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - - for ref in references: - key = id_map[ref['_id']] - dbref[key] = doc_cls._from_son(ref) - instance._data[self.name] = dbref - - return super(DereferenceField, self).__get__(instance, owner) \ No newline at end of file diff --git a/tests/dereference.py b/tests/dereference.py index 2764ee72..b6cee89e 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -11,7 +11,7 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = _get_db() - def ztest_list_item_dereference(self): + def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ class User(Document): @@ -42,7 +42,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() - def ztest_recursive_reference(self): + def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ class Employee(Document): @@ -75,7 +75,7 @@ class FieldTest(unittest.TestCase): peter.friends self.assertEqual(q, 3) - def ztest_generic_reference(self): + def test_generic_reference(self): class UserA(Document): name = StringField()