Changed how GenericReferenceFields are stored / queried
This commit is contained in:
parent
81dd5adccf
commit
95a7b33fb4
@ -3,7 +3,10 @@ from queryset import QuerySet, QuerySetManager
|
|||||||
import pymongo
|
import pymongo
|
||||||
|
|
||||||
|
|
||||||
_model_registry = {}
|
_document_registry = {}
|
||||||
|
|
||||||
|
def get_document(name):
|
||||||
|
return _document_registry[name]
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(Exception):
|
class ValidationError(Exception):
|
||||||
@ -153,7 +156,11 @@ class DocumentMetaclass(type):
|
|||||||
doc_fields[attr_name] = attr_value
|
doc_fields[attr_name] = attr_value
|
||||||
attrs['_fields'] = doc_fields
|
attrs['_fields'] = doc_fields
|
||||||
|
|
||||||
return super_new(cls, name, bases, attrs)
|
new_class = super_new(cls, name, bases, attrs)
|
||||||
|
for field in new_class._fields.values():
|
||||||
|
field.owner_document = new_class
|
||||||
|
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
class TopLevelDocumentMetaclass(DocumentMetaclass):
|
class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||||
@ -162,7 +169,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
global _model_registry
|
global _document_registry
|
||||||
|
|
||||||
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
||||||
# Classes defined in this package are abstract and should not have
|
# Classes defined in this package are abstract and should not have
|
||||||
@ -252,7 +259,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
new_class._meta['id_field'] = 'id'
|
new_class._meta['id_field'] = 'id'
|
||||||
new_class.id = new_class._fields['id'] = ObjectIdField(name='_id')
|
new_class.id = new_class._fields['id'] = ObjectIdField(name='_id')
|
||||||
|
|
||||||
_model_registry[name] = new_class
|
_document_registry[name] = new_class
|
||||||
|
|
||||||
return new_class
|
return new_class
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from base import BaseField, ObjectIdField, ValidationError, _model_registry
|
from base import BaseField, ObjectIdField, ValidationError, get_document
|
||||||
from document import Document, EmbeddedDocument
|
from document import Document, EmbeddedDocument
|
||||||
from connection import _get_db
|
from connection import _get_db
|
||||||
|
|
||||||
@ -241,8 +241,6 @@ class ListField(BaseField):
|
|||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
"""Descriptor to automatically dereference references.
|
"""Descriptor to automatically dereference references.
|
||||||
"""
|
"""
|
||||||
global _model_registry
|
|
||||||
|
|
||||||
if instance is None:
|
if instance is None:
|
||||||
# Document class being used rather than a document object
|
# Document class being used rather than a document object
|
||||||
return self
|
return self
|
||||||
@ -268,10 +266,8 @@ class ListField(BaseField):
|
|||||||
deref_list = []
|
deref_list = []
|
||||||
for value in value_list:
|
for value in value_list:
|
||||||
# Dereference DBRefs
|
# Dereference DBRefs
|
||||||
if isinstance(value, pymongo.dbref.DBRef):
|
if isinstance(value, (dict, pymongo.son.SON)):
|
||||||
value = _get_db().dereference(value)
|
deref_list.append(self.field.dereference(value))
|
||||||
referenced_type = _model_registry[value['_cls']]
|
|
||||||
deref_list.append(referenced_type._from_son(value))
|
|
||||||
else:
|
else:
|
||||||
deref_list.append(value)
|
deref_list.append(value)
|
||||||
instance._data[self.name] = deref_list
|
instance._data[self.name] = deref_list
|
||||||
@ -334,9 +330,10 @@ class ReferenceField(BaseField):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, document_type, **kwargs):
|
def __init__(self, document_type, **kwargs):
|
||||||
if not issubclass(document_type, Document):
|
if not isinstance(document_type, basestring):
|
||||||
|
if not issubclass(document_type, (Document, basestring)):
|
||||||
raise ValidationError('Argument to ReferenceField constructor '
|
raise ValidationError('Argument to ReferenceField constructor '
|
||||||
'must be a top level document class')
|
'must be a document class or a string')
|
||||||
self.document_type = document_type
|
self.document_type = document_type
|
||||||
self.document_obj = None
|
self.document_obj = None
|
||||||
super(ReferenceField, self).__init__(**kwargs)
|
super(ReferenceField, self).__init__(**kwargs)
|
||||||
@ -391,20 +388,23 @@ class GenericReferenceField(BaseField):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
global _model_registry
|
|
||||||
|
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
value = instance._data.get(self.name)
|
value = instance._data.get(self.name)
|
||||||
if isinstance(value, pymongo.dbref.DBRef):
|
if isinstance(value, (dict, pymongo.son.SON)):
|
||||||
value = _get_db().dereference(value)
|
instance._data[self.name] = self.dereference(value)
|
||||||
if value is not None:
|
|
||||||
model = _model_registry[value['_cls']]
|
|
||||||
instance._data[self.name] = model._from_son(value)
|
|
||||||
|
|
||||||
return super(GenericReferenceField, self).__get__(instance, owner)
|
return super(GenericReferenceField, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
def dereference(self, value):
|
||||||
|
doc_cls = get_document(value['_cls'])
|
||||||
|
reference = value['_ref']
|
||||||
|
doc = _get_db().dereference(reference)
|
||||||
|
if doc is not None:
|
||||||
|
doc = doc_cls._from_son(doc)
|
||||||
|
return doc
|
||||||
|
|
||||||
def to_mongo(self, document):
|
def to_mongo(self, document):
|
||||||
id_field_name = document.__class__._meta['id_field']
|
id_field_name = document.__class__._meta['id_field']
|
||||||
id_field = document.__class__._fields[id_field_name]
|
id_field = document.__class__._fields[id_field_name]
|
||||||
@ -420,4 +420,8 @@ class GenericReferenceField(BaseField):
|
|||||||
|
|
||||||
id_ = id_field.to_mongo(id_)
|
id_ = id_field.to_mongo(id_)
|
||||||
collection = document._meta['collection']
|
collection = document._meta['collection']
|
||||||
return pymongo.dbref.DBRef(collection, id_)
|
ref = pymongo.dbref.DBRef(collection, id_)
|
||||||
|
return {'_cls': document.__class__.__name__, '_ref': ref}
|
||||||
|
|
||||||
|
def prepare_query_value(self, op, value):
|
||||||
|
return self.to_mongo(value)['_ref']
|
||||||
|
@ -322,6 +322,9 @@ class QuerySet(object):
|
|||||||
# 'in', 'nin' and 'all' require a list of values
|
# 'in', 'nin' and 'all' require a list of values
|
||||||
value = [field.prepare_query_value(op, v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
|
|
||||||
|
if field.__class__.__name__ == 'GenericReferenceField':
|
||||||
|
parts.append('_ref')
|
||||||
|
|
||||||
if op and op not in match_operators:
|
if op and op not in match_operators:
|
||||||
value = {'$' + op: value}
|
value = {'$' + op: value}
|
||||||
|
|
||||||
|
@ -375,21 +375,21 @@ class FieldTest(unittest.TestCase):
|
|||||||
post2 = BlogPost(title='post 2', author=m2)
|
post2 = BlogPost(title='post 2', author=m2)
|
||||||
post2.save()
|
post2.save()
|
||||||
|
|
||||||
post = BlogPost.objects(author=m1.id).first()
|
post = BlogPost.objects(author=m1).first()
|
||||||
self.assertEqual(post.id, post1.id)
|
self.assertEqual(post.id, post1.id)
|
||||||
|
|
||||||
post = BlogPost.objects(author=m2.id).first()
|
post = BlogPost.objects(author=m2).first()
|
||||||
self.assertEqual(post.id, post2.id)
|
self.assertEqual(post.id, post2.id)
|
||||||
|
|
||||||
Member.drop_collection()
|
Member.drop_collection()
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
def test_generic_reference(self):
|
def test_generic_reference(self):
|
||||||
"""Ensure that a GenericReferenceField properly dereferences
|
"""Ensure that a GenericReferenceField properly dereferences items.
|
||||||
relationships to *any* model.
|
|
||||||
"""
|
"""
|
||||||
class Link(Document):
|
class Link(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
|
meta = {'allow_inheritance': False}
|
||||||
|
|
||||||
class Post(Document):
|
class Post(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
@ -410,7 +410,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
bm = Bookmark(bookmark_object=post_1)
|
bm = Bookmark(bookmark_object=post_1)
|
||||||
bm.save()
|
bm.save()
|
||||||
|
|
||||||
bm.reload()
|
bm = Bookmark.objects(bookmark_object=post_1).first()
|
||||||
|
|
||||||
self.assertEqual(bm.bookmark_object, post_1)
|
self.assertEqual(bm.bookmark_object, post_1)
|
||||||
self.assertTrue(isinstance(bm.bookmark_object, Post))
|
self.assertTrue(isinstance(bm.bookmark_object, Post))
|
||||||
@ -418,7 +418,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
bm.bookmark_object = link_1
|
bm.bookmark_object = link_1
|
||||||
bm.save()
|
bm.save()
|
||||||
|
|
||||||
bm.reload()
|
bm = Bookmark.objects(bookmark_object=link_1).first()
|
||||||
|
|
||||||
self.assertEqual(bm.bookmark_object, link_1)
|
self.assertEqual(bm.bookmark_object, link_1)
|
||||||
self.assertTrue(isinstance(bm.bookmark_object, Link))
|
self.assertTrue(isinstance(bm.bookmark_object, Link))
|
||||||
@ -428,8 +428,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
Bookmark.drop_collection()
|
Bookmark.drop_collection()
|
||||||
|
|
||||||
def test_generic_reference_list(self):
|
def test_generic_reference_list(self):
|
||||||
"""Ensure that a ListField properly dereferences
|
"""Ensure that a ListField properly dereferences generic references.
|
||||||
relationships to *any* model via GenericReferenceField.
|
|
||||||
"""
|
"""
|
||||||
class Link(Document):
|
class Link(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
@ -453,7 +452,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
user = User(bookmarks=[post_1, link_1])
|
user = User(bookmarks=[post_1, link_1])
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
user.reload()
|
user = User.objects(bookmarks__all=[post_1, link_1]).first()
|
||||||
|
|
||||||
self.assertEqual(user.bookmarks[0], post_1)
|
self.assertEqual(user.bookmarks[0], post_1)
|
||||||
self.assertEqual(user.bookmarks[1], link_1)
|
self.assertEqual(user.bookmarks[1], link_1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user