From 44b86e29c696cd59a36b869a7f9c586a0a8f0151 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Thu, 8 Dec 2016 19:27:57 -0500 Subject: [PATCH] more cleanup --- mongoengine/base/document.py | 20 +++++++++++++------- mongoengine/base/fields.py | 4 +++- mongoengine/connection.py | 36 +++++++++++++++++++++--------------- mongoengine/document.py | 2 +- 4 files changed, 38 insertions(+), 24 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index efa185a5..65f84005 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -728,9 +728,9 @@ class BaseDocument(object): """Generate and merge the full index specs.""" geo_indices = cls._geo_indices() unique_indices = cls._unique_with_indexes() - index_specs = [cls._build_index_spec(spec) - for spec in meta_indexes] + index_specs = [cls._build_index_spec(spec) for spec in meta_indexes] + # Merge geo indexes and unique_with indexes into the meta index specs def merge_index_specs(index_specs, indices): if not indices: return index_specs @@ -834,12 +834,11 @@ class BaseDocument(object): @classmethod def _unique_with_indexes(cls, namespace=""): - """ - Find and set unique indexes - """ + """Find unique indexes in the document schema and return them.""" unique_indexes = [] for field_name, field in cls._fields.items(): sparse = field.sparse + # Generate a list of indexes needed by uniqueness constraints if field.unique: unique_fields = [field.db_field] @@ -853,14 +852,17 @@ class BaseDocument(object): unique_with = [] for other_name in field.unique_with: parts = other_name.split('.') + # Lookup real name parts = cls._lookup_field(parts) name_parts = [part.db_field for part in parts] unique_with.append('.'.join(name_parts)) + # Unique field should be required parts[-1].required = True sparse = (not sparse and parts[-1].name not in cls.__dict__) + unique_fields += unique_with # Add the new index to the list @@ -896,10 +898,12 @@ class BaseDocument(object): for field in cls._fields.values(): if not isinstance(field, geo_field_types): continue + if hasattr(field, 'document_type'): field_cls = field.document_type if field_cls in inspected: continue + if hasattr(field_cls, '_geo_indices'): geo_indices += field_cls._geo_indices( inspected, parent_field=field.db_field) @@ -907,8 +911,10 @@ class BaseDocument(object): field_name = field.db_field if parent_field: field_name = "%s.%s" % (parent_field, field_name) - geo_indices.append({'fields': - [(field_name, field._geo_index)]}) + geo_indices.append({ + 'fields': [(field_name, field._geo_index)] + }) + return geo_indices @classmethod diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 4ab22d87..81dfaa4f 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -196,7 +196,9 @@ class BaseField(object): if isinstance(value, (Document, EmbeddedDocument)): if not any(isinstance(value, c) for c in choice_list): self.error( - 'Value must be instance of %s' % six.text_type(choice_list) + 'Value must be an instance of %s' % ( + six.text_type(choice_list) + ) ) # Choices which are types other than Documents elif value not in choice_list: diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 396ca61b..7a016a71 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -141,16 +141,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Validate that the requested alias exists in the _connection_settings. # Raise MongoEngineConnectionError if it doesn't. if alias not in _connection_settings: - msg = 'Connection with alias "%s" has not been defined' % alias if alias == DEFAULT_CONNECTION_NAME: msg = 'You have not defined a default connection' + else: + msg = 'Connection with alias "%s" has not been defined' % alias raise MongoEngineConnectionError(msg) def _clean_settings(settings_dict): - irrelevant_fields = ( + irrelevant_fields = set([ 'name', 'username', 'password', 'authentication_source', 'authentication_mechanism' - ) + ]) return dict( (k, v) for k, v in settings_dict.items() if k not in irrelevant_fields @@ -162,7 +163,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings = _clean_settings(_connection_settings[alias].copy()) # Determine if we should use PyMongo's or mongomock's MongoClient. - is_mock = conn_settings.pop('is_mock', None) + is_mock = conn_settings.pop('is_mock', False) if is_mock: try: import mongomock @@ -173,17 +174,22 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): else: connection_class = MongoClient - # For replica set connections with PyMongo 2.x, use MongoReplicaSetClient - # TODO remove this block once we stop supporting PyMongo 2.x. - if 'replicaSet' in conn_settings: - # Discard port since it can't be used on MongoReplicaSetClient - conn_settings.pop('port', None) - # Discard replicaSet if it's not a string - if not isinstance(conn_settings['replicaSet'], six.string_types): - conn_settings.pop('replicaSet', None) - if not IS_PYMONGO_3: - connection_class = MongoReplicaSetClient - conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) + # Handle replica set connections + if 'replicaSet' in conn_settings: + + # Discard port since it can't be used on MongoReplicaSetClient + conn_settings.pop('port', None) + + # Discard replicaSet if it's not a string + if not isinstance(conn_settings['replicaSet'], six.string_types): + del conn_settings['replicaSet'] + + # For replica set connections with PyMongo 2.x, use + # MongoReplicaSetClient. + # TODO remove this once we stop supporting PyMongo 2.x. + if not IS_PYMONGO_3: + connection_class = MongoReplicaSetClient + conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) # Iterate over all of the connection settings and if a connection with # the same parameters is already established, use it instead of creating diff --git a/mongoengine/document.py b/mongoengine/document.py index 1793bc20..572100e9 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -219,7 +219,7 @@ class Document(BaseDocument): if self._data.get('id') is None: del data['_id'] else: - data["_id"] = self._data['id'] + data['_id'] = self._data['id'] return data