diff --git a/mongoengine/base.py b/mongoengine/base.py index a452f3c8..a2027411 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -3,9 +3,13 @@ from queryset import QuerySet, QuerySetManager import pymongo +_model_registry = {} + + class ValidationError(Exception): pass + class BaseField(object): """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. @@ -158,6 +162,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): """ def __new__(cls, name, bases, attrs): + global _model_registry + super_new = super(TopLevelDocumentMetaclass, cls).__new__ # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. @@ -246,6 +252,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class._meta['id_field'] = 'id' new_class.id = new_class._fields['id'] = ObjectIdField(name='_id') + _model_registry[name] = new_class + return new_class diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 585a8cad..0f73adf5 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,4 @@ -from base import BaseField, ObjectIdField, ValidationError +from base import BaseField, ObjectIdField, ValidationError, _model_registry from document import Document, EmbeddedDocument from connection import _get_db @@ -8,10 +8,10 @@ import datetime import decimal -__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', +__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', - 'DecimalField', 'URLField'] + 'DecimalField', 'URLField', 'GenericReferenceField'] class StringField(BaseField): @@ -241,13 +241,15 @@ class ListField(BaseField): def __get__(self, instance, owner): """Descriptor to automatically dereference references. """ + global _model_registry + if instance is None: # Document class being used rather than a document object return self if isinstance(self.field, ReferenceField): referenced_type = self.field.document_type - # Get value from document instance if available + # Get value from document instance if available value_list = instance._data.get(self.name) if value_list: deref_list = [] @@ -259,7 +261,21 @@ class ListField(BaseField): else: deref_list.append(value) instance._data[self.name] = deref_list - + + if isinstance(self.field, GenericReferenceField): + value_list = instance._data.get(self.name) + if value_list: + deref_list = [] + for value in value_list: + # Dereference DBRefs + if isinstance(value, pymongo.dbref.DBRef): + value = _get_db().dereference(value) + referenced_type = _model_registry[value['_cls']] + deref_list.append(referenced_type._from_son(value)) + else: + deref_list.append(value) + instance._data[self.name] = deref_list + return super(ListField, self).__get__(instance, owner) def to_python(self, value): @@ -302,10 +318,10 @@ class DictField(BaseField): """ if not isinstance(value, dict): raise ValidationError('Only dictionaries may be used in a ' - 'DictField') + 'DictField') if any(('.' in k or '$' in k) for k in value): - raise ValidationError('Invalid dictionary key name - keys may not ' + raise ValidationError('Invalid dictionary key name - keys may not ' 'contain "." or "$" characters') def lookup_member(self, member_name): @@ -367,3 +383,41 @@ class ReferenceField(BaseField): def lookup_member(self, member_name): return self.document_type._fields.get(member_name) + + +class GenericReferenceField(BaseField): + """A reference to *any* :class:`~mongoengine.document.Document` subclass + that will be automatically dereferenced on access (lazily). + """ + + def __get__(self, instance, owner): + global _model_registry + + if instance is None: + return self + + value = instance._data.get(self.name) + if isinstance(value, pymongo.dbref.DBRef): + value = _get_db().dereference(value) + if value is not None: + model = _model_registry[value['_cls']] + instance._data[self.name] = model._from_son(value) + + return super(GenericReferenceField, self).__get__(instance, owner) + + def to_mongo(self, document): + id_field_name = document.__class__._meta['id_field'] + id_field = document.__class__._fields[id_field_name] + + if isinstance(document, Document): + # We need the id from the saved object to create the DBRef + id_ = document.id + if id_ is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + else: + id_ = document + + id_ = id_field.to_mongo(id_) + collection = document._meta['collection'] + return pymongo.dbref.DBRef(collection, id_) diff --git a/tests/fields.py b/tests/fields.py index 388bc102..47534116 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -2,6 +2,8 @@ import unittest import datetime from decimal import Decimal +import pymongo + from mongoengine import * from mongoengine.connection import _get_db @@ -381,6 +383,82 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() + + def test_generic_reference(self): + """Ensure that a GenericReferenceField properly dereferences + relationships to *any* model. + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class Bookmark(Document): + bookmark_object = GenericReferenceField() + + Link.drop_collection() + Post.drop_collection() + Bookmark.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + bm = Bookmark(bookmark_object=post_1) + bm.save() + + bm.reload() + + self.assertEqual(bm.bookmark_object, post_1) + + bm.bookmark_object = link_1 + bm.save() + + bm.reload() + + self.assertEqual(bm.bookmark_object, link_1) + + Link.drop_collection() + Post.drop_collection() + Bookmark.drop_collection() + + def test_generic_reference_list(self): + """Ensure that a ListField properly dereferences + relationships to *any* model via GenericReferenceField. + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class User(Document): + bookmarks = ListField(GenericReferenceField()) + + Link.drop_collection() + Post.drop_collection() + User.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + user = User(bookmarks=[post_1, link_1]) + user.save() + + user.reload() + + self.assertEqual(user.bookmarks[0], post_1) + self.assertEqual(user.bookmarks[1], link_1) + + Link.drop_collection() + Post.drop_collection() + User.drop_collection() if __name__ == '__main__':