632 lines
24 KiB
Python
632 lines
24 KiB
Python
import operator
|
|
import warnings
|
|
import weakref
|
|
|
|
from bson import DBRef, ObjectId, SON
|
|
import pymongo
|
|
import six
|
|
|
|
from mongoengine.base.common import UPDATE_OPERATORS
|
|
from mongoengine.base.datastructures import (BaseDict, BaseList,
|
|
EmbeddedDocumentList)
|
|
from mongoengine.common import _import_class
|
|
from mongoengine.errors import ValidationError
|
|
|
|
|
|
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
|
|
'GeoJsonBaseField')
|
|
|
|
|
|
class BaseField(object):
|
|
"""A base class for fields in a MongoDB document. Instances of this class
|
|
may be added to subclasses of `Document` to define a document's schema.
|
|
|
|
.. versionchanged:: 0.5 - added verbose and help text
|
|
"""
|
|
|
|
name = None
|
|
_geo_index = False
|
|
_auto_gen = False # Call `generate` to generate a value
|
|
_auto_dereference = True
|
|
|
|
# These track each time a Field instance is created. Used to retain order.
|
|
# The auto_creation_counter is used for fields that MongoEngine implicitly
|
|
# creates, creation_counter is used for all user-specified fields.
|
|
creation_counter = 0
|
|
auto_creation_counter = -1
|
|
|
|
def __init__(self, db_field=None, name=None, required=False, default=None,
|
|
unique=False, unique_with=None, primary_key=False,
|
|
validation=None, choices=None, null=False, sparse=False,
|
|
**kwargs):
|
|
"""
|
|
:param db_field: The database field to store this field in
|
|
(defaults to the name of the field)
|
|
:param name: Depreciated - use db_field
|
|
:param required: If the field is required. Whether it has to have a
|
|
value or not. Defaults to False.
|
|
:param default: (optional) The default value for this field if no value
|
|
has been set (or if the value has been unset). It can be a
|
|
callable.
|
|
:param unique: Is the field value unique or not. Defaults to False.
|
|
:param unique_with: (optional) The other field this field should be
|
|
unique with.
|
|
:param primary_key: Mark this field as the primary key. Defaults to False.
|
|
:param validation: (optional) A callable to validate the value of the
|
|
field. Generally this is deprecated in favour of the
|
|
`FIELD.validate` method
|
|
:param choices: (optional) The valid choices
|
|
:param null: (optional) Is the field value can be null. If no and there is a default value
|
|
then the default value is set
|
|
:param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False`
|
|
means that uniqueness won't be enforced for `None` values
|
|
:param **kwargs: (optional) Arbitrary indirection-free metadata for
|
|
this field can be supplied as additional keyword arguments and
|
|
accessed as attributes of the field. Must not conflict with any
|
|
existing attributes. Common metadata includes `verbose_name` and
|
|
`help_text`.
|
|
"""
|
|
self.db_field = (db_field or name) if not primary_key else '_id'
|
|
|
|
if name:
|
|
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
|
|
warnings.warn(msg, DeprecationWarning)
|
|
self.required = required or primary_key
|
|
self.default = default
|
|
self.unique = bool(unique or unique_with)
|
|
self.unique_with = unique_with
|
|
self.primary_key = primary_key
|
|
self.validation = validation
|
|
self.choices = choices
|
|
self.null = null
|
|
self.sparse = sparse
|
|
self._owner_document = None
|
|
|
|
# Detect and report conflicts between metadata and base properties.
|
|
conflicts = set(dir(self)) & set(kwargs)
|
|
if conflicts:
|
|
raise TypeError("%s already has attribute(s): %s" % (
|
|
self.__class__.__name__, ', '.join(conflicts)))
|
|
|
|
# Assign metadata to the instance
|
|
# This efficient method is available because no __slots__ are defined.
|
|
self.__dict__.update(kwargs)
|
|
|
|
# Adjust the appropriate creation counter, and save our local copy.
|
|
if self.db_field == '_id':
|
|
self.creation_counter = BaseField.auto_creation_counter
|
|
BaseField.auto_creation_counter -= 1
|
|
else:
|
|
self.creation_counter = BaseField.creation_counter
|
|
BaseField.creation_counter += 1
|
|
|
|
def __get__(self, instance, owner):
|
|
"""Descriptor for retrieving a value from a field in a document.
|
|
"""
|
|
if instance is None:
|
|
# Document class being used rather than a document object
|
|
return self
|
|
|
|
# Get value from document instance if available
|
|
return instance._data.get(self.name)
|
|
|
|
def __set__(self, instance, value):
|
|
"""Descriptor for assigning a value to a field in a document.
|
|
"""
|
|
|
|
# If setting to None and there is a default
|
|
# Then set the value to the default value
|
|
if value is None:
|
|
if self.null:
|
|
value = None
|
|
elif self.default is not None:
|
|
value = self.default
|
|
if callable(value):
|
|
value = value()
|
|
|
|
if instance._initialised:
|
|
try:
|
|
if (self.name not in instance._data or
|
|
instance._data[self.name] != value):
|
|
instance._mark_as_changed(self.name)
|
|
except Exception:
|
|
# Values cant be compared eg: naive and tz datetimes
|
|
# So mark it as changed
|
|
instance._mark_as_changed(self.name)
|
|
|
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
|
if isinstance(value, EmbeddedDocument):
|
|
value._instance = weakref.proxy(instance)
|
|
elif isinstance(value, (list, tuple)):
|
|
for v in value:
|
|
if isinstance(v, EmbeddedDocument):
|
|
v._instance = weakref.proxy(instance)
|
|
instance._data[self.name] = value
|
|
|
|
def error(self, message="", errors=None, field_name=None):
|
|
"""Raises a ValidationError.
|
|
"""
|
|
field_name = field_name if field_name else self.name
|
|
raise ValidationError(message, errors=errors, field_name=field_name)
|
|
|
|
def to_python(self, value):
|
|
"""Convert a MongoDB-compatible type to a Python type.
|
|
"""
|
|
return value
|
|
|
|
def to_mongo(self, value):
|
|
"""Convert a Python type to a MongoDB-compatible type.
|
|
"""
|
|
return self.to_python(value)
|
|
|
|
def _to_mongo_safe_call(self, value, use_db_field=True, fields=None):
|
|
"""A helper method to call to_mongo with proper inputs
|
|
"""
|
|
f_inputs = self.to_mongo.__code__.co_varnames
|
|
ex_vars = {}
|
|
if 'fields' in f_inputs:
|
|
ex_vars['fields'] = fields
|
|
|
|
if 'use_db_field' in f_inputs:
|
|
ex_vars['use_db_field'] = use_db_field
|
|
|
|
return self.to_mongo(value, **ex_vars)
|
|
|
|
def prepare_query_value(self, op, value):
|
|
"""Prepare a value that is being used in a query for PyMongo.
|
|
"""
|
|
if op in UPDATE_OPERATORS:
|
|
self.validate(value)
|
|
return value
|
|
|
|
def validate(self, value, clean=True):
|
|
"""Perform validation on a value.
|
|
"""
|
|
pass
|
|
|
|
def _validate_choices(self, value):
|
|
Document = _import_class('Document')
|
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
|
|
|
choice_list = self.choices
|
|
if isinstance(choice_list[0], (list, tuple)):
|
|
choice_list = [k for k, _ in choice_list]
|
|
|
|
# Choices which are other types of Documents
|
|
if isinstance(value, (Document, EmbeddedDocument)):
|
|
if not any(isinstance(value, c) for c in choice_list):
|
|
self.error(
|
|
'Value must be an instance of %s' % (
|
|
six.text_type(choice_list)
|
|
)
|
|
)
|
|
# Choices which are types other than Documents
|
|
elif value not in choice_list:
|
|
self.error('Value must be one of %s' % six.text_type(choice_list))
|
|
|
|
def _validate(self, value, **kwargs):
|
|
# Check the Choices Constraint
|
|
if self.choices:
|
|
self._validate_choices(value)
|
|
|
|
# check validation argument
|
|
if self.validation is not None:
|
|
if callable(self.validation):
|
|
if not self.validation(value):
|
|
self.error('Value does not match custom validation method')
|
|
else:
|
|
raise ValueError('validation argument for "%s" must be a '
|
|
'callable.' % self.name)
|
|
|
|
self.validate(value, **kwargs)
|
|
|
|
@property
|
|
def owner_document(self):
|
|
return self._owner_document
|
|
|
|
def _set_owner_document(self, owner_document):
|
|
self._owner_document = owner_document
|
|
|
|
@owner_document.setter
|
|
def owner_document(self, owner_document):
|
|
self._set_owner_document(owner_document)
|
|
|
|
|
|
class ComplexBaseField(BaseField):
|
|
"""Handles complex fields, such as lists / dictionaries.
|
|
|
|
Allows for nesting of embedded documents inside complex types.
|
|
Handles the lazy dereferencing of a queryset by lazily dereferencing all
|
|
items in a list / dict rather than one at a time.
|
|
|
|
.. versionadded:: 0.5
|
|
"""
|
|
|
|
field = None
|
|
|
|
def __get__(self, instance, owner):
|
|
"""Descriptor to automatically dereference references.
|
|
"""
|
|
if instance is None:
|
|
# Document class being used rather than a document object
|
|
return self
|
|
|
|
ReferenceField = _import_class('ReferenceField')
|
|
GenericReferenceField = _import_class('GenericReferenceField')
|
|
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
|
|
dereference = (self._auto_dereference and
|
|
(self.field is None or isinstance(self.field,
|
|
(GenericReferenceField, ReferenceField))))
|
|
|
|
_dereference = _import_class("DeReference")()
|
|
|
|
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
|
if instance._initialised and dereference and instance._data.get(self.name):
|
|
instance._data[self.name] = _dereference(
|
|
instance._data.get(self.name), max_depth=1, instance=instance,
|
|
name=self.name
|
|
)
|
|
|
|
value = super(ComplexBaseField, self).__get__(instance, owner)
|
|
|
|
# Convert lists / values so we can watch for any changes on them
|
|
if isinstance(value, (list, tuple)):
|
|
if (issubclass(type(self), EmbeddedDocumentListField) and
|
|
not isinstance(value, EmbeddedDocumentList)):
|
|
value = EmbeddedDocumentList(value, instance, self.name)
|
|
elif not isinstance(value, BaseList):
|
|
value = BaseList(value, instance, self.name)
|
|
instance._data[self.name] = value
|
|
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
|
value = BaseDict(value, instance, self.name)
|
|
instance._data[self.name] = value
|
|
|
|
if (self._auto_dereference and instance._initialised and
|
|
isinstance(value, (BaseList, BaseDict)) and
|
|
not value._dereferenced):
|
|
value = _dereference(
|
|
value, max_depth=1, instance=instance, name=self.name
|
|
)
|
|
value._dereferenced = True
|
|
instance._data[self.name] = value
|
|
|
|
return value
|
|
|
|
def to_python(self, value):
|
|
"""Convert a MongoDB-compatible type to a Python type.
|
|
"""
|
|
if isinstance(value, six.string_types):
|
|
return value
|
|
|
|
if hasattr(value, 'to_python'):
|
|
return value.to_python()
|
|
|
|
is_list = False
|
|
if not hasattr(value, 'items'):
|
|
try:
|
|
is_list = True
|
|
value = dict([(k, v) for k, v in enumerate(value)])
|
|
except TypeError: # Not iterable return the value
|
|
return value
|
|
|
|
if self.field:
|
|
self.field._auto_dereference = self._auto_dereference
|
|
value_dict = dict([(key, self.field.to_python(item))
|
|
for key, item in value.items()])
|
|
else:
|
|
Document = _import_class('Document')
|
|
value_dict = {}
|
|
for k, v in value.items():
|
|
if isinstance(v, Document):
|
|
# We need the id from the saved object to create the DBRef
|
|
if v.pk is None:
|
|
self.error('You can only reference documents once they'
|
|
' have been saved to the database')
|
|
collection = v._get_collection_name()
|
|
value_dict[k] = DBRef(collection, v.pk)
|
|
elif hasattr(v, 'to_python'):
|
|
value_dict[k] = v.to_python()
|
|
else:
|
|
value_dict[k] = self.to_python(v)
|
|
|
|
if is_list: # Convert back to a list
|
|
return [v for _, v in sorted(value_dict.items(),
|
|
key=operator.itemgetter(0))]
|
|
return value_dict
|
|
|
|
def to_mongo(self, value, use_db_field=True, fields=None):
|
|
"""Convert a Python type to a MongoDB-compatible type.
|
|
"""
|
|
Document = _import_class("Document")
|
|
EmbeddedDocument = _import_class("EmbeddedDocument")
|
|
GenericReferenceField = _import_class("GenericReferenceField")
|
|
|
|
if isinstance(value, six.string_types):
|
|
return value
|
|
|
|
if hasattr(value, 'to_mongo'):
|
|
if isinstance(value, Document):
|
|
return GenericReferenceField().to_mongo(value)
|
|
cls = value.__class__
|
|
val = value.to_mongo(use_db_field, fields)
|
|
# If it's a document that is not inherited add _cls
|
|
if isinstance(value, EmbeddedDocument):
|
|
val['_cls'] = cls.__name__
|
|
return val
|
|
|
|
is_list = False
|
|
if not hasattr(value, 'items'):
|
|
try:
|
|
is_list = True
|
|
value = dict([(k, v) for k, v in enumerate(value)])
|
|
except TypeError: # Not iterable return the value
|
|
return value
|
|
|
|
if self.field:
|
|
value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields))
|
|
for key, item in value.iteritems()])
|
|
else:
|
|
value_dict = {}
|
|
for k, v in value.iteritems():
|
|
if isinstance(v, Document):
|
|
# We need the id from the saved object to create the DBRef
|
|
if v.pk is None:
|
|
self.error('You can only reference documents once they'
|
|
' have been saved to the database')
|
|
|
|
# If its a document that is not inheritable it won't have
|
|
# any _cls data so make it a generic reference allows
|
|
# us to dereference
|
|
meta = getattr(v, '_meta', {})
|
|
allow_inheritance = meta.get('allow_inheritance')
|
|
if not allow_inheritance and not self.field:
|
|
value_dict[k] = GenericReferenceField().to_mongo(v)
|
|
else:
|
|
collection = v._get_collection_name()
|
|
value_dict[k] = DBRef(collection, v.pk)
|
|
elif hasattr(v, 'to_mongo'):
|
|
cls = v.__class__
|
|
val = v.to_mongo(use_db_field, fields)
|
|
# If it's a document that is not inherited add _cls
|
|
if isinstance(v, (Document, EmbeddedDocument)):
|
|
val['_cls'] = cls.__name__
|
|
value_dict[k] = val
|
|
else:
|
|
value_dict[k] = self.to_mongo(v, use_db_field, fields)
|
|
|
|
if is_list: # Convert back to a list
|
|
return [v for _, v in sorted(value_dict.items(),
|
|
key=operator.itemgetter(0))]
|
|
return value_dict
|
|
|
|
def validate(self, value):
|
|
"""If field is provided ensure the value is valid.
|
|
"""
|
|
errors = {}
|
|
if self.field:
|
|
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
|
|
sequence = value.iteritems()
|
|
else:
|
|
sequence = enumerate(value)
|
|
for k, v in sequence:
|
|
try:
|
|
self.field._validate(v)
|
|
except ValidationError as error:
|
|
errors[k] = error.errors or error
|
|
except (ValueError, AssertionError) as error:
|
|
errors[k] = error
|
|
|
|
if errors:
|
|
field_class = self.field.__class__.__name__
|
|
self.error('Invalid %s item (%s)' % (field_class, value),
|
|
errors=errors)
|
|
# Don't allow empty values if required
|
|
if self.required and not value:
|
|
self.error('Field is required and cannot be empty')
|
|
|
|
def prepare_query_value(self, op, value):
|
|
return self.to_mongo(value)
|
|
|
|
def lookup_member(self, member_name):
|
|
if self.field:
|
|
return self.field.lookup_member(member_name)
|
|
return None
|
|
|
|
def _set_owner_document(self, owner_document):
|
|
if self.field:
|
|
self.field.owner_document = owner_document
|
|
self._owner_document = owner_document
|
|
|
|
|
|
class ObjectIdField(BaseField):
|
|
"""A field wrapper around MongoDB's ObjectIds.
|
|
"""
|
|
|
|
def to_python(self, value):
|
|
try:
|
|
if not isinstance(value, ObjectId):
|
|
value = ObjectId(value)
|
|
except Exception:
|
|
pass
|
|
return value
|
|
|
|
def to_mongo(self, value):
|
|
if not isinstance(value, ObjectId):
|
|
try:
|
|
return ObjectId(six.text_type(value))
|
|
except Exception as e:
|
|
# e.message attribute has been deprecated since Python 2.6
|
|
self.error(six.text_type(e))
|
|
return value
|
|
|
|
def prepare_query_value(self, op, value):
|
|
return self.to_mongo(value)
|
|
|
|
def validate(self, value):
|
|
try:
|
|
ObjectId(six.text_type(value))
|
|
except Exception:
|
|
self.error('Invalid Object ID')
|
|
|
|
|
|
class GeoJsonBaseField(BaseField):
|
|
"""A geo json field storing a geojson style object.
|
|
|
|
.. versionadded:: 0.8
|
|
"""
|
|
|
|
_geo_index = pymongo.GEOSPHERE
|
|
_type = "GeoBase"
|
|
|
|
def __init__(self, auto_index=True, *args, **kwargs):
|
|
"""
|
|
:param bool auto_index: Automatically create a "2dsphere" index.\
|
|
Defaults to `True`.
|
|
"""
|
|
self._name = "%sField" % self._type
|
|
if not auto_index:
|
|
self._geo_index = False
|
|
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
|
|
|
|
def validate(self, value):
|
|
"""Validate the GeoJson object based on its type
|
|
"""
|
|
if isinstance(value, dict):
|
|
if set(value.keys()) == set(['type', 'coordinates']):
|
|
if value['type'] != self._type:
|
|
self.error('%s type must be "%s"' %
|
|
(self._name, self._type))
|
|
return self.validate(value['coordinates'])
|
|
else:
|
|
self.error('%s can only accept a valid GeoJson dictionary'
|
|
' or lists of (x, y)' % self._name)
|
|
return
|
|
elif not isinstance(value, (list, tuple)):
|
|
self.error('%s can only accept lists of [x, y]' % self._name)
|
|
return
|
|
|
|
validate = getattr(self, "_validate_%s" % self._type.lower())
|
|
error = validate(value)
|
|
if error:
|
|
self.error(error)
|
|
|
|
def _validate_polygon(self, value, top_level=True):
|
|
if not isinstance(value, (list, tuple)):
|
|
return 'Polygons must contain list of linestrings'
|
|
|
|
# Quick and dirty validator
|
|
try:
|
|
value[0][0][0]
|
|
except (TypeError, IndexError):
|
|
return "Invalid Polygon must contain at least one valid linestring"
|
|
|
|
errors = []
|
|
for val in value:
|
|
error = self._validate_linestring(val, False)
|
|
if not error and val[0] != val[-1]:
|
|
error = 'LineStrings must start and end at the same point'
|
|
if error and error not in errors:
|
|
errors.append(error)
|
|
if errors:
|
|
if top_level:
|
|
return "Invalid Polygon:\n%s" % ", ".join(errors)
|
|
else:
|
|
return "%s" % ", ".join(errors)
|
|
|
|
def _validate_linestring(self, value, top_level=True):
|
|
"""Validates a linestring"""
|
|
if not isinstance(value, (list, tuple)):
|
|
return 'LineStrings must contain list of coordinate pairs'
|
|
|
|
# Quick and dirty validator
|
|
try:
|
|
value[0][0]
|
|
except (TypeError, IndexError):
|
|
return "Invalid LineString must contain at least one valid point"
|
|
|
|
errors = []
|
|
for val in value:
|
|
error = self._validate_point(val)
|
|
if error and error not in errors:
|
|
errors.append(error)
|
|
if errors:
|
|
if top_level:
|
|
return "Invalid LineString:\n%s" % ", ".join(errors)
|
|
else:
|
|
return "%s" % ", ".join(errors)
|
|
|
|
def _validate_point(self, value):
|
|
"""Validate each set of coords"""
|
|
if not isinstance(value, (list, tuple)):
|
|
return 'Points must be a list of coordinate pairs'
|
|
elif not len(value) == 2:
|
|
return "Value (%s) must be a two-dimensional point" % repr(value)
|
|
elif (not isinstance(value[0], (float, int)) or
|
|
not isinstance(value[1], (float, int))):
|
|
return "Both values (%s) in point must be float or int" % repr(value)
|
|
|
|
def _validate_multipoint(self, value):
|
|
if not isinstance(value, (list, tuple)):
|
|
return 'MultiPoint must be a list of Point'
|
|
|
|
# Quick and dirty validator
|
|
try:
|
|
value[0][0]
|
|
except (TypeError, IndexError):
|
|
return "Invalid MultiPoint must contain at least one valid point"
|
|
|
|
errors = []
|
|
for point in value:
|
|
error = self._validate_point(point)
|
|
if error and error not in errors:
|
|
errors.append(error)
|
|
|
|
if errors:
|
|
return "%s" % ", ".join(errors)
|
|
|
|
def _validate_multilinestring(self, value, top_level=True):
|
|
if not isinstance(value, (list, tuple)):
|
|
return 'MultiLineString must be a list of LineString'
|
|
|
|
# Quick and dirty validator
|
|
try:
|
|
value[0][0][0]
|
|
except (TypeError, IndexError):
|
|
return "Invalid MultiLineString must contain at least one valid linestring"
|
|
|
|
errors = []
|
|
for linestring in value:
|
|
error = self._validate_linestring(linestring, False)
|
|
if error and error not in errors:
|
|
errors.append(error)
|
|
|
|
if errors:
|
|
if top_level:
|
|
return "Invalid MultiLineString:\n%s" % ", ".join(errors)
|
|
else:
|
|
return "%s" % ", ".join(errors)
|
|
|
|
def _validate_multipolygon(self, value):
|
|
if not isinstance(value, (list, tuple)):
|
|
return 'MultiPolygon must be a list of Polygon'
|
|
|
|
# Quick and dirty validator
|
|
try:
|
|
value[0][0][0][0]
|
|
except (TypeError, IndexError):
|
|
return "Invalid MultiPolygon must contain at least one valid Polygon"
|
|
|
|
errors = []
|
|
for polygon in value:
|
|
error = self._validate_polygon(polygon, False)
|
|
if error and error not in errors:
|
|
errors.append(error)
|
|
|
|
if errors:
|
|
return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
|
|
|
|
def to_mongo(self, value):
|
|
if isinstance(value, dict):
|
|
return value
|
|
return SON([("type", self._type), ("coordinates", value)])
|