diff --git a/mongoengine/base.py b/mongoengine/base.py index 02c3e661..1921cba7 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -3,7 +3,10 @@ from queryset import QuerySet, QuerySetManager import pymongo -_model_registry = {} +_document_registry = {} + +def get_document(name): + return _document_registry[name] class ValidationError(Exception): @@ -153,7 +156,11 @@ class DocumentMetaclass(type): doc_fields[attr_name] = attr_value attrs['_fields'] = doc_fields - return super_new(cls, name, bases, attrs) + new_class = super_new(cls, name, bases, attrs) + for field in new_class._fields.values(): + field.owner_document = new_class + + return new_class class TopLevelDocumentMetaclass(DocumentMetaclass): @@ -162,7 +169,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): """ def __new__(cls, name, bases, attrs): - global _model_registry + global _document_registry super_new = super(TopLevelDocumentMetaclass, cls).__new__ # Classes defined in this package are abstract and should not have @@ -252,7 +259,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class._meta['id_field'] = 'id' new_class.id = new_class._fields['id'] = ObjectIdField(name='_id') - _model_registry[name] = new_class + _document_registry[name] = new_class return new_class diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0f73adf5..a4ee7c13 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,4 @@ -from base import BaseField, ObjectIdField, ValidationError, _model_registry +from base import BaseField, ObjectIdField, ValidationError, get_document from document import Document, EmbeddedDocument from connection import _get_db @@ -241,8 +241,6 @@ class ListField(BaseField): def __get__(self, instance, owner): """Descriptor to automatically dereference references. """ - global _model_registry - if instance is None: # Document class being used rather than a document object return self @@ -268,10 +266,8 @@ class ListField(BaseField): deref_list = [] for value in value_list: # Dereference DBRefs - if isinstance(value, pymongo.dbref.DBRef): - value = _get_db().dereference(value) - referenced_type = _model_registry[value['_cls']] - deref_list.append(referenced_type._from_son(value)) + if isinstance(value, (dict, pymongo.son.SON)): + deref_list.append(self.field.dereference(value)) else: deref_list.append(value) instance._data[self.name] = deref_list @@ -334,9 +330,10 @@ class ReferenceField(BaseField): """ def __init__(self, document_type, **kwargs): - if not issubclass(document_type, Document): - raise ValidationError('Argument to ReferenceField constructor ' - 'must be a top level document class') + if not isinstance(document_type, basestring): + if not issubclass(document_type, (Document, basestring)): + raise ValidationError('Argument to ReferenceField constructor ' + 'must be a document class or a string') self.document_type = document_type self.document_obj = None super(ReferenceField, self).__init__(**kwargs) @@ -391,20 +388,23 @@ class GenericReferenceField(BaseField): """ def __get__(self, instance, owner): - global _model_registry - if instance is None: return self value = instance._data.get(self.name) - if isinstance(value, pymongo.dbref.DBRef): - value = _get_db().dereference(value) - if value is not None: - model = _model_registry[value['_cls']] - instance._data[self.name] = model._from_son(value) + if isinstance(value, (dict, pymongo.son.SON)): + instance._data[self.name] = self.dereference(value) return super(GenericReferenceField, self).__get__(instance, owner) + def dereference(self, value): + doc_cls = get_document(value['_cls']) + reference = value['_ref'] + doc = _get_db().dereference(reference) + if doc is not None: + doc = doc_cls._from_son(doc) + return doc + def to_mongo(self, document): id_field_name = document.__class__._meta['id_field'] id_field = document.__class__._fields[id_field_name] @@ -420,4 +420,8 @@ class GenericReferenceField(BaseField): id_ = id_field.to_mongo(id_) collection = document._meta['collection'] - return pymongo.dbref.DBRef(collection, id_) + ref = pymongo.dbref.DBRef(collection, id_) + return {'_cls': document.__class__.__name__, '_ref': ref} + + def prepare_query_value(self, op, value): + return self.to_mongo(value)['_ref'] diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0acb8696..8c13dc93 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -322,6 +322,9 @@ class QuerySet(object): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] + if field.__class__.__name__ == 'GenericReferenceField': + parts.append('_ref') + if op and op not in match_operators: value = {'$' + op: value} diff --git a/tests/fields.py b/tests/fields.py index 5a1a22fc..382cd455 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -375,21 +375,21 @@ class FieldTest(unittest.TestCase): post2 = BlogPost(title='post 2', author=m2) post2.save() - post = BlogPost.objects(author=m1.id).first() + post = BlogPost.objects(author=m1).first() self.assertEqual(post.id, post1.id) - post = BlogPost.objects(author=m2.id).first() + post = BlogPost.objects(author=m2).first() self.assertEqual(post.id, post2.id) Member.drop_collection() BlogPost.drop_collection() def test_generic_reference(self): - """Ensure that a GenericReferenceField properly dereferences - relationships to *any* model. + """Ensure that a GenericReferenceField properly dereferences items. """ class Link(Document): title = StringField() + meta = {'allow_inheritance': False} class Post(Document): title = StringField() @@ -410,7 +410,7 @@ class FieldTest(unittest.TestCase): bm = Bookmark(bookmark_object=post_1) bm.save() - bm.reload() + bm = Bookmark.objects(bookmark_object=post_1).first() self.assertEqual(bm.bookmark_object, post_1) self.assertTrue(isinstance(bm.bookmark_object, Post)) @@ -418,7 +418,7 @@ class FieldTest(unittest.TestCase): bm.bookmark_object = link_1 bm.save() - bm.reload() + bm = Bookmark.objects(bookmark_object=link_1).first() self.assertEqual(bm.bookmark_object, link_1) self.assertTrue(isinstance(bm.bookmark_object, Link)) @@ -428,8 +428,7 @@ class FieldTest(unittest.TestCase): Bookmark.drop_collection() def test_generic_reference_list(self): - """Ensure that a ListField properly dereferences - relationships to *any* model via GenericReferenceField. + """Ensure that a ListField properly dereferences generic references. """ class Link(Document): title = StringField() @@ -453,7 +452,7 @@ class FieldTest(unittest.TestCase): user = User(bookmarks=[post_1, link_1]) user.save() - user.reload() + user = User.objects(bookmarks__all=[post_1, link_1]).first() self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[1], link_1)