Queries now return correct subclasses of Documents
This commit is contained in:
parent
0674e3c013
commit
744077b150
@ -97,10 +97,18 @@ class DocumentMetaclass(type):
|
|||||||
return super_new(cls, name, bases, attrs)
|
return super_new(cls, name, bases, attrs)
|
||||||
|
|
||||||
doc_fields = {}
|
doc_fields = {}
|
||||||
# Include all fields present in superclasses
|
class_name = [name]
|
||||||
|
superclasses = {}
|
||||||
for base in bases:
|
for base in bases:
|
||||||
|
# Include all fields present in superclasses
|
||||||
if hasattr(base, '_fields'):
|
if hasattr(base, '_fields'):
|
||||||
doc_fields.update(base._fields)
|
doc_fields.update(base._fields)
|
||||||
|
class_name.append(base._class_name)
|
||||||
|
# Get superclasses from superclass
|
||||||
|
superclasses[base._class_name] = base
|
||||||
|
superclasses.update(base._superclasses)
|
||||||
|
attrs['_class_name'] = '.'.join(reversed(class_name))
|
||||||
|
attrs['_superclasses'] = superclasses
|
||||||
|
|
||||||
# Add the document's fields to the _fields attribute
|
# Add the document's fields to the _fields attribute
|
||||||
for attr_name, attr_value in attrs.items():
|
for attr_name, attr_value in attrs.items():
|
||||||
@ -164,6 +172,21 @@ class BaseDocument(object):
|
|||||||
# Use default value
|
# Use default value
|
||||||
setattr(self, attr_name, getattr(self, attr_name, None))
|
setattr(self, attr_name, getattr(self, attr_name, None))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_subclasses(cls):
|
||||||
|
"""Return a dictionary of all subclasses (found recursively).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
subclasses = cls.__subclasses__()
|
||||||
|
except:
|
||||||
|
subclasses = cls.__subclasses__(cls)
|
||||||
|
|
||||||
|
all_subclasses = {}
|
||||||
|
for subclass in subclasses:
|
||||||
|
all_subclasses[subclass._class_name] = subclass
|
||||||
|
all_subclasses.update(subclass._get_subclasses())
|
||||||
|
return all_subclasses
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# Use _data rather than _fields as iterator only looks at names so
|
# Use _data rather than _fields as iterator only looks at names so
|
||||||
# values don't need to be converted to Python types
|
# values don't need to be converted to Python types
|
||||||
@ -203,12 +226,25 @@ class BaseDocument(object):
|
|||||||
value = getattr(self, field_name, None)
|
value = getattr(self, field_name, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
data[field_name] = field._to_mongo(value)
|
data[field_name] = field._to_mongo(value)
|
||||||
|
data['_cls'] = self._class_name
|
||||||
|
data['_types'] = self._superclasses.keys() + [self._class_name]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_son(cls, son):
|
def _from_son(cls, son):
|
||||||
"""Create an instance of a Document (subclass) from a PyMongo SOM.
|
"""Create an instance of a Document (subclass) from a PyMongo SOM.
|
||||||
"""
|
"""
|
||||||
|
class_name = son[u'_cls']
|
||||||
data = dict((str(key), value) for key, value in son.items())
|
data = dict((str(key), value) for key, value in son.items())
|
||||||
|
del data['_cls']
|
||||||
|
|
||||||
|
# Return correct subclass for document type
|
||||||
|
if class_name != cls._class_name:
|
||||||
|
subclasses = cls._get_subclasses()
|
||||||
|
if class_name not in subclasses:
|
||||||
|
# Type of document is probably more generic than the class
|
||||||
|
# that has been queried to return this SON
|
||||||
|
return None
|
||||||
|
cls = subclasses[class_name]
|
||||||
return cls(**data)
|
return cls(**data)
|
||||||
|
|
||||||
|
@ -76,13 +76,13 @@ class CollectionManager(object):
|
|||||||
def find(self, **query):
|
def find(self, **query):
|
||||||
"""Query the collection for document matching the provided query.
|
"""Query the collection for document matching the provided query.
|
||||||
"""
|
"""
|
||||||
if query:
|
|
||||||
query = self._transform_query(**query)
|
query = self._transform_query(**query)
|
||||||
|
query['_types'] = self._document._class_name
|
||||||
return QuerySet(self._document, self._collection.find(query))
|
return QuerySet(self._document, self._collection.find(query))
|
||||||
|
|
||||||
def find_one(self, **query):
|
def find_one(self, **query):
|
||||||
"""Query the collection for document matching the provided query.
|
"""Query the collection for document matching the provided query.
|
||||||
"""
|
"""
|
||||||
if query:
|
|
||||||
query = self._transform_query(**query)
|
query = self._transform_query(**query)
|
||||||
|
query['_types'] = self._document._class_name
|
||||||
return self._document._from_son(self._collection.find_one(query))
|
return self._document._from_son(self._collection.find_one(query))
|
||||||
|
@ -69,6 +69,7 @@ class EmbeddedDocumentField(BaseField):
|
|||||||
|
|
||||||
def _to_python(self, value):
|
def _to_python(self, value):
|
||||||
if not isinstance(value, self.document):
|
if not isinstance(value, self.document):
|
||||||
|
assert(isinstance(value, (dict, pymongo.son.SON)))
|
||||||
return self.document._from_son(value)
|
return self.document._from_son(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -107,11 +108,6 @@ class ListField(BaseField):
|
|||||||
def _validate(self, value):
|
def _validate(self, value):
|
||||||
"""Make sure that a list of valid fields is being used.
|
"""Make sure that a list of valid fields is being used.
|
||||||
"""
|
"""
|
||||||
# print
|
|
||||||
# print value
|
|
||||||
# print type(value)
|
|
||||||
# print isinstance(value, list)
|
|
||||||
# print
|
|
||||||
if not isinstance(value, (list, tuple)):
|
if not isinstance(value, (list, tuple)):
|
||||||
raise ValidationError('Only lists and tuples may be used in a '
|
raise ValidationError('Only lists and tuples may be used in a '
|
||||||
'list field')
|
'list field')
|
||||||
|
@ -107,6 +107,8 @@ class CollectionManagerTest(unittest.TestCase):
|
|||||||
content = StringField()
|
content = StringField()
|
||||||
author = EmbeddedDocumentField(User)
|
author = EmbeddedDocumentField(User)
|
||||||
|
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
post = BlogPost(content='Had a good coffee today...')
|
post = BlogPost(content='Had a good coffee today...')
|
||||||
post.author = User(name='Test User')
|
post.author = User(name='Test User')
|
||||||
post.save()
|
post.save()
|
||||||
@ -115,6 +117,8 @@ class CollectionManagerTest(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(result.author, User))
|
self.assertTrue(isinstance(result.author, User))
|
||||||
self.assertEqual(result.author.name, 'Test User')
|
self.assertEqual(result.author.name, 'Test User')
|
||||||
|
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -38,6 +38,75 @@ class DocumentTest(unittest.TestCase):
|
|||||||
# Ensure Document isn't treated like an actual document
|
# Ensure Document isn't treated like an actual document
|
||||||
self.assertFalse(hasattr(Document, '_fields'))
|
self.assertFalse(hasattr(Document, '_fields'))
|
||||||
|
|
||||||
|
def test_get_superclasses(self):
|
||||||
|
"""Ensure that the correct list of superclasses is assembled.
|
||||||
|
"""
|
||||||
|
class Animal(Document): pass
|
||||||
|
class Fish(Animal): pass
|
||||||
|
class Mammal(Animal): pass
|
||||||
|
class Human(Mammal): pass
|
||||||
|
class Dog(Mammal): pass
|
||||||
|
|
||||||
|
mammal_superclasses = {'Animal': Animal}
|
||||||
|
self.assertEqual(Mammal._superclasses, mammal_superclasses)
|
||||||
|
|
||||||
|
dog_superclasses = {
|
||||||
|
'Animal': Animal,
|
||||||
|
'Animal.Mammal': Mammal,
|
||||||
|
}
|
||||||
|
self.assertEqual(Dog._superclasses, dog_superclasses)
|
||||||
|
|
||||||
|
def test_get_subclasses(self):
|
||||||
|
"""Ensure that the correct list of subclasses is retrieved by the
|
||||||
|
_get_subclasses method.
|
||||||
|
"""
|
||||||
|
class Animal(Document): pass
|
||||||
|
class Fish(Animal): pass
|
||||||
|
class Mammal(Animal): pass
|
||||||
|
class Human(Mammal): pass
|
||||||
|
class Dog(Mammal): pass
|
||||||
|
|
||||||
|
mammal_subclasses = {
|
||||||
|
'Animal.Mammal.Dog': Dog,
|
||||||
|
'Animal.Mammal.Human': Human
|
||||||
|
}
|
||||||
|
self.assertEqual(Mammal._get_subclasses(), mammal_subclasses)
|
||||||
|
|
||||||
|
animal_subclasses = {
|
||||||
|
'Animal.Fish': Fish,
|
||||||
|
'Animal.Mammal': Mammal,
|
||||||
|
'Animal.Mammal.Dog': Dog,
|
||||||
|
'Animal.Mammal.Human': Human
|
||||||
|
}
|
||||||
|
self.assertEqual(Animal._get_subclasses(), animal_subclasses)
|
||||||
|
|
||||||
|
def test_polymorphic_queries(self):
|
||||||
|
"""Ensure that the correct subclasses are returned from a query"""
|
||||||
|
class Animal(Document): pass
|
||||||
|
class Fish(Animal): pass
|
||||||
|
class Mammal(Animal): pass
|
||||||
|
class Human(Mammal): pass
|
||||||
|
class Dog(Mammal): pass
|
||||||
|
|
||||||
|
self.db.drop_collection(Animal._meta['collection'])
|
||||||
|
|
||||||
|
Animal().save()
|
||||||
|
Fish().save()
|
||||||
|
Mammal().save()
|
||||||
|
Human().save()
|
||||||
|
Dog().save()
|
||||||
|
|
||||||
|
classes = [obj.__class__ for obj in Animal.objects.find()]
|
||||||
|
self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog])
|
||||||
|
|
||||||
|
classes = [obj.__class__ for obj in Mammal.objects.find()]
|
||||||
|
self.assertEqual(classes, [Mammal, Human, Dog])
|
||||||
|
|
||||||
|
classes = [obj.__class__ for obj in Human.objects.find()]
|
||||||
|
self.assertEqual(classes, [Human])
|
||||||
|
|
||||||
|
self.db.drop_collection(Animal._meta['collection'])
|
||||||
|
|
||||||
def test_inheritance(self):
|
def test_inheritance(self):
|
||||||
"""Ensure that document may inherit fields from a superclass document.
|
"""Ensure that document may inherit fields from a superclass document.
|
||||||
"""
|
"""
|
||||||
@ -122,6 +191,8 @@ class DocumentTest(unittest.TestCase):
|
|||||||
comments = ListField(EmbeddedDocumentField(Comment))
|
comments = ListField(EmbeddedDocumentField(Comment))
|
||||||
tags = ListField(StringField())
|
tags = ListField(StringField())
|
||||||
|
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
post = BlogPost(content='Went for a walk today...')
|
post = BlogPost(content='Went for a walk today...')
|
||||||
post.tags = tags = ['fun', 'leisure']
|
post.tags = tags = ['fun', 'leisure']
|
||||||
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
||||||
@ -134,6 +205,8 @@ class DocumentTest(unittest.TestCase):
|
|||||||
for comment_obj, comment in zip(post_obj['comments'], comments):
|
for comment_obj, comment in zip(post_obj['comments'], comments):
|
||||||
self.assertEqual(comment_obj['content'], comment['content'])
|
self.assertEqual(comment_obj['content'], comment['content'])
|
||||||
|
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
def test_save_embedded_document(self):
|
def test_save_embedded_document(self):
|
||||||
"""Ensure that a document with an embedded document field may be
|
"""Ensure that a document with an embedded document field may be
|
||||||
saved in the database.
|
saved in the database.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user