Fixes collection creation post drop_collection
Thanks to Julien Rebetez for the original patch closes [#285]
This commit is contained in:
parent
5fb9d61d28
commit
0624cdd6e4
@ -6,6 +6,7 @@ Changelog
|
|||||||
Changes in dev
|
Changes in dev
|
||||||
==============
|
==============
|
||||||
|
|
||||||
|
- Fixed calling a queryset after drop_collection now recreates the collection
|
||||||
- Fixed tree based circular reference bug
|
- Fixed tree based circular reference bug
|
||||||
- Add field name to validation exception messages
|
- Add field name to validation exception messages
|
||||||
- Added UUID field
|
- Added UUID field
|
||||||
|
@ -36,7 +36,6 @@ class EmbeddedDocument(BaseDocument):
|
|||||||
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
|
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseDocument):
|
class Document(BaseDocument):
|
||||||
"""The base class used for defining the structure and properties of
|
"""The base class used for defining the structure and properties of
|
||||||
collections of documents stored in MongoDB. Inherit from this class, and
|
collections of documents stored in MongoDB. Inherit from this class, and
|
||||||
@ -81,7 +80,6 @@ class Document(BaseDocument):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _get_collection(self):
|
def _get_collection(self):
|
||||||
"""Returns the collection for the document."""
|
"""Returns the collection for the document."""
|
||||||
|
|
||||||
if not hasattr(self, '_collection') or self._collection is None:
|
if not hasattr(self, '_collection') or self._collection is None:
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
collection_name = self._get_collection_name()
|
collection_name = self._get_collection_name()
|
||||||
@ -291,8 +289,10 @@ class Document(BaseDocument):
|
|||||||
"""Drops the entire collection associated with this
|
"""Drops the entire collection associated with this
|
||||||
:class:`~mongoengine.Document` type from the database.
|
:class:`~mongoengine.Document` type from the database.
|
||||||
"""
|
"""
|
||||||
|
from mongoengine.queryset import QuerySet
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
db.drop_collection(cls._get_collection_name())
|
db.drop_collection(cls._get_collection_name())
|
||||||
|
QuerySet._reset_already_indexed(cls)
|
||||||
|
|
||||||
|
|
||||||
class DynamicDocument(Document):
|
class DynamicDocument(Document):
|
||||||
|
@ -434,9 +434,11 @@ class QuerySet(object):
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _reset_already_indexed(cls):
|
def _reset_already_indexed(cls, document=None):
|
||||||
"""Helper to reset already indexed, can be useful for testing purposes"""
|
"""Helper to reset already indexed, can be useful for testing purposes"""
|
||||||
cls.__already_indexed = set()
|
if document:
|
||||||
|
cls.__already_indexed.discard(document)
|
||||||
|
cls.__already_indexed.clear()
|
||||||
|
|
||||||
def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query):
|
def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query):
|
||||||
"""Filter the selected documents by calling the
|
"""Filter the selected documents by calling the
|
||||||
@ -476,6 +478,13 @@ class QuerySet(object):
|
|||||||
perform operations only if the collection is accessed.
|
perform operations only if the collection is accessed.
|
||||||
"""
|
"""
|
||||||
if self._document not in QuerySet.__already_indexed:
|
if self._document not in QuerySet.__already_indexed:
|
||||||
|
|
||||||
|
# Ensure collection exists
|
||||||
|
db = _get_db()
|
||||||
|
if self._collection_obj.name not in db.collection_names():
|
||||||
|
self._document._collection = None
|
||||||
|
self._collection_obj = self._document._get_collection()
|
||||||
|
|
||||||
QuerySet.__already_indexed.add(self._document)
|
QuerySet.__already_indexed.add(self._document)
|
||||||
|
|
||||||
background = self._document._meta.get('index_background', False)
|
background = self._document._meta.get('index_background', False)
|
||||||
|
@ -41,6 +41,21 @@ class DocumentTest(unittest.TestCase):
|
|||||||
self.Person.drop_collection()
|
self.Person.drop_collection()
|
||||||
self.assertFalse(collection in self.db.collection_names())
|
self.assertFalse(collection in self.db.collection_names())
|
||||||
|
|
||||||
|
def test_queryset_resurrects_dropped_collection(self):
|
||||||
|
|
||||||
|
self.Person.objects().item_frequencies('name')
|
||||||
|
self.Person.drop_collection()
|
||||||
|
|
||||||
|
self.assertEqual({}, self.Person.objects().item_frequencies('name'))
|
||||||
|
|
||||||
|
class Actor(self.Person):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Ensure works correctly with inhertited classes
|
||||||
|
Actor.objects().item_frequencies('name')
|
||||||
|
self.Person.drop_collection()
|
||||||
|
self.assertEqual({}, Actor.objects().item_frequencies('name'))
|
||||||
|
|
||||||
def test_definition(self):
|
def test_definition(self):
|
||||||
"""Ensure that document may be defined using fields.
|
"""Ensure that document may be defined using fields.
|
||||||
"""
|
"""
|
||||||
|
@ -15,7 +15,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
connect(db='mongoenginetest')
|
connect(db='mongoenginetest')
|
||||||
|
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
age = IntField()
|
age = IntField()
|
||||||
@ -455,6 +455,9 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Blog.drop_collection()
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
# Recreates the collection
|
||||||
|
self.assertEqual(0, Blog.objects.count())
|
||||||
|
|
||||||
with query_counter() as q:
|
with query_counter() as q:
|
||||||
self.assertEqual(q, 0)
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
@ -468,10 +471,10 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
|
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
|
||||||
|
|
||||||
Blog.objects.insert(blogs, load_bulk=False)
|
Blog.objects.insert(blogs, load_bulk=False)
|
||||||
self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert
|
self.assertEqual(q, 1) # 1 for the insert
|
||||||
|
|
||||||
Blog.objects.insert(blogs)
|
Blog.objects.insert(blogs)
|
||||||
self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk
|
self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total)
|
||||||
|
|
||||||
Blog.drop_collection()
|
Blog.drop_collection()
|
||||||
|
|
||||||
@ -1840,7 +1843,6 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True)
|
freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True)
|
||||||
self.assertEquals(freq, {'CRB': 0.5, None: 0.5})
|
self.assertEquals(freq, {'CRB': 0.5, None: 0.5})
|
||||||
|
|
||||||
|
|
||||||
def test_item_frequencies_with_null_embedded(self):
|
def test_item_frequencies_with_null_embedded(self):
|
||||||
class Data(EmbeddedDocument):
|
class Data(EmbeddedDocument):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
@ -2227,7 +2229,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
events = Event.objects(location__within_box=box)
|
events = Event.objects(location__within_box=box)
|
||||||
self.assertEqual(events.count(), 1)
|
self.assertEqual(events.count(), 1)
|
||||||
self.assertEqual(events[0].id, event2.id)
|
self.assertEqual(events[0].id, event2.id)
|
||||||
|
|
||||||
# check that polygon works for users who have a server >= 1.9
|
# check that polygon works for users who have a server >= 1.9
|
||||||
server_version = tuple(
|
server_version = tuple(
|
||||||
_get_connection().server_info()['version'].split('.')
|
_get_connection().server_info()['version'].split('.')
|
||||||
@ -2244,7 +2246,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
events = Event.objects(location__within_polygon=polygon)
|
events = Event.objects(location__within_polygon=polygon)
|
||||||
self.assertEqual(events.count(), 1)
|
self.assertEqual(events.count(), 1)
|
||||||
self.assertEqual(events[0].id, event1.id)
|
self.assertEqual(events[0].id, event1.id)
|
||||||
|
|
||||||
polygon2 = [
|
polygon2 = [
|
||||||
(54.033586,-1.742249),
|
(54.033586,-1.742249),
|
||||||
(52.792797,-1.225891),
|
(52.792797,-1.225891),
|
||||||
@ -2252,7 +2254,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
events = Event.objects(location__within_polygon=polygon2)
|
events = Event.objects(location__within_polygon=polygon2)
|
||||||
self.assertEqual(events.count(), 0)
|
self.assertEqual(events.count(), 0)
|
||||||
|
|
||||||
Event.drop_collection()
|
Event.drop_collection()
|
||||||
|
|
||||||
def test_spherical_geospatial_operators(self):
|
def test_spherical_geospatial_operators(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user