Merge branch 'master' into pr/625

This commit is contained in:
Ross Lawley
2014-06-26 16:48:12 +01:00
66 changed files with 2157 additions and 532 deletions

View File

@@ -15,7 +15,7 @@ import django
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
VERSION = (0, 8, 4)
VERSION = (0, 8, 7)
def get_version():

View File

@@ -1,6 +1,7 @@
import copy
import operator
import numbers
from collections import Hashable
from functools import partial
import pymongo
@@ -12,8 +13,7 @@ from mongoengine import signals
from mongoengine.common import _import_class
from mongoengine.errors import (ValidationError, InvalidDocumentError,
LookUpError)
from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
to_str_keys_recursive)
from mongoengine.python_support import PY3, txt_type
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
@@ -197,7 +197,7 @@ class BaseDocument(object):
"""Dictionary-style field access, set a field's value.
"""
# Ensure that the field exists before settings its value
if name not in self._fields:
if not self._dynamic and name not in self._fields:
raise KeyError(name)
return setattr(self, name, value)
@@ -391,20 +391,41 @@ class BaseDocument(object):
self._changed_fields.append(key)
def _clear_changed_fields(self):
"""Using get_changed_fields iterate and remove any fields that are
marked as changed"""
for changed in self._get_changed_fields():
parts = changed.split(".")
data = self
for part in parts:
if isinstance(data, list):
try:
data = data[int(part)]
except IndexError:
data = None
elif isinstance(data, dict):
data = data.get(part, None)
else:
data = getattr(data, part, None)
if hasattr(data, "_changed_fields"):
data._changed_fields = []
self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
field_value[idx]._clear_changed_fields()
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
field_value._clear_changed_fields()
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
for index, value in iterator:
list_key = "%s%s." % (key, index)
if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected)
changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed.
@@ -412,13 +433,12 @@ class BaseDocument(object):
EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
ReferenceField = _import_class("ReferenceField")
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
changed_fields = []
changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set()
if hasattr(self, 'id'):
if hasattr(self, 'id') and isinstance(self.id, Hashable):
if self.id in inspected:
return _changed_fields
return changed_fields
inspected.add(self.id)
for field_name in self._fields_ordered:
@@ -434,29 +454,17 @@ class BaseDocument(object):
if isinstance(field, ReferenceField):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields):
and db_field_name not in changed_fields):
# Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in _changed_fields):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'):
continue
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
return _changed_fields
db_field_name not in changed_fields):
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
self._nestable_types_changed_fields(changed_fields, key, data, inspected)
return changed_fields
def _delta(self):
"""Returns the delta (set, unset) of the changes for a document.
@@ -552,10 +560,6 @@ class BaseDocument(object):
# class if unavailable
class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data)
to_str_keys_recursive(data)
# Return correct subclass for document type
if class_name != cls._class_name:
@@ -773,6 +777,9 @@ class BaseDocument(object):
"""Lookup a field based on its attribute and return a list containing
the field's parents and the field.
"""
ListField = _import_class("ListField")
if not isinstance(parts, (list, tuple)):
parts = [parts]
fields = []
@@ -780,7 +787,7 @@ class BaseDocument(object):
for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit() and hasattr(field, 'field'):
if field_name.isdigit() and isinstance(field, ListField):
new_field = field.field
fields.append(field_name)
continue

View File

@@ -89,12 +89,7 @@ class BaseField(object):
return self
# Get value from document instance if available
value = instance._data.get(self.name)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = weakref.proxy(instance)
return value
return instance._data.get(self.name)
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
@@ -116,6 +111,10 @@ class BaseField(object):
# Values cant be compared eg: naive and tz datetimes
# So mark it as changed
instance._mark_as_changed(self.name)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None):
@@ -203,7 +202,7 @@ class ComplexBaseField(BaseField):
_dereference = _import_class("DeReference")()
self._auto_dereference = instance._fields[self.name]._auto_dereference
if instance._initialised and dereference:
if instance._initialised and dereference and instance._data.get(self.name):
instance._data[self.name] = _dereference(
instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name

View File

@@ -25,7 +25,7 @@ def _import_class(cls_name):
'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'ListField',
'PolygonField', 'ReferenceField', 'StringField',
'ComplexBaseField')
'ComplexBaseField', 'GeoJsonBaseField')
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)

View File

@@ -18,7 +18,7 @@ _connections = {}
_dbs = {}
def register_connection(alias, name, host='localhost', port=27017,
def register_connection(alias, name, host=None, port=None,
is_slave=False, read_preference=False, slaves=None,
username=None, password=None, **kwargs):
"""Add a connection.
@@ -43,8 +43,8 @@ def register_connection(alias, name, host='localhost', port=27017,
conn_settings = {
'name': name,
'host': host,
'port': port,
'host': host or 'localhost',
'port': port or 27017,
'is_slave': is_slave,
'slaves': slaves or [],
'username': username,
@@ -53,16 +53,15 @@ def register_connection(alias, name, host='localhost', port=27017,
}
# Handle uri style connections
if "://" in host:
uri_dict = uri_parser.parse_uri(host)
if "://" in conn_settings['host']:
uri_dict = uri_parser.parse_uri(conn_settings['host'])
conn_settings.update({
'host': host,
'name': uri_dict.get('database') or name,
'username': uri_dict.get('username'),
'password': uri_dict.get('password'),
'read_preference': read_preference,
})
if "replicaSet" in host:
if "replicaSet" in conn_settings['host']:
conn_settings['replicaSet'] = True
conn_settings.update(kwargs)
@@ -94,20 +93,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
raise ConnectionError(msg)
conn_settings = _connection_settings[alias].copy()
if hasattr(pymongo, 'version_tuple'): # Support for 2.1+
conn_settings.pop('name', None)
conn_settings.pop('slaves', None)
conn_settings.pop('is_slave', None)
conn_settings.pop('username', None)
conn_settings.pop('password', None)
else:
# Get all the slave connections
if 'slaves' in conn_settings:
slaves = []
for slave_alias in conn_settings['slaves']:
slaves.append(get_connection(slave_alias))
conn_settings['slaves'] = slaves
conn_settings.pop('read_preference', None)
conn_settings.pop('name', None)
conn_settings.pop('slaves', None)
conn_settings.pop('is_slave', None)
conn_settings.pop('username', None)
conn_settings.pop('password', None)
connection_class = MongoClient
if 'replicaSet' in conn_settings:
@@ -120,7 +110,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_class = MongoReplicaSetClient
try:
_connections[alias] = connection_class(**conn_settings)
connection = None
connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems())
for alias, connection_settings in connection_settings_iterator:
connection_settings.pop('name', None)
connection_settings.pop('slaves', None)
connection_settings.pop('is_slave', None)
connection_settings.pop('username', None)
connection_settings.pop('password', None)
if conn_settings == connection_settings and _connections.get(alias, None):
connection = _connections[alias]
break
_connections[alias] = connection if connection else connection_class(**conn_settings)
except Exception, e:
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
return _connections[alias]

View File

@@ -1,6 +1,5 @@
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import QuerySet
__all__ = ("switch_db", "switch_collection", "no_dereference",
@@ -162,12 +161,6 @@ class no_sub_classes(object):
return self.cls
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items
class query_counter(object):
""" Query_counter context manager to get the number of queries. """

View File

@@ -8,6 +8,10 @@ from django.contrib import auth
from django.contrib.auth.models import AnonymousUser
from django.utils.translation import ugettext_lazy as _
from .utils import datetime_now
REDIRECT_FIELD_NAME = 'next'
try:
from django.contrib.auth.hashers import check_password, make_password
except ImportError:
@@ -33,10 +37,6 @@ except ImportError:
hash = get_hexdigest(algo, salt, raw_password)
return '%s$%s$%s' % (algo, salt, hash)
from .utils import datetime_now
REDIRECT_FIELD_NAME = 'next'
class ContentType(Document):
name = StringField(max_length=100)
@@ -230,6 +230,9 @@ class User(Document):
date_joined = DateTimeField(default=datetime_now,
verbose_name=_('date joined'))
user_permissions = ListField(ReferenceField(Permission), verbose_name=_('user permissions'),
help_text=_('Permissions for the user.'))
USERNAME_FIELD = 'username'
REQUIRED_FIELDS = ['email']
@@ -378,9 +381,10 @@ class MongoEngineBackend(object):
supports_object_permissions = False
supports_anonymous_user = False
supports_inactive_user = False
_user_doc = False
def authenticate(self, username=None, password=None):
user = User.objects(username=username).first()
user = self.user_document.objects(username=username).first()
if user:
if password and user.check_password(password):
backend = auth.get_backends()[0]
@@ -389,8 +393,14 @@ class MongoEngineBackend(object):
return None
def get_user(self, user_id):
return User.objects.with_id(user_id)
return self.user_document.objects.with_id(user_id)
@property
def user_document(self):
if self._user_doc is False:
from .mongo_auth.models import get_user_document
self._user_doc = get_user_document()
return self._user_doc
def get_user(userid):
"""Returns a User object from an id (User.id). Django's equivalent takes

View File

@@ -1,4 +1,5 @@
from django.conf import settings
from django.contrib.auth.hashers import make_password
from django.contrib.auth.models import UserManager
from django.core.exceptions import ImproperlyConfigured
from django.db import models
@@ -105,3 +106,10 @@ class MongoUser(models.Model):
"""
objects = MongoUserManager()
class Meta:
app_label = 'mongo_auth'
def set_password(self, password):
"""Doesn't do anything, but works around the issue with Django 1.6."""
make_password(password)

View File

@@ -1,3 +1,4 @@
from bson import json_util
from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase, CreateError
from django.core.exceptions import SuspiciousOperation
@@ -55,6 +56,12 @@ class SessionStore(SessionBase):
"""A MongoEngine-based session store for Django.
"""
def _get_session(self, *args, **kwargs):
sess = super(SessionStore, self)._get_session(*args, **kwargs)
if sess.get('_auth_user_id', None):
sess['_auth_user_id'] = str(sess.get('_auth_user_id'))
return sess
def load(self):
try:
s = MongoSession.objects(session_key=self.session_key,
@@ -103,3 +110,15 @@ class SessionStore(SessionBase):
return
session_key = self.session_key
MongoSession.objects(session_key=session_key).delete()
class BSONSerializer(object):
"""
Serializer that can handle BSON types (eg ObjectId).
"""
def dumps(self, obj):
return json_util.dumps(obj, separators=(',', ':')).encode('ascii')
def loads(self, data):
return json_util.loads(data.decode('ascii'))

View File

@@ -76,7 +76,7 @@ class GridFSStorage(Storage):
"""Find the documents in the store with the given name
"""
docs = self.document.objects
doc = [d for d in docs if getattr(d, self.field).name == name]
doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name]
if doc:
return doc[0]
else:

View File

@@ -12,7 +12,9 @@ from mongoengine.common import _import_class
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document)
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
from mongoengine.errors import ValidationError
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db, switch_collection
@@ -67,7 +69,7 @@ class EmbeddedDocument(BaseDocument):
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._data == other._data
return self.to_mongo() == other.to_mongo()
return False
def __ne__(self, other):
@@ -182,7 +184,7 @@ class Document(BaseDocument):
def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, **kwargs):
_refs=None, save_condition=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -205,7 +207,8 @@ class Document(BaseDocument):
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw
to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g., version number)
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
set / unset. Saves are cascaded and any
@@ -219,6 +222,9 @@ class Document(BaseDocument):
meta['cascade'] = True. Also you can pass different kwargs to
the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values.
.. versionchanged:: 0.8.5
Optional save_condition that only overwrites existing documents
if the condition is satisfied in the current db record.
"""
signals.pre_save.send(self.__class__, document=self)
@@ -232,7 +238,8 @@ class Document(BaseDocument):
created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created)
try:
collection = self._get_collection()
@@ -245,7 +252,12 @@ class Document(BaseDocument):
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
select_dict = {'_id': object_id}
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
@@ -265,10 +277,12 @@ class Document(BaseDocument):
if removals:
update_query["$unset"] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=True, **write_concern)
upsert=upsert, **write_concern)
created = is_new_object(last_error)
if cascade is None:
cascade = self._meta.get('cascade', False) or cascade_kwargs is not None
@@ -283,7 +297,9 @@ class Document(BaseDocument):
kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs
self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError, err:
message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % unicode(err))
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', unicode(err)):
@@ -453,14 +469,16 @@ class Document(BaseDocument):
.. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable
"""
if not self.pk:
raise self.DoesNotExist("Document does not exist")
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).limit(1).select_related(max_depth=max_depth)
**self._object_key).limit(1).select_related(max_depth=max_depth)
if obj:
obj = obj[0]
else:
msg = "Reloaded document has been deleted"
raise OperationError(msg)
raise self.DoesNotExist("Document does not exist")
for field in self._fields_ordered:
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = obj._changed_fields
@@ -550,6 +568,8 @@ class Document(BaseDocument):
index_cls = cls._meta.get('index_cls', True)
collection = cls._get_collection()
if collection.read_preference > 1:
return
# determine if an index which we are creating includes
# _cls as its first field; if so, we can avoid creating

View File

@@ -42,7 +42,8 @@ __all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'GenericReferenceField', 'BinaryField', 'GridFSError',
'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField']
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField',
'GeoJsonBaseField']
RECURSIVE_REFERENCE_CONSTANT = 'self'
@@ -152,7 +153,7 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain
)
def validate(self, value):
@@ -304,7 +305,10 @@ class DecimalField(BaseField):
return value
# Convert to string for python 2.6 before casting to Decimal
value = decimal.Decimal("%s" % value)
try:
value = decimal.Decimal("%s" % value)
except decimal.InvalidOperation:
return value
return value.quantize(self.precision, rounding=self.rounding)
def to_mongo(self, value):
@@ -387,7 +391,7 @@ class DateTimeField(BaseField):
if dateutil:
try:
return dateutil.parser.parse(value)
except ValueError:
except (TypeError, ValueError):
return None
# split usecs, because they are not recognized by strptime.
@@ -735,13 +739,28 @@ class SortedListField(ListField):
reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse)
def key_not_string(d):
""" Helper function to recursively determine if any key in a dictionary is
not a string.
"""
for k, v in d.items():
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)):
return True
def key_has_dot_or_dollar(d):
""" Helper function to recursively determine if any key in a dictionary
contains a dot or a dollar sign.
"""
for k, v in d.items():
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
return True
class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
.. note::
Required means it cannot be empty - as the default for ListFields is []
Required means it cannot be empty - as the default for DictFields is {}
.. versionadded:: 0.3
.. versionchanged:: 0.5 - Can now handle complex / varying types of data
@@ -761,11 +780,11 @@ class DictField(ComplexBaseField):
if not isinstance(value, dict):
self.error('Only dictionaries may be used in a DictField')
if any(k for k in value.keys() if not isinstance(k, basestring)):
if key_not_string(value):
msg = ("Invalid dictionary key - documents must "
"have only string keys")
self.error(msg)
if any(('.' in k or '$' in k) for k in value.keys()):
if key_has_dot_or_dollar(value):
self.error('Invalid dictionary key name - keys may not contain "."'
' or "$" characters')
super(DictField, self).validate(value)
@@ -1004,7 +1023,10 @@ class GenericReferenceField(BaseField):
id_ = id_field.to_mongo(id_)
collection = document._get_collection_name()
ref = DBRef(collection, id_)
return {'_cls': document._class_name, '_ref': ref}
return SON((
('_cls', document._class_name),
('_ref', ref)
))
def prepare_query_value(self, op, value):
if value is None:
@@ -1591,7 +1613,12 @@ class UUIDField(BaseField):
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
"""A list storing a longitude and latitude coordinate.
.. note:: this represents a generic point in a 2D plane and a legacy way of
representing a geo point. It admits 2d indexes but not "2dsphere" indexes
in MongoDB > 2.4 which are more natural for modeling geospatial points.
See :ref:`geospatial-indexes`
.. versionadded:: 0.4
"""
@@ -1613,7 +1640,7 @@ class GeoPointField(BaseField):
class PointField(GeoJsonBaseField):
"""A geo json field storing a latitude and longitude.
"""A GeoJSON field storing a longitude and latitude coordinate.
The data is represented as:
@@ -1632,7 +1659,7 @@ class PointField(GeoJsonBaseField):
class LineStringField(GeoJsonBaseField):
"""A geo json field storing a line of latitude and longitude coordinates.
"""A GeoJSON field storing a line of longitude and latitude coordinates.
The data is represented as:
@@ -1650,7 +1677,7 @@ class LineStringField(GeoJsonBaseField):
class PolygonField(GeoJsonBaseField):
"""A geo json field storing a polygon of latitude and longitude coordinates.
"""A GeoJSON field storing a polygon of longitude and latitude coordinates.
The data is represented as:

View File

@@ -3,8 +3,6 @@
import sys
PY3 = sys.version_info[0] == 3
PY25 = sys.version_info[:2] == (2, 5)
UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264
if PY3:
import codecs
@@ -29,33 +27,3 @@ else:
txt_type = unicode
str_types = (bin_type, txt_type)
if PY25:
def product(*args, **kwds):
pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]]
for pool in pools:
result = [x + [y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
reduce = reduce
else:
from itertools import product
from functools import reduce
# For use with Python 2.5
# converts all keys from unicode to str for d and all nested dictionaries
def to_str_keys_recursive(d):
if isinstance(d, list):
for val in d:
if isinstance(val, (dict, list)):
to_str_keys_recursive(val)
elif isinstance(d, dict):
for key, val in d.items():
if isinstance(val, (dict, list)):
to_str_keys_recursive(val)
if isinstance(key, unicode):
d[str(key)] = d.pop(key)
else:
raise ValueError("non list/dict parameter not allowed")

View File

@@ -10,14 +10,15 @@ import warnings
from bson.code import Code
from bson import json_util
import pymongo
import pymongo.errors
from pymongo.common import validate_read_preference
from mongoengine import signals
from mongoengine.context_managers import switch_db
from mongoengine.common import _import_class
from mongoengine.base.common import get_document
from mongoengine.errors import (OperationError, NotUniqueError,
InvalidQueryError, LookUpError)
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
@@ -50,7 +51,7 @@ class BaseQuerySet(object):
self._initial_query = {}
self._where_clause = None
self._loaded_fields = QueryFieldList()
self._ordering = []
self._ordering = None
self._snapshot = False
self._timeout = True
self._class_check = True
@@ -154,6 +155,22 @@ class BaseQuerySet(object):
def __iter__(self):
raise NotImplementedError
def _has_data(self):
""" Retrieves whether cursor has any data. """
queryset = self.order_by()
return False if queryset.first() is None else True
def __nonzero__(self):
""" Avoid to open all records in an if stmt in Py2. """
return self._has_data()
def __bool__(self):
""" Avoid to open all records in an if stmt in Py3. """
return self._has_data()
# Core functions
def all(self):
@@ -302,8 +319,11 @@ class BaseQuerySet(object):
signals.pre_bulk_insert.send(self._document, documents=docs)
try:
ids = self._collection.insert(raw, **write_concern)
except pymongo.errors.DuplicateKeyError, err:
message = 'Could not save document (%s)';
raise NotUniqueError(message % unicode(err))
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
message = 'Could not save document (%s)';
if re.match('^E1100[01] duplicate key', unicode(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
@@ -331,7 +351,7 @@ class BaseQuerySet(object):
:meth:`skip` that has been applied to this cursor into account when
getting the count
"""
if self._limit == 0 and with_limit_and_skip:
if self._limit == 0 and with_limit_and_skip or self._none:
return 0
return self._cursor.count(with_limit_and_skip=with_limit_and_skip)
@@ -386,7 +406,7 @@ class BaseQuerySet(object):
ref_q = document_cls.objects(**{field_name + '__in': self})
ref_q_count = ref_q.count()
if (doc != document_cls and ref_q_count > 0
or (doc == document_cls and ref_q_count > 0)):
or (doc == document_cls and ref_q_count > 0)):
ref_q.delete(write_concern=write_concern)
elif rule == NULLIFY:
document_cls.objects(**{field_name + '__in': self}).update(
@@ -440,6 +460,8 @@ class BaseQuerySet(object):
return result
elif result:
return result['n']
except pymongo.errors.DuplicateKeyError, err:
raise NotUniqueError(u'Update failed (%s)' % unicode(err))
except pymongo.errors.OperationFailure, err:
if unicode(err) == u'multi not coded yet':
message = u'update() method requires MongoDB 1.1.3+'
@@ -463,6 +485,59 @@ class BaseQuerySet(object):
return self.update(
upsert=upsert, multi=False, write_concern=write_concern, **update)
def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
"""Update and return the updated document.
Returns either the document before or after modification based on `new`
parameter. If no documents match the query and `upsert` is false,
returns ``None``. If upserting and `new` is false, returns ``None``.
If the full_response parameter is ``True``, the return value will be
the entire response object from the server, including the 'ok' and
'lastErrorObject' fields, rather than just the modified document.
This is useful mainly because the 'lastErrorObject' document holds
information about the command's execution.
:param upsert: insert if document doesn't exist (default ``False``)
:param full_response: return the entire response object from the
server (default ``False``)
:param remove: remove rather than updating (default ``False``)
:param new: return updated rather than original document
(default ``False``)
:param update: Django-style update keyword arguments
.. versionadded:: 0.9
"""
if remove and new:
raise OperationError("Conflicting parameters: remove and new")
if not update and not upsert and not remove:
raise OperationError("No update parameters, must either update or remove")
queryset = self.clone()
query = queryset._query
update = transform.update(queryset._document, **update)
sort = queryset._ordering
try:
result = queryset._collection.find_and_modify(
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
full_response=full_response, **self._cursor_args)
except pymongo.errors.DuplicateKeyError, err:
raise NotUniqueError(u"Update failed (%s)" % err)
except pymongo.errors.OperationFailure, err:
raise OperationError(u"Update failed (%s)" % err)
if full_response:
if result["value"] is not None:
result["value"] = self._document._from_son(result["value"])
else:
if result is not None:
result = self._document._from_son(result)
return result
def with_id(self, object_id):
"""Retrieve the object matching the id provided. Uses `object_id` only
and raises InvalidQueryError if a filter has been applied. Returns
@@ -519,6 +594,19 @@ class BaseQuerySet(object):
return self
def using(self, alias):
"""This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database.
:param alias: The database alias
.. versionadded:: 0.8
"""
with switch_db(self._document, alias) as cls:
collection = cls._get_collection()
return self.clone_into(self.__class__(self._document, collection))
def clone(self):
"""Creates a copy of the current
:class:`~mongoengine.queryset.QuerySet`
@@ -621,8 +709,15 @@ class BaseQuerySet(object):
try:
field = self._fields_to_dbfields([field]).pop()
finally:
return self._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=self._document)
distinct = self._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=self._document)
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
doc_field = getattr(self._document._fields.get(field), "field", None)
instance = getattr(doc_field, "document_type", False)
if instance:
distinct = [instance(**doc) for doc in distinct]
return distinct
def only(self, *fields):
"""Load only a subset of this document's fields. ::
@@ -850,7 +945,7 @@ class BaseQuerySet(object):
:param output: output collection name, if set to 'inline' will try to
use :class:`~pymongo.collection.Collection.inline_map_reduce`
This can also be a dictionary containing output options
see: http://docs.mongodb.org/manual/reference/commands/#mapReduce
see: http://docs.mongodb.org/manual/reference/command/mapReduce/#dbcmd.mapReduce
:param finalize_f: finalize function, an optional function that
performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional.
@@ -916,7 +1011,7 @@ class BaseQuerySet(object):
mr_args['out'] = output
results = getattr(queryset._collection, map_reduce_function)(
map_f, reduce_f, **mr_args)
map_f, reduce_f, **mr_args)
if map_reduce_function == 'map_reduce':
results = results.find()
@@ -1179,8 +1274,9 @@ class BaseQuerySet(object):
if self._ordering:
# Apply query ordering
self._cursor_obj.sort(self._ordering)
elif self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model
elif self._ordering is None and self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model, unless
# it's been explicitly cleared via order_by with no arguments
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
@@ -1352,7 +1448,7 @@ class BaseQuerySet(object):
for subdoc in subclasses:
try:
subfield = ".".join(f.db_field for f in
subdoc._lookup_field(field.split('.')))
subdoc._lookup_field(field.split('.')))
ret.append(subfield)
found = True
break
@@ -1382,7 +1478,7 @@ class BaseQuerySet(object):
pass
key_list.append((key, direction))
if self._cursor_obj:
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list
@@ -1440,6 +1536,7 @@ class BaseQuerySet(object):
# type of this field and use the corresponding
# .to_python(...)
from mongoengine.fields import EmbeddedDocumentField
obj = self._document
for chunk in path.split('.'):
obj = getattr(obj, chunk, None)
@@ -1450,6 +1547,7 @@ class BaseQuerySet(object):
if obj and data is not None:
data = obj.to_python(data)
return data
return clean(row)
def _sub_js_fields(self, code):
@@ -1458,6 +1556,7 @@ class BaseQuerySet(object):
substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor).
"""
def field_sub(match):
# Extract just the field name, and look up the field objects
field_name = match.group(1).split('.')
@@ -1491,4 +1590,4 @@ class BaseQuerySet(object):
msg = ("Doc.objects()._ensure_indexes() is deprecated. "
"Use Doc.ensure_indexes() instead.")
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_indexes()
self._document.__class__.ensure_indexes()

View File

@@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet):
queryset = self.clone()
queryset.rewind()
return queryset
class QuerySetNoDeRef(QuerySet):
"""Special no_dereference QuerySet"""
def __dereference(items, max_depth=1, instance=None, name=None):
return items

View File

@@ -38,7 +38,7 @@ def query(_doc_cls=None, _field_operation=False, **query):
mongo_query.update(value)
continue
parts = key.split('__')
parts = key.rsplit('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is
@@ -206,6 +206,10 @@ def update(_doc_cls=None, **update):
else:
field = cleaned_fields[-1]
GeoJsonBaseField = _import_class("GeoJsonBaseField")
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
if op in (None, 'set', 'push', 'pull'):
if field.required or value is not None:
value = field.prepare_query_value(op, value)

View File

@@ -1,8 +1,9 @@
import copy
from mongoengine.errors import InvalidQueryError
from mongoengine.python_support import product, reduce
from itertools import product
from functools import reduce
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform
__all__ = ('Q',)