From 0624cdd6e419bf717ceb3c5c3bb108c82a953cc8 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 11 Oct 2011 02:26:33 -0700 Subject: [PATCH] Fixes collection creation post drop_collection Thanks to Julien Rebetez for the original patch closes [#285] --- docs/changelog.rst | 1 + mongoengine/document.py | 4 ++-- mongoengine/queryset.py | 13 +++++++++++-- tests/document.py | 15 +++++++++++++++ tests/queryset.py | 16 +++++++++------- 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index abbc1e4c..fab2041e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Changelog Changes in dev ============== +- Fixed calling a queryset after drop_collection now recreates the collection - Fixed tree based circular reference bug - Add field name to validation exception messages - Added UUID field diff --git a/mongoengine/document.py b/mongoengine/document.py index ce001d2a..a87f460e 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -36,7 +36,6 @@ class EmbeddedDocument(BaseDocument): super(EmbeddedDocument, self).__delattr__(*args, **kwargs) - class Document(BaseDocument): """The base class used for defining the structure and properties of collections of documents stored in MongoDB. Inherit from this class, and @@ -81,7 +80,6 @@ class Document(BaseDocument): @classmethod def _get_collection(self): """Returns the collection for the document.""" - if not hasattr(self, '_collection') or self._collection is None: db = _get_db() collection_name = self._get_collection_name() @@ -291,8 +289,10 @@ class Document(BaseDocument): """Drops the entire collection associated with this :class:`~mongoengine.Document` type from the database. """ + from mongoengine.queryset import QuerySet db = _get_db() db.drop_collection(cls._get_collection_name()) + QuerySet._reset_already_indexed(cls) class DynamicDocument(Document): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index e3d2b473..4c88ba24 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -434,9 +434,11 @@ class QuerySet(object): return spec @classmethod - def _reset_already_indexed(cls): + def _reset_already_indexed(cls, document=None): """Helper to reset already indexed, can be useful for testing purposes""" - cls.__already_indexed = set() + if document: + cls.__already_indexed.discard(document) + cls.__already_indexed.clear() def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): """Filter the selected documents by calling the @@ -476,6 +478,13 @@ class QuerySet(object): perform operations only if the collection is accessed. """ if self._document not in QuerySet.__already_indexed: + + # Ensure collection exists + db = _get_db() + if self._collection_obj.name not in db.collection_names(): + self._document._collection = None + self._collection_obj = self._document._get_collection() + QuerySet.__already_indexed.add(self._document) background = self._document._meta.get('index_background', False) diff --git a/tests/document.py b/tests/document.py index 1eeda46e..816bb498 100644 --- a/tests/document.py +++ b/tests/document.py @@ -41,6 +41,21 @@ class DocumentTest(unittest.TestCase): self.Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_queryset_resurrects_dropped_collection(self): + + self.Person.objects().item_frequencies('name') + self.Person.drop_collection() + + self.assertEqual({}, self.Person.objects().item_frequencies('name')) + + class Actor(self.Person): + pass + + # Ensure works correctly with inhertited classes + Actor.objects().item_frequencies('name') + self.Person.drop_collection() + self.assertEqual({}, Actor.objects().item_frequencies('name')) + def test_definition(self): """Ensure that document may be defined using fields. """ diff --git a/tests/queryset.py b/tests/queryset.py index de023877..92be1d3b 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -15,7 +15,7 @@ class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - + class Person(Document): name = StringField() age = IntField() @@ -455,6 +455,9 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() + # Recreates the collection + self.assertEqual(0, Blog.objects.count()) + with query_counter() as q: self.assertEqual(q, 0) @@ -468,10 +471,10 @@ class QuerySetTest(unittest.TestCase): blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert + self.assertEqual(q, 1) # 1 for the insert Blog.objects.insert(blogs) - self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk + self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total) Blog.drop_collection() @@ -1840,7 +1843,6 @@ class QuerySetTest(unittest.TestCase): freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True) self.assertEquals(freq, {'CRB': 0.5, None: 0.5}) - def test_item_frequencies_with_null_embedded(self): class Data(EmbeddedDocument): name = StringField() @@ -2227,7 +2229,7 @@ class QuerySetTest(unittest.TestCase): events = Event.objects(location__within_box=box) self.assertEqual(events.count(), 1) self.assertEqual(events[0].id, event2.id) - + # check that polygon works for users who have a server >= 1.9 server_version = tuple( _get_connection().server_info()['version'].split('.') @@ -2244,7 +2246,7 @@ class QuerySetTest(unittest.TestCase): events = Event.objects(location__within_polygon=polygon) self.assertEqual(events.count(), 1) self.assertEqual(events[0].id, event1.id) - + polygon2 = [ (54.033586,-1.742249), (52.792797,-1.225891), @@ -2252,7 +2254,7 @@ class QuerySetTest(unittest.TestCase): ] events = Event.objects(location__within_polygon=polygon2) self.assertEqual(events.count(), 0) - + Event.drop_collection() def test_spherical_geospatial_operators(self):