Merge branch 'master' of https://github.com/MongoEngine/mongoengine into swat_url_fix

This commit is contained in:
Bastien Gérard
2018-09-16 23:03:16 +02:00
87 changed files with 5596 additions and 3458 deletions

View File

@@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
list(signals.__all__) + list(errors.__all__))
VERSION = (0, 11, 0)
VERSION = (0, 15, 3)
def get_version():

View File

@@ -15,7 +15,7 @@ __all__ = (
'UPDATE_OPERATORS', '_document_registry', 'get_document',
# datastructures
'BaseDict', 'BaseList', 'EmbeddedDocumentList',
'BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference',
# document
'BaseDocument',

View File

@@ -3,9 +3,10 @@ from mongoengine.errors import NotRegistered
__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry')
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
'push_all', 'pull', 'pull_all', 'add_to_set',
'set_on_insert', 'min', 'max'])
UPDATE_OPERATORS = {'set', 'unset', 'inc', 'dec', 'mul',
'pop', 'push', 'push_all', 'pull',
'pull_all', 'add_to_set', 'set_on_insert',
'min', 'max', 'rename'}
_document_registry = {}

View File

@@ -1,12 +1,13 @@
import itertools
import weakref
from bson import DBRef
import six
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
__all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList')
__all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference')
class BaseDict(dict):
@@ -127,8 +128,8 @@ class BaseList(list):
return value
def __iter__(self):
for i in xrange(self.__len__()):
yield self[i]
for v in super(BaseList, self).__iter__():
yield v
def __setitem__(self, key, value, *args, **kwargs):
if isinstance(key, slice):
@@ -137,11 +138,8 @@ class BaseList(list):
self._mark_as_changed(key)
return super(BaseList, self).__setitem__(key, value)
def __delitem__(self, key, *args, **kwargs):
if isinstance(key, slice):
self._mark_as_changed()
else:
self._mark_as_changed(key)
def __delitem__(self, key):
self._mark_as_changed()
return super(BaseList, self).__delitem__(key)
def __setslice__(self, *args, **kwargs):
@@ -189,7 +187,7 @@ class BaseList(list):
self._mark_as_changed()
return super(BaseList, self).remove(*args, **kwargs)
def reverse(self, *args, **kwargs):
def reverse(self):
self._mark_as_changed()
return super(BaseList, self).reverse()
@@ -236,6 +234,9 @@ class EmbeddedDocumentList(BaseList):
Filters the list by only including embedded documents with the
given keyword arguments.
This method only supports simple comparison (e.g: .filter(name='John Doe'))
and does not support operators like __gte, __lte, __icontains like queryset.filter does
:param kwargs: The keyword arguments corresponding to the fields to
filter on. *Multiple arguments are treated as if they are ANDed
together.*
@@ -353,7 +354,8 @@ class EmbeddedDocumentList(BaseList):
def update(self, **update):
"""
Updates the embedded documents with the given update values.
Updates the embedded documents with the given replacement values. This
function does not support mongoDB update operators such as ``inc__``.
.. note::
The embedded document changes are not automatically saved
@@ -375,7 +377,7 @@ class EmbeddedDocumentList(BaseList):
class StrictDict(object):
__slots__ = ()
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
_special_fields = {'get', 'pop', 'iteritems', 'items', 'keys', 'create'}
_classes = {}
def __init__(self, **kwargs):
@@ -432,7 +434,7 @@ class StrictDict(object):
def __eq__(self, other):
return self.items() == other.items()
def __neq__(self, other):
def __ne__(self, other):
return self.items() != other.items()
@classmethod
@@ -450,40 +452,40 @@ class StrictDict(object):
return cls._classes[allowed_keys]
class SemiStrictDict(StrictDict):
__slots__ = ('_extras', )
_classes = {}
class LazyReference(DBRef):
__slots__ = ('_cached_doc', 'passthrough', 'document_type')
def __getattr__(self, attr):
try:
super(SemiStrictDict, self).__getattr__(attr)
except AttributeError:
try:
return self.__getattribute__('_extras')[attr]
except KeyError as e:
raise AttributeError(e)
def fetch(self, force=False):
if not self._cached_doc or force:
self._cached_doc = self.document_type.objects.get(pk=self.pk)
if not self._cached_doc:
raise DoesNotExist('Trying to dereference unknown document %s' % (self))
return self._cached_doc
def __setattr__(self, attr, value):
try:
super(SemiStrictDict, self).__setattr__(attr, value)
except AttributeError:
try:
self._extras[attr] = value
except AttributeError:
self._extras = {attr: value}
@property
def pk(self):
return self.id
def __delattr__(self, attr):
try:
super(SemiStrictDict, self).__delattr__(attr)
except AttributeError:
try:
del self._extras[attr]
except KeyError as e:
raise AttributeError(e)
def __init__(self, document_type, pk, cached_doc=None, passthrough=False):
self.document_type = document_type
self._cached_doc = cached_doc
self.passthrough = passthrough
super(LazyReference, self).__init__(self.document_type._get_collection_name(), pk)
def __iter__(self):
def __getitem__(self, name):
if not self.passthrough:
raise KeyError()
document = self.fetch()
return document[name]
def __getattr__(self, name):
if not object.__getattribute__(self, 'passthrough'):
raise AttributeError()
document = self.fetch()
try:
extras_iter = iter(self.__getattribute__('_extras'))
except AttributeError:
extras_iter = ()
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)
return document[name]
except KeyError:
raise AttributeError()
def __repr__(self):
return "<LazyReference(%s, %r)>" % (self.document_type, self.pk)

View File

@@ -13,13 +13,14 @@ from mongoengine import signals
from mongoengine.base.common import get_document
from mongoengine.base.datastructures import (BaseDict, BaseList,
EmbeddedDocumentList,
SemiStrictDict, StrictDict)
LazyReference,
StrictDict)
from mongoengine.base.fields import ComplexBaseField
from mongoengine.common import _import_class
from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
LookUpError, OperationError, ValidationError)
__all__ = ('BaseDocument',)
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__'
@@ -79,8 +80,7 @@ class BaseDocument(object):
if self.STRICT and not self._dynamic:
self._data = StrictDict.create(allowed_keys=self._fields_ordered)()
else:
self._data = SemiStrictDict.create(
allowed_keys=self._fields_ordered)()
self._data = {}
self._dynamic_fields = SON()
@@ -100,13 +100,11 @@ class BaseDocument(object):
for key, value in values.iteritems():
if key in self._fields or key == '_id':
setattr(self, key, value)
elif self._dynamic:
else:
dynamic_data[key] = value
else:
FileField = _import_class('FileField')
for key, value in values.iteritems():
if key == '__auto_convert':
continue
key = self._reverse_db_field_map.get(key, key)
if key in self._fields or key in ('id', 'pk', '_cls'):
if __auto_convert and value is not None:
@@ -147,7 +145,7 @@ class BaseDocument(object):
if not hasattr(self, name) and not name.startswith('_'):
DynamicField = _import_class('DynamicField')
field = DynamicField(db_field=name)
field = DynamicField(db_field=name, null=True)
field.name = name
self._dynamic_fields[name] = field
self._fields_ordered += (name,)
@@ -272,13 +270,6 @@ class BaseDocument(object):
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
if getattr(self, 'pk', None) is None:
# For new object
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
def clean(self):
"""
Hook for doing document level data cleaning before validation is run.
@@ -311,7 +302,7 @@ class BaseDocument(object):
data['_cls'] = self._class_name
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
root_fields = set([f.split('.')[0] for f in fields])
root_fields = {f.split('.')[0] for f in fields}
for field_name in self:
if root_fields and field_name not in root_fields:
@@ -344,7 +335,7 @@ class BaseDocument(object):
value = field.generate()
self._data[field_name] = value
if value is not None:
if (value is not None) or (field.null):
if use_db_field:
data[field.db_field] = value
else:
@@ -402,16 +393,26 @@ class BaseDocument(object):
raise ValidationError(message, errors=errors)
def to_json(self, *args, **kwargs):
"""Converts a document to JSON.
:param use_db_field: Set to True by default but enables the output of the json structure with the field names
and not the mongodb store db_names in case of set to False
"""Convert this document to JSON.
:param use_db_field: Serialize field names as they appear in
MongoDB (as opposed to attribute names on this document).
Defaults to True.
"""
use_db_field = kwargs.pop('use_db_field', True)
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
@classmethod
def from_json(cls, json_data, created=False):
"""Converts json data to an unsaved document instance"""
"""Converts json data to a Document instance
:param json_data: The json data to load into the Document
:param created: If True, the document will be considered as a brand new document
If False and an id is provided, it will consider that the data being
loaded corresponds to what's already in the database (This has an impact of subsequent call to .save())
If False and no id is provided, it will consider the data as a new document
(default ``False``)
"""
return cls._from_son(json_util.loads(json_data), created=created)
def __expand_dynamic_values(self, name, value):
@@ -494,7 +495,7 @@ class BaseDocument(object):
else:
data = getattr(data, part, None)
if hasattr(data, '_changed_fields'):
if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'):
if getattr(data, '_is_document', False):
continue
@@ -566,7 +567,7 @@ class BaseDocument(object):
continue
elif isinstance(field, SortedListField) and field._ordering:
# if ordering is affected whole list is changed
if any(map(lambda d: field._ordering in d._changed_fields, data)):
if any(field._ordering in d._changed_fields for d in data):
changed_fields.append(db_field_name)
continue
@@ -675,12 +676,20 @@ class BaseDocument(object):
if not only_fields:
only_fields = []
if son and not isinstance(son, dict):
raise ValueError("The source SON object needs to be of type 'dict'")
# Get the class name from the document, falling back to the given
# class if unavailable
class_name = son.get('_cls', cls._class_name)
# Convert SON to a dict, making sure each key is a string
data = {str(key): value for key, value in son.iteritems()}
# Convert SON to a data dict, making sure each key is a string and
# corresponds to the right db field.
data = {}
for key, value in son.iteritems():
key = str(key)
key = cls._db_field_map.get(key, key)
data[key] = value
# Return correct subclass for document type
if class_name != cls._class_name:
@@ -1077,5 +1086,11 @@ class BaseDocument(object):
"""Return the display value for a choice field"""
value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value)
if value is None:
return None
sep = getattr(field, 'display_sep', ' ')
values = value if field.__class__.__name__ in ('ListField', 'SortedListField') else [value]
return sep.join([
six.text_type(dict(field.choices).get(val, val))
for val in values or []])
return value

View File

@@ -41,7 +41,7 @@ class BaseField(object):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
:param name: Depreciated - use db_field
:param name: Deprecated - use db_field
:param required: If the field is required. Whether it has to have a
value or not. Defaults to False.
:param default: (optional) The default value for this field if no value
@@ -55,7 +55,7 @@ class BaseField(object):
field. Generally this is deprecated in favour of the
`FIELD.validate` method
:param choices: (optional) The valid choices
:param null: (optional) Is the field value can be null. If no and there is a default value
:param null: (optional) If the field value can be null. If no and there is a default value
then the default value is set
:param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False`
means that uniqueness won't be enforced for `None` values
@@ -81,6 +81,24 @@ class BaseField(object):
self.sparse = sparse
self._owner_document = None
# Make sure db_field is a string (if it's explicitly defined).
if (
self.db_field is not None and
not isinstance(self.db_field, six.string_types)
):
raise TypeError('db_field should be a string.')
# Make sure db_field doesn't contain any forbidden characters.
if isinstance(self.db_field, six.string_types) and (
'.' in self.db_field or
'\0' in self.db_field or
self.db_field.startswith('$')
):
raise ValueError(
'field names cannot contain dots (".") or null characters '
'("\\0"), and they must not start with a dollar sign ("$").'
)
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
@@ -112,7 +130,6 @@ class BaseField(object):
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
"""
# If setting to None and there is a default
# Then set the value to the default value
if value is None:
@@ -182,7 +199,8 @@ class BaseField(object):
EmbeddedDocument = _import_class('EmbeddedDocument')
choice_list = self.choices
if isinstance(choice_list[0], (list, tuple)):
if isinstance(next(iter(choice_list)), (list, tuple)):
# next(iter) is useful for sets
choice_list = [k for k, _ in choice_list]
# Choices which are other types of Documents
@@ -194,8 +212,10 @@ class BaseField(object):
)
)
# Choices which are types other than Documents
elif value not in choice_list:
self.error('Value must be one of %s' % six.text_type(choice_list))
else:
values = value if isinstance(value, (list, tuple)) else [value]
if len(set(values) - set(choice_list)):
self.error('Value must be one of %s' % six.text_type(choice_list))
def _validate(self, value, **kwargs):
# Check the Choices Constraint
@@ -481,7 +501,7 @@ class GeoJsonBaseField(BaseField):
def validate(self, value):
"""Validate the GeoJson object based on its type."""
if isinstance(value, dict):
if set(value.keys()) == set(['type', 'coordinates']):
if set(value.keys()) == {'type', 'coordinates'}:
if value['type'] != self._type:
self.error('%s type must be "%s"' %
(self._name, self._type))

22
mongoengine/base/utils.py Normal file
View File

@@ -0,0 +1,22 @@
import re
class LazyRegexCompiler(object):
"""Descriptor to allow lazy compilation of regex"""
def __init__(self, pattern, flags=0):
self._pattern = pattern
self._flags = flags
self._compiled_regex = None
@property
def compiled_regex(self):
if self._compiled_regex is None:
self._compiled_regex = re.compile(self._pattern, self._flags)
return self._compiled_regex
def __get__(self, obj, objtype):
return self.compiled_regex
def __set__(self, instance, value):
raise AttributeError("Can not set attribute LazyRegexCompiler")

View File

@@ -34,7 +34,10 @@ def _import_class(cls_name):
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)
if cls_name in doc_classes:
if cls_name == 'BaseDocument':
from mongoengine.base import document as module
import_classes = ['BaseDocument']
elif cls_name in doc_classes:
from mongoengine import document as module
import_classes = doc_classes
elif cls_name in field_classes:

View File

@@ -28,7 +28,7 @@ _connections = {}
_dbs = {}
def register_connection(alias, name=None, host=None, port=None,
def register_connection(alias, db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None,
authentication_source=None,
@@ -39,6 +39,7 @@ def register_connection(alias, name=None, host=None, port=None,
:param alias: the name that will be used to refer to this connection
throughout MongoEngine
:param name: the name of the specific database to use
:param db: the name of the database to use, for compatibility with connect
:param host: the host name of the :program:`mongod` instance to connect to
:param port: the port that the :program:`mongod` instance is running on
:param read_preference: The read preference for the collection
@@ -51,12 +52,14 @@ def register_connection(alias, name=None, host=None, port=None,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
:param kwargs: ad-hoc parameters to be passed into the pymongo driver,
for example maxpoolsize, tz_aware, etc. See the documentation
for pymongo's `MongoClient` for a full list.
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = {
'name': name or 'test',
'name': name or db or 'test',
'host': host or 'localhost',
'port': port or 27017,
'read_preference': read_preference,
@@ -66,9 +69,9 @@ def register_connection(alias, name=None, host=None, port=None,
'authentication_mechanism': authentication_mechanism
}
# Handle uri style connections
conn_host = conn_settings['host']
# host can be a list or a string, so if string, force to a list
# Host can be a list or a string, so if string, force to a list.
if isinstance(conn_host, six.string_types):
conn_host = [conn_host]
@@ -96,11 +99,23 @@ def register_connection(alias, name=None, host=None, port=None,
uri_options = uri_dict['options']
if 'replicaset' in uri_options:
conn_settings['replicaSet'] = True
conn_settings['replicaSet'] = uri_options['replicaset']
if 'authsource' in uri_options:
conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
if IS_PYMONGO_3 and 'readpreference' in uri_options:
read_preferences = (
ReadPreference.NEAREST,
ReadPreference.PRIMARY,
ReadPreference.PRIMARY_PREFERRED,
ReadPreference.SECONDARY,
ReadPreference.SECONDARY_PREFERRED)
read_pf_mode = uri_options['readpreference'].lower()
for preference in read_preferences:
if preference.name.lower() == read_pf_mode:
conn_settings['read_preference'] = preference
break
else:
resolved_hosts.append(entity)
conn_settings['host'] = resolved_hosts
@@ -144,13 +159,14 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
raise MongoEngineConnectionError(msg)
def _clean_settings(settings_dict):
irrelevant_fields = set([
'name', 'username', 'password', 'authentication_source',
'authentication_mechanism'
])
# set literal more efficient than calling set function
irrelevant_fields_set = {
'name', 'username', 'password',
'authentication_source', 'authentication_mechanism'
}
return {
k: v for k, v in settings_dict.items()
if k not in irrelevant_fields
if k not in irrelevant_fields_set
}
# Retrieve a copy of the connection settings associated with the requested
@@ -170,23 +186,22 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
else:
connection_class = MongoClient
# Handle replica set connections
if 'replicaSet' in conn_settings:
# For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if 'replicaSet' in conn_settings and not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# hosts_or_uri has to be a string, so if 'host' was provided
# as a list, join its parts and separate them by ','
if isinstance(conn_settings['hosts_or_uri'], list):
conn_settings['hosts_or_uri'] = ','.join(
conn_settings['hosts_or_uri'])
# Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None)
# Discard replicaSet if it's not a string
if not isinstance(conn_settings['replicaSet'], six.string_types):
del conn_settings['replicaSet']
# For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# Iterate over all of the connection settings and if a connection with
# the same parameters is already established, use it instead of creating
# a new one.
@@ -242,9 +257,12 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
running on the default port on localhost. If authentication is needed,
provide username and password arguments as well.
Multiple databases are supported by using aliases. Provide a separate
Multiple databases are supported by using aliases. Provide a separate
`alias` to connect to a different instance of :program:`mongod`.
See the docstring for `register_connection` for more details about all
supported kwargs.
.. versionchanged:: 0.6 - added multiple database support.
"""
if alias not in _connections:

View File

@@ -1,9 +1,11 @@
from contextlib import contextmanager
from pymongo.write_concern import WriteConcern
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
'no_sub_classes', 'query_counter')
'no_sub_classes', 'query_counter', 'set_write_concern')
class switch_db(object):
@@ -143,66 +145,83 @@ class no_sub_classes(object):
:param cls: the class to turn querying sub classes on
"""
self.cls = cls
self.cls_initial_subclasses = None
def __enter__(self):
"""Change the objects default and _auto_dereference values."""
self.cls._all_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls,)
self.cls_initial_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls._class_name,)
return self.cls
def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values."""
self.cls._subclasses = self.cls._all_subclasses
delattr(self.cls, '_all_subclasses')
return self.cls
self.cls._subclasses = self.cls_initial_subclasses
class query_counter(object):
"""Query_counter context manager to get the number of queries."""
"""Query_counter context manager to get the number of queries.
This works by updating the `profiling_level` of the database so that all queries get logged,
resetting the db.system.profile collection at the beginnig of the context and counting the new entries.
This was designed for debugging purpose. In fact it is a global counter so queries issued by other threads/processes
can interfere with it
Be aware that:
- Iterating over large amount of documents (>101) makes pymongo issue `getmore` queries to fetch the next batch of
documents (https://docs.mongodb.com/manual/tutorial/iterate-a-cursor/#cursor-batches)
- Some queries are ignored by default by the counter (killcursors, db.system.indexes)
"""
def __init__(self):
"""Construct the query_counter."""
self.counter = 0
"""Construct the query_counter
"""
self.db = get_db()
self.initial_profiling_level = None
self._ctx_query_counter = 0 # number of queries issued by the context
def __enter__(self):
"""On every with block we need to drop the profile collection."""
self._ignored_query = {
'ns':
{'$ne': '%s.system.indexes' % self.db.name},
'op':
{'$ne': 'killcursors'}
}
def _turn_on_profiling(self):
self.initial_profiling_level = self.db.profiling_level()
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
def _resets_profiling(self):
self.db.set_profiling_level(self.initial_profiling_level)
def __enter__(self):
self._turn_on_profiling()
return self
def __exit__(self, t, value, traceback):
"""Reset the profiling level."""
self.db.set_profiling_level(0)
self._resets_profiling()
def __eq__(self, value):
"""== Compare querycounter."""
counter = self._get_count()
return value == counter
def __ne__(self, value):
"""!= Compare querycounter."""
return not self.__eq__(value)
def __lt__(self, value):
"""< Compare querycounter."""
return self._get_count() < value
def __le__(self, value):
"""<= Compare querycounter."""
return self._get_count() <= value
def __gt__(self, value):
"""> Compare querycounter."""
return self._get_count() > value
def __ge__(self, value):
""">= Compare querycounter."""
return self._get_count() >= value
def __int__(self):
"""int representation."""
return self._get_count()
def __repr__(self):
@@ -210,8 +229,17 @@ class query_counter(object):
return u"%s" % self._get_count()
def _get_count(self):
"""Get the number of queries."""
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
count = self.db.system.profile.find(ignore_query).count() - self.counter
self.counter += 1
"""Get the number of queries by counting the current number of entries in db.system.profile
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
issued so we need to balance that
"""
count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter
self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
return count
@contextmanager
def set_write_concern(collection, write_concerns):
combined_concerns = dict(collection.write_concern.document.items())
combined_concerns.update(write_concerns)
yield collection.with_options(write_concern=WriteConcern(**combined_concerns))

View File

@@ -3,6 +3,7 @@ import six
from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.fields import DictField, ListField, MapField, ReferenceField
@@ -99,7 +100,10 @@ class DeReference(object):
if isinstance(item, (Document, EmbeddedDocument)):
for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None)
if isinstance(v, DBRef):
if isinstance(v, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
continue
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and '_ref' in v:
reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id)
@@ -110,6 +114,9 @@ class DeReference(object):
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
key = field_cls
reference_map.setdefault(key, set()).update(refs)
elif isinstance(item, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
continue
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and '_ref' in item:
@@ -126,7 +133,12 @@ class DeReference(object):
"""
object_map = {}
for collection, dbrefs in self.reference_map.iteritems():
if hasattr(collection, 'objects'): # We have a document class for the refs
# we use getattr instead of hasattr because as hasattr swallows any exception under python2
# so it could hide nasty things without raising exceptions (cfr bug #1688))
ref_document_cls_exists = (getattr(collection, 'objects', None) is not None)
if ref_document_cls_exists:
col_name = collection._get_collection_name()
refs = [dbref for dbref in dbrefs
if (col_name, dbref) not in object_map]
@@ -134,7 +146,7 @@ class DeReference(object):
for key, doc in references.iteritems():
object_map[(col_name, key)] = doc
else: # Generic reference: use the refs data to convert to document
if isinstance(doc_type, (ListField, DictField, MapField,)):
if isinstance(doc_type, (ListField, DictField, MapField)):
continue
refs = [dbref for dbref in dbrefs
@@ -230,7 +242,7 @@ class DeReference(object):
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
elif hasattr(v, 'id'):
elif isinstance(v, DBRef) and hasattr(v, 'id'):
data[k] = self.object_map.get((v.collection, v.id), v)
if instance and name:

View File

@@ -39,7 +39,7 @@ class InvalidCollectionError(Exception):
pass
class EmbeddedDocument(BaseDocument):
class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
"""A :class:`~mongoengine.Document` that isn't stored in its own
collection. :class:`~mongoengine.EmbeddedDocument`\ s should be used as
fields on :class:`~mongoengine.Document`\ s through the
@@ -58,7 +58,12 @@ class EmbeddedDocument(BaseDocument):
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass
# A generic embedded document doesn't have any immutable properties
# that describe it uniquely, hence it shouldn't be hashable. You can
# define your own __hash__ method on a subclass if you need your
# embedded documents to be hashable.
__hash__ = None
def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs)
@@ -89,7 +94,7 @@ class EmbeddedDocument(BaseDocument):
self._instance.reload(*args, **kwargs)
class Document(BaseDocument):
class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""The base class used for defining the structure and properties of
collections of documents stored in MongoDB. Inherit from this class, and
add fields as class attributes to define a document's structure.
@@ -144,7 +149,6 @@ class Document(BaseDocument):
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
__slots__ = ('__objects',)
@@ -160,6 +164,15 @@ class Document(BaseDocument):
"""Set the primary key."""
return setattr(self, self._meta['id_field'], value)
def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
and doesn't have a PK yet, return the default object hash instead.
"""
if self.pk is None:
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
@@ -167,45 +180,66 @@ class Document(BaseDocument):
@classmethod
def _get_collection(cls):
"""Returns the collection for the document."""
# TODO: use new get_collection() with PyMongo3 ?
"""Return a PyMongo collection for the document."""
if not hasattr(cls, '_collection') or cls._collection is None:
db = cls._get_db()
collection_name = cls._get_collection_name()
# Create collection as a capped collection if specified
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
# Get max document limit and max byte size from meta
max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default
max_documents = cls._meta.get('max_documents')
# Round up to next 256 bytes as MongoDB would do it to avoid exception
if max_size % 256:
max_size = (max_size // 256 + 1) * 256
if collection_name in db.collection_names():
cls._collection = db[collection_name]
# The collection already exists, check if its capped
# options match the specified capped options
options = cls._collection.options()
if options.get('max') != max_documents or \
options.get('size') != max_size:
msg = (('Cannot create collection "%s" as a capped '
'collection as it already exists')
% cls._collection)
raise InvalidCollectionError(msg)
else:
# Create the collection as a capped collection
opts = {'capped': True, 'size': max_size}
if max_documents:
opts['max'] = max_documents
cls._collection = db.create_collection(
collection_name, **opts
)
# Get the collection, either capped or regular.
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
cls._collection = cls._get_capped_collection()
else:
db = cls._get_db()
collection_name = cls._get_collection_name()
cls._collection = db[collection_name]
if cls._meta.get('auto_create_index', True):
# Ensure indexes on the collection unless auto_create_index was
# set to False.
# Also there is no need to ensure indexes on slave.
db = cls._get_db()
if cls._meta.get('auto_create_index', True) and\
db.client.is_primary:
cls.ensure_indexes()
return cls._collection
@classmethod
def _get_capped_collection(cls):
"""Create a new or get an existing capped PyMongo collection."""
db = cls._get_db()
collection_name = cls._get_collection_name()
# Get max document limit and max byte size from meta.
max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default
max_documents = cls._meta.get('max_documents')
# MongoDB will automatically raise the size to make it a multiple of
# 256 bytes. We raise it here ourselves to be able to reliably compare
# the options below.
if max_size % 256:
max_size = (max_size // 256 + 1) * 256
# If the collection already exists and has different options
# (i.e. isn't capped or has different max/size), raise an error.
if collection_name in db.collection_names():
collection = db[collection_name]
options = collection.options()
if (
options.get('max') != max_documents or
options.get('size') != max_size
):
raise InvalidCollectionError(
'Cannot create collection "{}" as a capped '
'collection as it already exists'.format(cls._collection)
)
return collection
# Create a new capped collection.
opts = {'capped': True, 'size': max_size}
if max_documents:
opts['max'] = max_documents
return db.create_collection(collection_name, **opts)
def to_mongo(self, *args, **kwargs):
data = super(Document, self).to_mongo(*args, **kwargs)
@@ -247,6 +281,9 @@ class Document(BaseDocument):
elif query[id_field] != self.pk:
raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
# Need to add shard key to query, or you get an error
query.update(self._object_key)
updated = self._qs(**query).modify(new=True, **update)
if updated is None:
return False
@@ -267,7 +304,7 @@ class Document(BaseDocument):
created.
:param force_insert: only try to create a new document, don't allow
updates of existing documents
updates of existing documents.
:param validate: validates the document; set to ``False`` to skip.
:param clean: call the document clean method, requires `validate` to be
True.
@@ -287,7 +324,7 @@ class Document(BaseDocument):
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`OperationError` if the conditions are not satisfied
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
:param signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
.. versionchanged:: 0.5
@@ -313,6 +350,9 @@ class Document(BaseDocument):
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
if self._meta.get('abstract'):
raise InvalidDocumentError('Cannot save an abstract document.')
signal_kwargs = signal_kwargs or {}
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
@@ -329,68 +369,20 @@ class Document(BaseDocument):
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created, **signal_kwargs)
if self._meta.get('auto_create_index', True):
self.ensure_indexes()
try:
collection = self._get_collection()
if self._meta.get('auto_create_index', True):
self.ensure_indexes()
# Save a new document or update an existing one
if created:
if force_insert:
object_id = collection.insert(doc, **write_concern)
else:
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get('updatedExisting')
if updated is not None:
return not updated
return created
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
created = is_new_object(last_error)
object_id, created = self._save_update(doc, save_condition,
write_concern)
if cascade is None:
cascade = self._meta.get(
'cascade', False) or cascade_kwargs is not None
cascade = (self._meta.get('cascade', False) or
cascade_kwargs is not None)
if cascade:
kwargs = {
@@ -403,6 +395,7 @@ class Document(BaseDocument):
kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs
self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err:
message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % six.text_type(err))
@@ -415,16 +408,101 @@ class Document(BaseDocument):
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
# Make sure we store the PK on this document now that it's saved
id_field = self._meta['id_field']
if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs)
self._clear_changed_fields()
self._created = False
return self
def _save_create(self, doc, force_insert, write_concern):
"""Save a new document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
if force_insert:
return collection.insert(doc, **write_concern)
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
return object_id
def _get_update_doc(self):
"""Return a dict containing all the $set and $unset operations
that should be sent to MongoDB based on the changes made to this
Document.
"""
updates, removals = self._delta()
update_doc = {}
if updates:
update_doc['$set'] = updates
if removals:
update_doc['$unset'] = removals
return update_doc
def _save_update(self, doc, save_condition, write_concern):
"""Update an existing document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
object_id = doc['_id']
created = False
select_dict = {}
if save_condition is not None:
select_dict = transform.query(self.__class__, **save_condition)
select_dict['_id'] = object_id
# Need to add shard key to query, or you get an error
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
update_doc = self._get_update_doc()
if update_doc:
upsert = save_condition is None
last_error = collection.update(select_dict, update_doc,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
if last_error is not None:
updated_existing = last_error.get('updatedExisting')
if updated_existing is False:
created = True
# !!! This is bad, means we accidentally created a new,
# potentially corrupted document. See
# https://github.com/MongoEngine/mongoengine/issues/564
return object_id, created
def cascade_save(self, **kwargs):
"""Recursively save any references and generic references on the
document.
@@ -502,12 +580,11 @@ class Document(BaseDocument):
"""Delete the :class:`~mongoengine.Document` from the database. This
will only take effect if the document has been previously saved.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
:param signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w: 2, fsync: True)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
@@ -628,7 +705,6 @@ class Document(BaseDocument):
obj = obj[0]
else:
raise self.DoesNotExist('Document does not exist')
for field in obj._data:
if not fields or field in fields:
try:
@@ -636,7 +712,7 @@ class Document(BaseDocument):
except (KeyError, AttributeError):
try:
# If field is a special field, e.g. items is stored as _reserved_items,
# an KeyError is thrown. So try to retrieve the field from _data
# a KeyError is thrown. So try to retrieve the field from _data
setattr(self, field, self._reload(field, obj._data.get(field)))
except KeyError:
# If field is removed from the database while the object
@@ -644,7 +720,9 @@ class Document(BaseDocument):
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
self._changed_fields = obj._changed_fields
self._changed_fields = list(
set(self._changed_fields) - set(fields)
) if fields else obj._changed_fields
self._created = False
return self
@@ -828,7 +906,6 @@ class Document(BaseDocument):
""" Lists all of the indexes that should be created for given
collection. It includes all the indexes from super- and sub-classes.
"""
if cls._meta.get('abstract'):
return []
@@ -891,8 +968,16 @@ class Document(BaseDocument):
"""
required = cls.list_indexes()
existing = [info['key']
for info in cls._get_collection().index_information().values()]
existing = []
for info in cls._get_collection().index_information().values():
if '_fts' in info['key'][0]:
index_type = info['key'][0][1]
text_index_fields = info.get('weights').keys()
existing.append(
[(key, index_type) for key in text_index_fields])
else:
existing.append(info['key'])
missing = [index for index in required if index not in existing]
extra = [index for index in existing if index not in required]
@@ -909,10 +994,10 @@ class Document(BaseDocument):
return {'missing': missing, 'extra': extra}
class DynamicDocument(Document):
class DynamicDocument(six.with_metaclass(TopLevelDocumentMetaclass, Document)):
"""A Dynamic Document class allowing flexible, expandable and uncontrolled
schemas. As a :class:`~mongoengine.Document` subclass, acts in the same
way as an ordinary document but has expando style properties. Any data
way as an ordinary document but has expanded style properties. Any data
passed or set against the :class:`~mongoengine.DynamicDocument` that is
not a field is automatically converted into a
:class:`~mongoengine.fields.DynamicField` and data can be attributed to that
@@ -926,7 +1011,6 @@ class DynamicDocument(Document):
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
_dynamic = True
@@ -937,11 +1021,12 @@ class DynamicDocument(Document):
field_name = args[0]
if field_name in self._dynamic_fields:
setattr(self, field_name, None)
self._dynamic_fields[field_name].null = False
else:
super(DynamicDocument, self).__delattr__(*args, **kwargs)
class DynamicEmbeddedDocument(EmbeddedDocument):
class DynamicEmbeddedDocument(six.with_metaclass(DocumentMetaclass, EmbeddedDocument)):
"""A Dynamic Embedded Document class allowing flexible, expandable and
uncontrolled schemas. See :class:`~mongoengine.DynamicDocument` for more
information about dynamic documents.
@@ -950,7 +1035,6 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass
_dynamic = True

View File

@@ -50,8 +50,8 @@ class FieldDoesNotExist(Exception):
or an :class:`~mongoengine.EmbeddedDocument`.
To avoid this behavior on data loading,
you should the :attr:`strict` to ``False``
in the :attr:`meta` dictionnary.
you should set the :attr:`strict` to ``False``
in the :attr:`meta` dictionary.
"""

View File

@@ -2,9 +2,9 @@ import datetime
import decimal
import itertools
import re
import socket
import time
import uuid
import warnings
from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON
@@ -24,11 +24,15 @@ try:
except ImportError:
Int64 = long
from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField,
GeoJsonBaseField, ObjectIdField, get_document)
GeoJsonBaseField, LazyReference, ObjectIdField,
get_document)
from mongoengine.base.utils import LazyRegexCompiler
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.errors import DoesNotExist, ValidationError
from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
from mongoengine.python_support import StringIO
from mongoengine.queryset import DO_NOTHING, QuerySet
@@ -38,13 +42,20 @@ except ImportError:
Image = None
ImageOps = None
if six.PY3:
# Useless as long as 2to3 gets executed
# as it turns `long` into `int` blindly
long = int
__all__ = (
'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'DateField',
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
'GenericEmbeddedDocumentField', 'DynamicField', 'ListField',
'SortedListField', 'EmbeddedDocumentListField', 'DictField',
'MapField', 'ReferenceField', 'CachedReferenceField',
'LazyReferenceField', 'GenericLazyReferenceField',
'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy',
'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField',
'GeoPointField', 'PointField', 'LineStringField', 'PolygonField',
@@ -119,7 +130,7 @@ class URLField(StringField):
.. versionadded:: 0.3
"""
_URL_REGEX = re.compile(
_URL_REGEX = LazyRegexCompiler(
r'^(?:[a-z0-9\.\-]*)://' # scheme is validated separately
r'(?:(?:[A-Z0-9](?:[A-Z0-9-_]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|' # domain...
r'localhost|' # localhost...
@@ -139,12 +150,12 @@ class URLField(StringField):
# Check first if the scheme is valid
scheme = value.split('://')[0].lower()
if scheme not in self.schemes:
self.error('Invalid scheme {} in URL: {}'.format(scheme, value))
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
return
# Then check full URL
if not self.url_regex.match(value):
self.error('Invalid URL: {}'.format(value))
self.error(u'Invalid URL: {}'.format(value))
return
@@ -153,21 +164,105 @@ class EmailField(StringField):
.. versionadded:: 0.4
"""
EMAIL_REGEX = re.compile(
# dot-atom
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*"
# quoted-string
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"'
# domain (max length of an ICAAN TLD is 22 characters)
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}|[A-Z0-9-]{2,}(?<!-))$', re.IGNORECASE
USER_REGEX = LazyRegexCompiler(
# `dot-atom` defined in RFC 5322 Section 3.2.3.
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
# `quoted-string` defined in RFC 5322 Section 3.2.4.
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)',
re.IGNORECASE
)
UTF8_USER_REGEX = LazyRegexCompiler(
six.u(
# RFC 6531 Section 3.3 extends `atext` (used by dot-atom) to
# include `UTF8-non-ascii`.
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+)*\Z"
# `quoted-string`
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)'
), re.IGNORECASE | re.UNICODE
)
DOMAIN_REGEX = LazyRegexCompiler(
r'((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z',
re.IGNORECASE
)
error_msg = u'Invalid email address: %s'
def __init__(self, domain_whitelist=None, allow_utf8_user=False,
allow_ip_domain=False, *args, **kwargs):
"""Initialize the EmailField.
Args:
domain_whitelist (list) - list of otherwise invalid domain
names which you'd like to support.
allow_utf8_user (bool) - if True, the user part of the email
address can contain UTF8 characters.
False by default.
allow_ip_domain (bool) - if True, the domain part of the email
can be a valid IPv4 or IPv6 address.
"""
self.domain_whitelist = domain_whitelist or []
self.allow_utf8_user = allow_utf8_user
self.allow_ip_domain = allow_ip_domain
super(EmailField, self).__init__(*args, **kwargs)
def validate_user_part(self, user_part):
"""Validate the user part of the email address. Return True if
valid and False otherwise.
"""
if self.allow_utf8_user:
return self.UTF8_USER_REGEX.match(user_part)
return self.USER_REGEX.match(user_part)
def validate_domain_part(self, domain_part):
"""Validate the domain part of the email address. Return True if
valid and False otherwise.
"""
# Skip domain validation if it's in the whitelist.
if domain_part in self.domain_whitelist:
return True
if self.DOMAIN_REGEX.match(domain_part):
return True
# Validate IPv4/IPv6, e.g. user@[192.168.0.1]
if (
self.allow_ip_domain and
domain_part[0] == '[' and
domain_part[-1] == ']'
):
for addr_family in (socket.AF_INET, socket.AF_INET6):
try:
socket.inet_pton(addr_family, domain_part[1:-1])
return True
except (socket.error, UnicodeEncodeError):
pass
return False
def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value):
self.error('Invalid email address: %s' % value)
super(EmailField, self).validate(value)
if '@' not in value:
self.error(self.error_msg % value)
user_part, domain_part = value.rsplit('@', 1)
# Validate the user part.
if not self.validate_user_part(user_part):
self.error(self.error_msg % value)
# Validate the domain and, if invalid, see if it's IDN-encoded.
if not self.validate_domain_part(domain_part):
try:
domain_part = domain_part.encode('idna').decode('ascii')
except UnicodeError:
self.error(self.error_msg % value)
else:
if not self.validate_domain_part(domain_part):
self.error(self.error_msg % value)
class IntField(BaseField):
"""32-bit integer field."""
@@ -276,7 +371,8 @@ class FloatField(BaseField):
class DecimalField(BaseField):
"""Fixed-point decimal number field.
"""Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
If using floats, beware of Decimal to float conversion (potential precision loss)
.. versionchanged:: 0.8
.. versionadded:: 0.3
@@ -287,7 +383,9 @@ class DecimalField(BaseField):
"""
:param min_value: Validation rule for the minimum acceptable value.
:param max_value: Validation rule for the maximum acceptable value.
:param force_string: Store as a string.
:param force_string: Store the value as a string (instead of a float).
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
and some query operator won't work (e.g: inc, dec)
:param precision: Number of decimal places to store.
:param rounding: The rounding rule from the python decimal library:
@@ -374,6 +472,8 @@ class DateTimeField(BaseField):
installed you can utilise it to convert varying types of date formats into valid
python datetime objects.
Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow)
Note: Microseconds are rounded to the nearest millisecond.
Pre UTC microsecond support is effectively broken.
Use :class:`~mongoengine.fields.ComplexDateTimeField` if you
@@ -398,6 +498,10 @@ class DateTimeField(BaseField):
if not isinstance(value, six.string_types):
return None
value = value.strip()
if not value:
return None
# Attempt to parse a datetime:
if dateutil:
try:
@@ -433,6 +537,22 @@ class DateTimeField(BaseField):
return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value))
class DateField(DateTimeField):
def to_mongo(self, value):
value = super(DateField, self).to_mongo(value)
# drop hours, minutes, seconds
if isinstance(value, datetime.datetime):
value = datetime.datetime(value.year, value.month, value.day)
return value
def to_python(self, value):
value = super(DateField, self).to_python(value)
# convert datetime to date
if isinstance(value, datetime.datetime):
value = datetime.date(value.year, value.month, value.day)
return value
class ComplexDateTimeField(StringField):
"""
ComplexDateTimeField handles microseconds exactly instead of rounding
@@ -449,11 +569,15 @@ class ComplexDateTimeField(StringField):
The `,` as the separator can be easily modified by passing the `separator`
keyword when initializing the field.
Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow)
.. versionadded:: 0.5
"""
def __init__(self, separator=',', **kwargs):
self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond']
"""
:param separator: Allows to customize the separator used for storage (default ``,``)
"""
self.separator = separator
self.format = separator.join(['%Y', '%m', '%d', '%H', '%M', '%S', '%f'])
super(ComplexDateTimeField, self).__init__(**kwargs)
@@ -480,20 +604,24 @@ class ComplexDateTimeField(StringField):
>>> ComplexDateTimeField()._convert_from_string(a)
datetime.datetime(2011, 6, 8, 20, 26, 24, 92284)
"""
values = map(int, data.split(self.separator))
values = [int(d) for d in data.split(self.separator)]
return datetime.datetime(*values)
def __get__(self, instance, owner):
if instance is None:
return self
data = super(ComplexDateTimeField, self).__get__(instance, owner)
if data is None:
return None if self.null else datetime.datetime.now()
if isinstance(data, datetime.datetime):
if isinstance(data, datetime.datetime) or data is None:
return data
return self._convert_from_string(data)
def __set__(self, instance, value):
value = self._convert_from_datetime(value) if value else value
return super(ComplexDateTimeField, self).__set__(instance, value)
super(ComplexDateTimeField, self).__set__(instance, value)
value = instance._data[self.name]
if value is not None:
instance._data[self.name] = self._convert_from_datetime(value)
def validate(self, value):
value = self.to_python(value)
@@ -522,9 +650,10 @@ class EmbeddedDocumentField(BaseField):
"""
def __init__(self, document_type, **kwargs):
if (
not isinstance(document_type, six.string_types) and
not issubclass(document_type, EmbeddedDocument)
# XXX ValidationError raised outside of the "validate" method.
if not (
isinstance(document_type, six.string_types) or
issubclass(document_type, EmbeddedDocument)
):
self.error('Invalid embedded document class provided to an '
'EmbeddedDocumentField')
@@ -536,9 +665,17 @@ class EmbeddedDocumentField(BaseField):
def document_type(self):
if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
resolved_document_type = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
resolved_document_type = get_document(self.document_type_obj)
if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
# There is a chance that it won't be an EmbeddedDocument (#1661)
self.error('Invalid embedded document class provided to an '
'EmbeddedDocumentField')
self.document_type_obj = resolved_document_type
return self.document_type_obj
def to_python(self, value):
@@ -566,7 +703,11 @@ class EmbeddedDocumentField(BaseField):
def prepare_query_value(self, op, value):
if value is not None and not isinstance(value, self.document_type):
value = self.document_type._from_son(value)
try:
value = self.document_type._from_son(value)
except ValueError:
raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" %
(self.document_type._class_name,))
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value)
@@ -593,16 +734,28 @@ class GenericEmbeddedDocumentField(BaseField):
return value
def validate(self, value, clean=True):
if self.choices and isinstance(value, SON):
for choice in self.choices:
if value['_cls'] == choice._class_name:
return True
if not isinstance(value, EmbeddedDocument):
self.error('Invalid embedded document instance provided to an '
'GenericEmbeddedDocumentField')
value.validate(clean=clean)
def lookup_member(self, member_name):
if self.choices:
for choice in self.choices:
field = choice._fields.get(member_name)
if field:
return field
return None
def to_mongo(self, document, use_db_field=True, fields=None):
if document is None:
return None
data = document.to_mongo(use_db_field, fields)
if '_cls' not in data:
data['_cls'] = document._class_name
@@ -686,6 +839,17 @@ class ListField(ComplexBaseField):
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
if instance is None:
# Document class being used rather than a document object
return self
value = instance._data.get(self.name)
LazyReferenceField = _import_class('LazyReferenceField')
GenericLazyReferenceField = _import_class('GenericLazyReferenceField')
if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value:
instance._data[self.name] = [self.field.build_lazyref(x) for x in value]
return super(ListField, self).__get__(instance, owner)
def validate(self, value):
"""Make sure that a list of valid fields is being used."""
if (not isinstance(value, (list, tuple, QuerySet)) or
@@ -796,12 +960,10 @@ class DictField(ComplexBaseField):
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
"""
def __init__(self, basecls=None, field=None, *args, **kwargs):
def __init__(self, field=None, *args, **kwargs):
self.field = field
self._auto_dereference = False
self.basecls = basecls or BaseField
if not issubclass(self.basecls, BaseField):
self.error('DictField only accepts dict values')
kwargs.setdefault('default', lambda: {})
super(DictField, self).__init__(*args, **kwargs)
@@ -820,7 +982,7 @@ class DictField(ComplexBaseField):
super(DictField, self).validate(value)
def lookup_member(self, member_name):
return DictField(basecls=self.basecls, db_field=member_name)
return DictField(db_field=member_name)
def prepare_query_value(self, op, value):
match_operators = ['contains', 'icontains', 'startswith',
@@ -850,6 +1012,7 @@ class MapField(DictField):
"""
def __init__(self, field=None, *args, **kwargs):
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(field, BaseField):
self.error('Argument to MapField constructor must be a valid '
'field')
@@ -860,6 +1023,15 @@ class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on
access (lazily).
Note this means you will get a database I/O access everytime you access
this field. This is necessary because the field returns a :class:`~mongoengine.Document`
which precise type can depend of the value of the `_cls` field present in the
document in database.
In short, using this type of field can lead to poor performances (especially
if you access this field only to retrieve it `pk` field which is already
known before dereference). To solve this you should consider using the
:class:`~mongoengine.fields.LazyReferenceField`.
Use the `reverse_delete_rule` to handle what should happen if the document
the field is referencing is deleted. EmbeddedDocuments, DictFields and
MapFields does not support reverse_delete_rule and an `InvalidDocumentError`
@@ -878,15 +1050,13 @@ class ReferenceField(BaseField):
.. code-block:: python
class Bar(Document):
content = StringField()
foo = ReferenceField('Foo')
class Org(Document):
owner = ReferenceField('User')
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
class User(Document):
org = ReferenceField('Org', reverse_delete_rule=CASCADE)
.. note ::
`reverse_delete_rule` does not trigger pre / post delete signals to be
triggered.
User.register_delete_rule(Org, 'owner', DENY)
.. versionchanged:: 0.5 added `reverse_delete_rule`
"""
@@ -904,6 +1074,7 @@ class ReferenceField(BaseField):
A reference to an abstract document type is always stored as a
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
# XXX ValidationError raised outside of the "validate" method.
if (
not isinstance(document_type, six.string_types) and
not issubclass(document_type, Document)
@@ -958,6 +1129,8 @@ class ReferenceField(BaseField):
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.pk
# XXX ValidationError raised outside of the "validate" method.
if id_ is None:
self.error('You can only reference documents once they have'
' been saved to the database')
@@ -997,19 +1170,20 @@ class ReferenceField(BaseField):
return self.to_mongo(value)
def validate(self, value):
if not isinstance(value, (self.document_type, DBRef)):
self.error('A ReferenceField only accepts DBRef or documents')
if not isinstance(value, (self.document_type, LazyReference, DBRef, ObjectId)):
self.error('A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents')
if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been '
'saved to the database')
if self.document_type._meta.get('abstract') and \
not isinstance(value, self.document_type):
if (
self.document_type._meta.get('abstract') and
not isinstance(value, self.document_type)
):
self.error(
'%s is not an instance of abstract reference type %s' % (
self.document_type._class_name)
value, self.document_type._class_name)
)
def lookup_member(self, member_name):
@@ -1032,6 +1206,7 @@ class CachedReferenceField(BaseField):
if fields is None:
fields = []
# XXX ValidationError raised outside of the "validate" method.
if (
not isinstance(document_type, six.string_types) and
not issubclass(document_type, Document)
@@ -1106,6 +1281,7 @@ class CachedReferenceField(BaseField):
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
# XXX ValidationError raised outside of the "validate" method.
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.pk
@@ -1114,7 +1290,6 @@ class CachedReferenceField(BaseField):
' been saved to the database')
else:
self.error('Only accept a document object')
# TODO: should raise here or will fail next statement
value = SON((
('_id', id_field.to_mongo(id_)),
@@ -1132,16 +1307,20 @@ class CachedReferenceField(BaseField):
if value is None:
return None
# XXX ValidationError raised outside of the "validate" method.
if isinstance(value, Document):
if value.pk is None:
self.error('You can only reference documents once they have'
' been saved to the database')
return {'_id': value.pk}
value_dict = {'_id': value.pk}
for field in self.fields:
value_dict.update({field: value[field]})
return value_dict
raise NotImplementedError
def validate(self, value):
if not isinstance(value, self.document_type):
self.error('A CachedReferenceField only accepts documents')
@@ -1174,6 +1353,12 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).
Note this field works the same way as :class:`~mongoengine.document.ReferenceField`,
doing database I/O access the first time it is accessed (even if it's to access
it ``pk`` or ``id`` field).
To solve this you should consider using the
:class:`~mongoengine.fields.GenericLazyReferenceField`.
.. note ::
* Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register
@@ -1196,6 +1381,8 @@ class GenericReferenceField(BaseField):
elif isinstance(choice, type) and issubclass(choice, Document):
self.choices.append(choice._class_name)
else:
# XXX ValidationError raised outside of the "validate"
# method.
self.error('Invalid choices provided: must be a list of'
'Document subclasses and/or six.string_typess')
@@ -1259,6 +1446,7 @@ class GenericReferenceField(BaseField):
# We need the id from the saved object to create the DBRef
id_ = document.id
if id_ is None:
# XXX ValidationError raised outside of the "validate" method.
self.error('You can only reference documents once they have'
' been saved to the database')
else:
@@ -1344,9 +1532,11 @@ class GridFSProxy(object):
def __get__(self, instance, value):
return self
def __nonzero__(self):
def __bool__(self):
return bool(self.grid_id)
__nonzero__ = __bool__ # For Py2 support
def __getstate__(self):
self_dict = self.__dict__
self_dict['_fs'] = None
@@ -1364,9 +1554,9 @@ class GridFSProxy(object):
return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
def __str__(self):
name = getattr(
self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
return '<%s: %s>' % (self.__class__.__name__, name)
gridout = self.get()
filename = getattr(gridout, 'filename') if gridout else '<no file>'
return '<%s: %s (%s)>' % (self.__class__.__name__, filename, self.grid_id)
def __eq__(self, other):
if isinstance(other, GridFSProxy):
@@ -1376,6 +1566,9 @@ class GridFSProxy(object):
else:
return False
def __ne__(self, other):
return not self == other
@property
def fs(self):
if not self._fs:
@@ -2049,3 +2242,201 @@ class MultiPolygonField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
_type = 'MultiPolygon'
class LazyReferenceField(BaseField):
"""A really lazy reference to a document.
Unlike the :class:`~mongoengine.fields.ReferenceField` it will
**not** be automatically (lazily) dereferenced on access.
Instead, access will return a :class:`~mongoengine.base.LazyReference` class
instance, allowing access to `pk` or manual dereference by using
``fetch()`` method.
.. versionadded:: 0.15
"""
def __init__(self, document_type, passthrough=False, dbref=False,
reverse_delete_rule=DO_NOTHING, **kwargs):
"""Initialises the Reference Field.
:param dbref: Store the reference as :class:`~pymongo.dbref.DBRef`
or as the :class:`~pymongo.objectid.ObjectId`.id .
:param reverse_delete_rule: Determines what to do when the referring
object is deleted
:param passthrough: When trying to access unknown fields, the
:class:`~mongoengine.base.datastructure.LazyReference` instance will
automatically call `fetch()` and try to retrive the field on the fetched
document. Note this only work getting field (not setting or deleting).
"""
# XXX ValidationError raised outside of the "validate" method.
if (
not isinstance(document_type, six.string_types) and
not issubclass(document_type, Document)
):
self.error('Argument to LazyReferenceField constructor must be a '
'document class or a string')
self.dbref = dbref
self.passthrough = passthrough
self.document_type_obj = document_type
self.reverse_delete_rule = reverse_delete_rule
super(LazyReferenceField, self).__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def build_lazyref(self, value):
if isinstance(value, LazyReference):
if value.passthrough != self.passthrough:
value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
elif value is not None:
if isinstance(value, self.document_type):
value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough)
elif isinstance(value, DBRef):
value = LazyReference(self.document_type, value.id, passthrough=self.passthrough)
else:
# value is the primary key of the referenced document
value = LazyReference(self.document_type, value, passthrough=self.passthrough)
return value
def __get__(self, instance, owner):
"""Descriptor to allow lazy dereferencing."""
if instance is None:
# Document class being used rather than a document object
return self
value = self.build_lazyref(instance._data.get(self.name))
if value:
instance._data[self.name] = value
return super(LazyReferenceField, self).__get__(instance, owner)
def to_mongo(self, value):
if isinstance(value, LazyReference):
pk = value.pk
elif isinstance(value, self.document_type):
pk = value.pk
elif isinstance(value, DBRef):
pk = value.id
else:
# value is the primary key of the referenced document
pk = value
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
pk = id_field.to_mongo(pk)
if self.dbref:
return DBRef(self.document_type._get_collection_name(), pk)
else:
return pk
def validate(self, value):
if isinstance(value, LazyReference):
if value.collection != self.document_type._get_collection_name():
self.error('Reference must be on a `%s` document.' % self.document_type)
pk = value.pk
elif isinstance(value, self.document_type):
pk = value.pk
elif isinstance(value, DBRef):
# TODO: check collection ?
collection = self.document_type._get_collection_name()
if value.collection != collection:
self.error("DBRef on bad collection (must be on `%s`)" % collection)
pk = value.id
else:
# value is the primary key of the referenced document
id_field_name = self.document_type._meta['id_field']
id_field = getattr(self.document_type, id_field_name)
pk = value
try:
id_field.validate(pk)
except ValidationError:
self.error(
"value should be `{0}` document, LazyReference or DBRef on `{0}` "
"or `{0}`'s primary key (i.e. `{1}`)".format(
self.document_type.__name__, type(id_field).__name__))
if pk is None:
self.error('You can only reference documents once they have been '
'saved to the database')
def prepare_query_value(self, op, value):
if value is None:
return None
super(LazyReferenceField, self).prepare_query_value(op, value)
return self.to_mongo(value)
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
class GenericLazyReferenceField(GenericReferenceField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass.
Unlike the :class:`~mongoengine.fields.GenericReferenceField` it will
**not** be automatically (lazily) dereferenced on access.
Instead, access will return a :class:`~mongoengine.base.LazyReference` class
instance, allowing access to `pk` or manual dereference by using
``fetch()`` method.
.. note ::
* Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register
it.
* You can use the choices param to limit the acceptable Document types
.. versionadded:: 0.15
"""
def __init__(self, *args, **kwargs):
self.passthrough = kwargs.pop('passthrough', False)
super(GenericLazyReferenceField, self).__init__(*args, **kwargs)
def _validate_choices(self, value):
if isinstance(value, LazyReference):
value = value.document_type._class_name
super(GenericLazyReferenceField, self)._validate_choices(value)
def build_lazyref(self, value):
if isinstance(value, LazyReference):
if value.passthrough != self.passthrough:
value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
elif value is not None:
if isinstance(value, (dict, SON)):
value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough)
elif isinstance(value, Document):
value = LazyReference(type(value), value.pk, passthrough=self.passthrough)
return value
def __get__(self, instance, owner):
if instance is None:
return self
value = self.build_lazyref(instance._data.get(self.name))
if value:
instance._data[self.name] = value
return super(GenericLazyReferenceField, self).__get__(instance, owner)
def validate(self, value):
if isinstance(value, LazyReference) and value.pk is None:
self.error('You can only reference documents once they have been'
' saved to the database')
return super(GenericLazyReferenceField, self).validate(value)
def to_mongo(self, document):
if document is None:
return None
if isinstance(document, LazyReference):
return SON((
('_cls', document.document_type._class_name),
('_ref', DBRef(document.document_type._get_collection_name(), document.pk))
))
else:
return super(GenericLazyReferenceField, self).to_mongo(document)

View File

@@ -6,11 +6,7 @@ import pymongo
import six
if pymongo.version_tuple[0] < 3:
IS_PYMONGO_3 = False
else:
IS_PYMONGO_3 = True
IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3
# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
StringIO = six.BytesIO

View File

@@ -2,7 +2,6 @@ from __future__ import absolute_import
import copy
import itertools
import operator
import pprint
import re
import warnings
@@ -18,7 +17,7 @@ from mongoengine import signals
from mongoengine.base import get_document
from mongoengine.common import _import_class
from mongoengine.connection import get_db
from mongoengine.context_managers import switch_db
from mongoengine.context_managers import set_write_concern, switch_db
from mongoengine.errors import (InvalidQueryError, LookUpError,
NotUniqueError, OperationError)
from mongoengine.python_support import IS_PYMONGO_3
@@ -67,7 +66,6 @@ class BaseQuerySet(object):
self._scalar = []
self._none = False
self._as_pymongo = False
self._as_pymongo_coerce = False
self._search_text = None
# If inheritance is allowed, only return instances and instances of
@@ -86,6 +84,7 @@ class BaseQuerySet(object):
self._batch_size = None
self.only_fields = []
self._max_time_ms = None
self._comment = None
def __call__(self, q_obj=None, class_check=True, read_preference=None,
**query):
@@ -157,44 +156,49 @@ class BaseQuerySet(object):
# self._cursor
def __getitem__(self, key):
"""Support skip and limit using getitem and slicing syntax."""
"""Return a document instance corresponding to a given index if
the key is an integer. If the key is a slice, translate its
bounds into a skip and a limit, and return a cloned queryset
with that skip/limit applied. For example:
>>> User.objects[0]
<User: User object>
>>> User.objects[1:3]
[<User: User object>, <User: User object>]
"""
queryset = self.clone()
# Slice provided
# Handle a slice
if isinstance(key, slice):
try:
queryset._cursor_obj = queryset._cursor[key]
queryset._skip, queryset._limit = key.start, key.stop
if key.start and key.stop:
queryset._limit = key.stop - key.start
except IndexError as err:
# PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it.
start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop:
queryset.limit(0)
queryset._skip = key.start
queryset._limit = key.stop - start
return queryset
raise err
queryset._cursor_obj = queryset._cursor[key]
queryset._skip, queryset._limit = key.start, key.stop
if key.start and key.stop:
queryset._limit = key.stop - key.start
# Allow further QuerySet modifications to be performed
return queryset
# Integer index provided
# Handle an index
elif isinstance(key, int):
if queryset._scalar:
return queryset._get_scalar(
queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields))
queryset._document._from_son(
queryset._cursor[key],
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields
)
)
if queryset._as_pymongo:
return queryset._get_as_pymongo(queryset._cursor[key])
return queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
raise AttributeError
return queryset._document._from_son(
queryset._cursor[key],
_auto_dereference=self._auto_dereference,
only_fields=self.only_fields
)
raise AttributeError('Provide a slice or an integer index')
def __iter__(self):
raise NotImplementedError
@@ -204,14 +208,12 @@ class BaseQuerySet(object):
queryset = self.order_by()
return False if queryset.first() is None else True
def __nonzero__(self):
"""Avoid to open all records in an if stmt in Py2."""
return self._has_data()
def __bool__(self):
"""Avoid to open all records in an if stmt in Py3."""
return self._has_data()
__nonzero__ = __bool__ # For Py2 support
# Core functions
def all(self):
@@ -264,13 +266,13 @@ class BaseQuerySet(object):
queryset = queryset.filter(*q_objs, **query)
try:
result = queryset.next()
result = six.next(queryset)
except StopIteration:
msg = ('%s matching query does not exist.'
% queryset._document._class_name)
raise queryset._document.DoesNotExist(msg)
try:
queryset.next()
six.next(queryset)
except StopIteration:
return result
@@ -285,7 +287,7 @@ class BaseQuerySet(object):
.. versionadded:: 0.4
"""
return self._document(**kwargs).save()
return self._document(**kwargs).save(force_insert=True)
def first(self):
"""Retrieve the first object matching the query."""
@@ -345,11 +347,24 @@ class BaseQuerySet(object):
documents=docs, **signal_kwargs)
raw = [doc.to_mongo() for doc in docs]
with set_write_concern(self._collection, write_concern) as collection:
insert_func = collection.insert_many
if return_one:
raw = raw[0]
insert_func = collection.insert_one
try:
ids = self._collection.insert(raw, **write_concern)
inserted_result = insert_func(raw)
ids = return_one and [inserted_result.inserted_id] or inserted_result.inserted_ids
except pymongo.errors.DuplicateKeyError as err:
message = 'Could not save document (%s)'
raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.BulkWriteError as err:
# inserting documents that already have an _id field will
# give huge performance debt or raise
message = u'Document must not have _id value before bulk write (%s)'
raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.OperationFailure as err:
message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', six.text_type(err)):
@@ -363,7 +378,6 @@ class BaseQuerySet(object):
signals.post_bulk_insert.send(
self._document, documents=docs, loaded=False, **signal_kwargs)
return return_one and ids[0] or ids
documents = self.in_bulk(ids)
results = []
for obj_id in ids:
@@ -379,7 +393,7 @@ class BaseQuerySet(object):
:meth:`skip` that has been applied to this cursor into account when
getting the count
"""
if self._limit == 0 and with_limit_and_skip or self._none:
if self._limit == 0 and with_limit_and_skip is False or self._none:
return 0
return self._cursor.count(with_limit_and_skip=with_limit_and_skip)
@@ -481,8 +495,9 @@ class BaseQuerySet(object):
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param full_result: Return the full result rather than just the number
updated.
:param full_result: Return the full result dictionary rather than just the number
updated, e.g. return
``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``.
:param update: Django-style update keyword arguments
.. versionadded:: 0.2
@@ -505,12 +520,15 @@ class BaseQuerySet(object):
else:
update['$set'] = {'_cls': queryset._document._class_name}
try:
result = queryset._collection.update(query, update, multi=multi,
upsert=upsert, **write_concern)
with set_write_concern(queryset._collection, write_concern) as collection:
update_func = collection.update_one
if multi:
update_func = collection.update_many
result = update_func(query, update, upsert=upsert)
if full_result:
return result
elif result:
return result['n']
elif result.raw_result:
return result.raw_result['n']
except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u'Update failed (%s)' % six.text_type(err))
except pymongo.errors.OperationFailure as err:
@@ -539,10 +557,10 @@ class BaseQuerySet(object):
write_concern=write_concern,
full_result=True, **update)
if atomic_update['updatedExisting']:
if atomic_update.raw_result['updatedExisting']:
document = self.get()
else:
document = self._document.objects.with_id(atomic_update['upserted'])
document = self._document.objects.with_id(atomic_update.upserted_id)
return document
def update_one(self, upsert=False, write_concern=None, **update):
@@ -706,39 +724,37 @@ class BaseQuerySet(object):
with switch_db(self._document, alias) as cls:
collection = cls._get_collection()
return self.clone_into(self.__class__(self._document, collection))
return self._clone_into(self.__class__(self._document, collection))
def clone(self):
"""Creates a copy of the current
:class:`~mongoengine.queryset.QuerySet`
"""Create a copy of the current queryset."""
return self._clone_into(self.__class__(self._document, self._collection_obj))
.. versionadded:: 0.5
def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
"""
return self.clone_into(self.__class__(self._document, self._collection_obj))
def clone_into(self, cls):
"""Creates a copy of the current
:class:`~mongoengine.queryset.base.BaseQuerySet` into another child class
"""
if not isinstance(cls, BaseQuerySet):
if not isinstance(new_qs, BaseQuerySet):
raise OperationError(
'%s is not a subclass of BaseQuerySet' % cls.__name__)
'%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_class_check', '_slave_okay', '_read_preference',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_where_clause', '_loaded_fields', '_ordering',
'_snapshot', '_timeout', '_class_check', '_slave_okay',
'_read_preference', '_iter', '_scalar', '_as_pymongo',
'_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms')
'_search_text', 'only_fields', '_max_time_ms',
'_comment')
for prop in copy_props:
val = getattr(self, prop)
setattr(cls, prop, copy.copy(val))
setattr(new_qs, prop, copy.copy(val))
if self._cursor_obj:
cls._cursor_obj = self._cursor_obj.clone()
new_qs._cursor_obj = self._cursor_obj.clone()
return cls
return new_qs
def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects or
@@ -756,11 +772,16 @@ class BaseQuerySet(object):
"""Limit the number of returned documents to `n`. This may also be
achieved using array-slicing syntax (e.g. ``User.objects[:5]``).
:param n: the maximum number of objects to return
:param n: the maximum number of objects to return if n is greater than 0.
When 0 is passed, returns all the documents in the cursor
"""
queryset = self.clone()
queryset._limit = n if n != 0 else 1
# Return self to allow chaining
queryset._limit = n
# If a cursor object has already been created, apply the limit to it.
if queryset._cursor_obj:
queryset._cursor_obj.limit(queryset._limit)
return queryset
def skip(self, n):
@@ -771,6 +792,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._skip = n
# If a cursor object has already been created, apply the skip to it.
if queryset._cursor_obj:
queryset._cursor_obj.skip(queryset._skip)
return queryset
def hint(self, index=None):
@@ -788,6 +814,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._hint = index
# If a cursor object has already been created, apply the hint to it.
if queryset._cursor_obj:
queryset._cursor_obj.hint(queryset._hint)
return queryset
def batch_size(self, size):
@@ -801,6 +832,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._batch_size = size
# If a cursor object has already been created, apply the batch size to it.
if queryset._cursor_obj:
queryset._cursor_obj.batch_size(queryset._batch_size)
return queryset
def distinct(self, field):
@@ -900,18 +936,25 @@ class BaseQuerySet(object):
return self.fields(**fields)
def fields(self, _only_called=False, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. If called
directly, use a set of kwargs similar to the MongoDB projection
document. For example:
Retrieving a Subrange of Array Elements:
Include only a subset of fields:
You can use the $slice operator to retrieve a subrange of elements in
an array. For example to get the first 5 comments::
posts = BlogPost.objects(...).fields(author=1, title=1)
post = BlogPost.objects(...).fields(slice__comments=5)
Exclude a specific field:
:param kwargs: A dictionary identifying what to include
posts = BlogPost.objects(...).fields(comments=0)
To retrieve a subrange of array elements:
posts = BlogPost.objects(...).fields(slice__comments=5)
:param kwargs: A set of keyword arguments identifying what to
include, exclude, or slice.
.. versionadded:: 0.5
"""
@@ -927,7 +970,20 @@ class BaseQuerySet(object):
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
# Sort fields by their values, explicitly excluded fields first, then
# explicitly included, and then more complicated operators such as
# $slice.
def _sort_key(field_tuple):
key, value = field_tuple
if isinstance(value, (int)):
return value # 0 for exclusion, 1 for inclusion
else:
return 2 # so that complex values appear last
fields = sorted(cleaned_fields, key=_sort_key)
# Clone the queryset, group all fields by their value, convert
# each of them to db_fields, and set the queryset's _loaded_fields
queryset = self.clone()
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
@@ -953,13 +1009,31 @@ class BaseQuerySet(object):
def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed.
Ascending order is assumed. If no keys are passed, existing ordering
is cleared instead.
:param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction
"""
queryset = self.clone()
queryset._ordering = queryset._get_order_by(keys)
old_ordering = queryset._ordering
new_ordering = queryset._get_order_by(keys)
if queryset._cursor_obj:
# If a cursor object has already been created, apply the sort to it
if new_ordering:
queryset._cursor_obj.sort(new_ordering)
# If we're trying to clear a previous explicit ordering, we need
# to clear the cursor entirely (because PyMongo doesn't allow
# clearing an existing sort on a cursor).
elif old_ordering:
queryset._cursor_obj = None
queryset._ordering = new_ordering
return queryset
def comment(self, text):
@@ -1069,16 +1143,15 @@ class BaseQuerySet(object):
"""An alias for scalar"""
return self.scalar(*fields)
def as_pymongo(self, coerce_types=False):
def as_pymongo(self):
"""Instead of returning Document instances, return raw values from
pymongo.
:param coerce_types: Field types (if applicable) would be use to
coerce types.
This method is particularly useful if you don't need dereferencing
and care primarily about the speed of data retrieval.
"""
queryset = self.clone()
queryset._as_pymongo = True
queryset._as_pymongo_coerce = coerce_types
return queryset
def max_time_ms(self, ms):
@@ -1123,6 +1196,10 @@ class BaseQuerySet(object):
pipeline = initial_pipeline + list(pipeline)
if IS_PYMONGO_3 and self._read_preference is not None:
return self._collection.with_options(read_preference=self._read_preference) \
.aggregate(pipeline, cursor={}, **kwargs)
return self._collection.aggregate(pipeline, cursor={}, **kwargs)
# JS functionality
@@ -1398,27 +1475,31 @@ class BaseQuerySet(object):
# Iterator helpers
def next(self):
def __next__(self):
"""Wrap the result in a :class:`~mongoengine.Document` object.
"""
if self._limit == 0 or self._none:
raise StopIteration
raw_doc = self._cursor.next()
raw_doc = six.next(self._cursor)
if self._as_pymongo:
return self._get_as_pymongo(raw_doc)
doc = self._document._from_son(raw_doc,
_auto_dereference=self._auto_dereference, only_fields=self.only_fields)
doc = self._document._from_son(
raw_doc, _auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
if self._scalar:
return self._get_scalar(doc)
return doc
next = __next__ # For Python2 support
def rewind(self):
"""Rewind the cursor to its unevaluated state.
.. versionadded:: 0.3
"""
self._iter = False
@@ -1468,43 +1549,57 @@ class BaseQuerySet(object):
@property
def _cursor(self):
if self._cursor_obj is None:
"""Return a PyMongo cursor object corresponding to this queryset."""
# In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned
# collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
# If _cursor_obj already exists, return it immediately.
if self._cursor_obj is not None:
return self._cursor_obj
if self._ordering:
# Apply query ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model, unless
# it's been explicitly cleared via order_by with no arguments
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
# Create a new PyMongo cursor.
# XXX In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply "where" clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
# Apply ordering to the cursor.
# XXX self._ordering can be equal to:
# * None if we didn't explicitly call order_by on this queryset.
# * A list of PyMongo-style sorting tuples.
# * An empty list if we explicitly called order_by() without any
# arguments. This indicates that we want to clear the default
# ordering.
if self._ordering:
# explicit ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# default ordering
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
if self._comment is not None:
self._cursor_obj.comment(self._comment)
return self._cursor_obj
@@ -1650,25 +1745,33 @@ class BaseQuerySet(object):
return frequencies
def _fields_to_dbfields(self, fields):
"""Translate fields paths to its db equivalents"""
ret = []
"""Translate fields' paths to their db equivalents."""
subclasses = []
document = self._document
if document._meta['allow_inheritance']:
if self._document._meta['allow_inheritance']:
subclasses = [get_document(x)
for x in document._subclasses][1:]
for x in self._document._subclasses][1:]
db_field_paths = []
for field in fields:
field_parts = field.split('.')
try:
field = '.'.join(f.db_field for f in
document._lookup_field(field.split('.')))
ret.append(field)
field = '.'.join(
f if isinstance(f, six.string_types) else f.db_field
for f in self._document._lookup_field(field_parts)
)
db_field_paths.append(field)
except LookUpError as err:
found = False
# If a field path wasn't found on the main document, go
# through its subclasses and see if it exists on any of them.
for subdoc in subclasses:
try:
subfield = '.'.join(f.db_field for f in
subdoc._lookup_field(field.split('.')))
ret.append(subfield)
subfield = '.'.join(
f if isinstance(f, six.string_types) else f.db_field
for f in subdoc._lookup_field(field_parts)
)
db_field_paths.append(subfield)
found = True
break
except LookUpError:
@@ -1676,10 +1779,17 @@ class BaseQuerySet(object):
if not found:
raise err
return ret
return db_field_paths
def _get_order_by(self, keys):
"""Creates a list of order by fields"""
"""Given a list of MongoEngine-style sort keys, return a list
of sorting tuples that can be applied to a PyMongo cursor. For
example:
>>> qs._get_order_by(['-last_name', 'first_name'])
[('last_name', -1), ('first_name', 1)]
"""
key_list = []
for key in keys:
if not key:
@@ -1692,17 +1802,19 @@ class BaseQuerySet(object):
direction = pymongo.ASCENDING
if key[0] == '-':
direction = pymongo.DESCENDING
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
try:
key = self._document._translate_field_name(key)
except Exception:
# TODO this exception should be more specific
pass
key_list.append((key, direction))
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list
def _get_scalar(self, doc):
@@ -1719,59 +1831,25 @@ class BaseQuerySet(object):
return tuple(data)
def _get_as_pymongo(self, row):
# Extract which fields paths we should follow if .fields(...) was
# used. If not, handle all fields.
if not getattr(self, '__as_pymongo_fields', None):
self.__as_pymongo_fields = []
def _get_as_pymongo(self, doc):
"""Clean up a PyMongo doc, removing fields that were only fetched
for the sake of MongoEngine's implementation, and return it.
"""
# Always remove _cls as a MongoEngine's implementation detail.
if '_cls' in doc:
del doc['_cls']
for field in self._loaded_fields.fields - set(['_cls']):
self.__as_pymongo_fields.append(field)
while '.' in field:
field, _ = field.rsplit('.', 1)
self.__as_pymongo_fields.append(field)
# If the _id was not included in a .only or was excluded in a .exclude,
# remove it from the doc (we always fetch it so that we can properly
# construct documents).
fields = self._loaded_fields
if fields and '_id' in doc and (
(fields.value == QueryFieldList.ONLY and '_id' not in fields.fields) or
(fields.value == QueryFieldList.EXCLUDE and '_id' in fields.fields)
):
del doc['_id']
all_fields = not self.__as_pymongo_fields
def clean(data, path=None):
path = path or ''
if isinstance(data, dict):
new_data = {}
for key, value in data.iteritems():
new_path = '%s.%s' % (path, key) if path else key
if all_fields:
include_field = True
elif self._loaded_fields.value == QueryFieldList.ONLY:
include_field = new_path in self.__as_pymongo_fields
else:
include_field = new_path not in self.__as_pymongo_fields
if include_field:
new_data[key] = clean(value, path=new_path)
data = new_data
elif isinstance(data, list):
data = [clean(d, path=path) for d in data]
else:
if self._as_pymongo_coerce:
# If we need to coerce types, we need to determine the
# type of this field and use the corresponding
# .to_python(...)
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
obj = self._document
for chunk in path.split('.'):
obj = getattr(obj, chunk, None)
if obj is None:
break
elif isinstance(obj, EmbeddedDocumentField):
obj = obj.document_type
if obj and data is not None:
data = obj.to_python(data)
return data
return clean(row)
return doc
def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where
@@ -1800,10 +1878,21 @@ class BaseQuerySet(object):
return code
def _chainable_method(self, method_name, val):
"""Call a particular method on the PyMongo cursor call a particular chainable method
with the provided value.
"""
queryset = self.clone()
method = getattr(queryset._cursor, method_name)
method(val)
# Get an existing cursor object or create a new one
cursor = queryset._cursor
# Find the requested method on the cursor and call it with the
# provided value
getattr(cursor, method_name)(val)
# Cache the value on the queryset._{method_name}
setattr(queryset, '_' + method_name, val)
return queryset
# Deprecated

View File

@@ -63,9 +63,11 @@ class QueryFieldList(object):
self._only_called = True
return self
def __nonzero__(self):
def __bool__(self):
return bool(self.fields)
__nonzero__ = __bool__ # For Py2 support
def as_dict(self):
field_list = {field: self.value for field in self.fields}
if self.slice:

View File

@@ -36,7 +36,7 @@ class QuerySetManager(object):
queryset_class = owner._meta.get('queryset_class', self.default)
queryset = queryset_class(owner, owner._get_collection())
if self.get_queryset:
arg_count = self.get_queryset.func_code.co_argcount
arg_count = self.get_queryset.__code__.co_argcount
if arg_count == 1:
queryset = self.get_queryset(queryset)
elif arg_count == 2:

View File

@@ -1,3 +1,5 @@
import six
from mongoengine.errors import OperationError
from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING,
NULLIFY, PULL)
@@ -87,10 +89,10 @@ class QuerySet(BaseQuerySet):
yield self._result_cache[pos]
pos += 1
# Raise StopIteration if we already established there were no more
# return if we already established there were no more
# docs in the db cursor.
if not self._has_more:
raise StopIteration
return
# Otherwise, populate more of the cache and repeat.
if len(self._result_cache) <= pos:
@@ -112,8 +114,8 @@ class QuerySet(BaseQuerySet):
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
# the result cache.
try:
for _ in xrange(ITER_CHUNK_SIZE):
self._result_cache.append(self.next())
for _ in six.moves.range(ITER_CHUNK_SIZE):
self._result_cache.append(six.next(self))
except StopIteration:
# Getting this exception means there are no more docs in the
# db cursor. Set _has_more to False so that we can use that
@@ -136,13 +138,15 @@ class QuerySet(BaseQuerySet):
return self._len
def no_cache(self):
"""Convert to a non_caching queryset
"""Convert to a non-caching queryset
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection))
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
class QuerySetNoCache(BaseQuerySet):
@@ -153,7 +157,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to caching queryset
"""
return self.clone_into(QuerySet(self._document, self._collection))
return self._clone_into(QuerySet(self._document, self._collection))
def __repr__(self):
"""Provides the string representation of the QuerySet
@@ -164,9 +168,9 @@ class QuerySetNoCache(BaseQuerySet):
return '.. queryset mid-iteration ..'
data = []
for _ in xrange(REPR_OUTPUT_SIZE + 1):
for _ in six.moves.range(REPR_OUTPUT_SIZE + 1):
try:
data.append(self.next())
data.append(six.next(self))
except StopIteration:
break

View File

@@ -101,8 +101,8 @@ def query(_doc_cls=None, **kwargs):
value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
# Raise an error if the in/nin/all/near param is not iterable.
value = _prepare_query_for_iterable(field, op, value)
# If we're querying a GenericReferenceField, we need to alter the
# key depending on the value:
@@ -147,7 +147,7 @@ def query(_doc_cls=None, **kwargs):
if op is None or key not in mongo_query:
mongo_query[key] = value
elif key in mongo_query:
if isinstance(mongo_query[key], dict):
if isinstance(mongo_query[key], dict) and isinstance(value, dict):
mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key]
@@ -201,31 +201,37 @@ def update(_doc_cls=None, **update):
format.
"""
mongo_update = {}
for key, value in update.items():
if key == '__raw__':
mongo_update.update(value)
continue
parts = key.split('__')
# if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
parts.insert(0, 'set')
# Check for an operator and transform to mongo-style if there is
op = None
if parts[0] in UPDATE_OPERATORS:
op = parts.pop(0)
# Convert Pythonic names to Mongo equivalents
if op in ('push_all', 'pull_all'):
op = op.replace('_all', 'All')
elif op == 'dec':
operator_map = {
'push_all': 'pushAll',
'pull_all': 'pullAll',
'dec': 'inc',
'add_to_set': 'addToSet',
'set_on_insert': 'setOnInsert'
}
if op == 'dec':
# Support decrement by flipping a positive value's sign
# and using 'inc'
op = 'inc'
if value > 0:
value = -value
elif op == 'add_to_set':
op = 'addToSet'
elif op == 'set_on_insert':
op = 'setOnInsert'
value = -value
# If the operator doesn't found from operator map, the op value
# will stay unchanged
op = operator_map.get(op, op)
match = None
if parts[-1] in COMPARISON_OPERATORS:
@@ -272,7 +278,15 @@ def update(_doc_cls=None, **update):
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
if op in (None, 'set', 'push', 'pull'):
if op == 'pull':
if field.required or value is not None:
if match == 'in' and not isinstance(value, dict):
value = _prepare_query_for_iterable(field, op, value)
else:
value = field.prepare_query_value(op, value)
elif op == 'push' and isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif op in (None, 'set', 'push'):
if field.required or value is not None:
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
@@ -284,6 +298,8 @@ def update(_doc_cls=None, **update):
value = field.prepare_query_value(op, value)
elif op == 'unset':
value = 1
elif op == 'inc':
value = field.prepare_query_value(op, value)
if match:
match = '$' + match
@@ -307,11 +323,17 @@ def update(_doc_cls=None, **update):
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
ListField = _import_class('ListField')
if ListField in field_classes:
# Join all fields via dot notation to the last ListField
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
if ListField in field_classes or EmbeddedDocumentListField in field_classes:
# Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
# Then process as normal
if ListField in field_classes:
_check_field = ListField
else:
_check_field = EmbeddedDocumentListField
last_listField = len(
cleaned_fields) - field_classes.index(ListField)
cleaned_fields) - field_classes.index(_check_field)
key = '.'.join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
@@ -321,10 +343,26 @@ def update(_doc_cls=None, **update):
value = {key: value}
elif op == 'addToSet' and isinstance(value, list):
value = {key: {'$each': value}}
elif op in ('push', 'pushAll'):
if parts[-1].isdigit():
key = parts[0]
position = int(parts[-1])
# $position expects an iterable. If pushing a single value,
# wrap it in a list.
if not isinstance(value, (set, tuple, list)):
value = [value]
value = {key: {'$each': value, '$position': position}}
else:
if op == 'pushAll':
op = 'push' # convert to non-deprecated keyword
if not isinstance(value, (set, tuple, list)):
value = [value]
value = {key: {'$each': value}}
else:
value = {key: value}
else:
value = {key: value}
key = '$' + op
if key not in mongo_update:
mongo_update[key] = value
elif key in mongo_update and isinstance(mongo_update[key], dict):
@@ -413,3 +451,22 @@ def _infer_geometry(value):
raise InvalidQueryError('Invalid $geometry data. Can be either a '
'dictionary or (nested) lists of coordinate(s)')
def _prepare_query_for_iterable(field, op, value):
# We need a special check for BaseDocument, because - although it's iterable - using
# it as such in the context of this method is most definitely a mistake.
BaseDocument = _import_class('BaseDocument')
if isinstance(value, BaseDocument):
raise TypeError("When using the `in`, `nin`, `all`, or "
"`near`-operators you can\'t use a "
"`Document`, you must wrap your object "
"in a list (object -> [object]).")
if not hasattr(value, '__iter__'):
raise TypeError("The `in`, `nin`, `all`, or "
"`near`-operators must be applied to an "
"iterable (e.g. a list).")
return [field.prepare_query_value(op, v) for v in value]