Merge pull request #1 from MongoEngine/master

pull new changes from original
This commit is contained in:
iici-gli
2016-09-04 23:43:04 -04:00
committed by GitHub
37 changed files with 1228 additions and 488 deletions

View File

@@ -14,7 +14,7 @@ import errors
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
VERSION = (0, 10, 0)
VERSION = (0, 10, 6)
def get_version():

View File

@@ -199,7 +199,8 @@ class BaseList(list):
def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'):
if key:
self._instance._mark_as_changed('%s.%s' % (self._name, key))
self._instance._mark_as_changed('%s.%s' % (self._name,
key % len(self)))
else:
self._instance._mark_as_changed(self._name)
@@ -210,7 +211,7 @@ class EmbeddedDocumentList(BaseList):
def __match_all(cls, i, kwargs):
items = kwargs.items()
return all([
getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items
getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items
])
@classmethod

View File

@@ -51,7 +51,7 @@ class BaseDocument(object):
# We only want named arguments.
field = iter(self._fields_ordered)
# If its an automatic id field then skip to the first defined field
if self._auto_id_field:
if getattr(self, '_auto_id_field', False):
next(field)
for value in args:
name = next(field)
@@ -325,20 +325,17 @@ class BaseDocument(object):
if value is not None:
if isinstance(field, EmbeddedDocumentField):
if fields:
key = '%s.' % field_name
embedded_fields = [
i.replace(key, '') for i in fields
if i.startswith(key)]
if fields:
key = '%s.' % field_name
embedded_fields = [
i.replace(key, '') for i in fields
if i.startswith(key)]
else:
embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
else:
value = field.to_mongo(value)
embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
# Handle self generating fields
if value is None and field._auto_gen:
@@ -835,10 +832,6 @@ class BaseDocument(object):
if index_list:
spec['fields'] = index_list
if spec.get('sparse', False) and len(spec['fields']) > 1:
raise ValueError(
'Sparse indexes can only have one field in them. '
'See https://jira.mongodb.org/browse/SERVER-2193')
return spec
@@ -974,7 +967,7 @@ class BaseDocument(object):
if hasattr(getattr(field, 'field', None), 'lookup_member'):
new_field = field.field.lookup_member(field_name)
elif cls._dynamic and (isinstance(field, DynamicField) or
getattr(getattr(field, 'document_type'), '_dynamic')):
getattr(getattr(field, 'document_type', None), '_dynamic', None)):
new_field = DynamicField(db_field=field_name)
else:
# Look up subfield on the previous field or raise

View File

@@ -41,8 +41,8 @@ class BaseField(object):
def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False,
validation=None, choices=None, verbose_name=None,
help_text=None, null=False, sparse=False, custom_data=None):
validation=None, choices=None, null=False, sparse=False,
**kwargs):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
@@ -60,16 +60,15 @@ class BaseField(object):
field. Generally this is deprecated in favour of the
`FIELD.validate` method
:param choices: (optional) The valid choices
:param verbose_name: (optional) The verbose name for the field.
Designed to be human readable and is often used when generating
model forms from the document model.
:param help_text: (optional) The help text for this field and is often
used when generating model forms from the document model.
:param null: (optional) Is the field value can be null. If no and there is a default value
then the default value is set
:param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False`
means that uniqueness won't be enforced for `None` values
:param custom_data: (optional) Custom metadata for this field.
:param **kwargs: (optional) Arbitrary indirection-free metadata for
this field can be supplied as additional keyword arguments and
accessed as attributes of the field. Must not conflict with any
existing attributes. Common metadata includes `verbose_name` and
`help_text`.
"""
self.db_field = (db_field or name) if not primary_key else '_id'
@@ -83,12 +82,19 @@ class BaseField(object):
self.primary_key = primary_key
self.validation = validation
self.choices = choices
self.verbose_name = verbose_name
self.help_text = help_text
self.null = null
self.sparse = sparse
self._owner_document = None
self.custom_data = custom_data
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
raise TypeError("%s already has attribute(s): %s" % (
self.__class__.__name__, ', '.join(conflicts) ))
# Assign metadata to the instance
# This efficient method is available because no __slots__ are defined.
self.__dict__.update(kwargs)
# Adjust the appropriate creation counter, and save our local copy.
if self.db_field == '_id':
@@ -127,7 +133,7 @@ class BaseField(object):
if (self.name not in instance._data or
instance._data[self.name] != value):
instance._mark_as_changed(self.name)
except:
except Exception:
# Values cant be compared eg: naive and tz datetimes
# So mark it as changed
instance._mark_as_changed(self.name)
@@ -135,6 +141,10 @@ class BaseField(object):
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument):
value._instance = weakref.proxy(instance)
elif isinstance(value, (list, tuple)):
for v in value:
if isinstance(v, EmbeddedDocument):
v._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None):
@@ -148,7 +158,7 @@ class BaseField(object):
"""
return value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB-compatible type.
"""
return self.to_python(value)
@@ -275,8 +285,6 @@ class ComplexBaseField(BaseField):
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
Document = _import_class('Document')
if isinstance(value, basestring):
return value
@@ -296,6 +304,7 @@ class ComplexBaseField(BaseField):
value_dict = dict([(key, self.field.to_python(item))
for key, item in value.items()])
else:
Document = _import_class('Document')
value_dict = {}
for k, v in value.items():
if isinstance(v, Document):
@@ -315,7 +324,7 @@ class ComplexBaseField(BaseField):
key=operator.itemgetter(0))]
return value_dict
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB-compatible type.
"""
Document = _import_class("Document")
@@ -327,9 +336,10 @@ class ComplexBaseField(BaseField):
if hasattr(value, 'to_mongo'):
if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
return GenericReferenceField().to_mongo(
value, **kwargs)
cls = value.__class__
val = value.to_mongo()
val = value.to_mongo(**kwargs)
# If it's a document that is not inherited add _cls
if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__
@@ -344,7 +354,7 @@ class ComplexBaseField(BaseField):
return value
if self.field:
value_dict = dict([(key, self.field.to_mongo(item))
value_dict = dict([(key, self.field.to_mongo(item, **kwargs))
for key, item in value.iteritems()])
else:
value_dict = {}
@@ -363,19 +373,20 @@ class ComplexBaseField(BaseField):
meta.get('allow_inheritance', ALLOW_INHERITANCE)
is True)
if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v)
value_dict[k] = GenericReferenceField().to_mongo(
v, **kwargs)
else:
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
cls = v.__class__
val = v.to_mongo()
val = v.to_mongo(**kwargs)
# If it's a document that is not inherited add _cls
if isinstance(v, (Document, EmbeddedDocument)):
val['_cls'] = cls.__name__
value_dict[k] = val
else:
value_dict[k] = self.to_mongo(v)
value_dict[k] = self.to_mongo(v, **kwargs)
if is_list: # Convert back to a list
return [v for _, v in sorted(value_dict.items(),
@@ -429,11 +440,11 @@ class ObjectIdField(BaseField):
try:
if not isinstance(value, ObjectId):
value = ObjectId(value)
except:
except Exception:
pass
return value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if not isinstance(value, ObjectId):
try:
return ObjectId(unicode(value))
@@ -448,7 +459,7 @@ class ObjectIdField(BaseField):
def validate(self, value):
try:
ObjectId(unicode(value))
except:
except Exception:
self.error('Invalid Object ID')
@@ -500,7 +511,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0][0]
except:
except (TypeError, IndexError):
return "Invalid Polygon must contain at least one valid linestring"
errors = []
@@ -524,7 +535,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0]
except:
except (TypeError, IndexError):
return "Invalid LineString must contain at least one valid point"
errors = []
@@ -555,7 +566,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0]
except:
except (TypeError, IndexError):
return "Invalid MultiPoint must contain at least one valid point"
errors = []
@@ -574,7 +585,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0][0]
except:
except (TypeError, IndexError):
return "Invalid MultiLineString must contain at least one valid linestring"
errors = []
@@ -596,7 +607,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0][0][0]
except:
except (TypeError, IndexError):
return "Invalid MultiPolygon must contain at least one valid Polygon"
errors = []
@@ -608,7 +619,7 @@ class GeoJsonBaseField(BaseField):
if errors:
return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if isinstance(value, dict):
return value
return SON([("type", self._type), ("coordinates", value)])

View File

@@ -38,8 +38,11 @@ def register_connection(alias, name=None, host=None, port=None,
:param username: username to authenticate with
:param password: password to authenticate with
:param authentication_source: database to authenticate against
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
.. versionchanged:: 0.10.6 - added mongomock support
"""
global _connection_settings
@@ -54,8 +57,13 @@ def register_connection(alias, name=None, host=None, port=None,
}
# Handle uri style connections
if "://" in conn_settings['host']:
uri_dict = uri_parser.parse_uri(conn_settings['host'])
conn_host = conn_settings['host']
if conn_host.startswith('mongomock://'):
conn_settings['is_mock'] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
conn_settings['host'] = conn_host.replace('mongomock://', 'mongodb://', 1)
elif '://' in conn_host:
uri_dict = uri_parser.parse_uri(conn_host)
conn_settings.update({
'name': uri_dict.get('database') or name,
'username': uri_dict.get('username'),
@@ -106,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings.pop('password', None)
conn_settings.pop('authentication_source', None)
connection_class = MongoClient
is_mock = conn_settings.pop('is_mock', None)
if is_mock:
# Use MongoClient from mongomock
try:
import mongomock
except ImportError:
raise RuntimeError('You need mongomock installed '
'to mock MongoEngine.')
connection_class = mongomock.MongoClient
else:
# Use MongoClient from pymongo
connection_class = MongoClient
if 'replicaSet' in conn_settings:
# Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None)
@@ -126,6 +146,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_settings.pop('name', None)
connection_settings.pop('username', None)
connection_settings.pop('password', None)
connection_settings.pop('authentication_source', None)
if conn_settings == connection_settings and _connections.get(db_alias, None):
connection = _connections[db_alias]
break

View File

@@ -1,5 +1,7 @@
from bson import DBRef, SON
from mongoengine.python_support import txt_type
from base import (
BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document
@@ -226,7 +228,7 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection, v['_ref'].id), v)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = "{0}.{1}.{2}".format(name, k, field_name)
item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name

View File

@@ -217,7 +217,7 @@ class Document(BaseDocument):
Returns True if the document has been updated or False if the document
in the database doesn't match the query.
.. note:: All unsaved changes that has been made to the document are
.. note:: All unsaved changes that have been made to the document are
rejected if the method returns True.
:param query: the update will be performed only if the document in the
@@ -250,7 +250,7 @@ class Document(BaseDocument):
def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, save_condition=None, **kwargs):
_refs=None, save_condition=None, signal_kwargs=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -276,6 +276,8 @@ class Document(BaseDocument):
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`OperationError` if the conditions are not satisfied
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
@@ -297,8 +299,11 @@ class Document(BaseDocument):
:class:`OperationError` exception raised if save_condition fails.
.. versionchanged:: 0.10.1
:class: save_condition failure now raises a `SaveConditionError`
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
signals.pre_save.send(self.__class__, document=self)
signal_kwargs = signal_kwargs or {}
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
if validate:
self.validate(clean=clean)
@@ -311,7 +316,7 @@ class Document(BaseDocument):
created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created)
created=created, **signal_kwargs)
try:
collection = self._get_collection()
@@ -341,8 +346,12 @@ class Document(BaseDocument):
select_dict['_id'] = object_id
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
select_dict[actual_key] = doc[actual_key]
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
def is_new_object(last_error):
if last_error is not None:
@@ -396,14 +405,15 @@ class Document(BaseDocument):
if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self, created=created)
signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs)
self._clear_changed_fields()
self._created = False
return self
def cascade_save(self, *args, **kwargs):
"""Recursively saves any references /
generic references on an objects"""
generic references on the document"""
_refs = kwargs.get('_refs', []) or []
ReferenceField = _import_class('ReferenceField')
@@ -444,7 +454,12 @@ class Document(BaseDocument):
select_dict = {'pk': self.pk}
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
select_dict[k] = getattr(self, k)
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = self
for ak in actual_key:
val = getattr(val, ak)
select_dict['__'.join(actual_key)] = val
return select_dict
def update(self, **kwargs):
@@ -467,18 +482,24 @@ class Document(BaseDocument):
# Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs)
def delete(self, **write_concern):
def delete(self, signal_kwargs=None, **write_concern):
"""Delete the :class:`~mongoengine.Document` from the database. This
will only take effect if the document has been previously saved.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
signals.pre_delete.send(self.__class__, document=self)
signal_kwargs = signal_kwargs or {}
signals.pre_delete.send(self.__class__, document=self, **signal_kwargs)
# Delete FileFields separately
FileField = _import_class('FileField')
@@ -492,7 +513,7 @@ class Document(BaseDocument):
except pymongo.errors.OperationFailure, err:
message = u'Could not delete document (%s)' % err.message
raise OperationError(message)
signals.post_delete.send(self.__class__, document=self)
signals.post_delete.send(self.__class__, document=self, **signal_kwargs)
def switch_db(self, db_alias, keep_created=True):
"""
@@ -595,11 +616,16 @@ class Document(BaseDocument):
if not fields or field in fields:
try:
setattr(self, field, self._reload(field, obj[field]))
except KeyError:
# If field is removed from the database while the object
# is in memory, a reload would cause a KeyError
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
except (KeyError, AttributeError):
try:
# If field is a special field, e.g. items is stored as _reserved_items,
# an KeyError is thrown. So try to retrieve the field from _data
setattr(self, field, self._reload(field, obj._data.get(field)))
except KeyError:
# If field is removed from the database while the object
# is in memory, a reload would cause a KeyError
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
self._changed_fields = obj._changed_fields
self._created = False
@@ -653,10 +679,20 @@ class Document(BaseDocument):
def drop_collection(cls):
"""Drops the entire collection associated with this
:class:`~mongoengine.Document` type from the database.
Raises :class:`OperationError` if the document has no collection set
(i.g. if it is `abstract`)
.. versionchanged:: 0.10.7
:class:`OperationError` exception raised if no collection available
"""
col_name = cls._get_collection_name()
if not col_name:
raise OperationError('Document %s has no collection defined '
'(is it abstract ?)' % cls)
cls._collection = None
db = cls._get_db()
db.drop_collection(cls._get_collection_name())
db.drop_collection(col_name)
@classmethod
def create_index(cls, keys, background=False, **kwargs):
@@ -945,7 +981,7 @@ class MapReduceDocument(object):
if not isinstance(self.key, id_field_type):
try:
self.key = id_field_type(self.key)
except:
except Exception:
raise Exception("Could not cast key as %s" %
id_field_type.__name__)

View File

@@ -6,7 +6,7 @@ from mongoengine.python_support import txt_type
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
'ValidationError')
'ValidationError', 'SaveConditionError')
class NotRegistered(Exception):

View File

@@ -8,6 +8,8 @@ import uuid
import warnings
from operator import itemgetter
import six
try:
import dateutil
except ImportError:
@@ -18,6 +20,10 @@ else:
import pymongo
import gridfs
from bson import Binary, DBRef, SON, ObjectId
try:
from bson.int64 import Int64
except ImportError:
Int64 = long
from mongoengine.errors import ValidationError
from mongoengine.python_support import (PY3, bin_type, txt_type,
@@ -65,7 +71,7 @@ class StringField(BaseField):
return value
try:
value = value.decode('utf-8')
except:
except Exception:
pass
return value
@@ -194,7 +200,7 @@ class IntField(BaseField):
def validate(self, value):
try:
value = int(value)
except:
except Exception:
self.error('%s could not be converted to int' % value)
if self.min_value is not None and value < self.min_value:
@@ -225,10 +231,13 @@ class LongField(BaseField):
pass
return value
def to_mongo(self, value, **kwargs):
return Int64(value)
def validate(self, value):
try:
value = long(value)
except:
except Exception:
self.error('%s could not be converted to long' % value)
if self.min_value is not None and value < self.min_value:
@@ -260,10 +269,14 @@ class FloatField(BaseField):
return value
def validate(self, value):
if isinstance(value, int):
value = float(value)
if isinstance(value, six.integer_types):
try:
value = float(value)
except OverflowError:
self.error('The value is too large to be converted to float')
if not isinstance(value, float):
self.error('FloatField only accepts float values')
self.error('FloatField only accepts float and integer values')
if self.min_value is not None and value < self.min_value:
self.error('Float value is too small')
@@ -325,7 +338,7 @@ class DecimalField(BaseField):
return value
return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding)
def to_mongo(self, value, use_db_field=True):
def to_mongo(self, value, **kwargs):
if value is None:
return value
if self.force_string:
@@ -388,7 +401,7 @@ class DateTimeField(BaseField):
if not isinstance(new_value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if value is None:
return value
if isinstance(value, datetime.datetime):
@@ -508,10 +521,10 @@ class ComplexDateTimeField(StringField):
original_value = value
try:
return self._convert_from_string(value)
except:
except Exception:
return original_value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
value = self.to_python(value)
return self._convert_from_datetime(value)
@@ -546,11 +559,10 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._from_son(value, _auto_dereference=self._auto_dereference)
return value
def to_mongo(self, value, use_db_field=True, fields=[]):
def to_mongo(self, value, **kwargs):
if not isinstance(value, self.document_type):
return value
return self.document_type.to_mongo(value, use_db_field,
fields=fields)
return self.document_type.to_mongo(value, **kwargs)
def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the
@@ -600,11 +612,11 @@ class GenericEmbeddedDocumentField(BaseField):
value.validate(clean=clean)
def to_mongo(self, document, use_db_field=True):
def to_mongo(self, document, **kwargs):
if document is None:
return None
data = document.to_mongo(use_db_field)
data = document.to_mongo(**kwargs)
if '_cls' not in data:
data['_cls'] = document._class_name
return data
@@ -616,7 +628,7 @@ class DynamicField(BaseField):
Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB compatible type.
"""
@@ -625,7 +637,7 @@ class DynamicField(BaseField):
if hasattr(value, 'to_mongo'):
cls = value.__class__
val = value.to_mongo()
val = value.to_mongo(**kwargs)
# If we its a document thats not inherited add _cls
if isinstance(value, Document):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
@@ -643,7 +655,7 @@ class DynamicField(BaseField):
data = {}
for k, v in value.iteritems():
data[k] = self.to_mongo(v)
data[k] = self.to_mongo(v, **kwargs)
value = data
if is_list: # Convert back to a list
@@ -697,7 +709,7 @@ class ListField(ComplexBaseField):
def prepare_query_value(self, op, value):
if self.field:
if op in ('set', 'unset') and (
if op in ('set', 'unset', None) and (
not isinstance(value, basestring) and
not isinstance(value, BaseDocument) and
hasattr(value, '__iter__')):
@@ -755,8 +767,8 @@ class SortedListField(ListField):
self._order_reverse = kwargs.pop('reverse')
super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value)
def to_mongo(self, value, **kwargs):
value = super(SortedListField, self).to_mongo(value, **kwargs)
if self._ordering is not None:
return sorted(value, key=itemgetter(self._ordering),
reverse=self._order_reverse)
@@ -863,12 +875,11 @@ class ReferenceField(BaseField):
The options are:
* DO_NOTHING - don't do anything (default).
* NULLIFY - Updates the reference to null.
* CASCADE - Deletes the documents associated with the reference.
* DENY - Prevent the deletion of the reference object.
* PULL - Pull the reference from a :class:`~mongoengine.fields.ListField`
of references
* DO_NOTHING (0) - don't do anything (default).
* NULLIFY (1) - Updates the reference to null.
* CASCADE (2) - Deletes the documents associated with the reference.
* DENY (3) - Prevent the deletion of the reference object.
* PULL (4) - Pull the reference from a :class:`~mongoengine.fields.ListField` of references
Alternative syntax for registering delete rules (useful when implementing
bi-directional delete rules)
@@ -879,7 +890,7 @@ class ReferenceField(BaseField):
content = StringField()
foo = ReferenceField('Foo')
Bar.register_delete_rule(Foo, 'bar', NULLIFY)
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
.. note ::
`reverse_delete_rule` does not trigger pre / post delete signals to be
@@ -896,6 +907,10 @@ class ReferenceField(BaseField):
or as the :class:`~pymongo.objectid.ObjectId`.id .
:param reverse_delete_rule: Determines what to do when the referring
object is deleted
.. note ::
A reference to an abstract document type is always stored as a
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
if not isinstance(document_type, basestring):
if not issubclass(document_type, (Document, basestring)):
@@ -928,33 +943,46 @@ class ReferenceField(BaseField):
self._auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs
if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value)
if hasattr(value, 'cls'):
# Dereference using the class type specified in the reference
cls = get_document(value.cls)
else:
cls = self.document_type
value = cls._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)
instance._data[self.name] = cls._from_son(value)
return super(ReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
def to_mongo(self, document, **kwargs):
if isinstance(document, DBRef):
if not self.dbref:
return document.id
return document
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.pk
if id_ is None:
self.error('You can only reference documents once they have'
' been saved to the database')
# Use the attributes from the document instance, so that they
# override the attributes of this field's document type
cls = document
else:
id_ = document
cls = self.document_type
id_ = id_field.to_mongo(id_)
if self.dbref:
collection = self.document_type._get_collection_name()
id_field_name = cls._meta['id_field']
id_field = cls._fields[id_field_name]
id_ = id_field.to_mongo(id_, **kwargs)
if self.document_type._meta.get('abstract'):
collection = cls._get_collection_name()
return DBRef(collection, id_, cls=cls._class_name)
elif self.dbref:
collection = cls._get_collection_name()
return DBRef(collection, id_)
return id_
@@ -983,6 +1011,14 @@ class ReferenceField(BaseField):
self.error('You can only reference documents once they have been '
'saved to the database')
if self.document_type._meta.get('abstract') and \
not isinstance(value, self.document_type):
self.error('%s is not an instance of abstract reference'
' type %s' % (value._class_name,
self.document_type._class_name)
)
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
@@ -990,7 +1026,7 @@ class ReferenceField(BaseField):
class CachedReferenceField(BaseField):
"""
A referencefield with cache fields to purpose pseudo-joins
.. versionadded:: 0.9
"""
@@ -1064,7 +1100,7 @@ class CachedReferenceField(BaseField):
return super(CachedReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
def to_mongo(self, document, **kwargs):
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
@@ -1079,10 +1115,11 @@ class CachedReferenceField(BaseField):
# TODO: should raise here or will fail next statement
value = SON((
("_id", id_field.to_mongo(id_)),
("_id", id_field.to_mongo(id_, **kwargs)),
))
value.update(dict(document.to_mongo(fields=self.fields)))
kwargs['fields'] = self.fields
value.update(dict(document.to_mongo(**kwargs)))
return value
def prepare_query_value(self, op, value):
@@ -1198,7 +1235,7 @@ class GenericReferenceField(BaseField):
doc = doc_cls._from_son(doc)
return doc
def to_mongo(self, document, use_db_field=True):
def to_mongo(self, document, **kwargs):
if document is None:
return None
@@ -1217,7 +1254,7 @@ class GenericReferenceField(BaseField):
else:
id_ = document
id_ = id_field.to_mongo(id_)
id_ = id_field.to_mongo(id_, **kwargs)
collection = document._get_collection_name()
ref = DBRef(collection, id_)
return SON((
@@ -1246,7 +1283,7 @@ class BinaryField(BaseField):
value = bin_type(value)
return super(BinaryField, self).__set__(instance, value)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
return Binary(value)
def validate(self, value):
@@ -1346,7 +1383,7 @@ class GridFSProxy(object):
if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except:
except Exception:
# File has been deleted
return None
@@ -1384,7 +1421,7 @@ class GridFSProxy(object):
else:
try:
return gridout.read(size)
except:
except Exception:
return ""
def delete(self):
@@ -1449,7 +1486,7 @@ class FileField(BaseField):
if grid_file:
try:
grid_file.delete()
except:
except Exception:
pass
# Create a new proxy object as we don't already have one
@@ -1471,7 +1508,7 @@ class FileField(BaseField):
db_alias=db_alias,
collection_name=collection_name)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
# Store the GridFS file id in MongoDB
if isinstance(value, self.proxy_class) and value.grid_id is not None:
return value.grid_id
@@ -1683,17 +1720,17 @@ class SequenceField(BaseField):
:param collection_name: Name of the counter collection (default 'mongoengine.counters')
:param sequence_name: Name of the sequence in the collection (default 'ClassName.counter')
:param value_decorator: Any callable to use as a counter (default int)
Use any callable as `value_decorator` to transform calculated counter into
any value suitable for your needs, e.g. string or hexadecimal
representation of the default integer counter value.
.. note::
In case the counter is defined in the abstract document, it will be
common to all inherited documents and the default sequence name will
In case the counter is defined in the abstract document, it will be
common to all inherited documents and the default sequence name will
be the class name of the abstract document.
.. versionadded:: 0.5
.. versionchanged:: 0.8 added `value_decorator`
"""
@@ -1817,11 +1854,11 @@ class UUIDField(BaseField):
if not isinstance(value, basestring):
value = unicode(value)
return uuid.UUID(value)
except:
except Exception:
return original_value
return value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if not self._binary:
return unicode(value)
elif isinstance(value, basestring):

View File

@@ -266,7 +266,8 @@ class BaseQuerySet(object):
result = None
return result
def insert(self, doc_or_docs, load_bulk=True, write_concern=None):
def insert(self, doc_or_docs, load_bulk=True,
write_concern=None, signal_kwargs=None):
"""bulk insert documents
:param doc_or_docs: a document or list of documents to be inserted
@@ -279,11 +280,15 @@ class BaseQuerySet(object):
``insert(..., {w: 2, fsync: True})`` will wait until at least
two servers have recorded the write and will force an fsync on
each server being written to.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds``
.. versionadded:: 0.5
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
Document = _import_class('Document')
@@ -296,7 +301,6 @@ class BaseQuerySet(object):
return_one = True
docs = [docs]
raw = []
for doc in docs:
if not isinstance(doc, self._document):
msg = ("Some documents inserted aren't instances of %s"
@@ -305,9 +309,12 @@ class BaseQuerySet(object):
if doc.pk and not doc._created:
msg = "Some documents have ObjectIds use doc.update() instead"
raise OperationError(msg)
raw.append(doc.to_mongo())
signals.pre_bulk_insert.send(self._document, documents=docs)
signal_kwargs = signal_kwargs or {}
signals.pre_bulk_insert.send(self._document,
documents=docs, **signal_kwargs)
raw = [doc.to_mongo() for doc in docs]
try:
ids = self._collection.insert(raw, **write_concern)
except pymongo.errors.DuplicateKeyError, err:
@@ -324,7 +331,7 @@ class BaseQuerySet(object):
if not load_bulk:
signals.post_bulk_insert.send(
self._document, documents=docs, loaded=False)
self._document, documents=docs, loaded=False, **signal_kwargs)
return return_one and ids[0] or ids
documents = self.in_bulk(ids)
@@ -332,7 +339,7 @@ class BaseQuerySet(object):
for obj_id in ids:
results.append(documents.get(obj_id))
signals.post_bulk_insert.send(
self._document, documents=results, loaded=True)
self._document, documents=results, loaded=True, **signal_kwargs)
return return_one and results[0] or results
def count(self, with_limit_and_skip=False):
@@ -403,8 +410,10 @@ class BaseQuerySet(object):
rule = doc._meta['delete_rules'][rule_entry]
if rule == CASCADE:
cascade_refs = set() if cascade_refs is None else cascade_refs
for ref in queryset:
cascade_refs.add(ref.id)
# Handle recursive reference
if doc._collection == document_cls._collection:
for ref in queryset:
cascade_refs.add(ref.id)
ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs})
ref_q_count = ref_q.count()
if ref_q_count > 0:
@@ -425,7 +434,7 @@ class BaseQuerySet(object):
full_result=False, **update):
"""Perform an atomic update on the fields matched by the query.
:param upsert: Any existing document with that "_id" is overwritten.
:param upsert: insert if document doesn't exist (default ``False``)
:param multi: Update multiple documents.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
@@ -471,10 +480,36 @@ class BaseQuerySet(object):
raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, upsert=False, write_concern=None, **update):
"""Perform an atomic update on first field matched by the query.
def upsert_one(self, write_concern=None, **update):
"""Overwrite or add the first document matched by the query.
:param upsert: Any existing document with that "_id" is overwritten.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param update: Django-style update keyword arguments
:returns the new or overwritten document
.. versionadded:: 0.10.2
"""
atomic_update = self.update(multi=False, upsert=True, write_concern=write_concern,
full_result=True, **update)
if atomic_update['updatedExisting']:
document = self.get()
else:
document = self._document.objects.with_id(atomic_update['upserted'])
return document
def update_one(self, upsert=False, write_concern=None, **update):
"""Perform an atomic update on the fields of the first document
matched by the query.
:param upsert: insert if document doesn't exist (default ``False``)
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
@@ -929,6 +964,7 @@ class BaseQuerySet(object):
validate_read_preference('read_preference', read_preference)
queryset = self.clone()
queryset._read_preference = read_preference
queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_preference
return queryset
def scalar(self, *fields):
@@ -1201,66 +1237,28 @@ class BaseQuerySet(object):
def sum(self, field):
"""Sum over the values of the specified field.
:param field: the field to sum over; use dot-notation to refer to
:param field: the field to sum over; use dot notation to refer to
embedded document fields
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
with sharding.
"""
map_func = """
function() {
var path = '{{~%(field)s}}'.split('.'),
field = this;
for (p in path) {
if (typeof field != 'undefined')
field = field[path[p]];
else
break;
}
if (field && field.constructor == Array) {
field.forEach(function(item) {
emit(1, item||0);
});
} else if (typeof field != 'undefined') {
emit(1, field||0);
}
}
""" % dict(field=field)
reduce_func = Code("""
function(key, values) {
var sum = 0;
for (var i in values) {
sum += values[i];
}
return sum;
}
""")
for result in self.map_reduce(map_func, reduce_func, output='inline'):
return result.value
else:
return 0
def aggregate_sum(self, field):
"""Sum over the values of the specified field.
:param field: the field to sum over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `sum`, because it uses
the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
])
]
# if we're performing a sum over a list field, we sum up all the
# elements in the list, hence we need to $unwind the arrays first
ListField = _import_class('ListField')
field_parts = field.split('.')
field_instances = self._document._lookup_field(field_parts)
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline)
if IS_PYMONGO_3:
result = list(result)
result = tuple(result)
else:
result = result.get('result')
if result:
return result[0]['total']
return 0
@@ -1268,73 +1266,26 @@ class BaseQuerySet(object):
def average(self, field):
"""Average over the values of the specified field.
:param field: the field to average over; use dot-notation to refer to
:param field: the field to average over; use dot notation to refer to
embedded document fields
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
with sharding.
"""
map_func = """
function() {
var path = '{{~%(field)s}}'.split('.'),
field = this;
for (p in path) {
if (typeof field != 'undefined')
field = field[path[p]];
else
break;
}
if (field && field.constructor == Array) {
field.forEach(function(item) {
emit(1, {t: item||0, c: 1});
});
} else if (typeof field != 'undefined') {
emit(1, {t: field||0, c: 1});
}
}
""" % dict(field=field)
reduce_func = Code("""
function(key, values) {
var out = {t: 0, c: 0};
for (var i in values) {
var value = values[i];
out.t += value.t;
out.c += value.c;
}
return out;
}
""")
finalize_func = Code("""
function(key, value) {
return value.t / value.c;
}
""")
for result in self.map_reduce(map_func, reduce_func,
finalize_f=finalize_func, output='inline'):
return result.value
else:
return 0
def aggregate_average(self, field):
"""Average over the values of the specified field.
:param field: the field to average over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `average`, because it
uses the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
])
]
# if we're performing an average over a list field, we average out
# all the elements in the list, hence we need to $unwind the arrays
# first
ListField = _import_class('ListField')
field_parts = field.split('.')
field_instances = self._document._lookup_field(field_parts)
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline)
if IS_PYMONGO_3:
result = list(result)
result = tuple(result)
else:
result = result.get('result')
if result:
@@ -1351,7 +1302,7 @@ class BaseQuerySet(object):
Can only do direct simple mappings and cannot map across
:class:`~mongoengine.fields.ReferenceField` or
:class:`~mongoengine.fields.GenericReferenceField` for more complex
counting a manual map reduce call would is required.
counting a manual map reduce call is required.
If the field is a :class:`~mongoengine.fields.ListField`, the items within
each list will be counted individually.
@@ -1425,7 +1376,7 @@ class BaseQuerySet(object):
msg = "The snapshot option is not anymore available with PyMongo 3+"
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': self._timeout
'no_cursor_timeout': not self._timeout
}
if self._loaded_fields:
cursor_args[fields_name] = self._loaded_fields.as_dict()
@@ -1442,8 +1393,16 @@ class BaseQuerySet(object):
def _cursor(self):
if self._cursor_obj is None:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned
# collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
@@ -1660,7 +1619,7 @@ class BaseQuerySet(object):
key = key.replace('__', '.')
try:
key = self._document._translate_field_name(key)
except:
except Exception:
pass
key_list.append((key, direction))

View File

@@ -29,7 +29,7 @@ class QuerySetManager(object):
Document.objects is accessed.
"""
if instance is not None:
# Document class being used rather than a document object
# Document object being used rather than a document class
return self
# owner is the document that contains the QuerySetManager

View File

@@ -38,7 +38,7 @@ class QuerySet(BaseQuerySet):
def __len__(self):
"""Since __len__ is called quite frequently (for example, as part of
list(qs) we populate the result cache and cache the length.
list(qs)), we populate the result cache and cache the length.
"""
if self._len is not None:
return self._len

View File

@@ -26,12 +26,12 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS)
def query(_doc_cls=None, **query):
def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format.
"""
mongo_query = {}
merge_query = defaultdict(list)
for key, value in sorted(query.items()):
for key, value in sorted(kwargs.items()):
if key == "__raw__":
mongo_query.update(value)
continue
@@ -44,7 +44,7 @@ def query(_doc_cls=None, **query):
if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
op = parts.pop()
# Allw to escape operator-like field name by __
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
parts.pop()
@@ -105,13 +105,18 @@ def query(_doc_cls=None, **query):
if op:
if op in GEO_OPERATORS:
value = _geo_operator(field, op, value)
elif op in CUSTOM_OPERATORS:
if op in ('elem_match', 'match'):
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value}
elif op in ('match', 'elemMatch'):
ListField = _import_class('ListField')
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
if (isinstance(value, dict) and isinstance(field, ListField) and
isinstance(field.field, EmbeddedDocumentField)):
value = query(field.field.document_type, **value)
else:
NotImplementedError("Custom method '%s' has not "
"been implemented" % op)
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value}
elif op in CUSTOM_OPERATORS:
NotImplementedError("Custom method '%s' has not "
"been implemented" % op)
elif op not in STRING_OPERATORS:
value = {'$' + op: value}
@@ -207,6 +212,10 @@ def update(_doc_cls=None, **update):
if parts[-1] in COMPARISON_OPERATORS:
match = parts.pop()
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
parts.pop()
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
try:
@@ -359,20 +368,24 @@ def _infer_geometry(value):
"type and coordinates keys")
elif isinstance(value, (list, set)):
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
# TODO: should both TypeError and IndexError be alike interpreted?
try:
value[0][0][0]
return {"$geometry": {"type": "Polygon", "coordinates": value}}
except:
except (TypeError, IndexError):
pass
try:
value[0][0]
return {"$geometry": {"type": "LineString", "coordinates": value}}
except:
except (TypeError, IndexError):
pass
try:
value[0]
return {"$geometry": {"type": "Point", "coordinates": value}}
except:
except (TypeError, IndexError):
pass
raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary "