Merge branch 'master' into pr/539

Conflicts:
	mongoengine/base/datastructures.py
This commit is contained in:
Ross Lawley
2014-06-27 12:20:44 +01:00
34 changed files with 1404 additions and 208 deletions

View File

@@ -1,4 +1,6 @@
import weakref
import functools
import itertools
from mongoengine.common import _import_class
__all__ = ("BaseDict", "BaseList")
@@ -183,3 +185,97 @@ class BaseList(list):
self._instance._mark_as_changed('%s.%s' % (self._name, key))
else:
self._instance._mark_as_changed(self._name)
class StrictDict(object):
__slots__ = ()
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
_classes = {}
def __init__(self, **kwargs):
for k,v in kwargs.iteritems():
setattr(self, k, v)
def __getitem__(self, key):
key = '_reserved_' + key if key in self._special_fields else key
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)
def __setitem__(self, key, value):
key = '_reserved_' + key if key in self._special_fields else key
return setattr(self, key, value)
def __contains__(self, key):
return hasattr(self, key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
v = self.get(key, default)
try:
delattr(self, key)
except AttributeError:
pass
return v
def iteritems(self):
for key in self:
yield key, self[key]
def items(self):
return [(k, self[k]) for k in iter(self)]
def keys(self):
return list(iter(self))
def __iter__(self):
return (key for key in self.__slots__ if hasattr(self, key))
def __len__(self):
return len(list(self.iteritems()))
def __eq__(self, other):
return self.items() == other.items()
def __neq__(self, other):
return self.items() != other.items()
@classmethod
def create(cls, allowed_keys):
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
allowed_keys = frozenset(allowed_keys_tuple)
if allowed_keys not in cls._classes:
class SpecificStrictDict(cls):
__slots__ = allowed_keys_tuple
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]
class SemiStrictDict(StrictDict):
__slots__ = ('_extras')
_classes = {}
def __getattr__(self, attr):
try:
super(SemiStrictDict, self).__getattr__(attr)
except AttributeError:
try:
return self.__getattribute__('_extras')[attr]
except KeyError as e:
raise AttributeError(e)
def __setattr__(self, attr, value):
try:
super(SemiStrictDict, self).__setattr__(attr, value)
except AttributeError:
try:
self._extras[attr] = value
except AttributeError:
self._extras = {attr: value}
def __delattr__(self, attr):
try:
super(SemiStrictDict, self).__delattr__(attr)
except AttributeError:
try:
del self._extras[attr]
except KeyError as e:
raise AttributeError(e)
def __iter__(self):
try:
extras_iter = iter(self.__getattribute__('_extras'))
except AttributeError:
extras_iter = ()
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)

View File

@@ -13,24 +13,23 @@ from mongoengine import signals
from mongoengine.common import _import_class
from mongoengine.errors import (ValidationError, InvalidDocumentError,
LookUpError)
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
to_str_keys_recursive)
from mongoengine.python_support import PY3, txt_type
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
from mongoengine.base.fields import ComplexBaseField
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__'
class BaseDocument(object):
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
_dynamic = False
_created = True
_dynamic_lock = True
_initialised = False
STRICT = False
def __init__(self, *args, **values):
"""
@@ -39,6 +38,8 @@ class BaseDocument(object):
:param __auto_convert: Try and will cast python objects to Object types
:param values: A dictionary of values for the document
"""
self._initialised = False
self._created = True
if args:
# Combine positional arguments with named arguments.
# We only want named arguments.
@@ -54,7 +55,11 @@ class BaseDocument(object):
__auto_convert = values.pop("__auto_convert", True)
signals.pre_init.send(self.__class__, document=self, values=values)
self._data = {}
if self.STRICT and not self._dynamic:
self._data = StrictDict.create(allowed_keys=self._fields.keys())()
else:
self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())()
self._dynamic_fields = SON()
# Assign default values to instance
@@ -130,17 +135,25 @@ class BaseDocument(object):
self._data[name] = value
if hasattr(self, '_changed_fields'):
self._mark_as_changed(name)
try:
self__created = self._created
except AttributeError:
self__created = True
if (self._is_document and not self._created and
if (self._is_document and not self__created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
OperationError = _import_class('OperationError')
msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg)
try:
self__initialised = self._initialised
except AttributeError:
self__initialised = False
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
if (self._is_document and self__initialised
and self__created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value)
@@ -158,9 +171,11 @@ class BaseDocument(object):
if isinstance(data["_data"], SON):
data["_data"] = self.__class__._from_son(data["_data"])._data
for k in ('_changed_fields', '_initialised', '_created', '_data',
'_fields_ordered', '_dynamic_fields'):
'_dynamic_fields'):
if k in data:
setattr(self, k, data[k])
if '_fields_ordered' in data:
setattr(type(self), '_fields_ordered', data['_fields_ordered'])
dynamic_fields = data.get('_dynamic_fields') or SON()
for k in dynamic_fields.keys():
setattr(self, k, data["_data"].get(k))
@@ -182,7 +197,7 @@ class BaseDocument(object):
"""Dictionary-style field access, set a field's value.
"""
# Ensure that the field exists before settings its value
if name not in self._fields:
if not self._dynamic and name not in self._fields:
raise KeyError(name)
return setattr(self, name, value)
@@ -214,8 +229,9 @@ class BaseDocument(object):
def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'):
if self.id == other.id:
return True
return self.id == other.id
if isinstance(other, DBRef):
return self._get_collection_name() == other.collection and self.id == other.id
return False
def __ne__(self, other):
@@ -317,7 +333,7 @@ class BaseDocument(object):
pk = "None"
if hasattr(self, 'pk'):
pk = self.pk
elif self._instance:
elif self._instance and hasattr(self._instance, 'pk'):
pk = self._instance.pk
message = "ValidationError (%s:%s) " % (self._class_name, pk)
raise ValidationError(message, errors=errors)
@@ -401,6 +417,8 @@ class BaseDocument(object):
else:
data = getattr(data, part, None)
if hasattr(data, "_changed_fields"):
if hasattr(data, "_is_document") and data._is_document:
continue
data._changed_fields = []
self._changed_fields = []
@@ -562,10 +580,6 @@ class BaseDocument(object):
# class if unavailable
class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data)
to_str_keys_recursive(data)
# Return correct subclass for document type
if class_name != cls._class_name:
@@ -603,6 +617,8 @@ class BaseDocument(object):
% (cls._class_name, errors))
raise InvalidDocumentError(msg)
if cls.STRICT:
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields)
obj = cls(__auto_convert=False, **data)
obj._changed_fields = changed_fields
obj._created = False
@@ -821,8 +837,17 @@ class BaseDocument(object):
# Look up subfield on the previous field
new_field = field.lookup_member(field_name)
if not new_field and isinstance(field, ComplexBaseField):
fields.append(field_name)
continue
if hasattr(field.field, 'document_type') and cls._dynamic \
and field.field.document_type._dynamic:
DynamicField = _import_class('DynamicField')
new_field = DynamicField(db_field=field_name)
else:
fields.append(field_name)
continue
elif not new_field and hasattr(field, 'document_type') and cls._dynamic \
and field.document_type._dynamic:
DynamicField = _import_class('DynamicField')
new_field = DynamicField(db_field=field_name)
elif not new_field:
raise LookUpError('Cannot resolve field "%s"'
% field_name)
@@ -842,7 +867,11 @@ class BaseDocument(object):
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():
if field.choices:
setattr(self,
if self._dynamic:
obj = self
else:
obj = type(self)
setattr(obj,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))

View File

@@ -359,7 +359,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class.id = field
# Set primary key if not defined by the document
new_class._auto_id_field = False
new_class._auto_id_field = getattr(parent_doc_cls,
'_auto_id_field', False)
if not new_class._meta.get('id_field'):
new_class._auto_id_field = True
new_class._meta['id_field'] = 'id'

View File

@@ -20,7 +20,8 @@ _dbs = {}
def register_connection(alias, name, host=None, port=None,
is_slave=False, read_preference=False, slaves=None,
username=None, password=None, **kwargs):
username=None, password=None, authentication_source=None,
**kwargs):
"""Add a connection.
:param alias: the name that will be used to refer to this connection
@@ -36,6 +37,7 @@ def register_connection(alias, name, host=None, port=None,
be a registered connection that has :attr:`is_slave` set to ``True``
:param username: username to authenticate with
:param password: password to authenticate with
:param authentication_source: database to authenticate against
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
"""
@@ -46,10 +48,11 @@ def register_connection(alias, name, host=None, port=None,
'host': host or 'localhost',
'port': port or 27017,
'is_slave': is_slave,
'read_preference': read_preference,
'slaves': slaves or [],
'username': username,
'password': password,
'read_preference': read_preference
'authentication_source': authentication_source
}
# Handle uri style connections
@@ -93,20 +96,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
raise ConnectionError(msg)
conn_settings = _connection_settings[alias].copy()
if hasattr(pymongo, 'version_tuple'): # Support for 2.1+
conn_settings.pop('name', None)
conn_settings.pop('slaves', None)
conn_settings.pop('is_slave', None)
conn_settings.pop('username', None)
conn_settings.pop('password', None)
else:
# Get all the slave connections
if 'slaves' in conn_settings:
slaves = []
for slave_alias in conn_settings['slaves']:
slaves.append(get_connection(slave_alias))
conn_settings['slaves'] = slaves
conn_settings.pop('read_preference', None)
conn_settings.pop('name', None)
conn_settings.pop('slaves', None)
conn_settings.pop('is_slave', None)
conn_settings.pop('username', None)
conn_settings.pop('password', None)
conn_settings.pop('authentication_source', None)
connection_class = MongoClient
if 'replicaSet' in conn_settings:
@@ -119,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_class = MongoReplicaSetClient
try:
_connections[alias] = connection_class(**conn_settings)
connection = None
connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems())
for alias, connection_settings in connection_settings_iterator:
connection_settings.pop('name', None)
connection_settings.pop('slaves', None)
connection_settings.pop('is_slave', None)
connection_settings.pop('username', None)
connection_settings.pop('password', None)
if conn_settings == connection_settings and _connections.get(alias, None):
connection = _connections[alias]
break
_connections[alias] = connection if connection else connection_class(**conn_settings)
except Exception, e:
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
return _connections[alias]
@@ -137,7 +144,8 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Authenticate if necessary
if conn_settings['username'] and conn_settings['password']:
db.authenticate(conn_settings['username'],
conn_settings['password'])
conn_settings['password'],
source=conn_settings['authentication_source'])
_dbs[alias] = db
return _dbs[alias]

View File

@@ -1,6 +1,5 @@
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import QuerySet
__all__ = ("switch_db", "switch_collection", "no_dereference",
@@ -162,12 +161,6 @@ class no_sub_classes(object):
return self.cls
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items
class query_counter(object):
""" Query_counter context manager to get the number of queries. """

View File

@@ -36,7 +36,7 @@ class DeReference(object):
if instance and isinstance(instance, (Document, EmbeddedDocument,
TopLevelDocumentMetaclass)):
doc_type = instance._fields.get(name)
if hasattr(doc_type, 'field'):
while hasattr(doc_type, 'field'):
doc_type = doc_type.field
if isinstance(doc_type, ReferenceField):
@@ -51,9 +51,19 @@ class DeReference(object):
return items
elif not field.dbref:
if not hasattr(items, 'items'):
items = [field.to_python(v)
if not isinstance(v, (DBRef, Document)) else v
for v in items]
def _get_items(items):
new_items = []
for v in items:
if isinstance(v, list):
new_items.append(_get_items(v))
elif not isinstance(v, (DBRef, Document)):
new_items.append(field.to_python(v))
else:
new_items.append(v)
return new_items
items = _get_items(items)
else:
items = dict([
(k, field.to_python(v))
@@ -114,11 +124,11 @@ class DeReference(object):
"""Fetch all references and convert to their document objects
"""
object_map = {}
for col, dbrefs in self.reference_map.iteritems():
for collection, dbrefs in self.reference_map.iteritems():
keys = object_map.keys()
refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys]))
if hasattr(col, 'objects'): # We have a document class for the refs
references = col.objects.in_bulk(refs)
if hasattr(collection, 'objects'): # We have a document class for the refs
references = collection.objects.in_bulk(refs)
for key, doc in references.iteritems():
object_map[key] = doc
else: # Generic reference: use the refs data to convert to document
@@ -126,19 +136,19 @@ class DeReference(object):
continue
if doc_type:
references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
references = doc_type._get_db()[collection].find({'_id': {'$in': refs}})
for ref in references:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
else:
references = get_db()[col].find({'_id': {'$in': refs}})
references = get_db()[collection].find({'_id': {'$in': refs}})
for ref in references:
if '_cls' in ref:
doc = get_document(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
''.join(x.capitalize()
for x in col.split('_')))._from_son(ref)
for x in collection.split('_')))._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc

View File

@@ -13,7 +13,8 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document)
from mongoengine.errors import ValidationError
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db, switch_collection
@@ -54,20 +55,21 @@ class EmbeddedDocument(BaseDocument):
dictionary.
"""
__slots__ = ('_instance')
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass
_instance = None
def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._instance = None
self._changed_fields = []
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.to_mongo() == other.to_mongo()
return self._data == other._data
return False
def __ne__(self, other):
@@ -125,6 +127,8 @@ class Document(BaseDocument):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
__slots__ = ('__objects' )
def pk():
"""Primary key alias
"""
@@ -180,7 +184,7 @@ class Document(BaseDocument):
def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, **kwargs):
_refs=None, save_condition=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -203,7 +207,8 @@ class Document(BaseDocument):
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw
to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g., version number)
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
set / unset. Saves are cascaded and any
@@ -217,6 +222,9 @@ class Document(BaseDocument):
meta['cascade'] = True. Also you can pass different kwargs to
the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values.
.. versionchanged:: 0.8.5
Optional save_condition that only overwrites existing documents
if the condition is satisfied in the current db record.
"""
signals.pre_save.send(self.__class__, document=self)
@@ -230,7 +238,8 @@ class Document(BaseDocument):
created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created)
try:
collection = self._get_collection()
@@ -243,7 +252,12 @@ class Document(BaseDocument):
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
select_dict = {'_id': object_id}
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
@@ -263,10 +277,12 @@ class Document(BaseDocument):
if removals:
update_query["$unset"] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=True, **write_concern)
upsert=upsert, **write_concern)
created = is_new_object(last_error)
if cascade is None:
cascade = self._meta.get('cascade', False) or cascade_kwargs is not None
@@ -293,12 +309,12 @@ class Document(BaseDocument):
raise NotUniqueError(message % unicode(err))
raise OperationError(message % unicode(err))
id_field = self._meta['id_field']
if id_field not in self._meta.get('shard_key', []):
if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self, created=created)
self._clear_changed_fields()
self._created = False
signals.post_save.send(self.__class__, document=self, created=created)
return self
def cascade_save(self, *args, **kwargs):
@@ -447,27 +463,41 @@ class Document(BaseDocument):
DeReference()([self], max_depth + 1)
return self
def reload(self, max_depth=1):
def reload(self, *fields, **kwargs):
"""Reloads all attributes from the database.
:param fields: (optional) args list of fields to reload
:param max_depth: (optional) depth of dereferencing to follow
.. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable
.. versionchanged:: 0.9 Can provide specific fields to reload
"""
max_depth = 1
if fields and isinstance(fields[0], int):
max_depth = fields[0]
fields = fields[1:]
elif "max_depth" in kwargs:
max_depth = kwargs["max_depth"]
if not self.pk:
raise self.DoesNotExist("Document does not exist")
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).limit(1).select_related(max_depth=max_depth)
**self._object_key).only(*fields).limit(1
).select_related(max_depth=max_depth)
if obj:
obj = obj[0]
else:
raise self.DoesNotExist("Document does not exist")
for field in self._fields_ordered:
setattr(self, field, self._reload(field, obj[field]))
if not fields or field in fields:
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = obj._changed_fields
self._created = False
return obj
return self
def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the

View File

@@ -760,7 +760,7 @@ class DictField(ComplexBaseField):
similar to an embedded document, but the structure is not defined.
.. note::
Required means it cannot be empty - as the default for ListFields is []
Required means it cannot be empty - as the default for DictFields is {}
.. versionadded:: 0.3
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
@@ -1554,6 +1554,14 @@ class SequenceField(BaseField):
return super(SequenceField, self).__set__(instance, value)
def prepare_query_value(self, op, value):
"""
This method is overriden in order to convert the query value into to required
type. We need to do this in order to be able to successfully compare query
values passed as string, the base implementation returns the value as is.
"""
return self.value_decorator(value)
def to_python(self, value):
if value is None:
value = self.generate()
@@ -1613,7 +1621,12 @@ class UUIDField(BaseField):
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
"""A list storing a longitude and latitude coordinate.
.. note:: this represents a generic point in a 2D plane and a legacy way of
representing a geo point. It admits 2d indexes but not "2dsphere" indexes
in MongoDB > 2.4 which are more natural for modeling geospatial points.
See :ref:`geospatial-indexes`
.. versionadded:: 0.4
"""
@@ -1635,7 +1648,7 @@ class GeoPointField(BaseField):
class PointField(GeoJsonBaseField):
"""A geo json field storing a latitude and longitude.
"""A GeoJSON field storing a longitude and latitude coordinate.
The data is represented as:
@@ -1654,7 +1667,7 @@ class PointField(GeoJsonBaseField):
class LineStringField(GeoJsonBaseField):
"""A geo json field storing a line of latitude and longitude coordinates.
"""A GeoJSON field storing a line of longitude and latitude coordinates.
The data is represented as:
@@ -1672,7 +1685,7 @@ class LineStringField(GeoJsonBaseField):
class PolygonField(GeoJsonBaseField):
"""A geo json field storing a polygon of latitude and longitude coordinates.
"""A GeoJSON field storing a polygon of longitude and latitude coordinates.
The data is represented as:

View File

@@ -3,8 +3,6 @@
import sys
PY3 = sys.version_info[0] == 3
PY25 = sys.version_info[:2] == (2, 5)
UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264
if PY3:
import codecs
@@ -29,33 +27,3 @@ else:
txt_type = unicode
str_types = (bin_type, txt_type)
if PY25:
def product(*args, **kwds):
pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]]
for pool in pools:
result = [x + [y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
reduce = reduce
else:
from itertools import product
from functools import reduce
# For use with Python 2.5
# converts all keys from unicode to str for d and all nested dictionaries
def to_str_keys_recursive(d):
if isinstance(d, list):
for val in d:
if isinstance(val, (dict, list)):
to_str_keys_recursive(val)
elif isinstance(d, dict):
for key, val in d.items():
if isinstance(val, (dict, list)):
to_str_keys_recursive(val)
if isinstance(key, unicode):
d[str(key)] = d.pop(key)
else:
raise ValueError("non list/dict parameter not allowed")

View File

@@ -7,17 +7,20 @@ import pprint
import re
import warnings
from bson import SON
from bson.code import Code
from bson import json_util
import pymongo
import pymongo.errors
from pymongo.common import validate_read_preference
from mongoengine import signals
from mongoengine.connection import get_db
from mongoengine.context_managers import switch_db
from mongoengine.common import _import_class
from mongoengine.base.common import get_document
from mongoengine.errors import (OperationError, NotUniqueError,
InvalidQueryError, LookUpError)
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
@@ -50,7 +53,7 @@ class BaseQuerySet(object):
self._initial_query = {}
self._where_clause = None
self._loaded_fields = QueryFieldList()
self._ordering = []
self._ordering = None
self._snapshot = False
self._timeout = True
self._class_check = True
@@ -146,7 +149,7 @@ class BaseQuerySet(object):
queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference))
if queryset._as_pymongo:
return queryset._get_as_pymongo(queryset._cursor.next())
return queryset._get_as_pymongo(queryset._cursor[key])
return queryset._document._from_son(queryset._cursor[key],
_auto_dereference=self._auto_dereference)
raise AttributeError
@@ -154,6 +157,22 @@ class BaseQuerySet(object):
def __iter__(self):
raise NotImplementedError
def _has_data(self):
""" Retrieves whether cursor has any data. """
queryset = self.order_by()
return False if queryset.first() is None else True
def __nonzero__(self):
""" Avoid to open all records in an if stmt in Py2. """
return self._has_data()
def __bool__(self):
""" Avoid to open all records in an if stmt in Py3. """
return self._has_data()
# Core functions
def all(self):
@@ -175,7 +194,7 @@ class BaseQuerySet(object):
.. versionadded:: 0.3
"""
queryset = self.clone()
queryset = queryset.limit(2)
queryset = queryset.order_by().limit(2)
queryset = queryset.filter(*q_objs, **query)
try:
@@ -389,7 +408,7 @@ class BaseQuerySet(object):
ref_q = document_cls.objects(**{field_name + '__in': self})
ref_q_count = ref_q.count()
if (doc != document_cls and ref_q_count > 0
or (doc == document_cls and ref_q_count > 0)):
or (doc == document_cls and ref_q_count > 0)):
ref_q.delete(write_concern=write_concern)
elif rule == NULLIFY:
document_cls.objects(**{field_name + '__in': self}).update(
@@ -443,6 +462,8 @@ class BaseQuerySet(object):
return result
elif result:
return result['n']
except pymongo.errors.DuplicateKeyError, err:
raise NotUniqueError(u'Update failed (%s)' % unicode(err))
except pymongo.errors.OperationFailure, err:
if unicode(err) == u'multi not coded yet':
message = u'update() method requires MongoDB 1.1.3+'
@@ -466,6 +487,59 @@ class BaseQuerySet(object):
return self.update(
upsert=upsert, multi=False, write_concern=write_concern, **update)
def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
"""Update and return the updated document.
Returns either the document before or after modification based on `new`
parameter. If no documents match the query and `upsert` is false,
returns ``None``. If upserting and `new` is false, returns ``None``.
If the full_response parameter is ``True``, the return value will be
the entire response object from the server, including the 'ok' and
'lastErrorObject' fields, rather than just the modified document.
This is useful mainly because the 'lastErrorObject' document holds
information about the command's execution.
:param upsert: insert if document doesn't exist (default ``False``)
:param full_response: return the entire response object from the
server (default ``False``)
:param remove: remove rather than updating (default ``False``)
:param new: return updated rather than original document
(default ``False``)
:param update: Django-style update keyword arguments
.. versionadded:: 0.9
"""
if remove and new:
raise OperationError("Conflicting parameters: remove and new")
if not update and not upsert and not remove:
raise OperationError("No update parameters, must either update or remove")
queryset = self.clone()
query = queryset._query
update = transform.update(queryset._document, **update)
sort = queryset._ordering
try:
result = queryset._collection.find_and_modify(
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
except pymongo.errors.DuplicateKeyError, err:
raise NotUniqueError(u"Update failed (%s)" % err)
except pymongo.errors.OperationFailure, err:
raise OperationError(u"Update failed (%s)" % err)
if full_response:
if result["value"] is not None:
result["value"] = self._document._from_son(result["value"])
else:
if result is not None:
result = self._document._from_son(result)
return result
def with_id(self, object_id):
"""Retrieve the object matching the id provided. Uses `object_id` only
and raises InvalidQueryError if a filter has been applied. Returns
@@ -522,6 +596,19 @@ class BaseQuerySet(object):
return self
def using(self, alias):
"""This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database.
:param alias: The database alias
.. versionadded:: 0.8
"""
with switch_db(self._document, alias) as cls:
collection = cls._get_collection()
return self.clone_into(self.__class__(self._document, collection))
def clone(self):
"""Creates a copy of the current
:class:`~mongoengine.queryset.QuerySet`
@@ -923,10 +1010,39 @@ class BaseQuerySet(object):
map_reduce_function = 'inline_map_reduce'
else:
map_reduce_function = 'map_reduce'
mr_args['out'] = output
if isinstance(output, basestring):
mr_args['out'] = output
elif isinstance(output, dict):
ordered_output = []
for part in ('replace', 'merge', 'reduce'):
value = output.get(part)
if value:
ordered_output.append((part, value))
break
else:
raise OperationError("actionData not specified for output")
db_alias = output.get('db_alias')
remaing_args = ['db', 'sharded', 'nonAtomic']
if db_alias:
ordered_output.append(('db', get_db(db_alias).name))
del remaing_args[0]
for part in remaing_args:
value = output.get(part)
if value:
ordered_output.append((part, value))
mr_args['out'] = SON(ordered_output)
results = getattr(queryset._collection, map_reduce_function)(
map_f, reduce_f, **mr_args)
map_f, reduce_f, **mr_args)
if map_reduce_function == 'map_reduce':
results = results.find()
@@ -1189,8 +1305,9 @@ class BaseQuerySet(object):
if self._ordering:
# Apply query ordering
self._cursor_obj.sort(self._ordering)
elif self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model
elif self._ordering is None and self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model, unless
# it's been explicitly cleared via order_by with no arguments
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
@@ -1362,7 +1479,7 @@ class BaseQuerySet(object):
for subdoc in subclasses:
try:
subfield = ".".join(f.db_field for f in
subdoc._lookup_field(field.split('.')))
subdoc._lookup_field(field.split('.')))
ret.append(subfield)
found = True
break
@@ -1392,7 +1509,7 @@ class BaseQuerySet(object):
pass
key_list.append((key, direction))
if self._cursor_obj:
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list
@@ -1450,6 +1567,7 @@ class BaseQuerySet(object):
# type of this field and use the corresponding
# .to_python(...)
from mongoengine.fields import EmbeddedDocumentField
obj = self._document
for chunk in path.split('.'):
obj = getattr(obj, chunk, None)
@@ -1460,6 +1578,7 @@ class BaseQuerySet(object):
if obj and data is not None:
data = obj.to_python(data)
return data
return clean(row)
def _sub_js_fields(self, code):
@@ -1468,6 +1587,7 @@ class BaseQuerySet(object):
substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor).
"""
def field_sub(match):
# Extract just the field name, and look up the field objects
field_name = match.group(1).split('.')

View File

@@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet):
queryset = self.clone()
queryset.rewind()
return queryset
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items

View File

@@ -3,6 +3,7 @@ from collections import defaultdict
import pymongo
from bson import SON
from mongoengine.connection import get_connection
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError, LookUpError
@@ -38,7 +39,7 @@ def query(_doc_cls=None, _field_operation=False, **query):
mongo_query.update(value)
continue
parts = key.split('__')
parts = key.rsplit('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is
@@ -115,14 +116,26 @@ def query(_doc_cls=None, _field_operation=False, **query):
if key in mongo_query and isinstance(mongo_query[key], dict):
mongo_query[key].update(value)
# $maxDistance needs to come last - convert to SON
if '$maxDistance' in mongo_query[key]:
value_dict = mongo_query[key]
value_dict = mongo_query[key]
if ('$maxDistance' in value_dict and '$near' in value_dict):
value_son = SON()
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
value_son['$maxDistance'] = value_dict['$maxDistance']
if isinstance(value_dict['$near'], dict):
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
if (get_connection().max_wire_version <= 1):
value_son['$maxDistance'] = value_dict['$maxDistance']
else:
value_son['$near'] = SON(value_son['$near'])
value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
else:
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
value_son['$maxDistance'] = value_dict['$maxDistance']
mongo_query[key] = value_son
else:
# Store for manually merging later

View File

@@ -1,8 +1,9 @@
import copy
from mongoengine.errors import InvalidQueryError
from mongoengine.python_support import product, reduce
from itertools import product
from functools import reduce
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform
__all__ = ('Q',)