added global model registry and GenericReferenceField, a ReferenceField not bound to a particular model

This commit is contained in:
blackbrrr 2010-02-26 16:59:12 -06:00
parent 200e9eca92
commit 03d31b1890
3 changed files with 160 additions and 7 deletions

View File

@ -3,9 +3,13 @@ from queryset import QuerySet, QuerySetManager
import pymongo
_model_registry = {}
class ValidationError(Exception):
pass
class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
@ -158,6 +162,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
"""
def __new__(cls, name, bases, attrs):
global _model_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc.
@ -246,6 +252,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._meta['id_field'] = 'id'
new_class.id = new_class._fields['id'] = ObjectIdField(name='_id')
_model_registry[name] = new_class
return new_class

View File

@ -1,4 +1,4 @@
from base import BaseField, ObjectIdField, ValidationError
from base import BaseField, ObjectIdField, ValidationError, _model_registry
from document import Document, EmbeddedDocument
from connection import _get_db
@ -11,7 +11,7 @@ import decimal
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError',
'DecimalField', 'URLField']
'DecimalField', 'URLField', 'GenericReferenceField']
class StringField(BaseField):
@ -215,7 +215,7 @@ class ListField(BaseField):
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
if isinstance(self.field, (ReferenceField, GenericReferenceField)):
referenced_type = self.field.document_type
# Get value from document instance if available
value_list = instance._data.get(self.name)
@ -230,6 +230,20 @@ class ListField(BaseField):
deref_list.append(value)
instance._data[self.name] = deref_list
# if isinstance(self.field, GenericReferenceField):
# value_list = instance._data.get(self.name)
# if value_list:
# deref_list = []
# for value in value_list:
# # Dereference DBRefs
# if isinstance(value, pymongo.dbref.DBRef):
# value = _get_db().dereference(value)
# referenced_type = value.
# deref_list.append()
# else:
# deref_list.append(value)
# instance._data[self.name] = deref_list
return super(ListField, self).__get__(instance, owner)
def to_python(self, value):
@ -337,3 +351,41 @@ class ReferenceField(BaseField):
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).
"""
def __get__(self, instance, owner):
global _model_registry
if instance is None:
return self
value = instance._data.get(self.name)
if isinstance(value, pymongo.dbref.DBRef):
value = _get_db().dereference(value)
if value is not None:
model = _model_registry[value['_cls']]
instance._data[self.name] = model._from_son(value)
return super(GenericReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
id_field_name = document.__class__._meta['id_field']
id_field = document.__class__._fields[id_field_name]
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.id
if id_ is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
else:
id_ = document
id_ = id_field.to_mongo(id_)
collection = document._meta['collection']
return pymongo.dbref.DBRef(collection, id_)

View File

@ -2,6 +2,8 @@ import unittest
import datetime
from decimal import Decimal
import pymongo
from mongoengine import *
from mongoengine.connection import _get_db
@ -349,6 +351,97 @@ class FieldTest(unittest.TestCase):
Member.drop_collection()
BlogPost.drop_collection()
def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences
relationships to *any* model.
"""
class Link(Document):
title = StringField()
class Post(Document):
title = StringField()
class Bookmark(Document):
bookmark_object = GenericReferenceField()
Link.drop_collection()
Post.drop_collection()
Bookmark.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
post_1.save()
bm = Bookmark(bookmark_object=post_1)
bm.save()
bm.reload()
self.assertEqual(bm.bookmark_object, post_1)
bm.bookmark_object = link_1
bm.save()
bm.reload()
self.assertEqual(bm.bookmark_object, link_1)
Link.drop_collection()
Post.drop_collection()
Bookmark.drop_collection()
# def test_generic_reference_list(self):
# """Ensure that a ListField properly dereferences
# relationships to *any* model via GenericReferenceField.
# """
# class Link(Document):
# title = StringField()
#
# class Post(Document):
# title = StringField()
#
# class User(Document):
# bookmarks = ListField(GenericReferenceField())
#
# Link.drop_collection()
# Post.drop_collection()
# User.drop_collection()
#
# link_1 = Link(title="Pitchfork")
# link_1.save()
#
# post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
# post_1.save()
#
# user = User(bookmarks=[post_1, link_1])
# user.save()
#
# del user
#
# user = User.objects().first()
# print user.bookmarks
#
# # print dir(user)
#
# # self.assertEqual(bm.bookmark_object, post_1)
# # self.assertEqual(bm._data['bookmark_object'].__class__,
# # pymongo.dbref.DBRef)
# #
# # bm.bookmark_object = link_1
# # bm.save()
# #
# # bm.reload()
# #
# # self.assertEqual(bm.bookmark_object, link_1)
# # self.assertEqual(bm._data['bookmark_object'].__class__,
# # pymongo.dbref.DBRef)
#
# Link.drop_collection()
# Post.drop_collection()
# User.drop_collection()
if __name__ == '__main__':
unittest.main()