Fixes collection creation post drop_collection

Thanks to Julien Rebetez for the original patch
closes [#285]
This commit is contained in:
Ross Lawley 2011-10-11 02:26:33 -07:00
parent 5fb9d61d28
commit 0624cdd6e4
5 changed files with 38 additions and 11 deletions

View File

@ -6,6 +6,7 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Fixed calling a queryset after drop_collection now recreates the collection
- Fixed tree based circular reference bug - Fixed tree based circular reference bug
- Add field name to validation exception messages - Add field name to validation exception messages
- Added UUID field - Added UUID field

View File

@ -36,7 +36,6 @@ class EmbeddedDocument(BaseDocument):
super(EmbeddedDocument, self).__delattr__(*args, **kwargs) super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
class Document(BaseDocument): class Document(BaseDocument):
"""The base class used for defining the structure and properties of """The base class used for defining the structure and properties of
collections of documents stored in MongoDB. Inherit from this class, and collections of documents stored in MongoDB. Inherit from this class, and
@ -81,7 +80,6 @@ class Document(BaseDocument):
@classmethod @classmethod
def _get_collection(self): def _get_collection(self):
"""Returns the collection for the document.""" """Returns the collection for the document."""
if not hasattr(self, '_collection') or self._collection is None: if not hasattr(self, '_collection') or self._collection is None:
db = _get_db() db = _get_db()
collection_name = self._get_collection_name() collection_name = self._get_collection_name()
@ -291,8 +289,10 @@ class Document(BaseDocument):
"""Drops the entire collection associated with this """Drops the entire collection associated with this
:class:`~mongoengine.Document` type from the database. :class:`~mongoengine.Document` type from the database.
""" """
from mongoengine.queryset import QuerySet
db = _get_db() db = _get_db()
db.drop_collection(cls._get_collection_name()) db.drop_collection(cls._get_collection_name())
QuerySet._reset_already_indexed(cls)
class DynamicDocument(Document): class DynamicDocument(Document):

View File

@ -434,9 +434,11 @@ class QuerySet(object):
return spec return spec
@classmethod @classmethod
def _reset_already_indexed(cls): def _reset_already_indexed(cls, document=None):
"""Helper to reset already indexed, can be useful for testing purposes""" """Helper to reset already indexed, can be useful for testing purposes"""
cls.__already_indexed = set() if document:
cls.__already_indexed.discard(document)
cls.__already_indexed.clear()
def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query):
"""Filter the selected documents by calling the """Filter the selected documents by calling the
@ -476,6 +478,13 @@ class QuerySet(object):
perform operations only if the collection is accessed. perform operations only if the collection is accessed.
""" """
if self._document not in QuerySet.__already_indexed: if self._document not in QuerySet.__already_indexed:
# Ensure collection exists
db = _get_db()
if self._collection_obj.name not in db.collection_names():
self._document._collection = None
self._collection_obj = self._document._get_collection()
QuerySet.__already_indexed.add(self._document) QuerySet.__already_indexed.add(self._document)
background = self._document._meta.get('index_background', False) background = self._document._meta.get('index_background', False)

View File

@ -41,6 +41,21 @@ class DocumentTest(unittest.TestCase):
self.Person.drop_collection() self.Person.drop_collection()
self.assertFalse(collection in self.db.collection_names()) self.assertFalse(collection in self.db.collection_names())
def test_queryset_resurrects_dropped_collection(self):
self.Person.objects().item_frequencies('name')
self.Person.drop_collection()
self.assertEqual({}, self.Person.objects().item_frequencies('name'))
class Actor(self.Person):
pass
# Ensure works correctly with inhertited classes
Actor.objects().item_frequencies('name')
self.Person.drop_collection()
self.assertEqual({}, Actor.objects().item_frequencies('name'))
def test_definition(self): def test_definition(self):
"""Ensure that document may be defined using fields. """Ensure that document may be defined using fields.
""" """

View File

@ -15,7 +15,7 @@ class QuerySetTest(unittest.TestCase):
def setUp(self): def setUp(self):
connect(db='mongoenginetest') connect(db='mongoenginetest')
class Person(Document): class Person(Document):
name = StringField() name = StringField()
age = IntField() age = IntField()
@ -455,6 +455,9 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection() Blog.drop_collection()
# Recreates the collection
self.assertEqual(0, Blog.objects.count())
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
@ -468,10 +471,10 @@ class QuerySetTest(unittest.TestCase):
blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
Blog.objects.insert(blogs, load_bulk=False) Blog.objects.insert(blogs, load_bulk=False)
self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert self.assertEqual(q, 1) # 1 for the insert
Blog.objects.insert(blogs) Blog.objects.insert(blogs)
self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total)
Blog.drop_collection() Blog.drop_collection()
@ -1840,7 +1843,6 @@ class QuerySetTest(unittest.TestCase):
freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True) freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True)
self.assertEquals(freq, {'CRB': 0.5, None: 0.5}) self.assertEquals(freq, {'CRB': 0.5, None: 0.5})
def test_item_frequencies_with_null_embedded(self): def test_item_frequencies_with_null_embedded(self):
class Data(EmbeddedDocument): class Data(EmbeddedDocument):
name = StringField() name = StringField()
@ -2227,7 +2229,7 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_box=box) events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id) self.assertEqual(events[0].id, event2.id)
# check that polygon works for users who have a server >= 1.9 # check that polygon works for users who have a server >= 1.9
server_version = tuple( server_version = tuple(
_get_connection().server_info()['version'].split('.') _get_connection().server_info()['version'].split('.')
@ -2244,7 +2246,7 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_polygon=polygon) events = Event.objects(location__within_polygon=polygon)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event1.id) self.assertEqual(events[0].id, event1.id)
polygon2 = [ polygon2 = [
(54.033586,-1.742249), (54.033586,-1.742249),
(52.792797,-1.225891), (52.792797,-1.225891),
@ -2252,7 +2254,7 @@ class QuerySetTest(unittest.TestCase):
] ]
events = Event.objects(location__within_polygon=polygon2) events = Event.objects(location__within_polygon=polygon2)
self.assertEqual(events.count(), 0) self.assertEqual(events.count(), 0)
Event.drop_collection() Event.drop_collection()
def test_spherical_geospatial_operators(self): def test_spherical_geospatial_operators(self):