more cleanup

This commit is contained in:
Stefan Wojcik 2016-12-08 19:27:57 -05:00
parent f1f999a570
commit 44b86e29c6
4 changed files with 38 additions and 24 deletions

View File

@ -728,9 +728,9 @@ class BaseDocument(object):
"""Generate and merge the full index specs.""" """Generate and merge the full index specs."""
geo_indices = cls._geo_indices() geo_indices = cls._geo_indices()
unique_indices = cls._unique_with_indexes() unique_indices = cls._unique_with_indexes()
index_specs = [cls._build_index_spec(spec) index_specs = [cls._build_index_spec(spec) for spec in meta_indexes]
for spec in meta_indexes]
# Merge geo indexes and unique_with indexes into the meta index specs
def merge_index_specs(index_specs, indices): def merge_index_specs(index_specs, indices):
if not indices: if not indices:
return index_specs return index_specs
@ -834,12 +834,11 @@ class BaseDocument(object):
@classmethod @classmethod
def _unique_with_indexes(cls, namespace=""): def _unique_with_indexes(cls, namespace=""):
""" """Find unique indexes in the document schema and return them."""
Find and set unique indexes
"""
unique_indexes = [] unique_indexes = []
for field_name, field in cls._fields.items(): for field_name, field in cls._fields.items():
sparse = field.sparse sparse = field.sparse
# Generate a list of indexes needed by uniqueness constraints # Generate a list of indexes needed by uniqueness constraints
if field.unique: if field.unique:
unique_fields = [field.db_field] unique_fields = [field.db_field]
@ -853,14 +852,17 @@ class BaseDocument(object):
unique_with = [] unique_with = []
for other_name in field.unique_with: for other_name in field.unique_with:
parts = other_name.split('.') parts = other_name.split('.')
# Lookup real name # Lookup real name
parts = cls._lookup_field(parts) parts = cls._lookup_field(parts)
name_parts = [part.db_field for part in parts] name_parts = [part.db_field for part in parts]
unique_with.append('.'.join(name_parts)) unique_with.append('.'.join(name_parts))
# Unique field should be required # Unique field should be required
parts[-1].required = True parts[-1].required = True
sparse = (not sparse and sparse = (not sparse and
parts[-1].name not in cls.__dict__) parts[-1].name not in cls.__dict__)
unique_fields += unique_with unique_fields += unique_with
# Add the new index to the list # Add the new index to the list
@ -896,10 +898,12 @@ class BaseDocument(object):
for field in cls._fields.values(): for field in cls._fields.values():
if not isinstance(field, geo_field_types): if not isinstance(field, geo_field_types):
continue continue
if hasattr(field, 'document_type'): if hasattr(field, 'document_type'):
field_cls = field.document_type field_cls = field.document_type
if field_cls in inspected: if field_cls in inspected:
continue continue
if hasattr(field_cls, '_geo_indices'): if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices( geo_indices += field_cls._geo_indices(
inspected, parent_field=field.db_field) inspected, parent_field=field.db_field)
@ -907,8 +911,10 @@ class BaseDocument(object):
field_name = field.db_field field_name = field.db_field
if parent_field: if parent_field:
field_name = "%s.%s" % (parent_field, field_name) field_name = "%s.%s" % (parent_field, field_name)
geo_indices.append({'fields': geo_indices.append({
[(field_name, field._geo_index)]}) 'fields': [(field_name, field._geo_index)]
})
return geo_indices return geo_indices
@classmethod @classmethod

View File

@ -196,7 +196,9 @@ class BaseField(object):
if isinstance(value, (Document, EmbeddedDocument)): if isinstance(value, (Document, EmbeddedDocument)):
if not any(isinstance(value, c) for c in choice_list): if not any(isinstance(value, c) for c in choice_list):
self.error( self.error(
'Value must be instance of %s' % six.text_type(choice_list) 'Value must be an instance of %s' % (
six.text_type(choice_list)
)
) )
# Choices which are types other than Documents # Choices which are types other than Documents
elif value not in choice_list: elif value not in choice_list:

View File

@ -141,16 +141,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Validate that the requested alias exists in the _connection_settings. # Validate that the requested alias exists in the _connection_settings.
# Raise MongoEngineConnectionError if it doesn't. # Raise MongoEngineConnectionError if it doesn't.
if alias not in _connection_settings: if alias not in _connection_settings:
msg = 'Connection with alias "%s" has not been defined' % alias
if alias == DEFAULT_CONNECTION_NAME: if alias == DEFAULT_CONNECTION_NAME:
msg = 'You have not defined a default connection' msg = 'You have not defined a default connection'
else:
msg = 'Connection with alias "%s" has not been defined' % alias
raise MongoEngineConnectionError(msg) raise MongoEngineConnectionError(msg)
def _clean_settings(settings_dict): def _clean_settings(settings_dict):
irrelevant_fields = ( irrelevant_fields = set([
'name', 'username', 'password', 'authentication_source', 'name', 'username', 'password', 'authentication_source',
'authentication_mechanism' 'authentication_mechanism'
) ])
return dict( return dict(
(k, v) for k, v in settings_dict.items() (k, v) for k, v in settings_dict.items()
if k not in irrelevant_fields if k not in irrelevant_fields
@ -162,7 +163,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings = _clean_settings(_connection_settings[alias].copy()) conn_settings = _clean_settings(_connection_settings[alias].copy())
# Determine if we should use PyMongo's or mongomock's MongoClient. # Determine if we should use PyMongo's or mongomock's MongoClient.
is_mock = conn_settings.pop('is_mock', None) is_mock = conn_settings.pop('is_mock', False)
if is_mock: if is_mock:
try: try:
import mongomock import mongomock
@ -173,17 +174,22 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
else: else:
connection_class = MongoClient connection_class = MongoClient
# For replica set connections with PyMongo 2.x, use MongoReplicaSetClient # Handle replica set connections
# TODO remove this block once we stop supporting PyMongo 2.x. if 'replicaSet' in conn_settings:
if 'replicaSet' in conn_settings:
# Discard port since it can't be used on MongoReplicaSetClient # Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None) conn_settings.pop('port', None)
# Discard replicaSet if it's not a string
if not isinstance(conn_settings['replicaSet'], six.string_types): # Discard replicaSet if it's not a string
conn_settings.pop('replicaSet', None) if not isinstance(conn_settings['replicaSet'], six.string_types):
if not IS_PYMONGO_3: del conn_settings['replicaSet']
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) # For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# Iterate over all of the connection settings and if a connection with # Iterate over all of the connection settings and if a connection with
# the same parameters is already established, use it instead of creating # the same parameters is already established, use it instead of creating

View File

@ -219,7 +219,7 @@ class Document(BaseDocument):
if self._data.get('id') is None: if self._data.get('id') is None:
del data['_id'] del data['_id']
else: else:
data["_id"] = self._data['id'] data['_id'] = self._data['id']
return data return data