Merge pull request #946 from MRigal/fix/pymongo3-connection

fixes #946
This commit is contained in:
David Bordeynik 2015-05-11 15:51:51 +03:00
commit 94eac1e79d
18 changed files with 293 additions and 91 deletions

View File

@ -10,8 +10,8 @@ python:
env:
- PYMONGO=2.7
- PYMONGO=2.8
# - PYMONGO=3.0
# - PYMONGO=dev
- PYMONGO=3.0
- PYMONGO=dev
matrix:
fast_finish: true
before_install:

View File

@ -221,3 +221,4 @@ that much better:
* Eremeev Danil (https://github.com/elephanter)
* Catstyle Lee (https://github.com/Catstyle)
* Kiryl Yermakou (https://github.com/rma4ok)
* Matthieu Rigal (https://github.com/MRigal)

View File

@ -17,6 +17,7 @@ Changes in 0.9.X - DEV
- Don't send a "cls" option to ensureIndex (related to https://jira.mongodb.org/browse/SERVER-769)
- Fix for updating sorting in SortedListField. #978
- Added __ support to escape field name in fields lookup keywords that match operators names #949
- Support for PyMongo 3+ #946
Changes in 0.9.0
================

View File

@ -1,5 +1,4 @@
import weakref
import functools
import itertools
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned

View File

@ -1,13 +1,11 @@
import warnings
import pymongo
from mongoengine.common import _import_class
from mongoengine.errors import InvalidDocumentError
from mongoengine.python_support import PY3
from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
MultipleObjectsReturned,
QuerySet, QuerySetManager)
QuerySetManager)
from mongoengine.base.common import _document_registry, ALLOW_INHERITANCE
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField

View File

@ -1,11 +1,16 @@
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
from pymongo import MongoClient, ReadPreference, uri_parser
from mongoengine.python_support import IS_PYMONGO_3
__all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME']
DEFAULT_CONNECTION_NAME = 'default'
if IS_PYMONGO_3:
READ_PREFERENCE = ReadPreference.PRIMARY
else:
from pymongo import MongoReplicaSetClient
READ_PREFERENCE = False
class ConnectionError(Exception):
@ -18,7 +23,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None,
read_preference=False,
read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None,
**kwargs):
"""Add a connection.
@ -109,6 +114,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Discard replicaSet if not base string
if not isinstance(conn_settings['replicaSet'], basestring):
conn_settings.pop('replicaSet', None)
if not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
try:

View File

@ -1,11 +1,8 @@
import warnings
import hashlib
import pymongo
import re
from pymongo.read_preferences import ReadPreference
from bson import ObjectId
from bson.dbref import DBRef
from mongoengine import signals
from mongoengine.common import _import_class
@ -19,7 +16,7 @@ from mongoengine.base import (
ALLOW_INHERITANCE,
get_document
)
from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError
from mongoengine.errors import InvalidQueryError, InvalidDocumentError
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
@ -169,6 +166,7 @@ class Document(BaseDocument):
@classmethod
def _get_collection(cls):
"""Returns the collection for the document."""
# TODO: use new get_collection() with PyMongo3 ?
if not hasattr(cls, '_collection') or cls._collection is None:
db = cls._get_db()
collection_name = cls._get_collection_name()
@ -310,6 +308,13 @@ class Document(BaseDocument):
object_id = collection.insert(doc, **write_concern)
else:
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = self._qs.filter(pk=pk_as_mongo_obj).first() and \
self._qs.filter(pk=pk_as_mongo_obj).first().pk
else:
object_id = doc['_id']
updates, removals = self._delta()

View File

@ -1,6 +1,13 @@
"""Helper functions and types to aid with Python 2.5 - 3 support."""
import sys
import pymongo
if pymongo.version_tuple[0] < 3:
IS_PYMONGO_3 = False
else:
IS_PYMONGO_3 = True
PY3 = sys.version_info[0] == 3

View File

@ -21,10 +21,14 @@ from mongoengine.common import _import_class
from mongoengine.base.common import get_document
from mongoengine.errors import (OperationError, NotUniqueError,
InvalidQueryError, LookUpError)
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
if IS_PYMONGO_3:
from pymongo.collection import ReturnDocument
__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
@ -158,7 +162,8 @@ class BaseQuerySet(object):
if queryset._as_pymongo:
return queryset._get_as_pymongo(queryset._cursor[key])
return queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference, only_fields=self.only_fields)
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
raise AttributeError
@ -423,7 +428,7 @@ class BaseQuerySet(object):
if call_document_delete:
cnt = 0
for doc in queryset:
doc.delete(write_concern=write_concern)
doc.delete(**write_concern)
cnt += 1
return cnt
@ -545,7 +550,7 @@ class BaseQuerySet(object):
:param upsert: insert if document doesn't exist (default ``False``)
:param full_response: return the entire response object from the
server (default ``False``)
server (default ``False``, not available for PyMongo 3+)
:param remove: remove rather than updating (default ``False``)
:param new: return updated rather than original document
(default ``False``)
@ -563,10 +568,28 @@ class BaseQuerySet(object):
queryset = self.clone()
query = queryset._query
if not IS_PYMONGO_3 or not remove:
update = transform.update(queryset._document, **update)
sort = queryset._ordering
try:
if IS_PYMONGO_3:
if full_response:
msg = ("With PyMongo 3+, it is not possible anymore to get the full response.")
warnings.warn(msg, DeprecationWarning)
if remove:
result = queryset._collection.find_one_and_delete(
query, sort=sort, **self._cursor_args)
else:
if new:
return_doc = ReturnDocument.AFTER
else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
query, update, upsert=upsert, sort=sort, return_document=return_doc,
**self._cursor_args)
else:
result = queryset._collection.find_and_modify(
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
@ -907,13 +930,18 @@ class BaseQuerySet(object):
plan = pprint.pformat(plan)
return plan
# DEPRECATED. Has no more impact on PyMongo 3+
def snapshot(self, enabled):
"""Enable or disable snapshot mode when querying.
:param enabled: whether or not snapshot mode is enabled
..versionchanged:: 0.5 - made chainable
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = "snapshot is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._snapshot = enabled
return queryset
@ -929,11 +957,17 @@ class BaseQuerySet(object):
queryset._timeout = enabled
return queryset
# DEPRECATED. Has no more impact on PyMongo 3+
def slave_okay(self, enabled):
"""Enable or disable the slave_okay when querying.
:param enabled: whether or not the slave_okay is enabled
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._slave_okay = enabled
return queryset
@ -1383,22 +1417,34 @@ class BaseQuerySet(object):
@property
def _cursor_args(self):
if not IS_PYMONGO_3:
fields_name = 'fields'
cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout
'timeout': self._timeout,
'snapshot': self._snapshot
}
if self._read_preference is not None:
cursor_args['read_preference'] = self._read_preference
else:
cursor_args['slave_okay'] = self._slave_okay
else:
fields_name = 'projection'
# snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers
if self._snapshot:
msg = "The snapshot option is not anymore available with PyMongo 3+"
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': self._timeout
}
if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict()
cursor_args[fields_name] = self._loaded_fields.as_dict()
if self._search_text:
if 'fields' not in cursor_args:
cursor_args['fields'] = {}
if fields_name not in cursor_args:
cursor_args[fields_name] = {}
cursor_args['fields']['_text_score'] = {'$meta': "textScore"}
cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"}
return cursor_args

View File

@ -6,7 +6,7 @@ from bson import SON
from mongoengine.base.fields import UPDATE_OPERATORS
from mongoengine.connection import get_connection
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError, LookUpError
from mongoengine.errors import InvalidQueryError
__all__ = ('query', 'update')
@ -128,20 +128,15 @@ def query(_doc_cls=None, _field_operation=False, **query):
mongo_query[key].update(value)
# $maxDistance needs to come last - convert to SON
value_dict = mongo_query[key]
if ('$maxDistance' in value_dict and '$near' in value_dict):
if '$maxDistance' in value_dict and '$near' in value_dict:
value_son = SON()
if isinstance(value_dict['$near'], dict):
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
if (get_connection().max_wire_version <= 1):
value_son['$maxDistance'] = value_dict[
'$maxDistance']
else:
value_son['$near'] = SON(value_son['$near'])
value_son['$near'][
'$maxDistance'] = value_dict['$maxDistance']
value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
else:
for k, v in value_dict.iteritems():
if k == '$maxDistance':

View File

@ -1,8 +1,5 @@
import copy
from itertools import product
from functools import reduce
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform

View File

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
import unittest
import sys
sys.path[0:0] = [""]
import os
import pymongo
from nose.plugins.skip import SkipTest
@ -432,6 +432,7 @@ class IndexesTest(unittest.TestCase):
class Test(Document):
a = IntField()
b = IntField()
meta = {
'indexes': ['a'],
@ -443,16 +444,36 @@ class IndexesTest(unittest.TestCase):
obj = Test(a=1)
obj.save()
connection = get_connection()
IS_MONGODB_3 = connection.server_info()['versionArray'][0] >= 3
# Need to be explicit about covered indexes as mongoDB doesn't know if
# the documents returned might have more keys in that here.
query_plan = Test.objects(id=obj.id).exclude('a').explain()
if not IS_MONGODB_3:
self.assertFalse(query_plan['indexOnly'])
else:
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK')
query_plan = Test.objects(id=obj.id).only('id').explain()
if not IS_MONGODB_3:
self.assertTrue(query_plan['indexOnly'])
else:
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK')
query_plan = Test.objects(a=1).only('a').exclude('id').explain()
if not IS_MONGODB_3:
self.assertTrue(query_plan['indexOnly'])
else:
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN')
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'PROJECTION')
query_plan = Test.objects(a=1).explain()
if not IS_MONGODB_3:
self.assertFalse(query_plan['indexOnly'])
else:
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN')
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'FETCH')
def test_index_on_id(self):
@ -491,6 +512,9 @@ class IndexesTest(unittest.TestCase):
self.assertEqual(BlogPost.objects.count(), 10)
self.assertEqual(BlogPost.objects.hint().count(), 10)
# PyMongo 3.0 bug only, works correctly with 2.X and 3.0.1+ versions
if pymongo.version != '3.0':
self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10)
self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10)
@ -862,11 +886,14 @@ class IndexesTest(unittest.TestCase):
index_info = TestDoc._get_collection().index_information()
for key in index_info:
del index_info[key]['v'] # drop the index version - we don't care about that here
if 'ns' in index_info[key]:
del index_info[key]['ns'] # drop the index namespace - we don't care about that here, MongoDB 3+
if 'dropDups' in index_info[key]:
del index_info[key]['dropDups'] # drop the index dropDups - it is deprecated in MongoDB 3+
self.assertEqual(index_info, {
'txt_1': {
'key': [('txt', 1)],
'dropDups': False,
'background': False
},
'_id_': {
@ -874,7 +901,6 @@ class IndexesTest(unittest.TestCase):
},
'txt2_1': {
'key': [('txt2', 1)],
'dropDups': False,
'background': False
},
'_cls_1': {

View File

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
import sys
from nose.plugins.skip import SkipTest
sys.path[0:0] = [""]
import datetime
@ -2488,10 +2490,29 @@ class FieldTest(unittest.TestCase):
id = BinaryField(primary_key=True)
Attachment.drop_collection()
att = Attachment(id=uuid.uuid4().bytes).save()
binary_id = uuid.uuid4().bytes
att = Attachment(id=binary_id).save()
self.assertEqual(1, Attachment.objects.count())
self.assertEqual(1, Attachment.objects.filter(id=att.id).count())
# TODO use assertIsNotNone once Python 2.6 support is dropped
self.assertTrue(Attachment.objects.filter(id=att.id).first() is not None)
att.delete()
self.assertEqual(0, Attachment.objects.count())
def test_binary_field_primary_filter_by_binary_pk_as_str(self):
raise SkipTest("Querying by id as string is not currently supported")
class Attachment(Document):
id = BinaryField(primary_key=True)
Attachment.drop_collection()
binary_id = uuid.uuid4().bytes
att = Attachment(id=binary_id).save()
self.assertEqual(1, Attachment.objects.filter(id=binary_id).count())
# TODO use assertIsNotNone once Python 2.6 support is dropped
self.assertTrue(Attachment.objects.filter(id=binary_id).first() is not None)
att.delete()
self.assertEqual(0, Attachment.objects.count())
def test_choices_validation(self):

View File

@ -336,12 +336,11 @@ class GeoFieldTest(unittest.TestCase):
Location.drop_collection()
Parent.drop_collection()
list(Parent.objects)
collection = Parent._get_collection()
info = collection.index_information()
Parent(name='Berlin').save()
info = Parent._get_collection().index_information()
self.assertFalse('location_2d' in info)
info = Location._get_collection().index_information()
self.assertTrue('location_2d' in info)
self.assertEqual(len(Parent._geo_indices()), 0)
self.assertEqual(len(Location._geo_indices()), 1)

View File

@ -1,12 +1,16 @@
import sys
sys.path[0:0] = [""]
import unittest
from datetime import datetime, timedelta
from mongoengine import *
from pymongo.errors import OperationFailure
from mongoengine import *
from mongoengine.connection import get_connection
from nose.plugins.skip import SkipTest
__all__ = ("GeoQueriesTest",)
@ -141,7 +145,13 @@ class GeoQueriesTest(unittest.TestCase):
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working
"""
raise SkipTest("https://jira.mongodb.org/browse/SERVER-14039")
# Needs MongoDB > 2.6.4 https://jira.mongodb.org/browse/SERVER-14039
connection = get_connection()
info = connection.test.command('buildInfo')
mongodb_version = tuple([int(i) for i in info['version'].split('.')])
if mongodb_version < (2, 6, 4):
raise SkipTest("Need MongoDB version 2.6.4+")
class Point(Document):
location = GeoPointField()
@ -167,6 +177,13 @@ class GeoQueriesTest(unittest.TestCase):
points = Point.objects(location__near_sphere=[-122, 37.5],
location__max_distance=60 / earth_radius)
# This test is sometimes failing with Mongo internals non-sense.
# See https://travis-ci.org/MongoEngine/mongoengine/builds/58729101
try:
points.count()
except OperationFailure:
raise SkipTest("Sometimes MongoDB ignores its capacities on maxDistance")
self.assertEqual(points.count(), 2)
# Finds both points, but orders the north point first because it's

View File

@ -17,7 +17,7 @@ from bson import ObjectId
from mongoengine import *
from mongoengine.connection import get_connection, get_db
from mongoengine.python_support import PY3
from mongoengine.python_support import PY3, IS_PYMONGO_3
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.queryset import (QuerySet, QuerySetManager,
MultipleObjectsReturned, DoesNotExist,
@ -51,6 +51,20 @@ def skip_older_mongodb(f):
return _inner
def skip_pymongo3(f):
def _inner(*args, **kwargs):
if IS_PYMONGO_3:
raise SkipTest("Useless with PyMongo 3+")
return f(*args, **kwargs)
_inner.__name__ = f.__name__
_inner.__doc__ = f.__doc__
return _inner
class QuerySetTest(unittest.TestCase):
def setUp(self):
@ -694,6 +708,11 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
# get MongoDB version info
connection = get_connection()
info = connection.test.command('buildInfo')
mongodb_version = tuple([int(i) for i in info['version'].split('.')])
# Recreates the collection
self.assertEqual(0, Blog.objects.count())
@ -710,7 +729,7 @@ class QuerySetTest(unittest.TestCase):
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
Blog.objects.insert(blogs, load_bulk=False)
if (get_connection().max_wire_version <= 1):
if mongodb_version < (2, 6):
self.assertEqual(q, 1)
else:
# profiling logs each doc now in the bulk op
@ -723,7 +742,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 0)
Blog.objects.insert(blogs)
if (get_connection().max_wire_version <= 1):
if mongodb_version < (2, 6):
self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch
else:
# 99 for insert, and 1 for in bulk fetch
@ -855,8 +874,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 3)
@skip_pymongo3
def test_slave_okay(self):
"""Ensures that a query can take slave_okay syntax
"""Ensures that a query can take slave_okay syntax.
Useless with PyMongo 3+ as well as with MongoDB 3+.
"""
person1 = self.Person(name="User A", age=20)
person1.save()
@ -869,6 +890,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
@skip_older_mongodb
@skip_pymongo3
def test_cursor_args(self):
"""Ensures the cursor args can be set as expected
"""
@ -2926,8 +2949,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(query.count(), 3)
self.assertEqual(query._query, {'$text': {'$search': 'brasil'}})
cursor_args = query._cursor_args
if not IS_PYMONGO_3:
cursor_args_fields = cursor_args['fields']
else:
cursor_args_fields = cursor_args['projection']
self.assertEqual(
cursor_args['fields'], {'_text_score': {'$meta': 'textScore'}})
cursor_args_fields, {'_text_score': {'$meta': 'textScore'}})
text_scores = [i.get_text_score() for i in query]
self.assertEqual(len(text_scores), 3)
@ -3992,8 +4019,11 @@ class QuerySetTest(unittest.TestCase):
bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY))
self.assertEqual([], bars)
self.assertRaises(ConfigurationError, Bar.objects,
read_preference='Primary')
if not IS_PYMONGO_3:
error_class = ConfigurationError
else:
error_class = TypeError
self.assertRaises(error_class, Bar.objects, read_preference='Primary')
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(

View File

@ -1,4 +1,7 @@
import sys
import datetime
from pymongo.errors import OperationFailure
sys.path[0:0] = [""]
try:
@ -6,8 +9,6 @@ try:
except ImportError:
import unittest
import datetime
import pymongo
from bson.tz_util import utc
@ -15,10 +16,18 @@ from mongoengine import (
connect, register_connection,
Document, DateTimeField
)
from mongoengine.python_support import IS_PYMONGO_3
import mongoengine.connection
from mongoengine.connection import get_db, get_connection, ConnectionError
def get_tz_awareness(connection):
if not IS_PYMONGO_3:
return connection.tz_aware
else:
return connection.codec_options.tz_aware
class ConnectionTest(unittest.TestCase):
def tearDown(self):
@ -51,6 +60,13 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest', alias='testdb2')
actual_connection = get_connection('testdb2')
# Handle PyMongo 3+ Async Connection
if IS_PYMONGO_3:
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
# Purposely not catching exception to fail test if thrown.
expected_connection.server_info()
self.assertEqual(expected_connection, actual_connection)
def test_connect_uri(self):
@ -64,6 +80,7 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
@ -90,6 +107,7 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/')
@ -107,6 +125,7 @@ class ConnectionTest(unittest.TestCase):
def test_connect_uri_with_authsource(self):
"""Ensure that the connect() method works well with
the option `authSource` in URI.
This feature was introduced in MongoDB 2.4 and removed in 2.6
"""
# Create users
c = connect('mongoenginetest')
@ -114,6 +133,11 @@ class ConnectionTest(unittest.TestCase):
c.admin.add_user('username', 'password')
# Authentication fails without "authSource"
if IS_PYMONGO_3:
test_conn = connect('mongoenginetest', alias='test2',
host='mongodb://username:password@localhost/mongoenginetest')
self.assertRaises(OperationFailure, test_conn.server_info)
else:
self.assertRaises(
ConnectionError, connect, 'mongoenginetest', alias='test1',
host='mongodb://username:password@localhost/mongoenginetest'
@ -121,11 +145,13 @@ class ConnectionTest(unittest.TestCase):
self.assertRaises(ConnectionError, get_db, 'test1')
# Authentication succeeds with "authSource"
connect(
test_conn2 = connect(
'mongoenginetest', alias='test2',
host=('mongodb://username:password@localhost/'
'mongoenginetest?authSource=admin')
)
# This will fail starting from MongoDB 2.6+
# test_conn2.server_info()
db = get_db('test2')
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest')
@ -160,11 +186,11 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest', alias='t1', tz_aware=True)
conn = get_connection('t1')
self.assertTrue(conn.tz_aware)
self.assertTrue(get_tz_awareness(conn))
connect('mongoenginetest2', alias='t2')
conn = get_connection('t2')
self.assertFalse(conn.tz_aware)
self.assertFalse(get_tz_awareness(conn))
def test_datetime(self):
connect('mongoenginetest', tz_aware=True)
@ -188,8 +214,17 @@ class ConnectionTest(unittest.TestCase):
self.assertEqual(len(mongo_connections.items()), 2)
self.assertTrue('t1' in mongo_connections.keys())
self.assertTrue('t2' in mongo_connections.keys())
if not IS_PYMONGO_3:
self.assertEqual(mongo_connections['t1'].host, 'localhost')
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
else:
# Handle PyMongo 3+ Async Connection
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
# Purposely not catching exception to fail test if thrown.
mongo_connections['t1'].server_info()
mongo_connections['t2'].server_info()
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
if __name__ == '__main__':

View File

@ -1,17 +1,33 @@
import sys
sys.path[0:0] = [""]
import unittest
import pymongo
from pymongo import ReadPreference, ReplicaSetConnection
from pymongo import ReadPreference
from mongoengine.python_support import IS_PYMONGO_3
if IS_PYMONGO_3:
from pymongo import MongoClient
CONN_CLASS = MongoClient
READ_PREF = ReadPreference.SECONDARY
else:
from pymongo import ReplicaSetConnection
CONN_CLASS = ReplicaSetConnection
READ_PREF = ReadPreference.SECONDARY_ONLY
import mongoengine
from mongoengine import *
from mongoengine.connection import get_db, get_connection, ConnectionError
from mongoengine.connection import ConnectionError
class ConnectionTest(unittest.TestCase):
def setUp(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
mongoengine.connection._dbs = {}
def tearDown(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
@ -22,14 +38,17 @@ class ConnectionTest(unittest.TestCase):
"""
try:
conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=ReadPreference.SECONDARY_ONLY)
conn = connect(db='mongoenginetest',
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
read_preference=READ_PREF)
except ConnectionError, e:
return
if not isinstance(conn, ReplicaSetConnection):
if not isinstance(conn, CONN_CLASS):
# really???
return
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_ONLY)
self.assertEqual(conn.read_preference, READ_PREF)
if __name__ == '__main__':
unittest.main()