diff --git a/mongoengine/document.py b/mongoengine/document.py index e83dd452..315776aa 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,4 +1,5 @@ from base import DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument +from connection import _get_db __all__ = ['Document', 'EmbeddedDocument'] @@ -18,3 +19,11 @@ class Document(BaseDocument): it will be updated, otherwise it will be created. """ self.objects._save_document(self) + + @classmethod + def drop_collection(cls): + """Drops the entire collection associated with this Document type from + the database. + """ + db = _get_db() + db.drop_collection(cls._meta['collection']) diff --git a/tests/collection.py b/tests/collection.py index 164ba32a..d77f304c 100644 --- a/tests/collection.py +++ b/tests/collection.py @@ -2,7 +2,6 @@ import unittest import pymongo from mongoengine.collection import CollectionManager, QuerySet -from mongoengine.connection import _get_db from mongoengine import * @@ -16,9 +15,6 @@ class CollectionManagerTest(unittest.TestCase): age = IntField() self.Person = Person - self.db = _get_db() - self.db.drop_collection(self.Person._meta['collection']) - def test_initialisation(self): """Ensure that CollectionManager is correctly initialised. """ @@ -112,8 +108,6 @@ class CollectionManagerTest(unittest.TestCase): content = StringField() author = EmbeddedDocumentField(User) - self.db.drop_collection(BlogPost._meta['collection']) - post = BlogPost(content='Had a good coffee today...') post.author = User(name='Test User') post.save() @@ -122,7 +116,7 @@ class CollectionManagerTest(unittest.TestCase): self.assertTrue(isinstance(result.author, User)) self.assertEqual(result.author.name, 'Test User') - self.db.drop_collection(BlogPost._meta['collection']) + BlogPost.drop_collection() def test_delete(self): """Ensure that documents are properly deleted from the database. @@ -139,6 +133,9 @@ class CollectionManagerTest(unittest.TestCase): self.Person.objects.find().delete() self.assertEqual(self.Person.objects.find().count(), 0) + def tearDown(self): + self.Person.drop_collection() + if __name__ == '__main__': unittest.main() diff --git a/tests/document.py b/tests/document.py index 6e97cc8c..2f2fef3c 100644 --- a/tests/document.py +++ b/tests/document.py @@ -9,14 +9,23 @@ class DocumentTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') + self.db = _get_db() class Person(Document): name = StringField() age = IntField() self.Person = Person - self.db = _get_db() - self.db.drop_collection(self.Person._meta['collection']) + def test_drop_collection(self): + """Ensure that the collection may be dropped from the database. + """ + self.Person(name='Test').save() + + collection = self.Person._meta['collection'] + self.assertTrue(collection in self.db.collection_names()) + + self.Person.drop_collection() + self.assertFalse(collection in self.db.collection_names()) def test_definition(self): """Ensure that document may be defined using fields. @@ -89,8 +98,6 @@ class DocumentTest(unittest.TestCase): class Human(Mammal): pass class Dog(Mammal): pass - self.db.drop_collection(Animal._meta['collection']) - Animal().save() Fish().save() Mammal().save() @@ -106,7 +113,7 @@ class DocumentTest(unittest.TestCase): classes = [obj.__class__ for obj in Human.objects.find()] self.assertEqual(classes, [Human]) - self.db.drop_collection(Animal._meta['collection']) + Animal.drop_collection() def test_inheritance(self): """Ensure that document may inherit fields from a superclass document. @@ -192,8 +199,6 @@ class DocumentTest(unittest.TestCase): comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) - self.db.drop_collection(BlogPost._meta['collection']) - post = BlogPost(content='Went for a walk today...') post.tags = tags = ['fun', 'leisure'] comments = [Comment(content='Good for you'), Comment(content='Yay.')] @@ -206,7 +211,7 @@ class DocumentTest(unittest.TestCase): for comment_obj, comment in zip(post_obj['comments'], comments): self.assertEqual(comment_obj['content'], comment['content']) - self.db.drop_collection(BlogPost._meta['collection']) + BlogPost.drop_collection() def test_save_embedded_document(self): """Ensure that a document with an embedded document field may be @@ -241,8 +246,6 @@ class DocumentTest(unittest.TestCase): content = StringField() author = ReferenceField(self.Person) - self.db.drop_collection(BlogPost._meta['collection']) - author = self.Person(name='Test User') author.save() @@ -266,10 +269,10 @@ class DocumentTest(unittest.TestCase): author = self.Person.objects.find_one(name='Test User') self.assertEqual(author.age, 25) - self.db.drop_collection(BlogPost._meta['collection']) + BlogPost.drop_collection() def tearDown(self): - self.db.drop_collection(self.Person._meta['collection']) + self.Person.drop_collection() if __name__ == '__main__': diff --git a/tests/fields.py b/tests/fields.py index e5b6603e..2777a89e 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -177,8 +177,8 @@ class FieldTest(unittest.TestCase): post2.save() self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) - self.db.drop_collection(User._meta['collection']) - self.db.drop_collection(BlogPost._meta['collection']) + User.drop_collection() + BlogPost.drop_collection() if __name__ == '__main__':