Merge branch 'master' of github.com:MongoEngine/mongoengine into freyr_binaryfield_query_value_type
This commit is contained in:
@@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
|
||||
list(signals.__all__) + list(errors.__all__))
|
||||
|
||||
|
||||
VERSION = (0, 14, 0)
|
||||
VERSION = (0, 16, 3)
|
||||
|
||||
|
||||
def get_version():
|
||||
|
||||
@@ -15,7 +15,7 @@ __all__ = (
|
||||
'UPDATE_OPERATORS', '_document_registry', 'get_document',
|
||||
|
||||
# datastructures
|
||||
'BaseDict', 'BaseList', 'EmbeddedDocumentList',
|
||||
'BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference',
|
||||
|
||||
# document
|
||||
'BaseDocument',
|
||||
|
||||
@@ -3,9 +3,10 @@ from mongoengine.errors import NotRegistered
|
||||
__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry')
|
||||
|
||||
|
||||
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
|
||||
'push_all', 'pull', 'pull_all', 'add_to_set',
|
||||
'set_on_insert', 'min', 'max', 'rename'])
|
||||
UPDATE_OPERATORS = {'set', 'unset', 'inc', 'dec', 'mul',
|
||||
'pop', 'push', 'push_all', 'pull',
|
||||
'pull_all', 'add_to_set', 'set_on_insert',
|
||||
'min', 'max', 'rename'}
|
||||
|
||||
|
||||
_document_registry = {}
|
||||
@@ -18,7 +19,7 @@ def get_document(name):
|
||||
# Possible old style name
|
||||
single_end = name.split('.')[-1]
|
||||
compound_end = '.%s' % single_end
|
||||
possible_match = [k for k in _document_registry.keys()
|
||||
possible_match = [k for k in _document_registry
|
||||
if k.endswith(compound_end) or k == single_end]
|
||||
if len(possible_match) == 1:
|
||||
doc = _document_registry.get(possible_match.pop(), None)
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
import itertools
|
||||
import weakref
|
||||
|
||||
from bson import DBRef
|
||||
import six
|
||||
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
|
||||
|
||||
__all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList')
|
||||
__all__ = ('BaseDict', 'StrictDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference')
|
||||
|
||||
|
||||
def mark_as_changed_wrapper(parent_method):
|
||||
"""Decorators that ensures _mark_as_changed method gets called"""
|
||||
def wrapper(self, *args, **kwargs):
|
||||
result = parent_method(self, *args, **kwargs) # Can't use super() in the decorator
|
||||
self._mark_as_changed()
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
def mark_key_as_changed_wrapper(parent_method):
|
||||
"""Decorators that ensures _mark_as_changed method gets called with the key argument"""
|
||||
def wrapper(self, key, *args, **kwargs):
|
||||
result = parent_method(self, key, *args, **kwargs) # Can't use super() in the decorator
|
||||
self._mark_as_changed(key)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class BaseDict(dict):
|
||||
@@ -17,46 +35,36 @@ class BaseDict(dict):
|
||||
_name = None
|
||||
|
||||
def __init__(self, dict_items, instance, name):
|
||||
Document = _import_class('Document')
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
BaseDocument = _import_class('BaseDocument')
|
||||
|
||||
if isinstance(instance, (Document, EmbeddedDocument)):
|
||||
if isinstance(instance, BaseDocument):
|
||||
self._instance = weakref.proxy(instance)
|
||||
self._name = name
|
||||
super(BaseDict, self).__init__(dict_items)
|
||||
|
||||
def __getitem__(self, key, *args, **kwargs):
|
||||
def get(self, key, default=None):
|
||||
# get does not use __getitem__ by default so we must override it as well
|
||||
try:
|
||||
return self.__getitem__(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = super(BaseDict, self).__getitem__(key)
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = self._instance
|
||||
elif not isinstance(value, BaseDict) and isinstance(value, dict):
|
||||
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
||||
value = BaseDict(value, None, '%s.%s' % (self._name, key))
|
||||
super(BaseDict, self).__setitem__(key, value)
|
||||
value._instance = self._instance
|
||||
elif not isinstance(value, BaseList) and isinstance(value, list):
|
||||
elif isinstance(value, list) and not isinstance(value, BaseList):
|
||||
value = BaseList(value, None, '%s.%s' % (self._name, key))
|
||||
super(BaseDict, self).__setitem__(key, value)
|
||||
value._instance = self._instance
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value, *args, **kwargs):
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseDict, self).__setitem__(key, value)
|
||||
|
||||
def __delete__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).__delete__(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, key, *args, **kwargs):
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseDict, self).__delitem__(key)
|
||||
|
||||
def __delattr__(self, key, *args, **kwargs):
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseDict, self).__delattr__(key)
|
||||
|
||||
def __getstate__(self):
|
||||
self.instance = None
|
||||
self._dereferenced = False
|
||||
@@ -66,25 +74,14 @@ class BaseDict(dict):
|
||||
self = state
|
||||
return self
|
||||
|
||||
def clear(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).clear()
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).pop(*args, **kwargs)
|
||||
|
||||
def popitem(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).popitem()
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).setdefault(*args, **kwargs)
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).update(*args, **kwargs)
|
||||
__setitem__ = mark_key_as_changed_wrapper(dict.__setitem__)
|
||||
__delattr__ = mark_key_as_changed_wrapper(dict.__delattr__)
|
||||
__delitem__ = mark_key_as_changed_wrapper(dict.__delitem__)
|
||||
pop = mark_as_changed_wrapper(dict.pop)
|
||||
clear = mark_as_changed_wrapper(dict.clear)
|
||||
update = mark_as_changed_wrapper(dict.update)
|
||||
popitem = mark_as_changed_wrapper(dict.popitem)
|
||||
setdefault = mark_as_changed_wrapper(dict.setdefault)
|
||||
|
||||
def _mark_as_changed(self, key=None):
|
||||
if hasattr(self._instance, '_mark_as_changed'):
|
||||
@@ -102,52 +99,39 @@ class BaseList(list):
|
||||
_name = None
|
||||
|
||||
def __init__(self, list_items, instance, name):
|
||||
Document = _import_class('Document')
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
BaseDocument = _import_class('BaseDocument')
|
||||
|
||||
if isinstance(instance, (Document, EmbeddedDocument)):
|
||||
if isinstance(instance, BaseDocument):
|
||||
self._instance = weakref.proxy(instance)
|
||||
self._name = name
|
||||
super(BaseList, self).__init__(list_items)
|
||||
|
||||
def __getitem__(self, key, *args, **kwargs):
|
||||
def __getitem__(self, key):
|
||||
value = super(BaseList, self).__getitem__(key)
|
||||
|
||||
if isinstance(key, slice):
|
||||
# When receiving a slice operator, we don't convert the structure and bind
|
||||
# to parent's instance. This is buggy for now but would require more work to be handled properly
|
||||
return value
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = self._instance
|
||||
elif not isinstance(value, BaseDict) and isinstance(value, dict):
|
||||
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
||||
# Replace dict by BaseDict
|
||||
value = BaseDict(value, None, '%s.%s' % (self._name, key))
|
||||
super(BaseList, self).__setitem__(key, value)
|
||||
value._instance = self._instance
|
||||
elif not isinstance(value, BaseList) and isinstance(value, list):
|
||||
elif isinstance(value, list) and not isinstance(value, BaseList):
|
||||
# Replace list by BaseList
|
||||
value = BaseList(value, None, '%s.%s' % (self._name, key))
|
||||
super(BaseList, self).__setitem__(key, value)
|
||||
value._instance = self._instance
|
||||
return value
|
||||
|
||||
def __iter__(self):
|
||||
for i in xrange(self.__len__()):
|
||||
yield self[i]
|
||||
|
||||
def __setitem__(self, key, value, *args, **kwargs):
|
||||
if isinstance(key, slice):
|
||||
self._mark_as_changed()
|
||||
else:
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseList, self).__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).__delitem__(key)
|
||||
|
||||
def __setslice__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).__setslice__(*args, **kwargs)
|
||||
|
||||
def __delslice__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).__delslice__(*args, **kwargs)
|
||||
for v in super(BaseList, self).__iter__():
|
||||
yield v
|
||||
|
||||
def __getstate__(self):
|
||||
self.instance = None
|
||||
@@ -158,41 +142,40 @@ class BaseList(list):
|
||||
self = state
|
||||
return self
|
||||
|
||||
def __iadd__(self, other):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).__iadd__(other)
|
||||
def __setitem__(self, key, value):
|
||||
changed_key = key
|
||||
if isinstance(key, slice):
|
||||
# In case of slice, we don't bother to identify the exact elements being updated
|
||||
# instead, we simply marks the whole list as changed
|
||||
changed_key = None
|
||||
|
||||
def __imul__(self, other):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).__imul__(other)
|
||||
result = super(BaseList, self).__setitem__(key, value)
|
||||
self._mark_as_changed(changed_key)
|
||||
return result
|
||||
|
||||
def append(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).append(*args, **kwargs)
|
||||
append = mark_as_changed_wrapper(list.append)
|
||||
extend = mark_as_changed_wrapper(list.extend)
|
||||
insert = mark_as_changed_wrapper(list.insert)
|
||||
pop = mark_as_changed_wrapper(list.pop)
|
||||
remove = mark_as_changed_wrapper(list.remove)
|
||||
reverse = mark_as_changed_wrapper(list.reverse)
|
||||
sort = mark_as_changed_wrapper(list.sort)
|
||||
__delitem__ = mark_as_changed_wrapper(list.__delitem__)
|
||||
__iadd__ = mark_as_changed_wrapper(list.__iadd__)
|
||||
__imul__ = mark_as_changed_wrapper(list.__imul__)
|
||||
|
||||
def extend(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).extend(*args, **kwargs)
|
||||
if six.PY2:
|
||||
# Under py3 __setslice__, __delslice__ and __getslice__
|
||||
# are replaced by __setitem__, __delitem__ and __getitem__ with a slice as parameter
|
||||
# so we mimic this under python 2
|
||||
def __setslice__(self, i, j, sequence):
|
||||
return self.__setitem__(slice(i, j), sequence)
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).insert(*args, **kwargs)
|
||||
def __delslice__(self, i, j):
|
||||
return self.__delitem__(slice(i, j))
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).pop(*args, **kwargs)
|
||||
|
||||
def remove(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).remove(*args, **kwargs)
|
||||
|
||||
def reverse(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).reverse()
|
||||
|
||||
def sort(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).sort(*args, **kwargs)
|
||||
def __getslice__(self, i, j):
|
||||
return self.__getitem__(slice(i, j))
|
||||
|
||||
def _mark_as_changed(self, key=None):
|
||||
if hasattr(self._instance, '_mark_as_changed'):
|
||||
@@ -206,6 +189,10 @@ class BaseList(list):
|
||||
|
||||
class EmbeddedDocumentList(BaseList):
|
||||
|
||||
def __init__(self, list_items, instance, name):
|
||||
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
|
||||
self._instance = instance
|
||||
|
||||
@classmethod
|
||||
def __match_all(cls, embedded_doc, kwargs):
|
||||
"""Return True if a given embedded doc matches all the filter
|
||||
@@ -224,15 +211,14 @@ class EmbeddedDocumentList(BaseList):
|
||||
return embedded_docs
|
||||
return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)]
|
||||
|
||||
def __init__(self, list_items, instance, name):
|
||||
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
|
||||
self._instance = instance
|
||||
|
||||
def filter(self, **kwargs):
|
||||
"""
|
||||
Filters the list by only including embedded documents with the
|
||||
given keyword arguments.
|
||||
|
||||
This method only supports simple comparison (e.g: .filter(name='John Doe'))
|
||||
and does not support operators like __gte, __lte, __icontains like queryset.filter does
|
||||
|
||||
:param kwargs: The keyword arguments corresponding to the fields to
|
||||
filter on. *Multiple arguments are treated as if they are ANDed
|
||||
together.*
|
||||
@@ -350,7 +336,8 @@ class EmbeddedDocumentList(BaseList):
|
||||
|
||||
def update(self, **update):
|
||||
"""
|
||||
Updates the embedded documents with the given update values.
|
||||
Updates the embedded documents with the given replacement values. This
|
||||
function does not support mongoDB update operators such as ``inc__``.
|
||||
|
||||
.. note::
|
||||
The embedded document changes are not automatically saved
|
||||
@@ -372,7 +359,7 @@ class EmbeddedDocumentList(BaseList):
|
||||
|
||||
class StrictDict(object):
|
||||
__slots__ = ()
|
||||
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
|
||||
_special_fields = {'get', 'pop', 'iteritems', 'items', 'keys', 'create'}
|
||||
_classes = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -447,40 +434,40 @@ class StrictDict(object):
|
||||
return cls._classes[allowed_keys]
|
||||
|
||||
|
||||
class SemiStrictDict(StrictDict):
|
||||
__slots__ = ('_extras', )
|
||||
_classes = {}
|
||||
class LazyReference(DBRef):
|
||||
__slots__ = ('_cached_doc', 'passthrough', 'document_type')
|
||||
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
super(SemiStrictDict, self).__getattr__(attr)
|
||||
except AttributeError:
|
||||
try:
|
||||
return self.__getattribute__('_extras')[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
def fetch(self, force=False):
|
||||
if not self._cached_doc or force:
|
||||
self._cached_doc = self.document_type.objects.get(pk=self.pk)
|
||||
if not self._cached_doc:
|
||||
raise DoesNotExist('Trying to dereference unknown document %s' % (self))
|
||||
return self._cached_doc
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
try:
|
||||
super(SemiStrictDict, self).__setattr__(attr, value)
|
||||
except AttributeError:
|
||||
try:
|
||||
self._extras[attr] = value
|
||||
except AttributeError:
|
||||
self._extras = {attr: value}
|
||||
@property
|
||||
def pk(self):
|
||||
return self.id
|
||||
|
||||
def __delattr__(self, attr):
|
||||
try:
|
||||
super(SemiStrictDict, self).__delattr__(attr)
|
||||
except AttributeError:
|
||||
try:
|
||||
del self._extras[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
def __init__(self, document_type, pk, cached_doc=None, passthrough=False):
|
||||
self.document_type = document_type
|
||||
self._cached_doc = cached_doc
|
||||
self.passthrough = passthrough
|
||||
super(LazyReference, self).__init__(self.document_type._get_collection_name(), pk)
|
||||
|
||||
def __iter__(self):
|
||||
def __getitem__(self, name):
|
||||
if not self.passthrough:
|
||||
raise KeyError()
|
||||
document = self.fetch()
|
||||
return document[name]
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not object.__getattribute__(self, 'passthrough'):
|
||||
raise AttributeError()
|
||||
document = self.fetch()
|
||||
try:
|
||||
extras_iter = iter(self.__getattribute__('_extras'))
|
||||
except AttributeError:
|
||||
extras_iter = ()
|
||||
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)
|
||||
return document[name]
|
||||
except KeyError:
|
||||
raise AttributeError()
|
||||
|
||||
def __repr__(self):
|
||||
return "<LazyReference(%s, %r)>" % (self.document_type, self.pk)
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import copy
|
||||
import numbers
|
||||
from collections import Hashable
|
||||
from functools import partial
|
||||
|
||||
from bson import ObjectId, json_util
|
||||
from bson.dbref import DBRef
|
||||
from bson.son import SON
|
||||
from bson import DBRef, ObjectId, SON, json_util
|
||||
import pymongo
|
||||
import six
|
||||
|
||||
@@ -13,13 +10,15 @@ from mongoengine import signals
|
||||
from mongoengine.base.common import get_document
|
||||
from mongoengine.base.datastructures import (BaseDict, BaseList,
|
||||
EmbeddedDocumentList,
|
||||
SemiStrictDict, StrictDict)
|
||||
LazyReference,
|
||||
StrictDict)
|
||||
from mongoengine.base.fields import ComplexBaseField
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
|
||||
LookUpError, OperationError, ValidationError)
|
||||
from mongoengine.python_support import Hashable
|
||||
|
||||
__all__ = ('BaseDocument',)
|
||||
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
|
||||
|
||||
NON_FIELD_ERRORS = '__all__'
|
||||
|
||||
@@ -79,8 +78,7 @@ class BaseDocument(object):
|
||||
if self.STRICT and not self._dynamic:
|
||||
self._data = StrictDict.create(allowed_keys=self._fields_ordered)()
|
||||
else:
|
||||
self._data = SemiStrictDict.create(
|
||||
allowed_keys=self._fields_ordered)()
|
||||
self._data = {}
|
||||
|
||||
self._dynamic_fields = SON()
|
||||
|
||||
@@ -100,13 +98,11 @@ class BaseDocument(object):
|
||||
for key, value in values.iteritems():
|
||||
if key in self._fields or key == '_id':
|
||||
setattr(self, key, value)
|
||||
elif self._dynamic:
|
||||
else:
|
||||
dynamic_data[key] = value
|
||||
else:
|
||||
FileField = _import_class('FileField')
|
||||
for key, value in values.iteritems():
|
||||
if key == '__auto_convert':
|
||||
continue
|
||||
key = self._reverse_db_field_map.get(key, key)
|
||||
if key in self._fields or key in ('id', 'pk', '_cls'):
|
||||
if __auto_convert and value is not None:
|
||||
@@ -147,7 +143,7 @@ class BaseDocument(object):
|
||||
|
||||
if not hasattr(self, name) and not name.startswith('_'):
|
||||
DynamicField = _import_class('DynamicField')
|
||||
field = DynamicField(db_field=name)
|
||||
field = DynamicField(db_field=name, null=True)
|
||||
field.name = name
|
||||
self._dynamic_fields[name] = field
|
||||
self._fields_ordered += (name,)
|
||||
@@ -304,7 +300,7 @@ class BaseDocument(object):
|
||||
data['_cls'] = self._class_name
|
||||
|
||||
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
|
||||
root_fields = set([f.split('.')[0] for f in fields])
|
||||
root_fields = {f.split('.')[0] for f in fields}
|
||||
|
||||
for field_name in self:
|
||||
if root_fields and field_name not in root_fields:
|
||||
@@ -337,7 +333,7 @@ class BaseDocument(object):
|
||||
value = field.generate()
|
||||
self._data[field_name] = value
|
||||
|
||||
if value is not None:
|
||||
if (value is not None) or (field.null):
|
||||
if use_db_field:
|
||||
data[field.db_field] = value
|
||||
else:
|
||||
@@ -406,7 +402,15 @@ class BaseDocument(object):
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_data, created=False):
|
||||
"""Converts json data to an unsaved document instance"""
|
||||
"""Converts json data to a Document instance
|
||||
|
||||
:param json_data: The json data to load into the Document
|
||||
:param created: If True, the document will be considered as a brand new document
|
||||
If False and an id is provided, it will consider that the data being
|
||||
loaded corresponds to what's already in the database (This has an impact of subsequent call to .save())
|
||||
If False and no id is provided, it will consider the data as a new document
|
||||
(default ``False``)
|
||||
"""
|
||||
return cls._from_son(json_util.loads(json_data), created=created)
|
||||
|
||||
def __expand_dynamic_values(self, name, value):
|
||||
@@ -489,7 +493,7 @@ class BaseDocument(object):
|
||||
else:
|
||||
data = getattr(data, part, None)
|
||||
|
||||
if hasattr(data, '_changed_fields'):
|
||||
if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'):
|
||||
if getattr(data, '_is_document', False):
|
||||
continue
|
||||
|
||||
@@ -497,7 +501,13 @@ class BaseDocument(object):
|
||||
|
||||
self._changed_fields = []
|
||||
|
||||
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
|
||||
def _nestable_types_changed_fields(self, changed_fields, base_key, data):
|
||||
"""Inspect nested data for changed fields
|
||||
|
||||
:param changed_fields: Previously collected changed fields
|
||||
:param base_key: The base key that must be used to prepend changes to this data
|
||||
:param data: data to inspect for changes
|
||||
"""
|
||||
# Loop list / dict fields as they contain documents
|
||||
# Determine the iterator to use
|
||||
if not hasattr(data, 'items'):
|
||||
@@ -505,68 +515,60 @@ class BaseDocument(object):
|
||||
else:
|
||||
iterator = data.iteritems()
|
||||
|
||||
for index, value in iterator:
|
||||
list_key = '%s%s.' % (key, index)
|
||||
for index_or_key, value in iterator:
|
||||
item_key = '%s%s.' % (base_key, index_or_key)
|
||||
# don't check anything lower if this key is already marked
|
||||
# as changed.
|
||||
if list_key[:-1] in changed_fields:
|
||||
if item_key[:-1] in changed_fields:
|
||||
continue
|
||||
|
||||
if hasattr(value, '_get_changed_fields'):
|
||||
changed = value._get_changed_fields(inspected)
|
||||
changed_fields += ['%s%s' % (list_key, k)
|
||||
for k in changed if k]
|
||||
changed = value._get_changed_fields()
|
||||
changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
|
||||
elif isinstance(value, (list, tuple, dict)):
|
||||
self._nestable_types_changed_fields(
|
||||
changed_fields, list_key, value, inspected)
|
||||
changed_fields, item_key, value)
|
||||
|
||||
def _get_changed_fields(self, inspected=None):
|
||||
def _get_changed_fields(self):
|
||||
"""Return a list of all fields that have explicitly been changed.
|
||||
"""
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
|
||||
ReferenceField = _import_class('ReferenceField')
|
||||
GenericReferenceField = _import_class('GenericReferenceField')
|
||||
SortedListField = _import_class('SortedListField')
|
||||
|
||||
changed_fields = []
|
||||
changed_fields += getattr(self, '_changed_fields', [])
|
||||
|
||||
inspected = inspected or set()
|
||||
if hasattr(self, 'id') and isinstance(self.id, Hashable):
|
||||
if self.id in inspected:
|
||||
return changed_fields
|
||||
inspected.add(self.id)
|
||||
|
||||
for field_name in self._fields_ordered:
|
||||
db_field_name = self._db_field_map.get(field_name, field_name)
|
||||
key = '%s.' % db_field_name
|
||||
data = self._data.get(field_name, None)
|
||||
field = self._fields.get(field_name)
|
||||
|
||||
if hasattr(data, 'id'):
|
||||
if data.id in inspected:
|
||||
continue
|
||||
if isinstance(field, ReferenceField):
|
||||
if db_field_name in changed_fields:
|
||||
# Whole field already marked as changed, no need to go further
|
||||
continue
|
||||
elif (
|
||||
isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and
|
||||
db_field_name not in changed_fields
|
||||
):
|
||||
|
||||
if isinstance(field, ReferenceField): # Don't follow referenced documents
|
||||
continue
|
||||
|
||||
if isinstance(data, EmbeddedDocument):
|
||||
# Find all embedded fields that have been changed
|
||||
changed = data._get_changed_fields(inspected)
|
||||
changed = data._get_changed_fields()
|
||||
changed_fields += ['%s%s' % (key, k) for k in changed if k]
|
||||
elif (isinstance(data, (list, tuple, dict)) and
|
||||
db_field_name not in changed_fields):
|
||||
elif isinstance(data, (list, tuple, dict)):
|
||||
if (hasattr(field, 'field') and
|
||||
isinstance(field.field, ReferenceField)):
|
||||
isinstance(field.field, (ReferenceField, GenericReferenceField))):
|
||||
continue
|
||||
elif isinstance(field, SortedListField) and field._ordering:
|
||||
# if ordering is affected whole list is changed
|
||||
if any(map(lambda d: field._ordering in d._changed_fields, data)):
|
||||
if any(field._ordering in d._changed_fields for d in data):
|
||||
changed_fields.append(db_field_name)
|
||||
continue
|
||||
|
||||
self._nestable_types_changed_fields(
|
||||
changed_fields, key, data, inspected)
|
||||
changed_fields, key, data)
|
||||
return changed_fields
|
||||
|
||||
def _delta(self):
|
||||
@@ -578,7 +580,6 @@ class BaseDocument(object):
|
||||
|
||||
set_fields = self._get_changed_fields()
|
||||
unset_data = {}
|
||||
parts = []
|
||||
if hasattr(self, '_changed_fields'):
|
||||
set_data = {}
|
||||
# Fetch each set item from its path
|
||||
@@ -588,15 +589,13 @@ class BaseDocument(object):
|
||||
new_path = []
|
||||
for p in parts:
|
||||
if isinstance(d, (ObjectId, DBRef)):
|
||||
# Don't dig in the references
|
||||
break
|
||||
elif isinstance(d, list) and p.lstrip('-').isdigit():
|
||||
if p[0] == '-':
|
||||
p = str(len(d) + int(p))
|
||||
try:
|
||||
d = d[int(p)]
|
||||
except IndexError:
|
||||
d = None
|
||||
elif isinstance(d, list) and p.isdigit():
|
||||
# An item of a list (identified by its index) is updated
|
||||
d = d[int(p)]
|
||||
elif hasattr(d, 'get'):
|
||||
# dict-like (dict, embedded document)
|
||||
d = d.get(p)
|
||||
new_path.append(p)
|
||||
path = '.'.join(new_path)
|
||||
@@ -608,26 +607,26 @@ class BaseDocument(object):
|
||||
|
||||
# Determine if any changed items were actually unset.
|
||||
for path, value in set_data.items():
|
||||
if value or isinstance(value, (numbers.Number, bool)):
|
||||
if value or isinstance(value, (numbers.Number, bool)): # Account for 0 and True that are truthy
|
||||
continue
|
||||
|
||||
# If we've set a value that ain't the default value don't unset it.
|
||||
default = None
|
||||
parts = path.split('.')
|
||||
|
||||
if (self._dynamic and len(parts) and parts[0] in
|
||||
self._dynamic_fields):
|
||||
del set_data[path]
|
||||
unset_data[path] = 1
|
||||
continue
|
||||
elif path in self._fields:
|
||||
|
||||
# If we've set a value that ain't the default value don't unset it.
|
||||
default = None
|
||||
if path in self._fields:
|
||||
default = self._fields[path].default
|
||||
else: # Perform a full lookup for lists / embedded lookups
|
||||
d = self
|
||||
parts = path.split('.')
|
||||
db_field_name = parts.pop()
|
||||
for p in parts:
|
||||
if isinstance(d, list) and p.lstrip('-').isdigit():
|
||||
if p[0] == '-':
|
||||
p = str(len(d) + int(p))
|
||||
if isinstance(d, list) and p.isdigit():
|
||||
d = d[int(p)]
|
||||
elif (hasattr(d, '__getattribute__') and
|
||||
not isinstance(d, dict)):
|
||||
@@ -645,10 +644,9 @@ class BaseDocument(object):
|
||||
default = None
|
||||
|
||||
if default is not None:
|
||||
if callable(default):
|
||||
default = default()
|
||||
default = default() if callable(default) else default
|
||||
|
||||
if default != value:
|
||||
if value != default:
|
||||
continue
|
||||
|
||||
del set_data[path]
|
||||
@@ -694,7 +692,7 @@ class BaseDocument(object):
|
||||
|
||||
fields = cls._fields
|
||||
if not _auto_dereference:
|
||||
fields = copy.copy(fields)
|
||||
fields = copy.deepcopy(fields)
|
||||
|
||||
for field_name, field in fields.iteritems():
|
||||
field._auto_dereference = _auto_dereference
|
||||
@@ -1080,5 +1078,11 @@ class BaseDocument(object):
|
||||
"""Return the display value for a choice field"""
|
||||
value = getattr(self, field.name)
|
||||
if field.choices and isinstance(field.choices[0], (list, tuple)):
|
||||
return dict(field.choices).get(value, value)
|
||||
if value is None:
|
||||
return None
|
||||
sep = getattr(field, 'display_sep', ' ')
|
||||
values = value if field.__class__.__name__ in ('ListField', 'SortedListField') else [value]
|
||||
return sep.join([
|
||||
six.text_type(dict(field.choices).get(val, val))
|
||||
for val in values or []])
|
||||
return value
|
||||
|
||||
@@ -55,7 +55,7 @@ class BaseField(object):
|
||||
field. Generally this is deprecated in favour of the
|
||||
`FIELD.validate` method
|
||||
:param choices: (optional) The valid choices
|
||||
:param null: (optional) Is the field value can be null. If no and there is a default value
|
||||
:param null: (optional) If the field value can be null. If no and there is a default value
|
||||
then the default value is set
|
||||
:param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False`
|
||||
means that uniqueness won't be enforced for `None` values
|
||||
@@ -130,7 +130,6 @@ class BaseField(object):
|
||||
def __set__(self, instance, value):
|
||||
"""Descriptor for assigning a value to a field in a document.
|
||||
"""
|
||||
|
||||
# If setting to None and there is a default
|
||||
# Then set the value to the default value
|
||||
if value is None:
|
||||
@@ -213,8 +212,10 @@ class BaseField(object):
|
||||
)
|
||||
)
|
||||
# Choices which are types other than Documents
|
||||
elif value not in choice_list:
|
||||
self.error('Value must be one of %s' % six.text_type(choice_list))
|
||||
else:
|
||||
values = value if isinstance(value, (list, tuple)) else [value]
|
||||
if len(set(values) - set(choice_list)):
|
||||
self.error('Value must be one of %s' % six.text_type(choice_list))
|
||||
|
||||
def _validate(self, value, **kwargs):
|
||||
# Check the Choices Constraint
|
||||
@@ -265,13 +266,15 @@ class ComplexBaseField(BaseField):
|
||||
ReferenceField = _import_class('ReferenceField')
|
||||
GenericReferenceField = _import_class('GenericReferenceField')
|
||||
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
|
||||
dereference = (self._auto_dereference and
|
||||
|
||||
auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
|
||||
dereference = (auto_dereference and
|
||||
(self.field is None or isinstance(self.field,
|
||||
(GenericReferenceField, ReferenceField))))
|
||||
|
||||
_dereference = _import_class('DeReference')()
|
||||
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
if instance._initialised and dereference and instance._data.get(self.name):
|
||||
instance._data[self.name] = _dereference(
|
||||
instance._data.get(self.name), max_depth=1, instance=instance,
|
||||
@@ -292,7 +295,7 @@ class ComplexBaseField(BaseField):
|
||||
value = BaseDict(value, instance, self.name)
|
||||
instance._data[self.name] = value
|
||||
|
||||
if (self._auto_dereference and instance._initialised and
|
||||
if (auto_dereference and instance._initialised and
|
||||
isinstance(value, (BaseList, BaseDict)) and
|
||||
not value._dereferenced):
|
||||
value = _dereference(
|
||||
@@ -311,11 +314,16 @@ class ComplexBaseField(BaseField):
|
||||
if hasattr(value, 'to_python'):
|
||||
return value.to_python()
|
||||
|
||||
BaseDocument = _import_class('BaseDocument')
|
||||
if isinstance(value, BaseDocument):
|
||||
# Something is wrong, return the value as it is
|
||||
return value
|
||||
|
||||
is_list = False
|
||||
if not hasattr(value, 'items'):
|
||||
try:
|
||||
is_list = True
|
||||
value = {k: v for k, v in enumerate(value)}
|
||||
value = {idx: v for idx, v in enumerate(value)}
|
||||
except TypeError: # Not iterable return the value
|
||||
return value
|
||||
|
||||
@@ -500,7 +508,7 @@ class GeoJsonBaseField(BaseField):
|
||||
def validate(self, value):
|
||||
"""Validate the GeoJson object based on its type."""
|
||||
if isinstance(value, dict):
|
||||
if set(value.keys()) == set(['type', 'coordinates']):
|
||||
if set(value.keys()) == {'type', 'coordinates'}:
|
||||
if value['type'] != self._type:
|
||||
self.error('%s type must be "%s"' %
|
||||
(self._name, self._type))
|
||||
|
||||
@@ -18,14 +18,14 @@ class DocumentMetaclass(type):
|
||||
"""Metaclass for all documents."""
|
||||
|
||||
# TODO lower complexity of this method
|
||||
def __new__(cls, name, bases, attrs):
|
||||
flattened_bases = cls._get_bases(bases)
|
||||
super_new = super(DocumentMetaclass, cls).__new__
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
flattened_bases = mcs._get_bases(bases)
|
||||
super_new = super(DocumentMetaclass, mcs).__new__
|
||||
|
||||
# If a base class just call super
|
||||
metaclass = attrs.get('my_metaclass')
|
||||
if metaclass and issubclass(metaclass, DocumentMetaclass):
|
||||
return super_new(cls, name, bases, attrs)
|
||||
return super_new(mcs, name, bases, attrs)
|
||||
|
||||
attrs['_is_document'] = attrs.get('_is_document', False)
|
||||
attrs['_cached_reference_fields'] = []
|
||||
@@ -121,7 +121,8 @@ class DocumentMetaclass(type):
|
||||
# inheritance of classes where inheritance is set to False
|
||||
allow_inheritance = base._meta.get('allow_inheritance')
|
||||
if not allow_inheritance and not base._meta.get('abstract'):
|
||||
raise ValueError('Document %s may not be subclassed' %
|
||||
raise ValueError('Document %s may not be subclassed. '
|
||||
'To enable inheritance, use the "allow_inheritance" meta attribute.' %
|
||||
base.__name__)
|
||||
|
||||
# Get superclasses from last base superclass
|
||||
@@ -138,7 +139,7 @@ class DocumentMetaclass(type):
|
||||
attrs['_types'] = attrs['_subclasses'] # TODO depreciate _types
|
||||
|
||||
# Create the new_class
|
||||
new_class = super_new(cls, name, bases, attrs)
|
||||
new_class = super_new(mcs, name, bases, attrs)
|
||||
|
||||
# Set _subclasses
|
||||
for base in document_bases:
|
||||
@@ -147,7 +148,7 @@ class DocumentMetaclass(type):
|
||||
base._types = base._subclasses # TODO depreciate _types
|
||||
|
||||
(Document, EmbeddedDocument, DictField,
|
||||
CachedReferenceField) = cls._import_classes()
|
||||
CachedReferenceField) = mcs._import_classes()
|
||||
|
||||
if issubclass(new_class, Document):
|
||||
new_class._collection = None
|
||||
@@ -219,29 +220,26 @@ class DocumentMetaclass(type):
|
||||
|
||||
return new_class
|
||||
|
||||
def add_to_class(self, name, value):
|
||||
setattr(self, name, value)
|
||||
|
||||
@classmethod
|
||||
def _get_bases(cls, bases):
|
||||
def _get_bases(mcs, bases):
|
||||
if isinstance(bases, BasesTuple):
|
||||
return bases
|
||||
seen = []
|
||||
bases = cls.__get_bases(bases)
|
||||
bases = mcs.__get_bases(bases)
|
||||
unique_bases = (b for b in bases if not (b in seen or seen.append(b)))
|
||||
return BasesTuple(unique_bases)
|
||||
|
||||
@classmethod
|
||||
def __get_bases(cls, bases):
|
||||
def __get_bases(mcs, bases):
|
||||
for base in bases:
|
||||
if base is object:
|
||||
continue
|
||||
yield base
|
||||
for child_base in cls.__get_bases(base.__bases__):
|
||||
for child_base in mcs.__get_bases(base.__bases__):
|
||||
yield child_base
|
||||
|
||||
@classmethod
|
||||
def _import_classes(cls):
|
||||
def _import_classes(mcs):
|
||||
Document = _import_class('Document')
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
DictField = _import_class('DictField')
|
||||
@@ -254,9 +252,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
collection in the database.
|
||||
"""
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
flattened_bases = cls._get_bases(bases)
|
||||
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
flattened_bases = mcs._get_bases(bases)
|
||||
super_new = super(TopLevelDocumentMetaclass, mcs).__new__
|
||||
|
||||
# Set default _meta data if base class, otherwise get user defined meta
|
||||
if attrs.get('my_metaclass') == TopLevelDocumentMetaclass:
|
||||
@@ -319,7 +317,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
not parent_doc_cls._meta.get('abstract', False)):
|
||||
msg = 'Abstract document cannot have non-abstract base'
|
||||
raise ValueError(msg)
|
||||
return super_new(cls, name, bases, attrs)
|
||||
return super_new(mcs, name, bases, attrs)
|
||||
|
||||
# Merge base class metas.
|
||||
# Uses a special MetaDict that handles various merging rules
|
||||
@@ -360,7 +358,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
attrs['_meta'] = meta
|
||||
|
||||
# Call super and get the new class
|
||||
new_class = super_new(cls, name, bases, attrs)
|
||||
new_class = super_new(mcs, name, bases, attrs)
|
||||
|
||||
meta = new_class._meta
|
||||
|
||||
@@ -394,7 +392,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
'_auto_id_field', False)
|
||||
if not new_class._meta.get('id_field'):
|
||||
# After 0.10, find not existing names, instead of overwriting
|
||||
id_name, id_db_name = cls.get_auto_id_names(new_class)
|
||||
id_name, id_db_name = mcs.get_auto_id_names(new_class)
|
||||
new_class._auto_id_field = True
|
||||
new_class._meta['id_field'] = id_name
|
||||
new_class._fields[id_name] = ObjectIdField(db_field=id_db_name)
|
||||
@@ -419,7 +417,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
return new_class
|
||||
|
||||
@classmethod
|
||||
def get_auto_id_names(cls, new_class):
|
||||
def get_auto_id_names(mcs, new_class):
|
||||
id_name, id_db_name = ('id', '_id')
|
||||
if id_name not in new_class._fields and \
|
||||
id_db_name not in (v.db_field for v in new_class._fields.values()):
|
||||
|
||||
22
mongoengine/base/utils.py
Normal file
22
mongoengine/base/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import re
|
||||
|
||||
|
||||
class LazyRegexCompiler(object):
|
||||
"""Descriptor to allow lazy compilation of regex"""
|
||||
|
||||
def __init__(self, pattern, flags=0):
|
||||
self._pattern = pattern
|
||||
self._flags = flags
|
||||
self._compiled_regex = None
|
||||
|
||||
@property
|
||||
def compiled_regex(self):
|
||||
if self._compiled_regex is None:
|
||||
self._compiled_regex = re.compile(self._pattern, self._flags)
|
||||
return self._compiled_regex
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.compiled_regex
|
||||
|
||||
def __set__(self, instance, value):
|
||||
raise AttributeError("Can not set attribute LazyRegexCompiler")
|
||||
@@ -28,7 +28,7 @@ _connections = {}
|
||||
_dbs = {}
|
||||
|
||||
|
||||
def register_connection(alias, name=None, host=None, port=None,
|
||||
def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
@@ -39,6 +39,7 @@ def register_connection(alias, name=None, host=None, port=None,
|
||||
:param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
:param name: the name of the specific database to use
|
||||
:param db: the name of the database to use, for compatibility with connect
|
||||
:param host: the host name of the :program:`mongod` instance to connect to
|
||||
:param port: the port that the :program:`mongod` instance is running on
|
||||
:param read_preference: The read preference for the collection
|
||||
@@ -58,7 +59,7 @@ def register_connection(alias, name=None, host=None, port=None,
|
||||
.. versionchanged:: 0.10.6 - added mongomock support
|
||||
"""
|
||||
conn_settings = {
|
||||
'name': name or 'test',
|
||||
'name': name or db or 'test',
|
||||
'host': host or 'localhost',
|
||||
'port': port or 27017,
|
||||
'read_preference': read_preference,
|
||||
@@ -103,6 +104,18 @@ def register_connection(alias, name=None, host=None, port=None,
|
||||
conn_settings['authentication_source'] = uri_options['authsource']
|
||||
if 'authmechanism' in uri_options:
|
||||
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
|
||||
if IS_PYMONGO_3 and 'readpreference' in uri_options:
|
||||
read_preferences = (
|
||||
ReadPreference.NEAREST,
|
||||
ReadPreference.PRIMARY,
|
||||
ReadPreference.PRIMARY_PREFERRED,
|
||||
ReadPreference.SECONDARY,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
read_pf_mode = uri_options['readpreference'].lower()
|
||||
for preference in read_preferences:
|
||||
if preference.name.lower() == read_pf_mode:
|
||||
conn_settings['read_preference'] = preference
|
||||
break
|
||||
else:
|
||||
resolved_hosts.append(entity)
|
||||
conn_settings['host'] = resolved_hosts
|
||||
@@ -146,13 +159,14 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
raise MongoEngineConnectionError(msg)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
irrelevant_fields = set([
|
||||
'name', 'username', 'password', 'authentication_source',
|
||||
'authentication_mechanism'
|
||||
])
|
||||
# set literal more efficient than calling set function
|
||||
irrelevant_fields_set = {
|
||||
'name', 'username', 'password',
|
||||
'authentication_source', 'authentication_mechanism'
|
||||
}
|
||||
return {
|
||||
k: v for k, v in settings_dict.items()
|
||||
if k not in irrelevant_fields
|
||||
if k not in irrelevant_fields_set
|
||||
}
|
||||
|
||||
# Retrieve a copy of the connection settings associated with the requested
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from pymongo.write_concern import WriteConcern
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
|
||||
|
||||
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
|
||||
'no_sub_classes', 'query_counter')
|
||||
'no_sub_classes', 'query_counter', 'set_write_concern')
|
||||
|
||||
|
||||
class switch_db(object):
|
||||
@@ -143,66 +145,85 @@ class no_sub_classes(object):
|
||||
:param cls: the class to turn querying sub classes on
|
||||
"""
|
||||
self.cls = cls
|
||||
self.cls_initial_subclasses = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Change the objects default and _auto_dereference values."""
|
||||
self.cls._all_subclasses = self.cls._subclasses
|
||||
self.cls._subclasses = (self.cls,)
|
||||
self.cls_initial_subclasses = self.cls._subclasses
|
||||
self.cls._subclasses = (self.cls._class_name,)
|
||||
return self.cls
|
||||
|
||||
def __exit__(self, t, value, traceback):
|
||||
"""Reset the default and _auto_dereference values."""
|
||||
self.cls._subclasses = self.cls._all_subclasses
|
||||
delattr(self.cls, '_all_subclasses')
|
||||
return self.cls
|
||||
self.cls._subclasses = self.cls_initial_subclasses
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
"""Query_counter context manager to get the number of queries."""
|
||||
"""Query_counter context manager to get the number of queries.
|
||||
This works by updating the `profiling_level` of the database so that all queries get logged,
|
||||
resetting the db.system.profile collection at the beginnig of the context and counting the new entries.
|
||||
|
||||
This was designed for debugging purpose. In fact it is a global counter so queries issued by other threads/processes
|
||||
can interfere with it
|
||||
|
||||
Be aware that:
|
||||
- Iterating over large amount of documents (>101) makes pymongo issue `getmore` queries to fetch the next batch of
|
||||
documents (https://docs.mongodb.com/manual/tutorial/iterate-a-cursor/#cursor-batches)
|
||||
- Some queries are ignored by default by the counter (killcursors, db.system.indexes)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Construct the query_counter."""
|
||||
self.counter = 0
|
||||
"""Construct the query_counter
|
||||
"""
|
||||
self.db = get_db()
|
||||
self.initial_profiling_level = None
|
||||
self._ctx_query_counter = 0 # number of queries issued by the context
|
||||
|
||||
def __enter__(self):
|
||||
"""On every with block we need to drop the profile collection."""
|
||||
self._ignored_query = {
|
||||
'ns':
|
||||
{'$ne': '%s.system.indexes' % self.db.name},
|
||||
'op': # MONGODB < 3.2
|
||||
{'$ne': 'killcursors'},
|
||||
'command.killCursors': # MONGODB >= 3.2
|
||||
{'$exists': False}
|
||||
}
|
||||
|
||||
def _turn_on_profiling(self):
|
||||
self.initial_profiling_level = self.db.profiling_level()
|
||||
self.db.set_profiling_level(0)
|
||||
self.db.system.profile.drop()
|
||||
self.db.set_profiling_level(2)
|
||||
|
||||
def _resets_profiling(self):
|
||||
self.db.set_profiling_level(self.initial_profiling_level)
|
||||
|
||||
def __enter__(self):
|
||||
self._turn_on_profiling()
|
||||
return self
|
||||
|
||||
def __exit__(self, t, value, traceback):
|
||||
"""Reset the profiling level."""
|
||||
self.db.set_profiling_level(0)
|
||||
self._resets_profiling()
|
||||
|
||||
def __eq__(self, value):
|
||||
"""== Compare querycounter."""
|
||||
counter = self._get_count()
|
||||
return value == counter
|
||||
|
||||
def __ne__(self, value):
|
||||
"""!= Compare querycounter."""
|
||||
return not self.__eq__(value)
|
||||
|
||||
def __lt__(self, value):
|
||||
"""< Compare querycounter."""
|
||||
return self._get_count() < value
|
||||
|
||||
def __le__(self, value):
|
||||
"""<= Compare querycounter."""
|
||||
return self._get_count() <= value
|
||||
|
||||
def __gt__(self, value):
|
||||
"""> Compare querycounter."""
|
||||
return self._get_count() > value
|
||||
|
||||
def __ge__(self, value):
|
||||
""">= Compare querycounter."""
|
||||
return self._get_count() >= value
|
||||
|
||||
def __int__(self):
|
||||
"""int representation."""
|
||||
return self._get_count()
|
||||
|
||||
def __repr__(self):
|
||||
@@ -210,8 +231,17 @@ class query_counter(object):
|
||||
return u"%s" % self._get_count()
|
||||
|
||||
def _get_count(self):
|
||||
"""Get the number of queries."""
|
||||
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
|
||||
count = self.db.system.profile.find(ignore_query).count() - self.counter
|
||||
self.counter += 1
|
||||
"""Get the number of queries by counting the current number of entries in db.system.profile
|
||||
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
|
||||
issued so we need to balance that
|
||||
"""
|
||||
count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter
|
||||
self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
|
||||
return count
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_write_concern(collection, write_concerns):
|
||||
combined_concerns = dict(collection.write_concern.document.items())
|
||||
combined_concerns.update(write_concerns)
|
||||
yield collection.with_options(write_concern=WriteConcern(**combined_concerns))
|
||||
|
||||
@@ -3,6 +3,7 @@ import six
|
||||
|
||||
from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
|
||||
TopLevelDocumentMetaclass, get_document)
|
||||
from mongoengine.base.datastructures import LazyReference
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.document import Document, EmbeddedDocument
|
||||
from mongoengine.fields import DictField, ListField, MapField, ReferenceField
|
||||
@@ -51,26 +52,40 @@ class DeReference(object):
|
||||
[i.__class__ == doc_type for i in items.values()]):
|
||||
return items
|
||||
elif not field.dbref:
|
||||
# We must turn the ObjectIds into DBRefs
|
||||
|
||||
# Recursively dig into the sub items of a list/dict
|
||||
# to turn the ObjectIds into DBRefs
|
||||
def _get_items_from_list(items):
|
||||
new_items = []
|
||||
for v in items:
|
||||
value = v
|
||||
if isinstance(v, dict):
|
||||
value = _get_items_from_dict(v)
|
||||
elif isinstance(v, list):
|
||||
value = _get_items_from_list(v)
|
||||
elif not isinstance(v, (DBRef, Document)):
|
||||
value = field.to_python(v)
|
||||
new_items.append(value)
|
||||
return new_items
|
||||
|
||||
def _get_items_from_dict(items):
|
||||
new_items = {}
|
||||
for k, v in items.iteritems():
|
||||
value = v
|
||||
if isinstance(v, list):
|
||||
value = _get_items_from_list(v)
|
||||
elif isinstance(v, dict):
|
||||
value = _get_items_from_dict(v)
|
||||
elif not isinstance(v, (DBRef, Document)):
|
||||
value = field.to_python(v)
|
||||
new_items[k] = value
|
||||
return new_items
|
||||
|
||||
if not hasattr(items, 'items'):
|
||||
|
||||
def _get_items(items):
|
||||
new_items = []
|
||||
for v in items:
|
||||
if isinstance(v, list):
|
||||
new_items.append(_get_items(v))
|
||||
elif not isinstance(v, (DBRef, Document)):
|
||||
new_items.append(field.to_python(v))
|
||||
else:
|
||||
new_items.append(v)
|
||||
return new_items
|
||||
|
||||
items = _get_items(items)
|
||||
items = _get_items_from_list(items)
|
||||
else:
|
||||
items = {
|
||||
k: (v if isinstance(v, (DBRef, Document))
|
||||
else field.to_python(v))
|
||||
for k, v in items.iteritems()
|
||||
}
|
||||
items = _get_items_from_dict(items)
|
||||
|
||||
self.reference_map = self._find_references(items)
|
||||
self.object_map = self._fetch_objects(doc_type=doc_type)
|
||||
@@ -99,7 +114,10 @@ class DeReference(object):
|
||||
if isinstance(item, (Document, EmbeddedDocument)):
|
||||
for field_name, field in item._fields.iteritems():
|
||||
v = item._data.get(field_name, None)
|
||||
if isinstance(v, DBRef):
|
||||
if isinstance(v, LazyReference):
|
||||
# LazyReference inherits DBRef but should not be dereferenced here !
|
||||
continue
|
||||
elif isinstance(v, DBRef):
|
||||
reference_map.setdefault(field.document_type, set()).add(v.id)
|
||||
elif isinstance(v, (dict, SON)) and '_ref' in v:
|
||||
reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id)
|
||||
@@ -110,6 +128,9 @@ class DeReference(object):
|
||||
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
|
||||
key = field_cls
|
||||
reference_map.setdefault(key, set()).update(refs)
|
||||
elif isinstance(item, LazyReference):
|
||||
# LazyReference inherits DBRef but should not be dereferenced here !
|
||||
continue
|
||||
elif isinstance(item, DBRef):
|
||||
reference_map.setdefault(item.collection, set()).add(item.id)
|
||||
elif isinstance(item, (dict, SON)) and '_ref' in item:
|
||||
@@ -126,7 +147,12 @@ class DeReference(object):
|
||||
"""
|
||||
object_map = {}
|
||||
for collection, dbrefs in self.reference_map.iteritems():
|
||||
if hasattr(collection, 'objects'): # We have a document class for the refs
|
||||
|
||||
# we use getattr instead of hasattr because hasattr swallows any exception under python2
|
||||
# so it could hide nasty things without raising exceptions (cfr bug #1688))
|
||||
ref_document_cls_exists = (getattr(collection, 'objects', None) is not None)
|
||||
|
||||
if ref_document_cls_exists:
|
||||
col_name = collection._get_collection_name()
|
||||
refs = [dbref for dbref in dbrefs
|
||||
if (col_name, dbref) not in object_map]
|
||||
@@ -134,7 +160,7 @@ class DeReference(object):
|
||||
for key, doc in references.iteritems():
|
||||
object_map[(col_name, key)] = doc
|
||||
else: # Generic reference: use the refs data to convert to document
|
||||
if isinstance(doc_type, (ListField, DictField, MapField,)):
|
||||
if isinstance(doc_type, (ListField, DictField, MapField)):
|
||||
continue
|
||||
|
||||
refs = [dbref for dbref in dbrefs
|
||||
@@ -230,7 +256,7 @@ class DeReference(object):
|
||||
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
||||
item_name = '%s.%s' % (name, k) if name else name
|
||||
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
|
||||
elif hasattr(v, 'id'):
|
||||
elif isinstance(v, DBRef) and hasattr(v, 'id'):
|
||||
data[k] = self.object_map.get((v.collection, v.id), v)
|
||||
|
||||
if instance and name:
|
||||
|
||||
@@ -12,7 +12,9 @@ from mongoengine.base import (BaseDict, BaseDocument, BaseList,
|
||||
TopLevelDocumentMetaclass, get_document)
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
from mongoengine.context_managers import switch_collection, switch_db
|
||||
from mongoengine.context_managers import (set_write_concern,
|
||||
switch_collection,
|
||||
switch_db)
|
||||
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
|
||||
SaveConditionError)
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
@@ -39,7 +41,7 @@ class InvalidCollectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddedDocument(BaseDocument):
|
||||
class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
|
||||
"""A :class:`~mongoengine.Document` that isn't stored in its own
|
||||
collection. :class:`~mongoengine.EmbeddedDocument`\ s should be used as
|
||||
fields on :class:`~mongoengine.Document`\ s through the
|
||||
@@ -58,7 +60,6 @@ class EmbeddedDocument(BaseDocument):
|
||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||
my_metaclass = DocumentMetaclass
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
# A generic embedded document doesn't have any immutable properties
|
||||
# that describe it uniquely, hence it shouldn't be hashable. You can
|
||||
@@ -95,7 +96,7 @@ class EmbeddedDocument(BaseDocument):
|
||||
self._instance.reload(*args, **kwargs)
|
||||
|
||||
|
||||
class Document(BaseDocument):
|
||||
class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
"""The base class used for defining the structure and properties of
|
||||
collections of documents stored in MongoDB. Inherit from this class, and
|
||||
add fields as class attributes to define a document's structure.
|
||||
@@ -150,7 +151,6 @@ class Document(BaseDocument):
|
||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||
my_metaclass = TopLevelDocumentMetaclass
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
__slots__ = ('__objects',)
|
||||
|
||||
@@ -172,8 +172,8 @@ class Document(BaseDocument):
|
||||
"""
|
||||
if self.pk is None:
|
||||
return super(BaseDocument, self).__hash__()
|
||||
else:
|
||||
return hash(self.pk)
|
||||
|
||||
return hash(self.pk)
|
||||
|
||||
@classmethod
|
||||
def _get_db(cls):
|
||||
@@ -195,7 +195,10 @@ class Document(BaseDocument):
|
||||
|
||||
# Ensure indexes on the collection unless auto_create_index was
|
||||
# set to False.
|
||||
if cls._meta.get('auto_create_index', True):
|
||||
# Also there is no need to ensure indexes on slave.
|
||||
db = cls._get_db()
|
||||
if cls._meta.get('auto_create_index', True) and\
|
||||
db.client.is_primary:
|
||||
cls.ensure_indexes()
|
||||
|
||||
return cls._collection
|
||||
@@ -280,6 +283,9 @@ class Document(BaseDocument):
|
||||
elif query[id_field] != self.pk:
|
||||
raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
|
||||
|
||||
# Need to add shard key to query, or you get an error
|
||||
query.update(self._object_key)
|
||||
|
||||
updated = self._qs(**query).modify(new=True, **update)
|
||||
if updated is None:
|
||||
return False
|
||||
@@ -364,6 +370,8 @@ class Document(BaseDocument):
|
||||
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self,
|
||||
created=created, **signal_kwargs)
|
||||
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
|
||||
doc = self.to_mongo()
|
||||
|
||||
if self._meta.get('auto_create_index', True):
|
||||
self.ensure_indexes()
|
||||
@@ -423,11 +431,18 @@ class Document(BaseDocument):
|
||||
Helper method, should only be used inside save().
|
||||
"""
|
||||
collection = self._get_collection()
|
||||
with set_write_concern(collection, write_concern) as wc_collection:
|
||||
if force_insert:
|
||||
return wc_collection.insert_one(doc).inserted_id
|
||||
# insert_one will provoke UniqueError alongside save does not
|
||||
# therefore, it need to catch and call replace_one.
|
||||
if '_id' in doc:
|
||||
raw_object = wc_collection.find_one_and_replace(
|
||||
{'_id': doc['_id']}, doc)
|
||||
if raw_object:
|
||||
return doc['_id']
|
||||
|
||||
if force_insert:
|
||||
return collection.insert(doc, **write_concern)
|
||||
|
||||
object_id = collection.save(doc, **write_concern)
|
||||
object_id = wc_collection.insert_one(doc).inserted_id
|
||||
|
||||
# In PyMongo 3.0, the save() call calls internally the _update() call
|
||||
# but they forget to return the _id value passed back, therefore getting it back here
|
||||
@@ -576,12 +591,11 @@ class Document(BaseDocument):
|
||||
"""Delete the :class:`~mongoengine.Document` from the database. This
|
||||
will only take effect if the document has been previously saved.
|
||||
|
||||
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
|
||||
:param signal_kwargs: (optional) kwargs dictionary to be passed to
|
||||
the signal calls.
|
||||
:param write_concern: Extra keyword arguments are passed down which
|
||||
will be used as options for the resultant
|
||||
``getLastError`` command. For example,
|
||||
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
|
||||
will be used as options for the resultant ``getLastError`` command.
|
||||
For example, ``save(..., w: 2, fsync: True)`` will
|
||||
wait until at least two servers have recorded the write and
|
||||
will force an fsync on the primary server.
|
||||
|
||||
@@ -702,7 +716,6 @@ class Document(BaseDocument):
|
||||
obj = obj[0]
|
||||
else:
|
||||
raise self.DoesNotExist('Document does not exist')
|
||||
|
||||
for field in obj._data:
|
||||
if not fields or field in fields:
|
||||
try:
|
||||
@@ -710,7 +723,7 @@ class Document(BaseDocument):
|
||||
except (KeyError, AttributeError):
|
||||
try:
|
||||
# If field is a special field, e.g. items is stored as _reserved_items,
|
||||
# an KeyError is thrown. So try to retrieve the field from _data
|
||||
# a KeyError is thrown. So try to retrieve the field from _data
|
||||
setattr(self, field, self._reload(field, obj._data.get(field)))
|
||||
except KeyError:
|
||||
# If field is removed from the database while the object
|
||||
@@ -718,7 +731,9 @@ class Document(BaseDocument):
|
||||
# i.e. obj.update(unset__field=1) followed by obj.reload()
|
||||
delattr(self, field)
|
||||
|
||||
self._changed_fields = obj._changed_fields
|
||||
self._changed_fields = list(
|
||||
set(self._changed_fields) - set(fields)
|
||||
) if fields else obj._changed_fields
|
||||
self._created = False
|
||||
return self
|
||||
|
||||
@@ -964,8 +979,16 @@ class Document(BaseDocument):
|
||||
"""
|
||||
|
||||
required = cls.list_indexes()
|
||||
existing = [info['key']
|
||||
for info in cls._get_collection().index_information().values()]
|
||||
|
||||
existing = []
|
||||
for info in cls._get_collection().index_information().values():
|
||||
if '_fts' in info['key'][0]:
|
||||
index_type = info['key'][0][1]
|
||||
text_index_fields = info.get('weights').keys()
|
||||
existing.append(
|
||||
[(key, index_type) for key in text_index_fields])
|
||||
else:
|
||||
existing.append(info['key'])
|
||||
missing = [index for index in required if index not in existing]
|
||||
extra = [index for index in existing if index not in required]
|
||||
|
||||
@@ -982,10 +1005,10 @@ class Document(BaseDocument):
|
||||
return {'missing': missing, 'extra': extra}
|
||||
|
||||
|
||||
class DynamicDocument(Document):
|
||||
class DynamicDocument(six.with_metaclass(TopLevelDocumentMetaclass, Document)):
|
||||
"""A Dynamic Document class allowing flexible, expandable and uncontrolled
|
||||
schemas. As a :class:`~mongoengine.Document` subclass, acts in the same
|
||||
way as an ordinary document but has expando style properties. Any data
|
||||
way as an ordinary document but has expanded style properties. Any data
|
||||
passed or set against the :class:`~mongoengine.DynamicDocument` that is
|
||||
not a field is automatically converted into a
|
||||
:class:`~mongoengine.fields.DynamicField` and data can be attributed to that
|
||||
@@ -999,7 +1022,6 @@ class DynamicDocument(Document):
|
||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||
my_metaclass = TopLevelDocumentMetaclass
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
_dynamic = True
|
||||
|
||||
@@ -1010,11 +1032,12 @@ class DynamicDocument(Document):
|
||||
field_name = args[0]
|
||||
if field_name in self._dynamic_fields:
|
||||
setattr(self, field_name, None)
|
||||
self._dynamic_fields[field_name].null = False
|
||||
else:
|
||||
super(DynamicDocument, self).__delattr__(*args, **kwargs)
|
||||
|
||||
|
||||
class DynamicEmbeddedDocument(EmbeddedDocument):
|
||||
class DynamicEmbeddedDocument(six.with_metaclass(DocumentMetaclass, EmbeddedDocument)):
|
||||
"""A Dynamic Embedded Document class allowing flexible, expandable and
|
||||
uncontrolled schemas. See :class:`~mongoengine.DynamicDocument` for more
|
||||
information about dynamic documents.
|
||||
@@ -1023,7 +1046,6 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
|
||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||
my_metaclass = DocumentMetaclass
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
_dynamic = True
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ class ValidationError(AssertionError):
|
||||
_message = None
|
||||
|
||||
def __init__(self, message='', **kwargs):
|
||||
super(ValidationError, self).__init__(message)
|
||||
self.errors = kwargs.get('errors', {})
|
||||
self.field_name = kwargs.get('field_name')
|
||||
self.message = message
|
||||
|
||||
@@ -5,7 +5,6 @@ import re
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
|
||||
from bson import Binary, DBRef, ObjectId, SON
|
||||
@@ -25,13 +24,18 @@ try:
|
||||
except ImportError:
|
||||
Int64 = long
|
||||
|
||||
|
||||
from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField,
|
||||
GeoJsonBaseField, ObjectIdField, get_document)
|
||||
GeoJsonBaseField, LazyReference, ObjectIdField,
|
||||
get_document)
|
||||
from mongoengine.base.utils import LazyRegexCompiler
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
from mongoengine.document import Document, EmbeddedDocument
|
||||
from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
|
||||
from mongoengine.python_support import StringIO
|
||||
from mongoengine.queryset import DO_NOTHING, QuerySet
|
||||
from mongoengine.queryset import DO_NOTHING
|
||||
from mongoengine.queryset.base import BaseQuerySet
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageOps
|
||||
@@ -39,13 +43,20 @@ except ImportError:
|
||||
Image = None
|
||||
ImageOps = None
|
||||
|
||||
if six.PY3:
|
||||
# Useless as long as 2to3 gets executed
|
||||
# as it turns `long` into `int` blindly
|
||||
long = int
|
||||
|
||||
|
||||
__all__ = (
|
||||
'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
|
||||
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
|
||||
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'DateField',
|
||||
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
|
||||
'GenericEmbeddedDocumentField', 'DynamicField', 'ListField',
|
||||
'SortedListField', 'EmbeddedDocumentListField', 'DictField',
|
||||
'MapField', 'ReferenceField', 'CachedReferenceField',
|
||||
'LazyReferenceField', 'GenericLazyReferenceField',
|
||||
'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy',
|
||||
'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField',
|
||||
'GeoPointField', 'PointField', 'LineStringField', 'PolygonField',
|
||||
@@ -120,9 +131,9 @@ class URLField(StringField):
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
|
||||
_URL_REGEX = re.compile(
|
||||
_URL_REGEX = LazyRegexCompiler(
|
||||
r'^(?:[a-z0-9\.\-]*)://' # scheme is validated separately
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|' # domain...
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-_]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|' # domain...
|
||||
r'localhost|' # localhost...
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|' # ...or ipv4
|
||||
r'\[?[A-F0-9]*:[A-F0-9:]+\]?)' # ...or ipv6
|
||||
@@ -130,8 +141,7 @@ class URLField(StringField):
|
||||
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
||||
_URL_SCHEMES = ['http', 'https', 'ftp', 'ftps']
|
||||
|
||||
def __init__(self, verify_exists=False, url_regex=None, schemes=None, **kwargs):
|
||||
self.verify_exists = verify_exists
|
||||
def __init__(self, url_regex=None, schemes=None, **kwargs):
|
||||
self.url_regex = url_regex or self._URL_REGEX
|
||||
self.schemes = schemes or self._URL_SCHEMES
|
||||
super(URLField, self).__init__(**kwargs)
|
||||
@@ -154,7 +164,7 @@ class EmailField(StringField):
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
USER_REGEX = re.compile(
|
||||
USER_REGEX = LazyRegexCompiler(
|
||||
# `dot-atom` defined in RFC 5322 Section 3.2.3.
|
||||
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
|
||||
# `quoted-string` defined in RFC 5322 Section 3.2.4.
|
||||
@@ -162,7 +172,7 @@ class EmailField(StringField):
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
UTF8_USER_REGEX = re.compile(
|
||||
UTF8_USER_REGEX = LazyRegexCompiler(
|
||||
six.u(
|
||||
# RFC 6531 Section 3.3 extends `atext` (used by dot-atom) to
|
||||
# include `UTF8-non-ascii`.
|
||||
@@ -172,7 +182,7 @@ class EmailField(StringField):
|
||||
), re.IGNORECASE | re.UNICODE
|
||||
)
|
||||
|
||||
DOMAIN_REGEX = re.compile(
|
||||
DOMAIN_REGEX = LazyRegexCompiler(
|
||||
r'((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z',
|
||||
re.IGNORECASE
|
||||
)
|
||||
@@ -264,14 +274,14 @@ class IntField(BaseField):
|
||||
def to_python(self, value):
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return value
|
||||
|
||||
def validate(self, value):
|
||||
try:
|
||||
value = int(value)
|
||||
except Exception:
|
||||
except (TypeError, ValueError):
|
||||
self.error('%s could not be converted to int' % value)
|
||||
|
||||
if self.min_value is not None and value < self.min_value:
|
||||
@@ -297,7 +307,7 @@ class LongField(BaseField):
|
||||
def to_python(self, value):
|
||||
try:
|
||||
value = long(value)
|
||||
except ValueError:
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return value
|
||||
|
||||
@@ -307,7 +317,7 @@ class LongField(BaseField):
|
||||
def validate(self, value):
|
||||
try:
|
||||
value = long(value)
|
||||
except Exception:
|
||||
except (TypeError, ValueError):
|
||||
self.error('%s could not be converted to long' % value)
|
||||
|
||||
if self.min_value is not None and value < self.min_value:
|
||||
@@ -361,7 +371,8 @@ class FloatField(BaseField):
|
||||
|
||||
|
||||
class DecimalField(BaseField):
|
||||
"""Fixed-point decimal number field.
|
||||
"""Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
|
||||
If using floats, beware of Decimal to float conversion (potential precision loss)
|
||||
|
||||
.. versionchanged:: 0.8
|
||||
.. versionadded:: 0.3
|
||||
@@ -372,7 +383,9 @@ class DecimalField(BaseField):
|
||||
"""
|
||||
:param min_value: Validation rule for the minimum acceptable value.
|
||||
:param max_value: Validation rule for the maximum acceptable value.
|
||||
:param force_string: Store as a string.
|
||||
:param force_string: Store the value as a string (instead of a float).
|
||||
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
|
||||
and some query operator won't work (e.g: inc, dec)
|
||||
:param precision: Number of decimal places to store.
|
||||
:param rounding: The rounding rule from the python decimal library:
|
||||
|
||||
@@ -403,7 +416,7 @@ class DecimalField(BaseField):
|
||||
# Convert to string for python 2.6 before casting to Decimal
|
||||
try:
|
||||
value = decimal.Decimal('%s' % value)
|
||||
except decimal.InvalidOperation:
|
||||
except (TypeError, ValueError, decimal.InvalidOperation):
|
||||
return value
|
||||
return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding)
|
||||
|
||||
@@ -420,7 +433,7 @@ class DecimalField(BaseField):
|
||||
value = six.text_type(value)
|
||||
try:
|
||||
value = decimal.Decimal(value)
|
||||
except Exception as exc:
|
||||
except (TypeError, ValueError, decimal.InvalidOperation) as exc:
|
||||
self.error('Could not convert value to decimal: %s' % exc)
|
||||
|
||||
if self.min_value is not None and value < self.min_value:
|
||||
@@ -459,6 +472,8 @@ class DateTimeField(BaseField):
|
||||
installed you can utilise it to convert varying types of date formats into valid
|
||||
python datetime objects.
|
||||
|
||||
Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow)
|
||||
|
||||
Note: Microseconds are rounded to the nearest millisecond.
|
||||
Pre UTC microsecond support is effectively broken.
|
||||
Use :class:`~mongoengine.fields.ComplexDateTimeField` if you
|
||||
@@ -522,6 +537,22 @@ class DateTimeField(BaseField):
|
||||
return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value))
|
||||
|
||||
|
||||
class DateField(DateTimeField):
|
||||
def to_mongo(self, value):
|
||||
value = super(DateField, self).to_mongo(value)
|
||||
# drop hours, minutes, seconds
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = datetime.datetime(value.year, value.month, value.day)
|
||||
return value
|
||||
|
||||
def to_python(self, value):
|
||||
value = super(DateField, self).to_python(value)
|
||||
# convert datetime to date
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = datetime.date(value.year, value.month, value.day)
|
||||
return value
|
||||
|
||||
|
||||
class ComplexDateTimeField(StringField):
|
||||
"""
|
||||
ComplexDateTimeField handles microseconds exactly instead of rounding
|
||||
@@ -538,11 +569,15 @@ class ComplexDateTimeField(StringField):
|
||||
The `,` as the separator can be easily modified by passing the `separator`
|
||||
keyword when initializing the field.
|
||||
|
||||
Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow)
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, separator=',', **kwargs):
|
||||
self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond']
|
||||
"""
|
||||
:param separator: Allows to customize the separator used for storage (default ``,``)
|
||||
"""
|
||||
self.separator = separator
|
||||
self.format = separator.join(['%Y', '%m', '%d', '%H', '%M', '%S', '%f'])
|
||||
super(ComplexDateTimeField, self).__init__(**kwargs)
|
||||
@@ -569,20 +604,24 @@ class ComplexDateTimeField(StringField):
|
||||
>>> ComplexDateTimeField()._convert_from_string(a)
|
||||
datetime.datetime(2011, 6, 8, 20, 26, 24, 92284)
|
||||
"""
|
||||
values = map(int, data.split(self.separator))
|
||||
values = [int(d) for d in data.split(self.separator)]
|
||||
return datetime.datetime(*values)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
data = super(ComplexDateTimeField, self).__get__(instance, owner)
|
||||
if data is None:
|
||||
return None if self.null else datetime.datetime.now()
|
||||
if isinstance(data, datetime.datetime):
|
||||
|
||||
if isinstance(data, datetime.datetime) or data is None:
|
||||
return data
|
||||
return self._convert_from_string(data)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
value = self._convert_from_datetime(value) if value else value
|
||||
return super(ComplexDateTimeField, self).__set__(instance, value)
|
||||
super(ComplexDateTimeField, self).__set__(instance, value)
|
||||
value = instance._data[self.name]
|
||||
if value is not None:
|
||||
instance._data[self.name] = self._convert_from_datetime(value)
|
||||
|
||||
def validate(self, value):
|
||||
value = self.to_python(value)
|
||||
@@ -611,6 +650,7 @@ class EmbeddedDocumentField(BaseField):
|
||||
"""
|
||||
|
||||
def __init__(self, document_type, **kwargs):
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if not (
|
||||
isinstance(document_type, six.string_types) or
|
||||
issubclass(document_type, EmbeddedDocument)
|
||||
@@ -625,9 +665,17 @@ class EmbeddedDocumentField(BaseField):
|
||||
def document_type(self):
|
||||
if isinstance(self.document_type_obj, six.string_types):
|
||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||
self.document_type_obj = self.owner_document
|
||||
resolved_document_type = self.owner_document
|
||||
else:
|
||||
self.document_type_obj = get_document(self.document_type_obj)
|
||||
resolved_document_type = get_document(self.document_type_obj)
|
||||
|
||||
if not issubclass(resolved_document_type, EmbeddedDocument):
|
||||
# Due to the late resolution of the document_type
|
||||
# There is a chance that it won't be an EmbeddedDocument (#1661)
|
||||
self.error('Invalid embedded document class provided to an '
|
||||
'EmbeddedDocumentField')
|
||||
self.document_type_obj = resolved_document_type
|
||||
|
||||
return self.document_type_obj
|
||||
|
||||
def to_python(self, value):
|
||||
@@ -686,16 +734,28 @@ class GenericEmbeddedDocumentField(BaseField):
|
||||
return value
|
||||
|
||||
def validate(self, value, clean=True):
|
||||
if self.choices and isinstance(value, SON):
|
||||
for choice in self.choices:
|
||||
if value['_cls'] == choice._class_name:
|
||||
return True
|
||||
|
||||
if not isinstance(value, EmbeddedDocument):
|
||||
self.error('Invalid embedded document instance provided to an '
|
||||
'GenericEmbeddedDocumentField')
|
||||
|
||||
value.validate(clean=clean)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
if self.choices:
|
||||
for choice in self.choices:
|
||||
field = choice._fields.get(member_name)
|
||||
if field:
|
||||
return field
|
||||
return None
|
||||
|
||||
def to_mongo(self, document, use_db_field=True, fields=None):
|
||||
if document is None:
|
||||
return None
|
||||
|
||||
data = document.to_mongo(use_db_field, fields)
|
||||
if '_cls' not in data:
|
||||
data['_cls'] = document._class_name
|
||||
@@ -779,10 +839,20 @@ class ListField(ComplexBaseField):
|
||||
kwargs.setdefault('default', lambda: [])
|
||||
super(ListField, self).__init__(**kwargs)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
# Document class being used rather than a document object
|
||||
return self
|
||||
value = instance._data.get(self.name)
|
||||
LazyReferenceField = _import_class('LazyReferenceField')
|
||||
GenericLazyReferenceField = _import_class('GenericLazyReferenceField')
|
||||
if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value:
|
||||
instance._data[self.name] = [self.field.build_lazyref(x) for x in value]
|
||||
return super(ListField, self).__get__(instance, owner)
|
||||
|
||||
def validate(self, value):
|
||||
"""Make sure that a list of valid fields is being used."""
|
||||
if (not isinstance(value, (list, tuple, QuerySet)) or
|
||||
isinstance(value, six.string_types)):
|
||||
if not isinstance(value, (list, tuple, BaseQuerySet)):
|
||||
self.error('Only lists and tuples may be used in a list field')
|
||||
super(ListField, self).validate(value)
|
||||
|
||||
@@ -874,7 +944,7 @@ def key_has_dot_or_dollar(d):
|
||||
dictionary contains a dot or a dollar sign.
|
||||
"""
|
||||
for k, v in d.items():
|
||||
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
|
||||
if ('.' in k or k.startswith('$')) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
|
||||
return True
|
||||
|
||||
|
||||
@@ -889,12 +959,10 @@ class DictField(ComplexBaseField):
|
||||
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
|
||||
"""
|
||||
|
||||
def __init__(self, basecls=None, field=None, *args, **kwargs):
|
||||
def __init__(self, field=None, *args, **kwargs):
|
||||
self.field = field
|
||||
self._auto_dereference = False
|
||||
self.basecls = basecls or BaseField
|
||||
if not issubclass(self.basecls, BaseField):
|
||||
self.error('DictField only accepts dict values')
|
||||
|
||||
kwargs.setdefault('default', lambda: {})
|
||||
super(DictField, self).__init__(*args, **kwargs)
|
||||
|
||||
@@ -909,11 +977,11 @@ class DictField(ComplexBaseField):
|
||||
self.error(msg)
|
||||
if key_has_dot_or_dollar(value):
|
||||
self.error('Invalid dictionary key name - keys may not contain "."'
|
||||
' or "$" characters')
|
||||
' or startswith "$" characters')
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
return DictField(basecls=self.basecls, db_field=member_name)
|
||||
return DictField(db_field=member_name)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
match_operators = ['contains', 'icontains', 'startswith',
|
||||
@@ -923,7 +991,7 @@ class DictField(ComplexBaseField):
|
||||
if op in match_operators and isinstance(value, six.string_types):
|
||||
return StringField().prepare_query_value(op, value)
|
||||
|
||||
if hasattr(self.field, 'field'):
|
||||
if hasattr(self.field, 'field'): # Used for instance when using DictField(ListField(IntField()))
|
||||
if op in ('set', 'unset') and isinstance(value, dict):
|
||||
return {
|
||||
k: self.field.prepare_query_value(op, v)
|
||||
@@ -943,6 +1011,7 @@ class MapField(DictField):
|
||||
"""
|
||||
|
||||
def __init__(self, field=None, *args, **kwargs):
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if not isinstance(field, BaseField):
|
||||
self.error('Argument to MapField constructor must be a valid '
|
||||
'field')
|
||||
@@ -953,6 +1022,15 @@ class ReferenceField(BaseField):
|
||||
"""A reference to a document that will be automatically dereferenced on
|
||||
access (lazily).
|
||||
|
||||
Note this means you will get a database I/O access everytime you access
|
||||
this field. This is necessary because the field returns a :class:`~mongoengine.Document`
|
||||
which precise type can depend of the value of the `_cls` field present in the
|
||||
document in database.
|
||||
In short, using this type of field can lead to poor performances (especially
|
||||
if you access this field only to retrieve it `pk` field which is already
|
||||
known before dereference). To solve this you should consider using the
|
||||
:class:`~mongoengine.fields.LazyReferenceField`.
|
||||
|
||||
Use the `reverse_delete_rule` to handle what should happen if the document
|
||||
the field is referencing is deleted. EmbeddedDocuments, DictFields and
|
||||
MapFields does not support reverse_delete_rule and an `InvalidDocumentError`
|
||||
@@ -971,11 +1049,13 @@ class ReferenceField(BaseField):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Bar(Document):
|
||||
content = StringField()
|
||||
foo = ReferenceField('Foo')
|
||||
class Org(Document):
|
||||
owner = ReferenceField('User')
|
||||
|
||||
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
|
||||
class User(Document):
|
||||
org = ReferenceField('Org', reverse_delete_rule=CASCADE)
|
||||
|
||||
User.register_delete_rule(Org, 'owner', DENY)
|
||||
|
||||
.. versionchanged:: 0.5 added `reverse_delete_rule`
|
||||
"""
|
||||
@@ -993,6 +1073,7 @@ class ReferenceField(BaseField):
|
||||
A reference to an abstract document type is always stored as a
|
||||
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
|
||||
"""
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if (
|
||||
not isinstance(document_type, six.string_types) and
|
||||
not issubclass(document_type, Document)
|
||||
@@ -1022,9 +1103,9 @@ class ReferenceField(BaseField):
|
||||
|
||||
# Get value from document instance if available
|
||||
value = instance._data.get(self.name)
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
# Dereference DBRefs
|
||||
if self._auto_dereference and isinstance(value, DBRef):
|
||||
if auto_dereference and isinstance(value, DBRef):
|
||||
if hasattr(value, 'cls'):
|
||||
# Dereference using the class type specified in the reference
|
||||
cls = get_document(value.cls)
|
||||
@@ -1047,6 +1128,8 @@ class ReferenceField(BaseField):
|
||||
if isinstance(document, Document):
|
||||
# We need the id from the saved object to create the DBRef
|
||||
id_ = document.pk
|
||||
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if id_ is None:
|
||||
self.error('You can only reference documents once they have'
|
||||
' been saved to the database')
|
||||
@@ -1086,21 +1169,13 @@ class ReferenceField(BaseField):
|
||||
return self.to_mongo(value)
|
||||
|
||||
def validate(self, value):
|
||||
|
||||
if not isinstance(value, (self.document_type, DBRef, ObjectId)):
|
||||
self.error('A ReferenceField only accepts DBRef, ObjectId or documents')
|
||||
if not isinstance(value, (self.document_type, LazyReference, DBRef, ObjectId)):
|
||||
self.error('A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents')
|
||||
|
||||
if isinstance(value, Document) and value.id is None:
|
||||
self.error('You can only reference documents once they have been '
|
||||
'saved to the database')
|
||||
|
||||
if self.document_type._meta.get('abstract') and \
|
||||
not isinstance(value, self.document_type):
|
||||
self.error(
|
||||
'%s is not an instance of abstract reference type %s' % (
|
||||
self.document_type._class_name)
|
||||
)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
return self.document_type._fields.get(member_name)
|
||||
|
||||
@@ -1121,6 +1196,7 @@ class CachedReferenceField(BaseField):
|
||||
if fields is None:
|
||||
fields = []
|
||||
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if (
|
||||
not isinstance(document_type, six.string_types) and
|
||||
not issubclass(document_type, Document)
|
||||
@@ -1180,9 +1256,10 @@ class CachedReferenceField(BaseField):
|
||||
|
||||
# Get value from document instance if available
|
||||
value = instance._data.get(self.name)
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
|
||||
# Dereference DBRefs
|
||||
if self._auto_dereference and isinstance(value, DBRef):
|
||||
if auto_dereference and isinstance(value, DBRef):
|
||||
dereferenced = self.document_type._get_db().dereference(value)
|
||||
if dereferenced is None:
|
||||
raise DoesNotExist('Trying to dereference unknown document %s' % value)
|
||||
@@ -1195,6 +1272,7 @@ class CachedReferenceField(BaseField):
|
||||
id_field_name = self.document_type._meta['id_field']
|
||||
id_field = self.document_type._fields[id_field_name]
|
||||
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if isinstance(document, Document):
|
||||
# We need the id from the saved object to create the DBRef
|
||||
id_ = document.pk
|
||||
@@ -1203,7 +1281,6 @@ class CachedReferenceField(BaseField):
|
||||
' been saved to the database')
|
||||
else:
|
||||
self.error('Only accept a document object')
|
||||
# TODO: should raise here or will fail next statement
|
||||
|
||||
value = SON((
|
||||
('_id', id_field.to_mongo(id_)),
|
||||
@@ -1221,16 +1298,20 @@ class CachedReferenceField(BaseField):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if isinstance(value, Document):
|
||||
if value.pk is None:
|
||||
self.error('You can only reference documents once they have'
|
||||
' been saved to the database')
|
||||
return {'_id': value.pk}
|
||||
value_dict = {'_id': value.pk}
|
||||
for field in self.fields:
|
||||
value_dict.update({field: value[field]})
|
||||
|
||||
return value_dict
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self, value):
|
||||
|
||||
if not isinstance(value, self.document_type):
|
||||
self.error('A CachedReferenceField only accepts documents')
|
||||
|
||||
@@ -1263,6 +1344,12 @@ class GenericReferenceField(BaseField):
|
||||
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
|
||||
that will be automatically dereferenced on access (lazily).
|
||||
|
||||
Note this field works the same way as :class:`~mongoengine.document.ReferenceField`,
|
||||
doing database I/O access the first time it is accessed (even if it's to access
|
||||
it ``pk`` or ``id`` field).
|
||||
To solve this you should consider using the
|
||||
:class:`~mongoengine.fields.GenericLazyReferenceField`.
|
||||
|
||||
.. note ::
|
||||
* Any documents used as a generic reference must be registered in the
|
||||
document registry. Importing the model will automatically register
|
||||
@@ -1285,6 +1372,8 @@ class GenericReferenceField(BaseField):
|
||||
elif isinstance(choice, type) and issubclass(choice, Document):
|
||||
self.choices.append(choice._class_name)
|
||||
else:
|
||||
# XXX ValidationError raised outside of the "validate"
|
||||
# method.
|
||||
self.error('Invalid choices provided: must be a list of'
|
||||
'Document subclasses and/or six.string_typess')
|
||||
|
||||
@@ -1303,8 +1392,8 @@ class GenericReferenceField(BaseField):
|
||||
|
||||
value = instance._data.get(self.name)
|
||||
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
if self._auto_dereference and isinstance(value, (dict, SON)):
|
||||
auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
if auto_dereference and isinstance(value, (dict, SON)):
|
||||
dereferenced = self.dereference(value)
|
||||
if dereferenced is None:
|
||||
raise DoesNotExist('Trying to dereference unknown document %s' % value)
|
||||
@@ -1348,6 +1437,7 @@ class GenericReferenceField(BaseField):
|
||||
# We need the id from the saved object to create the DBRef
|
||||
id_ = document.id
|
||||
if id_ is None:
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
self.error('You can only reference documents once they have'
|
||||
' been saved to the database')
|
||||
else:
|
||||
@@ -1385,10 +1475,10 @@ class BinaryField(BaseField):
|
||||
return Binary(value)
|
||||
|
||||
def validate(self, value):
|
||||
if not isinstance(value, (six.binary_type, six.text_type, Binary)):
|
||||
if not isinstance(value, (six.binary_type, Binary)):
|
||||
self.error('BinaryField only accepts instances of '
|
||||
'(%s, %s, Binary)' % (
|
||||
six.binary_type.__name__, six.text_type.__name__))
|
||||
six.binary_type.__name__, Binary.__name__))
|
||||
|
||||
if self.max_bytes is not None and len(value) > self.max_bytes:
|
||||
self.error('Binary value is too long')
|
||||
@@ -1439,9 +1529,11 @@ class GridFSProxy(object):
|
||||
def __get__(self, instance, value):
|
||||
return self
|
||||
|
||||
def __nonzero__(self):
|
||||
def __bool__(self):
|
||||
return bool(self.grid_id)
|
||||
|
||||
__nonzero__ = __bool__ # For Py2 support
|
||||
|
||||
def __getstate__(self):
|
||||
self_dict = self.__dict__
|
||||
self_dict['_fs'] = None
|
||||
@@ -1459,9 +1551,9 @@ class GridFSProxy(object):
|
||||
return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
|
||||
|
||||
def __str__(self):
|
||||
name = getattr(
|
||||
self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
|
||||
return '<%s: %s>' % (self.__class__.__name__, name)
|
||||
gridout = self.get()
|
||||
filename = getattr(gridout, 'filename') if gridout else '<no file>'
|
||||
return '<%s: %s (%s)>' % (self.__class__.__name__, filename, self.grid_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, GridFSProxy):
|
||||
@@ -1471,6 +1563,9 @@ class GridFSProxy(object):
|
||||
else:
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
@property
|
||||
def fs(self):
|
||||
if not self._fs:
|
||||
@@ -1778,12 +1873,9 @@ class ImageField(FileField):
|
||||
"""
|
||||
A Image File storage field.
|
||||
|
||||
@size (width, height, force):
|
||||
max size to store images, if larger will be automatically resized
|
||||
ex: size=(800, 600, True)
|
||||
|
||||
@thumbnail (width, height, force):
|
||||
size to generate a thumbnail
|
||||
:param size: max size to store images, provided as (width, height, force)
|
||||
if larger, it will be automatically resized (ex: size=(800, 600, True))
|
||||
:param thumbnail_size: size to generate a thumbnail, provided as (width, height, force)
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
@@ -1854,8 +1946,7 @@ class SequenceField(BaseField):
|
||||
self.collection_name = collection_name or self.COLLECTION_NAME
|
||||
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
|
||||
self.sequence_name = sequence_name
|
||||
self.value_decorator = (callable(value_decorator) and
|
||||
value_decorator or self.VALUE_DECORATOR)
|
||||
self.value_decorator = value_decorator if callable(value_decorator) else self.VALUE_DECORATOR
|
||||
super(SequenceField, self).__init__(*args, **kwargs)
|
||||
|
||||
def generate(self):
|
||||
@@ -1964,7 +2055,7 @@ class UUIDField(BaseField):
|
||||
if not isinstance(value, six.string_types):
|
||||
value = six.text_type(value)
|
||||
return uuid.UUID(value)
|
||||
except Exception:
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return original_value
|
||||
return value
|
||||
|
||||
@@ -1986,7 +2077,7 @@ class UUIDField(BaseField):
|
||||
value = str(value)
|
||||
try:
|
||||
uuid.UUID(value)
|
||||
except Exception as exc:
|
||||
except (ValueError, TypeError, AttributeError) as exc:
|
||||
self.error('Could not convert to UUID: %s' % exc)
|
||||
|
||||
|
||||
@@ -2144,3 +2235,201 @@ class MultiPolygonField(GeoJsonBaseField):
|
||||
.. versionadded:: 0.9
|
||||
"""
|
||||
_type = 'MultiPolygon'
|
||||
|
||||
|
||||
class LazyReferenceField(BaseField):
|
||||
"""A really lazy reference to a document.
|
||||
Unlike the :class:`~mongoengine.fields.ReferenceField` it will
|
||||
**not** be automatically (lazily) dereferenced on access.
|
||||
Instead, access will return a :class:`~mongoengine.base.LazyReference` class
|
||||
instance, allowing access to `pk` or manual dereference by using
|
||||
``fetch()`` method.
|
||||
|
||||
.. versionadded:: 0.15
|
||||
"""
|
||||
|
||||
def __init__(self, document_type, passthrough=False, dbref=False,
|
||||
reverse_delete_rule=DO_NOTHING, **kwargs):
|
||||
"""Initialises the Reference Field.
|
||||
|
||||
:param dbref: Store the reference as :class:`~pymongo.dbref.DBRef`
|
||||
or as the :class:`~pymongo.objectid.ObjectId`.id .
|
||||
:param reverse_delete_rule: Determines what to do when the referring
|
||||
object is deleted
|
||||
:param passthrough: When trying to access unknown fields, the
|
||||
:class:`~mongoengine.base.datastructure.LazyReference` instance will
|
||||
automatically call `fetch()` and try to retrive the field on the fetched
|
||||
document. Note this only work getting field (not setting or deleting).
|
||||
"""
|
||||
# XXX ValidationError raised outside of the "validate" method.
|
||||
if (
|
||||
not isinstance(document_type, six.string_types) and
|
||||
not issubclass(document_type, Document)
|
||||
):
|
||||
self.error('Argument to LazyReferenceField constructor must be a '
|
||||
'document class or a string')
|
||||
|
||||
self.dbref = dbref
|
||||
self.passthrough = passthrough
|
||||
self.document_type_obj = document_type
|
||||
self.reverse_delete_rule = reverse_delete_rule
|
||||
super(LazyReferenceField, self).__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def document_type(self):
|
||||
if isinstance(self.document_type_obj, six.string_types):
|
||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||
self.document_type_obj = self.owner_document
|
||||
else:
|
||||
self.document_type_obj = get_document(self.document_type_obj)
|
||||
return self.document_type_obj
|
||||
|
||||
def build_lazyref(self, value):
|
||||
if isinstance(value, LazyReference):
|
||||
if value.passthrough != self.passthrough:
|
||||
value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
|
||||
elif value is not None:
|
||||
if isinstance(value, self.document_type):
|
||||
value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough)
|
||||
elif isinstance(value, DBRef):
|
||||
value = LazyReference(self.document_type, value.id, passthrough=self.passthrough)
|
||||
else:
|
||||
# value is the primary key of the referenced document
|
||||
value = LazyReference(self.document_type, value, passthrough=self.passthrough)
|
||||
return value
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""Descriptor to allow lazy dereferencing."""
|
||||
if instance is None:
|
||||
# Document class being used rather than a document object
|
||||
return self
|
||||
|
||||
value = self.build_lazyref(instance._data.get(self.name))
|
||||
if value:
|
||||
instance._data[self.name] = value
|
||||
|
||||
return super(LazyReferenceField, self).__get__(instance, owner)
|
||||
|
||||
def to_mongo(self, value):
|
||||
if isinstance(value, LazyReference):
|
||||
pk = value.pk
|
||||
elif isinstance(value, self.document_type):
|
||||
pk = value.pk
|
||||
elif isinstance(value, DBRef):
|
||||
pk = value.id
|
||||
else:
|
||||
# value is the primary key of the referenced document
|
||||
pk = value
|
||||
id_field_name = self.document_type._meta['id_field']
|
||||
id_field = self.document_type._fields[id_field_name]
|
||||
pk = id_field.to_mongo(pk)
|
||||
if self.dbref:
|
||||
return DBRef(self.document_type._get_collection_name(), pk)
|
||||
else:
|
||||
return pk
|
||||
|
||||
def validate(self, value):
|
||||
if isinstance(value, LazyReference):
|
||||
if value.collection != self.document_type._get_collection_name():
|
||||
self.error('Reference must be on a `%s` document.' % self.document_type)
|
||||
pk = value.pk
|
||||
elif isinstance(value, self.document_type):
|
||||
pk = value.pk
|
||||
elif isinstance(value, DBRef):
|
||||
# TODO: check collection ?
|
||||
collection = self.document_type._get_collection_name()
|
||||
if value.collection != collection:
|
||||
self.error("DBRef on bad collection (must be on `%s`)" % collection)
|
||||
pk = value.id
|
||||
else:
|
||||
# value is the primary key of the referenced document
|
||||
id_field_name = self.document_type._meta['id_field']
|
||||
id_field = getattr(self.document_type, id_field_name)
|
||||
pk = value
|
||||
try:
|
||||
id_field.validate(pk)
|
||||
except ValidationError:
|
||||
self.error(
|
||||
"value should be `{0}` document, LazyReference or DBRef on `{0}` "
|
||||
"or `{0}`'s primary key (i.e. `{1}`)".format(
|
||||
self.document_type.__name__, type(id_field).__name__))
|
||||
|
||||
if pk is None:
|
||||
self.error('You can only reference documents once they have been '
|
||||
'saved to the database')
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if value is None:
|
||||
return None
|
||||
super(LazyReferenceField, self).prepare_query_value(op, value)
|
||||
return self.to_mongo(value)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
return self.document_type._fields.get(member_name)
|
||||
|
||||
|
||||
class GenericLazyReferenceField(GenericReferenceField):
|
||||
"""A reference to *any* :class:`~mongoengine.document.Document` subclass.
|
||||
Unlike the :class:`~mongoengine.fields.GenericReferenceField` it will
|
||||
**not** be automatically (lazily) dereferenced on access.
|
||||
Instead, access will return a :class:`~mongoengine.base.LazyReference` class
|
||||
instance, allowing access to `pk` or manual dereference by using
|
||||
``fetch()`` method.
|
||||
|
||||
.. note ::
|
||||
* Any documents used as a generic reference must be registered in the
|
||||
document registry. Importing the model will automatically register
|
||||
it.
|
||||
|
||||
* You can use the choices param to limit the acceptable Document types
|
||||
|
||||
.. versionadded:: 0.15
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.passthrough = kwargs.pop('passthrough', False)
|
||||
super(GenericLazyReferenceField, self).__init__(*args, **kwargs)
|
||||
|
||||
def _validate_choices(self, value):
|
||||
if isinstance(value, LazyReference):
|
||||
value = value.document_type._class_name
|
||||
super(GenericLazyReferenceField, self)._validate_choices(value)
|
||||
|
||||
def build_lazyref(self, value):
|
||||
if isinstance(value, LazyReference):
|
||||
if value.passthrough != self.passthrough:
|
||||
value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
|
||||
elif value is not None:
|
||||
if isinstance(value, (dict, SON)):
|
||||
value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough)
|
||||
elif isinstance(value, Document):
|
||||
value = LazyReference(type(value), value.pk, passthrough=self.passthrough)
|
||||
return value
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
value = self.build_lazyref(instance._data.get(self.name))
|
||||
if value:
|
||||
instance._data[self.name] = value
|
||||
|
||||
return super(GenericLazyReferenceField, self).__get__(instance, owner)
|
||||
|
||||
def validate(self, value):
|
||||
if isinstance(value, LazyReference) and value.pk is None:
|
||||
self.error('You can only reference documents once they have been'
|
||||
' saved to the database')
|
||||
return super(GenericLazyReferenceField, self).validate(value)
|
||||
|
||||
def to_mongo(self, document):
|
||||
if document is None:
|
||||
return None
|
||||
|
||||
if isinstance(document, LazyReference):
|
||||
return SON((
|
||||
('_cls', document.document_type._class_name),
|
||||
('_ref', DBRef(document.document_type._get_collection_name(), document.pk))
|
||||
))
|
||||
else:
|
||||
return super(GenericLazyReferenceField, self).to_mongo(document)
|
||||
|
||||
@@ -6,11 +6,7 @@ import pymongo
|
||||
import six
|
||||
|
||||
|
||||
if pymongo.version_tuple[0] < 3:
|
||||
IS_PYMONGO_3 = False
|
||||
else:
|
||||
IS_PYMONGO_3 = True
|
||||
|
||||
IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3
|
||||
|
||||
# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
|
||||
StringIO = six.BytesIO
|
||||
@@ -23,3 +19,10 @@ if not six.PY3:
|
||||
pass
|
||||
else:
|
||||
StringIO = cStringIO.StringIO
|
||||
|
||||
|
||||
if six.PY3:
|
||||
from collections.abc import Hashable
|
||||
else:
|
||||
# raises DeprecationWarnings in Python >=3.7
|
||||
from collections import Hashable
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import operator
|
||||
import pprint
|
||||
import re
|
||||
import warnings
|
||||
@@ -18,7 +17,7 @@ from mongoengine import signals
|
||||
from mongoengine.base import get_document
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.context_managers import switch_db
|
||||
from mongoengine.context_managers import set_write_concern, switch_db
|
||||
from mongoengine.errors import (InvalidQueryError, LookUpError,
|
||||
NotUniqueError, OperationError)
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
@@ -39,8 +38,6 @@ CASCADE = 2
|
||||
DENY = 3
|
||||
PULL = 4
|
||||
|
||||
RE_TYPE = type(re.compile(''))
|
||||
|
||||
|
||||
class BaseQuerySet(object):
|
||||
"""A set of results returned from a query. Wraps a MongoDB cursor,
|
||||
@@ -191,7 +188,7 @@ class BaseQuerySet(object):
|
||||
)
|
||||
|
||||
if queryset._as_pymongo:
|
||||
return queryset._get_as_pymongo(queryset._cursor[key])
|
||||
return queryset._cursor[key]
|
||||
|
||||
return queryset._document._from_son(
|
||||
queryset._cursor[key],
|
||||
@@ -209,18 +206,16 @@ class BaseQuerySet(object):
|
||||
queryset = self.order_by()
|
||||
return False if queryset.first() is None else True
|
||||
|
||||
def __nonzero__(self):
|
||||
"""Avoid to open all records in an if stmt in Py2."""
|
||||
return self._has_data()
|
||||
|
||||
def __bool__(self):
|
||||
"""Avoid to open all records in an if stmt in Py3."""
|
||||
return self._has_data()
|
||||
|
||||
__nonzero__ = __bool__ # For Py2 support
|
||||
|
||||
# Core functions
|
||||
|
||||
def all(self):
|
||||
"""Returns all documents."""
|
||||
"""Returns a copy of the current QuerySet."""
|
||||
return self.__call__()
|
||||
|
||||
def filter(self, *q_objs, **query):
|
||||
@@ -269,13 +264,13 @@ class BaseQuerySet(object):
|
||||
queryset = queryset.filter(*q_objs, **query)
|
||||
|
||||
try:
|
||||
result = queryset.next()
|
||||
result = six.next(queryset)
|
||||
except StopIteration:
|
||||
msg = ('%s matching query does not exist.'
|
||||
% queryset._document._class_name)
|
||||
raise queryset._document.DoesNotExist(msg)
|
||||
try:
|
||||
queryset.next()
|
||||
six.next(queryset)
|
||||
except StopIteration:
|
||||
return result
|
||||
|
||||
@@ -350,11 +345,24 @@ class BaseQuerySet(object):
|
||||
documents=docs, **signal_kwargs)
|
||||
|
||||
raw = [doc.to_mongo() for doc in docs]
|
||||
|
||||
with set_write_concern(self._collection, write_concern) as collection:
|
||||
insert_func = collection.insert_many
|
||||
if return_one:
|
||||
raw = raw[0]
|
||||
insert_func = collection.insert_one
|
||||
|
||||
try:
|
||||
ids = self._collection.insert(raw, **write_concern)
|
||||
inserted_result = insert_func(raw)
|
||||
ids = [inserted_result.inserted_id] if return_one else inserted_result.inserted_ids
|
||||
except pymongo.errors.DuplicateKeyError as err:
|
||||
message = 'Could not save document (%s)'
|
||||
raise NotUniqueError(message % six.text_type(err))
|
||||
except pymongo.errors.BulkWriteError as err:
|
||||
# inserting documents that already have an _id field will
|
||||
# give huge performance debt or raise
|
||||
message = u'Document must not have _id value before bulk write (%s)'
|
||||
raise NotUniqueError(message % six.text_type(err))
|
||||
except pymongo.errors.OperationFailure as err:
|
||||
message = 'Could not save document (%s)'
|
||||
if re.match('^E1100[01] duplicate key', six.text_type(err)):
|
||||
@@ -364,18 +372,20 @@ class BaseQuerySet(object):
|
||||
raise NotUniqueError(message % six.text_type(err))
|
||||
raise OperationError(message % six.text_type(err))
|
||||
|
||||
# Apply inserted_ids to documents
|
||||
for doc, doc_id in zip(docs, ids):
|
||||
doc.pk = doc_id
|
||||
|
||||
if not load_bulk:
|
||||
signals.post_bulk_insert.send(
|
||||
self._document, documents=docs, loaded=False, **signal_kwargs)
|
||||
return return_one and ids[0] or ids
|
||||
return ids[0] if return_one else ids
|
||||
|
||||
documents = self.in_bulk(ids)
|
||||
results = []
|
||||
for obj_id in ids:
|
||||
results.append(documents.get(obj_id))
|
||||
results = [documents.get(obj_id) for obj_id in ids]
|
||||
signals.post_bulk_insert.send(
|
||||
self._document, documents=results, loaded=True, **signal_kwargs)
|
||||
return return_one and results[0] or results
|
||||
return results[0] if return_one else results
|
||||
|
||||
def count(self, with_limit_and_skip=False):
|
||||
"""Count the selected elements in the query.
|
||||
@@ -384,9 +394,11 @@ class BaseQuerySet(object):
|
||||
:meth:`skip` that has been applied to this cursor into account when
|
||||
getting the count
|
||||
"""
|
||||
if self._limit == 0 and with_limit_and_skip or self._none:
|
||||
if self._limit == 0 and with_limit_and_skip is False or self._none:
|
||||
return 0
|
||||
return self._cursor.count(with_limit_and_skip=with_limit_and_skip)
|
||||
count = self._cursor.count(with_limit_and_skip=with_limit_and_skip)
|
||||
self._cursor_obj = None
|
||||
return count
|
||||
|
||||
def delete(self, write_concern=None, _from_doc_delete=False,
|
||||
cascade_refs=None):
|
||||
@@ -486,8 +498,9 @@ class BaseQuerySet(object):
|
||||
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
|
||||
wait until at least two servers have recorded the write and
|
||||
will force an fsync on the primary server.
|
||||
:param full_result: Return the full result rather than just the number
|
||||
updated.
|
||||
:param full_result: Return the full result dictionary rather than just the number
|
||||
updated, e.g. return
|
||||
``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``.
|
||||
:param update: Django-style update keyword arguments
|
||||
|
||||
.. versionadded:: 0.2
|
||||
@@ -510,12 +523,15 @@ class BaseQuerySet(object):
|
||||
else:
|
||||
update['$set'] = {'_cls': queryset._document._class_name}
|
||||
try:
|
||||
result = queryset._collection.update(query, update, multi=multi,
|
||||
upsert=upsert, **write_concern)
|
||||
with set_write_concern(queryset._collection, write_concern) as collection:
|
||||
update_func = collection.update_one
|
||||
if multi:
|
||||
update_func = collection.update_many
|
||||
result = update_func(query, update, upsert=upsert)
|
||||
if full_result:
|
||||
return result
|
||||
elif result:
|
||||
return result['n']
|
||||
elif result.raw_result:
|
||||
return result.raw_result['n']
|
||||
except pymongo.errors.DuplicateKeyError as err:
|
||||
raise NotUniqueError(u'Update failed (%s)' % six.text_type(err))
|
||||
except pymongo.errors.OperationFailure as err:
|
||||
@@ -544,10 +560,10 @@ class BaseQuerySet(object):
|
||||
write_concern=write_concern,
|
||||
full_result=True, **update)
|
||||
|
||||
if atomic_update['updatedExisting']:
|
||||
if atomic_update.raw_result['updatedExisting']:
|
||||
document = self.get()
|
||||
else:
|
||||
document = self._document.objects.with_id(atomic_update['upserted'])
|
||||
document = self._document.objects.with_id(atomic_update.upserted_id)
|
||||
return document
|
||||
|
||||
def update_one(self, upsert=False, write_concern=None, **update):
|
||||
@@ -674,7 +690,7 @@ class BaseQuerySet(object):
|
||||
self._document._from_son(doc, only_fields=self.only_fields))
|
||||
elif self._as_pymongo:
|
||||
for doc in docs:
|
||||
doc_map[doc['_id']] = self._get_as_pymongo(doc)
|
||||
doc_map[doc['_id']] = doc
|
||||
else:
|
||||
for doc in docs:
|
||||
doc_map[doc['_id']] = self._document._from_son(
|
||||
@@ -759,10 +775,11 @@ class BaseQuerySet(object):
|
||||
"""Limit the number of returned documents to `n`. This may also be
|
||||
achieved using array-slicing syntax (e.g. ``User.objects[:5]``).
|
||||
|
||||
:param n: the maximum number of objects to return
|
||||
:param n: the maximum number of objects to return if n is greater than 0.
|
||||
When 0 is passed, returns all the documents in the cursor
|
||||
"""
|
||||
queryset = self.clone()
|
||||
queryset._limit = n if n != 0 else 1
|
||||
queryset._limit = n
|
||||
|
||||
# If a cursor object has already been created, apply the limit to it.
|
||||
if queryset._cursor_obj:
|
||||
@@ -960,11 +977,10 @@ class BaseQuerySet(object):
|
||||
# explicitly included, and then more complicated operators such as
|
||||
# $slice.
|
||||
def _sort_key(field_tuple):
|
||||
key, value = field_tuple
|
||||
if isinstance(value, (int)):
|
||||
_, value = field_tuple
|
||||
if isinstance(value, int):
|
||||
return value # 0 for exclusion, 1 for inclusion
|
||||
else:
|
||||
return 2 # so that complex values appear last
|
||||
return 2 # so that complex values appear last
|
||||
|
||||
fields = sorted(cleaned_fields, key=_sort_key)
|
||||
|
||||
@@ -1182,6 +1198,10 @@ class BaseQuerySet(object):
|
||||
|
||||
pipeline = initial_pipeline + list(pipeline)
|
||||
|
||||
if IS_PYMONGO_3 and self._read_preference is not None:
|
||||
return self._collection.with_options(read_preference=self._read_preference) \
|
||||
.aggregate(pipeline, cursor={}, **kwargs)
|
||||
|
||||
return self._collection.aggregate(pipeline, cursor={}, **kwargs)
|
||||
|
||||
# JS functionality
|
||||
@@ -1457,16 +1477,16 @@ class BaseQuerySet(object):
|
||||
|
||||
# Iterator helpers
|
||||
|
||||
def next(self):
|
||||
def __next__(self):
|
||||
"""Wrap the result in a :class:`~mongoengine.Document` object.
|
||||
"""
|
||||
if self._limit == 0 or self._none:
|
||||
raise StopIteration
|
||||
|
||||
raw_doc = self._cursor.next()
|
||||
raw_doc = six.next(self._cursor)
|
||||
|
||||
if self._as_pymongo:
|
||||
return self._get_as_pymongo(raw_doc)
|
||||
return raw_doc
|
||||
|
||||
doc = self._document._from_son(
|
||||
raw_doc, _auto_dereference=self._auto_dereference,
|
||||
@@ -1477,6 +1497,8 @@ class BaseQuerySet(object):
|
||||
|
||||
return doc
|
||||
|
||||
next = __next__ # For Python2 support
|
||||
|
||||
def rewind(self):
|
||||
"""Rewind the cursor to its unevaluated state.
|
||||
|
||||
@@ -1578,6 +1600,9 @@ class BaseQuerySet(object):
|
||||
if self._batch_size is not None:
|
||||
self._cursor_obj.batch_size(self._batch_size)
|
||||
|
||||
if self._comment is not None:
|
||||
self._cursor_obj.comment(self._comment)
|
||||
|
||||
return self._cursor_obj
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@@ -1722,25 +1747,33 @@ class BaseQuerySet(object):
|
||||
return frequencies
|
||||
|
||||
def _fields_to_dbfields(self, fields):
|
||||
"""Translate fields paths to its db equivalents"""
|
||||
ret = []
|
||||
"""Translate fields' paths to their db equivalents."""
|
||||
subclasses = []
|
||||
document = self._document
|
||||
if document._meta['allow_inheritance']:
|
||||
if self._document._meta['allow_inheritance']:
|
||||
subclasses = [get_document(x)
|
||||
for x in document._subclasses][1:]
|
||||
for x in self._document._subclasses][1:]
|
||||
|
||||
db_field_paths = []
|
||||
for field in fields:
|
||||
field_parts = field.split('.')
|
||||
try:
|
||||
field = '.'.join(f.db_field for f in
|
||||
document._lookup_field(field.split('.')))
|
||||
ret.append(field)
|
||||
field = '.'.join(
|
||||
f if isinstance(f, six.string_types) else f.db_field
|
||||
for f in self._document._lookup_field(field_parts)
|
||||
)
|
||||
db_field_paths.append(field)
|
||||
except LookUpError as err:
|
||||
found = False
|
||||
|
||||
# If a field path wasn't found on the main document, go
|
||||
# through its subclasses and see if it exists on any of them.
|
||||
for subdoc in subclasses:
|
||||
try:
|
||||
subfield = '.'.join(f.db_field for f in
|
||||
subdoc._lookup_field(field.split('.')))
|
||||
ret.append(subfield)
|
||||
subfield = '.'.join(
|
||||
f if isinstance(f, six.string_types) else f.db_field
|
||||
for f in subdoc._lookup_field(field_parts)
|
||||
)
|
||||
db_field_paths.append(subfield)
|
||||
found = True
|
||||
break
|
||||
except LookUpError:
|
||||
@@ -1748,7 +1781,8 @@ class BaseQuerySet(object):
|
||||
|
||||
if not found:
|
||||
raise err
|
||||
return ret
|
||||
|
||||
return db_field_paths
|
||||
|
||||
def _get_order_by(self, keys):
|
||||
"""Given a list of MongoEngine-style sort keys, return a list
|
||||
@@ -1799,26 +1833,6 @@ class BaseQuerySet(object):
|
||||
|
||||
return tuple(data)
|
||||
|
||||
def _get_as_pymongo(self, doc):
|
||||
"""Clean up a PyMongo doc, removing fields that were only fetched
|
||||
for the sake of MongoEngine's implementation, and return it.
|
||||
"""
|
||||
# Always remove _cls as a MongoEngine's implementation detail.
|
||||
if '_cls' in doc:
|
||||
del doc['_cls']
|
||||
|
||||
# If the _id was not included in a .only or was excluded in a .exclude,
|
||||
# remove it from the doc (we always fetch it so that we can properly
|
||||
# construct documents).
|
||||
fields = self._loaded_fields
|
||||
if fields and '_id' in doc and (
|
||||
(fields.value == QueryFieldList.ONLY and '_id' not in fields.fields) or
|
||||
(fields.value == QueryFieldList.EXCLUDE and '_id' in fields.fields)
|
||||
):
|
||||
del doc['_id']
|
||||
|
||||
return doc
|
||||
|
||||
def _sub_js_fields(self, code):
|
||||
"""When fields are specified with [~fieldname] syntax, where
|
||||
*fieldname* is the Python name of a field, *fieldname* will be
|
||||
@@ -1840,8 +1854,8 @@ class BaseQuerySet(object):
|
||||
# Substitute the correct name for the field into the javascript
|
||||
return '.'.join([f.db_field for f in fields])
|
||||
|
||||
code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
|
||||
code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
|
||||
code = re.sub(r'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
|
||||
code = re.sub(r'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
|
||||
code)
|
||||
return code
|
||||
|
||||
|
||||
@@ -63,9 +63,11 @@ class QueryFieldList(object):
|
||||
self._only_called = True
|
||||
return self
|
||||
|
||||
def __nonzero__(self):
|
||||
def __bool__(self):
|
||||
return bool(self.fields)
|
||||
|
||||
__nonzero__ = __bool__ # For Py2 support
|
||||
|
||||
def as_dict(self):
|
||||
field_list = {field: self.value for field in self.fields}
|
||||
if self.slice:
|
||||
|
||||
@@ -36,7 +36,7 @@ class QuerySetManager(object):
|
||||
queryset_class = owner._meta.get('queryset_class', self.default)
|
||||
queryset = queryset_class(owner, owner._get_collection())
|
||||
if self.get_queryset:
|
||||
arg_count = self.get_queryset.func_code.co_argcount
|
||||
arg_count = self.get_queryset.__code__.co_argcount
|
||||
if arg_count == 1:
|
||||
queryset = self.get_queryset(queryset)
|
||||
elif arg_count == 2:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import six
|
||||
|
||||
from mongoengine.errors import OperationError
|
||||
from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING,
|
||||
NULLIFY, PULL)
|
||||
@@ -87,10 +89,10 @@ class QuerySet(BaseQuerySet):
|
||||
yield self._result_cache[pos]
|
||||
pos += 1
|
||||
|
||||
# Raise StopIteration if we already established there were no more
|
||||
# return if we already established there were no more
|
||||
# docs in the db cursor.
|
||||
if not self._has_more:
|
||||
raise StopIteration
|
||||
return
|
||||
|
||||
# Otherwise, populate more of the cache and repeat.
|
||||
if len(self._result_cache) <= pos:
|
||||
@@ -112,8 +114,8 @@ class QuerySet(BaseQuerySet):
|
||||
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
|
||||
# the result cache.
|
||||
try:
|
||||
for _ in xrange(ITER_CHUNK_SIZE):
|
||||
self._result_cache.append(self.next())
|
||||
for _ in six.moves.range(ITER_CHUNK_SIZE):
|
||||
self._result_cache.append(six.next(self))
|
||||
except StopIteration:
|
||||
# Getting this exception means there are no more docs in the
|
||||
# db cursor. Set _has_more to False so that we can use that
|
||||
@@ -166,9 +168,9 @@ class QuerySetNoCache(BaseQuerySet):
|
||||
return '.. queryset mid-iteration ..'
|
||||
|
||||
data = []
|
||||
for _ in xrange(REPR_OUTPUT_SIZE + 1):
|
||||
for _ in six.moves.range(REPR_OUTPUT_SIZE + 1):
|
||||
try:
|
||||
data.append(self.next())
|
||||
data.append(six.next(self))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
@@ -184,10 +186,3 @@ class QuerySetNoCache(BaseQuerySet):
|
||||
queryset = self.clone()
|
||||
queryset.rewind()
|
||||
return queryset
|
||||
|
||||
|
||||
class QuerySetNoDeRef(QuerySet):
|
||||
"""Special no_dereference QuerySet"""
|
||||
|
||||
def __dereference(items, max_depth=1, instance=None, name=None):
|
||||
return items
|
||||
|
||||
@@ -101,21 +101,8 @@ def query(_doc_cls=None, **kwargs):
|
||||
value = value['_id']
|
||||
|
||||
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
|
||||
# Raise an error if the in/nin/all/near param is not iterable. We need a
|
||||
# special check for BaseDocument, because - although it's iterable - using
|
||||
# it as such in the context of this method is most definitely a mistake.
|
||||
BaseDocument = _import_class('BaseDocument')
|
||||
if isinstance(value, BaseDocument):
|
||||
raise TypeError("When using the `in`, `nin`, `all`, or "
|
||||
"`near`-operators you can\'t use a "
|
||||
"`Document`, you must wrap your object "
|
||||
"in a list (object -> [object]).")
|
||||
elif not hasattr(value, '__iter__'):
|
||||
raise TypeError("The `in`, `nin`, `all`, or "
|
||||
"`near`-operators must be applied to an "
|
||||
"iterable (e.g. a list).")
|
||||
else:
|
||||
value = [field.prepare_query_value(op, v) for v in value]
|
||||
# Raise an error if the in/nin/all/near param is not iterable.
|
||||
value = _prepare_query_for_iterable(field, op, value)
|
||||
|
||||
# If we're querying a GenericReferenceField, we need to alter the
|
||||
# key depending on the value:
|
||||
@@ -160,7 +147,7 @@ def query(_doc_cls=None, **kwargs):
|
||||
if op is None or key not in mongo_query:
|
||||
mongo_query[key] = value
|
||||
elif key in mongo_query:
|
||||
if isinstance(mongo_query[key], dict):
|
||||
if isinstance(mongo_query[key], dict) and isinstance(value, dict):
|
||||
mongo_query[key].update(value)
|
||||
# $max/minDistance needs to come last - convert to SON
|
||||
value_dict = mongo_query[key]
|
||||
@@ -214,30 +201,37 @@ def update(_doc_cls=None, **update):
|
||||
format.
|
||||
"""
|
||||
mongo_update = {}
|
||||
|
||||
for key, value in update.items():
|
||||
if key == '__raw__':
|
||||
mongo_update.update(value)
|
||||
continue
|
||||
|
||||
parts = key.split('__')
|
||||
|
||||
# if there is no operator, default to 'set'
|
||||
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
||||
parts.insert(0, 'set')
|
||||
|
||||
# Check for an operator and transform to mongo-style if there is
|
||||
op = None
|
||||
if parts[0] in UPDATE_OPERATORS:
|
||||
op = parts.pop(0)
|
||||
# Convert Pythonic names to Mongo equivalents
|
||||
if op in ('push_all', 'pull_all'):
|
||||
op = op.replace('_all', 'All')
|
||||
elif op == 'dec':
|
||||
operator_map = {
|
||||
'push_all': 'pushAll',
|
||||
'pull_all': 'pullAll',
|
||||
'dec': 'inc',
|
||||
'add_to_set': 'addToSet',
|
||||
'set_on_insert': 'setOnInsert'
|
||||
}
|
||||
if op == 'dec':
|
||||
# Support decrement by flipping a positive value's sign
|
||||
# and using 'inc'
|
||||
op = 'inc'
|
||||
value = -value
|
||||
elif op == 'add_to_set':
|
||||
op = 'addToSet'
|
||||
elif op == 'set_on_insert':
|
||||
op = 'setOnInsert'
|
||||
# If the operator doesn't found from operator map, the op value
|
||||
# will stay unchanged
|
||||
op = operator_map.get(op, op)
|
||||
|
||||
match = None
|
||||
if parts[-1] in COMPARISON_OPERATORS:
|
||||
@@ -284,7 +278,15 @@ def update(_doc_cls=None, **update):
|
||||
if isinstance(field, GeoJsonBaseField):
|
||||
value = field.to_mongo(value)
|
||||
|
||||
if op in (None, 'set', 'push', 'pull'):
|
||||
if op == 'pull':
|
||||
if field.required or value is not None:
|
||||
if match == 'in' and not isinstance(value, dict):
|
||||
value = _prepare_query_for_iterable(field, op, value)
|
||||
else:
|
||||
value = field.prepare_query_value(op, value)
|
||||
elif op == 'push' and isinstance(value, (list, tuple, set)):
|
||||
value = [field.prepare_query_value(op, v) for v in value]
|
||||
elif op in (None, 'set', 'push'):
|
||||
if field.required or value is not None:
|
||||
value = field.prepare_query_value(op, value)
|
||||
elif op in ('pushAll', 'pullAll'):
|
||||
@@ -296,6 +298,8 @@ def update(_doc_cls=None, **update):
|
||||
value = field.prepare_query_value(op, value)
|
||||
elif op == 'unset':
|
||||
value = 1
|
||||
elif op == 'inc':
|
||||
value = field.prepare_query_value(op, value)
|
||||
|
||||
if match:
|
||||
match = '$' + match
|
||||
@@ -319,11 +323,17 @@ def update(_doc_cls=None, **update):
|
||||
field_classes = [c.__class__ for c in cleaned_fields]
|
||||
field_classes.reverse()
|
||||
ListField = _import_class('ListField')
|
||||
if ListField in field_classes:
|
||||
# Join all fields via dot notation to the last ListField
|
||||
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
|
||||
if ListField in field_classes or EmbeddedDocumentListField in field_classes:
|
||||
# Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
|
||||
# Then process as normal
|
||||
if ListField in field_classes:
|
||||
_check_field = ListField
|
||||
else:
|
||||
_check_field = EmbeddedDocumentListField
|
||||
|
||||
last_listField = len(
|
||||
cleaned_fields) - field_classes.index(ListField)
|
||||
cleaned_fields) - field_classes.index(_check_field)
|
||||
key = '.'.join(parts[:last_listField])
|
||||
parts = parts[last_listField:]
|
||||
parts.insert(0, key)
|
||||
@@ -333,10 +343,26 @@ def update(_doc_cls=None, **update):
|
||||
value = {key: value}
|
||||
elif op == 'addToSet' and isinstance(value, list):
|
||||
value = {key: {'$each': value}}
|
||||
elif op in ('push', 'pushAll'):
|
||||
if parts[-1].isdigit():
|
||||
key = '.'.join(parts[0:-1])
|
||||
position = int(parts[-1])
|
||||
# $position expects an iterable. If pushing a single value,
|
||||
# wrap it in a list.
|
||||
if not isinstance(value, (set, tuple, list)):
|
||||
value = [value]
|
||||
value = {key: {'$each': value, '$position': position}}
|
||||
else:
|
||||
if op == 'pushAll':
|
||||
op = 'push' # convert to non-deprecated keyword
|
||||
if not isinstance(value, (set, tuple, list)):
|
||||
value = [value]
|
||||
value = {key: {'$each': value}}
|
||||
else:
|
||||
value = {key: value}
|
||||
else:
|
||||
value = {key: value}
|
||||
key = '$' + op
|
||||
|
||||
if key not in mongo_update:
|
||||
mongo_update[key] = value
|
||||
elif key in mongo_update and isinstance(mongo_update[key], dict):
|
||||
@@ -403,7 +429,6 @@ def _infer_geometry(value):
|
||||
'type and coordinates keys')
|
||||
elif isinstance(value, (list, set)):
|
||||
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
|
||||
# TODO: should both TypeError and IndexError be alike interpreted?
|
||||
|
||||
try:
|
||||
value[0][0][0]
|
||||
@@ -425,3 +450,22 @@ def _infer_geometry(value):
|
||||
|
||||
raise InvalidQueryError('Invalid $geometry data. Can be either a '
|
||||
'dictionary or (nested) lists of coordinate(s)')
|
||||
|
||||
|
||||
def _prepare_query_for_iterable(field, op, value):
|
||||
# We need a special check for BaseDocument, because - although it's iterable - using
|
||||
# it as such in the context of this method is most definitely a mistake.
|
||||
BaseDocument = _import_class('BaseDocument')
|
||||
|
||||
if isinstance(value, BaseDocument):
|
||||
raise TypeError("When using the `in`, `nin`, `all`, or "
|
||||
"`near`-operators you can\'t use a "
|
||||
"`Document`, you must wrap your object "
|
||||
"in a list (object -> [object]).")
|
||||
|
||||
if not hasattr(value, '__iter__'):
|
||||
raise TypeError("The `in`, `nin`, `all`, or "
|
||||
"`near`-operators must be applied to an "
|
||||
"iterable (e.g. a list).")
|
||||
|
||||
return [field.prepare_query_value(op, v) for v in value]
|
||||
|
||||
@@ -3,7 +3,7 @@ import copy
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.queryset import transform
|
||||
|
||||
__all__ = ('Q',)
|
||||
__all__ = ('Q', 'QNode')
|
||||
|
||||
|
||||
class QNodeVisitor(object):
|
||||
@@ -131,6 +131,10 @@ class QCombination(QNode):
|
||||
else:
|
||||
self.children.append(node)
|
||||
|
||||
def __repr__(self):
|
||||
op = ' & ' if self.operation is self.AND else ' | '
|
||||
return '(%s)' % op.join([repr(node) for node in self.children])
|
||||
|
||||
def accept(self, visitor):
|
||||
for i in range(len(self.children)):
|
||||
if isinstance(self.children[i], QNode):
|
||||
@@ -151,6 +155,9 @@ class Q(QNode):
|
||||
def __init__(self, **query):
|
||||
self.query = query
|
||||
|
||||
def __repr__(self):
|
||||
return 'Q(**%s)' % repr(self.query)
|
||||
|
||||
def accept(self, visitor):
|
||||
return visitor.visit_query(self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user