Merge branch 'master' into pr/368

This commit is contained in:
Ross Lawley
2014-06-27 11:32:19 +01:00
76 changed files with 5414 additions and 2203 deletions

View File

@@ -15,7 +15,7 @@ import django
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
VERSION = (0, 8, 2)
VERSION = (0, 8, 7)
def get_version():

View File

@@ -1,4 +1,6 @@
import weakref
import functools
import itertools
from mongoengine.common import _import_class
__all__ = ("BaseDict", "BaseList")
@@ -108,6 +110,14 @@ class BaseList(list):
self._mark_as_changed()
return super(BaseList, self).__delitem__(*args, **kwargs)
def __setslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__setslice__(*args, **kwargs)
def __delslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__delslice__(*args, **kwargs)
def __getstate__(self):
self.instance = None
self._dereferenced = False
@@ -148,3 +158,98 @@ class BaseList(list):
def _mark_as_changed(self):
if hasattr(self._instance, '_mark_as_changed'):
self._instance._mark_as_changed(self._name)
class StrictDict(object):
__slots__ = ()
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
_classes = {}
def __init__(self, **kwargs):
for k,v in kwargs.iteritems():
setattr(self, k, v)
def __getitem__(self, key):
key = '_reserved_' + key if key in self._special_fields else key
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)
def __setitem__(self, key, value):
key = '_reserved_' + key if key in self._special_fields else key
return setattr(self, key, value)
def __contains__(self, key):
return hasattr(self, key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
v = self.get(key, default)
try:
delattr(self, key)
except AttributeError:
pass
return v
def iteritems(self):
for key in self:
yield key, self[key]
def items(self):
return [(k, self[k]) for k in iter(self)]
def keys(self):
return list(iter(self))
def __iter__(self):
return (key for key in self.__slots__ if hasattr(self, key))
def __len__(self):
return len(list(self.iteritems()))
def __eq__(self, other):
return self.items() == other.items()
def __neq__(self, other):
return self.items() != other.items()
@classmethod
def create(cls, allowed_keys):
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
allowed_keys = frozenset(allowed_keys_tuple)
if allowed_keys not in cls._classes:
class SpecificStrictDict(cls):
__slots__ = allowed_keys_tuple
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]
class SemiStrictDict(StrictDict):
__slots__ = ('_extras')
_classes = {}
def __getattr__(self, attr):
try:
super(SemiStrictDict, self).__getattr__(attr)
except AttributeError:
try:
return self.__getattribute__('_extras')[attr]
except KeyError as e:
raise AttributeError(e)
def __setattr__(self, attr, value):
try:
super(SemiStrictDict, self).__setattr__(attr, value)
except AttributeError:
try:
self._extras[attr] = value
except AttributeError:
self._extras = {attr: value}
def __delattr__(self, attr):
try:
super(SemiStrictDict, self).__delattr__(attr)
except AttributeError:
try:
del self._extras[attr]
except KeyError as e:
raise AttributeError(e)
def __iter__(self):
try:
extras_iter = iter(self.__getattribute__('_extras'))
except AttributeError:
extras_iter = ()
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)

View File

@@ -1,10 +1,11 @@
import copy
import operator
import numbers
from collections import Hashable
from functools import partial
import pymongo
from bson import json_util
from bson import json_util, ObjectId
from bson.dbref import DBRef
from bson.son import SON
@@ -12,24 +13,23 @@ from mongoengine import signals
from mongoengine.common import _import_class
from mongoengine.errors import (ValidationError, InvalidDocumentError,
LookUpError)
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
to_str_keys_recursive)
from mongoengine.python_support import PY3, txt_type
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
from mongoengine.base.fields import ComplexBaseField
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__'
class BaseDocument(object):
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
_dynamic = False
_created = True
_dynamic_lock = True
_initialised = False
STRICT = False
def __init__(self, *args, **values):
"""
@@ -38,10 +38,15 @@ class BaseDocument(object):
:param __auto_convert: Try and will cast python objects to Object types
:param values: A dictionary of values for the document
"""
self._initialised = False
self._created = True
if args:
# Combine positional arguments with named arguments.
# We only want named arguments.
field = iter(self._fields_ordered)
# If its an automatic id field then skip to the first defined field
if self._auto_id_field:
next(field)
for value in args:
name = next(field)
if name in values:
@@ -50,7 +55,12 @@ class BaseDocument(object):
__auto_convert = values.pop("__auto_convert", True)
signals.pre_init.send(self.__class__, document=self, values=values)
self._data = {}
if self.STRICT and not self._dynamic:
self._data = StrictDict.create(allowed_keys=self._fields.keys())()
else:
self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())()
self._dynamic_fields = SON()
# Assign default values to instance
for key, field in self._fields.iteritems():
@@ -61,7 +71,6 @@ class BaseDocument(object):
# Set passed values after initialisation
if self._dynamic:
self._dynamic_fields = {}
dynamic_data = {}
for key, value in values.iteritems():
if key in self._fields or key == '_id':
@@ -116,6 +125,7 @@ class BaseDocument(object):
field = DynamicField(db_field=name)
field.name = name
self._dynamic_fields[name] = field
self._fields_ordered += (name,)
if not name.startswith('_'):
value = self.__expand_dynamic_values(name, value)
@@ -125,24 +135,33 @@ class BaseDocument(object):
self._data[name] = value
if hasattr(self, '_changed_fields'):
self._mark_as_changed(name)
try:
self__created = self._created
except AttributeError:
self__created = True
if (self._is_document and not self._created and
if (self._is_document and not self__created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
OperationError = _import_class('OperationError')
msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg)
try:
self__initialised = self._initialised
except AttributeError:
self__initialised = False
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
if (self._is_document and self__initialised
and self__created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value)
def __getstate__(self):
data = {}
for k in ('_changed_fields', '_initialised', '_created'):
for k in ('_changed_fields', '_initialised', '_created',
'_dynamic_fields', '_fields_ordered'):
if hasattr(self, k):
data[k] = getattr(self, k)
data['_data'] = self.to_mongo()
@@ -151,21 +170,24 @@ class BaseDocument(object):
def __setstate__(self, data):
if isinstance(data["_data"], SON):
data["_data"] = self.__class__._from_son(data["_data"])._data
for k in ('_changed_fields', '_initialised', '_created', '_data'):
for k in ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields'):
if k in data:
setattr(self, k, data[k])
if '_fields_ordered' in data:
setattr(type(self), '_fields_ordered', data['_fields_ordered'])
dynamic_fields = data.get('_dynamic_fields') or SON()
for k in dynamic_fields.keys():
setattr(self, k, data["_data"].get(k))
def __iter__(self):
if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered)
def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present.
"""
try:
if name in self._fields:
if name in self._fields_ordered:
return getattr(self, name)
except AttributeError:
pass
@@ -175,7 +197,7 @@ class BaseDocument(object):
"""Dictionary-style field access, set a field's value.
"""
# Ensure that the field exists before settings its value
if name not in self._fields:
if not self._dynamic and name not in self._fields:
raise KeyError(name)
return setattr(self, name, value)
@@ -241,6 +263,8 @@ class BaseDocument(object):
for field_name in self:
value = self._data.get(field_name, None)
field = self._fields.get(field_name)
if field is None and self._dynamic:
field = self._dynamic_fields.get(field_name)
if value is not None:
value = field.to_mongo(value)
@@ -254,8 +278,10 @@ class BaseDocument(object):
data[field.db_field] = value
# If "_id" has not been set, then try and set it
if data["_id"] is None:
data["_id"] = self._data.get("id", None)
Document = _import_class("Document")
if isinstance(self, Document):
if data["_id"] is None:
data["_id"] = self._data.get("id", None)
if data['_id'] is None:
data.pop('_id')
@@ -265,15 +291,6 @@ class BaseDocument(object):
not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)):
data.pop('_cls')
if not self._dynamic:
return data
# Sort dynamic fields by key
dynamic_fields = sorted(self._dynamic_fields.iteritems(),
key=operator.itemgetter(0))
for name, field in dynamic_fields:
data[name] = field.to_mongo(self._data.get(name, None))
return data
def validate(self, clean=True):
@@ -289,11 +306,8 @@ class BaseDocument(object):
errors[NON_FIELD_ERRORS] = error
# Get a list of tuples of field names and their current values
fields = [(field, self._data.get(name))
for name, field in self._fields.items()]
if self._dynamic:
fields += [(field, self._data.get(name))
for name, field in self._dynamic_fields.items()]
fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
self._data.get(name)) for name in self._fields_ordered]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
@@ -318,14 +332,14 @@ class BaseDocument(object):
pk = "None"
if hasattr(self, 'pk'):
pk = self.pk
elif self._instance:
elif self._instance and hasattr(self._instance, 'pk'):
pk = self._instance.pk
message = "ValidationError (%s:%s) " % (self._class_name, pk)
raise ValidationError(message, errors=errors)
def to_json(self):
def to_json(self, *args, **kwargs):
"""Converts a document to JSON"""
return json_util.dumps(self.to_mongo())
return json_util.dumps(self.to_mongo(), *args, **kwargs)
@classmethod
def from_json(cls, json_data):
@@ -377,70 +391,82 @@ class BaseDocument(object):
self._changed_fields.append(key)
def _clear_changed_fields(self):
"""Using get_changed_fields iterate and remove any fields that are
marked as changed"""
for changed in self._get_changed_fields():
parts = changed.split(".")
data = self
for part in parts:
if isinstance(data, list):
try:
data = data[int(part)]
except IndexError:
data = None
elif isinstance(data, dict):
data = data.get(part, None)
else:
data = getattr(data, part, None)
if hasattr(data, "_changed_fields"):
if hasattr(data, "_is_document") and data._is_document:
continue
data._changed_fields = []
self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
field_value[idx]._clear_changed_fields()
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
field_value._clear_changed_fields()
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
for index, value in iterator:
list_key = "%s%s." % (key, index)
if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected)
changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed.
"""
EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
ReferenceField = _import_class("ReferenceField")
changed_fields = []
changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set()
if hasattr(self, 'id'):
if hasattr(self, 'id') and isinstance(self.id, Hashable):
if self.id in inspected:
return _changed_fields
return changed_fields
inspected.add(self.id)
field_list = self._fields.copy()
if self._dynamic:
field_list.update(self._dynamic_fields)
for field_name in field_list:
for field_name in self._fields_ordered:
db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name
field = self._data.get(field_name, None)
if hasattr(field, 'id'):
if field.id in inspected:
continue
inspected.add(field.id)
data = self._data.get(field_name, None)
field = self._fields.get(field_name)
if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields):
if hasattr(data, 'id'):
if data.id in inspected:
continue
inspected.add(data.id)
if isinstance(field, ReferenceField):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in changed_fields):
# Find all embedded fields that have been changed
changed = field._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(field, (list, tuple, dict)) and
db_field_name not in _changed_fields):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(field, 'items'):
iterator = enumerate(field)
else:
iterator = field.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'):
continue
list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
return _changed_fields
changed = data._get_changed_fields(inspected)
changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in changed_fields):
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
self._nestable_types_changed_fields(changed_fields, key, data, inspected)
return changed_fields
def _delta(self):
"""Returns the delta (set, unset) of the changes for a document.
@@ -450,7 +476,6 @@ class BaseDocument(object):
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
parts = []
if hasattr(self, '_changed_fields'):
@@ -461,7 +486,7 @@ class BaseDocument(object):
d = doc
new_path = []
for p in parts:
if isinstance(d, DBRef):
if isinstance(d, (ObjectId, DBRef)):
break
elif isinstance(d, list) and p.isdigit():
d = d[int(p)]
@@ -537,10 +562,6 @@ class BaseDocument(object):
# class if unavailable
class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data)
to_str_keys_recursive(data)
# Return correct subclass for document type
if class_name != cls._class_name:
@@ -578,6 +599,8 @@ class BaseDocument(object):
% (cls._class_name, errors))
raise InvalidDocumentError(msg)
if cls.STRICT:
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields)
obj = cls(__auto_convert=False, **data)
obj._changed_fields = changed_fields
obj._created = False
@@ -630,8 +653,10 @@ class BaseDocument(object):
# Check to see if we need to include _cls
allow_inheritance = cls._meta.get('allow_inheritance',
ALLOW_INHERITANCE)
include_cls = allow_inheritance and not spec.get('sparse', False)
include_cls = (allow_inheritance and not spec.get('sparse', False) and
spec.get('cls', True))
if "cls" in spec:
spec.pop('cls')
for key in spec['fields']:
# If inherited spec continue
if isinstance(key, (list, tuple)):
@@ -754,6 +779,9 @@ class BaseDocument(object):
"""Lookup a field based on its attribute and return a list containing
the field's parents and the field.
"""
ListField = _import_class("ListField")
if not isinstance(parts, (list, tuple)):
parts = [parts]
fields = []
@@ -761,7 +789,7 @@ class BaseDocument(object):
for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
if field_name.isdigit() and isinstance(field, ListField):
new_field = field.field
fields.append(field_name)
continue
@@ -812,7 +840,11 @@ class BaseDocument(object):
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():
if field.choices:
setattr(self,
if self._dynamic:
obj = self
else:
obj = type(self)
setattr(obj,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))

View File

@@ -89,12 +89,7 @@ class BaseField(object):
return self
# Get value from document instance if available
value = instance._data.get(self.name)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = weakref.proxy(instance)
return value
return instance._data.get(self.name)
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
@@ -116,6 +111,10 @@ class BaseField(object):
# Values cant be compared eg: naive and tz datetimes
# So mark it as changed
instance._mark_as_changed(self.name)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None):
@@ -186,7 +185,6 @@ class ComplexBaseField(BaseField):
"""
field = None
__dereference = False
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
@@ -201,9 +199,11 @@ class ComplexBaseField(BaseField):
(self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))))
_dereference = _import_class("DeReference")()
self._auto_dereference = instance._fields[self.name]._auto_dereference
if not self.__dereference and instance._initialised and dereference:
instance._data[self.name] = self._dereference(
if instance._initialised and dereference and instance._data.get(self.name):
instance._data[self.name] = _dereference(
instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name
)
@@ -222,7 +222,7 @@ class ComplexBaseField(BaseField):
if (self._auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict))
and not value._dereferenced):
value = self._dereference(
value = _dereference(
value, max_depth=1, instance=instance, name=self.name
)
value._dereferenced = True
@@ -382,13 +382,6 @@ class ComplexBaseField(BaseField):
owner_document = property(_get_owner_document, _set_owner_document)
@property
def _dereference(self,):
if not self.__dereference:
DeReference = _import_class("DeReference")
self.__dereference = DeReference() # Cached
return self.__dereference
class ObjectIdField(BaseField):
"""A field wrapper around MongoDB's ObjectIds.

View File

@@ -91,11 +91,12 @@ class DocumentMetaclass(type):
attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()])
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name)
for v in doc_fields.itervalues()))
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
#
# Set document hierarchy
@@ -358,12 +359,19 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class.id = field
# Set primary key if not defined by the document
new_class._auto_id_field = getattr(parent_doc_cls,
'_auto_id_field', False)
if not new_class._meta.get('id_field'):
new_class._auto_id_field = True
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id']
# Prepend id field to _fields_ordered
if 'id' in new_class._fields and 'id' not in new_class._fields_ordered:
new_class._fields_ordered = ('id', ) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy
exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned)
module = attrs.get('__module__')

View File

@@ -23,8 +23,9 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'PolygonField',
'ReferenceField', 'StringField', 'ComplexBaseField')
'PointField', 'LineStringField', 'ListField',
'PolygonField', 'ReferenceField', 'StringField',
'ComplexBaseField', 'GeoJsonBaseField')
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)

View File

@@ -18,7 +18,7 @@ _connections = {}
_dbs = {}
def register_connection(alias, name, host='localhost', port=27017,
def register_connection(alias, name, host=None, port=None,
is_slave=False, read_preference=False, slaves=None,
username=None, password=None, **kwargs):
"""Add a connection.
@@ -43,8 +43,8 @@ def register_connection(alias, name, host='localhost', port=27017,
conn_settings = {
'name': name,
'host': host,
'port': port,
'host': host or 'localhost',
'port': port or 27017,
'is_slave': is_slave,
'slaves': slaves or [],
'username': username,
@@ -53,19 +53,15 @@ def register_connection(alias, name, host='localhost', port=27017,
}
# Handle uri style connections
if "://" in host:
uri_dict = uri_parser.parse_uri(host)
if uri_dict.get('database') is None:
raise ConnectionError("If using URI style connection include "\
"database name in string")
if "://" in conn_settings['host']:
uri_dict = uri_parser.parse_uri(conn_settings['host'])
conn_settings.update({
'host': host,
'name': uri_dict.get('database'),
'name': uri_dict.get('database') or name,
'username': uri_dict.get('username'),
'password': uri_dict.get('password'),
'read_preference': read_preference,
})
if "replicaSet" in host:
if "replicaSet" in conn_settings['host']:
conn_settings['replicaSet'] = True
conn_settings.update(kwargs)
@@ -97,20 +93,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
raise ConnectionError(msg)
conn_settings = _connection_settings[alias].copy()
if hasattr(pymongo, 'version_tuple'): # Support for 2.1+
conn_settings.pop('name', None)
conn_settings.pop('slaves', None)
conn_settings.pop('is_slave', None)
conn_settings.pop('username', None)
conn_settings.pop('password', None)
else:
# Get all the slave connections
if 'slaves' in conn_settings:
slaves = []
for slave_alias in conn_settings['slaves']:
slaves.append(get_connection(slave_alias))
conn_settings['slaves'] = slaves
conn_settings.pop('read_preference', None)
conn_settings.pop('name', None)
conn_settings.pop('slaves', None)
conn_settings.pop('is_slave', None)
conn_settings.pop('username', None)
conn_settings.pop('password', None)
connection_class = MongoClient
if 'replicaSet' in conn_settings:
@@ -123,7 +110,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_class = MongoReplicaSetClient
try:
_connections[alias] = connection_class(**conn_settings)
connection = None
connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems())
for alias, connection_settings in connection_settings_iterator:
connection_settings.pop('name', None)
connection_settings.pop('slaves', None)
connection_settings.pop('is_slave', None)
connection_settings.pop('username', None)
connection_settings.pop('password', None)
if conn_settings == connection_settings and _connections.get(alias, None):
connection = _connections[alias]
break
_connections[alias] = connection if connection else connection_class(**conn_settings)
except Exception, e:
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
return _connections[alias]

View File

@@ -1,6 +1,5 @@
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import QuerySet
__all__ = ("switch_db", "switch_collection", "no_dereference",
@@ -162,12 +161,6 @@ class no_sub_classes(object):
return self.cls
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items
class query_counter(object):
""" Query_counter context manager to get the number of queries. """

View File

@@ -4,7 +4,7 @@ from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document)
from fields import (ReferenceField, ListField, DictField, MapField)
from connection import get_db
from queryset import QuerySet
from document import Document
from document import Document, EmbeddedDocument
class DeReference(object):
@@ -33,7 +33,8 @@ class DeReference(object):
self.max_depth = max_depth
doc_type = None
if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)):
if instance and isinstance(instance, (Document, EmbeddedDocument,
TopLevelDocumentMetaclass)):
doc_type = instance._fields.get(name)
while hasattr(doc_type, 'field'):
doc_type = doc_type.field

View File

@@ -8,6 +8,10 @@ from django.contrib import auth
from django.contrib.auth.models import AnonymousUser
from django.utils.translation import ugettext_lazy as _
from .utils import datetime_now
REDIRECT_FIELD_NAME = 'next'
try:
from django.contrib.auth.hashers import check_password, make_password
except ImportError:
@@ -33,10 +37,6 @@ except ImportError:
hash = get_hexdigest(algo, salt, raw_password)
return '%s$%s$%s' % (algo, salt, hash)
from .utils import datetime_now
REDIRECT_FIELD_NAME = 'next'
class ContentType(Document):
name = StringField(max_length=100)
@@ -230,6 +230,9 @@ class User(Document):
date_joined = DateTimeField(default=datetime_now,
verbose_name=_('date joined'))
user_permissions = ListField(ReferenceField(Permission), verbose_name=_('user permissions'),
help_text=_('Permissions for the user.'))
USERNAME_FIELD = 'username'
REQUIRED_FIELDS = ['email']
@@ -378,9 +381,10 @@ class MongoEngineBackend(object):
supports_object_permissions = False
supports_anonymous_user = False
supports_inactive_user = False
_user_doc = False
def authenticate(self, username=None, password=None):
user = User.objects(username=username).first()
user = self.user_document.objects(username=username).first()
if user:
if password and user.check_password(password):
backend = auth.get_backends()[0]
@@ -389,8 +393,14 @@ class MongoEngineBackend(object):
return None
def get_user(self, user_id):
return User.objects.with_id(user_id)
return self.user_document.objects.with_id(user_id)
@property
def user_document(self):
if self._user_doc is False:
from .mongo_auth.models import get_user_document
self._user_doc = get_user_document()
return self._user_doc
def get_user(userid):
"""Returns a User object from an id (User.id). Django's equivalent takes

View File

@@ -1,4 +1,5 @@
from django.conf import settings
from django.contrib.auth.hashers import make_password
from django.contrib.auth.models import UserManager
from django.core.exceptions import ImproperlyConfigured
from django.db import models
@@ -6,10 +7,29 @@ from django.utils.importlib import import_module
from django.utils.translation import ugettext_lazy as _
__all__ = (
'get_user_document',
)
MONGOENGINE_USER_DOCUMENT = getattr(
settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User')
def get_user_document():
"""Get the user document class used for authentication.
This is the class defined in settings.MONGOENGINE_USER_DOCUMENT, which
defaults to `mongoengine.django.auth.User`.
"""
name = MONGOENGINE_USER_DOCUMENT
dot = name.rindex('.')
module = import_module(name[:dot])
return getattr(module, name[dot + 1:])
class MongoUserManager(UserManager):
"""A User manager wich allows the use of MongoEngine documents in Django.
@@ -44,7 +64,7 @@ class MongoUserManager(UserManager):
def contribute_to_class(self, model, name):
super(MongoUserManager, self).contribute_to_class(model, name)
self.dj_model = self.model
self.model = self._get_user_document()
self.model = get_user_document()
self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD
username = models.CharField(_('username'), max_length=30, unique=True)
@@ -55,16 +75,6 @@ class MongoUserManager(UserManager):
field = models.CharField(_(name), max_length=30)
field.contribute_to_class(self.dj_model, name)
def _get_user_document(self):
try:
name = MONGOENGINE_USER_DOCUMENT
dot = name.rindex('.')
module = import_module(name[:dot])
return getattr(module, name[dot + 1:])
except ImportError:
raise ImproperlyConfigured("Error importing %s, please check "
"settings.MONGOENGINE_USER_DOCUMENT"
% name)
def get(self, *args, **kwargs):
try:
@@ -85,5 +95,21 @@ class MongoUserManager(UserManager):
class MongoUser(models.Model):
""""Dummy user model for Django.
MongoUser is used to replace Django's UserManager with MongoUserManager.
The actual user document class is mongoengine.django.auth.User or any
other document class specified in MONGOENGINE_USER_DOCUMENT.
To get the user document class, use `get_user_document()`.
"""
objects = MongoUserManager()
class Meta:
app_label = 'mongo_auth'
def set_password(self, password):
"""Doesn't do anything, but works around the issue with Django 1.6."""
make_password(password)

View File

@@ -1,7 +1,11 @@
from bson import json_util
from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase, CreateError
from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import force_unicode
try:
from django.utils.encoding import force_unicode
except ImportError:
from django.utils.encoding import force_text as force_unicode
from mongoengine.document import Document
from mongoengine import fields
@@ -52,6 +56,12 @@ class SessionStore(SessionBase):
"""A MongoEngine-based session store for Django.
"""
def _get_session(self, *args, **kwargs):
sess = super(SessionStore, self)._get_session(*args, **kwargs)
if sess.get('_auth_user_id', None):
sess['_auth_user_id'] = str(sess.get('_auth_user_id'))
return sess
def load(self):
try:
s = MongoSession.objects(session_key=self.session_key,
@@ -100,3 +110,15 @@ class SessionStore(SessionBase):
return
session_key = self.session_key
MongoSession.objects(session_key=session_key).delete()
class BSONSerializer(object):
"""
Serializer that can handle BSON types (eg ObjectId).
"""
def dumps(self, obj):
return json_util.dumps(obj, separators=(',', ':')).encode('ascii')
def loads(self, data):
return json_util.loads(data.decode('ascii'))

View File

@@ -76,7 +76,7 @@ class GridFSStorage(Storage):
"""Find the documents in the store with the given name
"""
docs = self.document.objects
doc = [d for d in docs if getattr(d, self.field).name == name]
doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name]
if doc:
return doc[0]
else:

View File

@@ -12,7 +12,9 @@ from mongoengine.common import _import_class
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document)
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
from mongoengine.errors import ValidationError
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db, switch_collection
@@ -53,20 +55,21 @@ class EmbeddedDocument(BaseDocument):
dictionary.
"""
__slots__ = ('_instance')
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass
_instance = None
def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._instance = None
self._changed_fields = []
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._data == other._data
return self.to_mongo() == other.to_mongo()
return False
def __ne__(self, other):
@@ -124,6 +127,8 @@ class Document(BaseDocument):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
__slots__ = ('__objects' )
def pk():
"""Primary key alias
"""
@@ -179,7 +184,7 @@ class Document(BaseDocument):
def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, **kwargs):
_refs=None, save_condition=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -202,7 +207,8 @@ class Document(BaseDocument):
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw
to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g., version number)
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
set / unset. Saves are cascaded and any
@@ -216,6 +222,9 @@ class Document(BaseDocument):
meta['cascade'] = True. Also you can pass different kwargs to
the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values.
.. versionchanged:: 0.8.5
Optional save_condition that only overwrites existing documents
if the condition is satisfied in the current db record.
"""
signals.pre_save.send(self.__class__, document=self)
@@ -229,7 +238,8 @@ class Document(BaseDocument):
created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created)
try:
collection = self._get_collection()
@@ -242,7 +252,12 @@ class Document(BaseDocument):
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
select_dict = {'_id': object_id}
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
@@ -262,8 +277,9 @@ class Document(BaseDocument):
if removals:
update_query["$unset"] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=True, **write_concern)
upsert=upsert, **write_concern)
created = is_new_object(last_error)
@@ -281,7 +297,9 @@ class Document(BaseDocument):
kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs
self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError, err:
message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % unicode(err))
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', unicode(err)):
@@ -291,12 +309,12 @@ class Document(BaseDocument):
raise NotUniqueError(message % unicode(err))
raise OperationError(message % unicode(err))
id_field = self._meta['id_field']
if id_field not in self._meta.get('shard_key', []):
if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self, created=created)
self._clear_changed_fields()
self._created = False
signals.post_save.send(self.__class__, document=self, created=created)
return self
def cascade_save(self, *args, **kwargs):
@@ -353,7 +371,13 @@ class Document(BaseDocument):
been saved.
"""
if not self.pk:
raise OperationError('attempt to update a document not yet saved')
if kwargs.get('upsert', False):
query = self.to_mongo()
if "_cls" in query:
del(query["_cls"])
return self._qs.filter(**query).update_one(**kwargs)
else:
raise OperationError('attempt to update a document not yet saved')
# Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs)
@@ -395,7 +419,7 @@ class Document(BaseDocument):
"""
with switch_db(self.__class__, db_alias) as cls:
collection = cls._get_collection()
db = cls._get_db
db = cls._get_db()
self._get_collection = lambda: collection
self._get_db = lambda: db
self._collection = collection
@@ -435,33 +459,45 @@ class Document(BaseDocument):
.. versionadded:: 0.5
"""
import dereference
self._data = dereference.DeReference()(self._data, max_depth)
DeReference = _import_class('DeReference')
DeReference()([self], max_depth + 1)
return self
def reload(self, max_depth=1):
def reload(self, *fields, **kwargs):
"""Reloads all attributes from the database.
:param fields: (optional) args list of fields to reload
:param max_depth: (optional) depth of dereferencing to follow
.. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable
.. versionchanged:: 0.9 Can provide specific fields to reload
"""
id_field = self._meta['id_field']
max_depth = 1
if fields and isinstance(fields[0], int):
max_depth = fields[0]
fields = fields[1:]
elif "max_depth" in kwargs:
max_depth = kwargs["max_depth"]
if not self.pk:
raise self.DoesNotExist("Document does not exist")
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**{id_field: self[id_field]}).limit(1).select_related(max_depth=max_depth)
**self._object_key).only(*fields).limit(1
).select_related(max_depth=max_depth)
if obj:
obj = obj[0]
else:
msg = "Reloaded document has been deleted"
raise OperationError(msg)
for field in self._fields:
setattr(self, field, self._reload(field, obj[field]))
if self._dynamic:
for name in self._dynamic_fields.keys():
setattr(self, name, self._reload(name, obj._data[name]))
raise self.DoesNotExist("Document does not exist")
for field in self._fields_ordered:
if not fields or field in fields:
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = obj._changed_fields
self._created = False
return obj
return self
def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
@@ -474,6 +510,7 @@ class Document(BaseDocument):
value = [self._reload(key, v) for v in value]
value = BaseList(value, self, key)
elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)):
value._instance = None
value._changed_fields = []
return value
@@ -534,6 +571,8 @@ class Document(BaseDocument):
def ensure_indexes(cls):
"""Checks the document meta data and ensures all the indexes exist.
Global defaults can be set in the meta - see :doc:`guide/defining-documents`
.. note:: You can disable automatic index creation by setting
`auto_create_index` to False in the documents meta data
"""
@@ -543,6 +582,8 @@ class Document(BaseDocument):
index_cls = cls._meta.get('index_cls', True)
collection = cls._get_collection()
if collection.read_preference > 1:
return
# determine if an index which we are creating includes
# _cls as its first field; if so, we can avoid creating

View File

@@ -42,7 +42,8 @@ __all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'GenericReferenceField', 'BinaryField', 'GridFSError',
'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField']
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField',
'GeoJsonBaseField']
RECURSIVE_REFERENCE_CONSTANT = 'self'
@@ -152,7 +153,7 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain
)
def validate(self, value):
@@ -279,14 +280,14 @@ class DecimalField(BaseField):
:param precision: Number of decimal places to store.
:param rounding: The rounding rule from the python decimal libary:
- decimial.ROUND_CEILING (towards Infinity)
- decimial.ROUND_DOWN (towards zero)
- decimial.ROUND_FLOOR (towards -Infinity)
- decimial.ROUND_HALF_DOWN (to nearest with ties going towards zero)
- decimial.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer)
- decimial.ROUND_HALF_UP (to nearest with ties going away from zero)
- decimial.ROUND_UP (away from zero)
- decimial.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero)
- decimal.ROUND_CEILING (towards Infinity)
- decimal.ROUND_DOWN (towards zero)
- decimal.ROUND_FLOOR (towards -Infinity)
- decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero)
- decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer)
- decimal.ROUND_HALF_UP (to nearest with ties going away from zero)
- decimal.ROUND_UP (away from zero)
- decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero)
Defaults to: ``decimal.ROUND_HALF_UP``
@@ -304,7 +305,10 @@ class DecimalField(BaseField):
return value
# Convert to string for python 2.6 before casting to Decimal
value = decimal.Decimal("%s" % value)
try:
value = decimal.Decimal("%s" % value)
except decimal.InvalidOperation:
return value
return value.quantize(self.precision, rounding=self.rounding)
def to_mongo(self, value):
@@ -387,7 +391,7 @@ class DateTimeField(BaseField):
if dateutil:
try:
return dateutil.parser.parse(value)
except ValueError:
except (TypeError, ValueError):
return None
# split usecs, because they are not recognized by strptime.
@@ -624,7 +628,9 @@ class DynamicField(BaseField):
cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, (Document, EmbeddedDocument))):
if (isinstance(value, Document)):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
if (isinstance(value, EmbeddedDocument)):
val['_cls'] = cls.__name__
return val
@@ -645,6 +651,15 @@ class DynamicField(BaseField):
value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))]
return value
def to_python(self, value):
if isinstance(value, dict) and '_cls' in value:
doc_cls = get_document(value['_cls'])
if '_ref' in value:
value = doc_cls._get_db().dereference(value['_ref'])
return doc_cls._from_son(value)
return super(DynamicField, self).to_python(value)
def lookup_member(self, member_name):
return member_name
@@ -724,13 +739,28 @@ class SortedListField(ListField):
reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse)
def key_not_string(d):
""" Helper function to recursively determine if any key in a dictionary is
not a string.
"""
for k, v in d.items():
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)):
return True
def key_has_dot_or_dollar(d):
""" Helper function to recursively determine if any key in a dictionary
contains a dot or a dollar sign.
"""
for k, v in d.items():
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
return True
class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
.. note::
Required means it cannot be empty - as the default for ListFields is []
Required means it cannot be empty - as the default for DictFields is {}
.. versionadded:: 0.3
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
@@ -750,11 +780,11 @@ class DictField(ComplexBaseField):
if not isinstance(value, dict):
self.error('Only dictionaries may be used in a DictField')
if any(k for k in value.keys() if not isinstance(k, basestring)):
if key_not_string(value):
msg = ("Invalid dictionary key - documents must "
"have only string keys")
self.error(msg)
if any(('.' in k or '$' in k) for k in value.keys()):
if key_has_dot_or_dollar(value):
self.error('Invalid dictionary key name - keys may not contain "."'
' or "$" characters')
super(DictField, self).validate(value)
@@ -769,6 +799,10 @@ class DictField(ComplexBaseField):
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
if hasattr(self.field, 'field'):
return self.field.prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)
@@ -989,7 +1023,10 @@ class GenericReferenceField(BaseField):
id_ = id_field.to_mongo(id_)
collection = document._get_collection_name()
ref = DBRef(collection, id_)
return {'_cls': document._class_name, '_ref': ref}
return SON((
('_cls', document._class_name),
('_ref', ref)
))
def prepare_query_value(self, op, value):
if value is None:
@@ -1083,6 +1120,10 @@ class GridFSProxy(object):
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
def __str__(self):
name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
return '<%s: %s>' % (self.__class__.__name__, name)
def __eq__(self, other):
if isinstance(other, GridFSProxy):
return ((self.grid_id == other.grid_id) and
@@ -1190,9 +1231,7 @@ class FileField(BaseField):
# Check if a file already exists for this model
grid_file = instance._data.get(self.name)
if not isinstance(grid_file, self.proxy_class):
grid_file = self.proxy_class(key=self.name, instance=instance,
db_alias=self.db_alias,
collection_name=self.collection_name)
grid_file = self.get_proxy_obj(key=self.name, instance=instance)
instance._data[self.name] = grid_file
if not grid_file.key:
@@ -1214,15 +1253,23 @@ class FileField(BaseField):
pass
# Create a new proxy object as we don't already have one
instance._data[key] = self.proxy_class(key=key, instance=instance,
db_alias=self.db_alias,
collection_name=self.collection_name)
instance._data[key] = self.get_proxy_obj(key=key, instance=instance)
instance._data[key].put(value)
else:
instance._data[key] = value
instance._mark_as_changed(key)
def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None):
if db_alias is None:
db_alias = self.db_alias
if collection_name is None:
collection_name = self.collection_name
return self.proxy_class(key=key, instance=instance,
db_alias=db_alias,
collection_name=collection_name)
def to_mongo(self, value):
# Store the GridFS file id in MongoDB
if isinstance(value, self.proxy_class) and value.grid_id is not None:
@@ -1255,6 +1302,9 @@ class ImageGridFsProxy(GridFSProxy):
applying field properties (size, thumbnail_size)
"""
field = self.instance._fields[self.key]
# Handle nested fields
if hasattr(field, 'field') and isinstance(field.field, FileField):
field = field.field
try:
img = Image.open(file_obj)
@@ -1563,7 +1613,12 @@ class UUIDField(BaseField):
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
"""A list storing a longitude and latitude coordinate.
.. note:: this represents a generic point in a 2D plane and a legacy way of
representing a geo point. It admits 2d indexes but not "2dsphere" indexes
in MongoDB > 2.4 which are more natural for modeling geospatial points.
See :ref:`geospatial-indexes`
.. versionadded:: 0.4
"""
@@ -1585,7 +1640,7 @@ class GeoPointField(BaseField):
class PointField(GeoJsonBaseField):
"""A geo json field storing a latitude and longitude.
"""A GeoJSON field storing a longitude and latitude coordinate.
The data is represented as:
@@ -1604,7 +1659,7 @@ class PointField(GeoJsonBaseField):
class LineStringField(GeoJsonBaseField):
"""A geo json field storing a line of latitude and longitude coordinates.
"""A GeoJSON field storing a line of longitude and latitude coordinates.
The data is represented as:
@@ -1622,7 +1677,7 @@ class LineStringField(GeoJsonBaseField):
class PolygonField(GeoJsonBaseField):
"""A geo json field storing a polygon of latitude and longitude coordinates.
"""A GeoJSON field storing a polygon of longitude and latitude coordinates.
The data is represented as:

View File

@@ -3,8 +3,6 @@
import sys
PY3 = sys.version_info[0] == 3
PY25 = sys.version_info[:2] == (2, 5)
UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264
if PY3:
import codecs
@@ -29,33 +27,3 @@ else:
txt_type = unicode
str_types = (bin_type, txt_type)
if PY25:
def product(*args, **kwds):
pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]]
for pool in pools:
result = [x + [y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
reduce = reduce
else:
from itertools import product
from functools import reduce
# For use with Python 2.5
# converts all keys from unicode to str for d and all nested dictionaries
def to_str_keys_recursive(d):
if isinstance(d, list):
for val in d:
if isinstance(val, (dict, list)):
to_str_keys_recursive(val)
elif isinstance(d, dict):
for key, val in d.items():
if isinstance(val, (dict, list)):
to_str_keys_recursive(val)
if isinstance(key, unicode):
d[str(key)] = d.pop(key)
else:
raise ValueError("non list/dict parameter not allowed")

1624
mongoengine/queryset/base.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -55,7 +55,8 @@ class QueryFieldList(object):
if self.always_include:
if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include)
if sorted(self.slice.keys()) != sorted(self.fields):
self.fields = self.fields.union(self.always_include)
else:
self.fields -= self.always_include

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ from collections import defaultdict
import pymongo
from bson import SON
from mongoengine.connection import get_connection
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError, LookUpError
@@ -38,16 +39,16 @@ def query(_doc_cls=None, _field_operation=False, **query):
mongo_query.update(value)
continue
parts = key.split('__')
parts = key.rsplit('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is
op = None
if parts[-1] in MATCH_OPERATORS:
if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
op = parts.pop()
negate = False
if parts[-1] == 'not':
if len(parts) > 1 and parts[-1] == 'not':
parts.pop()
negate = True
@@ -95,6 +96,7 @@ def query(_doc_cls=None, _field_operation=False, **query):
value = _geo_operator(field, op, value)
elif op in CUSTOM_OPERATORS:
if op == 'match':
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value}
else:
NotImplementedError("Custom method '%s' has not "
@@ -114,14 +116,26 @@ def query(_doc_cls=None, _field_operation=False, **query):
if key in mongo_query and isinstance(mongo_query[key], dict):
mongo_query[key].update(value)
# $maxDistance needs to come last - convert to SON
if '$maxDistance' in mongo_query[key]:
value_dict = mongo_query[key]
value_dict = mongo_query[key]
if ('$maxDistance' in value_dict and '$near' in value_dict):
value_son = SON()
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
value_son['$maxDistance'] = value_dict['$maxDistance']
if isinstance(value_dict['$near'], dict):
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
if (get_connection().max_wire_version <= 1):
value_son['$maxDistance'] = value_dict['$maxDistance']
else:
value_son['$near'] = SON(value_son['$near'])
value_son['$near']['$maxDistance'] = value_dict['$maxDistance']
else:
for k, v in value_dict.iteritems():
if k == '$maxDistance':
continue
value_son[k] = v
value_son['$maxDistance'] = value_dict['$maxDistance']
mongo_query[key] = value_son
else:
# Store for manually merging later
@@ -181,6 +195,7 @@ def update(_doc_cls=None, **update):
parts = []
cleaned_fields = []
appended_sub_field = False
for field in fields:
append_field = True
if isinstance(field, basestring):
@@ -192,21 +207,34 @@ def update(_doc_cls=None, **update):
else:
parts.append(field.db_field)
if append_field:
appended_sub_field = False
cleaned_fields.append(field)
if hasattr(field, 'field'):
cleaned_fields.append(field.field)
appended_sub_field = True
# Convert value to proper value
field = cleaned_fields[-1]
if appended_sub_field:
field = cleaned_fields[-2]
else:
field = cleaned_fields[-1]
GeoJsonBaseField = _import_class("GeoJsonBaseField")
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
if op in (None, 'set', 'push', 'pull'):
if field.required or value is not None:
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value]
elif op == 'addToSet':
elif op in ('addToSet', 'setOnInsert'):
if isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif field.required or value is not None:
value = field.prepare_query_value(op, value)
elif op == "unset":
value = 1
if match:
match = '$' + match
@@ -220,11 +248,24 @@ def update(_doc_cls=None, **update):
if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations
# it uses nested dict syntax
# unless they point to a list field
# Otherwise it uses nested dict syntax
if op == 'pullAll':
raise InvalidQueryError("pullAll operations only support "
"a single field depth")
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
ListField = _import_class('ListField')
if ListField in field_classes:
# Join all fields via dot notation to the last ListField
# Then process as normal
last_listField = len(cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
parts.reverse()
for key in parts:
value = {key: value}

View File

@@ -1,8 +1,9 @@
import copy
from mongoengine.errors import InvalidQueryError
from mongoengine.python_support import product, reduce
from itertools import product
from functools import reduce
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform
__all__ = ('Q',)