diff --git a/mongoengine/base.py b/mongoengine/base.py index 12c760aa..8101aa00 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -8,7 +8,7 @@ import weakref import sys import pymongo import pymongo.objectid -from operator import itemgetter +import operator from functools import partial @@ -163,70 +163,14 @@ class ComplexBaseField(BaseField): def __get__(self, instance, owner): """Descriptor to automatically dereference references. """ - from connection import _get_db - if instance is None: # Document class being used rather than a document object return self - # Get value from document instance if available - value_list = instance._data.get(self.name) - if not value_list or isinstance(value_list, basestring): - return super(ComplexBaseField, self).__get__(instance, owner) - - is_list = False - if not hasattr(value_list, 'items'): - is_list = True - value_list = dict([(k,v) for k,v in enumerate(value_list)]) - - for k,v in value_list.items(): - if isinstance(v, dict) and '_cls' in v and '_ref' not in v: - value_list[k] = get_document(v['_cls'])._from_son(v) - - # Handle all dereferencing - db = _get_db() - dbref = {} - collections = {} - for k,v in value_list.items(): - - # Save any DBRefs - if isinstance(v, (pymongo.dbref.DBRef)): - # direct reference (DBRef) - collections.setdefault(v.collection, []).append((k,v)) - elif isinstance(v, (dict, pymongo.son.SON)): - if '_ref' in v: - # generic reference - collection = get_document(v['_cls'])._get_collection_name() - collections.setdefault(collection, []).append((k,v)) - else: - # Use BaseDict so can watch any changes - dbref[k] = BaseDict(v, instance=instance, name=self.name) - else: - dbref[k] = v - - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = {} - for k,v in dbrefs: - if isinstance(v, (pymongo.dbref.DBRef)): - # direct reference (DBRef), has no _cls information - id_map[v.id] = (k, None) - elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: - # generic reference - includes _cls information - id_map[v['_ref'].id] = (k, get_document(v['_cls'])) - - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key, doc_cls = id_map[ref['_id']] - if not doc_cls: # If no doc_cls get it from the referenced doc - doc_cls = get_document(ref['_cls']) - dbref[key] = doc_cls._from_son(ref) - - if is_list: - dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name) - else: - dbref = BaseDict(dbref, instance=instance, name=self.name) - instance._data[self.name] = dbref + from dereference import dereference + instance._data[self.name] = dereference( + instance._data.get(self.name), max_depth=1, instance=instance, name=self.name, get=True + ) return super(ComplexBaseField, self).__get__(instance, owner) def to_python(self, value): @@ -266,7 +210,7 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_python(v) if is_list: # Convert back to a list - return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] + return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] return value_dict def to_mongo(self, value): @@ -315,7 +259,7 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_mongo(v) if is_list: # Convert back to a list - return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] + return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] return value_dict def validate(self, value): @@ -907,7 +851,7 @@ class BaseList(list): """ def __init__(self, list_items, instance, name): - self.instance = weakref.proxy(instance) + self.instance = instance self.name = name super(BaseList, self).__init__(list_items) @@ -958,7 +902,7 @@ class BaseDict(dict): """ def __init__(self, dict_items, instance, name): - self.instance = weakref.proxy(instance) + self.instance = instance self.name = name super(BaseDict, self).__init__(dict_items) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py new file mode 100644 index 00000000..9192901c --- /dev/null +++ b/mongoengine/dereference.py @@ -0,0 +1,171 @@ +import operator + +import pymongo + +from base import BaseDict, BaseList, get_document +from connection import _get_db +from queryset import QuerySet + + +class DeReference(object): + + def __call__(self, items, max_depth=1, instance=None, name=None, get=False): + """ + Cheaply dereferences the items to a set depth. + Also handles the convertion of complex data types. + + :param items: The iterable (dict, list, queryset) to be dereferenced. + :param max_depth: The maximum depth to recurse to + :param instance: The owning instance used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param name: The name of the field, used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param get: A boolean determining if being called by __get__ + """ + if items is None or isinstance(items, basestring): + return items + + # cheapest way to convert a queryset to a list + # list(queryset) uses a count() query to determine length + if isinstance(items, QuerySet): + items = [i for i in items] + + self.max_depth = max_depth + self.reference_map = self._find_references(items) + self.object_map = self._fetch_objects() + return self._attach_objects(items, 0, instance, name, get) + + def _find_references(self, items, depth=0): + """ + Recursively finds all db references to be dereferenced + + :param items: The iterable (dict, list, queryset) + :param depth: The current depth of recursion + """ + reference_map = {} + if not items: + return reference_map + + # Determine the iterator to use + if not hasattr(items, 'items'): + iterator = enumerate(items) + else: + iterator = items.iteritems() + + # Recursively find dbreferences + for k, item in iterator: + if hasattr(item, '_fields'): + for field_name, field in item._fields.iteritems(): + v = item._data.get(field_name, None) + if isinstance(v, (pymongo.dbref.DBRef)): + reference_map.setdefault(field.document_type, []).append(v.id) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) + elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: + field_cls = getattr(getattr(field, 'field', None), 'document_type', None) + references = self._find_references(v, depth) + for key, refs in references.iteritems(): + if field_cls: + key = field_cls + reference_map.setdefault(key, []).extend(refs) + elif isinstance(item, (pymongo.dbref.DBRef)): + reference_map.setdefault(item.collection, []).append(item.id) + elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: + reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) + elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth: + references = self._find_references(item, depth) + for key, refs in references.iteritems(): + reference_map.setdefault(key, []).extend(refs) + depth += 1 + return reference_map + + def _fetch_objects(self): + """Fetch all references and convert to their document objects + """ + object_map = {} + for col, dbrefs in self.reference_map.iteritems(): + keys = object_map.keys() + refs = list(set([dbref for dbref in dbrefs if str(dbref) not in keys])) + if hasattr(col, 'objects'): # We have a document class for the refs + references = col.objects.in_bulk(refs) + for key, doc in references.iteritems(): + object_map[key] = doc + else: # Generic reference: use the refs data to convert to document + references = _get_db()[col].find({'_id': {'$in': refs}}) + for ref in references: + doc = get_document(ref['_cls'])._from_son(ref) + object_map[doc.id] = doc + return object_map + + def _attach_objects(self, items, depth=0, instance=None, name=None, get=False): + """ + Recursively finds all db references to be dereferenced + + :param items: The iterable (dict, list, queryset) + :param depth: The current depth of recursion + :param instance: The owning instance used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param name: The name of the field, used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param get: A boolean determining if being called by __get__ + """ + if not items: + if isinstance(items, (BaseDict, BaseList)): + return items + + if instance: + if isinstance(items, dict): + return BaseDict(items, instance=instance, name=name) + else: + return BaseList(items, instance=instance, name=name) + + if isinstance(items, (dict, pymongo.son.SON)): + if '_ref' in items: + return self.object_map.get(items['_ref'].id, items) + elif '_types' in items and '_cls' in items: + doc = get_document(items['_cls'])._from_son(items) + if not get: + doc._data = self._attach_objects(doc._data, depth, doc, name, get) + return doc + + if not hasattr(items, 'items'): + is_list = True + iterator = enumerate(items) + data = [] + else: + is_list = False + iterator = items.iteritems() + data = {} + + for k, v in iterator: + if is_list: + data.append(v) + else: + data[k] = v + + if k in self.object_map: + data[k] = self.object_map[k] + elif hasattr(v, '_fields'): + for field_name, field in v._fields.iteritems(): + v = data[k]._data.get(field_name, None) + if isinstance(v, (pymongo.dbref.DBRef)): + data[k]._data[field_name] = self.object_map.get(v.id, v) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) + elif isinstance(v, dict) and depth < self.max_depth: + data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif isinstance(v, (list, tuple)): + data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth: + data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif hasattr(v, 'id'): + data[k] = self.object_map.get(v.id, v) + + if instance and name: + if is_list: + return BaseList(data, instance=instance, name=name) + return BaseDict(data, instance=instance, name=name) + depth += 1 + return data + +dereference = DeReference() diff --git a/mongoengine/document.py b/mongoengine/document.py index e20500d6..31a2530c 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -193,6 +193,11 @@ class Document(BaseDocument): signals.post_delete.send(self.__class__, document=self) + def select_related(self, max_depth=1): + from dereference import dereference + self._data = dereference(self._data, max_depth) + return self + def reload(self): """Reloads all attributes from the database. diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 82138fec..6b110ff0 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -801,13 +801,7 @@ class QuerySet(object): :param object_id: the value for the id of the document to look up """ - id_field = self._document._meta['id_field'] - object_id = self._document._fields[id_field].to_mongo(object_id) - - result = self._collection.find_one({'_id': object_id}, **self._cursor_args) - if result is not None: - result = self._document._from_son(result) - return result + return self._document.objects(pk=object_id).first() def in_bulk(self, object_ids): """Retrieve a set of documents by their ids. @@ -1530,6 +1524,9 @@ class QuerySet(object): data[-1] = "...(remaining elements truncated)..." return repr(data) + def select_related(self, max_depth=1): + from dereference import dereference + return dereference(self, max_depth=max_depth) class QuerySetManager(object): diff --git a/tests/dereference.py b/tests/dereference.py index 4040d5bd..a98267fd 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -30,6 +30,9 @@ class FieldTest(unittest.TestCase): group = Group(members=User.objects) group.save() + group = Group(members=User.objects) + group.save() + with query_counter() as q: self.assertEqual(q, 0) @@ -39,6 +42,24 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 2) + [m for m in group_obj.members] + self.assertEqual(q, 2) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + group_objs = Group.objects.select_related() + self.assertEqual(q, 2) + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 2) + User.drop_collection() Group.drop_collection() @@ -50,6 +71,8 @@ class FieldTest(unittest.TestCase): boss = ReferenceField('self') friends = ListField(ReferenceField('self')) + Employee.drop_collection() + bill = Employee(name='Bill Lumbergh') bill.save() @@ -63,6 +86,10 @@ class FieldTest(unittest.TestCase): peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) peter.save() + Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + with query_counter() as q: self.assertEqual(q, 0) @@ -75,6 +102,33 @@ class FieldTest(unittest.TestCase): peter.friends self.assertEqual(q, 3) + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id).select_related() + self.assertEqual(q, 2) + + self.assertEquals(peter.boss, bill) + self.assertEqual(q, 2) + + self.assertEquals(peter.friends, friends) + self.assertEqual(q, 2) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + employees = Employee.objects(boss=bill).select_related() + self.assertEqual(q, 2) + + for employee in employees: + self.assertEquals(employee.boss, bill) + self.assertEqual(q, 2) + + self.assertEquals(employee.friends, friends) + self.assertEqual(q, 2) + def test_generic_reference(self): class UserA(Document): @@ -110,6 +164,9 @@ class FieldTest(unittest.TestCase): group = Group(members=members) group.save() + group = Group(members=members) + group.save() + with query_counter() as q: self.assertEqual(q, 0) @@ -125,6 +182,39 @@ class FieldTest(unittest.TestCase): for m in group_obj.members: self.assertTrue('User' in m.__class__.__name__) + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() @@ -165,6 +255,9 @@ class FieldTest(unittest.TestCase): group = Group(members=members) group.save() + group = Group(members=members) + group.save() + with query_counter() as q: self.assertEqual(q, 0) @@ -180,6 +273,39 @@ class FieldTest(unittest.TestCase): for m in group_obj.members: self.assertTrue('User' in m.__class__.__name__) + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() @@ -205,6 +331,9 @@ class FieldTest(unittest.TestCase): group = Group(members=dict([(str(u.id), u) for u in members])) group.save() + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + with query_counter() as q: self.assertEqual(q, 0) @@ -217,6 +346,33 @@ class FieldTest(unittest.TestCase): for k, m in group_obj.members.iteritems(): self.assertTrue(isinstance(m, User)) + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 2) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + User.drop_collection() Group.drop_collection() @@ -254,6 +410,8 @@ class FieldTest(unittest.TestCase): group = Group(members=dict([(str(u.id), u) for u in members])) group.save() + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() with query_counter() as q: self.assertEqual(q, 0) @@ -270,8 +428,41 @@ class FieldTest(unittest.TestCase): for k, m in group_obj.members.iteritems(): self.assertTrue('User' in m.__class__.__name__) - group.members = {} - group.save() + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + Group.objects.delete() + Group().save() with query_counter() as q: self.assertEqual(q, 0) @@ -310,6 +501,9 @@ class FieldTest(unittest.TestCase): group = Group(members=dict([(str(u.id), u) for u in members])) group.save() + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + with query_counter() as q: self.assertEqual(q, 0) @@ -325,6 +519,39 @@ class FieldTest(unittest.TestCase): for k, m in group_obj.members.iteritems(): self.assertTrue(isinstance(m, UserA)) + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 2) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + UserA.drop_collection() Group.drop_collection() @@ -362,6 +589,8 @@ class FieldTest(unittest.TestCase): group = Group(members=dict([(str(u.id), u) for u in members])) group.save() + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() with query_counter() as q: self.assertEqual(q, 0) @@ -378,8 +607,41 @@ class FieldTest(unittest.TestCase): for k, m in group_obj.members.iteritems(): self.assertTrue('User' in m.__class__.__name__) - group.members = {} - group.save() + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + Group.objects.delete() + Group().save() with query_counter() as q: self.assertEqual(q, 0) @@ -393,4 +655,4 @@ class FieldTest(unittest.TestCase): UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() - Group.drop_collection() \ No newline at end of file + Group.drop_collection() diff --git a/tests/document.py b/tests/document.py index 28d61332..82488cf1 100644 --- a/tests/document.py +++ b/tests/document.py @@ -932,7 +932,7 @@ class DocumentTest(unittest.TestCase): list_field = ListField() embedded_field = EmbeddedDocumentField(Embedded) - Doc.drop_collection + Doc.drop_collection() doc = Doc() doc.dict_field = {'hello': 'world'} doc.list_field = ['1', 2, {'hello': 'world'}] @@ -1125,7 +1125,7 @@ class DocumentTest(unittest.TestCase): dict_field = DictField() list_field = ListField() - Doc.drop_collection + Doc.drop_collection() doc = Doc() doc.save() @@ -1180,7 +1180,7 @@ class DocumentTest(unittest.TestCase): list_field = ListField() embedded_field = EmbeddedDocumentField(Embedded) - Doc.drop_collection + Doc.drop_collection() doc = Doc() doc.save()