Changed how GenericReferenceFields are stored / queried

This commit is contained in:
Harry Marr 2010-02-28 23:15:21 +00:00
parent 81dd5adccf
commit 95a7b33fb4
4 changed files with 44 additions and 31 deletions

View File

@ -3,7 +3,10 @@ from queryset import QuerySet, QuerySetManager
import pymongo import pymongo
_model_registry = {} _document_registry = {}
def get_document(name):
return _document_registry[name]
class ValidationError(Exception): class ValidationError(Exception):
@ -153,7 +156,11 @@ class DocumentMetaclass(type):
doc_fields[attr_name] = attr_value doc_fields[attr_name] = attr_value
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
return super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
for field in new_class._fields.values():
field.owner_document = new_class
return new_class
class TopLevelDocumentMetaclass(DocumentMetaclass): class TopLevelDocumentMetaclass(DocumentMetaclass):
@ -162,7 +169,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
global _model_registry global _document_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__ super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have # Classes defined in this package are abstract and should not have
@ -252,7 +259,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._meta['id_field'] = 'id' new_class._meta['id_field'] = 'id'
new_class.id = new_class._fields['id'] = ObjectIdField(name='_id') new_class.id = new_class._fields['id'] = ObjectIdField(name='_id')
_model_registry[name] = new_class _document_registry[name] = new_class
return new_class return new_class

View File

@ -1,4 +1,4 @@
from base import BaseField, ObjectIdField, ValidationError, _model_registry from base import BaseField, ObjectIdField, ValidationError, get_document
from document import Document, EmbeddedDocument from document import Document, EmbeddedDocument
from connection import _get_db from connection import _get_db
@ -241,8 +241,6 @@ class ListField(BaseField):
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor to automatically dereference references. """Descriptor to automatically dereference references.
""" """
global _model_registry
if instance is None: if instance is None:
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
@ -268,10 +266,8 @@ class ListField(BaseField):
deref_list = [] deref_list = []
for value in value_list: for value in value_list:
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, pymongo.dbref.DBRef): if isinstance(value, (dict, pymongo.son.SON)):
value = _get_db().dereference(value) deref_list.append(self.field.dereference(value))
referenced_type = _model_registry[value['_cls']]
deref_list.append(referenced_type._from_son(value))
else: else:
deref_list.append(value) deref_list.append(value)
instance._data[self.name] = deref_list instance._data[self.name] = deref_list
@ -334,9 +330,10 @@ class ReferenceField(BaseField):
""" """
def __init__(self, document_type, **kwargs): def __init__(self, document_type, **kwargs):
if not issubclass(document_type, Document): if not isinstance(document_type, basestring):
raise ValidationError('Argument to ReferenceField constructor ' if not issubclass(document_type, (Document, basestring)):
'must be a top level document class') raise ValidationError('Argument to ReferenceField constructor '
'must be a document class or a string')
self.document_type = document_type self.document_type = document_type
self.document_obj = None self.document_obj = None
super(ReferenceField, self).__init__(**kwargs) super(ReferenceField, self).__init__(**kwargs)
@ -391,20 +388,23 @@ class GenericReferenceField(BaseField):
""" """
def __get__(self, instance, owner): def __get__(self, instance, owner):
global _model_registry
if instance is None: if instance is None:
return self return self
value = instance._data.get(self.name) value = instance._data.get(self.name)
if isinstance(value, pymongo.dbref.DBRef): if isinstance(value, (dict, pymongo.son.SON)):
value = _get_db().dereference(value) instance._data[self.name] = self.dereference(value)
if value is not None:
model = _model_registry[value['_cls']]
instance._data[self.name] = model._from_son(value)
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)
def dereference(self, value):
doc_cls = get_document(value['_cls'])
reference = value['_ref']
doc = _get_db().dereference(reference)
if doc is not None:
doc = doc_cls._from_son(doc)
return doc
def to_mongo(self, document): def to_mongo(self, document):
id_field_name = document.__class__._meta['id_field'] id_field_name = document.__class__._meta['id_field']
id_field = document.__class__._fields[id_field_name] id_field = document.__class__._fields[id_field_name]
@ -420,4 +420,8 @@ class GenericReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = document._meta['collection'] collection = document._meta['collection']
return pymongo.dbref.DBRef(collection, id_) ref = pymongo.dbref.DBRef(collection, id_)
return {'_cls': document.__class__.__name__, '_ref': ref}
def prepare_query_value(self, op, value):
return self.to_mongo(value)['_ref']

View File

@ -322,6 +322,9 @@ class QuerySet(object):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
if field.__class__.__name__ == 'GenericReferenceField':
parts.append('_ref')
if op and op not in match_operators: if op and op not in match_operators:
value = {'$' + op: value} value = {'$' + op: value}

View File

@ -375,21 +375,21 @@ class FieldTest(unittest.TestCase):
post2 = BlogPost(title='post 2', author=m2) post2 = BlogPost(title='post 2', author=m2)
post2.save() post2.save()
post = BlogPost.objects(author=m1.id).first() post = BlogPost.objects(author=m1).first()
self.assertEqual(post.id, post1.id) self.assertEqual(post.id, post1.id)
post = BlogPost.objects(author=m2.id).first() post = BlogPost.objects(author=m2).first()
self.assertEqual(post.id, post2.id) self.assertEqual(post.id, post2.id)
Member.drop_collection() Member.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_generic_reference(self): def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences """Ensure that a GenericReferenceField properly dereferences items.
relationships to *any* model.
""" """
class Link(Document): class Link(Document):
title = StringField() title = StringField()
meta = {'allow_inheritance': False}
class Post(Document): class Post(Document):
title = StringField() title = StringField()
@ -410,7 +410,7 @@ class FieldTest(unittest.TestCase):
bm = Bookmark(bookmark_object=post_1) bm = Bookmark(bookmark_object=post_1)
bm.save() bm.save()
bm.reload() bm = Bookmark.objects(bookmark_object=post_1).first()
self.assertEqual(bm.bookmark_object, post_1) self.assertEqual(bm.bookmark_object, post_1)
self.assertTrue(isinstance(bm.bookmark_object, Post)) self.assertTrue(isinstance(bm.bookmark_object, Post))
@ -418,7 +418,7 @@ class FieldTest(unittest.TestCase):
bm.bookmark_object = link_1 bm.bookmark_object = link_1
bm.save() bm.save()
bm.reload() bm = Bookmark.objects(bookmark_object=link_1).first()
self.assertEqual(bm.bookmark_object, link_1) self.assertEqual(bm.bookmark_object, link_1)
self.assertTrue(isinstance(bm.bookmark_object, Link)) self.assertTrue(isinstance(bm.bookmark_object, Link))
@ -428,8 +428,7 @@ class FieldTest(unittest.TestCase):
Bookmark.drop_collection() Bookmark.drop_collection()
def test_generic_reference_list(self): def test_generic_reference_list(self):
"""Ensure that a ListField properly dereferences """Ensure that a ListField properly dereferences generic references.
relationships to *any* model via GenericReferenceField.
""" """
class Link(Document): class Link(Document):
title = StringField() title = StringField()
@ -453,7 +452,7 @@ class FieldTest(unittest.TestCase):
user = User(bookmarks=[post_1, link_1]) user = User(bookmarks=[post_1, link_1])
user.save() user.save()
user.reload() user = User.objects(bookmarks__all=[post_1, link_1]).first()
self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[0], post_1)
self.assertEqual(user.bookmarks[1], link_1) self.assertEqual(user.bookmarks[1], link_1)