Prefer ' over " + minor docstring tweaks

This commit is contained in:
Stefan Wojcik 2016-12-08 22:44:02 -05:00
parent 76219901db
commit 5b7b65a750
13 changed files with 387 additions and 417 deletions

View File

@ -10,7 +10,7 @@ __all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList')
class BaseDict(dict):
"""A special dict so we can watch any changes"""
"""A special dict so we can watch any changes."""
_dereferenced = False
_instance = None
@ -95,8 +95,7 @@ class BaseDict(dict):
class BaseList(list):
"""A special list so we can watch any changes
"""
"""A special list so we can watch any changes."""
_dereferenced = False
_instance = None
@ -213,7 +212,7 @@ class EmbeddedDocumentList(BaseList):
@classmethod
def __match_all(cls, embedded_doc, kwargs):
"""Return True if a given embedded doc matches all the filter
kwargs. If it doesn't return False
kwargs. If it doesn't return False.
"""
for key, expected_value in kwargs.items():
doc_val = getattr(embedded_doc, key)
@ -292,18 +291,18 @@ class EmbeddedDocumentList(BaseList):
values = self.__only_matches(self, kwargs)
if len(values) == 0:
raise DoesNotExist(
"%s matching query does not exist." % self._name
'%s matching query does not exist.' % self._name
)
elif len(values) > 1:
raise MultipleObjectsReturned(
"%d items returned, instead of 1" % len(values)
'%d items returned, instead of 1' % len(values)
)
return values[0]
def first(self):
"""
Returns the first embedded document in the list, or ``None`` if empty.
"""Return the first embedded document in the list, or ``None``
if empty.
"""
if len(self) > 0:
return self[0]
@ -445,7 +444,7 @@ class StrictDict(object):
__slots__ = allowed_keys_tuple
def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]

View File

@ -54,15 +54,15 @@ class BaseDocument(object):
name = next(field)
if name in values:
raise TypeError(
"Multiple values for keyword argument '" + name + "'")
'Multiple values for keyword argument "%s"' % name)
values[name] = value
__auto_convert = values.pop("__auto_convert", True)
__auto_convert = values.pop('__auto_convert', True)
# 399: set default values only to fields loaded from DB
__only_fields = set(values.pop("__only_fields", values))
__only_fields = set(values.pop('__only_fields', values))
_created = values.pop("_created", True)
_created = values.pop('_created', True)
signals.pre_init.send(self.__class__, document=self, values=values)
@ -73,7 +73,7 @@ class BaseDocument(object):
self._fields.keys() + ['id', 'pk', '_cls', '_text_score'])
if _undefined_fields:
msg = (
"The fields '{0}' do not exist on the document '{1}'"
'The fields "{0}" do not exist on the document "{1}"'
).format(_undefined_fields, self._class_name)
raise FieldDoesNotExist(msg)
@ -92,7 +92,7 @@ class BaseDocument(object):
value = getattr(self, key, None)
setattr(self, key, value)
if "_cls" not in values:
if '_cls' not in values:
self._cls = self._class_name
# Set passed values after initialisation
@ -147,7 +147,7 @@ class BaseDocument(object):
if self._dynamic and not self._dynamic_lock:
if not hasattr(self, name) and not name.startswith('_'):
DynamicField = _import_class("DynamicField")
DynamicField = _import_class('DynamicField')
field = DynamicField(db_field=name)
field.name = name
self._dynamic_fields[name] = field
@ -172,7 +172,7 @@ class BaseDocument(object):
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value
):
msg = "Shard Keys are immutable. Tried to update %s" % name
msg = 'Shard Keys are immutable. Tried to update %s' % name
raise OperationError(msg)
try:
@ -196,8 +196,8 @@ class BaseDocument(object):
return data
def __setstate__(self, data):
if isinstance(data["_data"], SON):
data["_data"] = self.__class__._from_son(data["_data"])._data
if isinstance(data['_data'], SON):
data['_data'] = self.__class__._from_son(data['_data'])._data
for k in ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields'):
if k in data:
@ -211,7 +211,7 @@ class BaseDocument(object):
dynamic_fields = data.get('_dynamic_fields') or SON()
for k in dynamic_fields.keys():
setattr(self, k, data["_data"].get(k))
setattr(self, k, data['_data'].get(k))
def __iter__(self):
return iter(self._fields_ordered)
@ -373,9 +373,9 @@ class BaseDocument(object):
fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
self._data.get(name)) for name in self._fields_ordered]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
GenericEmbeddedDocumentField = _import_class(
"GenericEmbeddedDocumentField")
'GenericEmbeddedDocumentField')
for field, value in fields:
if value is not None:
@ -394,12 +394,12 @@ class BaseDocument(object):
field_name=field.name)
if errors:
pk = "None"
pk = 'None'
if hasattr(self, 'pk'):
pk = self.pk
elif self._instance and hasattr(self._instance, 'pk'):
pk = self._instance.pk
message = "ValidationError (%s:%s) " % (self._class_name, pk)
message = 'ValidationError (%s:%s) ' % (self._class_name, pk)
raise ValidationError(message, errors=errors)
def to_json(self, *args, **kwargs):
@ -455,8 +455,7 @@ class BaseDocument(object):
return value
def _mark_as_changed(self, key):
"""Marks a key as explicitly changed by the user
"""
"""Mark a key as explicitly changed by the user."""
if not key:
return
@ -489,7 +488,7 @@ class BaseDocument(object):
"""Using get_changed_fields iterate and remove any fields that are
marked as changed"""
for changed in self._get_changed_fields():
parts = changed.split(".")
parts = changed.split('.')
data = self
for part in parts:
if isinstance(data, list):
@ -501,8 +500,8 @@ class BaseDocument(object):
data = data.get(part, None)
else:
data = getattr(data, part, None)
if hasattr(data, "_changed_fields"):
if hasattr(data, "_is_document") and data._is_document:
if hasattr(data, '_changed_fields'):
if hasattr(data, '_is_document') and data._is_document:
continue
data._changed_fields = []
self._changed_fields = []
@ -516,26 +515,26 @@ class BaseDocument(object):
iterator = data.iteritems()
for index, value in iterator:
list_key = "%s%s." % (key, index)
list_key = '%s%s.' % (key, index)
# don't check anything lower if this key is already marked
# as changed.
if list_key[:-1] in changed_fields:
continue
if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected)
changed_fields += ["%s%s" % (list_key, k)
changed_fields += ['%s%s' % (list_key, k)
for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(
changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed.
"""Return a list of all fields that have explicitly been changed.
"""
EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
ReferenceField = _import_class("ReferenceField")
SortedListField = _import_class("SortedListField")
EmbeddedDocument = _import_class('EmbeddedDocument')
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
ReferenceField = _import_class('ReferenceField')
SortedListField = _import_class('SortedListField')
changed_fields = []
changed_fields += getattr(self, '_changed_fields', [])
@ -563,7 +562,7 @@ class BaseDocument(object):
):
# Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected)
changed_fields += ["%s%s" % (key, k) for k in changed if k]
changed_fields += ['%s%s' % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in changed_fields):
if (hasattr(field, 'field') and
@ -667,13 +666,15 @@ class BaseDocument(object):
@classmethod
def _get_collection_name(cls):
"""Returns the collection name for this class. None for abstract class
"""Return the collection name for this class. None for abstract
class.
"""
return cls._meta.get('collection', None)
@classmethod
def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
"""Create an instance of a Document (subclass) from a PyMongo SON.
"""Create an instance of a Document (subclass) from a PyMongo
SON.
"""
if not only_fields:
only_fields = []
@ -681,7 +682,7 @@ class BaseDocument(object):
# get the class name from the document, falling back to the given
# class if unavailable
class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.iteritems())
data = dict(('%s' % key, value) for key, value in son.iteritems())
# Return correct subclass for document type
if class_name != cls._class_name:
@ -707,9 +708,9 @@ class BaseDocument(object):
errors_dict[field_name] = e
if errors_dict:
errors = "\n".join(["%s - %s" % (k, v)
errors = '\n'.join(['%s - %s' % (k, v)
for k, v in errors_dict.items()])
msg = ("Invalid data to create a `%s` instance.\n%s"
msg = ('Invalid data to create a `%s` instance.\n%s'
% (cls._class_name, errors))
raise InvalidDocumentError(msg)
@ -782,7 +783,7 @@ class BaseDocument(object):
# 733: don't include cls if index_cls is False unless there is an explicit cls with the index
include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
if "cls" in spec:
if 'cls' in spec:
spec.pop('cls')
for key in spec['fields']:
# If inherited spec continue
@ -797,19 +798,19 @@ class BaseDocument(object):
# GEOHAYSTACK from )
# GEO2D from *
direction = pymongo.ASCENDING
if key.startswith("-"):
if key.startswith('-'):
direction = pymongo.DESCENDING
elif key.startswith("$"):
elif key.startswith('$'):
direction = pymongo.TEXT
elif key.startswith("#"):
elif key.startswith('#'):
direction = pymongo.HASHED
elif key.startswith("("):
elif key.startswith('('):
direction = pymongo.GEOSPHERE
elif key.startswith(")"):
elif key.startswith(')'):
direction = pymongo.GEOHAYSTACK
elif key.startswith("*"):
elif key.startswith('*'):
direction = pymongo.GEO2D
if key.startswith(("+", "-", "*", "$", "#", "(", ")")):
if key.startswith(('+', '-', '*', '$', '#', '(', ')')):
key = key[1:]
# Use real field name, do it manually because we need field
@ -822,7 +823,7 @@ class BaseDocument(object):
parts = []
for field in fields:
try:
if field != "_id":
if field != '_id':
field = field.db_field
except AttributeError:
pass
@ -841,7 +842,7 @@ class BaseDocument(object):
return spec
@classmethod
def _unique_with_indexes(cls, namespace=""):
def _unique_with_indexes(cls, namespace=''):
"""Find unique indexes in the document schema and return them."""
unique_indexes = []
for field_name, field in cls._fields.items():
@ -875,7 +876,7 @@ class BaseDocument(object):
# Add the new index to the list
fields = [
("%s%s" % (namespace, f), pymongo.ASCENDING)
('%s%s' % (namespace, f), pymongo.ASCENDING)
for f in unique_fields
]
index = {'fields': fields, 'unique': True, 'sparse': sparse}
@ -887,7 +888,7 @@ class BaseDocument(object):
# Grab any embedded document field unique indexes
if (field.__class__.__name__ == 'EmbeddedDocumentField' and
field.document_type != cls):
field_namespace = "%s." % field_name
field_namespace = '%s.' % field_name
doc_cls = field.document_type
unique_indexes += doc_cls._unique_with_indexes(field_namespace)
@ -921,7 +922,7 @@ class BaseDocument(object):
elif field._geo_index:
field_name = field.db_field
if parent_field:
field_name = "%s.%s" % (parent_field, field_name)
field_name = '%s.%s' % (parent_field, field_name)
geo_indices.append({
'fields': [(field_name, field._geo_index)]
})
@ -965,7 +966,7 @@ class BaseDocument(object):
# TODO this method is WAY too complicated. Simplify it.
# TODO don't think returning a string for embedded non-existent fields is desired
ListField = _import_class("ListField")
ListField = _import_class('ListField')
DynamicField = _import_class('DynamicField')
if not isinstance(parts, (list, tuple)):

View File

@ -69,7 +69,7 @@ class BaseField(object):
self.db_field = (db_field or name) if not primary_key else '_id'
if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
msg = 'Field\'s "name" attribute deprecated in favour of "db_field"'
warnings.warn(msg, DeprecationWarning)
self.required = required or primary_key
self.default = default
@ -85,7 +85,7 @@ class BaseField(object):
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
raise TypeError("%s already has attribute(s): %s" % (
raise TypeError('%s already has attribute(s): %s' % (
self.__class__.__name__, ', '.join(conflicts)))
# Assign metadata to the instance
@ -143,25 +143,21 @@ class BaseField(object):
v._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None):
"""Raises a ValidationError.
"""
def error(self, message='', errors=None, field_name=None):
"""Raise a ValidationError."""
field_name = field_name if field_name else self.name
raise ValidationError(message, errors=errors, field_name=field_name)
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
"""Convert a MongoDB-compatible type to a Python type."""
return value
def to_mongo(self, value):
"""Convert a Python type to a MongoDB-compatible type.
"""
"""Convert a Python type to a MongoDB-compatible type."""
return self.to_python(value)
def _to_mongo_safe_call(self, value, use_db_field=True, fields=None):
"""A helper method to call to_mongo with proper inputs
"""
"""Helper method to call to_mongo with proper inputs."""
f_inputs = self.to_mongo.__code__.co_varnames
ex_vars = {}
if 'fields' in f_inputs:
@ -173,15 +169,13 @@ class BaseField(object):
return self.to_mongo(value, **ex_vars)
def prepare_query_value(self, op, value):
"""Prepare a value that is being used in a query for PyMongo.
"""
"""Prepare a value that is being used in a query for PyMongo."""
if op in UPDATE_OPERATORS:
self.validate(value)
return value
def validate(self, value, clean=True):
"""Perform validation on a value.
"""
"""Perform validation on a value."""
pass
def _validate_choices(self, value):
@ -245,8 +239,7 @@ class ComplexBaseField(BaseField):
field = None
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
"""Descriptor to automatically dereference references."""
if instance is None:
# Document class being used rather than a document object
return self
@ -258,7 +251,7 @@ class ComplexBaseField(BaseField):
(self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))))
_dereference = _import_class("DeReference")()
_dereference = _import_class('DeReference')()
self._auto_dereference = instance._fields[self.name]._auto_dereference
if instance._initialised and dereference and instance._data.get(self.name):
@ -293,8 +286,7 @@ class ComplexBaseField(BaseField):
return value
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
"""Convert a MongoDB-compatible type to a Python type."""
if isinstance(value, six.string_types):
return value
@ -335,11 +327,10 @@ class ComplexBaseField(BaseField):
return value_dict
def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB-compatible type.
"""
Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
GenericReferenceField = _import_class("GenericReferenceField")
"""Convert a Python type to a MongoDB-compatible type."""
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
GenericReferenceField = _import_class('GenericReferenceField')
if isinstance(value, six.string_types):
return value
@ -400,8 +391,7 @@ class ComplexBaseField(BaseField):
return value_dict
def validate(self, value):
"""If field is provided ensure the value is valid.
"""
"""If field is provided ensure the value is valid."""
errors = {}
if self.field:
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
@ -439,8 +429,7 @@ class ComplexBaseField(BaseField):
class ObjectIdField(BaseField):
"""A field wrapper around MongoDB's ObjectIds.
"""
"""A field wrapper around MongoDB's ObjectIds."""
def to_python(self, value):
try:
@ -476,21 +465,20 @@ class GeoJsonBaseField(BaseField):
"""
_geo_index = pymongo.GEOSPHERE
_type = "GeoBase"
_type = 'GeoBase'
def __init__(self, auto_index=True, *args, **kwargs):
"""
:param bool auto_index: Automatically create a "2dsphere" index.\
:param bool auto_index: Automatically create a '2dsphere' index.\
Defaults to `True`.
"""
self._name = "%sField" % self._type
self._name = '%sField' % self._type
if not auto_index:
self._geo_index = False
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Validate the GeoJson object based on its type
"""
"""Validate the GeoJson object based on its type."""
if isinstance(value, dict):
if set(value.keys()) == set(['type', 'coordinates']):
if value['type'] != self._type:
@ -505,7 +493,7 @@ class GeoJsonBaseField(BaseField):
self.error('%s can only accept lists of [x, y]' % self._name)
return
validate = getattr(self, "_validate_%s" % self._type.lower())
validate = getattr(self, '_validate_%s' % self._type.lower())
error = validate(value)
if error:
self.error(error)
@ -518,7 +506,7 @@ class GeoJsonBaseField(BaseField):
try:
value[0][0][0]
except (TypeError, IndexError):
return "Invalid Polygon must contain at least one valid linestring"
return 'Invalid Polygon must contain at least one valid linestring'
errors = []
for val in value:
@ -529,12 +517,12 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
if top_level:
return "Invalid Polygon:\n%s" % ", ".join(errors)
return 'Invalid Polygon:\n%s' % ', '.join(errors)
else:
return "%s" % ", ".join(errors)
return '%s' % ', '.join(errors)
def _validate_linestring(self, value, top_level=True):
"""Validates a linestring"""
"""Validate a linestring."""
if not isinstance(value, (list, tuple)):
return 'LineStrings must contain list of coordinate pairs'
@ -542,7 +530,7 @@ class GeoJsonBaseField(BaseField):
try:
value[0][0]
except (TypeError, IndexError):
return "Invalid LineString must contain at least one valid point"
return 'Invalid LineString must contain at least one valid point'
errors = []
for val in value:
@ -551,19 +539,19 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
if top_level:
return "Invalid LineString:\n%s" % ", ".join(errors)
return 'Invalid LineString:\n%s' % ', '.join(errors)
else:
return "%s" % ", ".join(errors)
return '%s' % ', '.join(errors)
def _validate_point(self, value):
"""Validate each set of coords"""
if not isinstance(value, (list, tuple)):
return 'Points must be a list of coordinate pairs'
elif not len(value) == 2:
return "Value (%s) must be a two-dimensional point" % repr(value)
return 'Value (%s) must be a two-dimensional point' % repr(value)
elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))):
return "Both values (%s) in point must be float or int" % repr(value)
return 'Both values (%s) in point must be float or int' % repr(value)
def _validate_multipoint(self, value):
if not isinstance(value, (list, tuple)):
@ -573,7 +561,7 @@ class GeoJsonBaseField(BaseField):
try:
value[0][0]
except (TypeError, IndexError):
return "Invalid MultiPoint must contain at least one valid point"
return 'Invalid MultiPoint must contain at least one valid point'
errors = []
for point in value:
@ -582,7 +570,7 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
return "%s" % ", ".join(errors)
return '%s' % ', '.join(errors)
def _validate_multilinestring(self, value, top_level=True):
if not isinstance(value, (list, tuple)):
@ -592,7 +580,7 @@ class GeoJsonBaseField(BaseField):
try:
value[0][0][0]
except (TypeError, IndexError):
return "Invalid MultiLineString must contain at least one valid linestring"
return 'Invalid MultiLineString must contain at least one valid linestring'
errors = []
for linestring in value:
@ -602,9 +590,9 @@ class GeoJsonBaseField(BaseField):
if errors:
if top_level:
return "Invalid MultiLineString:\n%s" % ", ".join(errors)
return 'Invalid MultiLineString:\n%s' % ', '.join(errors)
else:
return "%s" % ", ".join(errors)
return '%s' % ', '.join(errors)
def _validate_multipolygon(self, value):
if not isinstance(value, (list, tuple)):
@ -614,7 +602,7 @@ class GeoJsonBaseField(BaseField):
try:
value[0][0][0][0]
except (TypeError, IndexError):
return "Invalid MultiPolygon must contain at least one valid Polygon"
return 'Invalid MultiPolygon must contain at least one valid Polygon'
errors = []
for polygon in value:
@ -623,9 +611,9 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
return 'Invalid MultiPolygon:\n%s' % ', '.join(errors)
def to_mongo(self, value):
if isinstance(value, dict):
return value
return SON([("type", self._type), ("coordinates", value)])
return SON([('type', self._type), ('coordinates', value)])

View File

@ -88,8 +88,8 @@ class DocumentMetaclass(type):
# Ensure no duplicate db_fields
duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
if duplicate_db_fields:
msg = ("Multiple db_fields defined for: %s " %
", ".join(duplicate_db_fields))
msg = ('Multiple db_fields defined for: %s ' %
', '.join(duplicate_db_fields))
raise InvalidDocumentError(msg)
# Set _fields and db_field maps
@ -178,11 +178,11 @@ class DocumentMetaclass(type):
if isinstance(f, CachedReferenceField):
if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError(
"CachedReferenceFields is not allowed in EmbeddedDocuments")
raise InvalidDocumentError('CachedReferenceFields is not '
'allowed in EmbeddedDocuments')
if not f.document_type:
raise InvalidDocumentError(
"Document is not available to sync")
'Document is not available to sync')
if f.auto_sync:
f.start_listener()
@ -194,8 +194,8 @@ class DocumentMetaclass(type):
'reverse_delete_rule',
DO_NOTHING)
if isinstance(f, DictField) and delete_rule != DO_NOTHING:
msg = ("Reverse delete rules are not supported "
"for %s (field: %s)" %
msg = ('Reverse delete rules are not supported '
'for %s (field: %s)' %
(field.__class__.__name__, field.name))
raise InvalidDocumentError(msg)
@ -203,16 +203,16 @@ class DocumentMetaclass(type):
if delete_rule != DO_NOTHING:
if issubclass(new_class, EmbeddedDocument):
msg = ("Reverse delete rules are not supported for "
"EmbeddedDocuments (field: %s)" % field.name)
msg = ('Reverse delete rules are not supported for '
'EmbeddedDocuments (field: %s)' % field.name)
raise InvalidDocumentError(msg)
f.document_type.register_delete_rule(new_class,
field.name, delete_rule)
if (field.name and hasattr(Document, field.name) and
EmbeddedDocument not in new_class.mro()):
msg = ("%s is a document method and not a valid "
"field name" % field.name)
msg = ('%s is a document method and not a valid '
'field name' % field.name)
raise InvalidDocumentError(msg)
return new_class
@ -302,7 +302,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# If parent wasn't an abstract class
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and
not parent_doc_cls._meta.get('abstract', True)):
msg = "Trying to set a collection on a subclass (%s)" % name
msg = 'Trying to set a collection on a subclass (%s)' % name
warnings.warn(msg, SyntaxWarning)
del attrs['_meta']['collection']
@ -310,7 +310,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
if (parent_doc_cls and
not parent_doc_cls._meta.get('abstract', False)):
msg = "Abstract document cannot have non-abstract base"
msg = 'Abstract document cannot have non-abstract base'
raise ValueError(msg)
return super_new(cls, name, bases, attrs)

View File

@ -2,12 +2,12 @@ from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
__all__ = ("switch_db", "switch_collection", "no_dereference",
"no_sub_classes", "query_counter")
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
'no_sub_classes', 'query_counter')
class switch_db(object):
""" switch_db alias context manager.
"""switch_db alias context manager.
Example ::
@ -18,15 +18,14 @@ class switch_db(object):
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
Group(name='test').save() # Saves in the default db
with switch_db(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
Group(name='hello testdb!').save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the switch_db context manager
"""Construct the switch_db context manager
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
@ -34,37 +33,36 @@ class switch_db(object):
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
"""Change the db_alias and clear the cached collection."""
self.cls._meta['db_alias'] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
"""Reset the db_alias and collection."""
self.cls._meta['db_alias'] = self.ori_db_alias
self.cls._collection = self.collection
class switch_collection(object):
""" switch_collection alias context manager.
"""switch_collection alias context manager.
Example ::
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
Group(name='test').save() # Saves in the default db
with switch_collection(Group, 'group1') as Group:
Group(name="hello testdb!").save() # Saves in group1 collection
Group(name='hello testdb!').save() # Saves in group1 collection
"""
def __init__(self, cls, collection_name):
""" Construct the switch_collection context manager
"""Construct the switch_collection context manager.
:param cls: the class to change the registered db
:param collection_name: the name of the collection to use
@ -75,7 +73,7 @@ class switch_collection(object):
self.collection_name = collection_name
def __enter__(self):
""" change the _get_collection_name and clear the cached collection """
"""Change the _get_collection_name and clear the cached collection."""
@classmethod
def _get_collection_name(cls):
@ -86,24 +84,23 @@ class switch_collection(object):
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the collection """
"""Reset the collection."""
self.cls._collection = self.ori_collection
self.cls._get_collection_name = self.ori_get_collection_name
class no_dereference(object):
""" no_dereference context manager.
"""no_dereference context manager.
Turns off all dereferencing in Documents for the duration of the context
manager::
with no_dereference(Group) as Group:
Group.objects.find()
"""
def __init__(self, cls):
""" Construct the no_dereference context manager.
"""Construct the no_dereference context manager.
:param cls: the class to turn dereferencing off on
"""
@ -119,103 +116,102 @@ class no_dereference(object):
ComplexBaseField))]
def __enter__(self):
""" change the objects default and _auto_dereference values"""
"""Change the objects default and _auto_dereference values."""
for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values"""
"""Reset the default and _auto_dereference values."""
for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
return self.cls
class no_sub_classes(object):
""" no_sub_classes context manager.
"""no_sub_classes context manager.
Only returns instances of this class and no sub (inherited) classes::
with no_sub_classes(Group) as Group:
Group.objects.find()
"""
def __init__(self, cls):
""" Construct the no_sub_classes context manager.
"""Construct the no_sub_classes context manager.
:param cls: the class to turn querying sub classes on
"""
self.cls = cls
def __enter__(self):
""" change the objects default and _auto_dereference values"""
"""Change the objects default and _auto_dereference values."""
self.cls._all_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls,)
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the default and _auto_dereference values"""
"""Reset the default and _auto_dereference values."""
self.cls._subclasses = self.cls._all_subclasses
delattr(self.cls, '_all_subclasses')
return self.cls
class query_counter(object):
""" Query_counter context manager to get the number of queries. """
"""Query_counter context manager to get the number of queries."""
def __init__(self):
""" Construct the query_counter. """
"""Construct the query_counter."""
self.counter = 0
self.db = get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
"""On every with block we need to drop the profile collection."""
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
"""Reset the profiling level."""
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
"""== Compare querycounter."""
counter = self._get_count()
return value == counter
def __ne__(self, value):
""" != Compare querycounter. """
"""!= Compare querycounter."""
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
"""< Compare querycounter."""
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
"""<= Compare querycounter."""
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
"""> Compare querycounter."""
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
""">= Compare querycounter."""
return self._get_count() >= value
def __int__(self):
""" int representation. """
"""int representation."""
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
"""repr query_counter as the number of queries."""
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
"""Get the number of queries."""
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
count = self.db.system.profile.find(ignore_query).count() - self.counter
self.counter += 1
return count

View File

@ -149,7 +149,7 @@ class DeReference(object):
references = get_db()[collection].find({'_id': {'$in': refs}})
for ref in references:
if '_cls' in ref:
doc = get_document(ref["_cls"])._from_son(ref)
doc = get_document(ref['_cls'])._from_son(ref)
elif doc_type is None:
doc = get_document(
''.join(x.capitalize()
@ -225,7 +225,7 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection, v['_ref'].id), v)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = six.text_type("{0}.{1}.{2}").format(name, k, field_name)
item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name

View File

@ -25,9 +25,7 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
def includes_cls(fields):
""" Helper function used for ensuring and comparing indexes
"""
"""Helper function used for ensuring and comparing indexes."""
first_field = None
if len(fields):
if isinstance(fields[0], six.string_types):
@ -167,7 +165,7 @@ class Document(BaseDocument):
@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
@classmethod
def _get_collection(cls):
@ -241,15 +239,15 @@ class Document(BaseDocument):
query = {}
if self.pk is None:
raise InvalidDocumentError("The document does not have a primary key.")
raise InvalidDocumentError('The document does not have a primary key.')
id_field = self._meta["id_field"]
id_field = self._meta['id_field']
query = query.copy() if isinstance(query, dict) else query.to_query(self)
if id_field not in query:
query[id_field] = self.pk
elif query[id_field] != self.pk:
raise InvalidQueryError("Invalid document modify query: it must modify only this document.")
raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
updated = self._qs(**query).modify(new=True, **update)
if updated is None:
@ -324,7 +322,7 @@ class Document(BaseDocument):
self.validate(clean=clean)
if write_concern is None:
write_concern = {"w": 1}
write_concern = {'w': 1}
doc = self.to_mongo()
@ -372,7 +370,7 @@ class Document(BaseDocument):
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get("updatedExisting")
updated = last_error.get('updatedExisting')
if updated is not None:
return not updated
return created
@ -380,14 +378,14 @@ class Document(BaseDocument):
update_query = {}
if updates:
update_query["$set"] = updates
update_query['$set'] = updates
if removals:
update_query["$unset"] = removals
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error["n"] == 0:
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
created = is_new_object(last_error)
@ -398,10 +396,10 @@ class Document(BaseDocument):
if cascade:
kwargs = {
"force_insert": force_insert,
"validate": validate,
"write_concern": write_concern,
"cascade": cascade
'force_insert': force_insert,
'validate': validate,
'write_concern': write_concern,
'cascade': cascade
}
if cascade_kwargs: # Allow granular control over cascades
kwargs.update(cascade_kwargs)
@ -492,8 +490,8 @@ class Document(BaseDocument):
if self.pk is None:
if kwargs.get('upsert', False):
query = self.to_mongo()
if "_cls" in query:
del query["_cls"]
if '_cls' in query:
del query['_cls']
return self._qs.filter(**query).update_one(**kwargs)
else:
raise OperationError(
@ -618,11 +616,12 @@ class Document(BaseDocument):
if fields and isinstance(fields[0], int):
max_depth = fields[0]
fields = fields[1:]
elif "max_depth" in kwargs:
max_depth = kwargs["max_depth"]
elif 'max_depth' in kwargs:
max_depth = kwargs['max_depth']
if self.pk is None:
raise self.DoesNotExist("Document does not exist")
raise self.DoesNotExist('Document does not exist')
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).only(*fields).limit(
1).select_related(max_depth=max_depth)
@ -630,7 +629,7 @@ class Document(BaseDocument):
if obj:
obj = obj[0]
else:
raise self.DoesNotExist("Document does not exist")
raise self.DoesNotExist('Document does not exist')
for field in obj._data:
if not fields or field in fields:
@ -673,7 +672,7 @@ class Document(BaseDocument):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries."""
if self.pk is None:
msg = "Only saved documents can have a valid dbref"
msg = 'Only saved documents can have a valid dbref'
raise OperationError(msg)
return DBRef(self.__class__._get_collection_name(), self.pk)
@ -728,7 +727,7 @@ class Document(BaseDocument):
fields = index_spec.pop('fields')
drop_dups = kwargs.get('drop_dups', False)
if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3:
index_spec['drop_dups'] = drop_dups
@ -754,7 +753,7 @@ class Document(BaseDocument):
will be removed if PyMongo3+ is used
"""
if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3:
kwargs.update({'drop_dups': drop_dups})
@ -774,7 +773,7 @@ class Document(BaseDocument):
index_opts = cls._meta.get('index_opts') or {}
index_cls = cls._meta.get('index_cls', True)
if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
collection = cls._get_collection()
@ -889,8 +888,8 @@ class Document(BaseDocument):
@classmethod
def compare_indexes(cls):
""" Compares the indexes defined in MongoEngine with the ones existing
in the database. Returns any missing/extra indexes.
""" Compares the indexes defined in MongoEngine with the ones
existing in the database. Returns any missing/extra indexes.
"""
required = cls.list_indexes()
@ -934,8 +933,9 @@ class DynamicDocument(Document):
_dynamic = True
def __delattr__(self, *args, **kwargs):
"""Deletes the attribute by setting to None and allowing _delta to unset
it"""
"""Delete the attribute by setting to None and allowing _delta
to unset it.
"""
field_name = args[0]
if field_name in self._dynamic_fields:
setattr(self, field_name, None)
@ -957,8 +957,9 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
_dynamic = True
def __delattr__(self, *args, **kwargs):
"""Deletes the attribute by setting to None and allowing _delta to unset
it"""
"""Delete the attribute by setting to None and allowing _delta
to unset it.
"""
field_name = args[0]
if field_name in self._fields:
default = self._fields[field_name].default
@ -1000,10 +1001,10 @@ class MapReduceDocument(object):
try:
self.key = id_field_type(self.key)
except Exception:
raise Exception("Could not cast key as %s" %
raise Exception('Could not cast key as %s' %
id_field_type.__name__)
if not hasattr(self, "_key_object"):
if not hasattr(self, '_key_object'):
self._key_object = self._document.objects.with_id(self.key)
return self._key_object
return self._key_object

View File

@ -70,7 +70,7 @@ class ValidationError(AssertionError):
field_name = None
_message = None
def __init__(self, message="", **kwargs):
def __init__(self, message='', **kwargs):
self.errors = kwargs.get('errors', {})
self.field_name = kwargs.get('field_name')
self.message = message
@ -136,10 +136,10 @@ class ValidationError(AssertionError):
value = ' '.join(
[generate_key(v, k) for k, v in value.iteritems()])
results = "%s.%s" % (prefix, value) if prefix else value
results = '%s.%s' % (prefix, value) if prefix else value
return results
error_dict = defaultdict(list)
for k, v in self.to_dict().iteritems():
error_dict[generate_key(v)].append(k)
return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()])
return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()])

View File

@ -57,8 +57,7 @@ RECURSIVE_REFERENCE_CONSTANT = 'self'
class StringField(BaseField):
"""A unicode string field.
"""
"""A unicode string field."""
def __init__(self, regex=None, max_length=None, min_length=None, **kwargs):
self.regex = re.compile(regex) if regex else None
@ -151,8 +150,8 @@ class URLField(StringField):
if self.verify_exists:
warnings.warn(
"The URLField verify_exists argument has intractable security "
"and performance issues. Accordingly, it has been deprecated.",
'The URLField verify_exists argument has intractable security '
'and performance issues. Accordingly, it has been deprecated.',
DeprecationWarning)
try:
request = urllib2.Request(value)
@ -183,8 +182,7 @@ class EmailField(StringField):
class IntField(BaseField):
"""An 32-bit integer field.
"""
"""32-bit integer field."""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
@ -217,8 +215,7 @@ class IntField(BaseField):
class LongField(BaseField):
"""An 64-bit integer field.
"""
"""64-bit integer field."""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
@ -254,8 +251,7 @@ class LongField(BaseField):
class FloatField(BaseField):
"""An floating point number field.
"""
"""Floating point number field."""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
@ -292,7 +288,7 @@ class FloatField(BaseField):
class DecimalField(BaseField):
"""A fixed-point decimal number field.
"""Fixed-point decimal number field.
.. versionchanged:: 0.8
.. versionadded:: 0.3
@ -333,10 +329,10 @@ class DecimalField(BaseField):
# Convert to string for python 2.6 before casting to Decimal
try:
value = decimal.Decimal("%s" % value)
value = decimal.Decimal('%s' % value)
except decimal.InvalidOperation:
return value
return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding)
return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding)
def to_mongo(self, value):
if value is None:
@ -365,7 +361,7 @@ class DecimalField(BaseField):
class BooleanField(BaseField):
"""A boolean field type.
"""Boolean field type.
.. versionadded:: 0.1.2
"""
@ -383,7 +379,7 @@ class BooleanField(BaseField):
class DateTimeField(BaseField):
"""A datetime field.
"""Datetime field.
Uses the python-dateutil library if available alternatively use time.strptime
to parse the dates. Note: python-dateutil's parser is fully featured and when
@ -643,7 +639,7 @@ class DynamicField(BaseField):
val = value.to_mongo(use_db_field, fields)
# If we its a document thats not inherited add _cls
if isinstance(value, Document):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
val = {'_ref': value.to_dbref(), '_cls': cls.__name__}
if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__
return val
@ -683,7 +679,7 @@ class DynamicField(BaseField):
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
def validate(self, value, clean=True):
if hasattr(value, "validate"):
if hasattr(value, 'validate'):
value.validate(clean=clean)
@ -703,8 +699,7 @@ class ListField(ComplexBaseField):
super(ListField, self).__init__(**kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
"""Make sure that a list of valid fields is being used."""
if (not isinstance(value, (list, tuple, QuerySet)) or
isinstance(value, six.string_types)):
self.error('Only lists and tuples may be used in a list field')
@ -737,7 +732,6 @@ class EmbeddedDocumentListField(ListField):
:class:`~mongoengine.EmbeddedDocument`.
.. versionadded:: 0.9
"""
def __init__(self, document_type, **kwargs):
@ -786,8 +780,8 @@ class SortedListField(ListField):
def key_not_string(d):
""" Helper function to recursively determine if any key in a dictionary is
not a string.
"""Helper function to recursively determine if any key in a
dictionary is not a string.
"""
for k, v in d.items():
if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)):
@ -795,8 +789,8 @@ def key_not_string(d):
def key_has_dot_or_dollar(d):
""" Helper function to recursively determine if any key in a dictionary
contains a dot or a dollar sign.
"""Helper function to recursively determine if any key in a
dictionary contains a dot or a dollar sign.
"""
for k, v in d.items():
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
@ -824,14 +818,13 @@ class DictField(ComplexBaseField):
super(DictField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
"""Make sure that a list of valid fields is being used."""
if not isinstance(value, dict):
self.error('Only dictionaries may be used in a DictField')
if key_not_string(value):
msg = ("Invalid dictionary key - documents must "
"have only string keys")
msg = ('Invalid dictionary key - documents must '
'have only string keys')
self.error(msg)
if key_has_dot_or_dollar(value):
self.error('Invalid dictionary key name - keys may not contain "."'
@ -944,8 +937,7 @@ class ReferenceField(BaseField):
return self.document_type_obj
def __get__(self, instance, owner):
"""Descriptor to allow lazy dereferencing.
"""
"""Descriptor to allow lazy dereferencing."""
if instance is None:
# Document class being used rather than a document object
return self
@ -1002,8 +994,7 @@ class ReferenceField(BaseField):
return id_
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
"""Convert a MongoDB-compatible type to a Python type."""
if (not self.dbref and
not isinstance(value, (DBRef, Document, EmbeddedDocument))):
collection = self.document_type._get_collection_name()
@ -1019,7 +1010,7 @@ class ReferenceField(BaseField):
def validate(self, value):
if not isinstance(value, (self.document_type, DBRef)):
self.error("A ReferenceField only accepts DBRef or documents")
self.error('A ReferenceField only accepts DBRef or documents')
if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been '
@ -1135,7 +1126,7 @@ class CachedReferenceField(BaseField):
# TODO: should raise here or will fail next statement
value = SON((
("_id", id_field.to_mongo(id_)),
('_id', id_field.to_mongo(id_)),
))
if fields:
@ -1161,7 +1152,7 @@ class CachedReferenceField(BaseField):
def validate(self, value):
if not isinstance(value, self.document_type):
self.error("A CachedReferenceField only accepts documents")
self.error('A CachedReferenceField only accepts documents')
if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been '
@ -1298,8 +1289,7 @@ class GenericReferenceField(BaseField):
class BinaryField(BaseField):
"""A binary data field.
"""
"""A binary data field."""
def __init__(self, max_bytes=None, **kwargs):
self.max_bytes = max_bytes
@ -1316,8 +1306,8 @@ class BinaryField(BaseField):
def validate(self, value):
if not isinstance(value, (six.binary_type, six.text_type, Binary)):
self.error("BinaryField only accepts instances of "
"(%s, %s, Binary)" % (
self.error('BinaryField only accepts instances of '
'(%s, %s, Binary)' % (
six.binary_type.__name__, six.text_type.__name__))
if self.max_bytes is not None and len(value) > self.max_bytes:
@ -1450,7 +1440,7 @@ class GridFSProxy(object):
try:
return gridout.read(size)
except Exception:
return ""
return ''
def delete(self):
# Delete file from GridFS, FileField still remains
@ -1482,9 +1472,8 @@ class FileField(BaseField):
"""
proxy_class = GridFSProxy
def __init__(self,
db_alias=DEFAULT_CONNECTION_NAME,
collection_name="fs", **kwargs):
def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs',
**kwargs):
super(FileField, self).__init__(**kwargs)
self.collection_name = collection_name
self.db_alias = db_alias
@ -1687,10 +1676,10 @@ class ImageGridFsProxy(GridFSProxy):
return self.fs.get(out.thumbnail_id)
def write(self, *args, **kwargs):
raise RuntimeError("Please use \"put\" method instead")
raise RuntimeError('Please use "put" method instead')
def writelines(self, *args, **kwargs):
raise RuntimeError("Please use \"put\" method instead")
raise RuntimeError('Please use "put" method instead')
class ImproperlyConfigured(Exception):
@ -1715,7 +1704,7 @@ class ImageField(FileField):
def __init__(self, size=None, thumbnail_size=None,
collection_name='images', **kwargs):
if not Image:
raise ImproperlyConfigured("PIL library was not found")
raise ImproperlyConfigured('PIL library was not found')
params_size = ('width', 'height', 'force')
extra_args = dict(size=size, thumbnail_size=thumbnail_size)
@ -1783,10 +1772,10 @@ class SequenceField(BaseField):
Generate and Increment the counter
"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
sequence_id = '%s.%s' % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id},
update={"$inc": {"next": 1}},
counter = collection.find_and_modify(query={'_id': sequence_id},
update={'$inc': {'next': 1}},
new=True,
upsert=True)
return self.value_decorator(counter['next'])
@ -1809,9 +1798,9 @@ class SequenceField(BaseField):
as it is only fixed on set.
"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
sequence_id = '%s.%s' % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
data = collection.find_one({"_id": sequence_id})
data = collection.find_one({'_id': sequence_id})
if data:
return self.value_decorator(data['next'] + 1)
@ -1924,19 +1913,18 @@ class GeoPointField(BaseField):
_geo_index = pymongo.GEO2D
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
"""Make sure that a geo-value is of type (x, y)"""
if not isinstance(value, (list, tuple)):
self.error('GeoPointField can only accept tuples or lists '
'of (x, y)')
if not len(value) == 2:
self.error("Value (%s) must be a two-dimensional point" %
self.error('Value (%s) must be a two-dimensional point' %
repr(value))
elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))):
self.error(
"Both values (%s) in point must be float or int" % repr(value))
'Both values (%s) in point must be float or int' % repr(value))
class PointField(GeoJsonBaseField):
@ -1946,8 +1934,8 @@ class PointField(GeoJsonBaseField):
.. code-block:: js
{ "type" : "Point" ,
"coordinates" : [x, y]}
{'type' : 'Point' ,
'coordinates' : [x, y]}
You can either pass a dict with the full information or a list
to set the value.
@ -1956,7 +1944,7 @@ class PointField(GeoJsonBaseField):
.. versionadded:: 0.8
"""
_type = "Point"
_type = 'Point'
class LineStringField(GeoJsonBaseField):
@ -1966,8 +1954,8 @@ class LineStringField(GeoJsonBaseField):
.. code-block:: js
{ "type" : "LineString" ,
"coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]}
{'type' : 'LineString' ,
'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]}
You can either pass a dict with the full information or a list of points.
@ -1975,7 +1963,7 @@ class LineStringField(GeoJsonBaseField):
.. versionadded:: 0.8
"""
_type = "LineString"
_type = 'LineString'
class PolygonField(GeoJsonBaseField):
@ -1985,8 +1973,8 @@ class PolygonField(GeoJsonBaseField):
.. code-block:: js
{ "type" : "Polygon" ,
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]],
{'type' : 'Polygon' ,
'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]]}
You can either pass a dict with the full information or a list
@ -1997,7 +1985,7 @@ class PolygonField(GeoJsonBaseField):
.. versionadded:: 0.8
"""
_type = "Polygon"
_type = 'Polygon'
class MultiPointField(GeoJsonBaseField):
@ -2007,8 +1995,8 @@ class MultiPointField(GeoJsonBaseField):
.. code-block:: js
{ "type" : "MultiPoint" ,
"coordinates" : [[x1, y1], [x2, y2]]}
{'type' : 'MultiPoint' ,
'coordinates' : [[x1, y1], [x2, y2]]}
You can either pass a dict with the full information or a list
to set the value.
@ -2017,7 +2005,7 @@ class MultiPointField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
_type = "MultiPoint"
_type = 'MultiPoint'
class MultiLineStringField(GeoJsonBaseField):
@ -2027,8 +2015,8 @@ class MultiLineStringField(GeoJsonBaseField):
.. code-block:: js
{ "type" : "MultiLineString" ,
"coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]],
{'type' : 'MultiLineString' ,
'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]]]}
You can either pass a dict with the full information or a list of points.
@ -2037,7 +2025,7 @@ class MultiLineStringField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
_type = "MultiLineString"
_type = 'MultiLineString'
class MultiPolygonField(GeoJsonBaseField):
@ -2047,8 +2035,8 @@ class MultiPolygonField(GeoJsonBaseField):
.. code-block:: js
{ "type" : "MultiPolygon" ,
"coordinates" : [[
{'type' : 'MultiPolygon' ,
'coordinates' : [[
[[x1, y1], [x1, y1] ... [xn, yn]],
[[x1, y1], [x1, y1] ... [xn, yn]]
], [
@ -2064,4 +2052,4 @@ class MultiPolygonField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
_type = "MultiPolygon"
_type = 'MultiPolygon'

View File

@ -74,10 +74,10 @@ class BaseQuerySet(object):
# subclasses of the class being used
if document._meta.get('allow_inheritance') is True:
if len(self._document._subclasses) == 1:
self._initial_query = {"_cls": self._document._subclasses[0]}
self._initial_query = {'_cls': self._document._subclasses[0]}
else:
self._initial_query = {
"_cls": {"$in": self._document._subclasses}}
'_cls': {'$in': self._document._subclasses}}
self._loaded_fields = QueryFieldList(always_include=['_cls'])
self._cursor_obj = None
self._limit = None
@ -106,8 +106,8 @@ class BaseQuerySet(object):
if q_obj:
# make sure proper query object is passed
if not isinstance(q_obj, QNode):
msg = ("Not a query object: %s. "
"Did you intend to use key=value?" % q_obj)
msg = ('Not a query object: %s. '
'Did you intend to use key=value?' % q_obj)
raise InvalidQueryError(msg)
query &= q_obj
@ -134,10 +134,10 @@ class BaseQuerySet(object):
obj_dict = self.__dict__.copy()
# don't picke collection, instead pickle collection params
obj_dict.pop("_collection_obj")
obj_dict.pop('_collection_obj')
# don't pickle cursor
obj_dict["_cursor_obj"] = None
obj_dict['_cursor_obj'] = None
return obj_dict
@ -148,7 +148,7 @@ class BaseQuerySet(object):
See https://github.com/MongoEngine/mongoengine/issues/442
"""
obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection()
obj_dict['_collection_obj'] = obj_dict['_document']._get_collection()
# update attributes
self.__dict__.update(obj_dict)
@ -200,19 +200,16 @@ class BaseQuerySet(object):
raise NotImplementedError
def _has_data(self):
""" Retrieves whether cursor has any data. """
"""Return True if cursor has any data."""
queryset = self.order_by()
return False if queryset.first() is None else True
def __nonzero__(self):
""" Avoid to open all records in an if stmt in Py2. """
"""Avoid to open all records in an if stmt in Py2."""
return self._has_data()
def __bool__(self):
""" Avoid to open all records in an if stmt in Py3. """
"""Avoid to open all records in an if stmt in Py3."""
return self._has_data()
# Core functions
@ -240,7 +237,7 @@ class BaseQuerySet(object):
queryset = self.clone()
if queryset._search_text:
raise OperationError(
"It is not possible to use search_text two times.")
'It is not possible to use search_text two times.')
query_kwargs = SON({'$search': text})
if language:
@ -269,7 +266,7 @@ class BaseQuerySet(object):
try:
result = queryset.next()
except StopIteration:
msg = ("%s matching query does not exist."
msg = ('%s matching query does not exist.'
% queryset._document._class_name)
raise queryset._document.DoesNotExist(msg)
try:
@ -291,8 +288,7 @@ class BaseQuerySet(object):
return self._document(**kwargs).save()
def first(self):
"""Retrieve the first object matching the query.
"""
"""Retrieve the first object matching the query."""
queryset = self.clone()
try:
result = queryset[0]
@ -341,7 +337,7 @@ class BaseQuerySet(object):
% str(self._document))
raise OperationError(msg)
if doc.pk and not doc._created:
msg = "Some documents have ObjectIds use doc.update() instead"
msg = 'Some documents have ObjectIds use doc.update() instead'
raise OperationError(msg)
signal_kwargs = signal_kwargs or {}
@ -433,7 +429,7 @@ class BaseQuerySet(object):
rule = doc._meta['delete_rules'][rule_entry]
if rule == DENY and document_cls.objects(
**{field_name + '__in': self}).count() > 0:
msg = ("Could not delete document (%s.%s refers to it)"
msg = ('Could not delete document (%s.%s refers to it)'
% (document_cls.__name__, field_name))
raise OperationError(msg)
@ -462,7 +458,7 @@ class BaseQuerySet(object):
result = queryset._collection.remove(queryset._query, **write_concern)
if result:
return result.get("n")
return result.get('n')
def update(self, upsert=False, multi=True, write_concern=None,
full_result=False, **update):
@ -483,7 +479,7 @@ class BaseQuerySet(object):
.. versionadded:: 0.2
"""
if not update and not upsert:
raise OperationError("No update parameters, would remove data")
raise OperationError('No update parameters, would remove data')
if write_concern is None:
write_concern = {}
@ -496,9 +492,9 @@ class BaseQuerySet(object):
# then ensure we add _cls to the update operation
if upsert and '_cls' in query:
if '$set' in update:
update["$set"]["_cls"] = queryset._document._class_name
update['$set']['_cls'] = queryset._document._class_name
else:
update["$set"] = {"_cls": queryset._document._class_name}
update['$set'] = {'_cls': queryset._document._class_name}
try:
result = queryset._collection.update(query, update, multi=multi,
upsert=upsert, **write_concern)
@ -583,11 +579,11 @@ class BaseQuerySet(object):
"""
if remove and new:
raise OperationError("Conflicting parameters: remove and new")
raise OperationError('Conflicting parameters: remove and new')
if not update and not upsert and not remove:
raise OperationError(
"No update parameters, must either update or remove")
'No update parameters, must either update or remove')
queryset = self.clone()
query = queryset._query
@ -598,7 +594,7 @@ class BaseQuerySet(object):
try:
if IS_PYMONGO_3:
if full_response:
msg = "With PyMongo 3+, it is not possible anymore to get the full response."
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
warnings.warn(msg, DeprecationWarning)
if remove:
result = queryset._collection.find_one_and_delete(
@ -617,13 +613,13 @@ class BaseQuerySet(object):
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u"Update failed (%s)" % err)
raise NotUniqueError(u'Update failed (%s)' % err)
except pymongo.errors.OperationFailure as err:
raise OperationError(u"Update failed (%s)" % err)
raise OperationError(u'Update failed (%s)' % err)
if full_response:
if result["value"] is not None:
result["value"] = self._document._from_son(result["value"], only_fields=self.only_fields)
if result['value'] is not None:
result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields)
else:
if result is not None:
result = self._document._from_son(result, only_fields=self.only_fields)
@ -641,7 +637,7 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
if not queryset._query_obj.empty:
msg = "Cannot use a filter whilst using `with_id`"
msg = 'Cannot use a filter whilst using `with_id`'
raise InvalidQueryError(msg)
return queryset.filter(pk=object_id).first()
@ -685,7 +681,7 @@ class BaseQuerySet(object):
Only return instances of this document and not any inherited documents
"""
if self._document._meta.get('allow_inheritance') is True:
self._initial_query = {"_cls": self._document._class_name}
self._initial_query = {'_cls': self._document._class_name}
return self
@ -824,9 +820,9 @@ class BaseQuerySet(object):
ListField = _import_class('ListField')
GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField')
if isinstance(doc_field, ListField):
doc_field = getattr(doc_field, "field", doc_field)
doc_field = getattr(doc_field, 'field', doc_field)
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
instance = getattr(doc_field, "document_type", False)
instance = getattr(doc_field, 'document_type', False)
# handle distinct on subdocuments
if '.' in field:
for field_part in field.split('.')[1:]:
@ -837,9 +833,9 @@ class BaseQuerySet(object):
doc_field = getattr(doc_field, field_part, doc_field)
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
if isinstance(doc_field, ListField):
doc_field = getattr(doc_field, "field", doc_field)
doc_field = getattr(doc_field, 'field', doc_field)
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
instance = getattr(doc_field, "document_type", False)
instance = getattr(doc_field, 'document_type', False)
if instance and isinstance(doc_field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)):
distinct = [instance(**doc) for doc in distinct]
@ -848,12 +844,12 @@ class BaseQuerySet(object):
def only(self, *fields):
"""Load only a subset of this document's fields. ::
post = BlogPost.objects(...).only("title", "author.name")
post = BlogPost.objects(...).only('title', 'author.name')
.. note :: `only()` is chainable and will perform a union ::
So with the following it will fetch both: `title` and `author.name`::
post = BlogPost.objects.only("title").only("author.name")
post = BlogPost.objects.only('title').only('author.name')
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
field filters.
@ -870,12 +866,12 @@ class BaseQuerySet(object):
def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. ::
post = BlogPost.objects(...).exclude("comments")
post = BlogPost.objects(...).exclude('comments')
.. note :: `exclude()` is chainable and will perform a union ::
So with the following it will exclude both: `title` and `author.name`::
post = BlogPost.objects.exclude("title").exclude("author.name")
post = BlogPost.objects.exclude('title').exclude('author.name')
:func:`~mongoengine.queryset.QuerySet.all_fields` will reset any
field filters.
@ -905,7 +901,7 @@ class BaseQuerySet(object):
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
operators = ['slice']
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
@ -929,7 +925,7 @@ class BaseQuerySet(object):
"""Include all fields. Reset all previously calls of .only() or
.exclude(). ::
post = BlogPost.objects.exclude("comments").all_fields()
post = BlogPost.objects.exclude('comments').all_fields()
.. versionadded:: 0.5
"""
@ -956,7 +952,7 @@ class BaseQuerySet(object):
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details.
"""
return self._chainable_method("comment", text)
return self._chainable_method('comment', text)
def explain(self, format=False):
"""Return an explain plan record for the
@ -979,7 +975,7 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = "snapshot is deprecated as it has no impact when using PyMongo 3+."
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._snapshot = enabled
@ -1005,7 +1001,7 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+
"""
if IS_PYMONGO_3:
msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+."
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._slave_okay = enabled
@ -1067,7 +1063,7 @@ class BaseQuerySet(object):
:param ms: the number of milliseconds before killing the query on the server
"""
return self._chainable_method("max_time_ms", ms)
return self._chainable_method('max_time_ms', ms)
# JSON Helpers
@ -1150,8 +1146,8 @@ class BaseQuerySet(object):
MapReduceDocument = _import_class('MapReduceDocument')
if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.7.1")
if not hasattr(self._collection, 'map_reduce'):
raise NotImplementedError('Requires MongoDB >= 1.7.1')
map_f_scope = {}
if isinstance(map_f, Code):
@ -1201,7 +1197,7 @@ class BaseQuerySet(object):
break
else:
raise OperationError("actionData not specified for output")
raise OperationError('actionData not specified for output')
db_alias = output.get('db_alias')
remaing_args = ['db', 'sharded', 'nonAtomic']
@ -1431,7 +1427,7 @@ class BaseQuerySet(object):
# snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers
if self._snapshot:
msg = "The snapshot option is not anymore available with PyMongo 3+"
msg = 'The snapshot option is not anymore available with PyMongo 3+'
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': not self._timeout
@ -1443,7 +1439,7 @@ class BaseQuerySet(object):
if fields_name not in cursor_args:
cursor_args[fields_name] = {}
cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"}
cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'}
return cursor_args
@ -1498,8 +1494,8 @@ class BaseQuerySet(object):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
if self._class_check and self._initial_query:
if "_cls" in self._mongo_query:
self._mongo_query = {"$and": [self._initial_query, self._mongo_query]}
if '_cls' in self._mongo_query:
self._mongo_query = {'$and': [self._initial_query, self._mongo_query]}
else:
self._mongo_query.update(self._initial_query)
return self._mongo_query
@ -1511,8 +1507,7 @@ class BaseQuerySet(object):
return self.__dereference
def no_dereference(self):
"""Turn off any dereferencing for the results of this queryset.
"""
"""Turn off any dereferencing for the results of this queryset."""
queryset = self.clone()
queryset._auto_dereference = False
return queryset
@ -1641,14 +1636,14 @@ class BaseQuerySet(object):
for x in document._subclasses][1:]
for field in fields:
try:
field = ".".join(f.db_field for f in
field = '.'.join(f.db_field for f in
document._lookup_field(field.split('.')))
ret.append(field)
except LookUpError as err:
found = False
for subdoc in subclasses:
try:
subfield = ".".join(f.db_field for f in
subfield = '.'.join(f.db_field for f in
subdoc._lookup_field(field.split('.')))
ret.append(subfield)
found = True
@ -1661,15 +1656,14 @@ class BaseQuerySet(object):
return ret
def _get_order_by(self, keys):
"""Creates a list of order by fields
"""
"""Creates a list of order by fields"""
key_list = []
for key in keys:
if not key:
continue
if key == '$text_score':
key_list.append(('_text_score', {'$meta': "textScore"}))
key_list.append(('_text_score', {'$meta': 'textScore'}))
continue
direction = pymongo.ASCENDING
@ -1775,7 +1769,7 @@ class BaseQuerySet(object):
field_name = match.group(1).split('.')
fields = self._document._lookup_field(field_name)
# Substitute the correct name for the field into the javascript
return ".".join([f.db_field for f in fields])
return '.'.join([f.db_field for f in fields])
code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
@ -1786,21 +1780,21 @@ class BaseQuerySet(object):
queryset = self.clone()
method = getattr(queryset._cursor, method_name)
method(val)
setattr(queryset, "_" + method_name, val)
setattr(queryset, '_' + method_name, val)
return queryset
# Deprecated
def ensure_index(self, **kwargs):
"""Deprecated use :func:`Document.ensure_index`"""
msg = ("Doc.objects()._ensure_index() is deprecated. "
"Use Doc.ensure_index() instead.")
msg = ('Doc.objects()._ensure_index() is deprecated. '
'Use Doc.ensure_index() instead.')
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_index(**kwargs)
return self
def _ensure_indexes(self):
"""Deprecated use :func:`~Document.ensure_indexes`"""
msg = ("Doc.objects()._ensure_indexes() is deprecated. "
"Use Doc.ensure_indexes() instead.")
msg = ('Doc.objects()._ensure_indexes() is deprecated. '
'Use Doc.ensure_indexes() instead.')
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_indexes()

View File

@ -53,15 +53,14 @@ class QuerySet(BaseQuerySet):
return self._len
def __repr__(self):
"""Provides the string representation of the QuerySet
"""
"""Provide a string representation of the QuerySet"""
if self._iter:
return '.. queryset mid-iteration ..'
self._populate_cache()
data = self._result_cache[:REPR_OUTPUT_SIZE + 1]
if len(data) > REPR_OUTPUT_SIZE:
data[-1] = "...(remaining elements truncated)..."
data[-1] = '...(remaining elements truncated)...'
return repr(data)
def _iter_results(self):
@ -142,7 +141,7 @@ class QuerySet(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
raise OperationError("QuerySet already cached")
raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection))
@ -171,7 +170,7 @@ class QuerySetNoCache(BaseQuerySet):
except StopIteration:
break
if len(data) > REPR_OUTPUT_SIZE:
data[-1] = "...(remaining elements truncated)..."
data[-1] = '...(remaining elements truncated)...'
self.rewind()
return repr(data)

View File

@ -28,12 +28,11 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format.
"""
"""Transform a query from Django-style format to Mongo format."""
mongo_query = {}
merge_query = defaultdict(list)
for key, value in sorted(kwargs.items()):
if key == "__raw__":
if key == '__raw__':
mongo_query.update(value)
continue
@ -46,7 +45,7 @@ def query(_doc_cls=None, **kwargs):
op = parts.pop()
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
if len(parts) > 1 and parts[-1] == '':
parts.pop()
negate = False
@ -117,10 +116,10 @@ def query(_doc_cls=None, **kwargs):
value = query(field.field.document_type, **value)
else:
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value}
value = {'$elemMatch': value}
elif op in CUSTOM_OPERATORS:
NotImplementedError("Custom method '%s' has not "
"been implemented" % op)
NotImplementedError('Custom method "%s" has not '
'been implemented' % op)
elif op not in STRING_OPERATORS:
value = {'$' + op: value}
@ -183,15 +182,16 @@ def query(_doc_cls=None, **kwargs):
def update(_doc_cls=None, **update):
"""Transform an update spec from Django-style format to Mongo format.
"""Transform an update spec from Django-style format to Mongo
format.
"""
mongo_update = {}
for key, value in update.items():
if key == "__raw__":
if key == '__raw__':
mongo_update.update(value)
continue
parts = key.split('__')
# if there is no operator, default to "set"
# if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
parts.insert(0, 'set')
# Check for an operator and transform to mongo-style if there is
@ -210,14 +210,14 @@ def update(_doc_cls=None, **update):
elif op == 'add_to_set':
op = 'addToSet'
elif op == 'set_on_insert':
op = "setOnInsert"
op = 'setOnInsert'
match = None
if parts[-1] in COMPARISON_OPERATORS:
match = parts.pop()
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
if len(parts) > 1 and parts[-1] == '':
parts.pop()
if _doc_cls:
@ -253,7 +253,7 @@ def update(_doc_cls=None, **update):
else:
field = cleaned_fields[-1]
GeoJsonBaseField = _import_class("GeoJsonBaseField")
GeoJsonBaseField = _import_class('GeoJsonBaseField')
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
@ -267,7 +267,7 @@ def update(_doc_cls=None, **update):
value = [field.prepare_query_value(op, v) for v in value]
elif field.required or value is not None:
value = field.prepare_query_value(op, value)
elif op == "unset":
elif op == 'unset':
value = 1
if match:
@ -277,16 +277,16 @@ def update(_doc_cls=None, **update):
key = '.'.join(parts)
if not op:
raise InvalidQueryError("Updates must supply an operation "
"eg: set__FIELD=value")
raise InvalidQueryError('Updates must supply an operation '
'eg: set__FIELD=value')
if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations
# unless they point to a list field
# Otherwise it uses nested dict syntax
if op == 'pullAll':
raise InvalidQueryError("pullAll operations only support "
"a single field depth")
raise InvalidQueryError('pullAll operations only support '
'a single field depth')
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
@ -297,7 +297,7 @@ def update(_doc_cls=None, **update):
# Then process as normal
last_listField = len(
cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField])
key = '.'.join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
@ -305,7 +305,7 @@ def update(_doc_cls=None, **update):
for key in parts:
value = {key: value}
elif op == 'addToSet' and isinstance(value, list):
value = {key: {"$each": value}}
value = {key: {'$each': value}}
else:
value = {key: value}
key = '$' + op
@ -319,78 +319,82 @@ def update(_doc_cls=None, **update):
def _geo_operator(field, op, value):
"""Helper to return the query for a given geo query"""
if op == "max_distance":
"""Helper to return the query for a given geo query."""
if op == 'max_distance':
value = {'$maxDistance': value}
elif op == "min_distance":
elif op == 'min_distance':
value = {'$minDistance': value}
elif field._geo_index == pymongo.GEO2D:
if op == "within_distance":
if op == 'within_distance':
value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
elif op == 'within_spherical_distance':
value = {'$within': {'$centerSphere': value}}
elif op == "within_polygon":
elif op == 'within_polygon':
value = {'$within': {'$polygon': value}}
elif op == "near":
elif op == 'near':
value = {'$near': value}
elif op == "near_sphere":
elif op == 'near_sphere':
value = {'$nearSphere': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented for a GeoPointField" % op)
raise NotImplementedError('Geo method "%s" has not been '
'implemented for a GeoPointField' % op)
else:
if op == "geo_within":
value = {"$geoWithin": _infer_geometry(value)}
elif op == "geo_within_box":
value = {"$geoWithin": {"$box": value}}
elif op == "geo_within_polygon":
value = {"$geoWithin": {"$polygon": value}}
elif op == "geo_within_center":
value = {"$geoWithin": {"$center": value}}
elif op == "geo_within_sphere":
value = {"$geoWithin": {"$centerSphere": value}}
elif op == "geo_intersects":
value = {"$geoIntersects": _infer_geometry(value)}
elif op == "near":
if op == 'geo_within':
value = {'$geoWithin': _infer_geometry(value)}
elif op == 'geo_within_box':
value = {'$geoWithin': {'$box': value}}
elif op == 'geo_within_polygon':
value = {'$geoWithin': {'$polygon': value}}
elif op == 'geo_within_center':
value = {'$geoWithin': {'$center': value}}
elif op == 'geo_within_sphere':
value = {'$geoWithin': {'$centerSphere': value}}
elif op == 'geo_intersects':
value = {'$geoIntersects': _infer_geometry(value)}
elif op == 'near':
value = {'$near': _infer_geometry(value)}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented for a %s " % (op, field._name))
raise NotImplementedError(
'Geo method "%s" has not been implemented for a %s '
% (op, field._name)
)
return value
def _infer_geometry(value):
"""Helper method that tries to infer the $geometry shape for a given value"""
"""Helper method that tries to infer the $geometry shape for a
given value.
"""
if isinstance(value, dict):
if "$geometry" in value:
if '$geometry' in value:
return value
elif 'coordinates' in value and 'type' in value:
return {"$geometry": value}
raise InvalidQueryError("Invalid $geometry dictionary should have "
"type and coordinates keys")
return {'$geometry': value}
raise InvalidQueryError('Invalid $geometry dictionary should have '
'type and coordinates keys')
elif isinstance(value, (list, set)):
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
# TODO: should both TypeError and IndexError be alike interpreted?
try:
value[0][0][0]
return {"$geometry": {"type": "Polygon", "coordinates": value}}
return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
except (TypeError, IndexError):
pass
try:
value[0][0]
return {"$geometry": {"type": "LineString", "coordinates": value}}
return {'$geometry': {'type': 'LineString', 'coordinates': value}}
except (TypeError, IndexError):
pass
try:
value[0]
return {"$geometry": {"type": "Point", "coordinates": value}}
return {'$geometry': {'type': 'Point', 'coordinates': value}}
except (TypeError, IndexError):
pass
raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary "
"or (nested) lists of coordinate(s)")
raise InvalidQueryError('Invalid $geometry data. Can be either a '
'dictionary or (nested) lists of coordinate(s)')

View File

@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor):
self.document = document
def visit_combination(self, combination):
operator = "$and"
operator = '$and'
if combination.operation == combination.OR:
operator = "$or"
operator = '$or'
return {operator: combination.children}
def visit_query(self, query):
@ -79,8 +79,7 @@ class QueryCompilerVisitor(QNodeVisitor):
class QNode(object):
"""Base class for nodes in query trees.
"""
"""Base class for nodes in query trees."""
AND = 0
OR = 1
@ -94,7 +93,8 @@ class QNode(object):
raise NotImplementedError
def _combine(self, other, operation):
"""Combine this node with another node into a QCombination object.
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
return self
@ -116,8 +116,8 @@ class QNode(object):
class QCombination(QNode):
"""Represents the combination of several conditions by a given logical
operator.
"""Represents the combination of several conditions by a given
logical operator.
"""
def __init__(self, operation, children):