diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 38ebb243..c0cfde31 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,10 +1,10 @@ from pymongo import MongoClient, ReadPreference, uri_parser import six -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 __all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME'] + 'DEFAULT_CONNECTION_NAME', 'get_db'] DEFAULT_CONNECTION_NAME = 'default' diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index d1e5d9ef..98bd897b 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -5,6 +5,7 @@ from six import iteritems from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.pymongo_support import count_documents __all__ = ('switch_db', 'switch_collection', 'no_dereference', 'no_sub_classes', 'query_counter', 'set_write_concern') @@ -237,7 +238,7 @@ class query_counter(object): and substracting the queries issued by this context. In fact everytime this is called, 1 query is issued so we need to balance that """ - count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter + count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter self._ctx_query_counter += 1 # Account for the query we just issued to gather the information return count diff --git a/mongoengine/document.py b/mongoengine/document.py index 84c1d699..5981d8d1 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -18,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern, switch_db) from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, SaveConditionError) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3, list_collection_names from mongoengine.queryset import (NotUniqueError, OperationError, QuerySet, transform) @@ -228,7 +228,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # If the collection already exists and has different options # (i.e. isn't capped or has different max/size), raise an error. - if collection_name in db.collection_names(): + if collection_name in list_collection_names(db, include_system_collections=True): collection = db[collection_name] options = collection.options() if ( diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py new file mode 100644 index 00000000..0d607162 --- /dev/null +++ b/mongoengine/pymongo_support.py @@ -0,0 +1,33 @@ +""" +Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support. +""" +import pymongo + +_PYMONGO_37 = (3, 7) + +PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) + +IS_PYMONGO_3 = PYMONGO_VERSION[0] >= 3 +IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 + + +def count_documents(collection, filter): + """Pymongo>3.7 deprecates count in favour of count_documents""" + if IS_PYMONGO_GTE_37: + return collection.count_documents(filter) + else: + count = collection.find(filter).count() + return count + + +def list_collection_names(db, include_system_collections=False): + """Pymongo>3.7 deprecates collection_names in favour of list_collection_names""" + if IS_PYMONGO_GTE_37: + collections = db.list_collection_names() + else: + collections = db.collection_names() + + if not include_system_collections: + collections = [c for c in collections if not c.startswith('system.')] + + return collections diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 7e8e108f..57e467db 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -1,13 +1,8 @@ """ -Helper functions, constants, and types to aid with Python v2.7 - v3.x and -PyMongo v2.7 - v3.x support. +Helper functions, constants, and types to aid with Python v2.7 - v3.x support """ -import pymongo import six - -IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3 - # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. StringIO = six.BytesIO diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 8c22c5b9..9ddfeab2 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -21,7 +21,7 @@ from mongoengine.connection import get_db from mongoengine.context_managers import set_write_concern, switch_db from mongoengine.errors import (InvalidQueryError, LookUpError, NotUniqueError, OperationError) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index c00271f3..3de10a69 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -10,7 +10,7 @@ from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class from mongoengine.connection import get_connection from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 __all__ = ('query', 'update') diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 88937ec8..421618e4 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -2,6 +2,7 @@ import unittest from mongoengine import * +from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db @@ -27,9 +28,7 @@ class ClassMethodsTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_definition(self): @@ -66,10 +65,10 @@ class ClassMethodsTest(unittest.TestCase): """ collection_name = 'person' self.Person(name='Test').save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, list_collection_names(self.db)) self.Person.drop_collection() - self.assertNotIn(collection_name, self.db.collection_names()) + self.assertNotIn(collection_name, list_collection_names(self.db)) def test_register_delete_rule(self): """Ensure that register delete rule adds a delete rule to the document @@ -340,7 +339,7 @@ class ClassMethodsTest(unittest.TestCase): meta = {'collection': collection_name} Person(name="Test User").save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, list_collection_names(self.db)) user_obj = self.db[collection_name].find_one() self.assertEqual(user_obj['name'], "Test User") @@ -349,7 +348,7 @@ class ClassMethodsTest(unittest.TestCase): self.assertEqual(user_obj.name, "Test User") Person.drop_collection() - self.assertNotIn(collection_name, self.db.collection_names()) + self.assertNotIn(collection_name, list_collection_names(self.db)) def test_collection_name_and_primary(self): """Ensure that a collection with a specified name may be used. diff --git a/tests/document/delta.py b/tests/document/delta.py index 942e3a0a..504c1707 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -3,16 +3,14 @@ import unittest from bson import SON from mongoengine import * -from mongoengine.connection import get_db - -__all__ = ("DeltaTest",) +from mongoengine.pymongo_support import list_collection_names +from tests.utils import MongoDBTestCase -class DeltaTest(unittest.TestCase): +class DeltaTest(MongoDBTestCase): def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() + super(DeltaTest, self).setUp() class Person(Document): name = StringField() @@ -25,9 +23,7 @@ class DeltaTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_delta(self): diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 83c2a80a..d81039f4 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -6,23 +6,18 @@ from six import iteritems from mongoengine import (BooleanField, Document, EmbeddedDocument, EmbeddedDocumentField, GenericReferenceField, - IntField, ReferenceField, StringField, connect) -from mongoengine.connection import get_db + IntField, ReferenceField, StringField) +from mongoengine.pymongo_support import list_collection_names +from tests.utils import MongoDBTestCase from tests.fixtures import Base __all__ = ('InheritanceTest', ) -class InheritanceTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() +class InheritanceTest(MongoDBTestCase): def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_constructor_cls(self): diff --git a/tests/document/instance.py b/tests/document/instance.py index 051eda68..9b28f1b4 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -12,6 +12,7 @@ from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError from six import iteritems +from mongoengine.pymongo_support import list_collection_names from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, PickleDynamicEmbedded, PickleDynamicTest) @@ -55,9 +56,7 @@ class InstanceTest(MongoDBTestCase): self.Job = Job def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def assertDbEqual(self, docs): @@ -572,7 +571,7 @@ class InstanceTest(MongoDBTestCase): Post.drop_collection() - Post._get_collection().insert({ + Post._get_collection().insert_one({ "title": "Items eclipse", "items": ["more lorem", "even more ipsum"] }) @@ -3217,8 +3216,7 @@ class InstanceTest(MongoDBTestCase): coll = Person._get_collection() for person in Person.objects.as_pymongo(): if 'height' not in person: - person['height'] = 189 - coll.save(person) + coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}}) self.assertEquals(Person.objects(height=189).count(), 1) diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 76e20bb9..4ff6865b 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -24,6 +24,16 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') +def get_file(path): + """Use a BytesIO instead of a file to allow + to have a one-liner and avoid that the file remains opened""" + bytes_io = StringIO() + with open(path, 'rb') as f: + bytes_io.write(f.read()) + bytes_io.seek(0) + return bytes_io + + class FileTest(MongoDBTestCase): def tearDown(self): @@ -247,8 +257,8 @@ class FileTest(MongoDBTestCase): Animal.drop_collection() marmot = Animal(genus='Marmota', family='Sciuridae') - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - marmot.photo.put(marmot_photo, content_type='image/jpeg', foo='bar') + marmot_photo_content = get_file(TEST_IMAGE_PATH) # Retrieve a photo from disk + marmot.photo.put(marmot_photo_content, content_type='image/jpeg', foo='bar') marmot.photo.close() marmot.save() @@ -261,11 +271,11 @@ class FileTest(MongoDBTestCase): the_file = FileField() TestFile.drop_collection() - test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() self.assertEqual(test_file.the_file.get().length, 8313) test_file = TestFile.objects.first() - test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() self.assertEqual(test_file.the_file.get().length, 4971) @@ -379,7 +389,7 @@ class FileTest(MongoDBTestCase): self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -400,11 +410,11 @@ class FileTest(MongoDBTestCase): the_file = ImageField() TestFile.drop_collection() - test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() self.assertEqual(test_file.the_file.size, (371, 76)) test_file = TestFile.objects.first() - test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() self.assertEqual(test_file.the_file.size, (45, 101)) @@ -418,7 +428,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -441,7 +451,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -464,7 +474,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -542,8 +552,8 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image1.put(open(TEST_IMAGE_PATH, 'rb')) - t.image2.put(open(TEST_IMAGE2_PATH, 'rb')) + t.image1.put(get_file(TEST_IMAGE_PATH)) + t.image2.put(get_file(TEST_IMAGE2_PATH)) t.save() test = TestImage.objects.first() @@ -563,12 +573,10 @@ class FileTest(MongoDBTestCase): Animal.drop_collection() marmot = Animal(genus='Marmota', family='Sciuridae') - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - - photos_field = marmot._fields['photos'].field - new_proxy = photos_field.get_proxy_obj('photos', marmot) - new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') - marmot_photo.close() + with open(TEST_IMAGE_PATH, 'rb') as marmot_photo: # Retrieve a photo from disk + photos_field = marmot._fields['photos'].field + new_proxy = photos_field.get_proxy_obj('photos', marmot) + new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') marmot.photos.append(new_proxy) marmot.save() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 05a7ca75..0d8d6285 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -19,14 +19,11 @@ from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32 -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) - from tests.utils import requires_mongodb_gte_26, skip_pymongo3 -__all__ = ("QuerySetTest",) - class db_ops_tracker(query_counter): @@ -4052,7 +4049,7 @@ class QuerySetTest(unittest.TestCase): fielda = IntField() fieldb = IntField() - IntPair.objects._collection.remove() + IntPair.drop_collection() a = IntPair(fielda=1, fieldb=1) b = IntPair(fielda=1, fieldb=2) @@ -5387,7 +5384,7 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() - Person._get_collection().insert({'name': 'a', 'id': ''}) + Person._get_collection().insert_one({'name': 'a', 'id': ''}) for p in Person.objects(): self.assertEqual(p.name, 'a') diff --git a/tests/test_connection.py b/tests/test_connection.py index 7c4fc4cf..fafef9d4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -14,7 +14,7 @@ from mongoengine import ( connect, register_connection, Document, DateTimeField ) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 import mongoengine.connection from mongoengine.connection import (MongoEngineConnectionError, get_db, get_connection) @@ -147,12 +147,12 @@ class ConnectionTest(unittest.TestCase): def test_connect_uri(self): """Ensure that the connect() method works properly with URIs.""" c = connect(db='mongoenginetest', alias='admin') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) - c.admin.add_user("admin", "password") + c.admin.command("createUser", "admin", pwd="password", roles=["root"]) c.admin.authenticate("admin", "password") - c.mongoenginetest.add_user("username", "password") + c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"]) if not IS_PYMONGO_3: self.assertRaises( @@ -169,8 +169,8 @@ class ConnectionTest(unittest.TestCase): self.assertIsInstance(db, pymongo.database.Database) self.assertEqual(db.name, 'mongoenginetest') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) def test_connect_uri_without_db(self): """Ensure connect() method works properly if the URI doesn't @@ -217,8 +217,9 @@ class ConnectionTest(unittest.TestCase): """ # Create users c = connect('mongoenginetest') - c.admin.system.users.remove({}) - c.admin.add_user('username2', 'password') + + c.admin.system.users.delete_many({}) + c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"]) # Authentication fails without "authSource" if IS_PYMONGO_3: @@ -246,7 +247,7 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(db.name, 'mongoenginetest') # Clear all users - authd_conn.admin.system.users.remove({}) + authd_conn.admin.system.users.delete_many({}) def test_register_connection(self): """Ensure that connections with different aliases may be registered. diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 8207cd89..227031e0 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -5,6 +5,7 @@ from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, no_sub_classes, no_dereference, query_counter) +from mongoengine.pymongo_support import count_documents class ContextManagersTest(unittest.TestCase): @@ -240,7 +241,7 @@ class ContextManagersTest(unittest.TestCase): collection.drop() def issue_1_count_query(): - collection.find({}).count() + count_documents(collection, {}) def issue_1_insert_query(): collection.insert_one({'test': 'garbage'}) diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index 4aa647d6..81fdfb64 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -2,7 +2,7 @@ import unittest from pymongo import ReadPreference -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 if IS_PYMONGO_3: from pymongo import MongoClient diff --git a/tests/utils.py b/tests/utils.py index e94e4a80..3c41f07d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ from nose.plugins.skip import SkipTest from mongoengine import connect from mongoengine.connection import get_db from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database