Merge branch 'master' into remove_save_embedded

This commit is contained in:
Erdenezul Batmunkh
2019-06-11 12:41:11 +02:00
73 changed files with 5437 additions and 3881 deletions

View File

@@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
list(signals.__all__) + list(errors.__all__))
VERSION = (0, 16, 3)
VERSION = (0, 17, 0)
def get_version():

View File

@@ -13,7 +13,7 @@ _document_registry = {}
def get_document(name):
"""Get a document class by name."""
"""Get a registered Document class by name."""
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
@@ -30,3 +30,12 @@ def get_document(name):
been imported?
""".strip() % name)
return doc
def _get_documents_by_db(connection_alias, default_connection_alias):
"""Get all registered Documents class attached to a given database"""
def get_doc_alias(doc_cls):
return doc_cls._meta.get('db_alias', default_connection_alias)
return [doc_cls for doc_cls in _document_registry.values()
if get_doc_alias(doc_cls) == connection_alias]

View File

@@ -2,6 +2,7 @@ import weakref
from bson import DBRef
import six
from six import iteritems
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
@@ -363,7 +364,7 @@ class StrictDict(object):
_classes = {}
def __init__(self, **kwargs):
for k, v in kwargs.iteritems():
for k, v in iteritems(kwargs):
setattr(self, k, v)
def __getitem__(self, key):
@@ -411,7 +412,7 @@ class StrictDict(object):
return (key for key in self.__slots__ if hasattr(self, key))
def __len__(self):
return len(list(self.iteritems()))
return len(list(iteritems(self)))
def __eq__(self, other):
return self.items() == other.items()

View File

@@ -5,6 +5,7 @@ from functools import partial
from bson import DBRef, ObjectId, SON, json_util
import pymongo
import six
from six import iteritems
from mongoengine import signals
from mongoengine.base.common import get_document
@@ -83,7 +84,7 @@ class BaseDocument(object):
self._dynamic_fields = SON()
# Assign default values to instance
for key, field in self._fields.iteritems():
for key, field in iteritems(self._fields):
if self._db_field_map.get(key, key) in __only_fields:
continue
value = getattr(self, key, None)
@@ -95,14 +96,14 @@ class BaseDocument(object):
# Set passed values after initialisation
if self._dynamic:
dynamic_data = {}
for key, value in values.iteritems():
for key, value in iteritems(values):
if key in self._fields or key == '_id':
setattr(self, key, value)
else:
dynamic_data[key] = value
else:
FileField = _import_class('FileField')
for key, value in values.iteritems():
for key, value in iteritems(values):
key = self._reverse_db_field_map.get(key, key)
if key in self._fields or key in ('id', 'pk', '_cls'):
if __auto_convert and value is not None:
@@ -118,7 +119,7 @@ class BaseDocument(object):
if self._dynamic:
self._dynamic_lock = False
for key, value in dynamic_data.iteritems():
for key, value in iteritems(dynamic_data):
setattr(self, key, value)
# Flag initialised
@@ -292,8 +293,7 @@ class BaseDocument(object):
"""
Return as SON data ready for use with MongoDB.
"""
if not fields:
fields = []
fields = fields or []
data = SON()
data['_id'] = None
@@ -513,7 +513,7 @@ class BaseDocument(object):
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
iterator = iteritems(data)
for index_or_key, value in iterator:
item_key = '%s%s.' % (base_key, index_or_key)
@@ -678,7 +678,7 @@ class BaseDocument(object):
# Convert SON to a data dict, making sure each key is a string and
# corresponds to the right db field.
data = {}
for key, value in son.iteritems():
for key, value in iteritems(son):
key = str(key)
key = cls._db_field_map.get(key, key)
data[key] = value
@@ -694,7 +694,7 @@ class BaseDocument(object):
if not _auto_dereference:
fields = copy.deepcopy(fields)
for field_name, field in fields.iteritems():
for field_name, field in iteritems(fields):
field._auto_dereference = _auto_dereference
if field.db_field in data:
value = data[field.db_field]
@@ -715,7 +715,7 @@ class BaseDocument(object):
# In STRICT documents, remove any keys that aren't in cls._fields
if cls.STRICT:
data = {k: v for k, v in data.iteritems() if k in cls._fields}
data = {k: v for k, v in iteritems(data) if k in cls._fields}
obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data)
obj._changed_fields = changed_fields
@@ -882,7 +882,8 @@ class BaseDocument(object):
index = {'fields': fields, 'unique': True, 'sparse': sparse}
unique_indexes.append(index)
if field.__class__.__name__ == 'ListField':
if field.__class__.__name__ in {'EmbeddedDocumentListField',
'ListField', 'SortedListField'}:
field = field.field
# Grab any embedded document field unique indexes

View File

@@ -5,13 +5,13 @@ import weakref
from bson import DBRef, ObjectId, SON
import pymongo
import six
from six import iteritems
from mongoengine.base.common import UPDATE_OPERATORS
from mongoengine.base.datastructures import (BaseDict, BaseList,
EmbeddedDocumentList)
from mongoengine.common import _import_class
from mongoengine.errors import ValidationError
from mongoengine.errors import DeprecatedError, ValidationError
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
'GeoJsonBaseField')
@@ -52,8 +52,8 @@ class BaseField(object):
unique with.
:param primary_key: Mark this field as the primary key. Defaults to False.
:param validation: (optional) A callable to validate the value of the
field. Generally this is deprecated in favour of the
`FIELD.validate` method
field. The callable takes the value as parameter and should raise
a ValidationError if validation fails
:param choices: (optional) The valid choices
:param null: (optional) If the field value can be null. If no and there is a default value
then the default value is set
@@ -225,10 +225,18 @@ class BaseField(object):
# check validation argument
if self.validation is not None:
if callable(self.validation):
if not self.validation(value):
self.error('Value does not match custom validation method')
try:
# breaking change of 0.18
# Get rid of True/False-type return for the validation method
# in favor of having validation raising a ValidationError
ret = self.validation(value)
if ret is not None:
raise DeprecatedError('validation argument for `%s` must not return anything, '
'it should raise a ValidationError if validation fails' % self.name)
except ValidationError as ex:
self.error(str(ex))
else:
raise ValueError('validation argument for "%s" must be a '
raise ValueError('validation argument for `"%s"` must be a '
'callable.' % self.name)
self.validate(value, **kwargs)
@@ -275,11 +283,16 @@ class ComplexBaseField(BaseField):
_dereference = _import_class('DeReference')()
if instance._initialised and dereference and instance._data.get(self.name):
if (instance._initialised and
dereference and
instance._data.get(self.name) and
not getattr(instance._data[self.name], '_dereferenced', False)):
instance._data[self.name] = _dereference(
instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name
)
if hasattr(instance._data[self.name], '_dereferenced'):
instance._data[self.name]._dereferenced = True
value = super(ComplexBaseField, self).__get__(instance, owner)
@@ -382,11 +395,11 @@ class ComplexBaseField(BaseField):
if self.field:
value_dict = {
key: self.field._to_mongo_safe_call(item, use_db_field, fields)
for key, item in value.iteritems()
for key, item in iteritems(value)
}
else:
value_dict = {}
for k, v in value.iteritems():
for k, v in iteritems(value):
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
@@ -423,7 +436,7 @@ class ComplexBaseField(BaseField):
errors = {}
if self.field:
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
sequence = value.iteritems()
sequence = iteritems(value)
else:
sequence = enumerate(value)
for k, v in sequence:

View File

@@ -1,6 +1,7 @@
import warnings
import six
from six import iteritems, itervalues
from mongoengine.base.common import _document_registry
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField
@@ -62,7 +63,7 @@ class DocumentMetaclass(type):
# Standard object mixin - merge in any Fields
if not hasattr(base, '_meta'):
base_fields = {}
for attr_name, attr_value in base.__dict__.iteritems():
for attr_name, attr_value in iteritems(base.__dict__):
if not isinstance(attr_value, BaseField):
continue
attr_value.name = attr_name
@@ -74,7 +75,7 @@ class DocumentMetaclass(type):
# Discover any document fields
field_names = {}
for attr_name, attr_value in attrs.iteritems():
for attr_name, attr_value in iteritems(attrs):
if not isinstance(attr_value, BaseField):
continue
attr_value.name = attr_name
@@ -103,7 +104,7 @@ class DocumentMetaclass(type):
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name)
for v in doc_fields.itervalues()))
for v in itervalues(doc_fields)))
#
# Set document hierarchy
@@ -173,7 +174,7 @@ class DocumentMetaclass(type):
f.__dict__.update({'im_self': getattr(f, '__self__')})
# Handle delete rules
for field in new_class._fields.itervalues():
for field in itervalues(new_class._fields):
f = field
if f.owner_document is None:
f.owner_document = new_class
@@ -183,9 +184,6 @@ class DocumentMetaclass(type):
if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError('CachedReferenceFields is not '
'allowed in EmbeddedDocuments')
if not f.document_type:
raise InvalidDocumentError(
'Document is not available to sync')
if f.auto_sync:
f.start_listener()
@@ -375,7 +373,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class.objects = QuerySetManager()
# Validate the fields and set primary key if needed
for field_name, field in new_class._fields.iteritems():
for field_name, field in iteritems(new_class._fields):
if field.primary_key:
# Ensure only one primary key is set
current_pk = new_class._meta.get('id_field')
@@ -438,7 +436,7 @@ class MetaDict(dict):
_merge_options = ('indexes',)
def merge(self, new_options):
for k, v in new_options.iteritems():
for k, v in iteritems(new_options):
if k in self._merge_options:
self[k] = self.get(k, []) + v
else:

View File

@@ -31,7 +31,6 @@ def _import_class(cls_name):
field_classes = _field_list_cache
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)
if cls_name == 'BaseDocument':
@@ -43,14 +42,11 @@ def _import_class(cls_name):
elif cls_name in field_classes:
from mongoengine import fields as module
import_classes = field_classes
elif cls_name in queryset_classes:
from mongoengine import queryset as module
import_classes = queryset_classes
elif cls_name in deref_classes:
from mongoengine import dereference as module
import_classes = deref_classes
else:
raise ValueError('No import set for: ' % cls_name)
raise ValueError('No import set for: %s' % cls_name)
for cls in import_classes:
_class_registry_cache[cls] = getattr(module, cls)

View File

@@ -1,19 +1,22 @@
from pymongo import MongoClient, ReadPreference, uri_parser
from pymongo.database import _check_name
import six
from mongoengine.python_support import IS_PYMONGO_3
__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME']
__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all',
'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME',
'get_db', 'get_connection']
DEFAULT_CONNECTION_NAME = 'default'
DEFAULT_DATABASE_NAME = 'test'
DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 27017
if IS_PYMONGO_3:
READ_PREFERENCE = ReadPreference.PRIMARY
else:
from pymongo import MongoReplicaSetClient
READ_PREFERENCE = False
_connection_settings = {}
_connections = {}
_dbs = {}
READ_PREFERENCE = ReadPreference.PRIMARY
class MongoEngineConnectionError(Exception):
@@ -23,45 +26,48 @@ class MongoEngineConnectionError(Exception):
pass
_connection_settings = {}
_connections = {}
_dbs = {}
def _check_db_name(name):
"""Check if a database name is valid.
This functionality is copied from pymongo Database class constructor.
"""
if not isinstance(name, six.string_types):
raise TypeError('name must be an instance of %s' % six.string_types)
elif name != '$external':
_check_name(name)
def register_connection(alias, db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs):
"""Add a connection.
def _get_connection_settings(
db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs):
"""Get the connection settings as a dict
:param alias: the name that will be used to refer to this connection
throughout MongoEngine
:param name: the name of the specific database to use
:param db: the name of the database to use, for compatibility with connect
:param host: the host name of the :program:`mongod` instance to connect to
:param port: the port that the :program:`mongod` instance is running on
:param read_preference: The read preference for the collection
** Added pymongo 2.1
:param username: username to authenticate with
:param password: password to authenticate with
:param authentication_source: database to authenticate against
:param authentication_mechanism: database authentication mechanisms.
: param db: the name of the database to use, for compatibility with connect
: param name: the name of the specific database to use
: param host: the host name of the: program: `mongod` instance to connect to
: param port: the port that the: program: `mongod` instance is running on
: param read_preference: The read preference for the collection
: param username: username to authenticate with
: param password: password to authenticate with
: param authentication_source: database to authenticate against
: param authentication_mechanism: database authentication mechanisms.
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: ad-hoc parameters to be passed into the pymongo driver,
: param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock: // ` as db host prefix)
: param kwargs: ad-hoc parameters to be passed into the pymongo driver,
for example maxpoolsize, tz_aware, etc. See the documentation
for pymongo's `MongoClient` for a full list.
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = {
'name': name or db or 'test',
'host': host or 'localhost',
'port': port or 27017,
'name': name or db or DEFAULT_DATABASE_NAME,
'host': host or DEFAULT_HOST,
'port': port or DEFAULT_PORT,
'read_preference': read_preference,
'username': username,
'password': password,
@@ -69,6 +75,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
'authentication_mechanism': authentication_mechanism
}
_check_db_name(conn_settings['name'])
conn_host = conn_settings['host']
# Host can be a list or a string, so if string, force to a list.
@@ -104,16 +111,28 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
if IS_PYMONGO_3 and 'readpreference' in uri_options:
if 'readpreference' in uri_options:
read_preferences = (
ReadPreference.NEAREST,
ReadPreference.PRIMARY,
ReadPreference.PRIMARY_PREFERRED,
ReadPreference.SECONDARY,
ReadPreference.SECONDARY_PREFERRED)
read_pf_mode = uri_options['readpreference'].lower()
ReadPreference.SECONDARY_PREFERRED,
)
# Starting with PyMongo v3.5, the "readpreference" option is
# returned as a string (e.g. "secondaryPreferred") and not an
# int (e.g. 3).
# TODO simplify the code below once we drop support for
# PyMongo v3.4.
read_pf_mode = uri_options['readpreference']
if isinstance(read_pf_mode, six.string_types):
read_pf_mode = read_pf_mode.lower()
for preference in read_preferences:
if preference.name.lower() == read_pf_mode:
if (
preference.name.lower() == read_pf_mode or
preference.mode == read_pf_mode
):
conn_settings['read_preference'] = preference
break
else:
@@ -125,17 +144,74 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
kwargs.pop('is_slave', None)
conn_settings.update(kwargs)
return conn_settings
def register_connection(alias, db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs):
"""Register the connection settings.
: param alias: the name that will be used to refer to this connection
throughout MongoEngine
: param name: the name of the specific database to use
: param db: the name of the database to use, for compatibility with connect
: param host: the host name of the: program: `mongod` instance to connect to
: param port: the port that the: program: `mongod` instance is running on
: param read_preference: The read preference for the collection
: param username: username to authenticate with
: param password: password to authenticate with
: param authentication_source: database to authenticate against
: param authentication_mechanism: database authentication mechanisms.
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
: param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock: // ` as db host prefix)
: param kwargs: ad-hoc parameters to be passed into the pymongo driver,
for example maxpoolsize, tz_aware, etc. See the documentation
for pymongo's `MongoClient` for a full list.
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = _get_connection_settings(
db=db, name=name, host=host, port=port,
read_preference=read_preference,
username=username, password=password,
authentication_source=authentication_source,
authentication_mechanism=authentication_mechanism,
**kwargs)
_connection_settings[alias] = conn_settings
def disconnect(alias=DEFAULT_CONNECTION_NAME):
"""Close the connection with a given alias."""
from mongoengine.base.common import _get_documents_by_db
from mongoengine import Document
if alias in _connections:
get_connection(alias=alias).close()
del _connections[alias]
if alias in _dbs:
# Detach all cached collections in Documents
for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME):
if issubclass(doc_cls, Document): # Skip EmbeddedDocument
doc_cls._disconnect()
del _dbs[alias]
if alias in _connection_settings:
del _connection_settings[alias]
def disconnect_all():
"""Close all registered database."""
for alias in list(_connections.keys()):
disconnect(alias)
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
"""Return a connection with a given alias."""
@@ -159,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
raise MongoEngineConnectionError(msg)
def _clean_settings(settings_dict):
# set literal more efficient than calling set function
irrelevant_fields_set = {
'name', 'username', 'password',
'authentication_source', 'authentication_mechanism'
@@ -169,10 +244,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
if k not in irrelevant_fields_set
}
raw_conn_settings = _connection_settings[alias].copy()
# Retrieve a copy of the connection settings associated with the requested
# alias and remove the database name and authentication info (we don't
# care about them at this point).
conn_settings = _clean_settings(_connection_settings[alias].copy())
conn_settings = _clean_settings(raw_conn_settings)
# Determine if we should use PyMongo's or mongomock's MongoClient.
is_mock = conn_settings.pop('is_mock', False)
@@ -186,51 +263,60 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
else:
connection_class = MongoClient
# For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if 'replicaSet' in conn_settings and not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# hosts_or_uri has to be a string, so if 'host' was provided
# as a list, join its parts and separate them by ','
if isinstance(conn_settings['hosts_or_uri'], list):
conn_settings['hosts_or_uri'] = ','.join(
conn_settings['hosts_or_uri'])
# Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None)
# Iterate over all of the connection settings and if a connection with
# the same parameters is already established, use it instead of creating
# a new one.
existing_connection = None
connection_settings_iterator = (
(db_alias, settings.copy())
for db_alias, settings in _connection_settings.items()
)
for db_alias, connection_settings in connection_settings_iterator:
connection_settings = _clean_settings(connection_settings)
if conn_settings == connection_settings and _connections.get(db_alias):
existing_connection = _connections[db_alias]
break
# Re-use existing connection if one is suitable
existing_connection = _find_existing_connection(raw_conn_settings)
# If an existing connection was found, assign it to the new alias
if existing_connection:
_connections[alias] = existing_connection
else:
# Otherwise, create the new connection for this alias. Raise
# MongoEngineConnectionError if it can't be established.
try:
_connections[alias] = connection_class(**conn_settings)
except Exception as e:
raise MongoEngineConnectionError(
'Cannot connect to database %s :\n%s' % (alias, e))
_connections[alias] = _create_connection(alias=alias,
connection_class=connection_class,
**conn_settings)
return _connections[alias]
def _create_connection(alias, connection_class, **connection_settings):
"""
Create the new connection for this alias. Raise
MongoEngineConnectionError if it can't be established.
"""
try:
return connection_class(**connection_settings)
except Exception as e:
raise MongoEngineConnectionError(
'Cannot connect to database %s :\n%s' % (alias, e))
def _find_existing_connection(connection_settings):
"""
Check if an existing connection could be reused
Iterate over all of the connection settings and if an existing connection
with the same parameters is suitable, return it
:param connection_settings: the settings of the new connection
:return: An existing connection or None
"""
connection_settings_bis = (
(db_alias, settings.copy())
for db_alias, settings in _connection_settings.items()
)
def _clean_settings(settings_dict):
# Only remove the name but it's important to
# keep the username/password/authentication_source/authentication_mechanism
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
return {k: v for k, v in settings_dict.items() if k != 'name'}
cleaned_conn_settings = _clean_settings(connection_settings)
for db_alias, connection_settings in connection_settings_bis:
db_conn_settings = _clean_settings(connection_settings)
if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias):
return _connections[db_alias]
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
if reconnect:
disconnect(alias)
@@ -258,14 +344,24 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
provide username and password arguments as well.
Multiple databases are supported by using aliases. Provide a separate
`alias` to connect to a different instance of :program:`mongod`.
`alias` to connect to a different instance of: program: `mongod`.
In order to replace a connection identified by a given alias, you'll
need to call ``disconnect`` first
See the docstring for `register_connection` for more details about all
supported kwargs.
.. versionchanged:: 0.6 - added multiple database support.
"""
if alias not in _connections:
if alias in _connections:
prev_conn_setting = _connection_settings[alias]
new_conn_settings = _get_connection_settings(db, **kwargs)
if new_conn_settings != prev_conn_setting:
raise MongoEngineConnectionError(
'A different connection with alias `%s` was already registered. Use disconnect() first' % alias)
else:
register_connection(alias, db, **kwargs)
return get_connection(alias)

View File

@@ -1,8 +1,11 @@
from contextlib import contextmanager
from pymongo.write_concern import WriteConcern
from six import iteritems
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
'no_sub_classes', 'query_counter', 'set_write_concern')
@@ -112,7 +115,7 @@ class no_dereference(object):
GenericReferenceField = _import_class('GenericReferenceField')
ComplexBaseField = _import_class('ComplexBaseField')
self.deref_fields = [k for k, v in self.cls._fields.iteritems()
self.deref_fields = [k for k, v in iteritems(self.cls._fields)
if isinstance(v, (ReferenceField,
GenericReferenceField,
ComplexBaseField))]
@@ -235,7 +238,7 @@ class query_counter(object):
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
issued so we need to balance that
"""
count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter
count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter
self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
return count

View File

@@ -1,5 +1,6 @@
from bson import DBRef, SON
import six
from six import iteritems
from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document)
@@ -71,7 +72,7 @@ class DeReference(object):
def _get_items_from_dict(items):
new_items = {}
for k, v in items.iteritems():
for k, v in iteritems(items):
value = v
if isinstance(v, list):
value = _get_items_from_list(v)
@@ -112,7 +113,7 @@ class DeReference(object):
depth += 1
for item in iterator:
if isinstance(item, (Document, EmbeddedDocument)):
for field_name, field in item._fields.iteritems():
for field_name, field in iteritems(item._fields):
v = item._data.get(field_name, None)
if isinstance(v, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
@@ -124,7 +125,7 @@ class DeReference(object):
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
references = self._find_references(v, depth)
for key, refs in references.iteritems():
for key, refs in iteritems(references):
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
key = field_cls
reference_map.setdefault(key, set()).update(refs)
@@ -137,7 +138,7 @@ class DeReference(object):
reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in references.iteritems():
for key, refs in iteritems(references):
reference_map.setdefault(key, set()).update(refs)
return reference_map
@@ -146,7 +147,7 @@ class DeReference(object):
"""Fetch all references and convert to their document objects
"""
object_map = {}
for collection, dbrefs in self.reference_map.iteritems():
for collection, dbrefs in iteritems(self.reference_map):
# we use getattr instead of hasattr because hasattr swallows any exception under python2
# so it could hide nasty things without raising exceptions (cfr bug #1688))
@@ -157,7 +158,7 @@ class DeReference(object):
refs = [dbref for dbref in dbrefs
if (col_name, dbref) not in object_map]
references = collection.objects.in_bulk(refs)
for key, doc in references.iteritems():
for key, doc in iteritems(references):
object_map[(col_name, key)] = doc
else: # Generic reference: use the refs data to convert to document
if isinstance(doc_type, (ListField, DictField, MapField)):
@@ -229,7 +230,7 @@ class DeReference(object):
data = []
else:
is_list = False
iterator = items.iteritems()
iterator = iteritems(items)
data = {}
depth += 1

View File

@@ -5,6 +5,7 @@ from bson.dbref import DBRef
import pymongo
from pymongo.read_preferences import ReadPreference
import six
from six import iteritems
from mongoengine import signals
from mongoengine.base import (BaseDict, BaseDocument, BaseList,
@@ -17,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern,
switch_db)
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
SaveConditionError)
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.pymongo_support import list_collection_names
from mongoengine.queryset import (NotUniqueError, OperationError,
QuerySet, transform)
@@ -175,10 +176,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
@classmethod
def _get_collection(cls):
"""Return a PyMongo collection for the document."""
if not hasattr(cls, '_collection') or cls._collection is None:
def _disconnect(cls):
"""Detach the Document class from the (cached) database collection"""
cls._collection = None
@classmethod
def _get_collection(cls):
"""Return the corresponding PyMongo collection of this document.
Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached
"""
if not hasattr(cls, '_collection') or cls._collection is None:
# Get the collection, either capped or regular.
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
cls._collection = cls._get_capped_collection()
@@ -215,7 +222,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# If the collection already exists and has different options
# (i.e. isn't capped or has different max/size), raise an error.
if collection_name in db.collection_names():
if collection_name in list_collection_names(db, include_system_collections=True):
collection = db[collection_name]
options = collection.options()
if (
@@ -240,7 +247,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
data = super(Document, self).to_mongo(*args, **kwargs)
# If '_id' is None, try and set it from self._data. If that
# doesn't exist either, remote '_id' from the SON completely.
# doesn't exist either, remove '_id' from the SON completely.
if data['_id'] is None:
if self._data.get('id') is None:
del data['_id']
@@ -346,21 +353,21 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
signal_kwargs = signal_kwargs or {}
if self._meta.get('abstract'):
raise InvalidDocumentError('Cannot save an abstract document.')
signal_kwargs = signal_kwargs or {}
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
if validate:
self.validate(clean=clean)
if write_concern is None:
write_concern = {'w': 1}
write_concern = {}
doc = self.to_mongo()
created = ('_id' not in doc or self._created or force_insert)
doc_id = self.to_mongo(fields=['id'])
created = ('_id' not in doc_id or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created, **signal_kwargs)
@@ -438,16 +445,6 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
object_id = wc_collection.insert_one(doc).inserted_id
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
return object_id
def _get_update_doc(self):
@@ -493,8 +490,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
update_doc = self._get_update_doc()
if update_doc:
upsert = save_condition is None
last_error = collection.update(select_dict, update_doc,
upsert=upsert, **write_concern)
with set_write_concern(collection, write_concern) as wc_collection:
last_error = wc_collection.update_one(
select_dict,
update_doc,
upsert=upsert
).raw_result
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
@@ -601,7 +602,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# Delete FileFields separately
FileField = _import_class('FileField')
for name, field in self._fields.iteritems():
for name, field in iteritems(self._fields):
if isinstance(field, FileField):
getattr(self, name).delete()
@@ -786,13 +787,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
.. versionchanged:: 0.10.7
:class:`OperationError` exception raised if no collection available
"""
col_name = cls._get_collection_name()
if not col_name:
coll_name = cls._get_collection_name()
if not coll_name:
raise OperationError('Document %s has no collection defined '
'(is it abstract ?)' % cls)
cls._collection = None
db = cls._get_db()
db.drop_collection(col_name)
db.drop_collection(coll_name)
@classmethod
def create_index(cls, keys, background=False, **kwargs):
@@ -807,18 +808,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
index_spec = index_spec.copy()
fields = index_spec.pop('fields')
drop_dups = kwargs.get('drop_dups', False)
if IS_PYMONGO_3 and drop_dups:
if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3:
index_spec['drop_dups'] = drop_dups
index_spec['background'] = background
index_spec.update(kwargs)
if IS_PYMONGO_3:
return cls._get_collection().create_index(fields, **index_spec)
else:
return cls._get_collection().ensure_index(fields, **index_spec)
return cls._get_collection().create_index(fields, **index_spec)
@classmethod
def ensure_index(cls, key_or_list, drop_dups=False, background=False,
@@ -833,11 +829,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
:param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value
will be removed if PyMongo3+ is used
"""
if IS_PYMONGO_3 and drop_dups:
if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3:
kwargs.update({'drop_dups': drop_dups})
return cls.create_index(key_or_list, background=background, **kwargs)
@classmethod
@@ -853,7 +847,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
drop_dups = cls._meta.get('index_drop_dups', False)
index_opts = cls._meta.get('index_opts') or {}
index_cls = cls._meta.get('index_cls', True)
if IS_PYMONGO_3 and drop_dups:
if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
@@ -884,11 +878,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if 'cls' in opts:
del opts['cls']
if IS_PYMONGO_3:
collection.create_index(fields, background=background, **opts)
else:
collection.ensure_index(fields, background=background,
drop_dups=drop_dups, **opts)
collection.create_index(fields, background=background, **opts)
# If _cls is being used (for polymorphism), it needs an index,
# only if another index doesn't begin with _cls
@@ -899,12 +889,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if 'cls' in index_opts:
del index_opts['cls']
if IS_PYMONGO_3:
collection.create_index('_cls', background=background,
**index_opts)
else:
collection.ensure_index('_cls', background=background,
**index_opts)
collection.create_index('_cls', background=background,
**index_opts)
@classmethod
def list_indexes(cls):

View File

@@ -1,11 +1,12 @@
from collections import defaultdict
import six
from six import iteritems
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
'ValidationError', 'SaveConditionError')
'ValidationError', 'SaveConditionError', 'DeprecatedError')
class NotRegistered(Exception):
@@ -109,11 +110,8 @@ class ValidationError(AssertionError):
def build_dict(source):
errors_dict = {}
if not source:
return errors_dict
if isinstance(source, dict):
for field_name, error in source.iteritems():
for field_name, error in iteritems(source):
errors_dict[field_name] = build_dict(error)
elif isinstance(source, ValidationError) and source.errors:
return build_dict(source.errors)
@@ -135,12 +133,17 @@ class ValidationError(AssertionError):
value = ' '.join([generate_key(k) for k in value])
elif isinstance(value, dict):
value = ' '.join(
[generate_key(v, k) for k, v in value.iteritems()])
[generate_key(v, k) for k, v in iteritems(value)])
results = '%s.%s' % (prefix, value) if prefix else value
return results
error_dict = defaultdict(list)
for k, v in self.to_dict().iteritems():
for k, v in iteritems(self.to_dict()):
error_dict[generate_key(v)].append(k)
return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()])
return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)])
class DeprecatedError(Exception):
"""Raise when a user uses a feature that has been Deprecated"""
pass

View File

@@ -11,6 +11,7 @@ from bson import Binary, DBRef, ObjectId, SON
import gridfs
import pymongo
import six
from six import iteritems
try:
import dateutil
@@ -36,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
from mongoengine.python_support import StringIO
from mongoengine.queryset import DO_NOTHING
from mongoengine.queryset.base import BaseQuerySet
from mongoengine.queryset.transform import STRING_OPERATORS
try:
from PIL import Image, ImageOps
@@ -105,11 +107,11 @@ class StringField(BaseField):
if not isinstance(op, six.string_types):
return value
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
flags = 0
if op.startswith('i'):
flags = re.IGNORECASE
op = op.lstrip('i')
if op in STRING_OPERATORS:
case_insensitive = op.startswith('i')
op = op.lstrip('i')
flags = re.IGNORECASE if case_insensitive else 0
regex = r'%s'
if op == 'startswith':
@@ -151,12 +153,10 @@ class URLField(StringField):
scheme = value.split('://')[0].lower()
if scheme not in self.schemes:
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
return
# Then check full URL
if not self.url_regex.match(value):
self.error(u'Invalid URL: {}'.format(value))
return
class EmailField(StringField):
@@ -258,10 +258,10 @@ class EmailField(StringField):
try:
domain_part = domain_part.encode('idna').decode('ascii')
except UnicodeError:
self.error(self.error_msg % value)
self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)"))
else:
if not self.validate_domain_part(domain_part):
self.error(self.error_msg % value)
self.error("%s %s" % (self.error_msg % value, "(domain validation failed)"))
class IntField(BaseField):
@@ -498,15 +498,18 @@ class DateTimeField(BaseField):
if not isinstance(value, six.string_types):
return None
return self._parse_datetime(value)
def _parse_datetime(self, value):
# Attempt to parse a datetime from a string
value = value.strip()
if not value:
return None
# Attempt to parse a datetime:
if dateutil:
try:
return dateutil.parser.parse(value)
except (TypeError, ValueError):
except (TypeError, ValueError, OverflowError):
return None
# split usecs, because they are not recognized by strptime.
@@ -699,7 +702,11 @@ class EmbeddedDocumentField(BaseField):
self.document_type.validate(value, clean)
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__()
for doc_type in doc_and_subclasses:
field = doc_type._fields.get(member_name)
if field:
return field
def prepare_query_value(self, op, value):
if value is not None and not isinstance(value, self.document_type):
@@ -746,12 +753,13 @@ class GenericEmbeddedDocumentField(BaseField):
value.validate(clean=clean)
def lookup_member(self, member_name):
if self.choices:
for choice in self.choices:
field = choice._fields.get(member_name)
document_choices = self.choices or []
for document_choice in document_choices:
doc_and_subclasses = [document_choice] + document_choice.__subclasses__()
for doc_type in doc_and_subclasses:
field = doc_type._fields.get(member_name)
if field:
return field
return None
def to_mongo(self, document, use_db_field=True, fields=None):
if document is None:
@@ -794,12 +802,12 @@ class DynamicField(BaseField):
value = {k: v for k, v in enumerate(value)}
data = {}
for k, v in value.iteritems():
for k, v in iteritems(value):
data[k] = self.to_mongo(v, use_db_field, fields)
value = data
if is_list: # Convert back to a list
value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))]
value = [v for k, v in sorted(iteritems(data), key=itemgetter(0))]
return value
def to_python(self, value):

View File

@@ -0,0 +1,19 @@
"""
Helper functions, constants, and types to aid with MongoDB version support
"""
from mongoengine.connection import get_connection
# Constant that can be used to compare the version retrieved with
# get_mongodb_version()
MONGODB_34 = (3, 4)
MONGODB_36 = (3, 6)
def get_mongodb_version():
"""Return the version of the connected mongoDB (first 2 digits)
:return: tuple(int, int)
"""
version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2)
return tuple(version_list)

View File

@@ -0,0 +1,32 @@
"""
Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support.
"""
import pymongo
_PYMONGO_37 = (3, 7)
PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37
def count_documents(collection, filter):
"""Pymongo>3.7 deprecates count in favour of count_documents"""
if IS_PYMONGO_GTE_37:
return collection.count_documents(filter)
else:
count = collection.find(filter).count()
return count
def list_collection_names(db, include_system_collections=False):
"""Pymongo>3.7 deprecates collection_names in favour of list_collection_names"""
if IS_PYMONGO_GTE_37:
collections = db.list_collection_names()
else:
collections = db.collection_names()
if not include_system_collections:
collections = [c for c in collections if not c.startswith('system.')]
return collections

View File

@@ -1,13 +1,8 @@
"""
Helper functions, constants, and types to aid with Python v2.7 - v3.x and
PyMongo v2.7 - v3.x support.
Helper functions, constants, and types to aid with Python v2.7 - v3.x support
"""
import pymongo
import six
IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3
# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
StringIO = six.BytesIO

View File

@@ -10,8 +10,10 @@ from bson import SON, json_util
from bson.code import Code
import pymongo
import pymongo.errors
from pymongo.collection import ReturnDocument
from pymongo.common import validate_read_preference
import six
from six import iteritems
from mongoengine import signals
from mongoengine.base import get_document
@@ -20,14 +22,10 @@ from mongoengine.connection import get_db
from mongoengine.context_managers import set_write_concern, switch_db
from mongoengine.errors import (InvalidQueryError, LookUpError,
NotUniqueError, OperationError)
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
if IS_PYMONGO_3:
from pymongo.collection import ReturnDocument
__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
@@ -196,7 +194,7 @@ class BaseQuerySet(object):
only_fields=self.only_fields
)
raise AttributeError('Provide a slice or an integer index')
raise TypeError('Provide a slice or an integer index')
def __iter__(self):
raise NotImplementedError
@@ -498,11 +496,12 @@ class BaseQuerySet(object):
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param full_result: Return the full result dictionary rather than just the number
updated, e.g. return
``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``.
:param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number
updated items
:param update: Django-style update keyword arguments
:returns the number of updated documents (unless ``full_result`` is True)
.. versionadded:: 0.2
"""
if not update and not upsert:
@@ -566,7 +565,7 @@ class BaseQuerySet(object):
document = self._document.objects.with_id(atomic_update.upserted_id)
return document
def update_one(self, upsert=False, write_concern=None, **update):
def update_one(self, upsert=False, write_concern=None, full_result=False, **update):
"""Perform an atomic update on the fields of the first document
matched by the query.
@@ -577,12 +576,19 @@ class BaseQuerySet(object):
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number
updated items
:param update: Django-style update keyword arguments
full_result
:returns the number of updated documents (unless ``full_result`` is True)
.. versionadded:: 0.2
"""
return self.update(
upsert=upsert, multi=False, write_concern=write_concern, **update)
upsert=upsert,
multi=False,
write_concern=write_concern,
full_result=full_result,
**update)
def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
"""Update and return the updated document.
@@ -617,31 +623,25 @@ class BaseQuerySet(object):
queryset = self.clone()
query = queryset._query
if not IS_PYMONGO_3 or not remove:
if not remove:
update = transform.update(queryset._document, **update)
sort = queryset._ordering
try:
if IS_PYMONGO_3:
if full_response:
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
warnings.warn(msg, DeprecationWarning)
if remove:
result = queryset._collection.find_one_and_delete(
query, sort=sort, **self._cursor_args)
else:
if new:
return_doc = ReturnDocument.AFTER
else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
query, update, upsert=upsert, sort=sort, return_document=return_doc,
**self._cursor_args)
if full_response:
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
warnings.warn(msg, DeprecationWarning)
if remove:
result = queryset._collection.find_one_and_delete(
query, sort=sort, **self._cursor_args)
else:
result = queryset._collection.find_and_modify(
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
if new:
return_doc = ReturnDocument.AFTER
else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
query, update, upsert=upsert, sort=sort, return_document=return_doc,
**self._cursor_args)
except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u'Update failed (%s)' % err)
except pymongo.errors.OperationFailure as err:
@@ -748,7 +748,7 @@ class BaseQuerySet(object):
'_read_preference', '_iter', '_scalar', '_as_pymongo',
'_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms',
'_comment')
'_comment', '_batch_size')
for prop in copy_props:
val = getattr(self, prop)
@@ -1073,15 +1073,14 @@ class BaseQuerySet(object):
..versionchanged:: 0.5 - made chainable
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._snapshot = enabled
return queryset
def timeout(self, enabled):
"""Enable or disable the default mongod timeout when querying.
"""Enable or disable the default mongod timeout when querying. (no_cursor_timeout option)
:param enabled: whether or not the timeout is used
@@ -1099,9 +1098,8 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._slave_okay = enabled
return queryset
@@ -1191,14 +1189,18 @@ class BaseQuerySet(object):
initial_pipeline.append({'$sort': dict(self._ordering)})
if self._limit is not None:
initial_pipeline.append({'$limit': self._limit})
# As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/),
# keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents
# for a skip stage that might succeed these. So we need to maintain more documents in memory in such a
# case (https://stackoverflow.com/a/24161461).
initial_pipeline.append({'$limit': self._limit + (self._skip or 0)})
if self._skip is not None:
initial_pipeline.append({'$skip': self._skip})
pipeline = initial_pipeline + list(pipeline)
if IS_PYMONGO_3 and self._read_preference is not None:
if self._read_preference is not None:
return self._collection.with_options(read_preference=self._read_preference) \
.aggregate(pipeline, cursor={}, **kwargs)
@@ -1408,11 +1410,7 @@ class BaseQuerySet(object):
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline)
if IS_PYMONGO_3:
result = tuple(result)
else:
result = result.get('result')
result = tuple(self._document._get_collection().aggregate(pipeline))
if result:
return result[0]['total']
@@ -1439,11 +1437,7 @@ class BaseQuerySet(object):
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline)
if IS_PYMONGO_3:
result = tuple(result)
else:
result = result.get('result')
result = tuple(self._document._get_collection().aggregate(pipeline))
if result:
return result[0]['total']
return 0
@@ -1518,26 +1512,16 @@ class BaseQuerySet(object):
@property
def _cursor_args(self):
if not IS_PYMONGO_3:
fields_name = 'fields'
cursor_args = {
'timeout': self._timeout,
'snapshot': self._snapshot
}
if self._read_preference is not None:
cursor_args['read_preference'] = self._read_preference
else:
cursor_args['slave_okay'] = self._slave_okay
else:
fields_name = 'projection'
# snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers
if self._snapshot:
msg = 'The snapshot option is not anymore available with PyMongo 3+'
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': not self._timeout
}
fields_name = 'projection'
# snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers
if self._snapshot:
msg = 'The snapshot option is not anymore available with PyMongo 3+'
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': not self._timeout
}
if self._loaded_fields:
cursor_args[fields_name] = self._loaded_fields.as_dict()
@@ -1561,7 +1545,7 @@ class BaseQuerySet(object):
# XXX In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
if self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
@@ -1731,13 +1715,13 @@ class BaseQuerySet(object):
}
"""
total, data, types = self.exec_js(freq_func, field)
values = {types.get(k): int(v) for k, v in data.iteritems()}
values = {types.get(k): int(v) for k, v in iteritems(data)}
if normalize:
values = {k: float(v) / total for k, v in values.items()}
frequencies = {}
for k, v in values.iteritems():
for k, v in iteritems(values):
if isinstance(k, float):
if int(k) == k:
k = int(k)

View File

@@ -4,12 +4,11 @@ from bson import ObjectId, SON
from bson.dbref import DBRef
import pymongo
import six
from six import iteritems
from mongoengine.base import UPDATE_OPERATORS
from mongoengine.common import _import_class
from mongoengine.connection import get_connection
from mongoengine.errors import InvalidQueryError
from mongoengine.python_support import IS_PYMONGO_3
__all__ = ('query', 'update')
@@ -87,18 +86,10 @@ def query(_doc_cls=None, **kwargs):
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += STRING_OPERATORS
if op in singular_ops:
if isinstance(field, six.string_types):
if (op in STRING_OPERATORS and
isinstance(value, six.string_types)):
StringField = _import_class('StringField')
value = StringField.prepare_query_value(op, value)
else:
value = field
else:
value = field.prepare_query_value(op, value)
value = field.prepare_query_value(op, value)
if isinstance(field, CachedReferenceField) and value:
value = value['_id']
if isinstance(field, CachedReferenceField) and value:
value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# Raise an error if the in/nin/all/near param is not iterable.
@@ -154,7 +145,7 @@ def query(_doc_cls=None, **kwargs):
if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \
('$near' in value_dict or '$nearSphere' in value_dict):
value_son = SON()
for k, v in value_dict.iteritems():
for k, v in iteritems(value_dict):
if k == '$maxDistance' or k == '$minDistance':
continue
value_son[k] = v
@@ -162,16 +153,14 @@ def query(_doc_cls=None, **kwargs):
# PyMongo 3+ and MongoDB < 2.6
near_embedded = False
for near_op in ('$near', '$nearSphere'):
if isinstance(value_dict.get(near_op), dict) and (
IS_PYMONGO_3 or get_connection().max_wire_version > 1):
if isinstance(value_dict.get(near_op), dict):
value_son[near_op] = SON(value_son[near_op])
if '$maxDistance' in value_dict:
value_son[near_op][
'$maxDistance'] = value_dict['$maxDistance']
value_son[near_op]['$maxDistance'] = value_dict['$maxDistance']
if '$minDistance' in value_dict:
value_son[near_op][
'$minDistance'] = value_dict['$minDistance']
value_son[near_op]['$minDistance'] = value_dict['$minDistance']
near_embedded = True
if not near_embedded:
if '$maxDistance' in value_dict:
value_son['$maxDistance'] = value_dict['$maxDistance']
@@ -280,7 +269,7 @@ def update(_doc_cls=None, **update):
if op == 'pull':
if field.required or value is not None:
if match == 'in' and not isinstance(value, dict):
if match in ('in', 'nin') and not isinstance(value, dict):
value = _prepare_query_for_iterable(field, op, value)
else:
value = field.prepare_query_value(op, value)
@@ -307,10 +296,6 @@ def update(_doc_cls=None, **update):
key = '.'.join(parts)
if not op:
raise InvalidQueryError('Updates must supply an operation '
'eg: set__FIELD=value')
if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations
# unless they point to a list field