Merge pull request #946 from MRigal/fix/pymongo3-connection

fixes #946
This commit is contained in:
David Bordeynik
2015-05-11 15:51:51 +03:00
18 changed files with 293 additions and 91 deletions

View File

@@ -1,5 +1,4 @@
import weakref
import functools
import itertools
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned

View File

@@ -1,13 +1,11 @@
import warnings
import pymongo
from mongoengine.common import _import_class
from mongoengine.errors import InvalidDocumentError
from mongoengine.python_support import PY3
from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
MultipleObjectsReturned,
QuerySet, QuerySetManager)
QuerySetManager)
from mongoengine.base.common import _document_registry, ALLOW_INHERITANCE
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField

View File

@@ -1,11 +1,16 @@
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
from pymongo import MongoClient, ReadPreference, uri_parser
from mongoengine.python_support import IS_PYMONGO_3
__all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME']
DEFAULT_CONNECTION_NAME = 'default'
if IS_PYMONGO_3:
READ_PREFERENCE = ReadPreference.PRIMARY
else:
from pymongo import MongoReplicaSetClient
READ_PREFERENCE = False
class ConnectionError(Exception):
@@ -18,7 +23,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None,
read_preference=False,
read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None,
**kwargs):
"""Add a connection.
@@ -109,7 +114,8 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Discard replicaSet if not base string
if not isinstance(conn_settings['replicaSet'], basestring):
conn_settings.pop('replicaSet', None)
connection_class = MongoReplicaSetClient
if not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
try:
connection = None

View File

@@ -1,11 +1,8 @@
import warnings
import hashlib
import pymongo
import re
from pymongo.read_preferences import ReadPreference
from bson import ObjectId
from bson.dbref import DBRef
from mongoengine import signals
from mongoengine.common import _import_class
@@ -19,7 +16,7 @@ from mongoengine.base import (
ALLOW_INHERITANCE,
get_document
)
from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError
from mongoengine.errors import InvalidQueryError, InvalidDocumentError
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
@@ -169,6 +166,7 @@ class Document(BaseDocument):
@classmethod
def _get_collection(cls):
"""Returns the collection for the document."""
# TODO: use new get_collection() with PyMongo3 ?
if not hasattr(cls, '_collection') or cls._collection is None:
db = cls._get_db()
collection_name = cls._get_collection_name()
@@ -310,6 +308,13 @@ class Document(BaseDocument):
object_id = collection.insert(doc, **write_concern)
else:
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = self._qs.filter(pk=pk_as_mongo_obj).first() and \
self._qs.filter(pk=pk_as_mongo_obj).first().pk
else:
object_id = doc['_id']
updates, removals = self._delta()

View File

@@ -1,6 +1,13 @@
"""Helper functions and types to aid with Python 2.5 - 3 support."""
import sys
import pymongo
if pymongo.version_tuple[0] < 3:
IS_PYMONGO_3 = False
else:
IS_PYMONGO_3 = True
PY3 = sys.version_info[0] == 3
@@ -12,7 +19,7 @@ if PY3:
return codecs.latin_1_encode(s)[0]
bin_type = bytes
txt_type = str
txt_type = str
else:
try:
from cStringIO import StringIO

View File

@@ -21,10 +21,14 @@ from mongoengine.common import _import_class
from mongoengine.base.common import get_document
from mongoengine.errors import (OperationError, NotUniqueError,
InvalidQueryError, LookUpError)
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
if IS_PYMONGO_3:
from pymongo.collection import ReturnDocument
__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
@@ -158,7 +162,8 @@ class BaseQuerySet(object):
if queryset._as_pymongo:
return queryset._get_as_pymongo(queryset._cursor[key])
return queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference, only_fields=self.only_fields)
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
raise AttributeError
@@ -423,7 +428,7 @@ class BaseQuerySet(object):
if call_document_delete:
cnt = 0
for doc in queryset:
doc.delete(write_concern=write_concern)
doc.delete(**write_concern)
cnt += 1
return cnt
@@ -545,7 +550,7 @@ class BaseQuerySet(object):
:param upsert: insert if document doesn't exist (default ``False``)
:param full_response: return the entire response object from the
server (default ``False``)
server (default ``False``, not available for PyMongo 3+)
:param remove: remove rather than updating (default ``False``)
:param new: return updated rather than original document
(default ``False``)
@@ -563,13 +568,31 @@ class BaseQuerySet(object):
queryset = self.clone()
query = queryset._query
update = transform.update(queryset._document, **update)
if not IS_PYMONGO_3 or not remove:
update = transform.update(queryset._document, **update)
sort = queryset._ordering
try:
result = queryset._collection.find_and_modify(
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
if IS_PYMONGO_3:
if full_response:
msg = ("With PyMongo 3+, it is not possible anymore to get the full response.")
warnings.warn(msg, DeprecationWarning)
if remove:
result = queryset._collection.find_one_and_delete(
query, sort=sort, **self._cursor_args)
else:
if new:
return_doc = ReturnDocument.AFTER
else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
query, update, upsert=upsert, sort=sort, return_document=return_doc,
**self._cursor_args)
else:
result = queryset._collection.find_and_modify(
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
except pymongo.errors.DuplicateKeyError, err:
raise NotUniqueError(u"Update failed (%s)" % err)
except pymongo.errors.OperationFailure, err:
@@ -907,13 +930,18 @@ class BaseQuerySet(object):
plan = pprint.pformat(plan)
return plan
# DEPRECATED. Has no more impact on PyMongo 3+
def snapshot(self, enabled):
"""Enable or disable snapshot mode when querying.
:param enabled: whether or not snapshot mode is enabled
..versionchanged:: 0.5 - made chainable
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = "snapshot is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._snapshot = enabled
return queryset
@@ -929,11 +957,17 @@ class BaseQuerySet(object):
queryset._timeout = enabled
return queryset
# DEPRECATED. Has no more impact on PyMongo 3+
def slave_okay(self, enabled):
"""Enable or disable the slave_okay when querying.
:param enabled: whether or not the slave_okay is enabled
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._slave_okay = enabled
return queryset
@@ -1383,22 +1417,34 @@ class BaseQuerySet(object):
@property
def _cursor_args(self):
cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout
}
if self._read_preference is not None:
cursor_args['read_preference'] = self._read_preference
if not IS_PYMONGO_3:
fields_name = 'fields'
cursor_args = {
'timeout': self._timeout,
'snapshot': self._snapshot
}
if self._read_preference is not None:
cursor_args['read_preference'] = self._read_preference
else:
cursor_args['slave_okay'] = self._slave_okay
else:
cursor_args['slave_okay'] = self._slave_okay
fields_name = 'projection'
# snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers
if self._snapshot:
msg = "The snapshot option is not anymore available with PyMongo 3+"
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': self._timeout
}
if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict()
cursor_args[fields_name] = self._loaded_fields.as_dict()
if self._search_text:
if 'fields' not in cursor_args:
cursor_args['fields'] = {}
if fields_name not in cursor_args:
cursor_args[fields_name] = {}
cursor_args['fields']['_text_score'] = {'$meta': "textScore"}
cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"}
return cursor_args

View File

@@ -6,7 +6,7 @@ from bson import SON
from mongoengine.base.fields import UPDATE_OPERATORS
from mongoengine.connection import get_connection
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError, LookUpError
from mongoengine.errors import InvalidQueryError
__all__ = ('query', 'update')
@@ -128,20 +128,15 @@ def query(_doc_cls=None, _field_operation=False, **query):
mongo_query[key].update(value)
# $maxDistance needs to come last - convert to SON
value_dict = mongo_query[key]
if ('$maxDistance' in value_dict and '$near' in value_dict):
if '$maxDistance' in value_dict and '$near' in value_dict:
value_son = SON()
if isinstance(value_dict['$near'], dict):
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
if (get_connection().max_wire_version <= 1):
value_son['$maxDistance'] = value_dict[
'$maxDistance']
else:
value_son['$near'] = SON(value_son['$near'])
value_son['$near'][
'$maxDistance'] = value_dict['$maxDistance']
value_son['$near'] = SON(value_son['$near'])
value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
else:
for k, v in value_dict.iteritems():
if k == '$maxDistance':

View File

@@ -1,8 +1,5 @@
import copy
from itertools import product
from functools import reduce
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform