Fixes collection creation post drop_collection

Thanks to Julien Rebetez for the original patch
closes [#285]
This commit is contained in:
Ross Lawley 2011-10-11 02:26:33 -07:00
parent 5fb9d61d28
commit 0624cdd6e4
5 changed files with 38 additions and 11 deletions

View File

@ -6,6 +6,7 @@ Changelog
Changes in dev
==============
- Fixed calling a queryset after drop_collection now recreates the collection
- Fixed tree based circular reference bug
- Add field name to validation exception messages
- Added UUID field

View File

@ -36,7 +36,6 @@ class EmbeddedDocument(BaseDocument):
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
class Document(BaseDocument):
"""The base class used for defining the structure and properties of
collections of documents stored in MongoDB. Inherit from this class, and
@ -81,7 +80,6 @@ class Document(BaseDocument):
@classmethod
def _get_collection(self):
"""Returns the collection for the document."""
if not hasattr(self, '_collection') or self._collection is None:
db = _get_db()
collection_name = self._get_collection_name()
@ -291,8 +289,10 @@ class Document(BaseDocument):
"""Drops the entire collection associated with this
:class:`~mongoengine.Document` type from the database.
"""
from mongoengine.queryset import QuerySet
db = _get_db()
db.drop_collection(cls._get_collection_name())
QuerySet._reset_already_indexed(cls)
class DynamicDocument(Document):

View File

@ -434,9 +434,11 @@ class QuerySet(object):
return spec
@classmethod
def _reset_already_indexed(cls):
def _reset_already_indexed(cls, document=None):
"""Helper to reset already indexed, can be useful for testing purposes"""
cls.__already_indexed = set()
if document:
cls.__already_indexed.discard(document)
cls.__already_indexed.clear()
def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query):
"""Filter the selected documents by calling the
@ -476,6 +478,13 @@ class QuerySet(object):
perform operations only if the collection is accessed.
"""
if self._document not in QuerySet.__already_indexed:
# Ensure collection exists
db = _get_db()
if self._collection_obj.name not in db.collection_names():
self._document._collection = None
self._collection_obj = self._document._get_collection()
QuerySet.__already_indexed.add(self._document)
background = self._document._meta.get('index_background', False)

View File

@ -41,6 +41,21 @@ class DocumentTest(unittest.TestCase):
self.Person.drop_collection()
self.assertFalse(collection in self.db.collection_names())
def test_queryset_resurrects_dropped_collection(self):
self.Person.objects().item_frequencies('name')
self.Person.drop_collection()
self.assertEqual({}, self.Person.objects().item_frequencies('name'))
class Actor(self.Person):
pass
# Ensure works correctly with inhertited classes
Actor.objects().item_frequencies('name')
self.Person.drop_collection()
self.assertEqual({}, Actor.objects().item_frequencies('name'))
def test_definition(self):
"""Ensure that document may be defined using fields.
"""

View File

@ -15,7 +15,7 @@ class QuerySetTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
class Person(Document):
name = StringField()
age = IntField()
@ -455,6 +455,9 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
# Recreates the collection
self.assertEqual(0, Blog.objects.count())
with query_counter() as q:
self.assertEqual(q, 0)
@ -468,10 +471,10 @@ class QuerySetTest(unittest.TestCase):
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
Blog.objects.insert(blogs, load_bulk=False)
self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert
self.assertEqual(q, 1) # 1 for the insert
Blog.objects.insert(blogs)
self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk
self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total)
Blog.drop_collection()
@ -1840,7 +1843,6 @@ class QuerySetTest(unittest.TestCase):
freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True)
self.assertEquals(freq, {'CRB': 0.5, None: 0.5})
def test_item_frequencies_with_null_embedded(self):
class Data(EmbeddedDocument):
name = StringField()
@ -2227,7 +2229,7 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id)
# check that polygon works for users who have a server >= 1.9
server_version = tuple(
_get_connection().server_info()['version'].split('.')
@ -2244,7 +2246,7 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_polygon=polygon)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event1.id)
polygon2 = [
(54.033586,-1.742249),
(52.792797,-1.225891),
@ -2252,7 +2254,7 @@ class QuerySetTest(unittest.TestCase):
]
events = Event.objects(location__within_polygon=polygon2)
self.assertEqual(events.count(), 0)
Event.drop_collection()
def test_spherical_geospatial_operators(self):