Merge branch 'master' into pr/539
Conflicts: mongoengine/base/datastructures.py
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import weakref
|
||||
import functools
|
||||
import itertools
|
||||
from mongoengine.common import _import_class
|
||||
|
||||
__all__ = ("BaseDict", "BaseList")
|
||||
@@ -183,3 +185,97 @@ class BaseList(list):
|
||||
self._instance._mark_as_changed('%s.%s' % (self._name, key))
|
||||
else:
|
||||
self._instance._mark_as_changed(self._name)
|
||||
|
||||
|
||||
class StrictDict(object):
|
||||
__slots__ = ()
|
||||
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
|
||||
_classes = {}
|
||||
def __init__(self, **kwargs):
|
||||
for k,v in kwargs.iteritems():
|
||||
setattr(self, k, v)
|
||||
def __getitem__(self, key):
|
||||
key = '_reserved_' + key if key in self._special_fields else key
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
raise KeyError(key)
|
||||
def __setitem__(self, key, value):
|
||||
key = '_reserved_' + key if key in self._special_fields else key
|
||||
return setattr(self, key, value)
|
||||
def __contains__(self, key):
|
||||
return hasattr(self, key)
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
def pop(self, key, default=None):
|
||||
v = self.get(key, default)
|
||||
try:
|
||||
delattr(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
return v
|
||||
def iteritems(self):
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
def items(self):
|
||||
return [(k, self[k]) for k in iter(self)]
|
||||
def keys(self):
|
||||
return list(iter(self))
|
||||
def __iter__(self):
|
||||
return (key for key in self.__slots__ if hasattr(self, key))
|
||||
def __len__(self):
|
||||
return len(list(self.iteritems()))
|
||||
def __eq__(self, other):
|
||||
return self.items() == other.items()
|
||||
def __neq__(self, other):
|
||||
return self.items() != other.items()
|
||||
|
||||
@classmethod
|
||||
def create(cls, allowed_keys):
|
||||
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
|
||||
allowed_keys = frozenset(allowed_keys_tuple)
|
||||
if allowed_keys not in cls._classes:
|
||||
class SpecificStrictDict(cls):
|
||||
__slots__ = allowed_keys_tuple
|
||||
cls._classes[allowed_keys] = SpecificStrictDict
|
||||
return cls._classes[allowed_keys]
|
||||
|
||||
|
||||
class SemiStrictDict(StrictDict):
|
||||
__slots__ = ('_extras')
|
||||
_classes = {}
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
super(SemiStrictDict, self).__getattr__(attr)
|
||||
except AttributeError:
|
||||
try:
|
||||
return self.__getattribute__('_extras')[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
def __setattr__(self, attr, value):
|
||||
try:
|
||||
super(SemiStrictDict, self).__setattr__(attr, value)
|
||||
except AttributeError:
|
||||
try:
|
||||
self._extras[attr] = value
|
||||
except AttributeError:
|
||||
self._extras = {attr: value}
|
||||
|
||||
def __delattr__(self, attr):
|
||||
try:
|
||||
super(SemiStrictDict, self).__delattr__(attr)
|
||||
except AttributeError:
|
||||
try:
|
||||
del self._extras[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
extras_iter = iter(self.__getattribute__('_extras'))
|
||||
except AttributeError:
|
||||
extras_iter = ()
|
||||
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)
|
||||
|
||||
@@ -13,24 +13,23 @@ from mongoengine import signals
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import (ValidationError, InvalidDocumentError,
|
||||
LookUpError)
|
||||
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
|
||||
to_str_keys_recursive)
|
||||
from mongoengine.python_support import PY3, txt_type
|
||||
|
||||
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
|
||||
from mongoengine.base.datastructures import BaseDict, BaseList
|
||||
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
|
||||
from mongoengine.base.fields import ComplexBaseField
|
||||
|
||||
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
|
||||
|
||||
NON_FIELD_ERRORS = '__all__'
|
||||
|
||||
|
||||
class BaseDocument(object):
|
||||
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
|
||||
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
|
||||
|
||||
_dynamic = False
|
||||
_created = True
|
||||
_dynamic_lock = True
|
||||
_initialised = False
|
||||
STRICT = False
|
||||
|
||||
def __init__(self, *args, **values):
|
||||
"""
|
||||
@@ -39,6 +38,8 @@ class BaseDocument(object):
|
||||
:param __auto_convert: Try and will cast python objects to Object types
|
||||
:param values: A dictionary of values for the document
|
||||
"""
|
||||
self._initialised = False
|
||||
self._created = True
|
||||
if args:
|
||||
# Combine positional arguments with named arguments.
|
||||
# We only want named arguments.
|
||||
@@ -54,7 +55,11 @@ class BaseDocument(object):
|
||||
__auto_convert = values.pop("__auto_convert", True)
|
||||
signals.pre_init.send(self.__class__, document=self, values=values)
|
||||
|
||||
self._data = {}
|
||||
if self.STRICT and not self._dynamic:
|
||||
self._data = StrictDict.create(allowed_keys=self._fields.keys())()
|
||||
else:
|
||||
self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())()
|
||||
|
||||
self._dynamic_fields = SON()
|
||||
|
||||
# Assign default values to instance
|
||||
@@ -130,17 +135,25 @@ class BaseDocument(object):
|
||||
self._data[name] = value
|
||||
if hasattr(self, '_changed_fields'):
|
||||
self._mark_as_changed(name)
|
||||
try:
|
||||
self__created = self._created
|
||||
except AttributeError:
|
||||
self__created = True
|
||||
|
||||
if (self._is_document and not self._created and
|
||||
if (self._is_document and not self__created and
|
||||
name in self._meta.get('shard_key', tuple()) and
|
||||
self._data.get(name) != value):
|
||||
OperationError = _import_class('OperationError')
|
||||
msg = "Shard Keys are immutable. Tried to update %s" % name
|
||||
raise OperationError(msg)
|
||||
|
||||
try:
|
||||
self__initialised = self._initialised
|
||||
except AttributeError:
|
||||
self__initialised = False
|
||||
# Check if the user has created a new instance of a class
|
||||
if (self._is_document and self._initialised
|
||||
and self._created and name == self._meta['id_field']):
|
||||
if (self._is_document and self__initialised
|
||||
and self__created and name == self._meta['id_field']):
|
||||
super(BaseDocument, self).__setattr__('_created', False)
|
||||
|
||||
super(BaseDocument, self).__setattr__(name, value)
|
||||
@@ -158,9 +171,11 @@ class BaseDocument(object):
|
||||
if isinstance(data["_data"], SON):
|
||||
data["_data"] = self.__class__._from_son(data["_data"])._data
|
||||
for k in ('_changed_fields', '_initialised', '_created', '_data',
|
||||
'_fields_ordered', '_dynamic_fields'):
|
||||
'_dynamic_fields'):
|
||||
if k in data:
|
||||
setattr(self, k, data[k])
|
||||
if '_fields_ordered' in data:
|
||||
setattr(type(self), '_fields_ordered', data['_fields_ordered'])
|
||||
dynamic_fields = data.get('_dynamic_fields') or SON()
|
||||
for k in dynamic_fields.keys():
|
||||
setattr(self, k, data["_data"].get(k))
|
||||
@@ -182,7 +197,7 @@ class BaseDocument(object):
|
||||
"""Dictionary-style field access, set a field's value.
|
||||
"""
|
||||
# Ensure that the field exists before settings its value
|
||||
if name not in self._fields:
|
||||
if not self._dynamic and name not in self._fields:
|
||||
raise KeyError(name)
|
||||
return setattr(self, name, value)
|
||||
|
||||
@@ -214,8 +229,9 @@ class BaseDocument(object):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__) and hasattr(other, 'id'):
|
||||
if self.id == other.id:
|
||||
return True
|
||||
return self.id == other.id
|
||||
if isinstance(other, DBRef):
|
||||
return self._get_collection_name() == other.collection and self.id == other.id
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
@@ -317,7 +333,7 @@ class BaseDocument(object):
|
||||
pk = "None"
|
||||
if hasattr(self, 'pk'):
|
||||
pk = self.pk
|
||||
elif self._instance:
|
||||
elif self._instance and hasattr(self._instance, 'pk'):
|
||||
pk = self._instance.pk
|
||||
message = "ValidationError (%s:%s) " % (self._class_name, pk)
|
||||
raise ValidationError(message, errors=errors)
|
||||
@@ -401,6 +417,8 @@ class BaseDocument(object):
|
||||
else:
|
||||
data = getattr(data, part, None)
|
||||
if hasattr(data, "_changed_fields"):
|
||||
if hasattr(data, "_is_document") and data._is_document:
|
||||
continue
|
||||
data._changed_fields = []
|
||||
self._changed_fields = []
|
||||
|
||||
@@ -562,10 +580,6 @@ class BaseDocument(object):
|
||||
# class if unavailable
|
||||
class_name = son.get('_cls', cls._class_name)
|
||||
data = dict(("%s" % key, value) for key, value in son.iteritems())
|
||||
if not UNICODE_KWARGS:
|
||||
# python 2.6.4 and lower cannot handle unicode keys
|
||||
# passed to class constructor example: cls(**data)
|
||||
to_str_keys_recursive(data)
|
||||
|
||||
# Return correct subclass for document type
|
||||
if class_name != cls._class_name:
|
||||
@@ -603,6 +617,8 @@ class BaseDocument(object):
|
||||
% (cls._class_name, errors))
|
||||
raise InvalidDocumentError(msg)
|
||||
|
||||
if cls.STRICT:
|
||||
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields)
|
||||
obj = cls(__auto_convert=False, **data)
|
||||
obj._changed_fields = changed_fields
|
||||
obj._created = False
|
||||
@@ -821,8 +837,17 @@ class BaseDocument(object):
|
||||
# Look up subfield on the previous field
|
||||
new_field = field.lookup_member(field_name)
|
||||
if not new_field and isinstance(field, ComplexBaseField):
|
||||
fields.append(field_name)
|
||||
continue
|
||||
if hasattr(field.field, 'document_type') and cls._dynamic \
|
||||
and field.field.document_type._dynamic:
|
||||
DynamicField = _import_class('DynamicField')
|
||||
new_field = DynamicField(db_field=field_name)
|
||||
else:
|
||||
fields.append(field_name)
|
||||
continue
|
||||
elif not new_field and hasattr(field, 'document_type') and cls._dynamic \
|
||||
and field.document_type._dynamic:
|
||||
DynamicField = _import_class('DynamicField')
|
||||
new_field = DynamicField(db_field=field_name)
|
||||
elif not new_field:
|
||||
raise LookUpError('Cannot resolve field "%s"'
|
||||
% field_name)
|
||||
@@ -842,7 +867,11 @@ class BaseDocument(object):
|
||||
"""Dynamically set the display value for a field with choices"""
|
||||
for attr_name, field in self._fields.items():
|
||||
if field.choices:
|
||||
setattr(self,
|
||||
if self._dynamic:
|
||||
obj = self
|
||||
else:
|
||||
obj = type(self)
|
||||
setattr(obj,
|
||||
'get_%s_display' % attr_name,
|
||||
partial(self.__get_field_display, field=field))
|
||||
|
||||
|
||||
@@ -359,7 +359,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
new_class.id = field
|
||||
|
||||
# Set primary key if not defined by the document
|
||||
new_class._auto_id_field = False
|
||||
new_class._auto_id_field = getattr(parent_doc_cls,
|
||||
'_auto_id_field', False)
|
||||
if not new_class._meta.get('id_field'):
|
||||
new_class._auto_id_field = True
|
||||
new_class._meta['id_field'] = 'id'
|
||||
|
||||
@@ -20,7 +20,8 @@ _dbs = {}
|
||||
|
||||
def register_connection(alias, name, host=None, port=None,
|
||||
is_slave=False, read_preference=False, slaves=None,
|
||||
username=None, password=None, **kwargs):
|
||||
username=None, password=None, authentication_source=None,
|
||||
**kwargs):
|
||||
"""Add a connection.
|
||||
|
||||
:param alias: the name that will be used to refer to this connection
|
||||
@@ -36,6 +37,7 @@ def register_connection(alias, name, host=None, port=None,
|
||||
be a registered connection that has :attr:`is_slave` set to ``True``
|
||||
:param username: username to authenticate with
|
||||
:param password: password to authenticate with
|
||||
:param authentication_source: database to authenticate against
|
||||
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
|
||||
|
||||
"""
|
||||
@@ -46,10 +48,11 @@ def register_connection(alias, name, host=None, port=None,
|
||||
'host': host or 'localhost',
|
||||
'port': port or 27017,
|
||||
'is_slave': is_slave,
|
||||
'read_preference': read_preference,
|
||||
'slaves': slaves or [],
|
||||
'username': username,
|
||||
'password': password,
|
||||
'read_preference': read_preference
|
||||
'authentication_source': authentication_source
|
||||
}
|
||||
|
||||
# Handle uri style connections
|
||||
@@ -93,20 +96,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
raise ConnectionError(msg)
|
||||
conn_settings = _connection_settings[alias].copy()
|
||||
|
||||
if hasattr(pymongo, 'version_tuple'): # Support for 2.1+
|
||||
conn_settings.pop('name', None)
|
||||
conn_settings.pop('slaves', None)
|
||||
conn_settings.pop('is_slave', None)
|
||||
conn_settings.pop('username', None)
|
||||
conn_settings.pop('password', None)
|
||||
else:
|
||||
# Get all the slave connections
|
||||
if 'slaves' in conn_settings:
|
||||
slaves = []
|
||||
for slave_alias in conn_settings['slaves']:
|
||||
slaves.append(get_connection(slave_alias))
|
||||
conn_settings['slaves'] = slaves
|
||||
conn_settings.pop('read_preference', None)
|
||||
conn_settings.pop('name', None)
|
||||
conn_settings.pop('slaves', None)
|
||||
conn_settings.pop('is_slave', None)
|
||||
conn_settings.pop('username', None)
|
||||
conn_settings.pop('password', None)
|
||||
conn_settings.pop('authentication_source', None)
|
||||
|
||||
connection_class = MongoClient
|
||||
if 'replicaSet' in conn_settings:
|
||||
@@ -119,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
connection_class = MongoReplicaSetClient
|
||||
|
||||
try:
|
||||
_connections[alias] = connection_class(**conn_settings)
|
||||
connection = None
|
||||
connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems())
|
||||
for alias, connection_settings in connection_settings_iterator:
|
||||
connection_settings.pop('name', None)
|
||||
connection_settings.pop('slaves', None)
|
||||
connection_settings.pop('is_slave', None)
|
||||
connection_settings.pop('username', None)
|
||||
connection_settings.pop('password', None)
|
||||
if conn_settings == connection_settings and _connections.get(alias, None):
|
||||
connection = _connections[alias]
|
||||
break
|
||||
|
||||
_connections[alias] = connection if connection else connection_class(**conn_settings)
|
||||
except Exception, e:
|
||||
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
|
||||
return _connections[alias]
|
||||
@@ -137,7 +144,8 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
# Authenticate if necessary
|
||||
if conn_settings['username'] and conn_settings['password']:
|
||||
db.authenticate(conn_settings['username'],
|
||||
conn_settings['password'])
|
||||
conn_settings['password'],
|
||||
source=conn_settings['authentication_source'])
|
||||
_dbs[alias] = db
|
||||
return _dbs[alias]
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
from mongoengine.queryset import QuerySet
|
||||
|
||||
|
||||
__all__ = ("switch_db", "switch_collection", "no_dereference",
|
||||
@@ -162,12 +161,6 @@ class no_sub_classes(object):
|
||||
return self.cls
|
||||
|
||||
|
||||
class QuerySetNoDeRef(QuerySet):
|
||||
"""Special no_dereference QuerySet"""
|
||||
def __dereference(items, max_depth=1, instance=None, name=None):
|
||||
return items
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
""" Query_counter context manager to get the number of queries. """
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class DeReference(object):
|
||||
if instance and isinstance(instance, (Document, EmbeddedDocument,
|
||||
TopLevelDocumentMetaclass)):
|
||||
doc_type = instance._fields.get(name)
|
||||
if hasattr(doc_type, 'field'):
|
||||
while hasattr(doc_type, 'field'):
|
||||
doc_type = doc_type.field
|
||||
|
||||
if isinstance(doc_type, ReferenceField):
|
||||
@@ -51,9 +51,19 @@ class DeReference(object):
|
||||
return items
|
||||
elif not field.dbref:
|
||||
if not hasattr(items, 'items'):
|
||||
items = [field.to_python(v)
|
||||
if not isinstance(v, (DBRef, Document)) else v
|
||||
for v in items]
|
||||
|
||||
def _get_items(items):
|
||||
new_items = []
|
||||
for v in items:
|
||||
if isinstance(v, list):
|
||||
new_items.append(_get_items(v))
|
||||
elif not isinstance(v, (DBRef, Document)):
|
||||
new_items.append(field.to_python(v))
|
||||
else:
|
||||
new_items.append(v)
|
||||
return new_items
|
||||
|
||||
items = _get_items(items)
|
||||
else:
|
||||
items = dict([
|
||||
(k, field.to_python(v))
|
||||
@@ -114,11 +124,11 @@ class DeReference(object):
|
||||
"""Fetch all references and convert to their document objects
|
||||
"""
|
||||
object_map = {}
|
||||
for col, dbrefs in self.reference_map.iteritems():
|
||||
for collection, dbrefs in self.reference_map.iteritems():
|
||||
keys = object_map.keys()
|
||||
refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys]))
|
||||
if hasattr(col, 'objects'): # We have a document class for the refs
|
||||
references = col.objects.in_bulk(refs)
|
||||
if hasattr(collection, 'objects'): # We have a document class for the refs
|
||||
references = collection.objects.in_bulk(refs)
|
||||
for key, doc in references.iteritems():
|
||||
object_map[key] = doc
|
||||
else: # Generic reference: use the refs data to convert to document
|
||||
@@ -126,19 +136,19 @@ class DeReference(object):
|
||||
continue
|
||||
|
||||
if doc_type:
|
||||
references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
|
||||
references = doc_type._get_db()[collection].find({'_id': {'$in': refs}})
|
||||
for ref in references:
|
||||
doc = doc_type._from_son(ref)
|
||||
object_map[doc.id] = doc
|
||||
else:
|
||||
references = get_db()[col].find({'_id': {'$in': refs}})
|
||||
references = get_db()[collection].find({'_id': {'$in': refs}})
|
||||
for ref in references:
|
||||
if '_cls' in ref:
|
||||
doc = get_document(ref["_cls"])._from_son(ref)
|
||||
elif doc_type is None:
|
||||
doc = get_document(
|
||||
''.join(x.capitalize()
|
||||
for x in col.split('_')))._from_son(ref)
|
||||
for x in collection.split('_')))._from_son(ref)
|
||||
else:
|
||||
doc = doc_type._from_son(ref)
|
||||
object_map[doc.id] = doc
|
||||
|
||||
@@ -13,7 +13,8 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
|
||||
BaseDocument, BaseDict, BaseList,
|
||||
ALLOW_INHERITANCE, get_document)
|
||||
from mongoengine.errors import ValidationError
|
||||
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
|
||||
from mongoengine.queryset import (OperationError, NotUniqueError,
|
||||
QuerySet, transform)
|
||||
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
||||
from mongoengine.context_managers import switch_db, switch_collection
|
||||
|
||||
@@ -54,20 +55,21 @@ class EmbeddedDocument(BaseDocument):
|
||||
dictionary.
|
||||
"""
|
||||
|
||||
__slots__ = ('_instance')
|
||||
|
||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||
my_metaclass = DocumentMetaclass
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
_instance = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
||||
self._instance = None
|
||||
self._changed_fields = []
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.to_mongo() == other.to_mongo()
|
||||
return self._data == other._data
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
@@ -125,6 +127,8 @@ class Document(BaseDocument):
|
||||
my_metaclass = TopLevelDocumentMetaclass
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
__slots__ = ('__objects' )
|
||||
|
||||
def pk():
|
||||
"""Primary key alias
|
||||
"""
|
||||
@@ -180,7 +184,7 @@ class Document(BaseDocument):
|
||||
|
||||
def save(self, force_insert=False, validate=True, clean=True,
|
||||
write_concern=None, cascade=None, cascade_kwargs=None,
|
||||
_refs=None, **kwargs):
|
||||
_refs=None, save_condition=None, **kwargs):
|
||||
"""Save the :class:`~mongoengine.Document` to the database. If the
|
||||
document already exists, it will be updated, otherwise it will be
|
||||
created.
|
||||
@@ -203,7 +207,8 @@ class Document(BaseDocument):
|
||||
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw
|
||||
to cascading saves. Implies ``cascade=True``.
|
||||
:param _refs: A list of processed references used in cascading saves
|
||||
|
||||
:param save_condition: only perform save if matching record in db
|
||||
satisfies condition(s) (e.g., version number)
|
||||
.. versionchanged:: 0.5
|
||||
In existing documents it only saves changed fields using
|
||||
set / unset. Saves are cascaded and any
|
||||
@@ -217,6 +222,9 @@ class Document(BaseDocument):
|
||||
meta['cascade'] = True. Also you can pass different kwargs to
|
||||
the cascade save using cascade_kwargs which overwrites the
|
||||
existing kwargs with custom values.
|
||||
.. versionchanged:: 0.8.5
|
||||
Optional save_condition that only overwrites existing documents
|
||||
if the condition is satisfied in the current db record.
|
||||
"""
|
||||
signals.pre_save.send(self.__class__, document=self)
|
||||
|
||||
@@ -230,7 +238,8 @@ class Document(BaseDocument):
|
||||
|
||||
created = ('_id' not in doc or self._created or force_insert)
|
||||
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self,
|
||||
created=created)
|
||||
|
||||
try:
|
||||
collection = self._get_collection()
|
||||
@@ -243,7 +252,12 @@ class Document(BaseDocument):
|
||||
object_id = doc['_id']
|
||||
updates, removals = self._delta()
|
||||
# Need to add shard key to query, or you get an error
|
||||
select_dict = {'_id': object_id}
|
||||
if save_condition is not None:
|
||||
select_dict = transform.query(self.__class__,
|
||||
**save_condition)
|
||||
else:
|
||||
select_dict = {}
|
||||
select_dict['_id'] = object_id
|
||||
shard_key = self.__class__._meta.get('shard_key', tuple())
|
||||
for k in shard_key:
|
||||
actual_key = self._db_field_map.get(k, k)
|
||||
@@ -263,10 +277,12 @@ class Document(BaseDocument):
|
||||
if removals:
|
||||
update_query["$unset"] = removals
|
||||
if updates or removals:
|
||||
upsert = save_condition is None
|
||||
last_error = collection.update(select_dict, update_query,
|
||||
upsert=True, **write_concern)
|
||||
upsert=upsert, **write_concern)
|
||||
created = is_new_object(last_error)
|
||||
|
||||
|
||||
if cascade is None:
|
||||
cascade = self._meta.get('cascade', False) or cascade_kwargs is not None
|
||||
|
||||
@@ -293,12 +309,12 @@ class Document(BaseDocument):
|
||||
raise NotUniqueError(message % unicode(err))
|
||||
raise OperationError(message % unicode(err))
|
||||
id_field = self._meta['id_field']
|
||||
if id_field not in self._meta.get('shard_key', []):
|
||||
if created or id_field not in self._meta.get('shard_key', []):
|
||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||
|
||||
signals.post_save.send(self.__class__, document=self, created=created)
|
||||
self._clear_changed_fields()
|
||||
self._created = False
|
||||
signals.post_save.send(self.__class__, document=self, created=created)
|
||||
return self
|
||||
|
||||
def cascade_save(self, *args, **kwargs):
|
||||
@@ -447,27 +463,41 @@ class Document(BaseDocument):
|
||||
DeReference()([self], max_depth + 1)
|
||||
return self
|
||||
|
||||
def reload(self, max_depth=1):
|
||||
def reload(self, *fields, **kwargs):
|
||||
"""Reloads all attributes from the database.
|
||||
|
||||
:param fields: (optional) args list of fields to reload
|
||||
:param max_depth: (optional) depth of dereferencing to follow
|
||||
|
||||
.. versionadded:: 0.1.2
|
||||
.. versionchanged:: 0.6 Now chainable
|
||||
.. versionchanged:: 0.9 Can provide specific fields to reload
|
||||
"""
|
||||
max_depth = 1
|
||||
if fields and isinstance(fields[0], int):
|
||||
max_depth = fields[0]
|
||||
fields = fields[1:]
|
||||
elif "max_depth" in kwargs:
|
||||
max_depth = kwargs["max_depth"]
|
||||
|
||||
if not self.pk:
|
||||
raise self.DoesNotExist("Document does not exist")
|
||||
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
|
||||
**self._object_key).limit(1).select_related(max_depth=max_depth)
|
||||
|
||||
**self._object_key).only(*fields).limit(1
|
||||
).select_related(max_depth=max_depth)
|
||||
|
||||
if obj:
|
||||
obj = obj[0]
|
||||
else:
|
||||
raise self.DoesNotExist("Document does not exist")
|
||||
|
||||
for field in self._fields_ordered:
|
||||
setattr(self, field, self._reload(field, obj[field]))
|
||||
if not fields or field in fields:
|
||||
setattr(self, field, self._reload(field, obj[field]))
|
||||
|
||||
self._changed_fields = obj._changed_fields
|
||||
self._created = False
|
||||
return obj
|
||||
return self
|
||||
|
||||
def _reload(self, key, value):
|
||||
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
|
||||
|
||||
@@ -760,7 +760,7 @@ class DictField(ComplexBaseField):
|
||||
similar to an embedded document, but the structure is not defined.
|
||||
|
||||
.. note::
|
||||
Required means it cannot be empty - as the default for ListFields is []
|
||||
Required means it cannot be empty - as the default for DictFields is {}
|
||||
|
||||
.. versionadded:: 0.3
|
||||
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
|
||||
@@ -1554,6 +1554,14 @@ class SequenceField(BaseField):
|
||||
|
||||
return super(SequenceField, self).__set__(instance, value)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
"""
|
||||
This method is overriden in order to convert the query value into to required
|
||||
type. We need to do this in order to be able to successfully compare query
|
||||
values passed as string, the base implementation returns the value as is.
|
||||
"""
|
||||
return self.value_decorator(value)
|
||||
|
||||
def to_python(self, value):
|
||||
if value is None:
|
||||
value = self.generate()
|
||||
@@ -1613,7 +1621,12 @@ class UUIDField(BaseField):
|
||||
|
||||
|
||||
class GeoPointField(BaseField):
|
||||
"""A list storing a latitude and longitude.
|
||||
"""A list storing a longitude and latitude coordinate.
|
||||
|
||||
.. note:: this represents a generic point in a 2D plane and a legacy way of
|
||||
representing a geo point. It admits 2d indexes but not "2dsphere" indexes
|
||||
in MongoDB > 2.4 which are more natural for modeling geospatial points.
|
||||
See :ref:`geospatial-indexes`
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
@@ -1635,7 +1648,7 @@ class GeoPointField(BaseField):
|
||||
|
||||
|
||||
class PointField(GeoJsonBaseField):
|
||||
"""A geo json field storing a latitude and longitude.
|
||||
"""A GeoJSON field storing a longitude and latitude coordinate.
|
||||
|
||||
The data is represented as:
|
||||
|
||||
@@ -1654,7 +1667,7 @@ class PointField(GeoJsonBaseField):
|
||||
|
||||
|
||||
class LineStringField(GeoJsonBaseField):
|
||||
"""A geo json field storing a line of latitude and longitude coordinates.
|
||||
"""A GeoJSON field storing a line of longitude and latitude coordinates.
|
||||
|
||||
The data is represented as:
|
||||
|
||||
@@ -1672,7 +1685,7 @@ class LineStringField(GeoJsonBaseField):
|
||||
|
||||
|
||||
class PolygonField(GeoJsonBaseField):
|
||||
"""A geo json field storing a polygon of latitude and longitude coordinates.
|
||||
"""A GeoJSON field storing a polygon of longitude and latitude coordinates.
|
||||
|
||||
The data is represented as:
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import sys
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
PY25 = sys.version_info[:2] == (2, 5)
|
||||
UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264
|
||||
|
||||
if PY3:
|
||||
import codecs
|
||||
@@ -29,33 +27,3 @@ else:
|
||||
txt_type = unicode
|
||||
|
||||
str_types = (bin_type, txt_type)
|
||||
|
||||
if PY25:
|
||||
def product(*args, **kwds):
|
||||
pools = map(tuple, args) * kwds.get('repeat', 1)
|
||||
result = [[]]
|
||||
for pool in pools:
|
||||
result = [x + [y] for x in result for y in pool]
|
||||
for prod in result:
|
||||
yield tuple(prod)
|
||||
reduce = reduce
|
||||
else:
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
|
||||
|
||||
# For use with Python 2.5
|
||||
# converts all keys from unicode to str for d and all nested dictionaries
|
||||
def to_str_keys_recursive(d):
|
||||
if isinstance(d, list):
|
||||
for val in d:
|
||||
if isinstance(val, (dict, list)):
|
||||
to_str_keys_recursive(val)
|
||||
elif isinstance(d, dict):
|
||||
for key, val in d.items():
|
||||
if isinstance(val, (dict, list)):
|
||||
to_str_keys_recursive(val)
|
||||
if isinstance(key, unicode):
|
||||
d[str(key)] = d.pop(key)
|
||||
else:
|
||||
raise ValueError("non list/dict parameter not allowed")
|
||||
|
||||
@@ -7,17 +7,20 @@ import pprint
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from bson import SON
|
||||
from bson.code import Code
|
||||
from bson import json_util
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from pymongo.common import validate_read_preference
|
||||
|
||||
from mongoengine import signals
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.context_managers import switch_db
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.base.common import get_document
|
||||
from mongoengine.errors import (OperationError, NotUniqueError,
|
||||
InvalidQueryError, LookUpError)
|
||||
|
||||
from mongoengine.queryset import transform
|
||||
from mongoengine.queryset.field_list import QueryFieldList
|
||||
from mongoengine.queryset.visitor import Q, QNode
|
||||
@@ -50,7 +53,7 @@ class BaseQuerySet(object):
|
||||
self._initial_query = {}
|
||||
self._where_clause = None
|
||||
self._loaded_fields = QueryFieldList()
|
||||
self._ordering = []
|
||||
self._ordering = None
|
||||
self._snapshot = False
|
||||
self._timeout = True
|
||||
self._class_check = True
|
||||
@@ -146,7 +149,7 @@ class BaseQuerySet(object):
|
||||
queryset._document._from_son(queryset._cursor[key],
|
||||
_auto_dereference=self._auto_dereference))
|
||||
if queryset._as_pymongo:
|
||||
return queryset._get_as_pymongo(queryset._cursor.next())
|
||||
return queryset._get_as_pymongo(queryset._cursor[key])
|
||||
return queryset._document._from_son(queryset._cursor[key],
|
||||
_auto_dereference=self._auto_dereference)
|
||||
raise AttributeError
|
||||
@@ -154,6 +157,22 @@ class BaseQuerySet(object):
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _has_data(self):
|
||||
""" Retrieves whether cursor has any data. """
|
||||
|
||||
queryset = self.order_by()
|
||||
return False if queryset.first() is None else True
|
||||
|
||||
def __nonzero__(self):
|
||||
""" Avoid to open all records in an if stmt in Py2. """
|
||||
|
||||
return self._has_data()
|
||||
|
||||
def __bool__(self):
|
||||
""" Avoid to open all records in an if stmt in Py3. """
|
||||
|
||||
return self._has_data()
|
||||
|
||||
# Core functions
|
||||
|
||||
def all(self):
|
||||
@@ -175,7 +194,7 @@ class BaseQuerySet(object):
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
queryset = self.clone()
|
||||
queryset = queryset.limit(2)
|
||||
queryset = queryset.order_by().limit(2)
|
||||
queryset = queryset.filter(*q_objs, **query)
|
||||
|
||||
try:
|
||||
@@ -389,7 +408,7 @@ class BaseQuerySet(object):
|
||||
ref_q = document_cls.objects(**{field_name + '__in': self})
|
||||
ref_q_count = ref_q.count()
|
||||
if (doc != document_cls and ref_q_count > 0
|
||||
or (doc == document_cls and ref_q_count > 0)):
|
||||
or (doc == document_cls and ref_q_count > 0)):
|
||||
ref_q.delete(write_concern=write_concern)
|
||||
elif rule == NULLIFY:
|
||||
document_cls.objects(**{field_name + '__in': self}).update(
|
||||
@@ -443,6 +462,8 @@ class BaseQuerySet(object):
|
||||
return result
|
||||
elif result:
|
||||
return result['n']
|
||||
except pymongo.errors.DuplicateKeyError, err:
|
||||
raise NotUniqueError(u'Update failed (%s)' % unicode(err))
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
if unicode(err) == u'multi not coded yet':
|
||||
message = u'update() method requires MongoDB 1.1.3+'
|
||||
@@ -466,6 +487,59 @@ class BaseQuerySet(object):
|
||||
return self.update(
|
||||
upsert=upsert, multi=False, write_concern=write_concern, **update)
|
||||
|
||||
def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
|
||||
"""Update and return the updated document.
|
||||
|
||||
Returns either the document before or after modification based on `new`
|
||||
parameter. If no documents match the query and `upsert` is false,
|
||||
returns ``None``. If upserting and `new` is false, returns ``None``.
|
||||
|
||||
If the full_response parameter is ``True``, the return value will be
|
||||
the entire response object from the server, including the 'ok' and
|
||||
'lastErrorObject' fields, rather than just the modified document.
|
||||
This is useful mainly because the 'lastErrorObject' document holds
|
||||
information about the command's execution.
|
||||
|
||||
:param upsert: insert if document doesn't exist (default ``False``)
|
||||
:param full_response: return the entire response object from the
|
||||
server (default ``False``)
|
||||
:param remove: remove rather than updating (default ``False``)
|
||||
:param new: return updated rather than original document
|
||||
(default ``False``)
|
||||
:param update: Django-style update keyword arguments
|
||||
|
||||
.. versionadded:: 0.9
|
||||
"""
|
||||
|
||||
if remove and new:
|
||||
raise OperationError("Conflicting parameters: remove and new")
|
||||
|
||||
if not update and not upsert and not remove:
|
||||
raise OperationError("No update parameters, must either update or remove")
|
||||
|
||||
queryset = self.clone()
|
||||
query = queryset._query
|
||||
update = transform.update(queryset._document, **update)
|
||||
sort = queryset._ordering
|
||||
|
||||
try:
|
||||
result = queryset._collection.find_and_modify(
|
||||
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
|
||||
full_response=full_response, **self._cursor_args)
|
||||
except pymongo.errors.DuplicateKeyError, err:
|
||||
raise NotUniqueError(u"Update failed (%s)" % err)
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
raise OperationError(u"Update failed (%s)" % err)
|
||||
|
||||
if full_response:
|
||||
if result["value"] is not None:
|
||||
result["value"] = self._document._from_son(result["value"])
|
||||
else:
|
||||
if result is not None:
|
||||
result = self._document._from_son(result)
|
||||
|
||||
return result
|
||||
|
||||
def with_id(self, object_id):
|
||||
"""Retrieve the object matching the id provided. Uses `object_id` only
|
||||
and raises InvalidQueryError if a filter has been applied. Returns
|
||||
@@ -522,6 +596,19 @@ class BaseQuerySet(object):
|
||||
|
||||
return self
|
||||
|
||||
def using(self, alias):
|
||||
"""This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database.
|
||||
|
||||
:param alias: The database alias
|
||||
|
||||
.. versionadded:: 0.8
|
||||
"""
|
||||
|
||||
with switch_db(self._document, alias) as cls:
|
||||
collection = cls._get_collection()
|
||||
|
||||
return self.clone_into(self.__class__(self._document, collection))
|
||||
|
||||
def clone(self):
|
||||
"""Creates a copy of the current
|
||||
:class:`~mongoengine.queryset.QuerySet`
|
||||
@@ -923,10 +1010,39 @@ class BaseQuerySet(object):
|
||||
map_reduce_function = 'inline_map_reduce'
|
||||
else:
|
||||
map_reduce_function = 'map_reduce'
|
||||
mr_args['out'] = output
|
||||
|
||||
if isinstance(output, basestring):
|
||||
mr_args['out'] = output
|
||||
|
||||
elif isinstance(output, dict):
|
||||
ordered_output = []
|
||||
|
||||
for part in ('replace', 'merge', 'reduce'):
|
||||
value = output.get(part)
|
||||
if value:
|
||||
ordered_output.append((part, value))
|
||||
break
|
||||
|
||||
else:
|
||||
raise OperationError("actionData not specified for output")
|
||||
|
||||
db_alias = output.get('db_alias')
|
||||
remaing_args = ['db', 'sharded', 'nonAtomic']
|
||||
|
||||
if db_alias:
|
||||
ordered_output.append(('db', get_db(db_alias).name))
|
||||
del remaing_args[0]
|
||||
|
||||
|
||||
for part in remaing_args:
|
||||
value = output.get(part)
|
||||
if value:
|
||||
ordered_output.append((part, value))
|
||||
|
||||
mr_args['out'] = SON(ordered_output)
|
||||
|
||||
results = getattr(queryset._collection, map_reduce_function)(
|
||||
map_f, reduce_f, **mr_args)
|
||||
map_f, reduce_f, **mr_args)
|
||||
|
||||
if map_reduce_function == 'map_reduce':
|
||||
results = results.find()
|
||||
@@ -1189,8 +1305,9 @@ class BaseQuerySet(object):
|
||||
if self._ordering:
|
||||
# Apply query ordering
|
||||
self._cursor_obj.sort(self._ordering)
|
||||
elif self._document._meta['ordering']:
|
||||
# Otherwise, apply the ordering from the document model
|
||||
elif self._ordering is None and self._document._meta['ordering']:
|
||||
# Otherwise, apply the ordering from the document model, unless
|
||||
# it's been explicitly cleared via order_by with no arguments
|
||||
order = self._get_order_by(self._document._meta['ordering'])
|
||||
self._cursor_obj.sort(order)
|
||||
|
||||
@@ -1362,7 +1479,7 @@ class BaseQuerySet(object):
|
||||
for subdoc in subclasses:
|
||||
try:
|
||||
subfield = ".".join(f.db_field for f in
|
||||
subdoc._lookup_field(field.split('.')))
|
||||
subdoc._lookup_field(field.split('.')))
|
||||
ret.append(subfield)
|
||||
found = True
|
||||
break
|
||||
@@ -1392,7 +1509,7 @@ class BaseQuerySet(object):
|
||||
pass
|
||||
key_list.append((key, direction))
|
||||
|
||||
if self._cursor_obj:
|
||||
if self._cursor_obj and key_list:
|
||||
self._cursor_obj.sort(key_list)
|
||||
return key_list
|
||||
|
||||
@@ -1450,6 +1567,7 @@ class BaseQuerySet(object):
|
||||
# type of this field and use the corresponding
|
||||
# .to_python(...)
|
||||
from mongoengine.fields import EmbeddedDocumentField
|
||||
|
||||
obj = self._document
|
||||
for chunk in path.split('.'):
|
||||
obj = getattr(obj, chunk, None)
|
||||
@@ -1460,6 +1578,7 @@ class BaseQuerySet(object):
|
||||
if obj and data is not None:
|
||||
data = obj.to_python(data)
|
||||
return data
|
||||
|
||||
return clean(row)
|
||||
|
||||
def _sub_js_fields(self, code):
|
||||
@@ -1468,6 +1587,7 @@ class BaseQuerySet(object):
|
||||
substituted for the MongoDB name of the field (specified using the
|
||||
:attr:`name` keyword argument in a field's constructor).
|
||||
"""
|
||||
|
||||
def field_sub(match):
|
||||
# Extract just the field name, and look up the field objects
|
||||
field_name = match.group(1).split('.')
|
||||
|
||||
@@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet):
|
||||
queryset = self.clone()
|
||||
queryset.rewind()
|
||||
return queryset
|
||||
|
||||
|
||||
class QuerySetNoDeRef(QuerySet):
|
||||
"""Special no_dereference QuerySet"""
|
||||
|
||||
def __dereference(items, max_depth=1, instance=None, name=None):
|
||||
return items
|
||||
@@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
import pymongo
|
||||
from bson import SON
|
||||
|
||||
from mongoengine.connection import get_connection
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import InvalidQueryError, LookUpError
|
||||
|
||||
@@ -38,7 +39,7 @@ def query(_doc_cls=None, _field_operation=False, **query):
|
||||
mongo_query.update(value)
|
||||
continue
|
||||
|
||||
parts = key.split('__')
|
||||
parts = key.rsplit('__')
|
||||
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
|
||||
parts = [part for part in parts if not part.isdigit()]
|
||||
# Check for an operator and transform to mongo-style if there is
|
||||
@@ -115,14 +116,26 @@ def query(_doc_cls=None, _field_operation=False, **query):
|
||||
if key in mongo_query and isinstance(mongo_query[key], dict):
|
||||
mongo_query[key].update(value)
|
||||
# $maxDistance needs to come last - convert to SON
|
||||
if '$maxDistance' in mongo_query[key]:
|
||||
value_dict = mongo_query[key]
|
||||
value_dict = mongo_query[key]
|
||||
if ('$maxDistance' in value_dict and '$near' in value_dict):
|
||||
value_son = SON()
|
||||
for k, v in value_dict.iteritems():
|
||||
if k == '$maxDistance':
|
||||
continue
|
||||
value_son[k] = v
|
||||
value_son['$maxDistance'] = value_dict['$maxDistance']
|
||||
if isinstance(value_dict['$near'], dict):
|
||||
for k, v in value_dict.iteritems():
|
||||
if k == '$maxDistance':
|
||||
continue
|
||||
value_son[k] = v
|
||||
if (get_connection().max_wire_version <= 1):
|
||||
value_son['$maxDistance'] = value_dict['$maxDistance']
|
||||
else:
|
||||
value_son['$near'] = SON(value_son['$near'])
|
||||
value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
|
||||
else:
|
||||
for k, v in value_dict.iteritems():
|
||||
if k == '$maxDistance':
|
||||
continue
|
||||
value_son[k] = v
|
||||
value_son['$maxDistance'] = value_dict['$maxDistance']
|
||||
|
||||
mongo_query[key] = value_son
|
||||
else:
|
||||
# Store for manually merging later
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import copy
|
||||
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.python_support import product, reduce
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.queryset import transform
|
||||
|
||||
__all__ = ('Q',)
|
||||
|
||||
Reference in New Issue
Block a user