Merge branch 'master' into remove_save_embedded
This commit is contained in:
@@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
|
||||
list(signals.__all__) + list(errors.__all__))
|
||||
|
||||
|
||||
VERSION = (0, 16, 3)
|
||||
VERSION = (0, 17, 0)
|
||||
|
||||
|
||||
def get_version():
|
||||
|
||||
@@ -13,7 +13,7 @@ _document_registry = {}
|
||||
|
||||
|
||||
def get_document(name):
|
||||
"""Get a document class by name."""
|
||||
"""Get a registered Document class by name."""
|
||||
doc = _document_registry.get(name, None)
|
||||
if not doc:
|
||||
# Possible old style name
|
||||
@@ -30,3 +30,12 @@ def get_document(name):
|
||||
been imported?
|
||||
""".strip() % name)
|
||||
return doc
|
||||
|
||||
|
||||
def _get_documents_by_db(connection_alias, default_connection_alias):
|
||||
"""Get all registered Documents class attached to a given database"""
|
||||
def get_doc_alias(doc_cls):
|
||||
return doc_cls._meta.get('db_alias', default_connection_alias)
|
||||
|
||||
return [doc_cls for doc_cls in _document_registry.values()
|
||||
if get_doc_alias(doc_cls) == connection_alias]
|
||||
|
||||
@@ -2,6 +2,7 @@ import weakref
|
||||
|
||||
from bson import DBRef
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
|
||||
@@ -363,7 +364,7 @@ class StrictDict(object):
|
||||
_classes = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.iteritems():
|
||||
for k, v in iteritems(kwargs):
|
||||
setattr(self, k, v)
|
||||
|
||||
def __getitem__(self, key):
|
||||
@@ -411,7 +412,7 @@ class StrictDict(object):
|
||||
return (key for key in self.__slots__ if hasattr(self, key))
|
||||
|
||||
def __len__(self):
|
||||
return len(list(self.iteritems()))
|
||||
return len(list(iteritems(self)))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.items() == other.items()
|
||||
|
||||
@@ -5,6 +5,7 @@ from functools import partial
|
||||
from bson import DBRef, ObjectId, SON, json_util
|
||||
import pymongo
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine import signals
|
||||
from mongoengine.base.common import get_document
|
||||
@@ -83,7 +84,7 @@ class BaseDocument(object):
|
||||
self._dynamic_fields = SON()
|
||||
|
||||
# Assign default values to instance
|
||||
for key, field in self._fields.iteritems():
|
||||
for key, field in iteritems(self._fields):
|
||||
if self._db_field_map.get(key, key) in __only_fields:
|
||||
continue
|
||||
value = getattr(self, key, None)
|
||||
@@ -95,14 +96,14 @@ class BaseDocument(object):
|
||||
# Set passed values after initialisation
|
||||
if self._dynamic:
|
||||
dynamic_data = {}
|
||||
for key, value in values.iteritems():
|
||||
for key, value in iteritems(values):
|
||||
if key in self._fields or key == '_id':
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
dynamic_data[key] = value
|
||||
else:
|
||||
FileField = _import_class('FileField')
|
||||
for key, value in values.iteritems():
|
||||
for key, value in iteritems(values):
|
||||
key = self._reverse_db_field_map.get(key, key)
|
||||
if key in self._fields or key in ('id', 'pk', '_cls'):
|
||||
if __auto_convert and value is not None:
|
||||
@@ -118,7 +119,7 @@ class BaseDocument(object):
|
||||
|
||||
if self._dynamic:
|
||||
self._dynamic_lock = False
|
||||
for key, value in dynamic_data.iteritems():
|
||||
for key, value in iteritems(dynamic_data):
|
||||
setattr(self, key, value)
|
||||
|
||||
# Flag initialised
|
||||
@@ -292,8 +293,7 @@ class BaseDocument(object):
|
||||
"""
|
||||
Return as SON data ready for use with MongoDB.
|
||||
"""
|
||||
if not fields:
|
||||
fields = []
|
||||
fields = fields or []
|
||||
|
||||
data = SON()
|
||||
data['_id'] = None
|
||||
@@ -513,7 +513,7 @@ class BaseDocument(object):
|
||||
if not hasattr(data, 'items'):
|
||||
iterator = enumerate(data)
|
||||
else:
|
||||
iterator = data.iteritems()
|
||||
iterator = iteritems(data)
|
||||
|
||||
for index_or_key, value in iterator:
|
||||
item_key = '%s%s.' % (base_key, index_or_key)
|
||||
@@ -678,7 +678,7 @@ class BaseDocument(object):
|
||||
# Convert SON to a data dict, making sure each key is a string and
|
||||
# corresponds to the right db field.
|
||||
data = {}
|
||||
for key, value in son.iteritems():
|
||||
for key, value in iteritems(son):
|
||||
key = str(key)
|
||||
key = cls._db_field_map.get(key, key)
|
||||
data[key] = value
|
||||
@@ -694,7 +694,7 @@ class BaseDocument(object):
|
||||
if not _auto_dereference:
|
||||
fields = copy.deepcopy(fields)
|
||||
|
||||
for field_name, field in fields.iteritems():
|
||||
for field_name, field in iteritems(fields):
|
||||
field._auto_dereference = _auto_dereference
|
||||
if field.db_field in data:
|
||||
value = data[field.db_field]
|
||||
@@ -715,7 +715,7 @@ class BaseDocument(object):
|
||||
|
||||
# In STRICT documents, remove any keys that aren't in cls._fields
|
||||
if cls.STRICT:
|
||||
data = {k: v for k, v in data.iteritems() if k in cls._fields}
|
||||
data = {k: v for k, v in iteritems(data) if k in cls._fields}
|
||||
|
||||
obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data)
|
||||
obj._changed_fields = changed_fields
|
||||
@@ -882,7 +882,8 @@ class BaseDocument(object):
|
||||
index = {'fields': fields, 'unique': True, 'sparse': sparse}
|
||||
unique_indexes.append(index)
|
||||
|
||||
if field.__class__.__name__ == 'ListField':
|
||||
if field.__class__.__name__ in {'EmbeddedDocumentListField',
|
||||
'ListField', 'SortedListField'}:
|
||||
field = field.field
|
||||
|
||||
# Grab any embedded document field unique indexes
|
||||
|
||||
@@ -5,13 +5,13 @@ import weakref
|
||||
from bson import DBRef, ObjectId, SON
|
||||
import pymongo
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine.base.common import UPDATE_OPERATORS
|
||||
from mongoengine.base.datastructures import (BaseDict, BaseList,
|
||||
EmbeddedDocumentList)
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import ValidationError
|
||||
|
||||
from mongoengine.errors import DeprecatedError, ValidationError
|
||||
|
||||
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
|
||||
'GeoJsonBaseField')
|
||||
@@ -52,8 +52,8 @@ class BaseField(object):
|
||||
unique with.
|
||||
:param primary_key: Mark this field as the primary key. Defaults to False.
|
||||
:param validation: (optional) A callable to validate the value of the
|
||||
field. Generally this is deprecated in favour of the
|
||||
`FIELD.validate` method
|
||||
field. The callable takes the value as parameter and should raise
|
||||
a ValidationError if validation fails
|
||||
:param choices: (optional) The valid choices
|
||||
:param null: (optional) If the field value can be null. If no and there is a default value
|
||||
then the default value is set
|
||||
@@ -225,10 +225,18 @@ class BaseField(object):
|
||||
# check validation argument
|
||||
if self.validation is not None:
|
||||
if callable(self.validation):
|
||||
if not self.validation(value):
|
||||
self.error('Value does not match custom validation method')
|
||||
try:
|
||||
# breaking change of 0.18
|
||||
# Get rid of True/False-type return for the validation method
|
||||
# in favor of having validation raising a ValidationError
|
||||
ret = self.validation(value)
|
||||
if ret is not None:
|
||||
raise DeprecatedError('validation argument for `%s` must not return anything, '
|
||||
'it should raise a ValidationError if validation fails' % self.name)
|
||||
except ValidationError as ex:
|
||||
self.error(str(ex))
|
||||
else:
|
||||
raise ValueError('validation argument for "%s" must be a '
|
||||
raise ValueError('validation argument for `"%s"` must be a '
|
||||
'callable.' % self.name)
|
||||
|
||||
self.validate(value, **kwargs)
|
||||
@@ -275,11 +283,16 @@ class ComplexBaseField(BaseField):
|
||||
|
||||
_dereference = _import_class('DeReference')()
|
||||
|
||||
if instance._initialised and dereference and instance._data.get(self.name):
|
||||
if (instance._initialised and
|
||||
dereference and
|
||||
instance._data.get(self.name) and
|
||||
not getattr(instance._data[self.name], '_dereferenced', False)):
|
||||
instance._data[self.name] = _dereference(
|
||||
instance._data.get(self.name), max_depth=1, instance=instance,
|
||||
name=self.name
|
||||
)
|
||||
if hasattr(instance._data[self.name], '_dereferenced'):
|
||||
instance._data[self.name]._dereferenced = True
|
||||
|
||||
value = super(ComplexBaseField, self).__get__(instance, owner)
|
||||
|
||||
@@ -382,11 +395,11 @@ class ComplexBaseField(BaseField):
|
||||
if self.field:
|
||||
value_dict = {
|
||||
key: self.field._to_mongo_safe_call(item, use_db_field, fields)
|
||||
for key, item in value.iteritems()
|
||||
for key, item in iteritems(value)
|
||||
}
|
||||
else:
|
||||
value_dict = {}
|
||||
for k, v in value.iteritems():
|
||||
for k, v in iteritems(value):
|
||||
if isinstance(v, Document):
|
||||
# We need the id from the saved object to create the DBRef
|
||||
if v.pk is None:
|
||||
@@ -423,7 +436,7 @@ class ComplexBaseField(BaseField):
|
||||
errors = {}
|
||||
if self.field:
|
||||
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
|
||||
sequence = value.iteritems()
|
||||
sequence = iteritems(value)
|
||||
else:
|
||||
sequence = enumerate(value)
|
||||
for k, v in sequence:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import warnings
|
||||
|
||||
import six
|
||||
from six import iteritems, itervalues
|
||||
|
||||
from mongoengine.base.common import _document_registry
|
||||
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField
|
||||
@@ -62,7 +63,7 @@ class DocumentMetaclass(type):
|
||||
# Standard object mixin - merge in any Fields
|
||||
if not hasattr(base, '_meta'):
|
||||
base_fields = {}
|
||||
for attr_name, attr_value in base.__dict__.iteritems():
|
||||
for attr_name, attr_value in iteritems(base.__dict__):
|
||||
if not isinstance(attr_value, BaseField):
|
||||
continue
|
||||
attr_value.name = attr_name
|
||||
@@ -74,7 +75,7 @@ class DocumentMetaclass(type):
|
||||
|
||||
# Discover any document fields
|
||||
field_names = {}
|
||||
for attr_name, attr_value in attrs.iteritems():
|
||||
for attr_name, attr_value in iteritems(attrs):
|
||||
if not isinstance(attr_value, BaseField):
|
||||
continue
|
||||
attr_value.name = attr_name
|
||||
@@ -103,7 +104,7 @@ class DocumentMetaclass(type):
|
||||
|
||||
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
|
||||
(v.creation_counter, v.name)
|
||||
for v in doc_fields.itervalues()))
|
||||
for v in itervalues(doc_fields)))
|
||||
|
||||
#
|
||||
# Set document hierarchy
|
||||
@@ -173,7 +174,7 @@ class DocumentMetaclass(type):
|
||||
f.__dict__.update({'im_self': getattr(f, '__self__')})
|
||||
|
||||
# Handle delete rules
|
||||
for field in new_class._fields.itervalues():
|
||||
for field in itervalues(new_class._fields):
|
||||
f = field
|
||||
if f.owner_document is None:
|
||||
f.owner_document = new_class
|
||||
@@ -183,9 +184,6 @@ class DocumentMetaclass(type):
|
||||
if issubclass(new_class, EmbeddedDocument):
|
||||
raise InvalidDocumentError('CachedReferenceFields is not '
|
||||
'allowed in EmbeddedDocuments')
|
||||
if not f.document_type:
|
||||
raise InvalidDocumentError(
|
||||
'Document is not available to sync')
|
||||
|
||||
if f.auto_sync:
|
||||
f.start_listener()
|
||||
@@ -375,7 +373,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
new_class.objects = QuerySetManager()
|
||||
|
||||
# Validate the fields and set primary key if needed
|
||||
for field_name, field in new_class._fields.iteritems():
|
||||
for field_name, field in iteritems(new_class._fields):
|
||||
if field.primary_key:
|
||||
# Ensure only one primary key is set
|
||||
current_pk = new_class._meta.get('id_field')
|
||||
@@ -438,7 +436,7 @@ class MetaDict(dict):
|
||||
_merge_options = ('indexes',)
|
||||
|
||||
def merge(self, new_options):
|
||||
for k, v in new_options.iteritems():
|
||||
for k, v in iteritems(new_options):
|
||||
if k in self._merge_options:
|
||||
self[k] = self.get(k, []) + v
|
||||
else:
|
||||
|
||||
@@ -31,7 +31,6 @@ def _import_class(cls_name):
|
||||
|
||||
field_classes = _field_list_cache
|
||||
|
||||
queryset_classes = ('OperationError',)
|
||||
deref_classes = ('DeReference',)
|
||||
|
||||
if cls_name == 'BaseDocument':
|
||||
@@ -43,14 +42,11 @@ def _import_class(cls_name):
|
||||
elif cls_name in field_classes:
|
||||
from mongoengine import fields as module
|
||||
import_classes = field_classes
|
||||
elif cls_name in queryset_classes:
|
||||
from mongoengine import queryset as module
|
||||
import_classes = queryset_classes
|
||||
elif cls_name in deref_classes:
|
||||
from mongoengine import dereference as module
|
||||
import_classes = deref_classes
|
||||
else:
|
||||
raise ValueError('No import set for: ' % cls_name)
|
||||
raise ValueError('No import set for: %s' % cls_name)
|
||||
|
||||
for cls in import_classes:
|
||||
_class_registry_cache[cls] = getattr(module, cls)
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from pymongo import MongoClient, ReadPreference, uri_parser
|
||||
from pymongo.database import _check_name
|
||||
import six
|
||||
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
|
||||
__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection',
|
||||
'DEFAULT_CONNECTION_NAME']
|
||||
__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all',
|
||||
'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME',
|
||||
'get_db', 'get_connection']
|
||||
|
||||
|
||||
DEFAULT_CONNECTION_NAME = 'default'
|
||||
DEFAULT_DATABASE_NAME = 'test'
|
||||
DEFAULT_HOST = 'localhost'
|
||||
DEFAULT_PORT = 27017
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
READ_PREFERENCE = ReadPreference.PRIMARY
|
||||
else:
|
||||
from pymongo import MongoReplicaSetClient
|
||||
READ_PREFERENCE = False
|
||||
_connection_settings = {}
|
||||
_connections = {}
|
||||
_dbs = {}
|
||||
|
||||
READ_PREFERENCE = ReadPreference.PRIMARY
|
||||
|
||||
|
||||
class MongoEngineConnectionError(Exception):
|
||||
@@ -23,45 +26,48 @@ class MongoEngineConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_connection_settings = {}
|
||||
_connections = {}
|
||||
_dbs = {}
|
||||
def _check_db_name(name):
|
||||
"""Check if a database name is valid.
|
||||
This functionality is copied from pymongo Database class constructor.
|
||||
"""
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError('name must be an instance of %s' % six.string_types)
|
||||
elif name != '$external':
|
||||
_check_name(name)
|
||||
|
||||
|
||||
def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
authentication_mechanism=None,
|
||||
**kwargs):
|
||||
"""Add a connection.
|
||||
def _get_connection_settings(
|
||||
db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
authentication_mechanism=None,
|
||||
**kwargs):
|
||||
"""Get the connection settings as a dict
|
||||
|
||||
:param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
:param name: the name of the specific database to use
|
||||
:param db: the name of the database to use, for compatibility with connect
|
||||
:param host: the host name of the :program:`mongod` instance to connect to
|
||||
:param port: the port that the :program:`mongod` instance is running on
|
||||
:param read_preference: The read preference for the collection
|
||||
** Added pymongo 2.1
|
||||
:param username: username to authenticate with
|
||||
:param password: password to authenticate with
|
||||
:param authentication_source: database to authenticate against
|
||||
:param authentication_mechanism: database authentication mechanisms.
|
||||
: param db: the name of the database to use, for compatibility with connect
|
||||
: param name: the name of the specific database to use
|
||||
: param host: the host name of the: program: `mongod` instance to connect to
|
||||
: param port: the port that the: program: `mongod` instance is running on
|
||||
: param read_preference: The read preference for the collection
|
||||
: param username: username to authenticate with
|
||||
: param password: password to authenticate with
|
||||
: param authentication_source: database to authenticate against
|
||||
: param authentication_mechanism: database authentication mechanisms.
|
||||
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
|
||||
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
|
||||
:param is_mock: explicitly use mongomock for this connection
|
||||
(can also be done by using `mongomock://` as db host prefix)
|
||||
:param kwargs: ad-hoc parameters to be passed into the pymongo driver,
|
||||
: param is_mock: explicitly use mongomock for this connection
|
||||
(can also be done by using `mongomock: // ` as db host prefix)
|
||||
: param kwargs: ad-hoc parameters to be passed into the pymongo driver,
|
||||
for example maxpoolsize, tz_aware, etc. See the documentation
|
||||
for pymongo's `MongoClient` for a full list.
|
||||
|
||||
.. versionchanged:: 0.10.6 - added mongomock support
|
||||
"""
|
||||
conn_settings = {
|
||||
'name': name or db or 'test',
|
||||
'host': host or 'localhost',
|
||||
'port': port or 27017,
|
||||
'name': name or db or DEFAULT_DATABASE_NAME,
|
||||
'host': host or DEFAULT_HOST,
|
||||
'port': port or DEFAULT_PORT,
|
||||
'read_preference': read_preference,
|
||||
'username': username,
|
||||
'password': password,
|
||||
@@ -69,6 +75,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
'authentication_mechanism': authentication_mechanism
|
||||
}
|
||||
|
||||
_check_db_name(conn_settings['name'])
|
||||
conn_host = conn_settings['host']
|
||||
|
||||
# Host can be a list or a string, so if string, force to a list.
|
||||
@@ -104,16 +111,28 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
conn_settings['authentication_source'] = uri_options['authsource']
|
||||
if 'authmechanism' in uri_options:
|
||||
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
|
||||
if IS_PYMONGO_3 and 'readpreference' in uri_options:
|
||||
if 'readpreference' in uri_options:
|
||||
read_preferences = (
|
||||
ReadPreference.NEAREST,
|
||||
ReadPreference.PRIMARY,
|
||||
ReadPreference.PRIMARY_PREFERRED,
|
||||
ReadPreference.SECONDARY,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
read_pf_mode = uri_options['readpreference'].lower()
|
||||
ReadPreference.SECONDARY_PREFERRED,
|
||||
)
|
||||
|
||||
# Starting with PyMongo v3.5, the "readpreference" option is
|
||||
# returned as a string (e.g. "secondaryPreferred") and not an
|
||||
# int (e.g. 3).
|
||||
# TODO simplify the code below once we drop support for
|
||||
# PyMongo v3.4.
|
||||
read_pf_mode = uri_options['readpreference']
|
||||
if isinstance(read_pf_mode, six.string_types):
|
||||
read_pf_mode = read_pf_mode.lower()
|
||||
for preference in read_preferences:
|
||||
if preference.name.lower() == read_pf_mode:
|
||||
if (
|
||||
preference.name.lower() == read_pf_mode or
|
||||
preference.mode == read_pf_mode
|
||||
):
|
||||
conn_settings['read_preference'] = preference
|
||||
break
|
||||
else:
|
||||
@@ -125,17 +144,74 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
kwargs.pop('is_slave', None)
|
||||
|
||||
conn_settings.update(kwargs)
|
||||
return conn_settings
|
||||
|
||||
|
||||
def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
authentication_mechanism=None,
|
||||
**kwargs):
|
||||
"""Register the connection settings.
|
||||
|
||||
: param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
: param name: the name of the specific database to use
|
||||
: param db: the name of the database to use, for compatibility with connect
|
||||
: param host: the host name of the: program: `mongod` instance to connect to
|
||||
: param port: the port that the: program: `mongod` instance is running on
|
||||
: param read_preference: The read preference for the collection
|
||||
: param username: username to authenticate with
|
||||
: param password: password to authenticate with
|
||||
: param authentication_source: database to authenticate against
|
||||
: param authentication_mechanism: database authentication mechanisms.
|
||||
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
|
||||
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
|
||||
: param is_mock: explicitly use mongomock for this connection
|
||||
(can also be done by using `mongomock: // ` as db host prefix)
|
||||
: param kwargs: ad-hoc parameters to be passed into the pymongo driver,
|
||||
for example maxpoolsize, tz_aware, etc. See the documentation
|
||||
for pymongo's `MongoClient` for a full list.
|
||||
|
||||
.. versionchanged:: 0.10.6 - added mongomock support
|
||||
"""
|
||||
conn_settings = _get_connection_settings(
|
||||
db=db, name=name, host=host, port=port,
|
||||
read_preference=read_preference,
|
||||
username=username, password=password,
|
||||
authentication_source=authentication_source,
|
||||
authentication_mechanism=authentication_mechanism,
|
||||
**kwargs)
|
||||
_connection_settings[alias] = conn_settings
|
||||
|
||||
|
||||
def disconnect(alias=DEFAULT_CONNECTION_NAME):
|
||||
"""Close the connection with a given alias."""
|
||||
from mongoengine.base.common import _get_documents_by_db
|
||||
from mongoengine import Document
|
||||
|
||||
if alias in _connections:
|
||||
get_connection(alias=alias).close()
|
||||
del _connections[alias]
|
||||
|
||||
if alias in _dbs:
|
||||
# Detach all cached collections in Documents
|
||||
for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME):
|
||||
if issubclass(doc_cls, Document): # Skip EmbeddedDocument
|
||||
doc_cls._disconnect()
|
||||
|
||||
del _dbs[alias]
|
||||
|
||||
if alias in _connection_settings:
|
||||
del _connection_settings[alias]
|
||||
|
||||
|
||||
def disconnect_all():
|
||||
"""Close all registered database."""
|
||||
for alias in list(_connections.keys()):
|
||||
disconnect(alias)
|
||||
|
||||
|
||||
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
"""Return a connection with a given alias."""
|
||||
@@ -159,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
raise MongoEngineConnectionError(msg)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
# set literal more efficient than calling set function
|
||||
irrelevant_fields_set = {
|
||||
'name', 'username', 'password',
|
||||
'authentication_source', 'authentication_mechanism'
|
||||
@@ -169,10 +244,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
if k not in irrelevant_fields_set
|
||||
}
|
||||
|
||||
raw_conn_settings = _connection_settings[alias].copy()
|
||||
|
||||
# Retrieve a copy of the connection settings associated with the requested
|
||||
# alias and remove the database name and authentication info (we don't
|
||||
# care about them at this point).
|
||||
conn_settings = _clean_settings(_connection_settings[alias].copy())
|
||||
conn_settings = _clean_settings(raw_conn_settings)
|
||||
|
||||
# Determine if we should use PyMongo's or mongomock's MongoClient.
|
||||
is_mock = conn_settings.pop('is_mock', False)
|
||||
@@ -186,51 +263,60 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
else:
|
||||
connection_class = MongoClient
|
||||
|
||||
# For replica set connections with PyMongo 2.x, use
|
||||
# MongoReplicaSetClient.
|
||||
# TODO remove this once we stop supporting PyMongo 2.x.
|
||||
if 'replicaSet' in conn_settings and not IS_PYMONGO_3:
|
||||
connection_class = MongoReplicaSetClient
|
||||
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
||||
|
||||
# hosts_or_uri has to be a string, so if 'host' was provided
|
||||
# as a list, join its parts and separate them by ','
|
||||
if isinstance(conn_settings['hosts_or_uri'], list):
|
||||
conn_settings['hosts_or_uri'] = ','.join(
|
||||
conn_settings['hosts_or_uri'])
|
||||
|
||||
# Discard port since it can't be used on MongoReplicaSetClient
|
||||
conn_settings.pop('port', None)
|
||||
|
||||
# Iterate over all of the connection settings and if a connection with
|
||||
# the same parameters is already established, use it instead of creating
|
||||
# a new one.
|
||||
existing_connection = None
|
||||
connection_settings_iterator = (
|
||||
(db_alias, settings.copy())
|
||||
for db_alias, settings in _connection_settings.items()
|
||||
)
|
||||
for db_alias, connection_settings in connection_settings_iterator:
|
||||
connection_settings = _clean_settings(connection_settings)
|
||||
if conn_settings == connection_settings and _connections.get(db_alias):
|
||||
existing_connection = _connections[db_alias]
|
||||
break
|
||||
# Re-use existing connection if one is suitable
|
||||
existing_connection = _find_existing_connection(raw_conn_settings)
|
||||
|
||||
# If an existing connection was found, assign it to the new alias
|
||||
if existing_connection:
|
||||
_connections[alias] = existing_connection
|
||||
else:
|
||||
# Otherwise, create the new connection for this alias. Raise
|
||||
# MongoEngineConnectionError if it can't be established.
|
||||
try:
|
||||
_connections[alias] = connection_class(**conn_settings)
|
||||
except Exception as e:
|
||||
raise MongoEngineConnectionError(
|
||||
'Cannot connect to database %s :\n%s' % (alias, e))
|
||||
_connections[alias] = _create_connection(alias=alias,
|
||||
connection_class=connection_class,
|
||||
**conn_settings)
|
||||
|
||||
return _connections[alias]
|
||||
|
||||
|
||||
def _create_connection(alias, connection_class, **connection_settings):
|
||||
"""
|
||||
Create the new connection for this alias. Raise
|
||||
MongoEngineConnectionError if it can't be established.
|
||||
"""
|
||||
try:
|
||||
return connection_class(**connection_settings)
|
||||
except Exception as e:
|
||||
raise MongoEngineConnectionError(
|
||||
'Cannot connect to database %s :\n%s' % (alias, e))
|
||||
|
||||
|
||||
def _find_existing_connection(connection_settings):
|
||||
"""
|
||||
Check if an existing connection could be reused
|
||||
|
||||
Iterate over all of the connection settings and if an existing connection
|
||||
with the same parameters is suitable, return it
|
||||
|
||||
:param connection_settings: the settings of the new connection
|
||||
:return: An existing connection or None
|
||||
"""
|
||||
connection_settings_bis = (
|
||||
(db_alias, settings.copy())
|
||||
for db_alias, settings in _connection_settings.items()
|
||||
)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
# Only remove the name but it's important to
|
||||
# keep the username/password/authentication_source/authentication_mechanism
|
||||
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
|
||||
return {k: v for k, v in settings_dict.items() if k != 'name'}
|
||||
|
||||
cleaned_conn_settings = _clean_settings(connection_settings)
|
||||
for db_alias, connection_settings in connection_settings_bis:
|
||||
db_conn_settings = _clean_settings(connection_settings)
|
||||
if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias):
|
||||
return _connections[db_alias]
|
||||
|
||||
|
||||
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
if reconnect:
|
||||
disconnect(alias)
|
||||
@@ -258,14 +344,24 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
|
||||
provide username and password arguments as well.
|
||||
|
||||
Multiple databases are supported by using aliases. Provide a separate
|
||||
`alias` to connect to a different instance of :program:`mongod`.
|
||||
`alias` to connect to a different instance of: program: `mongod`.
|
||||
|
||||
In order to replace a connection identified by a given alias, you'll
|
||||
need to call ``disconnect`` first
|
||||
|
||||
See the docstring for `register_connection` for more details about all
|
||||
supported kwargs.
|
||||
|
||||
.. versionchanged:: 0.6 - added multiple database support.
|
||||
"""
|
||||
if alias not in _connections:
|
||||
if alias in _connections:
|
||||
prev_conn_setting = _connection_settings[alias]
|
||||
new_conn_settings = _get_connection_settings(db, **kwargs)
|
||||
|
||||
if new_conn_settings != prev_conn_setting:
|
||||
raise MongoEngineConnectionError(
|
||||
'A different connection with alias `%s` was already registered. Use disconnect() first' % alias)
|
||||
else:
|
||||
register_connection(alias, db, **kwargs)
|
||||
|
||||
return get_connection(alias)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from pymongo.write_concern import WriteConcern
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
|
||||
from mongoengine.pymongo_support import count_documents
|
||||
|
||||
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
|
||||
'no_sub_classes', 'query_counter', 'set_write_concern')
|
||||
@@ -112,7 +115,7 @@ class no_dereference(object):
|
||||
GenericReferenceField = _import_class('GenericReferenceField')
|
||||
ComplexBaseField = _import_class('ComplexBaseField')
|
||||
|
||||
self.deref_fields = [k for k, v in self.cls._fields.iteritems()
|
||||
self.deref_fields = [k for k, v in iteritems(self.cls._fields)
|
||||
if isinstance(v, (ReferenceField,
|
||||
GenericReferenceField,
|
||||
ComplexBaseField))]
|
||||
@@ -235,7 +238,7 @@ class query_counter(object):
|
||||
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
|
||||
issued so we need to balance that
|
||||
"""
|
||||
count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter
|
||||
count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter
|
||||
self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
|
||||
return count
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from bson import DBRef, SON
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
|
||||
TopLevelDocumentMetaclass, get_document)
|
||||
@@ -71,7 +72,7 @@ class DeReference(object):
|
||||
|
||||
def _get_items_from_dict(items):
|
||||
new_items = {}
|
||||
for k, v in items.iteritems():
|
||||
for k, v in iteritems(items):
|
||||
value = v
|
||||
if isinstance(v, list):
|
||||
value = _get_items_from_list(v)
|
||||
@@ -112,7 +113,7 @@ class DeReference(object):
|
||||
depth += 1
|
||||
for item in iterator:
|
||||
if isinstance(item, (Document, EmbeddedDocument)):
|
||||
for field_name, field in item._fields.iteritems():
|
||||
for field_name, field in iteritems(item._fields):
|
||||
v = item._data.get(field_name, None)
|
||||
if isinstance(v, LazyReference):
|
||||
# LazyReference inherits DBRef but should not be dereferenced here !
|
||||
@@ -124,7 +125,7 @@ class DeReference(object):
|
||||
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
||||
field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
|
||||
references = self._find_references(v, depth)
|
||||
for key, refs in references.iteritems():
|
||||
for key, refs in iteritems(references):
|
||||
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
|
||||
key = field_cls
|
||||
reference_map.setdefault(key, set()).update(refs)
|
||||
@@ -137,7 +138,7 @@ class DeReference(object):
|
||||
reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id)
|
||||
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
|
||||
references = self._find_references(item, depth - 1)
|
||||
for key, refs in references.iteritems():
|
||||
for key, refs in iteritems(references):
|
||||
reference_map.setdefault(key, set()).update(refs)
|
||||
|
||||
return reference_map
|
||||
@@ -146,7 +147,7 @@ class DeReference(object):
|
||||
"""Fetch all references and convert to their document objects
|
||||
"""
|
||||
object_map = {}
|
||||
for collection, dbrefs in self.reference_map.iteritems():
|
||||
for collection, dbrefs in iteritems(self.reference_map):
|
||||
|
||||
# we use getattr instead of hasattr because hasattr swallows any exception under python2
|
||||
# so it could hide nasty things without raising exceptions (cfr bug #1688))
|
||||
@@ -157,7 +158,7 @@ class DeReference(object):
|
||||
refs = [dbref for dbref in dbrefs
|
||||
if (col_name, dbref) not in object_map]
|
||||
references = collection.objects.in_bulk(refs)
|
||||
for key, doc in references.iteritems():
|
||||
for key, doc in iteritems(references):
|
||||
object_map[(col_name, key)] = doc
|
||||
else: # Generic reference: use the refs data to convert to document
|
||||
if isinstance(doc_type, (ListField, DictField, MapField)):
|
||||
@@ -229,7 +230,7 @@ class DeReference(object):
|
||||
data = []
|
||||
else:
|
||||
is_list = False
|
||||
iterator = items.iteritems()
|
||||
iterator = iteritems(items)
|
||||
data = {}
|
||||
|
||||
depth += 1
|
||||
|
||||
@@ -5,6 +5,7 @@ from bson.dbref import DBRef
|
||||
import pymongo
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine import signals
|
||||
from mongoengine.base import (BaseDict, BaseDocument, BaseList,
|
||||
@@ -17,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern,
|
||||
switch_db)
|
||||
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
|
||||
SaveConditionError)
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
from mongoengine.pymongo_support import list_collection_names
|
||||
from mongoengine.queryset import (NotUniqueError, OperationError,
|
||||
QuerySet, transform)
|
||||
|
||||
@@ -175,10 +176,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
|
||||
|
||||
@classmethod
|
||||
def _get_collection(cls):
|
||||
"""Return a PyMongo collection for the document."""
|
||||
if not hasattr(cls, '_collection') or cls._collection is None:
|
||||
def _disconnect(cls):
|
||||
"""Detach the Document class from the (cached) database collection"""
|
||||
cls._collection = None
|
||||
|
||||
@classmethod
|
||||
def _get_collection(cls):
|
||||
"""Return the corresponding PyMongo collection of this document.
|
||||
Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached
|
||||
"""
|
||||
if not hasattr(cls, '_collection') or cls._collection is None:
|
||||
# Get the collection, either capped or regular.
|
||||
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
|
||||
cls._collection = cls._get_capped_collection()
|
||||
@@ -215,7 +222,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
|
||||
# If the collection already exists and has different options
|
||||
# (i.e. isn't capped or has different max/size), raise an error.
|
||||
if collection_name in db.collection_names():
|
||||
if collection_name in list_collection_names(db, include_system_collections=True):
|
||||
collection = db[collection_name]
|
||||
options = collection.options()
|
||||
if (
|
||||
@@ -240,7 +247,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
data = super(Document, self).to_mongo(*args, **kwargs)
|
||||
|
||||
# If '_id' is None, try and set it from self._data. If that
|
||||
# doesn't exist either, remote '_id' from the SON completely.
|
||||
# doesn't exist either, remove '_id' from the SON completely.
|
||||
if data['_id'] is None:
|
||||
if self._data.get('id') is None:
|
||||
del data['_id']
|
||||
@@ -346,21 +353,21 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
.. versionchanged:: 0.10.7
|
||||
Add signal_kwargs argument
|
||||
"""
|
||||
signal_kwargs = signal_kwargs or {}
|
||||
|
||||
if self._meta.get('abstract'):
|
||||
raise InvalidDocumentError('Cannot save an abstract document.')
|
||||
|
||||
signal_kwargs = signal_kwargs or {}
|
||||
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
|
||||
|
||||
if validate:
|
||||
self.validate(clean=clean)
|
||||
|
||||
if write_concern is None:
|
||||
write_concern = {'w': 1}
|
||||
write_concern = {}
|
||||
|
||||
doc = self.to_mongo()
|
||||
|
||||
created = ('_id' not in doc or self._created or force_insert)
|
||||
doc_id = self.to_mongo(fields=['id'])
|
||||
created = ('_id' not in doc_id or self._created or force_insert)
|
||||
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self,
|
||||
created=created, **signal_kwargs)
|
||||
@@ -438,16 +445,6 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
|
||||
object_id = wc_collection.insert_one(doc).inserted_id
|
||||
|
||||
# In PyMongo 3.0, the save() call calls internally the _update() call
|
||||
# but they forget to return the _id value passed back, therefore getting it back here
|
||||
# Correct behaviour in 2.X and in 3.0.1+ versions
|
||||
if not object_id and pymongo.version_tuple == (3, 0):
|
||||
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
|
||||
object_id = (
|
||||
self._qs.filter(pk=pk_as_mongo_obj).first() and
|
||||
self._qs.filter(pk=pk_as_mongo_obj).first().pk
|
||||
) # TODO doesn't this make 2 queries?
|
||||
|
||||
return object_id
|
||||
|
||||
def _get_update_doc(self):
|
||||
@@ -493,8 +490,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
update_doc = self._get_update_doc()
|
||||
if update_doc:
|
||||
upsert = save_condition is None
|
||||
last_error = collection.update(select_dict, update_doc,
|
||||
upsert=upsert, **write_concern)
|
||||
with set_write_concern(collection, write_concern) as wc_collection:
|
||||
last_error = wc_collection.update_one(
|
||||
select_dict,
|
||||
update_doc,
|
||||
upsert=upsert
|
||||
).raw_result
|
||||
if not upsert and last_error['n'] == 0:
|
||||
raise SaveConditionError('Race condition preventing'
|
||||
' document update detected')
|
||||
@@ -601,7 +602,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
|
||||
# Delete FileFields separately
|
||||
FileField = _import_class('FileField')
|
||||
for name, field in self._fields.iteritems():
|
||||
for name, field in iteritems(self._fields):
|
||||
if isinstance(field, FileField):
|
||||
getattr(self, name).delete()
|
||||
|
||||
@@ -786,13 +787,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
.. versionchanged:: 0.10.7
|
||||
:class:`OperationError` exception raised if no collection available
|
||||
"""
|
||||
col_name = cls._get_collection_name()
|
||||
if not col_name:
|
||||
coll_name = cls._get_collection_name()
|
||||
if not coll_name:
|
||||
raise OperationError('Document %s has no collection defined '
|
||||
'(is it abstract ?)' % cls)
|
||||
cls._collection = None
|
||||
db = cls._get_db()
|
||||
db.drop_collection(col_name)
|
||||
db.drop_collection(coll_name)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, keys, background=False, **kwargs):
|
||||
@@ -807,18 +808,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
index_spec = index_spec.copy()
|
||||
fields = index_spec.pop('fields')
|
||||
drop_dups = kwargs.get('drop_dups', False)
|
||||
if IS_PYMONGO_3 and drop_dups:
|
||||
if drop_dups:
|
||||
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
elif not IS_PYMONGO_3:
|
||||
index_spec['drop_dups'] = drop_dups
|
||||
index_spec['background'] = background
|
||||
index_spec.update(kwargs)
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
return cls._get_collection().create_index(fields, **index_spec)
|
||||
else:
|
||||
return cls._get_collection().ensure_index(fields, **index_spec)
|
||||
return cls._get_collection().create_index(fields, **index_spec)
|
||||
|
||||
@classmethod
|
||||
def ensure_index(cls, key_or_list, drop_dups=False, background=False,
|
||||
@@ -833,11 +829,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
:param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value
|
||||
will be removed if PyMongo3+ is used
|
||||
"""
|
||||
if IS_PYMONGO_3 and drop_dups:
|
||||
if drop_dups:
|
||||
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
elif not IS_PYMONGO_3:
|
||||
kwargs.update({'drop_dups': drop_dups})
|
||||
return cls.create_index(key_or_list, background=background, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -853,7 +847,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
drop_dups = cls._meta.get('index_drop_dups', False)
|
||||
index_opts = cls._meta.get('index_opts') or {}
|
||||
index_cls = cls._meta.get('index_cls', True)
|
||||
if IS_PYMONGO_3 and drop_dups:
|
||||
if drop_dups:
|
||||
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
|
||||
@@ -884,11 +878,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
if 'cls' in opts:
|
||||
del opts['cls']
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
collection.create_index(fields, background=background, **opts)
|
||||
else:
|
||||
collection.ensure_index(fields, background=background,
|
||||
drop_dups=drop_dups, **opts)
|
||||
collection.create_index(fields, background=background, **opts)
|
||||
|
||||
# If _cls is being used (for polymorphism), it needs an index,
|
||||
# only if another index doesn't begin with _cls
|
||||
@@ -899,12 +889,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
if 'cls' in index_opts:
|
||||
del index_opts['cls']
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
collection.create_index('_cls', background=background,
|
||||
**index_opts)
|
||||
else:
|
||||
collection.ensure_index('_cls', background=background,
|
||||
**index_opts)
|
||||
collection.create_index('_cls', background=background,
|
||||
**index_opts)
|
||||
|
||||
@classmethod
|
||||
def list_indexes(cls):
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
|
||||
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
|
||||
'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
|
||||
'ValidationError', 'SaveConditionError')
|
||||
'ValidationError', 'SaveConditionError', 'DeprecatedError')
|
||||
|
||||
|
||||
class NotRegistered(Exception):
|
||||
@@ -109,11 +110,8 @@ class ValidationError(AssertionError):
|
||||
|
||||
def build_dict(source):
|
||||
errors_dict = {}
|
||||
if not source:
|
||||
return errors_dict
|
||||
|
||||
if isinstance(source, dict):
|
||||
for field_name, error in source.iteritems():
|
||||
for field_name, error in iteritems(source):
|
||||
errors_dict[field_name] = build_dict(error)
|
||||
elif isinstance(source, ValidationError) and source.errors:
|
||||
return build_dict(source.errors)
|
||||
@@ -135,12 +133,17 @@ class ValidationError(AssertionError):
|
||||
value = ' '.join([generate_key(k) for k in value])
|
||||
elif isinstance(value, dict):
|
||||
value = ' '.join(
|
||||
[generate_key(v, k) for k, v in value.iteritems()])
|
||||
[generate_key(v, k) for k, v in iteritems(value)])
|
||||
|
||||
results = '%s.%s' % (prefix, value) if prefix else value
|
||||
return results
|
||||
|
||||
error_dict = defaultdict(list)
|
||||
for k, v in self.to_dict().iteritems():
|
||||
for k, v in iteritems(self.to_dict()):
|
||||
error_dict[generate_key(v)].append(k)
|
||||
return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()])
|
||||
return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)])
|
||||
|
||||
|
||||
class DeprecatedError(Exception):
|
||||
"""Raise when a user uses a feature that has been Deprecated"""
|
||||
pass
|
||||
|
||||
@@ -11,6 +11,7 @@ from bson import Binary, DBRef, ObjectId, SON
|
||||
import gridfs
|
||||
import pymongo
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
try:
|
||||
import dateutil
|
||||
@@ -36,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
|
||||
from mongoengine.python_support import StringIO
|
||||
from mongoengine.queryset import DO_NOTHING
|
||||
from mongoengine.queryset.base import BaseQuerySet
|
||||
from mongoengine.queryset.transform import STRING_OPERATORS
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageOps
|
||||
@@ -105,11 +107,11 @@ class StringField(BaseField):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
|
||||
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
|
||||
flags = 0
|
||||
if op.startswith('i'):
|
||||
flags = re.IGNORECASE
|
||||
op = op.lstrip('i')
|
||||
if op in STRING_OPERATORS:
|
||||
case_insensitive = op.startswith('i')
|
||||
op = op.lstrip('i')
|
||||
|
||||
flags = re.IGNORECASE if case_insensitive else 0
|
||||
|
||||
regex = r'%s'
|
||||
if op == 'startswith':
|
||||
@@ -151,12 +153,10 @@ class URLField(StringField):
|
||||
scheme = value.split('://')[0].lower()
|
||||
if scheme not in self.schemes:
|
||||
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
|
||||
return
|
||||
|
||||
# Then check full URL
|
||||
if not self.url_regex.match(value):
|
||||
self.error(u'Invalid URL: {}'.format(value))
|
||||
return
|
||||
|
||||
|
||||
class EmailField(StringField):
|
||||
@@ -258,10 +258,10 @@ class EmailField(StringField):
|
||||
try:
|
||||
domain_part = domain_part.encode('idna').decode('ascii')
|
||||
except UnicodeError:
|
||||
self.error(self.error_msg % value)
|
||||
self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)"))
|
||||
else:
|
||||
if not self.validate_domain_part(domain_part):
|
||||
self.error(self.error_msg % value)
|
||||
self.error("%s %s" % (self.error_msg % value, "(domain validation failed)"))
|
||||
|
||||
|
||||
class IntField(BaseField):
|
||||
@@ -498,15 +498,18 @@ class DateTimeField(BaseField):
|
||||
if not isinstance(value, six.string_types):
|
||||
return None
|
||||
|
||||
return self._parse_datetime(value)
|
||||
|
||||
def _parse_datetime(self, value):
|
||||
# Attempt to parse a datetime from a string
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
|
||||
# Attempt to parse a datetime:
|
||||
if dateutil:
|
||||
try:
|
||||
return dateutil.parser.parse(value)
|
||||
except (TypeError, ValueError):
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
# split usecs, because they are not recognized by strptime.
|
||||
@@ -699,7 +702,11 @@ class EmbeddedDocumentField(BaseField):
|
||||
self.document_type.validate(value, clean)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
return self.document_type._fields.get(member_name)
|
||||
doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__()
|
||||
for doc_type in doc_and_subclasses:
|
||||
field = doc_type._fields.get(member_name)
|
||||
if field:
|
||||
return field
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if value is not None and not isinstance(value, self.document_type):
|
||||
@@ -746,12 +753,13 @@ class GenericEmbeddedDocumentField(BaseField):
|
||||
value.validate(clean=clean)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
if self.choices:
|
||||
for choice in self.choices:
|
||||
field = choice._fields.get(member_name)
|
||||
document_choices = self.choices or []
|
||||
for document_choice in document_choices:
|
||||
doc_and_subclasses = [document_choice] + document_choice.__subclasses__()
|
||||
for doc_type in doc_and_subclasses:
|
||||
field = doc_type._fields.get(member_name)
|
||||
if field:
|
||||
return field
|
||||
return None
|
||||
|
||||
def to_mongo(self, document, use_db_field=True, fields=None):
|
||||
if document is None:
|
||||
@@ -794,12 +802,12 @@ class DynamicField(BaseField):
|
||||
value = {k: v for k, v in enumerate(value)}
|
||||
|
||||
data = {}
|
||||
for k, v in value.iteritems():
|
||||
for k, v in iteritems(value):
|
||||
data[k] = self.to_mongo(v, use_db_field, fields)
|
||||
|
||||
value = data
|
||||
if is_list: # Convert back to a list
|
||||
value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))]
|
||||
value = [v for k, v in sorted(iteritems(data), key=itemgetter(0))]
|
||||
return value
|
||||
|
||||
def to_python(self, value):
|
||||
|
||||
19
mongoengine/mongodb_support.py
Normal file
19
mongoengine/mongodb_support.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Helper functions, constants, and types to aid with MongoDB version support
|
||||
"""
|
||||
from mongoengine.connection import get_connection
|
||||
|
||||
|
||||
# Constant that can be used to compare the version retrieved with
|
||||
# get_mongodb_version()
|
||||
MONGODB_34 = (3, 4)
|
||||
MONGODB_36 = (3, 6)
|
||||
|
||||
|
||||
def get_mongodb_version():
|
||||
"""Return the version of the connected mongoDB (first 2 digits)
|
||||
|
||||
:return: tuple(int, int)
|
||||
"""
|
||||
version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2)
|
||||
return tuple(version_list)
|
||||
32
mongoengine/pymongo_support.py
Normal file
32
mongoengine/pymongo_support.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support.
|
||||
"""
|
||||
import pymongo
|
||||
|
||||
_PYMONGO_37 = (3, 7)
|
||||
|
||||
PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
|
||||
|
||||
IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37
|
||||
|
||||
|
||||
def count_documents(collection, filter):
|
||||
"""Pymongo>3.7 deprecates count in favour of count_documents"""
|
||||
if IS_PYMONGO_GTE_37:
|
||||
return collection.count_documents(filter)
|
||||
else:
|
||||
count = collection.find(filter).count()
|
||||
return count
|
||||
|
||||
|
||||
def list_collection_names(db, include_system_collections=False):
|
||||
"""Pymongo>3.7 deprecates collection_names in favour of list_collection_names"""
|
||||
if IS_PYMONGO_GTE_37:
|
||||
collections = db.list_collection_names()
|
||||
else:
|
||||
collections = db.collection_names()
|
||||
|
||||
if not include_system_collections:
|
||||
collections = [c for c in collections if not c.startswith('system.')]
|
||||
|
||||
return collections
|
||||
@@ -1,13 +1,8 @@
|
||||
"""
|
||||
Helper functions, constants, and types to aid with Python v2.7 - v3.x and
|
||||
PyMongo v2.7 - v3.x support.
|
||||
Helper functions, constants, and types to aid with Python v2.7 - v3.x support
|
||||
"""
|
||||
import pymongo
|
||||
import six
|
||||
|
||||
|
||||
IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3
|
||||
|
||||
# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
|
||||
StringIO = six.BytesIO
|
||||
|
||||
|
||||
@@ -10,8 +10,10 @@ from bson import SON, json_util
|
||||
from bson.code import Code
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.common import validate_read_preference
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine import signals
|
||||
from mongoengine.base import get_document
|
||||
@@ -20,14 +22,10 @@ from mongoengine.connection import get_db
|
||||
from mongoengine.context_managers import set_write_concern, switch_db
|
||||
from mongoengine.errors import (InvalidQueryError, LookUpError,
|
||||
NotUniqueError, OperationError)
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
from mongoengine.queryset import transform
|
||||
from mongoengine.queryset.field_list import QueryFieldList
|
||||
from mongoengine.queryset.visitor import Q, QNode
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
from pymongo.collection import ReturnDocument
|
||||
|
||||
|
||||
__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
|
||||
|
||||
@@ -196,7 +194,7 @@ class BaseQuerySet(object):
|
||||
only_fields=self.only_fields
|
||||
)
|
||||
|
||||
raise AttributeError('Provide a slice or an integer index')
|
||||
raise TypeError('Provide a slice or an integer index')
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
@@ -498,11 +496,12 @@ class BaseQuerySet(object):
|
||||
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
|
||||
wait until at least two servers have recorded the write and
|
||||
will force an fsync on the primary server.
|
||||
:param full_result: Return the full result dictionary rather than just the number
|
||||
updated, e.g. return
|
||||
``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``.
|
||||
:param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number
|
||||
updated items
|
||||
:param update: Django-style update keyword arguments
|
||||
|
||||
:returns the number of updated documents (unless ``full_result`` is True)
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
if not update and not upsert:
|
||||
@@ -566,7 +565,7 @@ class BaseQuerySet(object):
|
||||
document = self._document.objects.with_id(atomic_update.upserted_id)
|
||||
return document
|
||||
|
||||
def update_one(self, upsert=False, write_concern=None, **update):
|
||||
def update_one(self, upsert=False, write_concern=None, full_result=False, **update):
|
||||
"""Perform an atomic update on the fields of the first document
|
||||
matched by the query.
|
||||
|
||||
@@ -577,12 +576,19 @@ class BaseQuerySet(object):
|
||||
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
|
||||
wait until at least two servers have recorded the write and
|
||||
will force an fsync on the primary server.
|
||||
:param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number
|
||||
updated items
|
||||
:param update: Django-style update keyword arguments
|
||||
|
||||
full_result
|
||||
:returns the number of updated documents (unless ``full_result`` is True)
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
return self.update(
|
||||
upsert=upsert, multi=False, write_concern=write_concern, **update)
|
||||
upsert=upsert,
|
||||
multi=False,
|
||||
write_concern=write_concern,
|
||||
full_result=full_result,
|
||||
**update)
|
||||
|
||||
def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
|
||||
"""Update and return the updated document.
|
||||
@@ -617,31 +623,25 @@ class BaseQuerySet(object):
|
||||
|
||||
queryset = self.clone()
|
||||
query = queryset._query
|
||||
if not IS_PYMONGO_3 or not remove:
|
||||
if not remove:
|
||||
update = transform.update(queryset._document, **update)
|
||||
sort = queryset._ordering
|
||||
|
||||
try:
|
||||
if IS_PYMONGO_3:
|
||||
if full_response:
|
||||
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
if remove:
|
||||
result = queryset._collection.find_one_and_delete(
|
||||
query, sort=sort, **self._cursor_args)
|
||||
else:
|
||||
if new:
|
||||
return_doc = ReturnDocument.AFTER
|
||||
else:
|
||||
return_doc = ReturnDocument.BEFORE
|
||||
result = queryset._collection.find_one_and_update(
|
||||
query, update, upsert=upsert, sort=sort, return_document=return_doc,
|
||||
**self._cursor_args)
|
||||
|
||||
if full_response:
|
||||
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
if remove:
|
||||
result = queryset._collection.find_one_and_delete(
|
||||
query, sort=sort, **self._cursor_args)
|
||||
else:
|
||||
result = queryset._collection.find_and_modify(
|
||||
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
|
||||
full_response=full_response, **self._cursor_args)
|
||||
if new:
|
||||
return_doc = ReturnDocument.AFTER
|
||||
else:
|
||||
return_doc = ReturnDocument.BEFORE
|
||||
result = queryset._collection.find_one_and_update(
|
||||
query, update, upsert=upsert, sort=sort, return_document=return_doc,
|
||||
**self._cursor_args)
|
||||
except pymongo.errors.DuplicateKeyError as err:
|
||||
raise NotUniqueError(u'Update failed (%s)' % err)
|
||||
except pymongo.errors.OperationFailure as err:
|
||||
@@ -748,7 +748,7 @@ class BaseQuerySet(object):
|
||||
'_read_preference', '_iter', '_scalar', '_as_pymongo',
|
||||
'_limit', '_skip', '_hint', '_auto_dereference',
|
||||
'_search_text', 'only_fields', '_max_time_ms',
|
||||
'_comment')
|
||||
'_comment', '_batch_size')
|
||||
|
||||
for prop in copy_props:
|
||||
val = getattr(self, prop)
|
||||
@@ -1073,15 +1073,14 @@ class BaseQuerySet(object):
|
||||
..versionchanged:: 0.5 - made chainable
|
||||
.. deprecated:: Ignored with PyMongo 3+
|
||||
"""
|
||||
if IS_PYMONGO_3:
|
||||
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
queryset = self.clone()
|
||||
queryset._snapshot = enabled
|
||||
return queryset
|
||||
|
||||
def timeout(self, enabled):
|
||||
"""Enable or disable the default mongod timeout when querying.
|
||||
"""Enable or disable the default mongod timeout when querying. (no_cursor_timeout option)
|
||||
|
||||
:param enabled: whether or not the timeout is used
|
||||
|
||||
@@ -1099,9 +1098,8 @@ class BaseQuerySet(object):
|
||||
|
||||
.. deprecated:: Ignored with PyMongo 3+
|
||||
"""
|
||||
if IS_PYMONGO_3:
|
||||
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
queryset = self.clone()
|
||||
queryset._slave_okay = enabled
|
||||
return queryset
|
||||
@@ -1191,14 +1189,18 @@ class BaseQuerySet(object):
|
||||
initial_pipeline.append({'$sort': dict(self._ordering)})
|
||||
|
||||
if self._limit is not None:
|
||||
initial_pipeline.append({'$limit': self._limit})
|
||||
# As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/),
|
||||
# keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents
|
||||
# for a skip stage that might succeed these. So we need to maintain more documents in memory in such a
|
||||
# case (https://stackoverflow.com/a/24161461).
|
||||
initial_pipeline.append({'$limit': self._limit + (self._skip or 0)})
|
||||
|
||||
if self._skip is not None:
|
||||
initial_pipeline.append({'$skip': self._skip})
|
||||
|
||||
pipeline = initial_pipeline + list(pipeline)
|
||||
|
||||
if IS_PYMONGO_3 and self._read_preference is not None:
|
||||
if self._read_preference is not None:
|
||||
return self._collection.with_options(read_preference=self._read_preference) \
|
||||
.aggregate(pipeline, cursor={}, **kwargs)
|
||||
|
||||
@@ -1408,11 +1410,7 @@ class BaseQuerySet(object):
|
||||
if isinstance(field_instances[-1], ListField):
|
||||
pipeline.insert(1, {'$unwind': '$' + field})
|
||||
|
||||
result = self._document._get_collection().aggregate(pipeline)
|
||||
if IS_PYMONGO_3:
|
||||
result = tuple(result)
|
||||
else:
|
||||
result = result.get('result')
|
||||
result = tuple(self._document._get_collection().aggregate(pipeline))
|
||||
|
||||
if result:
|
||||
return result[0]['total']
|
||||
@@ -1439,11 +1437,7 @@ class BaseQuerySet(object):
|
||||
if isinstance(field_instances[-1], ListField):
|
||||
pipeline.insert(1, {'$unwind': '$' + field})
|
||||
|
||||
result = self._document._get_collection().aggregate(pipeline)
|
||||
if IS_PYMONGO_3:
|
||||
result = tuple(result)
|
||||
else:
|
||||
result = result.get('result')
|
||||
result = tuple(self._document._get_collection().aggregate(pipeline))
|
||||
if result:
|
||||
return result[0]['total']
|
||||
return 0
|
||||
@@ -1518,26 +1512,16 @@ class BaseQuerySet(object):
|
||||
|
||||
@property
|
||||
def _cursor_args(self):
|
||||
if not IS_PYMONGO_3:
|
||||
fields_name = 'fields'
|
||||
cursor_args = {
|
||||
'timeout': self._timeout,
|
||||
'snapshot': self._snapshot
|
||||
}
|
||||
if self._read_preference is not None:
|
||||
cursor_args['read_preference'] = self._read_preference
|
||||
else:
|
||||
cursor_args['slave_okay'] = self._slave_okay
|
||||
else:
|
||||
fields_name = 'projection'
|
||||
# snapshot is not handled at all by PyMongo 3+
|
||||
# TODO: evaluate similar possibilities using modifiers
|
||||
if self._snapshot:
|
||||
msg = 'The snapshot option is not anymore available with PyMongo 3+'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
cursor_args = {
|
||||
'no_cursor_timeout': not self._timeout
|
||||
}
|
||||
fields_name = 'projection'
|
||||
# snapshot is not handled at all by PyMongo 3+
|
||||
# TODO: evaluate similar possibilities using modifiers
|
||||
if self._snapshot:
|
||||
msg = 'The snapshot option is not anymore available with PyMongo 3+'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
cursor_args = {
|
||||
'no_cursor_timeout': not self._timeout
|
||||
}
|
||||
|
||||
if self._loaded_fields:
|
||||
cursor_args[fields_name] = self._loaded_fields.as_dict()
|
||||
|
||||
@@ -1561,7 +1545,7 @@ class BaseQuerySet(object):
|
||||
# XXX In PyMongo 3+, we define the read preference on a collection
|
||||
# level, not a cursor level. Thus, we need to get a cloned collection
|
||||
# object using `with_options` first.
|
||||
if IS_PYMONGO_3 and self._read_preference is not None:
|
||||
if self._read_preference is not None:
|
||||
self._cursor_obj = self._collection\
|
||||
.with_options(read_preference=self._read_preference)\
|
||||
.find(self._query, **self._cursor_args)
|
||||
@@ -1731,13 +1715,13 @@ class BaseQuerySet(object):
|
||||
}
|
||||
"""
|
||||
total, data, types = self.exec_js(freq_func, field)
|
||||
values = {types.get(k): int(v) for k, v in data.iteritems()}
|
||||
values = {types.get(k): int(v) for k, v in iteritems(data)}
|
||||
|
||||
if normalize:
|
||||
values = {k: float(v) / total for k, v in values.items()}
|
||||
|
||||
frequencies = {}
|
||||
for k, v in values.iteritems():
|
||||
for k, v in iteritems(values):
|
||||
if isinstance(k, float):
|
||||
if int(k) == k:
|
||||
k = int(k)
|
||||
|
||||
@@ -4,12 +4,11 @@ from bson import ObjectId, SON
|
||||
from bson.dbref import DBRef
|
||||
import pymongo
|
||||
import six
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine.base import UPDATE_OPERATORS
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import get_connection
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
|
||||
__all__ = ('query', 'update')
|
||||
|
||||
@@ -87,18 +86,10 @@ def query(_doc_cls=None, **kwargs):
|
||||
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
|
||||
singular_ops += STRING_OPERATORS
|
||||
if op in singular_ops:
|
||||
if isinstance(field, six.string_types):
|
||||
if (op in STRING_OPERATORS and
|
||||
isinstance(value, six.string_types)):
|
||||
StringField = _import_class('StringField')
|
||||
value = StringField.prepare_query_value(op, value)
|
||||
else:
|
||||
value = field
|
||||
else:
|
||||
value = field.prepare_query_value(op, value)
|
||||
value = field.prepare_query_value(op, value)
|
||||
|
||||
if isinstance(field, CachedReferenceField) and value:
|
||||
value = value['_id']
|
||||
if isinstance(field, CachedReferenceField) and value:
|
||||
value = value['_id']
|
||||
|
||||
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
|
||||
# Raise an error if the in/nin/all/near param is not iterable.
|
||||
@@ -154,7 +145,7 @@ def query(_doc_cls=None, **kwargs):
|
||||
if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \
|
||||
('$near' in value_dict or '$nearSphere' in value_dict):
|
||||
value_son = SON()
|
||||
for k, v in value_dict.iteritems():
|
||||
for k, v in iteritems(value_dict):
|
||||
if k == '$maxDistance' or k == '$minDistance':
|
||||
continue
|
||||
value_son[k] = v
|
||||
@@ -162,16 +153,14 @@ def query(_doc_cls=None, **kwargs):
|
||||
# PyMongo 3+ and MongoDB < 2.6
|
||||
near_embedded = False
|
||||
for near_op in ('$near', '$nearSphere'):
|
||||
if isinstance(value_dict.get(near_op), dict) and (
|
||||
IS_PYMONGO_3 or get_connection().max_wire_version > 1):
|
||||
if isinstance(value_dict.get(near_op), dict):
|
||||
value_son[near_op] = SON(value_son[near_op])
|
||||
if '$maxDistance' in value_dict:
|
||||
value_son[near_op][
|
||||
'$maxDistance'] = value_dict['$maxDistance']
|
||||
value_son[near_op]['$maxDistance'] = value_dict['$maxDistance']
|
||||
if '$minDistance' in value_dict:
|
||||
value_son[near_op][
|
||||
'$minDistance'] = value_dict['$minDistance']
|
||||
value_son[near_op]['$minDistance'] = value_dict['$minDistance']
|
||||
near_embedded = True
|
||||
|
||||
if not near_embedded:
|
||||
if '$maxDistance' in value_dict:
|
||||
value_son['$maxDistance'] = value_dict['$maxDistance']
|
||||
@@ -280,7 +269,7 @@ def update(_doc_cls=None, **update):
|
||||
|
||||
if op == 'pull':
|
||||
if field.required or value is not None:
|
||||
if match == 'in' and not isinstance(value, dict):
|
||||
if match in ('in', 'nin') and not isinstance(value, dict):
|
||||
value = _prepare_query_for_iterable(field, op, value)
|
||||
else:
|
||||
value = field.prepare_query_value(op, value)
|
||||
@@ -307,10 +296,6 @@ def update(_doc_cls=None, **update):
|
||||
|
||||
key = '.'.join(parts)
|
||||
|
||||
if not op:
|
||||
raise InvalidQueryError('Updates must supply an operation '
|
||||
'eg: set__FIELD=value')
|
||||
|
||||
if 'pull' in op and '.' in key:
|
||||
# Dot operators don't work on pull operations
|
||||
# unless they point to a list field
|
||||
|
||||
Reference in New Issue
Block a user