diff --git a/mongomap/base.py b/mongomap/base.py index be4fd324..b180231a 100644 --- a/mongomap/base.py +++ b/mongomap/base.py @@ -97,10 +97,18 @@ class DocumentMetaclass(type): return super_new(cls, name, bases, attrs) doc_fields = {} - # Include all fields present in superclasses + class_name = [name] + superclasses = {} for base in bases: + # Include all fields present in superclasses if hasattr(base, '_fields'): doc_fields.update(base._fields) + class_name.append(base._class_name) + # Get superclasses from superclass + superclasses[base._class_name] = base + superclasses.update(base._superclasses) + attrs['_class_name'] = '.'.join(reversed(class_name)) + attrs['_superclasses'] = superclasses # Add the document's fields to the _fields attribute for attr_name, attr_value in attrs.items(): @@ -164,6 +172,21 @@ class BaseDocument(object): # Use default value setattr(self, attr_name, getattr(self, attr_name, None)) + @classmethod + def _get_subclasses(cls): + """Return a dictionary of all subclasses (found recursively). + """ + try: + subclasses = cls.__subclasses__() + except: + subclasses = cls.__subclasses__(cls) + + all_subclasses = {} + for subclass in subclasses: + all_subclasses[subclass._class_name] = subclass + all_subclasses.update(subclass._get_subclasses()) + return all_subclasses + def __iter__(self): # Use _data rather than _fields as iterator only looks at names so # values don't need to be converted to Python types @@ -203,12 +226,25 @@ class BaseDocument(object): value = getattr(self, field_name, None) if value is not None: data[field_name] = field._to_mongo(value) + data['_cls'] = self._class_name + data['_types'] = self._superclasses.keys() + [self._class_name] return data @classmethod def _from_son(cls, son): """Create an instance of a Document (subclass) from a PyMongo SOM. """ + class_name = son[u'_cls'] data = dict((str(key), value) for key, value in son.items()) + del data['_cls'] + + # Return correct subclass for document type + if class_name != cls._class_name: + subclasses = cls._get_subclasses() + if class_name not in subclasses: + # Type of document is probably more generic than the class + # that has been queried to return this SON + return None + cls = subclasses[class_name] return cls(**data) diff --git a/mongomap/collection.py b/mongomap/collection.py index db4c0760..3f0ae8f8 100644 --- a/mongomap/collection.py +++ b/mongomap/collection.py @@ -76,13 +76,13 @@ class CollectionManager(object): def find(self, **query): """Query the collection for document matching the provided query. """ - if query: - query = self._transform_query(**query) + query = self._transform_query(**query) + query['_types'] = self._document._class_name return QuerySet(self._document, self._collection.find(query)) def find_one(self, **query): """Query the collection for document matching the provided query. """ - if query: - query = self._transform_query(**query) + query = self._transform_query(**query) + query['_types'] = self._document._class_name return self._document._from_son(self._collection.find_one(query)) diff --git a/mongomap/fields.py b/mongomap/fields.py index 52a91d13..f65aa38a 100644 --- a/mongomap/fields.py +++ b/mongomap/fields.py @@ -69,6 +69,7 @@ class EmbeddedDocumentField(BaseField): def _to_python(self, value): if not isinstance(value, self.document): + assert(isinstance(value, (dict, pymongo.son.SON))) return self.document._from_son(value) return value @@ -107,11 +108,6 @@ class ListField(BaseField): def _validate(self, value): """Make sure that a list of valid fields is being used. """ -# print -# print value -# print type(value) -# print isinstance(value, list) -# print if not isinstance(value, (list, tuple)): raise ValidationError('Only lists and tuples may be used in a ' 'list field') diff --git a/tests/collection.py b/tests/collection.py index 4de6cbbc..f5b69a3c 100644 --- a/tests/collection.py +++ b/tests/collection.py @@ -107,6 +107,8 @@ class CollectionManagerTest(unittest.TestCase): content = StringField() author = EmbeddedDocumentField(User) + self.db.drop_collection(BlogPost._meta['collection']) + post = BlogPost(content='Had a good coffee today...') post.author = User(name='Test User') post.save() @@ -114,6 +116,8 @@ class CollectionManagerTest(unittest.TestCase): result = BlogPost.objects.find_one() self.assertTrue(isinstance(result.author, User)) self.assertEqual(result.author.name, 'Test User') + + self.db.drop_collection(BlogPost._meta['collection']) if __name__ == '__main__': diff --git a/tests/document.py b/tests/document.py index 738fef15..e389b322 100644 --- a/tests/document.py +++ b/tests/document.py @@ -38,6 +38,75 @@ class DocumentTest(unittest.TestCase): # Ensure Document isn't treated like an actual document self.assertFalse(hasattr(Document, '_fields')) + def test_get_superclasses(self): + """Ensure that the correct list of superclasses is assembled. + """ + class Animal(Document): pass + class Fish(Animal): pass + class Mammal(Animal): pass + class Human(Mammal): pass + class Dog(Mammal): pass + + mammal_superclasses = {'Animal': Animal} + self.assertEqual(Mammal._superclasses, mammal_superclasses) + + dog_superclasses = { + 'Animal': Animal, + 'Animal.Mammal': Mammal, + } + self.assertEqual(Dog._superclasses, dog_superclasses) + + def test_get_subclasses(self): + """Ensure that the correct list of subclasses is retrieved by the + _get_subclasses method. + """ + class Animal(Document): pass + class Fish(Animal): pass + class Mammal(Animal): pass + class Human(Mammal): pass + class Dog(Mammal): pass + + mammal_subclasses = { + 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Human': Human + } + self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) + + animal_subclasses = { + 'Animal.Fish': Fish, + 'Animal.Mammal': Mammal, + 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Human': Human + } + self.assertEqual(Animal._get_subclasses(), animal_subclasses) + + def test_polymorphic_queries(self): + """Ensure that the correct subclasses are returned from a query""" + class Animal(Document): pass + class Fish(Animal): pass + class Mammal(Animal): pass + class Human(Mammal): pass + class Dog(Mammal): pass + + self.db.drop_collection(Animal._meta['collection']) + + Animal().save() + Fish().save() + Mammal().save() + Human().save() + Dog().save() + + classes = [obj.__class__ for obj in Animal.objects.find()] + self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) + + classes = [obj.__class__ for obj in Mammal.objects.find()] + self.assertEqual(classes, [Mammal, Human, Dog]) + + classes = [obj.__class__ for obj in Human.objects.find()] + self.assertEqual(classes, [Human]) + + self.db.drop_collection(Animal._meta['collection']) + def test_inheritance(self): """Ensure that document may inherit fields from a superclass document. """ @@ -122,6 +191,8 @@ class DocumentTest(unittest.TestCase): comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) + self.db.drop_collection(BlogPost._meta['collection']) + post = BlogPost(content='Went for a walk today...') post.tags = tags = ['fun', 'leisure'] comments = [Comment(content='Good for you'), Comment(content='Yay.')] @@ -134,6 +205,8 @@ class DocumentTest(unittest.TestCase): for comment_obj, comment in zip(post_obj['comments'], comments): self.assertEqual(comment_obj['content'], comment['content']) + self.db.drop_collection(BlogPost._meta['collection']) + def test_save_embedded_document(self): """Ensure that a document with an embedded document field may be saved in the database.