Merge branch 'master' into pr/625
This commit is contained in:
@@ -15,7 +15,7 @@ import django
|
||||
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
|
||||
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
|
||||
|
||||
VERSION = (0, 8, 4)
|
||||
VERSION = (0, 8, 7)
|
||||
|
||||
|
||||
def get_version():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import operator
|
||||
import numbers
|
||||
from collections import Hashable
|
||||
from functools import partial
|
||||
|
||||
import pymongo
|
||||
@@ -12,8 +13,7 @@ from mongoengine import signals
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import (ValidationError, InvalidDocumentError,
|
||||
LookUpError)
|
||||
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
|
||||
to_str_keys_recursive)
|
||||
from mongoengine.python_support import PY3, txt_type
|
||||
|
||||
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
|
||||
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
|
||||
@@ -197,7 +197,7 @@ class BaseDocument(object):
|
||||
"""Dictionary-style field access, set a field's value.
|
||||
"""
|
||||
# Ensure that the field exists before settings its value
|
||||
if name not in self._fields:
|
||||
if not self._dynamic and name not in self._fields:
|
||||
raise KeyError(name)
|
||||
return setattr(self, name, value)
|
||||
|
||||
@@ -391,20 +391,41 @@ class BaseDocument(object):
|
||||
self._changed_fields.append(key)
|
||||
|
||||
def _clear_changed_fields(self):
|
||||
"""Using get_changed_fields iterate and remove any fields that are
|
||||
marked as changed"""
|
||||
for changed in self._get_changed_fields():
|
||||
parts = changed.split(".")
|
||||
data = self
|
||||
for part in parts:
|
||||
if isinstance(data, list):
|
||||
try:
|
||||
data = data[int(part)]
|
||||
except IndexError:
|
||||
data = None
|
||||
elif isinstance(data, dict):
|
||||
data = data.get(part, None)
|
||||
else:
|
||||
data = getattr(data, part, None)
|
||||
if hasattr(data, "_changed_fields"):
|
||||
data._changed_fields = []
|
||||
self._changed_fields = []
|
||||
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
|
||||
for field_name, field in self._fields.iteritems():
|
||||
if (isinstance(field, ComplexBaseField) and
|
||||
isinstance(field.field, EmbeddedDocumentField)):
|
||||
field_value = getattr(self, field_name, None)
|
||||
if field_value:
|
||||
for idx in (field_value if isinstance(field_value, dict)
|
||||
else xrange(len(field_value))):
|
||||
field_value[idx]._clear_changed_fields()
|
||||
elif isinstance(field, EmbeddedDocumentField):
|
||||
field_value = getattr(self, field_name, None)
|
||||
if field_value:
|
||||
field_value._clear_changed_fields()
|
||||
|
||||
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
|
||||
# Loop list / dict fields as they contain documents
|
||||
# Determine the iterator to use
|
||||
if not hasattr(data, 'items'):
|
||||
iterator = enumerate(data)
|
||||
else:
|
||||
iterator = data.iteritems()
|
||||
|
||||
for index, value in iterator:
|
||||
list_key = "%s%s." % (key, index)
|
||||
if hasattr(value, '_get_changed_fields'):
|
||||
changed = value._get_changed_fields(inspected)
|
||||
changed_fields += ["%s%s" % (list_key, k)
|
||||
for k in changed if k]
|
||||
elif isinstance(value, (list, tuple, dict)):
|
||||
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected)
|
||||
|
||||
def _get_changed_fields(self, inspected=None):
|
||||
"""Returns a list of all fields that have explicitly been changed.
|
||||
@@ -412,13 +433,12 @@ class BaseDocument(object):
|
||||
EmbeddedDocument = _import_class("EmbeddedDocument")
|
||||
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
|
||||
ReferenceField = _import_class("ReferenceField")
|
||||
_changed_fields = []
|
||||
_changed_fields += getattr(self, '_changed_fields', [])
|
||||
|
||||
changed_fields = []
|
||||
changed_fields += getattr(self, '_changed_fields', [])
|
||||
inspected = inspected or set()
|
||||
if hasattr(self, 'id'):
|
||||
if hasattr(self, 'id') and isinstance(self.id, Hashable):
|
||||
if self.id in inspected:
|
||||
return _changed_fields
|
||||
return changed_fields
|
||||
inspected.add(self.id)
|
||||
|
||||
for field_name in self._fields_ordered:
|
||||
@@ -434,29 +454,17 @@ class BaseDocument(object):
|
||||
if isinstance(field, ReferenceField):
|
||||
continue
|
||||
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
|
||||
and db_field_name not in _changed_fields):
|
||||
and db_field_name not in changed_fields):
|
||||
# Find all embedded fields that have been changed
|
||||
changed = data._get_changed_fields(inspected)
|
||||
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
|
||||
changed_fields += ["%s%s" % (key, k) for k in changed if k]
|
||||
elif (isinstance(data, (list, tuple, dict)) and
|
||||
db_field_name not in _changed_fields):
|
||||
# Loop list / dict fields as they contain documents
|
||||
# Determine the iterator to use
|
||||
if not hasattr(data, 'items'):
|
||||
iterator = enumerate(data)
|
||||
else:
|
||||
iterator = data.iteritems()
|
||||
for index, value in iterator:
|
||||
if not hasattr(value, '_get_changed_fields'):
|
||||
continue
|
||||
if (hasattr(field, 'field') and
|
||||
isinstance(field.field, ReferenceField)):
|
||||
continue
|
||||
list_key = "%s%s." % (key, index)
|
||||
changed = value._get_changed_fields(inspected)
|
||||
_changed_fields += ["%s%s" % (list_key, k)
|
||||
for k in changed if k]
|
||||
return _changed_fields
|
||||
db_field_name not in changed_fields):
|
||||
if (hasattr(field, 'field') and
|
||||
isinstance(field.field, ReferenceField)):
|
||||
continue
|
||||
self._nestable_types_changed_fields(changed_fields, key, data, inspected)
|
||||
return changed_fields
|
||||
|
||||
def _delta(self):
|
||||
"""Returns the delta (set, unset) of the changes for a document.
|
||||
@@ -552,10 +560,6 @@ class BaseDocument(object):
|
||||
# class if unavailable
|
||||
class_name = son.get('_cls', cls._class_name)
|
||||
data = dict(("%s" % key, value) for key, value in son.iteritems())
|
||||
if not UNICODE_KWARGS:
|
||||
# python 2.6.4 and lower cannot handle unicode keys
|
||||
# passed to class constructor example: cls(**data)
|
||||
to_str_keys_recursive(data)
|
||||
|
||||
# Return correct subclass for document type
|
||||
if class_name != cls._class_name:
|
||||
@@ -773,6 +777,9 @@ class BaseDocument(object):
|
||||
"""Lookup a field based on its attribute and return a list containing
|
||||
the field's parents and the field.
|
||||
"""
|
||||
|
||||
ListField = _import_class("ListField")
|
||||
|
||||
if not isinstance(parts, (list, tuple)):
|
||||
parts = [parts]
|
||||
fields = []
|
||||
@@ -780,7 +787,7 @@ class BaseDocument(object):
|
||||
|
||||
for field_name in parts:
|
||||
# Handle ListField indexing:
|
||||
if field_name.isdigit() and hasattr(field, 'field'):
|
||||
if field_name.isdigit() and isinstance(field, ListField):
|
||||
new_field = field.field
|
||||
fields.append(field_name)
|
||||
continue
|
||||
|
||||
@@ -89,12 +89,7 @@ class BaseField(object):
|
||||
return self
|
||||
|
||||
# Get value from document instance if available
|
||||
value = instance._data.get(self.name)
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = weakref.proxy(instance)
|
||||
return value
|
||||
return instance._data.get(self.name)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
"""Descriptor for assigning a value to a field in a document.
|
||||
@@ -116,6 +111,10 @@ class BaseField(object):
|
||||
# Values cant be compared eg: naive and tz datetimes
|
||||
# So mark it as changed
|
||||
instance._mark_as_changed(self.name)
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = weakref.proxy(instance)
|
||||
instance._data[self.name] = value
|
||||
|
||||
def error(self, message="", errors=None, field_name=None):
|
||||
@@ -203,7 +202,7 @@ class ComplexBaseField(BaseField):
|
||||
_dereference = _import_class("DeReference")()
|
||||
|
||||
self._auto_dereference = instance._fields[self.name]._auto_dereference
|
||||
if instance._initialised and dereference:
|
||||
if instance._initialised and dereference and instance._data.get(self.name):
|
||||
instance._data[self.name] = _dereference(
|
||||
instance._data.get(self.name), max_depth=1, instance=instance,
|
||||
name=self.name
|
||||
|
||||
@@ -25,7 +25,7 @@ def _import_class(cls_name):
|
||||
'GenericEmbeddedDocumentField', 'GeoPointField',
|
||||
'PointField', 'LineStringField', 'ListField',
|
||||
'PolygonField', 'ReferenceField', 'StringField',
|
||||
'ComplexBaseField')
|
||||
'ComplexBaseField', 'GeoJsonBaseField')
|
||||
queryset_classes = ('OperationError',)
|
||||
deref_classes = ('DeReference',)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ _connections = {}
|
||||
_dbs = {}
|
||||
|
||||
|
||||
def register_connection(alias, name, host='localhost', port=27017,
|
||||
def register_connection(alias, name, host=None, port=None,
|
||||
is_slave=False, read_preference=False, slaves=None,
|
||||
username=None, password=None, **kwargs):
|
||||
"""Add a connection.
|
||||
@@ -43,8 +43,8 @@ def register_connection(alias, name, host='localhost', port=27017,
|
||||
|
||||
conn_settings = {
|
||||
'name': name,
|
||||
'host': host,
|
||||
'port': port,
|
||||
'host': host or 'localhost',
|
||||
'port': port or 27017,
|
||||
'is_slave': is_slave,
|
||||
'slaves': slaves or [],
|
||||
'username': username,
|
||||
@@ -53,16 +53,15 @@ def register_connection(alias, name, host='localhost', port=27017,
|
||||
}
|
||||
|
||||
# Handle uri style connections
|
||||
if "://" in host:
|
||||
uri_dict = uri_parser.parse_uri(host)
|
||||
if "://" in conn_settings['host']:
|
||||
uri_dict = uri_parser.parse_uri(conn_settings['host'])
|
||||
conn_settings.update({
|
||||
'host': host,
|
||||
'name': uri_dict.get('database') or name,
|
||||
'username': uri_dict.get('username'),
|
||||
'password': uri_dict.get('password'),
|
||||
'read_preference': read_preference,
|
||||
})
|
||||
if "replicaSet" in host:
|
||||
if "replicaSet" in conn_settings['host']:
|
||||
conn_settings['replicaSet'] = True
|
||||
|
||||
conn_settings.update(kwargs)
|
||||
@@ -94,20 +93,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
raise ConnectionError(msg)
|
||||
conn_settings = _connection_settings[alias].copy()
|
||||
|
||||
if hasattr(pymongo, 'version_tuple'): # Support for 2.1+
|
||||
conn_settings.pop('name', None)
|
||||
conn_settings.pop('slaves', None)
|
||||
conn_settings.pop('is_slave', None)
|
||||
conn_settings.pop('username', None)
|
||||
conn_settings.pop('password', None)
|
||||
else:
|
||||
# Get all the slave connections
|
||||
if 'slaves' in conn_settings:
|
||||
slaves = []
|
||||
for slave_alias in conn_settings['slaves']:
|
||||
slaves.append(get_connection(slave_alias))
|
||||
conn_settings['slaves'] = slaves
|
||||
conn_settings.pop('read_preference', None)
|
||||
conn_settings.pop('name', None)
|
||||
conn_settings.pop('slaves', None)
|
||||
conn_settings.pop('is_slave', None)
|
||||
conn_settings.pop('username', None)
|
||||
conn_settings.pop('password', None)
|
||||
|
||||
connection_class = MongoClient
|
||||
if 'replicaSet' in conn_settings:
|
||||
@@ -120,7 +110,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
connection_class = MongoReplicaSetClient
|
||||
|
||||
try:
|
||||
_connections[alias] = connection_class(**conn_settings)
|
||||
connection = None
|
||||
connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems())
|
||||
for alias, connection_settings in connection_settings_iterator:
|
||||
connection_settings.pop('name', None)
|
||||
connection_settings.pop('slaves', None)
|
||||
connection_settings.pop('is_slave', None)
|
||||
connection_settings.pop('username', None)
|
||||
connection_settings.pop('password', None)
|
||||
if conn_settings == connection_settings and _connections.get(alias, None):
|
||||
connection = _connections[alias]
|
||||
break
|
||||
|
||||
_connections[alias] = connection if connection else connection_class(**conn_settings)
|
||||
except Exception, e:
|
||||
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
|
||||
return _connections[alias]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
from mongoengine.queryset import QuerySet
|
||||
|
||||
|
||||
__all__ = ("switch_db", "switch_collection", "no_dereference",
|
||||
@@ -162,12 +161,6 @@ class no_sub_classes(object):
|
||||
return self.cls
|
||||
|
||||
|
||||
class QuerySetNoDeRef(QuerySet):
|
||||
"""Special no_dereference QuerySet"""
|
||||
def __dereference(items, max_depth=1, instance=None, name=None):
|
||||
return items
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
""" Query_counter context manager to get the number of queries. """
|
||||
|
||||
|
||||
@@ -8,6 +8,10 @@ from django.contrib import auth
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from .utils import datetime_now
|
||||
|
||||
REDIRECT_FIELD_NAME = 'next'
|
||||
|
||||
try:
|
||||
from django.contrib.auth.hashers import check_password, make_password
|
||||
except ImportError:
|
||||
@@ -33,10 +37,6 @@ except ImportError:
|
||||
hash = get_hexdigest(algo, salt, raw_password)
|
||||
return '%s$%s$%s' % (algo, salt, hash)
|
||||
|
||||
from .utils import datetime_now
|
||||
|
||||
REDIRECT_FIELD_NAME = 'next'
|
||||
|
||||
|
||||
class ContentType(Document):
|
||||
name = StringField(max_length=100)
|
||||
@@ -230,6 +230,9 @@ class User(Document):
|
||||
date_joined = DateTimeField(default=datetime_now,
|
||||
verbose_name=_('date joined'))
|
||||
|
||||
user_permissions = ListField(ReferenceField(Permission), verbose_name=_('user permissions'),
|
||||
help_text=_('Permissions for the user.'))
|
||||
|
||||
USERNAME_FIELD = 'username'
|
||||
REQUIRED_FIELDS = ['email']
|
||||
|
||||
@@ -378,9 +381,10 @@ class MongoEngineBackend(object):
|
||||
supports_object_permissions = False
|
||||
supports_anonymous_user = False
|
||||
supports_inactive_user = False
|
||||
_user_doc = False
|
||||
|
||||
def authenticate(self, username=None, password=None):
|
||||
user = User.objects(username=username).first()
|
||||
user = self.user_document.objects(username=username).first()
|
||||
if user:
|
||||
if password and user.check_password(password):
|
||||
backend = auth.get_backends()[0]
|
||||
@@ -389,8 +393,14 @@ class MongoEngineBackend(object):
|
||||
return None
|
||||
|
||||
def get_user(self, user_id):
|
||||
return User.objects.with_id(user_id)
|
||||
return self.user_document.objects.with_id(user_id)
|
||||
|
||||
@property
|
||||
def user_document(self):
|
||||
if self._user_doc is False:
|
||||
from .mongo_auth.models import get_user_document
|
||||
self._user_doc = get_user_document()
|
||||
return self._user_doc
|
||||
|
||||
def get_user(userid):
|
||||
"""Returns a User object from an id (User.id). Django's equivalent takes
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.contrib.auth.models import UserManager
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import models
|
||||
@@ -105,3 +106,10 @@ class MongoUser(models.Model):
|
||||
"""
|
||||
|
||||
objects = MongoUserManager()
|
||||
|
||||
class Meta:
|
||||
app_label = 'mongo_auth'
|
||||
|
||||
def set_password(self, password):
|
||||
"""Doesn't do anything, but works around the issue with Django 1.6."""
|
||||
make_password(password)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from bson import json_util
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.backends.base import SessionBase, CreateError
|
||||
from django.core.exceptions import SuspiciousOperation
|
||||
@@ -55,6 +56,12 @@ class SessionStore(SessionBase):
|
||||
"""A MongoEngine-based session store for Django.
|
||||
"""
|
||||
|
||||
def _get_session(self, *args, **kwargs):
|
||||
sess = super(SessionStore, self)._get_session(*args, **kwargs)
|
||||
if sess.get('_auth_user_id', None):
|
||||
sess['_auth_user_id'] = str(sess.get('_auth_user_id'))
|
||||
return sess
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
s = MongoSession.objects(session_key=self.session_key,
|
||||
@@ -103,3 +110,15 @@ class SessionStore(SessionBase):
|
||||
return
|
||||
session_key = self.session_key
|
||||
MongoSession.objects(session_key=session_key).delete()
|
||||
|
||||
|
||||
class BSONSerializer(object):
|
||||
"""
|
||||
Serializer that can handle BSON types (eg ObjectId).
|
||||
"""
|
||||
def dumps(self, obj):
|
||||
return json_util.dumps(obj, separators=(',', ':')).encode('ascii')
|
||||
|
||||
def loads(self, data):
|
||||
return json_util.loads(data.decode('ascii'))
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class GridFSStorage(Storage):
|
||||
"""Find the documents in the store with the given name
|
||||
"""
|
||||
docs = self.document.objects
|
||||
doc = [d for d in docs if getattr(d, self.field).name == name]
|
||||
doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name]
|
||||
if doc:
|
||||
return doc[0]
|
||||
else:
|
||||
|
||||
@@ -12,7 +12,9 @@ from mongoengine.common import _import_class
|
||||
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
|
||||
BaseDocument, BaseDict, BaseList,
|
||||
ALLOW_INHERITANCE, get_document)
|
||||
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
|
||||
from mongoengine.errors import ValidationError
|
||||
from mongoengine.queryset import (OperationError, NotUniqueError,
|
||||
QuerySet, transform)
|
||||
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
||||
from mongoengine.context_managers import switch_db, switch_collection
|
||||
|
||||
@@ -67,7 +69,7 @@ class EmbeddedDocument(BaseDocument):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self._data == other._data
|
||||
return self.to_mongo() == other.to_mongo()
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
@@ -182,7 +184,7 @@ class Document(BaseDocument):
|
||||
|
||||
def save(self, force_insert=False, validate=True, clean=True,
|
||||
write_concern=None, cascade=None, cascade_kwargs=None,
|
||||
_refs=None, **kwargs):
|
||||
_refs=None, save_condition=None, **kwargs):
|
||||
"""Save the :class:`~mongoengine.Document` to the database. If the
|
||||
document already exists, it will be updated, otherwise it will be
|
||||
created.
|
||||
@@ -205,7 +207,8 @@ class Document(BaseDocument):
|
||||
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw
|
||||
to cascading saves. Implies ``cascade=True``.
|
||||
:param _refs: A list of processed references used in cascading saves
|
||||
|
||||
:param save_condition: only perform save if matching record in db
|
||||
satisfies condition(s) (e.g., version number)
|
||||
.. versionchanged:: 0.5
|
||||
In existing documents it only saves changed fields using
|
||||
set / unset. Saves are cascaded and any
|
||||
@@ -219,6 +222,9 @@ class Document(BaseDocument):
|
||||
meta['cascade'] = True. Also you can pass different kwargs to
|
||||
the cascade save using cascade_kwargs which overwrites the
|
||||
existing kwargs with custom values.
|
||||
.. versionchanged:: 0.8.5
|
||||
Optional save_condition that only overwrites existing documents
|
||||
if the condition is satisfied in the current db record.
|
||||
"""
|
||||
signals.pre_save.send(self.__class__, document=self)
|
||||
|
||||
@@ -232,7 +238,8 @@ class Document(BaseDocument):
|
||||
|
||||
created = ('_id' not in doc or self._created or force_insert)
|
||||
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self,
|
||||
created=created)
|
||||
|
||||
try:
|
||||
collection = self._get_collection()
|
||||
@@ -245,7 +252,12 @@ class Document(BaseDocument):
|
||||
object_id = doc['_id']
|
||||
updates, removals = self._delta()
|
||||
# Need to add shard key to query, or you get an error
|
||||
select_dict = {'_id': object_id}
|
||||
if save_condition is not None:
|
||||
select_dict = transform.query(self.__class__,
|
||||
**save_condition)
|
||||
else:
|
||||
select_dict = {}
|
||||
select_dict['_id'] = object_id
|
||||
shard_key = self.__class__._meta.get('shard_key', tuple())
|
||||
for k in shard_key:
|
||||
actual_key = self._db_field_map.get(k, k)
|
||||
@@ -265,10 +277,12 @@ class Document(BaseDocument):
|
||||
if removals:
|
||||
update_query["$unset"] = removals
|
||||
if updates or removals:
|
||||
upsert = save_condition is None
|
||||
last_error = collection.update(select_dict, update_query,
|
||||
upsert=True, **write_concern)
|
||||
upsert=upsert, **write_concern)
|
||||
created = is_new_object(last_error)
|
||||
|
||||
|
||||
if cascade is None:
|
||||
cascade = self._meta.get('cascade', False) or cascade_kwargs is not None
|
||||
|
||||
@@ -283,7 +297,9 @@ class Document(BaseDocument):
|
||||
kwargs.update(cascade_kwargs)
|
||||
kwargs['_refs'] = _refs
|
||||
self.cascade_save(**kwargs)
|
||||
|
||||
except pymongo.errors.DuplicateKeyError, err:
|
||||
message = u'Tried to save duplicate unique keys (%s)'
|
||||
raise NotUniqueError(message % unicode(err))
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
message = 'Could not save document (%s)'
|
||||
if re.match('^E1100[01] duplicate key', unicode(err)):
|
||||
@@ -453,14 +469,16 @@ class Document(BaseDocument):
|
||||
.. versionadded:: 0.1.2
|
||||
.. versionchanged:: 0.6 Now chainable
|
||||
"""
|
||||
if not self.pk:
|
||||
raise self.DoesNotExist("Document does not exist")
|
||||
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
|
||||
**self._object_key).limit(1).select_related(max_depth=max_depth)
|
||||
**self._object_key).limit(1).select_related(max_depth=max_depth)
|
||||
|
||||
|
||||
if obj:
|
||||
obj = obj[0]
|
||||
else:
|
||||
msg = "Reloaded document has been deleted"
|
||||
raise OperationError(msg)
|
||||
raise self.DoesNotExist("Document does not exist")
|
||||
for field in self._fields_ordered:
|
||||
setattr(self, field, self._reload(field, obj[field]))
|
||||
self._changed_fields = obj._changed_fields
|
||||
@@ -550,6 +568,8 @@ class Document(BaseDocument):
|
||||
index_cls = cls._meta.get('index_cls', True)
|
||||
|
||||
collection = cls._get_collection()
|
||||
if collection.read_preference > 1:
|
||||
return
|
||||
|
||||
# determine if an index which we are creating includes
|
||||
# _cls as its first field; if so, we can avoid creating
|
||||
|
||||
@@ -42,7 +42,8 @@ __all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
|
||||
'GenericReferenceField', 'BinaryField', 'GridFSError',
|
||||
'GridFSProxy', 'FileField', 'ImageGridFsProxy',
|
||||
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
|
||||
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField']
|
||||
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField',
|
||||
'GeoJsonBaseField']
|
||||
|
||||
|
||||
RECURSIVE_REFERENCE_CONSTANT = 'self'
|
||||
@@ -152,7 +153,7 @@ class EmailField(StringField):
|
||||
EMAIL_REGEX = re.compile(
|
||||
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
|
||||
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
|
||||
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
|
||||
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
@@ -304,7 +305,10 @@ class DecimalField(BaseField):
|
||||
return value
|
||||
|
||||
# Convert to string for python 2.6 before casting to Decimal
|
||||
value = decimal.Decimal("%s" % value)
|
||||
try:
|
||||
value = decimal.Decimal("%s" % value)
|
||||
except decimal.InvalidOperation:
|
||||
return value
|
||||
return value.quantize(self.precision, rounding=self.rounding)
|
||||
|
||||
def to_mongo(self, value):
|
||||
@@ -387,7 +391,7 @@ class DateTimeField(BaseField):
|
||||
if dateutil:
|
||||
try:
|
||||
return dateutil.parser.parse(value)
|
||||
except ValueError:
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
# split usecs, because they are not recognized by strptime.
|
||||
@@ -735,13 +739,28 @@ class SortedListField(ListField):
|
||||
reverse=self._order_reverse)
|
||||
return sorted(value, reverse=self._order_reverse)
|
||||
|
||||
def key_not_string(d):
|
||||
""" Helper function to recursively determine if any key in a dictionary is
|
||||
not a string.
|
||||
"""
|
||||
for k, v in d.items():
|
||||
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)):
|
||||
return True
|
||||
|
||||
def key_has_dot_or_dollar(d):
|
||||
""" Helper function to recursively determine if any key in a dictionary
|
||||
contains a dot or a dollar sign.
|
||||
"""
|
||||
for k, v in d.items():
|
||||
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
|
||||
return True
|
||||
|
||||
class DictField(ComplexBaseField):
|
||||
"""A dictionary field that wraps a standard Python dictionary. This is
|
||||
similar to an embedded document, but the structure is not defined.
|
||||
|
||||
.. note::
|
||||
Required means it cannot be empty - as the default for ListFields is []
|
||||
Required means it cannot be empty - as the default for DictFields is {}
|
||||
|
||||
.. versionadded:: 0.3
|
||||
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
|
||||
@@ -761,11 +780,11 @@ class DictField(ComplexBaseField):
|
||||
if not isinstance(value, dict):
|
||||
self.error('Only dictionaries may be used in a DictField')
|
||||
|
||||
if any(k for k in value.keys() if not isinstance(k, basestring)):
|
||||
if key_not_string(value):
|
||||
msg = ("Invalid dictionary key - documents must "
|
||||
"have only string keys")
|
||||
self.error(msg)
|
||||
if any(('.' in k or '$' in k) for k in value.keys()):
|
||||
if key_has_dot_or_dollar(value):
|
||||
self.error('Invalid dictionary key name - keys may not contain "."'
|
||||
' or "$" characters')
|
||||
super(DictField, self).validate(value)
|
||||
@@ -1004,7 +1023,10 @@ class GenericReferenceField(BaseField):
|
||||
id_ = id_field.to_mongo(id_)
|
||||
collection = document._get_collection_name()
|
||||
ref = DBRef(collection, id_)
|
||||
return {'_cls': document._class_name, '_ref': ref}
|
||||
return SON((
|
||||
('_cls', document._class_name),
|
||||
('_ref', ref)
|
||||
))
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if value is None:
|
||||
@@ -1591,7 +1613,12 @@ class UUIDField(BaseField):
|
||||
|
||||
|
||||
class GeoPointField(BaseField):
|
||||
"""A list storing a latitude and longitude.
|
||||
"""A list storing a longitude and latitude coordinate.
|
||||
|
||||
.. note:: this represents a generic point in a 2D plane and a legacy way of
|
||||
representing a geo point. It admits 2d indexes but not "2dsphere" indexes
|
||||
in MongoDB > 2.4 which are more natural for modeling geospatial points.
|
||||
See :ref:`geospatial-indexes`
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
@@ -1613,7 +1640,7 @@ class GeoPointField(BaseField):
|
||||
|
||||
|
||||
class PointField(GeoJsonBaseField):
|
||||
"""A geo json field storing a latitude and longitude.
|
||||
"""A GeoJSON field storing a longitude and latitude coordinate.
|
||||
|
||||
The data is represented as:
|
||||
|
||||
@@ -1632,7 +1659,7 @@ class PointField(GeoJsonBaseField):
|
||||
|
||||
|
||||
class LineStringField(GeoJsonBaseField):
|
||||
"""A geo json field storing a line of latitude and longitude coordinates.
|
||||
"""A GeoJSON field storing a line of longitude and latitude coordinates.
|
||||
|
||||
The data is represented as:
|
||||
|
||||
@@ -1650,7 +1677,7 @@ class LineStringField(GeoJsonBaseField):
|
||||
|
||||
|
||||
class PolygonField(GeoJsonBaseField):
|
||||
"""A geo json field storing a polygon of latitude and longitude coordinates.
|
||||
"""A GeoJSON field storing a polygon of longitude and latitude coordinates.
|
||||
|
||||
The data is represented as:
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import sys
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
PY25 = sys.version_info[:2] == (2, 5)
|
||||
UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264
|
||||
|
||||
if PY3:
|
||||
import codecs
|
||||
@@ -29,33 +27,3 @@ else:
|
||||
txt_type = unicode
|
||||
|
||||
str_types = (bin_type, txt_type)
|
||||
|
||||
if PY25:
|
||||
def product(*args, **kwds):
|
||||
pools = map(tuple, args) * kwds.get('repeat', 1)
|
||||
result = [[]]
|
||||
for pool in pools:
|
||||
result = [x + [y] for x in result for y in pool]
|
||||
for prod in result:
|
||||
yield tuple(prod)
|
||||
reduce = reduce
|
||||
else:
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
|
||||
|
||||
# For use with Python 2.5
|
||||
# converts all keys from unicode to str for d and all nested dictionaries
|
||||
def to_str_keys_recursive(d):
|
||||
if isinstance(d, list):
|
||||
for val in d:
|
||||
if isinstance(val, (dict, list)):
|
||||
to_str_keys_recursive(val)
|
||||
elif isinstance(d, dict):
|
||||
for key, val in d.items():
|
||||
if isinstance(val, (dict, list)):
|
||||
to_str_keys_recursive(val)
|
||||
if isinstance(key, unicode):
|
||||
d[str(key)] = d.pop(key)
|
||||
else:
|
||||
raise ValueError("non list/dict parameter not allowed")
|
||||
|
||||
@@ -10,14 +10,15 @@ import warnings
|
||||
from bson.code import Code
|
||||
from bson import json_util
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from pymongo.common import validate_read_preference
|
||||
|
||||
from mongoengine import signals
|
||||
from mongoengine.context_managers import switch_db
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.base.common import get_document
|
||||
from mongoengine.errors import (OperationError, NotUniqueError,
|
||||
InvalidQueryError, LookUpError)
|
||||
|
||||
from mongoengine.queryset import transform
|
||||
from mongoengine.queryset.field_list import QueryFieldList
|
||||
from mongoengine.queryset.visitor import Q, QNode
|
||||
@@ -50,7 +51,7 @@ class BaseQuerySet(object):
|
||||
self._initial_query = {}
|
||||
self._where_clause = None
|
||||
self._loaded_fields = QueryFieldList()
|
||||
self._ordering = []
|
||||
self._ordering = None
|
||||
self._snapshot = False
|
||||
self._timeout = True
|
||||
self._class_check = True
|
||||
@@ -154,6 +155,22 @@ class BaseQuerySet(object):
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _has_data(self):
|
||||
""" Retrieves whether cursor has any data. """
|
||||
|
||||
queryset = self.order_by()
|
||||
return False if queryset.first() is None else True
|
||||
|
||||
def __nonzero__(self):
|
||||
""" Avoid to open all records in an if stmt in Py2. """
|
||||
|
||||
return self._has_data()
|
||||
|
||||
def __bool__(self):
|
||||
""" Avoid to open all records in an if stmt in Py3. """
|
||||
|
||||
return self._has_data()
|
||||
|
||||
# Core functions
|
||||
|
||||
def all(self):
|
||||
@@ -302,8 +319,11 @@ class BaseQuerySet(object):
|
||||
signals.pre_bulk_insert.send(self._document, documents=docs)
|
||||
try:
|
||||
ids = self._collection.insert(raw, **write_concern)
|
||||
except pymongo.errors.DuplicateKeyError, err:
|
||||
message = 'Could not save document (%s)';
|
||||
raise NotUniqueError(message % unicode(err))
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
message = 'Could not save document (%s)'
|
||||
message = 'Could not save document (%s)';
|
||||
if re.match('^E1100[01] duplicate key', unicode(err)):
|
||||
# E11000 - duplicate key error index
|
||||
# E11001 - duplicate key on update
|
||||
@@ -331,7 +351,7 @@ class BaseQuerySet(object):
|
||||
:meth:`skip` that has been applied to this cursor into account when
|
||||
getting the count
|
||||
"""
|
||||
if self._limit == 0 and with_limit_and_skip:
|
||||
if self._limit == 0 and with_limit_and_skip or self._none:
|
||||
return 0
|
||||
return self._cursor.count(with_limit_and_skip=with_limit_and_skip)
|
||||
|
||||
@@ -386,7 +406,7 @@ class BaseQuerySet(object):
|
||||
ref_q = document_cls.objects(**{field_name + '__in': self})
|
||||
ref_q_count = ref_q.count()
|
||||
if (doc != document_cls and ref_q_count > 0
|
||||
or (doc == document_cls and ref_q_count > 0)):
|
||||
or (doc == document_cls and ref_q_count > 0)):
|
||||
ref_q.delete(write_concern=write_concern)
|
||||
elif rule == NULLIFY:
|
||||
document_cls.objects(**{field_name + '__in': self}).update(
|
||||
@@ -440,6 +460,8 @@ class BaseQuerySet(object):
|
||||
return result
|
||||
elif result:
|
||||
return result['n']
|
||||
except pymongo.errors.DuplicateKeyError, err:
|
||||
raise NotUniqueError(u'Update failed (%s)' % unicode(err))
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
if unicode(err) == u'multi not coded yet':
|
||||
message = u'update() method requires MongoDB 1.1.3+'
|
||||
@@ -463,6 +485,59 @@ class BaseQuerySet(object):
|
||||
return self.update(
|
||||
upsert=upsert, multi=False, write_concern=write_concern, **update)
|
||||
|
||||
def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
|
||||
"""Update and return the updated document.
|
||||
|
||||
Returns either the document before or after modification based on `new`
|
||||
parameter. If no documents match the query and `upsert` is false,
|
||||
returns ``None``. If upserting and `new` is false, returns ``None``.
|
||||
|
||||
If the full_response parameter is ``True``, the return value will be
|
||||
the entire response object from the server, including the 'ok' and
|
||||
'lastErrorObject' fields, rather than just the modified document.
|
||||
This is useful mainly because the 'lastErrorObject' document holds
|
||||
information about the command's execution.
|
||||
|
||||
:param upsert: insert if document doesn't exist (default ``False``)
|
||||
:param full_response: return the entire response object from the
|
||||
server (default ``False``)
|
||||
:param remove: remove rather than updating (default ``False``)
|
||||
:param new: return updated rather than original document
|
||||
(default ``False``)
|
||||
:param update: Django-style update keyword arguments
|
||||
|
||||
.. versionadded:: 0.9
|
||||
"""
|
||||
|
||||
if remove and new:
|
||||
raise OperationError("Conflicting parameters: remove and new")
|
||||
|
||||
if not update and not upsert and not remove:
|
||||
raise OperationError("No update parameters, must either update or remove")
|
||||
|
||||
queryset = self.clone()
|
||||
query = queryset._query
|
||||
update = transform.update(queryset._document, **update)
|
||||
sort = queryset._ordering
|
||||
|
||||
try:
|
||||
result = queryset._collection.find_and_modify(
|
||||
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
|
||||
full_response=full_response, **self._cursor_args)
|
||||
except pymongo.errors.DuplicateKeyError, err:
|
||||
raise NotUniqueError(u"Update failed (%s)" % err)
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
raise OperationError(u"Update failed (%s)" % err)
|
||||
|
||||
if full_response:
|
||||
if result["value"] is not None:
|
||||
result["value"] = self._document._from_son(result["value"])
|
||||
else:
|
||||
if result is not None:
|
||||
result = self._document._from_son(result)
|
||||
|
||||
return result
|
||||
|
||||
def with_id(self, object_id):
|
||||
"""Retrieve the object matching the id provided. Uses `object_id` only
|
||||
and raises InvalidQueryError if a filter has been applied. Returns
|
||||
@@ -519,6 +594,19 @@ class BaseQuerySet(object):
|
||||
|
||||
return self
|
||||
|
||||
def using(self, alias):
|
||||
"""This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database.
|
||||
|
||||
:param alias: The database alias
|
||||
|
||||
.. versionadded:: 0.8
|
||||
"""
|
||||
|
||||
with switch_db(self._document, alias) as cls:
|
||||
collection = cls._get_collection()
|
||||
|
||||
return self.clone_into(self.__class__(self._document, collection))
|
||||
|
||||
def clone(self):
|
||||
"""Creates a copy of the current
|
||||
:class:`~mongoengine.queryset.QuerySet`
|
||||
@@ -621,8 +709,15 @@ class BaseQuerySet(object):
|
||||
try:
|
||||
field = self._fields_to_dbfields([field]).pop()
|
||||
finally:
|
||||
return self._dereference(queryset._cursor.distinct(field), 1,
|
||||
name=field, instance=self._document)
|
||||
distinct = self._dereference(queryset._cursor.distinct(field), 1,
|
||||
name=field, instance=self._document)
|
||||
|
||||
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
|
||||
doc_field = getattr(self._document._fields.get(field), "field", None)
|
||||
instance = getattr(doc_field, "document_type", False)
|
||||
if instance:
|
||||
distinct = [instance(**doc) for doc in distinct]
|
||||
return distinct
|
||||
|
||||
def only(self, *fields):
|
||||
"""Load only a subset of this document's fields. ::
|
||||
@@ -850,7 +945,7 @@ class BaseQuerySet(object):
|
||||
:param output: output collection name, if set to 'inline' will try to
|
||||
use :class:`~pymongo.collection.Collection.inline_map_reduce`
|
||||
This can also be a dictionary containing output options
|
||||
see: http://docs.mongodb.org/manual/reference/commands/#mapReduce
|
||||
see: http://docs.mongodb.org/manual/reference/command/mapReduce/#dbcmd.mapReduce
|
||||
:param finalize_f: finalize function, an optional function that
|
||||
performs any post-reduction processing.
|
||||
:param scope: values to insert into map/reduce global scope. Optional.
|
||||
@@ -916,7 +1011,7 @@ class BaseQuerySet(object):
|
||||
mr_args['out'] = output
|
||||
|
||||
results = getattr(queryset._collection, map_reduce_function)(
|
||||
map_f, reduce_f, **mr_args)
|
||||
map_f, reduce_f, **mr_args)
|
||||
|
||||
if map_reduce_function == 'map_reduce':
|
||||
results = results.find()
|
||||
@@ -1179,8 +1274,9 @@ class BaseQuerySet(object):
|
||||
if self._ordering:
|
||||
# Apply query ordering
|
||||
self._cursor_obj.sort(self._ordering)
|
||||
elif self._document._meta['ordering']:
|
||||
# Otherwise, apply the ordering from the document model
|
||||
elif self._ordering is None and self._document._meta['ordering']:
|
||||
# Otherwise, apply the ordering from the document model, unless
|
||||
# it's been explicitly cleared via order_by with no arguments
|
||||
order = self._get_order_by(self._document._meta['ordering'])
|
||||
self._cursor_obj.sort(order)
|
||||
|
||||
@@ -1352,7 +1448,7 @@ class BaseQuerySet(object):
|
||||
for subdoc in subclasses:
|
||||
try:
|
||||
subfield = ".".join(f.db_field for f in
|
||||
subdoc._lookup_field(field.split('.')))
|
||||
subdoc._lookup_field(field.split('.')))
|
||||
ret.append(subfield)
|
||||
found = True
|
||||
break
|
||||
@@ -1382,7 +1478,7 @@ class BaseQuerySet(object):
|
||||
pass
|
||||
key_list.append((key, direction))
|
||||
|
||||
if self._cursor_obj:
|
||||
if self._cursor_obj and key_list:
|
||||
self._cursor_obj.sort(key_list)
|
||||
return key_list
|
||||
|
||||
@@ -1440,6 +1536,7 @@ class BaseQuerySet(object):
|
||||
# type of this field and use the corresponding
|
||||
# .to_python(...)
|
||||
from mongoengine.fields import EmbeddedDocumentField
|
||||
|
||||
obj = self._document
|
||||
for chunk in path.split('.'):
|
||||
obj = getattr(obj, chunk, None)
|
||||
@@ -1450,6 +1547,7 @@ class BaseQuerySet(object):
|
||||
if obj and data is not None:
|
||||
data = obj.to_python(data)
|
||||
return data
|
||||
|
||||
return clean(row)
|
||||
|
||||
def _sub_js_fields(self, code):
|
||||
@@ -1458,6 +1556,7 @@ class BaseQuerySet(object):
|
||||
substituted for the MongoDB name of the field (specified using the
|
||||
:attr:`name` keyword argument in a field's constructor).
|
||||
"""
|
||||
|
||||
def field_sub(match):
|
||||
# Extract just the field name, and look up the field objects
|
||||
field_name = match.group(1).split('.')
|
||||
@@ -1491,4 +1590,4 @@ class BaseQuerySet(object):
|
||||
msg = ("Doc.objects()._ensure_indexes() is deprecated. "
|
||||
"Use Doc.ensure_indexes() instead.")
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
self._document.__class__.ensure_indexes()
|
||||
self._document.__class__.ensure_indexes()
|
||||
|
||||
@@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet):
|
||||
queryset = self.clone()
|
||||
queryset.rewind()
|
||||
return queryset
|
||||
|
||||
|
||||
class QuerySetNoDeRef(QuerySet):
|
||||
"""Special no_dereference QuerySet"""
|
||||
|
||||
def __dereference(items, max_depth=1, instance=None, name=None):
|
||||
return items
|
||||
@@ -38,7 +38,7 @@ def query(_doc_cls=None, _field_operation=False, **query):
|
||||
mongo_query.update(value)
|
||||
continue
|
||||
|
||||
parts = key.split('__')
|
||||
parts = key.rsplit('__')
|
||||
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
|
||||
parts = [part for part in parts if not part.isdigit()]
|
||||
# Check for an operator and transform to mongo-style if there is
|
||||
@@ -206,6 +206,10 @@ def update(_doc_cls=None, **update):
|
||||
else:
|
||||
field = cleaned_fields[-1]
|
||||
|
||||
GeoJsonBaseField = _import_class("GeoJsonBaseField")
|
||||
if isinstance(field, GeoJsonBaseField):
|
||||
value = field.to_mongo(value)
|
||||
|
||||
if op in (None, 'set', 'push', 'pull'):
|
||||
if field.required or value is not None:
|
||||
value = field.prepare_query_value(op, value)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import copy
|
||||
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.python_support import product, reduce
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.queryset import transform
|
||||
|
||||
__all__ = ('Q',)
|
||||
|
||||
Reference in New Issue
Block a user