Prefer ' over " + minor docstring tweaks

This commit is contained in:
Stefan Wojcik 2016-12-08 22:44:02 -05:00
parent 76219901db
commit 5b7b65a750
13 changed files with 387 additions and 417 deletions

View File

@ -10,7 +10,7 @@ __all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList')
class BaseDict(dict): class BaseDict(dict):
"""A special dict so we can watch any changes""" """A special dict so we can watch any changes."""
_dereferenced = False _dereferenced = False
_instance = None _instance = None
@ -95,8 +95,7 @@ class BaseDict(dict):
class BaseList(list): class BaseList(list):
"""A special list so we can watch any changes """A special list so we can watch any changes."""
"""
_dereferenced = False _dereferenced = False
_instance = None _instance = None
@ -213,7 +212,7 @@ class EmbeddedDocumentList(BaseList):
@classmethod @classmethod
def __match_all(cls, embedded_doc, kwargs): def __match_all(cls, embedded_doc, kwargs):
"""Return True if a given embedded doc matches all the filter """Return True if a given embedded doc matches all the filter
kwargs. If it doesn't return False kwargs. If it doesn't return False.
""" """
for key, expected_value in kwargs.items(): for key, expected_value in kwargs.items():
doc_val = getattr(embedded_doc, key) doc_val = getattr(embedded_doc, key)
@ -292,18 +291,18 @@ class EmbeddedDocumentList(BaseList):
values = self.__only_matches(self, kwargs) values = self.__only_matches(self, kwargs)
if len(values) == 0: if len(values) == 0:
raise DoesNotExist( raise DoesNotExist(
"%s matching query does not exist." % self._name '%s matching query does not exist.' % self._name
) )
elif len(values) > 1: elif len(values) > 1:
raise MultipleObjectsReturned( raise MultipleObjectsReturned(
"%d items returned, instead of 1" % len(values) '%d items returned, instead of 1' % len(values)
) )
return values[0] return values[0]
def first(self): def first(self):
""" """Return the first embedded document in the list, or ``None``
Returns the first embedded document in the list, or ``None`` if empty. if empty.
""" """
if len(self) > 0: if len(self) > 0:
return self[0] return self[0]
@ -445,7 +444,7 @@ class StrictDict(object):
__slots__ = allowed_keys_tuple __slots__ = allowed_keys_tuple
def __repr__(self): def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
cls._classes[allowed_keys] = SpecificStrictDict cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys] return cls._classes[allowed_keys]

View File

@ -54,15 +54,15 @@ class BaseDocument(object):
name = next(field) name = next(field)
if name in values: if name in values:
raise TypeError( raise TypeError(
"Multiple values for keyword argument '" + name + "'") 'Multiple values for keyword argument "%s"' % name)
values[name] = value values[name] = value
__auto_convert = values.pop("__auto_convert", True) __auto_convert = values.pop('__auto_convert', True)
# 399: set default values only to fields loaded from DB # 399: set default values only to fields loaded from DB
__only_fields = set(values.pop("__only_fields", values)) __only_fields = set(values.pop('__only_fields', values))
_created = values.pop("_created", True) _created = values.pop('_created', True)
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
@ -73,7 +73,7 @@ class BaseDocument(object):
self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) self._fields.keys() + ['id', 'pk', '_cls', '_text_score'])
if _undefined_fields: if _undefined_fields:
msg = ( msg = (
"The fields '{0}' do not exist on the document '{1}'" 'The fields "{0}" do not exist on the document "{1}"'
).format(_undefined_fields, self._class_name) ).format(_undefined_fields, self._class_name)
raise FieldDoesNotExist(msg) raise FieldDoesNotExist(msg)
@ -92,7 +92,7 @@ class BaseDocument(object):
value = getattr(self, key, None) value = getattr(self, key, None)
setattr(self, key, value) setattr(self, key, value)
if "_cls" not in values: if '_cls' not in values:
self._cls = self._class_name self._cls = self._class_name
# Set passed values after initialisation # Set passed values after initialisation
@ -147,7 +147,7 @@ class BaseDocument(object):
if self._dynamic and not self._dynamic_lock: if self._dynamic and not self._dynamic_lock:
if not hasattr(self, name) and not name.startswith('_'): if not hasattr(self, name) and not name.startswith('_'):
DynamicField = _import_class("DynamicField") DynamicField = _import_class('DynamicField')
field = DynamicField(db_field=name) field = DynamicField(db_field=name)
field.name = name field.name = name
self._dynamic_fields[name] = field self._dynamic_fields[name] = field
@ -172,7 +172,7 @@ class BaseDocument(object):
name in self._meta.get('shard_key', tuple()) and name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value self._data.get(name) != value
): ):
msg = "Shard Keys are immutable. Tried to update %s" % name msg = 'Shard Keys are immutable. Tried to update %s' % name
raise OperationError(msg) raise OperationError(msg)
try: try:
@ -196,8 +196,8 @@ class BaseDocument(object):
return data return data
def __setstate__(self, data): def __setstate__(self, data):
if isinstance(data["_data"], SON): if isinstance(data['_data'], SON):
data["_data"] = self.__class__._from_son(data["_data"])._data data['_data'] = self.__class__._from_son(data['_data'])._data
for k in ('_changed_fields', '_initialised', '_created', '_data', for k in ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields'): '_dynamic_fields'):
if k in data: if k in data:
@ -211,7 +211,7 @@ class BaseDocument(object):
dynamic_fields = data.get('_dynamic_fields') or SON() dynamic_fields = data.get('_dynamic_fields') or SON()
for k in dynamic_fields.keys(): for k in dynamic_fields.keys():
setattr(self, k, data["_data"].get(k)) setattr(self, k, data['_data'].get(k))
def __iter__(self): def __iter__(self):
return iter(self._fields_ordered) return iter(self._fields_ordered)
@ -373,9 +373,9 @@ class BaseDocument(object):
fields = [(self._fields.get(name, self._dynamic_fields.get(name)), fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
self._data.get(name)) for name in self._fields_ordered] self._data.get(name)) for name in self._fields_ordered]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
GenericEmbeddedDocumentField = _import_class( GenericEmbeddedDocumentField = _import_class(
"GenericEmbeddedDocumentField") 'GenericEmbeddedDocumentField')
for field, value in fields: for field, value in fields:
if value is not None: if value is not None:
@ -394,12 +394,12 @@ class BaseDocument(object):
field_name=field.name) field_name=field.name)
if errors: if errors:
pk = "None" pk = 'None'
if hasattr(self, 'pk'): if hasattr(self, 'pk'):
pk = self.pk pk = self.pk
elif self._instance and hasattr(self._instance, 'pk'): elif self._instance and hasattr(self._instance, 'pk'):
pk = self._instance.pk pk = self._instance.pk
message = "ValidationError (%s:%s) " % (self._class_name, pk) message = 'ValidationError (%s:%s) ' % (self._class_name, pk)
raise ValidationError(message, errors=errors) raise ValidationError(message, errors=errors)
def to_json(self, *args, **kwargs): def to_json(self, *args, **kwargs):
@ -455,8 +455,7 @@ class BaseDocument(object):
return value return value
def _mark_as_changed(self, key): def _mark_as_changed(self, key):
"""Marks a key as explicitly changed by the user """Mark a key as explicitly changed by the user."""
"""
if not key: if not key:
return return
@ -489,7 +488,7 @@ class BaseDocument(object):
"""Using get_changed_fields iterate and remove any fields that are """Using get_changed_fields iterate and remove any fields that are
marked as changed""" marked as changed"""
for changed in self._get_changed_fields(): for changed in self._get_changed_fields():
parts = changed.split(".") parts = changed.split('.')
data = self data = self
for part in parts: for part in parts:
if isinstance(data, list): if isinstance(data, list):
@ -501,8 +500,8 @@ class BaseDocument(object):
data = data.get(part, None) data = data.get(part, None)
else: else:
data = getattr(data, part, None) data = getattr(data, part, None)
if hasattr(data, "_changed_fields"): if hasattr(data, '_changed_fields'):
if hasattr(data, "_is_document") and data._is_document: if hasattr(data, '_is_document') and data._is_document:
continue continue
data._changed_fields = [] data._changed_fields = []
self._changed_fields = [] self._changed_fields = []
@ -516,26 +515,26 @@ class BaseDocument(object):
iterator = data.iteritems() iterator = data.iteritems()
for index, value in iterator: for index, value in iterator:
list_key = "%s%s." % (key, index) list_key = '%s%s.' % (key, index)
# don't check anything lower if this key is already marked # don't check anything lower if this key is already marked
# as changed. # as changed.
if list_key[:-1] in changed_fields: if list_key[:-1] in changed_fields:
continue continue
if hasattr(value, '_get_changed_fields'): if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected) changed = value._get_changed_fields(inspected)
changed_fields += ["%s%s" % (list_key, k) changed_fields += ['%s%s' % (list_key, k)
for k in changed if k] for k in changed if k]
elif isinstance(value, (list, tuple, dict)): elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields( self._nestable_types_changed_fields(
changed_fields, list_key, value, inspected) changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None): def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed. """Return a list of all fields that have explicitly been changed.
""" """
EmbeddedDocument = _import_class("EmbeddedDocument") EmbeddedDocument = _import_class('EmbeddedDocument')
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
ReferenceField = _import_class("ReferenceField") ReferenceField = _import_class('ReferenceField')
SortedListField = _import_class("SortedListField") SortedListField = _import_class('SortedListField')
changed_fields = [] changed_fields = []
changed_fields += getattr(self, '_changed_fields', []) changed_fields += getattr(self, '_changed_fields', [])
@ -563,7 +562,7 @@ class BaseDocument(object):
): ):
# Find all embedded fields that have been changed # Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected) changed = data._get_changed_fields(inspected)
changed_fields += ["%s%s" % (key, k) for k in changed if k] changed_fields += ['%s%s' % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in changed_fields): db_field_name not in changed_fields):
if (hasattr(field, 'field') and if (hasattr(field, 'field') and
@ -667,13 +666,15 @@ class BaseDocument(object):
@classmethod @classmethod
def _get_collection_name(cls): def _get_collection_name(cls):
"""Returns the collection name for this class. None for abstract class """Return the collection name for this class. None for abstract
class.
""" """
return cls._meta.get('collection', None) return cls._meta.get('collection', None)
@classmethod @classmethod
def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
"""Create an instance of a Document (subclass) from a PyMongo SON. """Create an instance of a Document (subclass) from a PyMongo
SON.
""" """
if not only_fields: if not only_fields:
only_fields = [] only_fields = []
@ -681,7 +682,7 @@ class BaseDocument(object):
# get the class name from the document, falling back to the given # get the class name from the document, falling back to the given
# class if unavailable # class if unavailable
class_name = son.get('_cls', cls._class_name) class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.iteritems()) data = dict(('%s' % key, value) for key, value in son.iteritems())
# Return correct subclass for document type # Return correct subclass for document type
if class_name != cls._class_name: if class_name != cls._class_name:
@ -707,9 +708,9 @@ class BaseDocument(object):
errors_dict[field_name] = e errors_dict[field_name] = e
if errors_dict: if errors_dict:
errors = "\n".join(["%s - %s" % (k, v) errors = '\n'.join(['%s - %s' % (k, v)
for k, v in errors_dict.items()]) for k, v in errors_dict.items()])
msg = ("Invalid data to create a `%s` instance.\n%s" msg = ('Invalid data to create a `%s` instance.\n%s'
% (cls._class_name, errors)) % (cls._class_name, errors))
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
@ -782,7 +783,7 @@ class BaseDocument(object):
# 733: don't include cls if index_cls is False unless there is an explicit cls with the index # 733: don't include cls if index_cls is False unless there is an explicit cls with the index
include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
if "cls" in spec: if 'cls' in spec:
spec.pop('cls') spec.pop('cls')
for key in spec['fields']: for key in spec['fields']:
# If inherited spec continue # If inherited spec continue
@ -797,19 +798,19 @@ class BaseDocument(object):
# GEOHAYSTACK from ) # GEOHAYSTACK from )
# GEO2D from * # GEO2D from *
direction = pymongo.ASCENDING direction = pymongo.ASCENDING
if key.startswith("-"): if key.startswith('-'):
direction = pymongo.DESCENDING direction = pymongo.DESCENDING
elif key.startswith("$"): elif key.startswith('$'):
direction = pymongo.TEXT direction = pymongo.TEXT
elif key.startswith("#"): elif key.startswith('#'):
direction = pymongo.HASHED direction = pymongo.HASHED
elif key.startswith("("): elif key.startswith('('):
direction = pymongo.GEOSPHERE direction = pymongo.GEOSPHERE
elif key.startswith(")"): elif key.startswith(')'):
direction = pymongo.GEOHAYSTACK direction = pymongo.GEOHAYSTACK
elif key.startswith("*"): elif key.startswith('*'):
direction = pymongo.GEO2D direction = pymongo.GEO2D
if key.startswith(("+", "-", "*", "$", "#", "(", ")")): if key.startswith(('+', '-', '*', '$', '#', '(', ')')):
key = key[1:] key = key[1:]
# Use real field name, do it manually because we need field # Use real field name, do it manually because we need field
@ -822,7 +823,7 @@ class BaseDocument(object):
parts = [] parts = []
for field in fields: for field in fields:
try: try:
if field != "_id": if field != '_id':
field = field.db_field field = field.db_field
except AttributeError: except AttributeError:
pass pass
@ -841,7 +842,7 @@ class BaseDocument(object):
return spec return spec
@classmethod @classmethod
def _unique_with_indexes(cls, namespace=""): def _unique_with_indexes(cls, namespace=''):
"""Find unique indexes in the document schema and return them.""" """Find unique indexes in the document schema and return them."""
unique_indexes = [] unique_indexes = []
for field_name, field in cls._fields.items(): for field_name, field in cls._fields.items():
@ -875,7 +876,7 @@ class BaseDocument(object):
# Add the new index to the list # Add the new index to the list
fields = [ fields = [
("%s%s" % (namespace, f), pymongo.ASCENDING) ('%s%s' % (namespace, f), pymongo.ASCENDING)
for f in unique_fields for f in unique_fields
] ]
index = {'fields': fields, 'unique': True, 'sparse': sparse} index = {'fields': fields, 'unique': True, 'sparse': sparse}
@ -887,7 +888,7 @@ class BaseDocument(object):
# Grab any embedded document field unique indexes # Grab any embedded document field unique indexes
if (field.__class__.__name__ == 'EmbeddedDocumentField' and if (field.__class__.__name__ == 'EmbeddedDocumentField' and
field.document_type != cls): field.document_type != cls):
field_namespace = "%s." % field_name field_namespace = '%s.' % field_name
doc_cls = field.document_type doc_cls = field.document_type
unique_indexes += doc_cls._unique_with_indexes(field_namespace) unique_indexes += doc_cls._unique_with_indexes(field_namespace)
@ -921,7 +922,7 @@ class BaseDocument(object):
elif field._geo_index: elif field._geo_index:
field_name = field.db_field field_name = field.db_field
if parent_field: if parent_field:
field_name = "%s.%s" % (parent_field, field_name) field_name = '%s.%s' % (parent_field, field_name)
geo_indices.append({ geo_indices.append({
'fields': [(field_name, field._geo_index)] 'fields': [(field_name, field._geo_index)]
}) })
@ -965,7 +966,7 @@ class BaseDocument(object):
# TODO this method is WAY too complicated. Simplify it. # TODO this method is WAY too complicated. Simplify it.
# TODO don't think returning a string for embedded non-existent fields is desired # TODO don't think returning a string for embedded non-existent fields is desired
ListField = _import_class("ListField") ListField = _import_class('ListField')
DynamicField = _import_class('DynamicField') DynamicField = _import_class('DynamicField')
if not isinstance(parts, (list, tuple)): if not isinstance(parts, (list, tuple)):

View File

@ -69,7 +69,7 @@ class BaseField(object):
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
if name: if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" msg = 'Field\'s "name" attribute deprecated in favour of "db_field"'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
self.required = required or primary_key self.required = required or primary_key
self.default = default self.default = default
@ -85,7 +85,7 @@ class BaseField(object):
# Detect and report conflicts between metadata and base properties. # Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs) conflicts = set(dir(self)) & set(kwargs)
if conflicts: if conflicts:
raise TypeError("%s already has attribute(s): %s" % ( raise TypeError('%s already has attribute(s): %s' % (
self.__class__.__name__, ', '.join(conflicts))) self.__class__.__name__, ', '.join(conflicts)))
# Assign metadata to the instance # Assign metadata to the instance
@ -143,25 +143,21 @@ class BaseField(object):
v._instance = weakref.proxy(instance) v._instance = weakref.proxy(instance)
instance._data[self.name] = value instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None): def error(self, message='', errors=None, field_name=None):
"""Raises a ValidationError. """Raise a ValidationError."""
"""
field_name = field_name if field_name else self.name field_name = field_name if field_name else self.name
raise ValidationError(message, errors=errors, field_name=field_name) raise ValidationError(message, errors=errors, field_name=field_name)
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type."""
"""
return value return value
def to_mongo(self, value): def to_mongo(self, value):
"""Convert a Python type to a MongoDB-compatible type. """Convert a Python type to a MongoDB-compatible type."""
"""
return self.to_python(value) return self.to_python(value)
def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): def _to_mongo_safe_call(self, value, use_db_field=True, fields=None):
"""A helper method to call to_mongo with proper inputs """Helper method to call to_mongo with proper inputs."""
"""
f_inputs = self.to_mongo.__code__.co_varnames f_inputs = self.to_mongo.__code__.co_varnames
ex_vars = {} ex_vars = {}
if 'fields' in f_inputs: if 'fields' in f_inputs:
@ -173,15 +169,13 @@ class BaseField(object):
return self.to_mongo(value, **ex_vars) return self.to_mongo(value, **ex_vars)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
"""Prepare a value that is being used in a query for PyMongo. """Prepare a value that is being used in a query for PyMongo."""
"""
if op in UPDATE_OPERATORS: if op in UPDATE_OPERATORS:
self.validate(value) self.validate(value)
return value return value
def validate(self, value, clean=True): def validate(self, value, clean=True):
"""Perform validation on a value. """Perform validation on a value."""
"""
pass pass
def _validate_choices(self, value): def _validate_choices(self, value):
@ -245,8 +239,7 @@ class ComplexBaseField(BaseField):
field = None field = None
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor to automatically dereference references. """Descriptor to automatically dereference references."""
"""
if instance is None: if instance is None:
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
@ -258,7 +251,7 @@ class ComplexBaseField(BaseField):
(self.field is None or isinstance(self.field, (self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField)))) (GenericReferenceField, ReferenceField))))
_dereference = _import_class("DeReference")() _dereference = _import_class('DeReference')()
self._auto_dereference = instance._fields[self.name]._auto_dereference self._auto_dereference = instance._fields[self.name]._auto_dereference
if instance._initialised and dereference and instance._data.get(self.name): if instance._initialised and dereference and instance._data.get(self.name):
@ -293,8 +286,7 @@ class ComplexBaseField(BaseField):
return value return value
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type."""
"""
if isinstance(value, six.string_types): if isinstance(value, six.string_types):
return value return value
@ -335,11 +327,10 @@ class ComplexBaseField(BaseField):
return value_dict return value_dict
def to_mongo(self, value, use_db_field=True, fields=None): def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB-compatible type. """Convert a Python type to a MongoDB-compatible type."""
""" Document = _import_class('Document')
Document = _import_class("Document") EmbeddedDocument = _import_class('EmbeddedDocument')
EmbeddedDocument = _import_class("EmbeddedDocument") GenericReferenceField = _import_class('GenericReferenceField')
GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, six.string_types): if isinstance(value, six.string_types):
return value return value
@ -400,8 +391,7 @@ class ComplexBaseField(BaseField):
return value_dict return value_dict
def validate(self, value): def validate(self, value):
"""If field is provided ensure the value is valid. """If field is provided ensure the value is valid."""
"""
errors = {} errors = {}
if self.field: if self.field:
if hasattr(value, 'iteritems') or hasattr(value, 'items'): if hasattr(value, 'iteritems') or hasattr(value, 'items'):
@ -439,8 +429,7 @@ class ComplexBaseField(BaseField):
class ObjectIdField(BaseField): class ObjectIdField(BaseField):
"""A field wrapper around MongoDB's ObjectIds. """A field wrapper around MongoDB's ObjectIds."""
"""
def to_python(self, value): def to_python(self, value):
try: try:
@ -476,21 +465,20 @@ class GeoJsonBaseField(BaseField):
""" """
_geo_index = pymongo.GEOSPHERE _geo_index = pymongo.GEOSPHERE
_type = "GeoBase" _type = 'GeoBase'
def __init__(self, auto_index=True, *args, **kwargs): def __init__(self, auto_index=True, *args, **kwargs):
""" """
:param bool auto_index: Automatically create a "2dsphere" index.\ :param bool auto_index: Automatically create a '2dsphere' index.\
Defaults to `True`. Defaults to `True`.
""" """
self._name = "%sField" % self._type self._name = '%sField' % self._type
if not auto_index: if not auto_index:
self._geo_index = False self._geo_index = False
super(GeoJsonBaseField, self).__init__(*args, **kwargs) super(GeoJsonBaseField, self).__init__(*args, **kwargs)
def validate(self, value): def validate(self, value):
"""Validate the GeoJson object based on its type """Validate the GeoJson object based on its type."""
"""
if isinstance(value, dict): if isinstance(value, dict):
if set(value.keys()) == set(['type', 'coordinates']): if set(value.keys()) == set(['type', 'coordinates']):
if value['type'] != self._type: if value['type'] != self._type:
@ -505,7 +493,7 @@ class GeoJsonBaseField(BaseField):
self.error('%s can only accept lists of [x, y]' % self._name) self.error('%s can only accept lists of [x, y]' % self._name)
return return
validate = getattr(self, "_validate_%s" % self._type.lower()) validate = getattr(self, '_validate_%s' % self._type.lower())
error = validate(value) error = validate(value)
if error: if error:
self.error(error) self.error(error)
@ -518,7 +506,7 @@ class GeoJsonBaseField(BaseField):
try: try:
value[0][0][0] value[0][0][0]
except (TypeError, IndexError): except (TypeError, IndexError):
return "Invalid Polygon must contain at least one valid linestring" return 'Invalid Polygon must contain at least one valid linestring'
errors = [] errors = []
for val in value: for val in value:
@ -529,12 +517,12 @@ class GeoJsonBaseField(BaseField):
errors.append(error) errors.append(error)
if errors: if errors:
if top_level: if top_level:
return "Invalid Polygon:\n%s" % ", ".join(errors) return 'Invalid Polygon:\n%s' % ', '.join(errors)
else: else:
return "%s" % ", ".join(errors) return '%s' % ', '.join(errors)
def _validate_linestring(self, value, top_level=True): def _validate_linestring(self, value, top_level=True):
"""Validates a linestring""" """Validate a linestring."""
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
return 'LineStrings must contain list of coordinate pairs' return 'LineStrings must contain list of coordinate pairs'
@ -542,7 +530,7 @@ class GeoJsonBaseField(BaseField):
try: try:
value[0][0] value[0][0]
except (TypeError, IndexError): except (TypeError, IndexError):
return "Invalid LineString must contain at least one valid point" return 'Invalid LineString must contain at least one valid point'
errors = [] errors = []
for val in value: for val in value:
@ -551,19 +539,19 @@ class GeoJsonBaseField(BaseField):
errors.append(error) errors.append(error)
if errors: if errors:
if top_level: if top_level:
return "Invalid LineString:\n%s" % ", ".join(errors) return 'Invalid LineString:\n%s' % ', '.join(errors)
else: else:
return "%s" % ", ".join(errors) return '%s' % ', '.join(errors)
def _validate_point(self, value): def _validate_point(self, value):
"""Validate each set of coords""" """Validate each set of coords"""
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
return 'Points must be a list of coordinate pairs' return 'Points must be a list of coordinate pairs'
elif not len(value) == 2: elif not len(value) == 2:
return "Value (%s) must be a two-dimensional point" % repr(value) return 'Value (%s) must be a two-dimensional point' % repr(value)
elif (not isinstance(value[0], (float, int)) or elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))): not isinstance(value[1], (float, int))):
return "Both values (%s) in point must be float or int" % repr(value) return 'Both values (%s) in point must be float or int' % repr(value)
def _validate_multipoint(self, value): def _validate_multipoint(self, value):
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
@ -573,7 +561,7 @@ class GeoJsonBaseField(BaseField):
try: try:
value[0][0] value[0][0]
except (TypeError, IndexError): except (TypeError, IndexError):
return "Invalid MultiPoint must contain at least one valid point" return 'Invalid MultiPoint must contain at least one valid point'
errors = [] errors = []
for point in value: for point in value:
@ -582,7 +570,7 @@ class GeoJsonBaseField(BaseField):
errors.append(error) errors.append(error)
if errors: if errors:
return "%s" % ", ".join(errors) return '%s' % ', '.join(errors)
def _validate_multilinestring(self, value, top_level=True): def _validate_multilinestring(self, value, top_level=True):
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
@ -592,7 +580,7 @@ class GeoJsonBaseField(BaseField):
try: try:
value[0][0][0] value[0][0][0]
except (TypeError, IndexError): except (TypeError, IndexError):
return "Invalid MultiLineString must contain at least one valid linestring" return 'Invalid MultiLineString must contain at least one valid linestring'
errors = [] errors = []
for linestring in value: for linestring in value:
@ -602,9 +590,9 @@ class GeoJsonBaseField(BaseField):
if errors: if errors:
if top_level: if top_level:
return "Invalid MultiLineString:\n%s" % ", ".join(errors) return 'Invalid MultiLineString:\n%s' % ', '.join(errors)
else: else:
return "%s" % ", ".join(errors) return '%s' % ', '.join(errors)
def _validate_multipolygon(self, value): def _validate_multipolygon(self, value):
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
@ -614,7 +602,7 @@ class GeoJsonBaseField(BaseField):
try: try:
value[0][0][0][0] value[0][0][0][0]
except (TypeError, IndexError): except (TypeError, IndexError):
return "Invalid MultiPolygon must contain at least one valid Polygon" return 'Invalid MultiPolygon must contain at least one valid Polygon'
errors = [] errors = []
for polygon in value: for polygon in value:
@ -623,9 +611,9 @@ class GeoJsonBaseField(BaseField):
errors.append(error) errors.append(error)
if errors: if errors:
return "Invalid MultiPolygon:\n%s" % ", ".join(errors) return 'Invalid MultiPolygon:\n%s' % ', '.join(errors)
def to_mongo(self, value): def to_mongo(self, value):
if isinstance(value, dict): if isinstance(value, dict):
return value return value
return SON([("type", self._type), ("coordinates", value)]) return SON([('type', self._type), ('coordinates', value)])

View File

@ -88,8 +88,8 @@ class DocumentMetaclass(type):
# Ensure no duplicate db_fields # Ensure no duplicate db_fields
duplicate_db_fields = [k for k, v in field_names.items() if v > 1] duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
if duplicate_db_fields: if duplicate_db_fields:
msg = ("Multiple db_fields defined for: %s " % msg = ('Multiple db_fields defined for: %s ' %
", ".join(duplicate_db_fields)) ', '.join(duplicate_db_fields))
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
# Set _fields and db_field maps # Set _fields and db_field maps
@ -178,11 +178,11 @@ class DocumentMetaclass(type):
if isinstance(f, CachedReferenceField): if isinstance(f, CachedReferenceField):
if issubclass(new_class, EmbeddedDocument): if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError( raise InvalidDocumentError('CachedReferenceFields is not '
"CachedReferenceFields is not allowed in EmbeddedDocuments") 'allowed in EmbeddedDocuments')
if not f.document_type: if not f.document_type:
raise InvalidDocumentError( raise InvalidDocumentError(
"Document is not available to sync") 'Document is not available to sync')
if f.auto_sync: if f.auto_sync:
f.start_listener() f.start_listener()
@ -194,8 +194,8 @@ class DocumentMetaclass(type):
'reverse_delete_rule', 'reverse_delete_rule',
DO_NOTHING) DO_NOTHING)
if isinstance(f, DictField) and delete_rule != DO_NOTHING: if isinstance(f, DictField) and delete_rule != DO_NOTHING:
msg = ("Reverse delete rules are not supported " msg = ('Reverse delete rules are not supported '
"for %s (field: %s)" % 'for %s (field: %s)' %
(field.__class__.__name__, field.name)) (field.__class__.__name__, field.name))
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
@ -203,16 +203,16 @@ class DocumentMetaclass(type):
if delete_rule != DO_NOTHING: if delete_rule != DO_NOTHING:
if issubclass(new_class, EmbeddedDocument): if issubclass(new_class, EmbeddedDocument):
msg = ("Reverse delete rules are not supported for " msg = ('Reverse delete rules are not supported for '
"EmbeddedDocuments (field: %s)" % field.name) 'EmbeddedDocuments (field: %s)' % field.name)
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
f.document_type.register_delete_rule(new_class, f.document_type.register_delete_rule(new_class,
field.name, delete_rule) field.name, delete_rule)
if (field.name and hasattr(Document, field.name) and if (field.name and hasattr(Document, field.name) and
EmbeddedDocument not in new_class.mro()): EmbeddedDocument not in new_class.mro()):
msg = ("%s is a document method and not a valid " msg = ('%s is a document method and not a valid '
"field name" % field.name) 'field name' % field.name)
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
return new_class return new_class
@ -302,7 +302,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# If parent wasn't an abstract class # If parent wasn't an abstract class
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and
not parent_doc_cls._meta.get('abstract', True)): not parent_doc_cls._meta.get('abstract', True)):
msg = "Trying to set a collection on a subclass (%s)" % name msg = 'Trying to set a collection on a subclass (%s)' % name
warnings.warn(msg, SyntaxWarning) warnings.warn(msg, SyntaxWarning)
del attrs['_meta']['collection'] del attrs['_meta']['collection']
@ -310,7 +310,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
if (parent_doc_cls and if (parent_doc_cls and
not parent_doc_cls._meta.get('abstract', False)): not parent_doc_cls._meta.get('abstract', False)):
msg = "Abstract document cannot have non-abstract base" msg = 'Abstract document cannot have non-abstract base'
raise ValueError(msg) raise ValueError(msg)
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)

View File

@ -2,12 +2,12 @@ from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
__all__ = ("switch_db", "switch_collection", "no_dereference", __all__ = ('switch_db', 'switch_collection', 'no_dereference',
"no_sub_classes", "query_counter") 'no_sub_classes', 'query_counter')
class switch_db(object): class switch_db(object):
""" switch_db alias context manager. """switch_db alias context manager.
Example :: Example ::
@ -18,15 +18,14 @@ class switch_db(object):
class Group(Document): class Group(Document):
name = StringField() name = StringField()
Group(name="test").save() # Saves in the default db Group(name='test').save() # Saves in the default db
with switch_db(Group, 'testdb-1') as Group: with switch_db(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1 Group(name='hello testdb!').save() # Saves in testdb-1
""" """
def __init__(self, cls, db_alias): def __init__(self, cls, db_alias):
""" Construct the switch_db context manager """Construct the switch_db context manager
:param cls: the class to change the registered db :param cls: the class to change the registered db
:param db_alias: the name of the specific database to use :param db_alias: the name of the specific database to use
@ -34,37 +33,36 @@ class switch_db(object):
self.cls = cls self.cls = cls
self.collection = cls._get_collection() self.collection = cls._get_collection()
self.db_alias = db_alias self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)
def __enter__(self): def __enter__(self):
""" change the db_alias and clear the cached collection """ """Change the db_alias and clear the cached collection."""
self.cls._meta["db_alias"] = self.db_alias self.cls._meta['db_alias'] = self.db_alias
self.cls._collection = None self.cls._collection = None
return self.cls return self.cls
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """ """Reset the db_alias and collection."""
self.cls._meta["db_alias"] = self.ori_db_alias self.cls._meta['db_alias'] = self.ori_db_alias
self.cls._collection = self.collection self.cls._collection = self.collection
class switch_collection(object): class switch_collection(object):
""" switch_collection alias context manager. """switch_collection alias context manager.
Example :: Example ::
class Group(Document): class Group(Document):
name = StringField() name = StringField()
Group(name="test").save() # Saves in the default db Group(name='test').save() # Saves in the default db
with switch_collection(Group, 'group1') as Group: with switch_collection(Group, 'group1') as Group:
Group(name="hello testdb!").save() # Saves in group1 collection Group(name='hello testdb!').save() # Saves in group1 collection
""" """
def __init__(self, cls, collection_name): def __init__(self, cls, collection_name):
""" Construct the switch_collection context manager """Construct the switch_collection context manager.
:param cls: the class to change the registered db :param cls: the class to change the registered db
:param collection_name: the name of the collection to use :param collection_name: the name of the collection to use
@ -75,7 +73,7 @@ class switch_collection(object):
self.collection_name = collection_name self.collection_name = collection_name
def __enter__(self): def __enter__(self):
""" change the _get_collection_name and clear the cached collection """ """Change the _get_collection_name and clear the cached collection."""
@classmethod @classmethod
def _get_collection_name(cls): def _get_collection_name(cls):
@ -86,24 +84,23 @@ class switch_collection(object):
return self.cls return self.cls
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
""" Reset the collection """ """Reset the collection."""
self.cls._collection = self.ori_collection self.cls._collection = self.ori_collection
self.cls._get_collection_name = self.ori_get_collection_name self.cls._get_collection_name = self.ori_get_collection_name
class no_dereference(object): class no_dereference(object):
""" no_dereference context manager. """no_dereference context manager.
Turns off all dereferencing in Documents for the duration of the context Turns off all dereferencing in Documents for the duration of the context
manager:: manager::
with no_dereference(Group) as Group: with no_dereference(Group) as Group:
Group.objects.find() Group.objects.find()
""" """
def __init__(self, cls): def __init__(self, cls):
""" Construct the no_dereference context manager. """Construct the no_dereference context manager.
:param cls: the class to turn dereferencing off on :param cls: the class to turn dereferencing off on
""" """
@ -119,103 +116,102 @@ class no_dereference(object):
ComplexBaseField))] ComplexBaseField))]
def __enter__(self): def __enter__(self):
""" change the objects default and _auto_dereference values""" """Change the objects default and _auto_dereference values."""
for field in self.deref_fields: for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False self.cls._fields[field]._auto_dereference = False
return self.cls return self.cls
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values""" """Reset the default and _auto_dereference values."""
for field in self.deref_fields: for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True self.cls._fields[field]._auto_dereference = True
return self.cls return self.cls
class no_sub_classes(object): class no_sub_classes(object):
""" no_sub_classes context manager. """no_sub_classes context manager.
Only returns instances of this class and no sub (inherited) classes:: Only returns instances of this class and no sub (inherited) classes::
with no_sub_classes(Group) as Group: with no_sub_classes(Group) as Group:
Group.objects.find() Group.objects.find()
""" """
def __init__(self, cls): def __init__(self, cls):
""" Construct the no_sub_classes context manager. """Construct the no_sub_classes context manager.
:param cls: the class to turn querying sub classes on :param cls: the class to turn querying sub classes on
""" """
self.cls = cls self.cls = cls
def __enter__(self): def __enter__(self):
""" change the objects default and _auto_dereference values""" """Change the objects default and _auto_dereference values."""
self.cls._all_subclasses = self.cls._subclasses self.cls._all_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls,) self.cls._subclasses = (self.cls,)
return self.cls return self.cls
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values""" """Reset the default and _auto_dereference values."""
self.cls._subclasses = self.cls._all_subclasses self.cls._subclasses = self.cls._all_subclasses
delattr(self.cls, '_all_subclasses') delattr(self.cls, '_all_subclasses')
return self.cls return self.cls
class query_counter(object): class query_counter(object):
""" Query_counter context manager to get the number of queries. """ """Query_counter context manager to get the number of queries."""
def __init__(self): def __init__(self):
""" Construct the query_counter. """ """Construct the query_counter."""
self.counter = 0 self.counter = 0
self.db = get_db() self.db = get_db()
def __enter__(self): def __enter__(self):
""" On every with block we need to drop the profile collection. """ """On every with block we need to drop the profile collection."""
self.db.set_profiling_level(0) self.db.set_profiling_level(0)
self.db.system.profile.drop() self.db.system.profile.drop()
self.db.set_profiling_level(2) self.db.set_profiling_level(2)
return self return self
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
""" Reset the profiling level. """ """Reset the profiling level."""
self.db.set_profiling_level(0) self.db.set_profiling_level(0)
def __eq__(self, value): def __eq__(self, value):
""" == Compare querycounter. """ """== Compare querycounter."""
counter = self._get_count() counter = self._get_count()
return value == counter return value == counter
def __ne__(self, value): def __ne__(self, value):
""" != Compare querycounter. """ """!= Compare querycounter."""
return not self.__eq__(value) return not self.__eq__(value)
def __lt__(self, value): def __lt__(self, value):
""" < Compare querycounter. """ """< Compare querycounter."""
return self._get_count() < value return self._get_count() < value
def __le__(self, value): def __le__(self, value):
""" <= Compare querycounter. """ """<= Compare querycounter."""
return self._get_count() <= value return self._get_count() <= value
def __gt__(self, value): def __gt__(self, value):
""" > Compare querycounter. """ """> Compare querycounter."""
return self._get_count() > value return self._get_count() > value
def __ge__(self, value): def __ge__(self, value):
""" >= Compare querycounter. """ """>= Compare querycounter."""
return self._get_count() >= value return self._get_count() >= value
def __int__(self): def __int__(self):
""" int representation. """ """int representation."""
return self._get_count() return self._get_count()
def __repr__(self): def __repr__(self):
""" repr query_counter as the number of queries. """ """repr query_counter as the number of queries."""
return u"%s" % self._get_count() return u"%s" % self._get_count()
def _get_count(self): def _get_count(self):
""" Get the number of queries. """ """Get the number of queries."""
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
count = self.db.system.profile.find(ignore_query).count() - self.counter count = self.db.system.profile.find(ignore_query).count() - self.counter
self.counter += 1 self.counter += 1
return count return count

View File

@ -149,7 +149,7 @@ class DeReference(object):
references = get_db()[collection].find({'_id': {'$in': refs}}) references = get_db()[collection].find({'_id': {'$in': refs}})
for ref in references: for ref in references:
if '_cls' in ref: if '_cls' in ref:
doc = get_document(ref["_cls"])._from_son(ref) doc = get_document(ref['_cls'])._from_son(ref)
elif doc_type is None: elif doc_type is None:
doc = get_document( doc = get_document(
''.join(x.capitalize() ''.join(x.capitalize()
@ -225,7 +225,7 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get( data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection, v['_ref'].id), v) (v['_ref'].collection, v['_ref'].id), v)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = six.text_type("{0}.{1}.{2}").format(name, k, field_name) item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name item_name = '%s.%s' % (name, k) if name else name

View File

@ -25,9 +25,7 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
def includes_cls(fields): def includes_cls(fields):
""" Helper function used for ensuring and comparing indexes """Helper function used for ensuring and comparing indexes."""
"""
first_field = None first_field = None
if len(fields): if len(fields):
if isinstance(fields[0], six.string_types): if isinstance(fields[0], six.string_types):
@ -167,7 +165,7 @@ class Document(BaseDocument):
@classmethod @classmethod
def _get_db(cls): def _get_db(cls):
"""Some Model using other db_alias""" """Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
@classmethod @classmethod
def _get_collection(cls): def _get_collection(cls):
@ -241,15 +239,15 @@ class Document(BaseDocument):
query = {} query = {}
if self.pk is None: if self.pk is None:
raise InvalidDocumentError("The document does not have a primary key.") raise InvalidDocumentError('The document does not have a primary key.')
id_field = self._meta["id_field"] id_field = self._meta['id_field']
query = query.copy() if isinstance(query, dict) else query.to_query(self) query = query.copy() if isinstance(query, dict) else query.to_query(self)
if id_field not in query: if id_field not in query:
query[id_field] = self.pk query[id_field] = self.pk
elif query[id_field] != self.pk: elif query[id_field] != self.pk:
raise InvalidQueryError("Invalid document modify query: it must modify only this document.") raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
updated = self._qs(**query).modify(new=True, **update) updated = self._qs(**query).modify(new=True, **update)
if updated is None: if updated is None:
@ -324,7 +322,7 @@ class Document(BaseDocument):
self.validate(clean=clean) self.validate(clean=clean)
if write_concern is None: if write_concern is None:
write_concern = {"w": 1} write_concern = {'w': 1}
doc = self.to_mongo() doc = self.to_mongo()
@ -372,7 +370,7 @@ class Document(BaseDocument):
def is_new_object(last_error): def is_new_object(last_error):
if last_error is not None: if last_error is not None:
updated = last_error.get("updatedExisting") updated = last_error.get('updatedExisting')
if updated is not None: if updated is not None:
return not updated return not updated
return created return created
@ -380,14 +378,14 @@ class Document(BaseDocument):
update_query = {} update_query = {}
if updates: if updates:
update_query["$set"] = updates update_query['$set'] = updates
if removals: if removals:
update_query["$unset"] = removals update_query['$unset'] = removals
if updates or removals: if updates or removals:
upsert = save_condition is None upsert = save_condition is None
last_error = collection.update(select_dict, update_query, last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern) upsert=upsert, **write_concern)
if not upsert and last_error["n"] == 0: if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing' raise SaveConditionError('Race condition preventing'
' document update detected') ' document update detected')
created = is_new_object(last_error) created = is_new_object(last_error)
@ -398,10 +396,10 @@ class Document(BaseDocument):
if cascade: if cascade:
kwargs = { kwargs = {
"force_insert": force_insert, 'force_insert': force_insert,
"validate": validate, 'validate': validate,
"write_concern": write_concern, 'write_concern': write_concern,
"cascade": cascade 'cascade': cascade
} }
if cascade_kwargs: # Allow granular control over cascades if cascade_kwargs: # Allow granular control over cascades
kwargs.update(cascade_kwargs) kwargs.update(cascade_kwargs)
@ -492,8 +490,8 @@ class Document(BaseDocument):
if self.pk is None: if self.pk is None:
if kwargs.get('upsert', False): if kwargs.get('upsert', False):
query = self.to_mongo() query = self.to_mongo()
if "_cls" in query: if '_cls' in query:
del query["_cls"] del query['_cls']
return self._qs.filter(**query).update_one(**kwargs) return self._qs.filter(**query).update_one(**kwargs)
else: else:
raise OperationError( raise OperationError(
@ -618,11 +616,12 @@ class Document(BaseDocument):
if fields and isinstance(fields[0], int): if fields and isinstance(fields[0], int):
max_depth = fields[0] max_depth = fields[0]
fields = fields[1:] fields = fields[1:]
elif "max_depth" in kwargs: elif 'max_depth' in kwargs:
max_depth = kwargs["max_depth"] max_depth = kwargs['max_depth']
if self.pk is None: if self.pk is None:
raise self.DoesNotExist("Document does not exist") raise self.DoesNotExist('Document does not exist')
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).only(*fields).limit( **self._object_key).only(*fields).limit(
1).select_related(max_depth=max_depth) 1).select_related(max_depth=max_depth)
@ -630,7 +629,7 @@ class Document(BaseDocument):
if obj: if obj:
obj = obj[0] obj = obj[0]
else: else:
raise self.DoesNotExist("Document does not exist") raise self.DoesNotExist('Document does not exist')
for field in obj._data: for field in obj._data:
if not fields or field in fields: if not fields or field in fields:
@ -673,7 +672,7 @@ class Document(BaseDocument):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in """Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries.""" `__raw__` queries."""
if self.pk is None: if self.pk is None:
msg = "Only saved documents can have a valid dbref" msg = 'Only saved documents can have a valid dbref'
raise OperationError(msg) raise OperationError(msg)
return DBRef(self.__class__._get_collection_name(), self.pk) return DBRef(self.__class__._get_collection_name(), self.pk)
@ -728,7 +727,7 @@ class Document(BaseDocument):
fields = index_spec.pop('fields') fields = index_spec.pop('fields')
drop_dups = kwargs.get('drop_dups', False) drop_dups = kwargs.get('drop_dups', False)
if IS_PYMONGO_3 and drop_dups: if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+." msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3: elif not IS_PYMONGO_3:
index_spec['drop_dups'] = drop_dups index_spec['drop_dups'] = drop_dups
@ -754,7 +753,7 @@ class Document(BaseDocument):
will be removed if PyMongo3+ is used will be removed if PyMongo3+ is used
""" """
if IS_PYMONGO_3 and drop_dups: if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+." msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3: elif not IS_PYMONGO_3:
kwargs.update({'drop_dups': drop_dups}) kwargs.update({'drop_dups': drop_dups})
@ -774,7 +773,7 @@ class Document(BaseDocument):
index_opts = cls._meta.get('index_opts') or {} index_opts = cls._meta.get('index_opts') or {}
index_cls = cls._meta.get('index_cls', True) index_cls = cls._meta.get('index_cls', True)
if IS_PYMONGO_3 and drop_dups: if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+." msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
collection = cls._get_collection() collection = cls._get_collection()
@ -889,8 +888,8 @@ class Document(BaseDocument):
@classmethod @classmethod
def compare_indexes(cls): def compare_indexes(cls):
""" Compares the indexes defined in MongoEngine with the ones existing """ Compares the indexes defined in MongoEngine with the ones
in the database. Returns any missing/extra indexes. existing in the database. Returns any missing/extra indexes.
""" """
required = cls.list_indexes() required = cls.list_indexes()
@ -934,8 +933,9 @@ class DynamicDocument(Document):
_dynamic = True _dynamic = True
def __delattr__(self, *args, **kwargs): def __delattr__(self, *args, **kwargs):
"""Deletes the attribute by setting to None and allowing _delta to unset """Delete the attribute by setting to None and allowing _delta
it""" to unset it.
"""
field_name = args[0] field_name = args[0]
if field_name in self._dynamic_fields: if field_name in self._dynamic_fields:
setattr(self, field_name, None) setattr(self, field_name, None)
@ -957,8 +957,9 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
_dynamic = True _dynamic = True
def __delattr__(self, *args, **kwargs): def __delattr__(self, *args, **kwargs):
"""Deletes the attribute by setting to None and allowing _delta to unset """Delete the attribute by setting to None and allowing _delta
it""" to unset it.
"""
field_name = args[0] field_name = args[0]
if field_name in self._fields: if field_name in self._fields:
default = self._fields[field_name].default default = self._fields[field_name].default
@ -1000,10 +1001,10 @@ class MapReduceDocument(object):
try: try:
self.key = id_field_type(self.key) self.key = id_field_type(self.key)
except Exception: except Exception:
raise Exception("Could not cast key as %s" % raise Exception('Could not cast key as %s' %
id_field_type.__name__) id_field_type.__name__)
if not hasattr(self, "_key_object"): if not hasattr(self, '_key_object'):
self._key_object = self._document.objects.with_id(self.key) self._key_object = self._document.objects.with_id(self.key)
return self._key_object return self._key_object
return self._key_object return self._key_object

View File

@ -70,7 +70,7 @@ class ValidationError(AssertionError):
field_name = None field_name = None
_message = None _message = None
def __init__(self, message="", **kwargs): def __init__(self, message='', **kwargs):
self.errors = kwargs.get('errors', {}) self.errors = kwargs.get('errors', {})
self.field_name = kwargs.get('field_name') self.field_name = kwargs.get('field_name')
self.message = message self.message = message
@ -136,10 +136,10 @@ class ValidationError(AssertionError):
value = ' '.join( value = ' '.join(
[generate_key(v, k) for k, v in value.iteritems()]) [generate_key(v, k) for k, v in value.iteritems()])
results = "%s.%s" % (prefix, value) if prefix else value results = '%s.%s' % (prefix, value) if prefix else value
return results return results
error_dict = defaultdict(list) error_dict = defaultdict(list)
for k, v in self.to_dict().iteritems(): for k, v in self.to_dict().iteritems():
error_dict[generate_key(v)].append(k) error_dict[generate_key(v)].append(k)
return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()])

View File

@ -57,8 +57,7 @@ RECURSIVE_REFERENCE_CONSTANT = 'self'
class StringField(BaseField): class StringField(BaseField):
"""A unicode string field. """A unicode string field."""
"""
def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): def __init__(self, regex=None, max_length=None, min_length=None, **kwargs):
self.regex = re.compile(regex) if regex else None self.regex = re.compile(regex) if regex else None
@ -151,8 +150,8 @@ class URLField(StringField):
if self.verify_exists: if self.verify_exists:
warnings.warn( warnings.warn(
"The URLField verify_exists argument has intractable security " 'The URLField verify_exists argument has intractable security '
"and performance issues. Accordingly, it has been deprecated.", 'and performance issues. Accordingly, it has been deprecated.',
DeprecationWarning) DeprecationWarning)
try: try:
request = urllib2.Request(value) request = urllib2.Request(value)
@ -183,8 +182,7 @@ class EmailField(StringField):
class IntField(BaseField): class IntField(BaseField):
"""An 32-bit integer field. """32-bit integer field."""
"""
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value self.min_value, self.max_value = min_value, max_value
@ -217,8 +215,7 @@ class IntField(BaseField):
class LongField(BaseField): class LongField(BaseField):
"""An 64-bit integer field. """64-bit integer field."""
"""
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value self.min_value, self.max_value = min_value, max_value
@ -254,8 +251,7 @@ class LongField(BaseField):
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """Floating point number field."""
"""
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value self.min_value, self.max_value = min_value, max_value
@ -292,7 +288,7 @@ class FloatField(BaseField):
class DecimalField(BaseField): class DecimalField(BaseField):
"""A fixed-point decimal number field. """Fixed-point decimal number field.
.. versionchanged:: 0.8 .. versionchanged:: 0.8
.. versionadded:: 0.3 .. versionadded:: 0.3
@ -333,10 +329,10 @@ class DecimalField(BaseField):
# Convert to string for python 2.6 before casting to Decimal # Convert to string for python 2.6 before casting to Decimal
try: try:
value = decimal.Decimal("%s" % value) value = decimal.Decimal('%s' % value)
except decimal.InvalidOperation: except decimal.InvalidOperation:
return value return value
return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding)
def to_mongo(self, value): def to_mongo(self, value):
if value is None: if value is None:
@ -365,7 +361,7 @@ class DecimalField(BaseField):
class BooleanField(BaseField): class BooleanField(BaseField):
"""A boolean field type. """Boolean field type.
.. versionadded:: 0.1.2 .. versionadded:: 0.1.2
""" """
@ -383,7 +379,7 @@ class BooleanField(BaseField):
class DateTimeField(BaseField): class DateTimeField(BaseField):
"""A datetime field. """Datetime field.
Uses the python-dateutil library if available alternatively use time.strptime Uses the python-dateutil library if available alternatively use time.strptime
to parse the dates. Note: python-dateutil's parser is fully featured and when to parse the dates. Note: python-dateutil's parser is fully featured and when
@ -643,7 +639,7 @@ class DynamicField(BaseField):
val = value.to_mongo(use_db_field, fields) val = value.to_mongo(use_db_field, fields)
# If we its a document thats not inherited add _cls # If we its a document thats not inherited add _cls
if isinstance(value, Document): if isinstance(value, Document):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__} val = {'_ref': value.to_dbref(), '_cls': cls.__name__}
if isinstance(value, EmbeddedDocument): if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__ val['_cls'] = cls.__name__
return val return val
@ -683,7 +679,7 @@ class DynamicField(BaseField):
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
def validate(self, value, clean=True): def validate(self, value, clean=True):
if hasattr(value, "validate"): if hasattr(value, 'validate'):
value.validate(clean=clean) value.validate(clean=clean)
@ -703,8 +699,7 @@ class ListField(ComplexBaseField):
super(ListField, self).__init__(**kwargs) super(ListField, self).__init__(**kwargs)
def validate(self, value): def validate(self, value):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used."""
"""
if (not isinstance(value, (list, tuple, QuerySet)) or if (not isinstance(value, (list, tuple, QuerySet)) or
isinstance(value, six.string_types)): isinstance(value, six.string_types)):
self.error('Only lists and tuples may be used in a list field') self.error('Only lists and tuples may be used in a list field')
@ -737,7 +732,6 @@ class EmbeddedDocumentListField(ListField):
:class:`~mongoengine.EmbeddedDocument`. :class:`~mongoengine.EmbeddedDocument`.
.. versionadded:: 0.9 .. versionadded:: 0.9
""" """
def __init__(self, document_type, **kwargs): def __init__(self, document_type, **kwargs):
@ -786,8 +780,8 @@ class SortedListField(ListField):
def key_not_string(d): def key_not_string(d):
""" Helper function to recursively determine if any key in a dictionary is """Helper function to recursively determine if any key in a
not a string. dictionary is not a string.
""" """
for k, v in d.items(): for k, v in d.items():
if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)): if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)):
@ -795,8 +789,8 @@ def key_not_string(d):
def key_has_dot_or_dollar(d): def key_has_dot_or_dollar(d):
""" Helper function to recursively determine if any key in a dictionary """Helper function to recursively determine if any key in a
contains a dot or a dollar sign. dictionary contains a dot or a dollar sign.
""" """
for k, v in d.items(): for k, v in d.items():
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
@ -824,14 +818,13 @@ class DictField(ComplexBaseField):
super(DictField, self).__init__(*args, **kwargs) super(DictField, self).__init__(*args, **kwargs)
def validate(self, value): def validate(self, value):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used."""
"""
if not isinstance(value, dict): if not isinstance(value, dict):
self.error('Only dictionaries may be used in a DictField') self.error('Only dictionaries may be used in a DictField')
if key_not_string(value): if key_not_string(value):
msg = ("Invalid dictionary key - documents must " msg = ('Invalid dictionary key - documents must '
"have only string keys") 'have only string keys')
self.error(msg) self.error(msg)
if key_has_dot_or_dollar(value): if key_has_dot_or_dollar(value):
self.error('Invalid dictionary key name - keys may not contain "."' self.error('Invalid dictionary key name - keys may not contain "."'
@ -944,8 +937,7 @@ class ReferenceField(BaseField):
return self.document_type_obj return self.document_type_obj
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor to allow lazy dereferencing. """Descriptor to allow lazy dereferencing."""
"""
if instance is None: if instance is None:
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
@ -1002,8 +994,7 @@ class ReferenceField(BaseField):
return id_ return id_
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type."""
"""
if (not self.dbref and if (not self.dbref and
not isinstance(value, (DBRef, Document, EmbeddedDocument))): not isinstance(value, (DBRef, Document, EmbeddedDocument))):
collection = self.document_type._get_collection_name() collection = self.document_type._get_collection_name()
@ -1019,7 +1010,7 @@ class ReferenceField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, (self.document_type, DBRef)): if not isinstance(value, (self.document_type, DBRef)):
self.error("A ReferenceField only accepts DBRef or documents") self.error('A ReferenceField only accepts DBRef or documents')
if isinstance(value, Document) and value.id is None: if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been ' self.error('You can only reference documents once they have been '
@ -1135,7 +1126,7 @@ class CachedReferenceField(BaseField):
# TODO: should raise here or will fail next statement # TODO: should raise here or will fail next statement
value = SON(( value = SON((
("_id", id_field.to_mongo(id_)), ('_id', id_field.to_mongo(id_)),
)) ))
if fields: if fields:
@ -1161,7 +1152,7 @@ class CachedReferenceField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
self.error("A CachedReferenceField only accepts documents") self.error('A CachedReferenceField only accepts documents')
if isinstance(value, Document) and value.id is None: if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been ' self.error('You can only reference documents once they have been '
@ -1298,8 +1289,7 @@ class GenericReferenceField(BaseField):
class BinaryField(BaseField): class BinaryField(BaseField):
"""A binary data field. """A binary data field."""
"""
def __init__(self, max_bytes=None, **kwargs): def __init__(self, max_bytes=None, **kwargs):
self.max_bytes = max_bytes self.max_bytes = max_bytes
@ -1316,8 +1306,8 @@ class BinaryField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, (six.binary_type, six.text_type, Binary)): if not isinstance(value, (six.binary_type, six.text_type, Binary)):
self.error("BinaryField only accepts instances of " self.error('BinaryField only accepts instances of '
"(%s, %s, Binary)" % ( '(%s, %s, Binary)' % (
six.binary_type.__name__, six.text_type.__name__)) six.binary_type.__name__, six.text_type.__name__))
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
@ -1450,7 +1440,7 @@ class GridFSProxy(object):
try: try:
return gridout.read(size) return gridout.read(size)
except Exception: except Exception:
return "" return ''
def delete(self): def delete(self):
# Delete file from GridFS, FileField still remains # Delete file from GridFS, FileField still remains
@ -1482,9 +1472,8 @@ class FileField(BaseField):
""" """
proxy_class = GridFSProxy proxy_class = GridFSProxy
def __init__(self, def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs',
db_alias=DEFAULT_CONNECTION_NAME, **kwargs):
collection_name="fs", **kwargs):
super(FileField, self).__init__(**kwargs) super(FileField, self).__init__(**kwargs)
self.collection_name = collection_name self.collection_name = collection_name
self.db_alias = db_alias self.db_alias = db_alias
@ -1687,10 +1676,10 @@ class ImageGridFsProxy(GridFSProxy):
return self.fs.get(out.thumbnail_id) return self.fs.get(out.thumbnail_id)
def write(self, *args, **kwargs): def write(self, *args, **kwargs):
raise RuntimeError("Please use \"put\" method instead") raise RuntimeError('Please use "put" method instead')
def writelines(self, *args, **kwargs): def writelines(self, *args, **kwargs):
raise RuntimeError("Please use \"put\" method instead") raise RuntimeError('Please use "put" method instead')
class ImproperlyConfigured(Exception): class ImproperlyConfigured(Exception):
@ -1715,7 +1704,7 @@ class ImageField(FileField):
def __init__(self, size=None, thumbnail_size=None, def __init__(self, size=None, thumbnail_size=None,
collection_name='images', **kwargs): collection_name='images', **kwargs):
if not Image: if not Image:
raise ImproperlyConfigured("PIL library was not found") raise ImproperlyConfigured('PIL library was not found')
params_size = ('width', 'height', 'force') params_size = ('width', 'height', 'force')
extra_args = dict(size=size, thumbnail_size=thumbnail_size) extra_args = dict(size=size, thumbnail_size=thumbnail_size)
@ -1783,10 +1772,10 @@ class SequenceField(BaseField):
Generate and Increment the counter Generate and Increment the counter
""" """
sequence_name = self.get_sequence_name() sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name) sequence_id = '%s.%s' % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name] collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id}, counter = collection.find_and_modify(query={'_id': sequence_id},
update={"$inc": {"next": 1}}, update={'$inc': {'next': 1}},
new=True, new=True,
upsert=True) upsert=True)
return self.value_decorator(counter['next']) return self.value_decorator(counter['next'])
@ -1809,9 +1798,9 @@ class SequenceField(BaseField):
as it is only fixed on set. as it is only fixed on set.
""" """
sequence_name = self.get_sequence_name() sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name) sequence_id = '%s.%s' % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name] collection = get_db(alias=self.db_alias)[self.collection_name]
data = collection.find_one({"_id": sequence_id}) data = collection.find_one({'_id': sequence_id})
if data: if data:
return self.value_decorator(data['next'] + 1) return self.value_decorator(data['next'] + 1)
@ -1924,19 +1913,18 @@ class GeoPointField(BaseField):
_geo_index = pymongo.GEO2D _geo_index = pymongo.GEO2D
def validate(self, value): def validate(self, value):
"""Make sure that a geo-value is of type (x, y) """Make sure that a geo-value is of type (x, y)"""
"""
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
self.error('GeoPointField can only accept tuples or lists ' self.error('GeoPointField can only accept tuples or lists '
'of (x, y)') 'of (x, y)')
if not len(value) == 2: if not len(value) == 2:
self.error("Value (%s) must be a two-dimensional point" % self.error('Value (%s) must be a two-dimensional point' %
repr(value)) repr(value))
elif (not isinstance(value[0], (float, int)) or elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))): not isinstance(value[1], (float, int))):
self.error( self.error(
"Both values (%s) in point must be float or int" % repr(value)) 'Both values (%s) in point must be float or int' % repr(value))
class PointField(GeoJsonBaseField): class PointField(GeoJsonBaseField):
@ -1946,8 +1934,8 @@ class PointField(GeoJsonBaseField):
.. code-block:: js .. code-block:: js
{ "type" : "Point" , {'type' : 'Point' ,
"coordinates" : [x, y]} 'coordinates' : [x, y]}
You can either pass a dict with the full information or a list You can either pass a dict with the full information or a list
to set the value. to set the value.
@ -1956,7 +1944,7 @@ class PointField(GeoJsonBaseField):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
_type = "Point" _type = 'Point'
class LineStringField(GeoJsonBaseField): class LineStringField(GeoJsonBaseField):
@ -1966,8 +1954,8 @@ class LineStringField(GeoJsonBaseField):
.. code-block:: js .. code-block:: js
{ "type" : "LineString" , {'type' : 'LineString' ,
"coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]} 'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]}
You can either pass a dict with the full information or a list of points. You can either pass a dict with the full information or a list of points.
@ -1975,7 +1963,7 @@ class LineStringField(GeoJsonBaseField):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
_type = "LineString" _type = 'LineString'
class PolygonField(GeoJsonBaseField): class PolygonField(GeoJsonBaseField):
@ -1985,9 +1973,9 @@ class PolygonField(GeoJsonBaseField):
.. code-block:: js .. code-block:: js
{ "type" : "Polygon" , {'type' : 'Polygon' ,
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], 'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]]} [[x1, y1], [x1, y1] ... [xn, yn]]}
You can either pass a dict with the full information or a list You can either pass a dict with the full information or a list
of LineStrings. The first LineString being the outside and the rest being of LineStrings. The first LineString being the outside and the rest being
@ -1997,7 +1985,7 @@ class PolygonField(GeoJsonBaseField):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
_type = "Polygon" _type = 'Polygon'
class MultiPointField(GeoJsonBaseField): class MultiPointField(GeoJsonBaseField):
@ -2007,8 +1995,8 @@ class MultiPointField(GeoJsonBaseField):
.. code-block:: js .. code-block:: js
{ "type" : "MultiPoint" , {'type' : 'MultiPoint' ,
"coordinates" : [[x1, y1], [x2, y2]]} 'coordinates' : [[x1, y1], [x2, y2]]}
You can either pass a dict with the full information or a list You can either pass a dict with the full information or a list
to set the value. to set the value.
@ -2017,7 +2005,7 @@ class MultiPointField(GeoJsonBaseField):
.. versionadded:: 0.9 .. versionadded:: 0.9
""" """
_type = "MultiPoint" _type = 'MultiPoint'
class MultiLineStringField(GeoJsonBaseField): class MultiLineStringField(GeoJsonBaseField):
@ -2027,9 +2015,9 @@ class MultiLineStringField(GeoJsonBaseField):
.. code-block:: js .. code-block:: js
{ "type" : "MultiLineString" , {'type' : 'MultiLineString' ,
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], 'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]]]} [[x1, y1], [x1, y1] ... [xn, yn]]]}
You can either pass a dict with the full information or a list of points. You can either pass a dict with the full information or a list of points.
@ -2037,7 +2025,7 @@ class MultiLineStringField(GeoJsonBaseField):
.. versionadded:: 0.9 .. versionadded:: 0.9
""" """
_type = "MultiLineString" _type = 'MultiLineString'
class MultiPolygonField(GeoJsonBaseField): class MultiPolygonField(GeoJsonBaseField):
@ -2047,14 +2035,14 @@ class MultiPolygonField(GeoJsonBaseField):
.. code-block:: js .. code-block:: js
{ "type" : "MultiPolygon" , {'type' : 'MultiPolygon' ,
"coordinates" : [[ 'coordinates' : [[
[[x1, y1], [x1, y1] ... [xn, yn]], [[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]] [[x1, y1], [x1, y1] ... [xn, yn]]
], [ ], [
[[x1, y1], [x1, y1] ... [xn, yn]], [[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]] [[x1, y1], [x1, y1] ... [xn, yn]]
] ]
} }
You can either pass a dict with the full information or a list You can either pass a dict with the full information or a list
@ -2064,4 +2052,4 @@ class MultiPolygonField(GeoJsonBaseField):
.. versionadded:: 0.9 .. versionadded:: 0.9
""" """
_type = "MultiPolygon" _type = 'MultiPolygon'

View File

@ -74,10 +74,10 @@ class BaseQuerySet(object):
# subclasses of the class being used # subclasses of the class being used
if document._meta.get('allow_inheritance') is True: if document._meta.get('allow_inheritance') is True:
if len(self._document._subclasses) == 1: if len(self._document._subclasses) == 1:
self._initial_query = {"_cls": self._document._subclasses[0]} self._initial_query = {'_cls': self._document._subclasses[0]}
else: else:
self._initial_query = { self._initial_query = {
"_cls": {"$in": self._document._subclasses}} '_cls': {'$in': self._document._subclasses}}
self._loaded_fields = QueryFieldList(always_include=['_cls']) self._loaded_fields = QueryFieldList(always_include=['_cls'])
self._cursor_obj = None self._cursor_obj = None
self._limit = None self._limit = None
@ -106,8 +106,8 @@ class BaseQuerySet(object):
if q_obj: if q_obj:
# make sure proper query object is passed # make sure proper query object is passed
if not isinstance(q_obj, QNode): if not isinstance(q_obj, QNode):
msg = ("Not a query object: %s. " msg = ('Not a query object: %s. '
"Did you intend to use key=value?" % q_obj) 'Did you intend to use key=value?' % q_obj)
raise InvalidQueryError(msg) raise InvalidQueryError(msg)
query &= q_obj query &= q_obj
@ -134,10 +134,10 @@ class BaseQuerySet(object):
obj_dict = self.__dict__.copy() obj_dict = self.__dict__.copy()
# don't picke collection, instead pickle collection params # don't picke collection, instead pickle collection params
obj_dict.pop("_collection_obj") obj_dict.pop('_collection_obj')
# don't pickle cursor # don't pickle cursor
obj_dict["_cursor_obj"] = None obj_dict['_cursor_obj'] = None
return obj_dict return obj_dict
@ -148,7 +148,7 @@ class BaseQuerySet(object):
See https://github.com/MongoEngine/mongoengine/issues/442 See https://github.com/MongoEngine/mongoengine/issues/442
""" """
obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() obj_dict['_collection_obj'] = obj_dict['_document']._get_collection()
# update attributes # update attributes
self.__dict__.update(obj_dict) self.__dict__.update(obj_dict)
@ -200,19 +200,16 @@ class BaseQuerySet(object):
raise NotImplementedError raise NotImplementedError
def _has_data(self): def _has_data(self):
""" Retrieves whether cursor has any data. """ """Return True if cursor has any data."""
queryset = self.order_by() queryset = self.order_by()
return False if queryset.first() is None else True return False if queryset.first() is None else True
def __nonzero__(self): def __nonzero__(self):
""" Avoid to open all records in an if stmt in Py2. """ """Avoid to open all records in an if stmt in Py2."""
return self._has_data() return self._has_data()
def __bool__(self): def __bool__(self):
""" Avoid to open all records in an if stmt in Py3. """ """Avoid to open all records in an if stmt in Py3."""
return self._has_data() return self._has_data()
# Core functions # Core functions
@ -240,7 +237,7 @@ class BaseQuerySet(object):
queryset = self.clone() queryset = self.clone()
if queryset._search_text: if queryset._search_text:
raise OperationError( raise OperationError(
"It is not possible to use search_text two times.") 'It is not possible to use search_text two times.')
query_kwargs = SON({'$search': text}) query_kwargs = SON({'$search': text})
if language: if language:
@ -269,7 +266,7 @@ class BaseQuerySet(object):
try: try:
result = queryset.next() result = queryset.next()
except StopIteration: except StopIteration:
msg = ("%s matching query does not exist." msg = ('%s matching query does not exist.'
% queryset._document._class_name) % queryset._document._class_name)
raise queryset._document.DoesNotExist(msg) raise queryset._document.DoesNotExist(msg)
try: try:
@ -291,8 +288,7 @@ class BaseQuerySet(object):
return self._document(**kwargs).save() return self._document(**kwargs).save()
def first(self): def first(self):
"""Retrieve the first object matching the query. """Retrieve the first object matching the query."""
"""
queryset = self.clone() queryset = self.clone()
try: try:
result = queryset[0] result = queryset[0]
@ -341,7 +337,7 @@ class BaseQuerySet(object):
% str(self._document)) % str(self._document))
raise OperationError(msg) raise OperationError(msg)
if doc.pk and not doc._created: if doc.pk and not doc._created:
msg = "Some documents have ObjectIds use doc.update() instead" msg = 'Some documents have ObjectIds use doc.update() instead'
raise OperationError(msg) raise OperationError(msg)
signal_kwargs = signal_kwargs or {} signal_kwargs = signal_kwargs or {}
@ -433,7 +429,7 @@ class BaseQuerySet(object):
rule = doc._meta['delete_rules'][rule_entry] rule = doc._meta['delete_rules'][rule_entry]
if rule == DENY and document_cls.objects( if rule == DENY and document_cls.objects(
**{field_name + '__in': self}).count() > 0: **{field_name + '__in': self}).count() > 0:
msg = ("Could not delete document (%s.%s refers to it)" msg = ('Could not delete document (%s.%s refers to it)'
% (document_cls.__name__, field_name)) % (document_cls.__name__, field_name))
raise OperationError(msg) raise OperationError(msg)
@ -462,7 +458,7 @@ class BaseQuerySet(object):
result = queryset._collection.remove(queryset._query, **write_concern) result = queryset._collection.remove(queryset._query, **write_concern)
if result: if result:
return result.get("n") return result.get('n')
def update(self, upsert=False, multi=True, write_concern=None, def update(self, upsert=False, multi=True, write_concern=None,
full_result=False, **update): full_result=False, **update):
@ -483,7 +479,7 @@ class BaseQuerySet(object):
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if not update and not upsert: if not update and not upsert:
raise OperationError("No update parameters, would remove data") raise OperationError('No update parameters, would remove data')
if write_concern is None: if write_concern is None:
write_concern = {} write_concern = {}
@ -496,9 +492,9 @@ class BaseQuerySet(object):
# then ensure we add _cls to the update operation # then ensure we add _cls to the update operation
if upsert and '_cls' in query: if upsert and '_cls' in query:
if '$set' in update: if '$set' in update:
update["$set"]["_cls"] = queryset._document._class_name update['$set']['_cls'] = queryset._document._class_name
else: else:
update["$set"] = {"_cls": queryset._document._class_name} update['$set'] = {'_cls': queryset._document._class_name}
try: try:
result = queryset._collection.update(query, update, multi=multi, result = queryset._collection.update(query, update, multi=multi,
upsert=upsert, **write_concern) upsert=upsert, **write_concern)
@ -583,11 +579,11 @@ class BaseQuerySet(object):
""" """
if remove and new: if remove and new:
raise OperationError("Conflicting parameters: remove and new") raise OperationError('Conflicting parameters: remove and new')
if not update and not upsert and not remove: if not update and not upsert and not remove:
raise OperationError( raise OperationError(
"No update parameters, must either update or remove") 'No update parameters, must either update or remove')
queryset = self.clone() queryset = self.clone()
query = queryset._query query = queryset._query
@ -598,7 +594,7 @@ class BaseQuerySet(object):
try: try:
if IS_PYMONGO_3: if IS_PYMONGO_3:
if full_response: if full_response:
msg = "With PyMongo 3+, it is not possible anymore to get the full response." msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
if remove: if remove:
result = queryset._collection.find_one_and_delete( result = queryset._collection.find_one_and_delete(
@ -617,13 +613,13 @@ class BaseQuerySet(object):
query, update, upsert=upsert, sort=sort, remove=remove, new=new, query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args) full_response=full_response, **self._cursor_args)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u"Update failed (%s)" % err) raise NotUniqueError(u'Update failed (%s)' % err)
except pymongo.errors.OperationFailure as err: except pymongo.errors.OperationFailure as err:
raise OperationError(u"Update failed (%s)" % err) raise OperationError(u'Update failed (%s)' % err)
if full_response: if full_response:
if result["value"] is not None: if result['value'] is not None:
result["value"] = self._document._from_son(result["value"], only_fields=self.only_fields) result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields)
else: else:
if result is not None: if result is not None:
result = self._document._from_son(result, only_fields=self.only_fields) result = self._document._from_son(result, only_fields=self.only_fields)
@ -641,7 +637,7 @@ class BaseQuerySet(object):
""" """
queryset = self.clone() queryset = self.clone()
if not queryset._query_obj.empty: if not queryset._query_obj.empty:
msg = "Cannot use a filter whilst using `with_id`" msg = 'Cannot use a filter whilst using `with_id`'
raise InvalidQueryError(msg) raise InvalidQueryError(msg)
return queryset.filter(pk=object_id).first() return queryset.filter(pk=object_id).first()
@ -685,7 +681,7 @@ class BaseQuerySet(object):
Only return instances of this document and not any inherited documents Only return instances of this document and not any inherited documents
""" """
if self._document._meta.get('allow_inheritance') is True: if self._document._meta.get('allow_inheritance') is True:
self._initial_query = {"_cls": self._document._class_name} self._initial_query = {'_cls': self._document._class_name}
return self return self
@ -824,9 +820,9 @@ class BaseQuerySet(object):
ListField = _import_class('ListField') ListField = _import_class('ListField')
GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField')
if isinstance(doc_field, ListField): if isinstance(doc_field, ListField):
doc_field = getattr(doc_field, "field", doc_field) doc_field = getattr(doc_field, 'field', doc_field)
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
instance = getattr(doc_field, "document_type", False) instance = getattr(doc_field, 'document_type', False)
# handle distinct on subdocuments # handle distinct on subdocuments
if '.' in field: if '.' in field:
for field_part in field.split('.')[1:]: for field_part in field.split('.')[1:]:
@ -837,9 +833,9 @@ class BaseQuerySet(object):
doc_field = getattr(doc_field, field_part, doc_field) doc_field = getattr(doc_field, field_part, doc_field)
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
if isinstance(doc_field, ListField): if isinstance(doc_field, ListField):
doc_field = getattr(doc_field, "field", doc_field) doc_field = getattr(doc_field, 'field', doc_field)
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
instance = getattr(doc_field, "document_type", False) instance = getattr(doc_field, 'document_type', False)
if instance and isinstance(doc_field, (EmbeddedDocumentField, if instance and isinstance(doc_field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)): GenericEmbeddedDocumentField)):
distinct = [instance(**doc) for doc in distinct] distinct = [instance(**doc) for doc in distinct]
@ -848,12 +844,12 @@ class BaseQuerySet(object):
def only(self, *fields): def only(self, *fields):
"""Load only a subset of this document's fields. :: """Load only a subset of this document's fields. ::
post = BlogPost.objects(...).only("title", "author.name") post = BlogPost.objects(...).only('title', 'author.name')
.. note :: `only()` is chainable and will perform a union :: .. note :: `only()` is chainable and will perform a union ::
So with the following it will fetch both: `title` and `author.name`:: So with the following it will fetch both: `title` and `author.name`::
post = BlogPost.objects.only("title").only("author.name") post = BlogPost.objects.only('title').only('author.name')
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
field filters. field filters.
@ -870,12 +866,12 @@ class BaseQuerySet(object):
def exclude(self, *fields): def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. :: """Opposite to .only(), exclude some document's fields. ::
post = BlogPost.objects(...).exclude("comments") post = BlogPost.objects(...).exclude('comments')
.. note :: `exclude()` is chainable and will perform a union :: .. note :: `exclude()` is chainable and will perform a union ::
So with the following it will exclude both: `title` and `author.name`:: So with the following it will exclude both: `title` and `author.name`::
post = BlogPost.objects.exclude("title").exclude("author.name") post = BlogPost.objects.exclude('title').exclude('author.name')
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
field filters. field filters.
@ -905,7 +901,7 @@ class BaseQuerySet(object):
""" """
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
operators = ["slice"] operators = ['slice']
cleaned_fields = [] cleaned_fields = []
for key, value in kwargs.items(): for key, value in kwargs.items():
parts = key.split('__') parts = key.split('__')
@ -929,7 +925,7 @@ class BaseQuerySet(object):
"""Include all fields. Reset all previously calls of .only() or """Include all fields. Reset all previously calls of .only() or
.exclude(). :: .exclude(). ::
post = BlogPost.objects.exclude("comments").all_fields() post = BlogPost.objects.exclude('comments').all_fields()
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
@ -956,7 +952,7 @@ class BaseQuerySet(object):
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details. for details.
""" """
return self._chainable_method("comment", text) return self._chainable_method('comment', text)
def explain(self, format=False): def explain(self, format=False):
"""Return an explain plan record for the """Return an explain plan record for the
@ -979,7 +975,7 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+ .. deprecated:: Ignored with PyMongo 3+
""" """
if IS_PYMONGO_3: if IS_PYMONGO_3:
msg = "snapshot is deprecated as it has no impact when using PyMongo 3+." msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
queryset = self.clone() queryset = self.clone()
queryset._snapshot = enabled queryset._snapshot = enabled
@ -1005,7 +1001,7 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+ .. deprecated:: Ignored with PyMongo 3+
""" """
if IS_PYMONGO_3: if IS_PYMONGO_3:
msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+." msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
queryset = self.clone() queryset = self.clone()
queryset._slave_okay = enabled queryset._slave_okay = enabled
@ -1067,7 +1063,7 @@ class BaseQuerySet(object):
:param ms: the number of milliseconds before killing the query on the server :param ms: the number of milliseconds before killing the query on the server
""" """
return self._chainable_method("max_time_ms", ms) return self._chainable_method('max_time_ms', ms)
# JSON Helpers # JSON Helpers
@ -1150,8 +1146,8 @@ class BaseQuerySet(object):
MapReduceDocument = _import_class('MapReduceDocument') MapReduceDocument = _import_class('MapReduceDocument')
if not hasattr(self._collection, "map_reduce"): if not hasattr(self._collection, 'map_reduce'):
raise NotImplementedError("Requires MongoDB >= 1.7.1") raise NotImplementedError('Requires MongoDB >= 1.7.1')
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, Code): if isinstance(map_f, Code):
@ -1201,7 +1197,7 @@ class BaseQuerySet(object):
break break
else: else:
raise OperationError("actionData not specified for output") raise OperationError('actionData not specified for output')
db_alias = output.get('db_alias') db_alias = output.get('db_alias')
remaing_args = ['db', 'sharded', 'nonAtomic'] remaing_args = ['db', 'sharded', 'nonAtomic']
@ -1431,7 +1427,7 @@ class BaseQuerySet(object):
# snapshot is not handled at all by PyMongo 3+ # snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers # TODO: evaluate similar possibilities using modifiers
if self._snapshot: if self._snapshot:
msg = "The snapshot option is not anymore available with PyMongo 3+" msg = 'The snapshot option is not anymore available with PyMongo 3+'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
cursor_args = { cursor_args = {
'no_cursor_timeout': not self._timeout 'no_cursor_timeout': not self._timeout
@ -1443,7 +1439,7 @@ class BaseQuerySet(object):
if fields_name not in cursor_args: if fields_name not in cursor_args:
cursor_args[fields_name] = {} cursor_args[fields_name] = {}
cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"} cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'}
return cursor_args return cursor_args
@ -1498,8 +1494,8 @@ class BaseQuerySet(object):
if self._mongo_query is None: if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document) self._mongo_query = self._query_obj.to_query(self._document)
if self._class_check and self._initial_query: if self._class_check and self._initial_query:
if "_cls" in self._mongo_query: if '_cls' in self._mongo_query:
self._mongo_query = {"$and": [self._initial_query, self._mongo_query]} self._mongo_query = {'$and': [self._initial_query, self._mongo_query]}
else: else:
self._mongo_query.update(self._initial_query) self._mongo_query.update(self._initial_query)
return self._mongo_query return self._mongo_query
@ -1511,8 +1507,7 @@ class BaseQuerySet(object):
return self.__dereference return self.__dereference
def no_dereference(self): def no_dereference(self):
"""Turn off any dereferencing for the results of this queryset. """Turn off any dereferencing for the results of this queryset."""
"""
queryset = self.clone() queryset = self.clone()
queryset._auto_dereference = False queryset._auto_dereference = False
return queryset return queryset
@ -1641,14 +1636,14 @@ class BaseQuerySet(object):
for x in document._subclasses][1:] for x in document._subclasses][1:]
for field in fields: for field in fields:
try: try:
field = ".".join(f.db_field for f in field = '.'.join(f.db_field for f in
document._lookup_field(field.split('.'))) document._lookup_field(field.split('.')))
ret.append(field) ret.append(field)
except LookUpError as err: except LookUpError as err:
found = False found = False
for subdoc in subclasses: for subdoc in subclasses:
try: try:
subfield = ".".join(f.db_field for f in subfield = '.'.join(f.db_field for f in
subdoc._lookup_field(field.split('.'))) subdoc._lookup_field(field.split('.')))
ret.append(subfield) ret.append(subfield)
found = True found = True
@ -1661,15 +1656,14 @@ class BaseQuerySet(object):
return ret return ret
def _get_order_by(self, keys): def _get_order_by(self, keys):
"""Creates a list of order by fields """Creates a list of order by fields"""
"""
key_list = [] key_list = []
for key in keys: for key in keys:
if not key: if not key:
continue continue
if key == '$text_score': if key == '$text_score':
key_list.append(('_text_score', {'$meta': "textScore"})) key_list.append(('_text_score', {'$meta': 'textScore'}))
continue continue
direction = pymongo.ASCENDING direction = pymongo.ASCENDING
@ -1775,7 +1769,7 @@ class BaseQuerySet(object):
field_name = match.group(1).split('.') field_name = match.group(1).split('.')
fields = self._document._lookup_field(field_name) fields = self._document._lookup_field(field_name)
# Substitute the correct name for the field into the javascript # Substitute the correct name for the field into the javascript
return ".".join([f.db_field for f in fields]) return '.'.join([f.db_field for f in fields])
code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
@ -1786,21 +1780,21 @@ class BaseQuerySet(object):
queryset = self.clone() queryset = self.clone()
method = getattr(queryset._cursor, method_name) method = getattr(queryset._cursor, method_name)
method(val) method(val)
setattr(queryset, "_" + method_name, val) setattr(queryset, '_' + method_name, val)
return queryset return queryset
# Deprecated # Deprecated
def ensure_index(self, **kwargs): def ensure_index(self, **kwargs):
"""Deprecated use :func:`Document.ensure_index`""" """Deprecated use :func:`Document.ensure_index`"""
msg = ("Doc.objects()._ensure_index() is deprecated. " msg = ('Doc.objects()._ensure_index() is deprecated. '
"Use Doc.ensure_index() instead.") 'Use Doc.ensure_index() instead.')
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_index(**kwargs) self._document.__class__.ensure_index(**kwargs)
return self return self
def _ensure_indexes(self): def _ensure_indexes(self):
"""Deprecated use :func:`~Document.ensure_indexes`""" """Deprecated use :func:`~Document.ensure_indexes`"""
msg = ("Doc.objects()._ensure_indexes() is deprecated. " msg = ('Doc.objects()._ensure_indexes() is deprecated. '
"Use Doc.ensure_indexes() instead.") 'Use Doc.ensure_indexes() instead.')
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_indexes() self._document.__class__.ensure_indexes()

View File

@ -53,15 +53,14 @@ class QuerySet(BaseQuerySet):
return self._len return self._len
def __repr__(self): def __repr__(self):
"""Provides the string representation of the QuerySet """Provide a string representation of the QuerySet"""
"""
if self._iter: if self._iter:
return '.. queryset mid-iteration ..' return '.. queryset mid-iteration ..'
self._populate_cache() self._populate_cache()
data = self._result_cache[:REPR_OUTPUT_SIZE + 1] data = self._result_cache[:REPR_OUTPUT_SIZE + 1]
if len(data) > REPR_OUTPUT_SIZE: if len(data) > REPR_OUTPUT_SIZE:
data[-1] = "...(remaining elements truncated)..." data[-1] = '...(remaining elements truncated)...'
return repr(data) return repr(data)
def _iter_results(self): def _iter_results(self):
@ -142,7 +141,7 @@ class QuerySet(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to non caching queryset .. versionadded:: 0.8.3 Convert to non caching queryset
""" """
if self._result_cache is not None: if self._result_cache is not None:
raise OperationError("QuerySet already cached") raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection)) return self.clone_into(QuerySetNoCache(self._document, self._collection))
@ -171,7 +170,7 @@ class QuerySetNoCache(BaseQuerySet):
except StopIteration: except StopIteration:
break break
if len(data) > REPR_OUTPUT_SIZE: if len(data) > REPR_OUTPUT_SIZE:
data[-1] = "...(remaining elements truncated)..." data[-1] = '...(remaining elements truncated)...'
self.rewind() self.rewind()
return repr(data) return repr(data)

View File

@ -28,12 +28,11 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
def query(_doc_cls=None, **kwargs): def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format."""
"""
mongo_query = {} mongo_query = {}
merge_query = defaultdict(list) merge_query = defaultdict(list)
for key, value in sorted(kwargs.items()): for key, value in sorted(kwargs.items()):
if key == "__raw__": if key == '__raw__':
mongo_query.update(value) mongo_query.update(value)
continue continue
@ -46,7 +45,7 @@ def query(_doc_cls=None, **kwargs):
op = parts.pop() op = parts.pop()
# Allow to escape operator-like field name by __ # Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "": if len(parts) > 1 and parts[-1] == '':
parts.pop() parts.pop()
negate = False negate = False
@ -117,10 +116,10 @@ def query(_doc_cls=None, **kwargs):
value = query(field.field.document_type, **value) value = query(field.field.document_type, **value)
else: else:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
value = {"$elemMatch": value} value = {'$elemMatch': value}
elif op in CUSTOM_OPERATORS: elif op in CUSTOM_OPERATORS:
NotImplementedError("Custom method '%s' has not " NotImplementedError('Custom method "%s" has not '
"been implemented" % op) 'been implemented' % op)
elif op not in STRING_OPERATORS: elif op not in STRING_OPERATORS:
value = {'$' + op: value} value = {'$' + op: value}
@ -183,15 +182,16 @@ def query(_doc_cls=None, **kwargs):
def update(_doc_cls=None, **update): def update(_doc_cls=None, **update):
"""Transform an update spec from Django-style format to Mongo format. """Transform an update spec from Django-style format to Mongo
format.
""" """
mongo_update = {} mongo_update = {}
for key, value in update.items(): for key, value in update.items():
if key == "__raw__": if key == '__raw__':
mongo_update.update(value) mongo_update.update(value)
continue continue
parts = key.split('__') parts = key.split('__')
# if there is no operator, default to "set" # if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
parts.insert(0, 'set') parts.insert(0, 'set')
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
@ -210,14 +210,14 @@ def update(_doc_cls=None, **update):
elif op == 'add_to_set': elif op == 'add_to_set':
op = 'addToSet' op = 'addToSet'
elif op == 'set_on_insert': elif op == 'set_on_insert':
op = "setOnInsert" op = 'setOnInsert'
match = None match = None
if parts[-1] in COMPARISON_OPERATORS: if parts[-1] in COMPARISON_OPERATORS:
match = parts.pop() match = parts.pop()
# Allow to escape operator-like field name by __ # Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "": if len(parts) > 1 and parts[-1] == '':
parts.pop() parts.pop()
if _doc_cls: if _doc_cls:
@ -253,7 +253,7 @@ def update(_doc_cls=None, **update):
else: else:
field = cleaned_fields[-1] field = cleaned_fields[-1]
GeoJsonBaseField = _import_class("GeoJsonBaseField") GeoJsonBaseField = _import_class('GeoJsonBaseField')
if isinstance(field, GeoJsonBaseField): if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value) value = field.to_mongo(value)
@ -267,7 +267,7 @@ def update(_doc_cls=None, **update):
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
elif field.required or value is not None: elif field.required or value is not None:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op == "unset": elif op == 'unset':
value = 1 value = 1
if match: if match:
@ -277,16 +277,16 @@ def update(_doc_cls=None, **update):
key = '.'.join(parts) key = '.'.join(parts)
if not op: if not op:
raise InvalidQueryError("Updates must supply an operation " raise InvalidQueryError('Updates must supply an operation '
"eg: set__FIELD=value") 'eg: set__FIELD=value')
if 'pull' in op and '.' in key: if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations # Dot operators don't work on pull operations
# unless they point to a list field # unless they point to a list field
# Otherwise it uses nested dict syntax # Otherwise it uses nested dict syntax
if op == 'pullAll': if op == 'pullAll':
raise InvalidQueryError("pullAll operations only support " raise InvalidQueryError('pullAll operations only support '
"a single field depth") 'a single field depth')
# Look for the last list field and use dot notation until there # Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields] field_classes = [c.__class__ for c in cleaned_fields]
@ -297,7 +297,7 @@ def update(_doc_cls=None, **update):
# Then process as normal # Then process as normal
last_listField = len( last_listField = len(
cleaned_fields) - field_classes.index(ListField) cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField]) key = '.'.join(parts[:last_listField])
parts = parts[last_listField:] parts = parts[last_listField:]
parts.insert(0, key) parts.insert(0, key)
@ -305,7 +305,7 @@ def update(_doc_cls=None, **update):
for key in parts: for key in parts:
value = {key: value} value = {key: value}
elif op == 'addToSet' and isinstance(value, list): elif op == 'addToSet' and isinstance(value, list):
value = {key: {"$each": value}} value = {key: {'$each': value}}
else: else:
value = {key: value} value = {key: value}
key = '$' + op key = '$' + op
@ -319,78 +319,82 @@ def update(_doc_cls=None, **update):
def _geo_operator(field, op, value): def _geo_operator(field, op, value):
"""Helper to return the query for a given geo query""" """Helper to return the query for a given geo query."""
if op == "max_distance": if op == 'max_distance':
value = {'$maxDistance': value} value = {'$maxDistance': value}
elif op == "min_distance": elif op == 'min_distance':
value = {'$minDistance': value} value = {'$minDistance': value}
elif field._geo_index == pymongo.GEO2D: elif field._geo_index == pymongo.GEO2D:
if op == "within_distance": if op == 'within_distance':
value = {'$within': {'$center': value}} value = {'$within': {'$center': value}}
elif op == "within_spherical_distance": elif op == 'within_spherical_distance':
value = {'$within': {'$centerSphere': value}} value = {'$within': {'$centerSphere': value}}
elif op == "within_polygon": elif op == 'within_polygon':
value = {'$within': {'$polygon': value}} value = {'$within': {'$polygon': value}}
elif op == "near": elif op == 'near':
value = {'$near': value} value = {'$near': value}
elif op == "near_sphere": elif op == 'near_sphere':
value = {'$nearSphere': value} value = {'$nearSphere': value}
elif op == 'within_box': elif op == 'within_box':
value = {'$within': {'$box': value}} value = {'$within': {'$box': value}}
else: else:
raise NotImplementedError("Geo method '%s' has not " raise NotImplementedError('Geo method "%s" has not been '
"been implemented for a GeoPointField" % op) 'implemented for a GeoPointField' % op)
else: else:
if op == "geo_within": if op == 'geo_within':
value = {"$geoWithin": _infer_geometry(value)} value = {'$geoWithin': _infer_geometry(value)}
elif op == "geo_within_box": elif op == 'geo_within_box':
value = {"$geoWithin": {"$box": value}} value = {'$geoWithin': {'$box': value}}
elif op == "geo_within_polygon": elif op == 'geo_within_polygon':
value = {"$geoWithin": {"$polygon": value}} value = {'$geoWithin': {'$polygon': value}}
elif op == "geo_within_center": elif op == 'geo_within_center':
value = {"$geoWithin": {"$center": value}} value = {'$geoWithin': {'$center': value}}
elif op == "geo_within_sphere": elif op == 'geo_within_sphere':
value = {"$geoWithin": {"$centerSphere": value}} value = {'$geoWithin': {'$centerSphere': value}}
elif op == "geo_intersects": elif op == 'geo_intersects':
value = {"$geoIntersects": _infer_geometry(value)} value = {'$geoIntersects': _infer_geometry(value)}
elif op == "near": elif op == 'near':
value = {'$near': _infer_geometry(value)} value = {'$near': _infer_geometry(value)}
else: else:
raise NotImplementedError("Geo method '%s' has not " raise NotImplementedError(
"been implemented for a %s " % (op, field._name)) 'Geo method "%s" has not been implemented for a %s '
% (op, field._name)
)
return value return value
def _infer_geometry(value): def _infer_geometry(value):
"""Helper method that tries to infer the $geometry shape for a given value""" """Helper method that tries to infer the $geometry shape for a
given value.
"""
if isinstance(value, dict): if isinstance(value, dict):
if "$geometry" in value: if '$geometry' in value:
return value return value
elif 'coordinates' in value and 'type' in value: elif 'coordinates' in value and 'type' in value:
return {"$geometry": value} return {'$geometry': value}
raise InvalidQueryError("Invalid $geometry dictionary should have " raise InvalidQueryError('Invalid $geometry dictionary should have '
"type and coordinates keys") 'type and coordinates keys')
elif isinstance(value, (list, set)): elif isinstance(value, (list, set)):
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
# TODO: should both TypeError and IndexError be alike interpreted? # TODO: should both TypeError and IndexError be alike interpreted?
try: try:
value[0][0][0] value[0][0][0]
return {"$geometry": {"type": "Polygon", "coordinates": value}} return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
try: try:
value[0][0] value[0][0]
return {"$geometry": {"type": "LineString", "coordinates": value}} return {'$geometry': {'type': 'LineString', 'coordinates': value}}
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
try: try:
value[0] value[0]
return {"$geometry": {"type": "Point", "coordinates": value}} return {'$geometry': {'type': 'Point', 'coordinates': value}}
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " raise InvalidQueryError('Invalid $geometry data. Can be either a '
"or (nested) lists of coordinate(s)") 'dictionary or (nested) lists of coordinate(s)')

View File

@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor):
self.document = document self.document = document
def visit_combination(self, combination): def visit_combination(self, combination):
operator = "$and" operator = '$and'
if combination.operation == combination.OR: if combination.operation == combination.OR:
operator = "$or" operator = '$or'
return {operator: combination.children} return {operator: combination.children}
def visit_query(self, query): def visit_query(self, query):
@ -79,8 +79,7 @@ class QueryCompilerVisitor(QNodeVisitor):
class QNode(object): class QNode(object):
"""Base class for nodes in query trees. """Base class for nodes in query trees."""
"""
AND = 0 AND = 0
OR = 1 OR = 1
@ -94,7 +93,8 @@ class QNode(object):
raise NotImplementedError raise NotImplementedError
def _combine(self, other, operation): def _combine(self, other, operation):
"""Combine this node with another node into a QCombination object. """Combine this node with another node into a QCombination
object.
""" """
if getattr(other, 'empty', True): if getattr(other, 'empty', True):
return self return self
@ -116,8 +116,8 @@ class QNode(object):
class QCombination(QNode): class QCombination(QNode):
"""Represents the combination of several conditions by a given logical """Represents the combination of several conditions by a given
operator. logical operator.
""" """
def __init__(self, operation, children): def __init__(self, operation, children):