Merge branch 'master' into remove_save_embedded
This commit is contained in:
		| @@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) + | ||||
|            list(signals.__all__) + list(errors.__all__)) | ||||
|  | ||||
|  | ||||
| VERSION = (0, 16, 3) | ||||
| VERSION = (0, 17, 0) | ||||
|  | ||||
|  | ||||
| def get_version(): | ||||
|   | ||||
| @@ -13,7 +13,7 @@ _document_registry = {} | ||||
|  | ||||
|  | ||||
| def get_document(name): | ||||
|     """Get a document class by name.""" | ||||
|     """Get a registered Document class by name.""" | ||||
|     doc = _document_registry.get(name, None) | ||||
|     if not doc: | ||||
|         # Possible old style name | ||||
| @@ -30,3 +30,12 @@ def get_document(name): | ||||
|             been imported? | ||||
|         """.strip() % name) | ||||
|     return doc | ||||
|  | ||||
|  | ||||
| def _get_documents_by_db(connection_alias, default_connection_alias): | ||||
|     """Get all registered Documents class attached to a given database""" | ||||
|     def get_doc_alias(doc_cls): | ||||
|         return doc_cls._meta.get('db_alias', default_connection_alias) | ||||
|  | ||||
|     return [doc_cls for doc_cls in _document_registry.values() | ||||
|             if get_doc_alias(doc_cls) == connection_alias] | ||||
|   | ||||
| @@ -2,6 +2,7 @@ import weakref | ||||
|  | ||||
| from bson import DBRef | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | ||||
| @@ -363,7 +364,7 @@ class StrictDict(object): | ||||
|     _classes = {} | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         for k, v in kwargs.iteritems(): | ||||
|         for k, v in iteritems(kwargs): | ||||
|             setattr(self, k, v) | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
| @@ -411,7 +412,7 @@ class StrictDict(object): | ||||
|         return (key for key in self.__slots__ if hasattr(self, key)) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(list(self.iteritems())) | ||||
|         return len(list(iteritems(self))) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return self.items() == other.items() | ||||
|   | ||||
| @@ -5,6 +5,7 @@ from functools import partial | ||||
| from bson import DBRef, ObjectId, SON, json_util | ||||
| import pymongo | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.base.common import get_document | ||||
| @@ -83,7 +84,7 @@ class BaseDocument(object): | ||||
|         self._dynamic_fields = SON() | ||||
|  | ||||
|         # Assign default values to instance | ||||
|         for key, field in self._fields.iteritems(): | ||||
|         for key, field in iteritems(self._fields): | ||||
|             if self._db_field_map.get(key, key) in __only_fields: | ||||
|                 continue | ||||
|             value = getattr(self, key, None) | ||||
| @@ -95,14 +96,14 @@ class BaseDocument(object): | ||||
|         # Set passed values after initialisation | ||||
|         if self._dynamic: | ||||
|             dynamic_data = {} | ||||
|             for key, value in values.iteritems(): | ||||
|             for key, value in iteritems(values): | ||||
|                 if key in self._fields or key == '_id': | ||||
|                     setattr(self, key, value) | ||||
|                 else: | ||||
|                     dynamic_data[key] = value | ||||
|         else: | ||||
|             FileField = _import_class('FileField') | ||||
|             for key, value in values.iteritems(): | ||||
|             for key, value in iteritems(values): | ||||
|                 key = self._reverse_db_field_map.get(key, key) | ||||
|                 if key in self._fields or key in ('id', 'pk', '_cls'): | ||||
|                     if __auto_convert and value is not None: | ||||
| @@ -118,7 +119,7 @@ class BaseDocument(object): | ||||
|  | ||||
|         if self._dynamic: | ||||
|             self._dynamic_lock = False | ||||
|             for key, value in dynamic_data.iteritems(): | ||||
|             for key, value in iteritems(dynamic_data): | ||||
|                 setattr(self, key, value) | ||||
|  | ||||
|         # Flag initialised | ||||
| @@ -292,8 +293,7 @@ class BaseDocument(object): | ||||
|         """ | ||||
|         Return as SON data ready for use with MongoDB. | ||||
|         """ | ||||
|         if not fields: | ||||
|             fields = [] | ||||
|         fields = fields or [] | ||||
|  | ||||
|         data = SON() | ||||
|         data['_id'] = None | ||||
| @@ -513,7 +513,7 @@ class BaseDocument(object): | ||||
|         if not hasattr(data, 'items'): | ||||
|             iterator = enumerate(data) | ||||
|         else: | ||||
|             iterator = data.iteritems() | ||||
|             iterator = iteritems(data) | ||||
|  | ||||
|         for index_or_key, value in iterator: | ||||
|             item_key = '%s%s.' % (base_key, index_or_key) | ||||
| @@ -678,7 +678,7 @@ class BaseDocument(object): | ||||
|         # Convert SON to a data dict, making sure each key is a string and | ||||
|         # corresponds to the right db field. | ||||
|         data = {} | ||||
|         for key, value in son.iteritems(): | ||||
|         for key, value in iteritems(son): | ||||
|             key = str(key) | ||||
|             key = cls._db_field_map.get(key, key) | ||||
|             data[key] = value | ||||
| @@ -694,7 +694,7 @@ class BaseDocument(object): | ||||
|         if not _auto_dereference: | ||||
|             fields = copy.deepcopy(fields) | ||||
|  | ||||
|         for field_name, field in fields.iteritems(): | ||||
|         for field_name, field in iteritems(fields): | ||||
|             field._auto_dereference = _auto_dereference | ||||
|             if field.db_field in data: | ||||
|                 value = data[field.db_field] | ||||
| @@ -715,7 +715,7 @@ class BaseDocument(object): | ||||
|  | ||||
|         # In STRICT documents, remove any keys that aren't in cls._fields | ||||
|         if cls.STRICT: | ||||
|             data = {k: v for k, v in data.iteritems() if k in cls._fields} | ||||
|             data = {k: v for k, v in iteritems(data) if k in cls._fields} | ||||
|  | ||||
|         obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) | ||||
|         obj._changed_fields = changed_fields | ||||
| @@ -882,7 +882,8 @@ class BaseDocument(object): | ||||
|                 index = {'fields': fields, 'unique': True, 'sparse': sparse} | ||||
|                 unique_indexes.append(index) | ||||
|  | ||||
|             if field.__class__.__name__ == 'ListField': | ||||
|             if field.__class__.__name__ in {'EmbeddedDocumentListField', | ||||
|                                             'ListField', 'SortedListField'}: | ||||
|                 field = field.field | ||||
|  | ||||
|             # Grab any embedded document field unique indexes | ||||
|   | ||||
| @@ -5,13 +5,13 @@ import weakref | ||||
| from bson import DBRef, ObjectId, SON | ||||
| import pymongo | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine.base.common import UPDATE_OPERATORS | ||||
| from mongoengine.base.datastructures import (BaseDict, BaseList, | ||||
|                                              EmbeddedDocumentList) | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import ValidationError | ||||
|  | ||||
| from mongoengine.errors import DeprecatedError, ValidationError | ||||
|  | ||||
| __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', | ||||
|            'GeoJsonBaseField') | ||||
| @@ -52,8 +52,8 @@ class BaseField(object): | ||||
|             unique with. | ||||
|         :param primary_key: Mark this field as the primary key. Defaults to False. | ||||
|         :param validation: (optional) A callable to validate the value of the | ||||
|             field.  Generally this is deprecated in favour of the | ||||
|             `FIELD.validate` method | ||||
|             field.  The callable takes the value as parameter and should raise | ||||
|             a ValidationError if validation fails | ||||
|         :param choices: (optional) The valid choices | ||||
|         :param null: (optional) If the field value can be null. If no and there is a default value | ||||
|             then the default value is set | ||||
| @@ -225,10 +225,18 @@ class BaseField(object): | ||||
|         # check validation argument | ||||
|         if self.validation is not None: | ||||
|             if callable(self.validation): | ||||
|                 if not self.validation(value): | ||||
|                     self.error('Value does not match custom validation method') | ||||
|                 try: | ||||
|                     # breaking change of 0.18 | ||||
|                     # Get rid of True/False-type return for the validation method | ||||
|                     # in favor of having validation raising a ValidationError | ||||
|                     ret = self.validation(value) | ||||
|                     if ret is not None: | ||||
|                         raise DeprecatedError('validation argument for `%s` must not return anything, ' | ||||
|                                               'it should raise a ValidationError if validation fails' % self.name) | ||||
|                 except ValidationError as ex: | ||||
|                     self.error(str(ex)) | ||||
|             else: | ||||
|                 raise ValueError('validation argument for "%s" must be a ' | ||||
|                 raise ValueError('validation argument for `"%s"` must be a ' | ||||
|                                  'callable.' % self.name) | ||||
|  | ||||
|         self.validate(value, **kwargs) | ||||
| @@ -275,11 +283,16 @@ class ComplexBaseField(BaseField): | ||||
|  | ||||
|         _dereference = _import_class('DeReference')() | ||||
|  | ||||
|         if instance._initialised and dereference and instance._data.get(self.name): | ||||
|         if (instance._initialised and | ||||
|                 dereference and | ||||
|                 instance._data.get(self.name) and | ||||
|                 not getattr(instance._data[self.name], '_dereferenced', False)): | ||||
|             instance._data[self.name] = _dereference( | ||||
|                 instance._data.get(self.name), max_depth=1, instance=instance, | ||||
|                 name=self.name | ||||
|             ) | ||||
|             if hasattr(instance._data[self.name], '_dereferenced'): | ||||
|                 instance._data[self.name]._dereferenced = True | ||||
|  | ||||
|         value = super(ComplexBaseField, self).__get__(instance, owner) | ||||
|  | ||||
| @@ -382,11 +395,11 @@ class ComplexBaseField(BaseField): | ||||
|         if self.field: | ||||
|             value_dict = { | ||||
|                 key: self.field._to_mongo_safe_call(item, use_db_field, fields) | ||||
|                 for key, item in value.iteritems() | ||||
|                 for key, item in iteritems(value) | ||||
|             } | ||||
|         else: | ||||
|             value_dict = {} | ||||
|             for k, v in value.iteritems(): | ||||
|             for k, v in iteritems(value): | ||||
|                 if isinstance(v, Document): | ||||
|                     # We need the id from the saved object to create the DBRef | ||||
|                     if v.pk is None: | ||||
| @@ -423,7 +436,7 @@ class ComplexBaseField(BaseField): | ||||
|         errors = {} | ||||
|         if self.field: | ||||
|             if hasattr(value, 'iteritems') or hasattr(value, 'items'): | ||||
|                 sequence = value.iteritems() | ||||
|                 sequence = iteritems(value) | ||||
|             else: | ||||
|                 sequence = enumerate(value) | ||||
|             for k, v in sequence: | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import warnings | ||||
|  | ||||
| import six | ||||
| from six import iteritems, itervalues | ||||
|  | ||||
| from mongoengine.base.common import _document_registry | ||||
| from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | ||||
| @@ -62,7 +63,7 @@ class DocumentMetaclass(type): | ||||
|             # Standard object mixin - merge in any Fields | ||||
|             if not hasattr(base, '_meta'): | ||||
|                 base_fields = {} | ||||
|                 for attr_name, attr_value in base.__dict__.iteritems(): | ||||
|                 for attr_name, attr_value in iteritems(base.__dict__): | ||||
|                     if not isinstance(attr_value, BaseField): | ||||
|                         continue | ||||
|                     attr_value.name = attr_name | ||||
| @@ -74,7 +75,7 @@ class DocumentMetaclass(type): | ||||
|  | ||||
|         # Discover any document fields | ||||
|         field_names = {} | ||||
|         for attr_name, attr_value in attrs.iteritems(): | ||||
|         for attr_name, attr_value in iteritems(attrs): | ||||
|             if not isinstance(attr_value, BaseField): | ||||
|                 continue | ||||
|             attr_value.name = attr_name | ||||
| @@ -103,7 +104,7 @@ class DocumentMetaclass(type): | ||||
|  | ||||
|         attrs['_fields_ordered'] = tuple(i[1] for i in sorted( | ||||
|                                          (v.creation_counter, v.name) | ||||
|                                          for v in doc_fields.itervalues())) | ||||
|                                          for v in itervalues(doc_fields))) | ||||
|  | ||||
|         # | ||||
|         # Set document hierarchy | ||||
| @@ -173,7 +174,7 @@ class DocumentMetaclass(type): | ||||
|                         f.__dict__.update({'im_self': getattr(f, '__self__')}) | ||||
|  | ||||
|         # Handle delete rules | ||||
|         for field in new_class._fields.itervalues(): | ||||
|         for field in itervalues(new_class._fields): | ||||
|             f = field | ||||
|             if f.owner_document is None: | ||||
|                 f.owner_document = new_class | ||||
| @@ -183,9 +184,6 @@ class DocumentMetaclass(type): | ||||
|                 if issubclass(new_class, EmbeddedDocument): | ||||
|                     raise InvalidDocumentError('CachedReferenceFields is not ' | ||||
|                                                'allowed in EmbeddedDocuments') | ||||
|                 if not f.document_type: | ||||
|                     raise InvalidDocumentError( | ||||
|                         'Document is not available to sync') | ||||
|  | ||||
|                 if f.auto_sync: | ||||
|                     f.start_listener() | ||||
| @@ -375,7 +373,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | ||||
|             new_class.objects = QuerySetManager() | ||||
|  | ||||
|         # Validate the fields and set primary key if needed | ||||
|         for field_name, field in new_class._fields.iteritems(): | ||||
|         for field_name, field in iteritems(new_class._fields): | ||||
|             if field.primary_key: | ||||
|                 # Ensure only one primary key is set | ||||
|                 current_pk = new_class._meta.get('id_field') | ||||
| @@ -438,7 +436,7 @@ class MetaDict(dict): | ||||
|     _merge_options = ('indexes',) | ||||
|  | ||||
|     def merge(self, new_options): | ||||
|         for k, v in new_options.iteritems(): | ||||
|         for k, v in iteritems(new_options): | ||||
|             if k in self._merge_options: | ||||
|                 self[k] = self.get(k, []) + v | ||||
|             else: | ||||
|   | ||||
| @@ -31,7 +31,6 @@ def _import_class(cls_name): | ||||
|  | ||||
|     field_classes = _field_list_cache | ||||
|  | ||||
|     queryset_classes = ('OperationError',) | ||||
|     deref_classes = ('DeReference',) | ||||
|  | ||||
|     if cls_name == 'BaseDocument': | ||||
| @@ -43,14 +42,11 @@ def _import_class(cls_name): | ||||
|     elif cls_name in field_classes: | ||||
|         from mongoengine import fields as module | ||||
|         import_classes = field_classes | ||||
|     elif cls_name in queryset_classes: | ||||
|         from mongoengine import queryset as module | ||||
|         import_classes = queryset_classes | ||||
|     elif cls_name in deref_classes: | ||||
|         from mongoengine import dereference as module | ||||
|         import_classes = deref_classes | ||||
|     else: | ||||
|         raise ValueError('No import set for: ' % cls_name) | ||||
|         raise ValueError('No import set for: %s' % cls_name) | ||||
|  | ||||
|     for cls in import_classes: | ||||
|         _class_registry_cache[cls] = getattr(module, cls) | ||||
|   | ||||
| @@ -1,19 +1,22 @@ | ||||
| from pymongo import MongoClient, ReadPreference, uri_parser | ||||
| from pymongo.database import _check_name | ||||
| import six | ||||
|  | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
|  | ||||
| __all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', | ||||
|            'DEFAULT_CONNECTION_NAME'] | ||||
| __all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all', | ||||
|            'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME', | ||||
|            'get_db', 'get_connection'] | ||||
|  | ||||
|  | ||||
| DEFAULT_CONNECTION_NAME = 'default' | ||||
| DEFAULT_DATABASE_NAME = 'test' | ||||
| DEFAULT_HOST = 'localhost' | ||||
| DEFAULT_PORT = 27017 | ||||
|  | ||||
| if IS_PYMONGO_3: | ||||
|     READ_PREFERENCE = ReadPreference.PRIMARY | ||||
| else: | ||||
|     from pymongo import MongoReplicaSetClient | ||||
|     READ_PREFERENCE = False | ||||
| _connection_settings = {} | ||||
| _connections = {} | ||||
| _dbs = {} | ||||
|  | ||||
| READ_PREFERENCE = ReadPreference.PRIMARY | ||||
|  | ||||
|  | ||||
| class MongoEngineConnectionError(Exception): | ||||
| @@ -23,45 +26,48 @@ class MongoEngineConnectionError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| _connection_settings = {} | ||||
| _connections = {} | ||||
| _dbs = {} | ||||
| def _check_db_name(name): | ||||
|     """Check if a database name is valid. | ||||
|     This functionality is copied from pymongo Database class constructor. | ||||
|     """ | ||||
|     if not isinstance(name, six.string_types): | ||||
|         raise TypeError('name must be an instance of %s' % six.string_types) | ||||
|     elif name != '$external': | ||||
|         _check_name(name) | ||||
|  | ||||
|  | ||||
| def register_connection(alias, db=None, name=None, host=None, port=None, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, | ||||
|                         authentication_source=None, | ||||
|                         authentication_mechanism=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
| def _get_connection_settings( | ||||
|         db=None, name=None, host=None, port=None, | ||||
|         read_preference=READ_PREFERENCE, | ||||
|         username=None, password=None, | ||||
|         authentication_source=None, | ||||
|         authentication_mechanism=None, | ||||
|         **kwargs): | ||||
|     """Get the connection settings as a dict | ||||
|  | ||||
|     :param alias: the name that will be used to refer to this connection | ||||
|         throughout MongoEngine | ||||
|     :param name: the name of the specific database to use | ||||
|     :param db: the name of the database to use, for compatibility with connect | ||||
|     :param host: the host name of the :program:`mongod` instance to connect to | ||||
|     :param port: the port that the :program:`mongod` instance is running on | ||||
|     :param read_preference: The read preference for the collection | ||||
|        ** Added pymongo 2.1 | ||||
|     :param username: username to authenticate with | ||||
|     :param password: password to authenticate with | ||||
|     :param authentication_source: database to authenticate against | ||||
|     :param authentication_mechanism: database authentication mechanisms. | ||||
|     : param db: the name of the database to use, for compatibility with connect | ||||
|     : param name: the name of the specific database to use | ||||
|     : param host: the host name of the: program: `mongod` instance to connect to | ||||
|     : param port: the port that the: program: `mongod` instance is running on | ||||
|     : param read_preference: The read preference for the collection | ||||
|     : param username: username to authenticate with | ||||
|     : param password: password to authenticate with | ||||
|     : param authentication_source: database to authenticate against | ||||
|     : param authentication_mechanism: database authentication mechanisms. | ||||
|         By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, | ||||
|         MONGODB-CR (MongoDB Challenge Response protocol) for older servers. | ||||
|     :param is_mock: explicitly use mongomock for this connection | ||||
|         (can also be done by using `mongomock://` as db host prefix) | ||||
|     :param kwargs: ad-hoc parameters to be passed into the pymongo driver, | ||||
|     : param is_mock: explicitly use mongomock for this connection | ||||
|         (can also be done by using `mongomock: // ` as db host prefix) | ||||
|     : param kwargs: ad-hoc parameters to be passed into the pymongo driver, | ||||
|         for example maxpoolsize, tz_aware, etc. See the documentation | ||||
|         for pymongo's `MongoClient` for a full list. | ||||
|  | ||||
|     .. versionchanged:: 0.10.6 - added mongomock support | ||||
|     """ | ||||
|     conn_settings = { | ||||
|         'name': name or db or 'test', | ||||
|         'host': host or 'localhost', | ||||
|         'port': port or 27017, | ||||
|         'name': name or db or DEFAULT_DATABASE_NAME, | ||||
|         'host': host or DEFAULT_HOST, | ||||
|         'port': port or DEFAULT_PORT, | ||||
|         'read_preference': read_preference, | ||||
|         'username': username, | ||||
|         'password': password, | ||||
| @@ -69,6 +75,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, | ||||
|         'authentication_mechanism': authentication_mechanism | ||||
|     } | ||||
|  | ||||
|     _check_db_name(conn_settings['name']) | ||||
|     conn_host = conn_settings['host'] | ||||
|  | ||||
|     # Host can be a list or a string, so if string, force to a list. | ||||
| @@ -104,16 +111,28 @@ def register_connection(alias, db=None, name=None, host=None, port=None, | ||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||
|             if 'authmechanism' in uri_options: | ||||
|                 conn_settings['authentication_mechanism'] = uri_options['authmechanism'] | ||||
|             if IS_PYMONGO_3 and 'readpreference' in uri_options: | ||||
|             if 'readpreference' in uri_options: | ||||
|                 read_preferences = ( | ||||
|                     ReadPreference.NEAREST, | ||||
|                     ReadPreference.PRIMARY, | ||||
|                     ReadPreference.PRIMARY_PREFERRED, | ||||
|                     ReadPreference.SECONDARY, | ||||
|                     ReadPreference.SECONDARY_PREFERRED) | ||||
|                 read_pf_mode = uri_options['readpreference'].lower() | ||||
|                     ReadPreference.SECONDARY_PREFERRED, | ||||
|                 ) | ||||
|  | ||||
|                 # Starting with PyMongo v3.5, the "readpreference" option is | ||||
|                 # returned as a string (e.g. "secondaryPreferred") and not an | ||||
|                 # int (e.g. 3). | ||||
|                 # TODO simplify the code below once we drop support for | ||||
|                 # PyMongo v3.4. | ||||
|                 read_pf_mode = uri_options['readpreference'] | ||||
|                 if isinstance(read_pf_mode, six.string_types): | ||||
|                     read_pf_mode = read_pf_mode.lower() | ||||
|                 for preference in read_preferences: | ||||
|                     if preference.name.lower() == read_pf_mode: | ||||
|                     if ( | ||||
|                         preference.name.lower() == read_pf_mode or | ||||
|                         preference.mode == read_pf_mode | ||||
|                     ): | ||||
|                         conn_settings['read_preference'] = preference | ||||
|                         break | ||||
|         else: | ||||
| @@ -125,17 +144,74 @@ def register_connection(alias, db=None, name=None, host=None, port=None, | ||||
|     kwargs.pop('is_slave', None) | ||||
|  | ||||
|     conn_settings.update(kwargs) | ||||
|     return conn_settings | ||||
|  | ||||
|  | ||||
| def register_connection(alias, db=None, name=None, host=None, port=None, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, | ||||
|                         authentication_source=None, | ||||
|                         authentication_mechanism=None, | ||||
|                         **kwargs): | ||||
|     """Register the connection settings. | ||||
|  | ||||
|     : param alias: the name that will be used to refer to this connection | ||||
|         throughout MongoEngine | ||||
|     : param name: the name of the specific database to use | ||||
|     : param db: the name of the database to use, for compatibility with connect | ||||
|     : param host: the host name of the: program: `mongod` instance to connect to | ||||
|     : param port: the port that the: program: `mongod` instance is running on | ||||
|     : param read_preference: The read preference for the collection | ||||
|     : param username: username to authenticate with | ||||
|     : param password: password to authenticate with | ||||
|     : param authentication_source: database to authenticate against | ||||
|     : param authentication_mechanism: database authentication mechanisms. | ||||
|         By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, | ||||
|         MONGODB-CR (MongoDB Challenge Response protocol) for older servers. | ||||
|     : param is_mock: explicitly use mongomock for this connection | ||||
|         (can also be done by using `mongomock: // ` as db host prefix) | ||||
|     : param kwargs: ad-hoc parameters to be passed into the pymongo driver, | ||||
|         for example maxpoolsize, tz_aware, etc. See the documentation | ||||
|         for pymongo's `MongoClient` for a full list. | ||||
|  | ||||
|     .. versionchanged:: 0.10.6 - added mongomock support | ||||
|     """ | ||||
|     conn_settings = _get_connection_settings( | ||||
|         db=db, name=name, host=host, port=port, | ||||
|         read_preference=read_preference, | ||||
|         username=username, password=password, | ||||
|         authentication_source=authentication_source, | ||||
|         authentication_mechanism=authentication_mechanism, | ||||
|         **kwargs) | ||||
|     _connection_settings[alias] = conn_settings | ||||
|  | ||||
|  | ||||
| def disconnect(alias=DEFAULT_CONNECTION_NAME): | ||||
|     """Close the connection with a given alias.""" | ||||
|     from mongoengine.base.common import _get_documents_by_db | ||||
|     from mongoengine import Document | ||||
|  | ||||
|     if alias in _connections: | ||||
|         get_connection(alias=alias).close() | ||||
|         del _connections[alias] | ||||
|  | ||||
|     if alias in _dbs: | ||||
|         # Detach all cached collections in Documents | ||||
|         for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): | ||||
|             if issubclass(doc_cls, Document):     # Skip EmbeddedDocument | ||||
|                 doc_cls._disconnect() | ||||
|  | ||||
|         del _dbs[alias] | ||||
|  | ||||
|     if alias in _connection_settings: | ||||
|         del _connection_settings[alias] | ||||
|  | ||||
|  | ||||
| def disconnect_all(): | ||||
|     """Close all registered database.""" | ||||
|     for alias in list(_connections.keys()): | ||||
|         disconnect(alias) | ||||
|  | ||||
|  | ||||
| def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|     """Return a connection with a given alias.""" | ||||
| @@ -159,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|         raise MongoEngineConnectionError(msg) | ||||
|  | ||||
|     def _clean_settings(settings_dict): | ||||
|         # set literal more efficient than calling set function | ||||
|         irrelevant_fields_set = { | ||||
|             'name', 'username', 'password', | ||||
|             'authentication_source', 'authentication_mechanism' | ||||
| @@ -169,10 +244,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|             if k not in irrelevant_fields_set | ||||
|         } | ||||
|  | ||||
|     raw_conn_settings = _connection_settings[alias].copy() | ||||
|  | ||||
|     # Retrieve a copy of the connection settings associated with the requested | ||||
|     # alias and remove the database name and authentication info (we don't | ||||
|     # care about them at this point). | ||||
|     conn_settings = _clean_settings(_connection_settings[alias].copy()) | ||||
|     conn_settings = _clean_settings(raw_conn_settings) | ||||
|  | ||||
|     # Determine if we should use PyMongo's or mongomock's MongoClient. | ||||
|     is_mock = conn_settings.pop('is_mock', False) | ||||
| @@ -186,51 +263,60 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|     else: | ||||
|         connection_class = MongoClient | ||||
|  | ||||
|         # For replica set connections with PyMongo 2.x, use | ||||
|         # MongoReplicaSetClient. | ||||
|         # TODO remove this once we stop supporting PyMongo 2.x. | ||||
|         if 'replicaSet' in conn_settings and not IS_PYMONGO_3: | ||||
|             connection_class = MongoReplicaSetClient | ||||
|             conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) | ||||
|  | ||||
|             # hosts_or_uri has to be a string, so if 'host' was provided | ||||
|             # as a list, join its parts and separate them by ',' | ||||
|             if isinstance(conn_settings['hosts_or_uri'], list): | ||||
|                 conn_settings['hosts_or_uri'] = ','.join( | ||||
|                     conn_settings['hosts_or_uri']) | ||||
|  | ||||
|             # Discard port since it can't be used on MongoReplicaSetClient | ||||
|             conn_settings.pop('port', None) | ||||
|  | ||||
|     # Iterate over all of the connection settings and if a connection with | ||||
|     # the same parameters is already established, use it instead of creating | ||||
|     # a new one. | ||||
|     existing_connection = None | ||||
|     connection_settings_iterator = ( | ||||
|         (db_alias, settings.copy()) | ||||
|         for db_alias, settings in _connection_settings.items() | ||||
|     ) | ||||
|     for db_alias, connection_settings in connection_settings_iterator: | ||||
|         connection_settings = _clean_settings(connection_settings) | ||||
|         if conn_settings == connection_settings and _connections.get(db_alias): | ||||
|             existing_connection = _connections[db_alias] | ||||
|             break | ||||
|     # Re-use existing connection if one is suitable | ||||
|     existing_connection = _find_existing_connection(raw_conn_settings) | ||||
|  | ||||
|     # If an existing connection was found, assign it to the new alias | ||||
|     if existing_connection: | ||||
|         _connections[alias] = existing_connection | ||||
|     else: | ||||
|         # Otherwise, create the new connection for this alias. Raise | ||||
|         # MongoEngineConnectionError if it can't be established. | ||||
|         try: | ||||
|             _connections[alias] = connection_class(**conn_settings) | ||||
|         except Exception as e: | ||||
|             raise MongoEngineConnectionError( | ||||
|                 'Cannot connect to database %s :\n%s' % (alias, e)) | ||||
|         _connections[alias] = _create_connection(alias=alias, | ||||
|                                                  connection_class=connection_class, | ||||
|                                                  **conn_settings) | ||||
|  | ||||
|     return _connections[alias] | ||||
|  | ||||
|  | ||||
| def _create_connection(alias, connection_class, **connection_settings): | ||||
|     """ | ||||
|     Create the new connection for this alias. Raise | ||||
|     MongoEngineConnectionError if it can't be established. | ||||
|     """ | ||||
|     try: | ||||
|         return connection_class(**connection_settings) | ||||
|     except Exception as e: | ||||
|         raise MongoEngineConnectionError( | ||||
|             'Cannot connect to database %s :\n%s' % (alias, e)) | ||||
|  | ||||
|  | ||||
| def _find_existing_connection(connection_settings): | ||||
|     """ | ||||
|     Check if an existing connection could be reused | ||||
|  | ||||
|     Iterate over all of the connection settings and if an existing connection | ||||
|     with the same parameters is suitable, return it | ||||
|  | ||||
|     :param connection_settings: the settings of the new connection | ||||
|     :return: An existing connection or None | ||||
|     """ | ||||
|     connection_settings_bis = ( | ||||
|         (db_alias, settings.copy()) | ||||
|         for db_alias, settings in _connection_settings.items() | ||||
|     ) | ||||
|  | ||||
|     def _clean_settings(settings_dict): | ||||
|         # Only remove the name but it's important to | ||||
|         # keep the username/password/authentication_source/authentication_mechanism | ||||
|         # to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047) | ||||
|         return {k: v for k, v in settings_dict.items() if k != 'name'} | ||||
|  | ||||
|     cleaned_conn_settings = _clean_settings(connection_settings) | ||||
|     for db_alias, connection_settings in connection_settings_bis: | ||||
|         db_conn_settings = _clean_settings(connection_settings) | ||||
|         if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): | ||||
|             return _connections[db_alias] | ||||
|  | ||||
|  | ||||
| def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|     if reconnect: | ||||
|         disconnect(alias) | ||||
| @@ -258,14 +344,24 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): | ||||
|     provide username and password arguments as well. | ||||
|  | ||||
|     Multiple databases are supported by using aliases. Provide a separate | ||||
|     `alias` to connect to a different instance of :program:`mongod`. | ||||
|     `alias` to connect to a different instance of: program: `mongod`. | ||||
|  | ||||
|     In order to replace a connection identified by a given alias, you'll | ||||
|     need to call ``disconnect`` first | ||||
|  | ||||
|     See the docstring for `register_connection` for more details about all | ||||
|     supported kwargs. | ||||
|  | ||||
|     .. versionchanged:: 0.6 - added multiple database support. | ||||
|     """ | ||||
|     if alias not in _connections: | ||||
|     if alias in _connections: | ||||
|         prev_conn_setting = _connection_settings[alias] | ||||
|         new_conn_settings = _get_connection_settings(db, **kwargs) | ||||
|  | ||||
|         if new_conn_settings != prev_conn_setting: | ||||
|             raise MongoEngineConnectionError( | ||||
|                 'A different connection with alias `%s` was already registered. Use disconnect() first' % alias) | ||||
|     else: | ||||
|         register_connection(alias, db, **kwargs) | ||||
|  | ||||
|     return get_connection(alias) | ||||
|   | ||||
| @@ -1,8 +1,11 @@ | ||||
| from contextlib import contextmanager | ||||
|  | ||||
| from pymongo.write_concern import WriteConcern | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
|  | ||||
| from mongoengine.pymongo_support import count_documents | ||||
|  | ||||
| __all__ = ('switch_db', 'switch_collection', 'no_dereference', | ||||
|            'no_sub_classes', 'query_counter', 'set_write_concern') | ||||
| @@ -112,7 +115,7 @@ class no_dereference(object): | ||||
|         GenericReferenceField = _import_class('GenericReferenceField') | ||||
|         ComplexBaseField = _import_class('ComplexBaseField') | ||||
|  | ||||
|         self.deref_fields = [k for k, v in self.cls._fields.iteritems() | ||||
|         self.deref_fields = [k for k, v in iteritems(self.cls._fields) | ||||
|                              if isinstance(v, (ReferenceField, | ||||
|                                                GenericReferenceField, | ||||
|                                                ComplexBaseField))] | ||||
| @@ -235,7 +238,7 @@ class query_counter(object): | ||||
|         and substracting the queries issued by this context. In fact everytime this is called, 1 query is | ||||
|         issued so we need to balance that | ||||
|         """ | ||||
|         count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter | ||||
|         count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter | ||||
|         self._ctx_query_counter += 1    # Account for the query we just issued to gather the information | ||||
|         return count | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| from bson import DBRef, SON | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, | ||||
|                               TopLevelDocumentMetaclass, get_document) | ||||
| @@ -71,7 +72,7 @@ class DeReference(object): | ||||
|  | ||||
|                     def _get_items_from_dict(items): | ||||
|                         new_items = {} | ||||
|                         for k, v in items.iteritems(): | ||||
|                         for k, v in iteritems(items): | ||||
|                             value = v | ||||
|                             if isinstance(v, list): | ||||
|                                 value = _get_items_from_list(v) | ||||
| @@ -112,7 +113,7 @@ class DeReference(object): | ||||
|         depth += 1 | ||||
|         for item in iterator: | ||||
|             if isinstance(item, (Document, EmbeddedDocument)): | ||||
|                 for field_name, field in item._fields.iteritems(): | ||||
|                 for field_name, field in iteritems(item._fields): | ||||
|                     v = item._data.get(field_name, None) | ||||
|                     if isinstance(v, LazyReference): | ||||
|                         # LazyReference inherits DBRef but should not be dereferenced here ! | ||||
| @@ -124,7 +125,7 @@ class DeReference(object): | ||||
|                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                         field_cls = getattr(getattr(field, 'field', None), 'document_type', None) | ||||
|                         references = self._find_references(v, depth) | ||||
|                         for key, refs in references.iteritems(): | ||||
|                         for key, refs in iteritems(references): | ||||
|                             if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): | ||||
|                                 key = field_cls | ||||
|                             reference_map.setdefault(key, set()).update(refs) | ||||
| @@ -137,7 +138,7 @@ class DeReference(object): | ||||
|                 reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id) | ||||
|             elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: | ||||
|                 references = self._find_references(item, depth - 1) | ||||
|                 for key, refs in references.iteritems(): | ||||
|                 for key, refs in iteritems(references): | ||||
|                     reference_map.setdefault(key, set()).update(refs) | ||||
|  | ||||
|         return reference_map | ||||
| @@ -146,7 +147,7 @@ class DeReference(object): | ||||
|         """Fetch all references and convert to their document objects | ||||
|         """ | ||||
|         object_map = {} | ||||
|         for collection, dbrefs in self.reference_map.iteritems(): | ||||
|         for collection, dbrefs in iteritems(self.reference_map): | ||||
|  | ||||
|             # we use getattr instead of hasattr because hasattr swallows any exception under python2 | ||||
|             # so it could hide nasty things without raising exceptions (cfr bug #1688)) | ||||
| @@ -157,7 +158,7 @@ class DeReference(object): | ||||
|                 refs = [dbref for dbref in dbrefs | ||||
|                         if (col_name, dbref) not in object_map] | ||||
|                 references = collection.objects.in_bulk(refs) | ||||
|                 for key, doc in references.iteritems(): | ||||
|                 for key, doc in iteritems(references): | ||||
|                     object_map[(col_name, key)] = doc | ||||
|             else:  # Generic reference: use the refs data to convert to document | ||||
|                 if isinstance(doc_type, (ListField, DictField, MapField)): | ||||
| @@ -229,7 +230,7 @@ class DeReference(object): | ||||
|             data = [] | ||||
|         else: | ||||
|             is_list = False | ||||
|             iterator = items.iteritems() | ||||
|             iterator = iteritems(items) | ||||
|             data = {} | ||||
|  | ||||
|         depth += 1 | ||||
|   | ||||
| @@ -5,6 +5,7 @@ from bson.dbref import DBRef | ||||
| import pymongo | ||||
| from pymongo.read_preferences import ReadPreference | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.base import (BaseDict, BaseDocument, BaseList, | ||||
| @@ -17,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern, | ||||
|                                           switch_db) | ||||
| from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, | ||||
|                                 SaveConditionError) | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| from mongoengine.pymongo_support import list_collection_names | ||||
| from mongoengine.queryset import (NotUniqueError, OperationError, | ||||
|                                   QuerySet, transform) | ||||
|  | ||||
| @@ -175,10 +176,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_collection(cls): | ||||
|         """Return a PyMongo collection for the document.""" | ||||
|         if not hasattr(cls, '_collection') or cls._collection is None: | ||||
|     def _disconnect(cls): | ||||
|         """Detach the Document class from the (cached) database collection""" | ||||
|         cls._collection = None | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_collection(cls): | ||||
|         """Return the corresponding PyMongo collection of this document. | ||||
|         Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached | ||||
|         """ | ||||
|         if not hasattr(cls, '_collection') or cls._collection is None: | ||||
|             # Get the collection, either capped or regular. | ||||
|             if cls._meta.get('max_size') or cls._meta.get('max_documents'): | ||||
|                 cls._collection = cls._get_capped_collection() | ||||
| @@ -215,7 +222,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|  | ||||
|         # If the collection already exists and has different options | ||||
|         # (i.e. isn't capped or has different max/size), raise an error. | ||||
|         if collection_name in db.collection_names(): | ||||
|         if collection_name in list_collection_names(db, include_system_collections=True): | ||||
|             collection = db[collection_name] | ||||
|             options = collection.options() | ||||
|             if ( | ||||
| @@ -240,7 +247,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         data = super(Document, self).to_mongo(*args, **kwargs) | ||||
|  | ||||
|         # If '_id' is None, try and set it from self._data. If that | ||||
|         # doesn't exist either, remote '_id' from the SON completely. | ||||
|         # doesn't exist either, remove '_id' from the SON completely. | ||||
|         if data['_id'] is None: | ||||
|             if self._data.get('id') is None: | ||||
|                 del data['_id'] | ||||
| @@ -346,21 +353,21 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             Add signal_kwargs argument | ||||
|         """ | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|  | ||||
|         if self._meta.get('abstract'): | ||||
|             raise InvalidDocumentError('Cannot save an abstract document.') | ||||
|  | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) | ||||
|  | ||||
|         if validate: | ||||
|             self.validate(clean=clean) | ||||
|  | ||||
|         if write_concern is None: | ||||
|             write_concern = {'w': 1} | ||||
|             write_concern = {} | ||||
|  | ||||
|         doc = self.to_mongo() | ||||
|  | ||||
|         created = ('_id' not in doc or self._created or force_insert) | ||||
|         doc_id = self.to_mongo(fields=['id']) | ||||
|         created = ('_id' not in doc_id or self._created or force_insert) | ||||
|  | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||
|                                               created=created, **signal_kwargs) | ||||
| @@ -438,16 +445,6 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|  | ||||
|             object_id = wc_collection.insert_one(doc).inserted_id | ||||
|  | ||||
|         # In PyMongo 3.0, the save() call calls internally the _update() call | ||||
|         # but they forget to return the _id value passed back, therefore getting it back here | ||||
|         # Correct behaviour in 2.X and in 3.0.1+ versions | ||||
|         if not object_id and pymongo.version_tuple == (3, 0): | ||||
|             pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||
|             object_id = ( | ||||
|                 self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||
|                 self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|             )  # TODO doesn't this make 2 queries? | ||||
|  | ||||
|         return object_id | ||||
|  | ||||
|     def _get_update_doc(self): | ||||
| @@ -493,8 +490,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         update_doc = self._get_update_doc() | ||||
|         if update_doc: | ||||
|             upsert = save_condition is None | ||||
|             last_error = collection.update(select_dict, update_doc, | ||||
|                                            upsert=upsert, **write_concern) | ||||
|             with set_write_concern(collection, write_concern) as wc_collection: | ||||
|                 last_error = wc_collection.update_one( | ||||
|                     select_dict, | ||||
|                     update_doc, | ||||
|                     upsert=upsert | ||||
|                 ).raw_result | ||||
|             if not upsert and last_error['n'] == 0: | ||||
|                 raise SaveConditionError('Race condition preventing' | ||||
|                                          ' document update detected') | ||||
| @@ -601,7 +602,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|  | ||||
|         # Delete FileFields separately | ||||
|         FileField = _import_class('FileField') | ||||
|         for name, field in self._fields.iteritems(): | ||||
|         for name, field in iteritems(self._fields): | ||||
|             if isinstance(field, FileField): | ||||
|                 getattr(self, name).delete() | ||||
|  | ||||
| @@ -786,13 +787,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             :class:`OperationError` exception raised if no collection available | ||||
|         """ | ||||
|         col_name = cls._get_collection_name() | ||||
|         if not col_name: | ||||
|         coll_name = cls._get_collection_name() | ||||
|         if not coll_name: | ||||
|             raise OperationError('Document %s has no collection defined ' | ||||
|                                  '(is it abstract ?)' % cls) | ||||
|         cls._collection = None | ||||
|         db = cls._get_db() | ||||
|         db.drop_collection(col_name) | ||||
|         db.drop_collection(coll_name) | ||||
|  | ||||
|     @classmethod | ||||
|     def create_index(cls, keys, background=False, **kwargs): | ||||
| @@ -807,18 +808,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         index_spec = index_spec.copy() | ||||
|         fields = index_spec.pop('fields') | ||||
|         drop_dups = kwargs.get('drop_dups', False) | ||||
|         if IS_PYMONGO_3 and drop_dups: | ||||
|         if drop_dups: | ||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         elif not IS_PYMONGO_3: | ||||
|             index_spec['drop_dups'] = drop_dups | ||||
|         index_spec['background'] = background | ||||
|         index_spec.update(kwargs) | ||||
|  | ||||
|         if IS_PYMONGO_3: | ||||
|             return cls._get_collection().create_index(fields, **index_spec) | ||||
|         else: | ||||
|             return cls._get_collection().ensure_index(fields, **index_spec) | ||||
|         return cls._get_collection().create_index(fields, **index_spec) | ||||
|  | ||||
|     @classmethod | ||||
|     def ensure_index(cls, key_or_list, drop_dups=False, background=False, | ||||
| @@ -833,11 +829,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         :param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value | ||||
|             will be removed if PyMongo3+ is used | ||||
|         """ | ||||
|         if IS_PYMONGO_3 and drop_dups: | ||||
|         if drop_dups: | ||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         elif not IS_PYMONGO_3: | ||||
|             kwargs.update({'drop_dups': drop_dups}) | ||||
|         return cls.create_index(key_or_list, background=background, **kwargs) | ||||
|  | ||||
|     @classmethod | ||||
| @@ -853,7 +847,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|         drop_dups = cls._meta.get('index_drop_dups', False) | ||||
|         index_opts = cls._meta.get('index_opts') or {} | ||||
|         index_cls = cls._meta.get('index_cls', True) | ||||
|         if IS_PYMONGO_3 and drop_dups: | ||||
|         if drop_dups: | ||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|  | ||||
| @@ -884,11 +878,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|                 if 'cls' in opts: | ||||
|                     del opts['cls'] | ||||
|  | ||||
|                 if IS_PYMONGO_3: | ||||
|                     collection.create_index(fields, background=background, **opts) | ||||
|                 else: | ||||
|                     collection.ensure_index(fields, background=background, | ||||
|                                             drop_dups=drop_dups, **opts) | ||||
|                 collection.create_index(fields, background=background, **opts) | ||||
|  | ||||
|         # If _cls is being used (for polymorphism), it needs an index, | ||||
|         # only if another index doesn't begin with _cls | ||||
| @@ -899,12 +889,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): | ||||
|             if 'cls' in index_opts: | ||||
|                 del index_opts['cls'] | ||||
|  | ||||
|             if IS_PYMONGO_3: | ||||
|                 collection.create_index('_cls', background=background, | ||||
|                                         **index_opts) | ||||
|             else: | ||||
|                 collection.ensure_index('_cls', background=background, | ||||
|                                         **index_opts) | ||||
|             collection.create_index('_cls', background=background, | ||||
|                                     **index_opts) | ||||
|  | ||||
|     @classmethod | ||||
|     def list_indexes(cls): | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| from collections import defaultdict | ||||
|  | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', | ||||
|            'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', | ||||
|            'OperationError', 'NotUniqueError', 'FieldDoesNotExist', | ||||
|            'ValidationError', 'SaveConditionError') | ||||
|            'ValidationError', 'SaveConditionError', 'DeprecatedError') | ||||
|  | ||||
|  | ||||
| class NotRegistered(Exception): | ||||
| @@ -109,11 +110,8 @@ class ValidationError(AssertionError): | ||||
|  | ||||
|         def build_dict(source): | ||||
|             errors_dict = {} | ||||
|             if not source: | ||||
|                 return errors_dict | ||||
|  | ||||
|             if isinstance(source, dict): | ||||
|                 for field_name, error in source.iteritems(): | ||||
|                 for field_name, error in iteritems(source): | ||||
|                     errors_dict[field_name] = build_dict(error) | ||||
|             elif isinstance(source, ValidationError) and source.errors: | ||||
|                 return build_dict(source.errors) | ||||
| @@ -135,12 +133,17 @@ class ValidationError(AssertionError): | ||||
|                 value = ' '.join([generate_key(k) for k in value]) | ||||
|             elif isinstance(value, dict): | ||||
|                 value = ' '.join( | ||||
|                     [generate_key(v, k) for k, v in value.iteritems()]) | ||||
|                     [generate_key(v, k) for k, v in iteritems(value)]) | ||||
|  | ||||
|             results = '%s.%s' % (prefix, value) if prefix else value | ||||
|             return results | ||||
|  | ||||
|         error_dict = defaultdict(list) | ||||
|         for k, v in self.to_dict().iteritems(): | ||||
|         for k, v in iteritems(self.to_dict()): | ||||
|             error_dict[generate_key(v)].append(k) | ||||
|         return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()]) | ||||
|         return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)]) | ||||
|  | ||||
|  | ||||
| class DeprecatedError(Exception): | ||||
|     """Raise when a user uses a feature that has been Deprecated""" | ||||
|     pass | ||||
|   | ||||
| @@ -11,6 +11,7 @@ from bson import Binary, DBRef, ObjectId, SON | ||||
| import gridfs | ||||
| import pymongo | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| try: | ||||
|     import dateutil | ||||
| @@ -36,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError | ||||
| from mongoengine.python_support import StringIO | ||||
| from mongoengine.queryset import DO_NOTHING | ||||
| from mongoengine.queryset.base import BaseQuerySet | ||||
| from mongoengine.queryset.transform import STRING_OPERATORS | ||||
|  | ||||
| try: | ||||
|     from PIL import Image, ImageOps | ||||
| @@ -105,11 +107,11 @@ class StringField(BaseField): | ||||
|         if not isinstance(op, six.string_types): | ||||
|             return value | ||||
|  | ||||
|         if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): | ||||
|             flags = 0 | ||||
|             if op.startswith('i'): | ||||
|                 flags = re.IGNORECASE | ||||
|                 op = op.lstrip('i') | ||||
|         if op in STRING_OPERATORS: | ||||
|             case_insensitive = op.startswith('i') | ||||
|             op = op.lstrip('i') | ||||
|  | ||||
|             flags = re.IGNORECASE if case_insensitive else 0 | ||||
|  | ||||
|             regex = r'%s' | ||||
|             if op == 'startswith': | ||||
| @@ -151,12 +153,10 @@ class URLField(StringField): | ||||
|         scheme = value.split('://')[0].lower() | ||||
|         if scheme not in self.schemes: | ||||
|             self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value)) | ||||
|             return | ||||
|  | ||||
|         # Then check full URL | ||||
|         if not self.url_regex.match(value): | ||||
|             self.error(u'Invalid URL: {}'.format(value)) | ||||
|             return | ||||
|  | ||||
|  | ||||
| class EmailField(StringField): | ||||
| @@ -258,10 +258,10 @@ class EmailField(StringField): | ||||
|             try: | ||||
|                 domain_part = domain_part.encode('idna').decode('ascii') | ||||
|             except UnicodeError: | ||||
|                 self.error(self.error_msg % value) | ||||
|                 self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)")) | ||||
|             else: | ||||
|                 if not self.validate_domain_part(domain_part): | ||||
|                     self.error(self.error_msg % value) | ||||
|                     self.error("%s %s" % (self.error_msg % value, "(domain validation failed)")) | ||||
|  | ||||
|  | ||||
| class IntField(BaseField): | ||||
| @@ -498,15 +498,18 @@ class DateTimeField(BaseField): | ||||
|         if not isinstance(value, six.string_types): | ||||
|             return None | ||||
|  | ||||
|         return self._parse_datetime(value) | ||||
|  | ||||
|     def _parse_datetime(self, value): | ||||
|         # Attempt to parse a datetime from a string | ||||
|         value = value.strip() | ||||
|         if not value: | ||||
|             return None | ||||
|  | ||||
|         # Attempt to parse a datetime: | ||||
|         if dateutil: | ||||
|             try: | ||||
|                 return dateutil.parser.parse(value) | ||||
|             except (TypeError, ValueError): | ||||
|             except (TypeError, ValueError, OverflowError): | ||||
|                 return None | ||||
|  | ||||
|         # split usecs, because they are not recognized by strptime. | ||||
| @@ -699,7 +702,11 @@ class EmbeddedDocumentField(BaseField): | ||||
|         self.document_type.validate(value, clean) | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         return self.document_type._fields.get(member_name) | ||||
|         doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__() | ||||
|         for doc_type in doc_and_subclasses: | ||||
|             field = doc_type._fields.get(member_name) | ||||
|             if field: | ||||
|                 return field | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if value is not None and not isinstance(value, self.document_type): | ||||
| @@ -746,12 +753,13 @@ class GenericEmbeddedDocumentField(BaseField): | ||||
|         value.validate(clean=clean) | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         if self.choices: | ||||
|             for choice in self.choices: | ||||
|                 field = choice._fields.get(member_name) | ||||
|         document_choices = self.choices or [] | ||||
|         for document_choice in document_choices: | ||||
|             doc_and_subclasses = [document_choice] + document_choice.__subclasses__() | ||||
|             for doc_type in doc_and_subclasses: | ||||
|                 field = doc_type._fields.get(member_name) | ||||
|                 if field: | ||||
|                     return field | ||||
|         return None | ||||
|  | ||||
|     def to_mongo(self, document, use_db_field=True, fields=None): | ||||
|         if document is None: | ||||
| @@ -794,12 +802,12 @@ class DynamicField(BaseField): | ||||
|             value = {k: v for k, v in enumerate(value)} | ||||
|  | ||||
|         data = {} | ||||
|         for k, v in value.iteritems(): | ||||
|         for k, v in iteritems(value): | ||||
|             data[k] = self.to_mongo(v, use_db_field, fields) | ||||
|  | ||||
|         value = data | ||||
|         if is_list:  # Convert back to a list | ||||
|             value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] | ||||
|             value = [v for k, v in sorted(iteritems(data), key=itemgetter(0))] | ||||
|         return value | ||||
|  | ||||
|     def to_python(self, value): | ||||
|   | ||||
							
								
								
									
										19
									
								
								mongoengine/mongodb_support.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								mongoengine/mongodb_support.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| """ | ||||
| Helper functions, constants, and types to aid with MongoDB version support | ||||
| """ | ||||
| from mongoengine.connection import get_connection | ||||
|  | ||||
|  | ||||
| # Constant that can be used to compare the version retrieved with | ||||
| # get_mongodb_version() | ||||
| MONGODB_34 = (3, 4) | ||||
| MONGODB_36 = (3, 6) | ||||
|  | ||||
|  | ||||
| def get_mongodb_version(): | ||||
|     """Return the version of the connected mongoDB (first 2 digits) | ||||
|  | ||||
|     :return: tuple(int, int) | ||||
|     """ | ||||
|     version_list = get_connection().server_info()['versionArray'][:2]     # e.g: (3, 2) | ||||
|     return tuple(version_list) | ||||
							
								
								
									
										32
									
								
								mongoengine/pymongo_support.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								mongoengine/pymongo_support.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| """ | ||||
| Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support. | ||||
| """ | ||||
| import pymongo | ||||
|  | ||||
| _PYMONGO_37 = (3, 7) | ||||
|  | ||||
| PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) | ||||
|  | ||||
| IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 | ||||
|  | ||||
|  | ||||
| def count_documents(collection, filter): | ||||
|     """Pymongo>3.7 deprecates count in favour of count_documents""" | ||||
|     if IS_PYMONGO_GTE_37: | ||||
|         return collection.count_documents(filter) | ||||
|     else: | ||||
|         count = collection.find(filter).count() | ||||
|     return count | ||||
|  | ||||
|  | ||||
| def list_collection_names(db, include_system_collections=False): | ||||
|     """Pymongo>3.7 deprecates collection_names in favour of list_collection_names""" | ||||
|     if IS_PYMONGO_GTE_37: | ||||
|         collections = db.list_collection_names() | ||||
|     else: | ||||
|         collections = db.collection_names() | ||||
|  | ||||
|     if not include_system_collections: | ||||
|         collections = [c for c in collections if not c.startswith('system.')] | ||||
|  | ||||
|     return collections | ||||
| @@ -1,13 +1,8 @@ | ||||
| """ | ||||
| Helper functions, constants, and types to aid with Python v2.7 - v3.x and | ||||
| PyMongo v2.7 - v3.x support. | ||||
| Helper functions, constants, and types to aid with Python v2.7 - v3.x support | ||||
| """ | ||||
| import pymongo | ||||
| import six | ||||
|  | ||||
|  | ||||
| IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3 | ||||
|  | ||||
| # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. | ||||
| StringIO = six.BytesIO | ||||
|  | ||||
|   | ||||
| @@ -10,8 +10,10 @@ from bson import SON, json_util | ||||
| from bson.code import Code | ||||
| import pymongo | ||||
| import pymongo.errors | ||||
| from pymongo.collection import ReturnDocument | ||||
| from pymongo.common import validate_read_preference | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.base import get_document | ||||
| @@ -20,14 +22,10 @@ from mongoengine.connection import get_db | ||||
| from mongoengine.context_managers import set_write_concern, switch_db | ||||
| from mongoengine.errors import (InvalidQueryError, LookUpError, | ||||
|                                 NotUniqueError, OperationError) | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| from mongoengine.queryset import transform | ||||
| from mongoengine.queryset.field_list import QueryFieldList | ||||
| from mongoengine.queryset.visitor import Q, QNode | ||||
|  | ||||
| if IS_PYMONGO_3: | ||||
|     from pymongo.collection import ReturnDocument | ||||
|  | ||||
|  | ||||
| __all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') | ||||
|  | ||||
| @@ -196,7 +194,7 @@ class BaseQuerySet(object): | ||||
|                 only_fields=self.only_fields | ||||
|             ) | ||||
|  | ||||
|         raise AttributeError('Provide a slice or an integer index') | ||||
|         raise TypeError('Provide a slice or an integer index') | ||||
|  | ||||
|     def __iter__(self): | ||||
|         raise NotImplementedError | ||||
| @@ -498,11 +496,12 @@ class BaseQuerySet(object): | ||||
|             ``save(..., write_concern={w: 2, fsync: True}, ...)`` will | ||||
|             wait until at least two servers have recorded the write and | ||||
|             will force an fsync on the primary server. | ||||
|         :param full_result: Return the full result dictionary rather than just the number | ||||
|             updated, e.g. return | ||||
|             ``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``. | ||||
|         :param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number | ||||
|             updated items | ||||
|         :param update: Django-style update keyword arguments | ||||
|  | ||||
|         :returns the number of updated documents (unless ``full_result`` is True) | ||||
|  | ||||
|         .. versionadded:: 0.2 | ||||
|         """ | ||||
|         if not update and not upsert: | ||||
| @@ -566,7 +565,7 @@ class BaseQuerySet(object): | ||||
|             document = self._document.objects.with_id(atomic_update.upserted_id) | ||||
|         return document | ||||
|  | ||||
|     def update_one(self, upsert=False, write_concern=None, **update): | ||||
|     def update_one(self, upsert=False, write_concern=None, full_result=False, **update): | ||||
|         """Perform an atomic update on the fields of the first document | ||||
|         matched by the query. | ||||
|  | ||||
| @@ -577,12 +576,19 @@ class BaseQuerySet(object): | ||||
|             ``save(..., write_concern={w: 2, fsync: True}, ...)`` will | ||||
|             wait until at least two servers have recorded the write and | ||||
|             will force an fsync on the primary server. | ||||
|         :param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number | ||||
|             updated items | ||||
|         :param update: Django-style update keyword arguments | ||||
|  | ||||
|             full_result | ||||
|         :returns the number of updated documents (unless ``full_result`` is True) | ||||
|         .. versionadded:: 0.2 | ||||
|         """ | ||||
|         return self.update( | ||||
|             upsert=upsert, multi=False, write_concern=write_concern, **update) | ||||
|             upsert=upsert, | ||||
|             multi=False, | ||||
|             write_concern=write_concern, | ||||
|             full_result=full_result, | ||||
|             **update) | ||||
|  | ||||
|     def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): | ||||
|         """Update and return the updated document. | ||||
| @@ -617,31 +623,25 @@ class BaseQuerySet(object): | ||||
|  | ||||
|         queryset = self.clone() | ||||
|         query = queryset._query | ||||
|         if not IS_PYMONGO_3 or not remove: | ||||
|         if not remove: | ||||
|             update = transform.update(queryset._document, **update) | ||||
|         sort = queryset._ordering | ||||
|  | ||||
|         try: | ||||
|             if IS_PYMONGO_3: | ||||
|                 if full_response: | ||||
|                     msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' | ||||
|                     warnings.warn(msg, DeprecationWarning) | ||||
|                 if remove: | ||||
|                     result = queryset._collection.find_one_and_delete( | ||||
|                         query, sort=sort, **self._cursor_args) | ||||
|                 else: | ||||
|                     if new: | ||||
|                         return_doc = ReturnDocument.AFTER | ||||
|                     else: | ||||
|                         return_doc = ReturnDocument.BEFORE | ||||
|                     result = queryset._collection.find_one_and_update( | ||||
|                         query, update, upsert=upsert, sort=sort, return_document=return_doc, | ||||
|                         **self._cursor_args) | ||||
|  | ||||
|             if full_response: | ||||
|                 msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' | ||||
|                 warnings.warn(msg, DeprecationWarning) | ||||
|             if remove: | ||||
|                 result = queryset._collection.find_one_and_delete( | ||||
|                     query, sort=sort, **self._cursor_args) | ||||
|             else: | ||||
|                 result = queryset._collection.find_and_modify( | ||||
|                     query, update, upsert=upsert, sort=sort, remove=remove, new=new, | ||||
|                     full_response=full_response, **self._cursor_args) | ||||
|                 if new: | ||||
|                     return_doc = ReturnDocument.AFTER | ||||
|                 else: | ||||
|                     return_doc = ReturnDocument.BEFORE | ||||
|                 result = queryset._collection.find_one_and_update( | ||||
|                     query, update, upsert=upsert, sort=sort, return_document=return_doc, | ||||
|                     **self._cursor_args) | ||||
|         except pymongo.errors.DuplicateKeyError as err: | ||||
|             raise NotUniqueError(u'Update failed (%s)' % err) | ||||
|         except pymongo.errors.OperationFailure as err: | ||||
| @@ -748,7 +748,7 @@ class BaseQuerySet(object): | ||||
|                       '_read_preference', '_iter', '_scalar', '_as_pymongo', | ||||
|                       '_limit', '_skip', '_hint', '_auto_dereference', | ||||
|                       '_search_text', 'only_fields', '_max_time_ms', | ||||
|                       '_comment') | ||||
|                       '_comment', '_batch_size') | ||||
|  | ||||
|         for prop in copy_props: | ||||
|             val = getattr(self, prop) | ||||
| @@ -1073,15 +1073,14 @@ class BaseQuerySet(object): | ||||
|         ..versionchanged:: 0.5 - made chainable | ||||
|         .. deprecated:: Ignored with PyMongo 3+ | ||||
|         """ | ||||
|         if IS_PYMONGO_3: | ||||
|             msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' | ||||
|         warnings.warn(msg, DeprecationWarning) | ||||
|         queryset = self.clone() | ||||
|         queryset._snapshot = enabled | ||||
|         return queryset | ||||
|  | ||||
|     def timeout(self, enabled): | ||||
|         """Enable or disable the default mongod timeout when querying. | ||||
|         """Enable or disable the default mongod timeout when querying. (no_cursor_timeout option) | ||||
|  | ||||
|         :param enabled: whether or not the timeout is used | ||||
|  | ||||
| @@ -1099,9 +1098,8 @@ class BaseQuerySet(object): | ||||
|  | ||||
|         .. deprecated:: Ignored with PyMongo 3+ | ||||
|         """ | ||||
|         if IS_PYMONGO_3: | ||||
|             msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' | ||||
|         warnings.warn(msg, DeprecationWarning) | ||||
|         queryset = self.clone() | ||||
|         queryset._slave_okay = enabled | ||||
|         return queryset | ||||
| @@ -1191,14 +1189,18 @@ class BaseQuerySet(object): | ||||
|             initial_pipeline.append({'$sort': dict(self._ordering)}) | ||||
|  | ||||
|         if self._limit is not None: | ||||
|             initial_pipeline.append({'$limit': self._limit}) | ||||
|             # As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/), | ||||
|             # keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents | ||||
|             # for a skip stage that might succeed these. So we need to maintain more documents in memory in such a | ||||
|             # case (https://stackoverflow.com/a/24161461). | ||||
|             initial_pipeline.append({'$limit': self._limit + (self._skip or 0)}) | ||||
|  | ||||
|         if self._skip is not None: | ||||
|             initial_pipeline.append({'$skip': self._skip}) | ||||
|  | ||||
|         pipeline = initial_pipeline + list(pipeline) | ||||
|  | ||||
|         if IS_PYMONGO_3 and self._read_preference is not None: | ||||
|         if self._read_preference is not None: | ||||
|             return self._collection.with_options(read_preference=self._read_preference) \ | ||||
|                        .aggregate(pipeline, cursor={}, **kwargs) | ||||
|  | ||||
| @@ -1408,11 +1410,7 @@ class BaseQuerySet(object): | ||||
|         if isinstance(field_instances[-1], ListField): | ||||
|             pipeline.insert(1, {'$unwind': '$' + field}) | ||||
|  | ||||
|         result = self._document._get_collection().aggregate(pipeline) | ||||
|         if IS_PYMONGO_3: | ||||
|             result = tuple(result) | ||||
|         else: | ||||
|             result = result.get('result') | ||||
|         result = tuple(self._document._get_collection().aggregate(pipeline)) | ||||
|  | ||||
|         if result: | ||||
|             return result[0]['total'] | ||||
| @@ -1439,11 +1437,7 @@ class BaseQuerySet(object): | ||||
|         if isinstance(field_instances[-1], ListField): | ||||
|             pipeline.insert(1, {'$unwind': '$' + field}) | ||||
|  | ||||
|         result = self._document._get_collection().aggregate(pipeline) | ||||
|         if IS_PYMONGO_3: | ||||
|             result = tuple(result) | ||||
|         else: | ||||
|             result = result.get('result') | ||||
|         result = tuple(self._document._get_collection().aggregate(pipeline)) | ||||
|         if result: | ||||
|             return result[0]['total'] | ||||
|         return 0 | ||||
| @@ -1518,26 +1512,16 @@ class BaseQuerySet(object): | ||||
|  | ||||
|     @property | ||||
|     def _cursor_args(self): | ||||
|         if not IS_PYMONGO_3: | ||||
|             fields_name = 'fields' | ||||
|             cursor_args = { | ||||
|                 'timeout': self._timeout, | ||||
|                 'snapshot': self._snapshot | ||||
|             } | ||||
|             if self._read_preference is not None: | ||||
|                 cursor_args['read_preference'] = self._read_preference | ||||
|             else: | ||||
|                 cursor_args['slave_okay'] = self._slave_okay | ||||
|         else: | ||||
|             fields_name = 'projection' | ||||
|             # snapshot is not handled at all by PyMongo 3+ | ||||
|             # TODO: evaluate similar possibilities using modifiers | ||||
|             if self._snapshot: | ||||
|                 msg = 'The snapshot option is not anymore available with PyMongo 3+' | ||||
|                 warnings.warn(msg, DeprecationWarning) | ||||
|             cursor_args = { | ||||
|                 'no_cursor_timeout': not self._timeout | ||||
|             } | ||||
|         fields_name = 'projection' | ||||
|         # snapshot is not handled at all by PyMongo 3+ | ||||
|         # TODO: evaluate similar possibilities using modifiers | ||||
|         if self._snapshot: | ||||
|             msg = 'The snapshot option is not anymore available with PyMongo 3+' | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         cursor_args = { | ||||
|             'no_cursor_timeout': not self._timeout | ||||
|         } | ||||
|  | ||||
|         if self._loaded_fields: | ||||
|             cursor_args[fields_name] = self._loaded_fields.as_dict() | ||||
|  | ||||
| @@ -1561,7 +1545,7 @@ class BaseQuerySet(object): | ||||
|         # XXX In PyMongo 3+, we define the read preference on a collection | ||||
|         # level, not a cursor level. Thus, we need to get a cloned collection | ||||
|         # object using `with_options` first. | ||||
|         if IS_PYMONGO_3 and self._read_preference is not None: | ||||
|         if self._read_preference is not None: | ||||
|             self._cursor_obj = self._collection\ | ||||
|                 .with_options(read_preference=self._read_preference)\ | ||||
|                 .find(self._query, **self._cursor_args) | ||||
| @@ -1731,13 +1715,13 @@ class BaseQuerySet(object): | ||||
|             } | ||||
|         """ | ||||
|         total, data, types = self.exec_js(freq_func, field) | ||||
|         values = {types.get(k): int(v) for k, v in data.iteritems()} | ||||
|         values = {types.get(k): int(v) for k, v in iteritems(data)} | ||||
|  | ||||
|         if normalize: | ||||
|             values = {k: float(v) / total for k, v in values.items()} | ||||
|  | ||||
|         frequencies = {} | ||||
|         for k, v in values.iteritems(): | ||||
|         for k, v in iteritems(values): | ||||
|             if isinstance(k, float): | ||||
|                 if int(k) == k: | ||||
|                     k = int(k) | ||||
|   | ||||
| @@ -4,12 +4,11 @@ from bson import ObjectId, SON | ||||
| from bson.dbref import DBRef | ||||
| import pymongo | ||||
| import six | ||||
| from six import iteritems | ||||
|  | ||||
| from mongoengine.base import UPDATE_OPERATORS | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import get_connection | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
|  | ||||
| __all__ = ('query', 'update') | ||||
|  | ||||
| @@ -87,18 +86,10 @@ def query(_doc_cls=None, **kwargs): | ||||
|             singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] | ||||
|             singular_ops += STRING_OPERATORS | ||||
|             if op in singular_ops: | ||||
|                 if isinstance(field, six.string_types): | ||||
|                     if (op in STRING_OPERATORS and | ||||
|                             isinstance(value, six.string_types)): | ||||
|                         StringField = _import_class('StringField') | ||||
|                         value = StringField.prepare_query_value(op, value) | ||||
|                     else: | ||||
|                         value = field | ||||
|                 else: | ||||
|                     value = field.prepare_query_value(op, value) | ||||
|                 value = field.prepare_query_value(op, value) | ||||
|  | ||||
|                     if isinstance(field, CachedReferenceField) and value: | ||||
|                         value = value['_id'] | ||||
|                 if isinstance(field, CachedReferenceField) and value: | ||||
|                     value = value['_id'] | ||||
|  | ||||
|             elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): | ||||
|                 # Raise an error if the in/nin/all/near param is not iterable. | ||||
| @@ -154,7 +145,7 @@ def query(_doc_cls=None, **kwargs): | ||||
|                 if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \ | ||||
|                         ('$near' in value_dict or '$nearSphere' in value_dict): | ||||
|                     value_son = SON() | ||||
|                     for k, v in value_dict.iteritems(): | ||||
|                     for k, v in iteritems(value_dict): | ||||
|                         if k == '$maxDistance' or k == '$minDistance': | ||||
|                             continue | ||||
|                         value_son[k] = v | ||||
| @@ -162,16 +153,14 @@ def query(_doc_cls=None, **kwargs): | ||||
|                     # PyMongo 3+ and MongoDB < 2.6 | ||||
|                     near_embedded = False | ||||
|                     for near_op in ('$near', '$nearSphere'): | ||||
|                         if isinstance(value_dict.get(near_op), dict) and ( | ||||
|                                 IS_PYMONGO_3 or get_connection().max_wire_version > 1): | ||||
|                         if isinstance(value_dict.get(near_op), dict): | ||||
|                             value_son[near_op] = SON(value_son[near_op]) | ||||
|                             if '$maxDistance' in value_dict: | ||||
|                                 value_son[near_op][ | ||||
|                                     '$maxDistance'] = value_dict['$maxDistance'] | ||||
|                                 value_son[near_op]['$maxDistance'] = value_dict['$maxDistance'] | ||||
|                             if '$minDistance' in value_dict: | ||||
|                                 value_son[near_op][ | ||||
|                                     '$minDistance'] = value_dict['$minDistance'] | ||||
|                                 value_son[near_op]['$minDistance'] = value_dict['$minDistance'] | ||||
|                             near_embedded = True | ||||
|  | ||||
|                     if not near_embedded: | ||||
|                         if '$maxDistance' in value_dict: | ||||
|                             value_son['$maxDistance'] = value_dict['$maxDistance'] | ||||
| @@ -280,7 +269,7 @@ def update(_doc_cls=None, **update): | ||||
|  | ||||
|             if op == 'pull': | ||||
|                 if field.required or value is not None: | ||||
|                     if match == 'in' and not isinstance(value, dict): | ||||
|                     if match in ('in', 'nin') and not isinstance(value, dict): | ||||
|                         value = _prepare_query_for_iterable(field, op, value) | ||||
|                     else: | ||||
|                         value = field.prepare_query_value(op, value) | ||||
| @@ -307,10 +296,6 @@ def update(_doc_cls=None, **update): | ||||
|  | ||||
|         key = '.'.join(parts) | ||||
|  | ||||
|         if not op: | ||||
|             raise InvalidQueryError('Updates must supply an operation ' | ||||
|                                     'eg: set__FIELD=value') | ||||
|  | ||||
|         if 'pull' in op and '.' in key: | ||||
|             # Dot operators don't work on pull operations | ||||
|             # unless they point to a list field | ||||
|   | ||||
		Reference in New Issue
	
	Block a user