added global model registry and GenericReferenceField, a ReferenceField not bound to a particular model
This commit is contained in:
parent
200e9eca92
commit
03d31b1890
@ -3,9 +3,13 @@ from queryset import QuerySet, QuerySetManager
|
|||||||
import pymongo
|
import pymongo
|
||||||
|
|
||||||
|
|
||||||
|
_model_registry = {}
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(Exception):
|
class ValidationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
"""A base class for fields in a MongoDB document. Instances of this class
|
"""A base class for fields in a MongoDB document. Instances of this class
|
||||||
may be added to subclasses of `Document` to define a document's schema.
|
may be added to subclasses of `Document` to define a document's schema.
|
||||||
@ -158,6 +162,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
|
global _model_registry
|
||||||
|
|
||||||
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
||||||
# Classes defined in this package are abstract and should not have
|
# Classes defined in this package are abstract and should not have
|
||||||
# their own metadata with DB collection, etc.
|
# their own metadata with DB collection, etc.
|
||||||
@ -246,6 +252,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
new_class._meta['id_field'] = 'id'
|
new_class._meta['id_field'] = 'id'
|
||||||
new_class.id = new_class._fields['id'] = ObjectIdField(name='_id')
|
new_class.id = new_class._fields['id'] = ObjectIdField(name='_id')
|
||||||
|
|
||||||
|
_model_registry[name] = new_class
|
||||||
|
|
||||||
return new_class
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from base import BaseField, ObjectIdField, ValidationError
|
from base import BaseField, ObjectIdField, ValidationError, _model_registry
|
||||||
from document import Document, EmbeddedDocument
|
from document import Document, EmbeddedDocument
|
||||||
from connection import _get_db
|
from connection import _get_db
|
||||||
|
|
||||||
@ -8,10 +8,10 @@ import datetime
|
|||||||
import decimal
|
import decimal
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
|
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
|
||||||
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
|
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
|
||||||
'ObjectIdField', 'ReferenceField', 'ValidationError',
|
'ObjectIdField', 'ReferenceField', 'ValidationError',
|
||||||
'DecimalField', 'URLField']
|
'DecimalField', 'URLField', 'GenericReferenceField']
|
||||||
|
|
||||||
|
|
||||||
class StringField(BaseField):
|
class StringField(BaseField):
|
||||||
@ -215,7 +215,7 @@ class ListField(BaseField):
|
|||||||
# Document class being used rather than a document object
|
# Document class being used rather than a document object
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if isinstance(self.field, ReferenceField):
|
if isinstance(self.field, (ReferenceField, GenericReferenceField)):
|
||||||
referenced_type = self.field.document_type
|
referenced_type = self.field.document_type
|
||||||
# Get value from document instance if available
|
# Get value from document instance if available
|
||||||
value_list = instance._data.get(self.name)
|
value_list = instance._data.get(self.name)
|
||||||
@ -229,7 +229,21 @@ class ListField(BaseField):
|
|||||||
else:
|
else:
|
||||||
deref_list.append(value)
|
deref_list.append(value)
|
||||||
instance._data[self.name] = deref_list
|
instance._data[self.name] = deref_list
|
||||||
|
|
||||||
|
# if isinstance(self.field, GenericReferenceField):
|
||||||
|
# value_list = instance._data.get(self.name)
|
||||||
|
# if value_list:
|
||||||
|
# deref_list = []
|
||||||
|
# for value in value_list:
|
||||||
|
# # Dereference DBRefs
|
||||||
|
# if isinstance(value, pymongo.dbref.DBRef):
|
||||||
|
# value = _get_db().dereference(value)
|
||||||
|
# referenced_type = value.
|
||||||
|
# deref_list.append()
|
||||||
|
# else:
|
||||||
|
# deref_list.append(value)
|
||||||
|
# instance._data[self.name] = deref_list
|
||||||
|
|
||||||
return super(ListField, self).__get__(instance, owner)
|
return super(ListField, self).__get__(instance, owner)
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
@ -272,10 +286,10 @@ class DictField(BaseField):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise ValidationError('Only dictionaries may be used in a '
|
raise ValidationError('Only dictionaries may be used in a '
|
||||||
'DictField')
|
'DictField')
|
||||||
|
|
||||||
if any(('.' in k or '$' in k) for k in value):
|
if any(('.' in k or '$' in k) for k in value):
|
||||||
raise ValidationError('Invalid dictionary key name - keys may not '
|
raise ValidationError('Invalid dictionary key name - keys may not '
|
||||||
'contain "." or "$" characters')
|
'contain "." or "$" characters')
|
||||||
|
|
||||||
def lookup_member(self, member_name):
|
def lookup_member(self, member_name):
|
||||||
@ -337,3 +351,41 @@ class ReferenceField(BaseField):
|
|||||||
|
|
||||||
def lookup_member(self, member_name):
|
def lookup_member(self, member_name):
|
||||||
return self.document_type._fields.get(member_name)
|
return self.document_type._fields.get(member_name)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericReferenceField(BaseField):
|
||||||
|
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
|
||||||
|
that will be automatically dereferenced on access (lazily).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
global _model_registry
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
value = instance._data.get(self.name)
|
||||||
|
if isinstance(value, pymongo.dbref.DBRef):
|
||||||
|
value = _get_db().dereference(value)
|
||||||
|
if value is not None:
|
||||||
|
model = _model_registry[value['_cls']]
|
||||||
|
instance._data[self.name] = model._from_son(value)
|
||||||
|
|
||||||
|
return super(GenericReferenceField, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
def to_mongo(self, document):
|
||||||
|
id_field_name = document.__class__._meta['id_field']
|
||||||
|
id_field = document.__class__._fields[id_field_name]
|
||||||
|
|
||||||
|
if isinstance(document, Document):
|
||||||
|
# We need the id from the saved object to create the DBRef
|
||||||
|
id_ = document.id
|
||||||
|
if id_ is None:
|
||||||
|
raise ValidationError('You can only reference documents once '
|
||||||
|
'they have been saved to the database')
|
||||||
|
else:
|
||||||
|
id_ = document
|
||||||
|
|
||||||
|
id_ = id_field.to_mongo(id_)
|
||||||
|
collection = document._meta['collection']
|
||||||
|
return pymongo.dbref.DBRef(collection, id_)
|
||||||
|
@ -2,6 +2,8 @@ import unittest
|
|||||||
import datetime
|
import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import pymongo
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import _get_db
|
from mongoengine.connection import _get_db
|
||||||
|
|
||||||
@ -348,6 +350,97 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
Member.drop_collection()
|
Member.drop_collection()
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
def test_generic_reference(self):
|
||||||
|
"""Ensure that a GenericReferenceField properly dereferences
|
||||||
|
relationships to *any* model.
|
||||||
|
"""
|
||||||
|
class Link(Document):
|
||||||
|
title = StringField()
|
||||||
|
|
||||||
|
class Post(Document):
|
||||||
|
title = StringField()
|
||||||
|
|
||||||
|
class Bookmark(Document):
|
||||||
|
bookmark_object = GenericReferenceField()
|
||||||
|
|
||||||
|
Link.drop_collection()
|
||||||
|
Post.drop_collection()
|
||||||
|
Bookmark.drop_collection()
|
||||||
|
|
||||||
|
link_1 = Link(title="Pitchfork")
|
||||||
|
link_1.save()
|
||||||
|
|
||||||
|
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
|
||||||
|
post_1.save()
|
||||||
|
|
||||||
|
bm = Bookmark(bookmark_object=post_1)
|
||||||
|
bm.save()
|
||||||
|
|
||||||
|
bm.reload()
|
||||||
|
|
||||||
|
self.assertEqual(bm.bookmark_object, post_1)
|
||||||
|
|
||||||
|
bm.bookmark_object = link_1
|
||||||
|
bm.save()
|
||||||
|
|
||||||
|
bm.reload()
|
||||||
|
|
||||||
|
self.assertEqual(bm.bookmark_object, link_1)
|
||||||
|
|
||||||
|
Link.drop_collection()
|
||||||
|
Post.drop_collection()
|
||||||
|
Bookmark.drop_collection()
|
||||||
|
|
||||||
|
# def test_generic_reference_list(self):
|
||||||
|
# """Ensure that a ListField properly dereferences
|
||||||
|
# relationships to *any* model via GenericReferenceField.
|
||||||
|
# """
|
||||||
|
# class Link(Document):
|
||||||
|
# title = StringField()
|
||||||
|
#
|
||||||
|
# class Post(Document):
|
||||||
|
# title = StringField()
|
||||||
|
#
|
||||||
|
# class User(Document):
|
||||||
|
# bookmarks = ListField(GenericReferenceField())
|
||||||
|
#
|
||||||
|
# Link.drop_collection()
|
||||||
|
# Post.drop_collection()
|
||||||
|
# User.drop_collection()
|
||||||
|
#
|
||||||
|
# link_1 = Link(title="Pitchfork")
|
||||||
|
# link_1.save()
|
||||||
|
#
|
||||||
|
# post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
|
||||||
|
# post_1.save()
|
||||||
|
#
|
||||||
|
# user = User(bookmarks=[post_1, link_1])
|
||||||
|
# user.save()
|
||||||
|
#
|
||||||
|
# del user
|
||||||
|
#
|
||||||
|
# user = User.objects().first()
|
||||||
|
# print user.bookmarks
|
||||||
|
#
|
||||||
|
# # print dir(user)
|
||||||
|
#
|
||||||
|
# # self.assertEqual(bm.bookmark_object, post_1)
|
||||||
|
# # self.assertEqual(bm._data['bookmark_object'].__class__,
|
||||||
|
# # pymongo.dbref.DBRef)
|
||||||
|
# #
|
||||||
|
# # bm.bookmark_object = link_1
|
||||||
|
# # bm.save()
|
||||||
|
# #
|
||||||
|
# # bm.reload()
|
||||||
|
# #
|
||||||
|
# # self.assertEqual(bm.bookmark_object, link_1)
|
||||||
|
# # self.assertEqual(bm._data['bookmark_object'].__class__,
|
||||||
|
# # pymongo.dbref.DBRef)
|
||||||
|
#
|
||||||
|
# Link.drop_collection()
|
||||||
|
# Post.drop_collection()
|
||||||
|
# User.drop_collection()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user