Merge branch 'master' into pr/625
This commit is contained in:
		| @@ -15,7 +15,7 @@ import django | ||||
| __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + | ||||
|            list(queryset.__all__) + signals.__all__ + list(errors.__all__)) | ||||
|  | ||||
| VERSION = (0, 8, 4) | ||||
| VERSION = (0, 8, 7) | ||||
|  | ||||
|  | ||||
| def get_version(): | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import copy | ||||
| import operator | ||||
| import numbers | ||||
| from collections import Hashable | ||||
| from functools import partial | ||||
|  | ||||
| import pymongo | ||||
| @@ -12,8 +13,7 @@ from mongoengine import signals | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import (ValidationError, InvalidDocumentError, | ||||
|                                 LookUpError) | ||||
| from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, | ||||
|                                         to_str_keys_recursive) | ||||
| from mongoengine.python_support import PY3, txt_type | ||||
|  | ||||
| from mongoengine.base.common import get_document, ALLOW_INHERITANCE | ||||
| from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict  | ||||
| @@ -197,7 +197,7 @@ class BaseDocument(object): | ||||
|         """Dictionary-style field access, set a field's value. | ||||
|         """ | ||||
|         # Ensure that the field exists before settings its value | ||||
|         if name not in self._fields: | ||||
|         if not self._dynamic and name not in self._fields: | ||||
|             raise KeyError(name) | ||||
|         return setattr(self, name, value) | ||||
|  | ||||
| @@ -391,20 +391,41 @@ class BaseDocument(object): | ||||
|             self._changed_fields.append(key) | ||||
|  | ||||
|     def _clear_changed_fields(self): | ||||
|         """Using get_changed_fields iterate and remove any fields that are | ||||
|         marked as changed""" | ||||
|         for changed in self._get_changed_fields(): | ||||
|             parts = changed.split(".") | ||||
|             data = self | ||||
|             for part in parts: | ||||
|                 if isinstance(data, list): | ||||
|                     try: | ||||
|                         data = data[int(part)] | ||||
|                     except IndexError: | ||||
|                         data = None | ||||
|                 elif isinstance(data, dict): | ||||
|                     data = data.get(part, None) | ||||
|                 else: | ||||
|                     data = getattr(data, part, None) | ||||
|                 if hasattr(data, "_changed_fields"): | ||||
|                     data._changed_fields = [] | ||||
|         self._changed_fields = [] | ||||
|         EmbeddedDocumentField = _import_class("EmbeddedDocumentField") | ||||
|         for field_name, field in self._fields.iteritems(): | ||||
|             if (isinstance(field, ComplexBaseField) and | ||||
|                isinstance(field.field, EmbeddedDocumentField)): | ||||
|                 field_value = getattr(self, field_name, None) | ||||
|                 if field_value: | ||||
|                     for idx in (field_value if isinstance(field_value, dict) | ||||
|                                 else xrange(len(field_value))): | ||||
|                         field_value[idx]._clear_changed_fields() | ||||
|             elif isinstance(field, EmbeddedDocumentField): | ||||
|                 field_value = getattr(self, field_name, None) | ||||
|                 if field_value: | ||||
|                     field_value._clear_changed_fields() | ||||
|  | ||||
|     def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): | ||||
|         # Loop list / dict fields as they contain documents | ||||
|         # Determine the iterator to use | ||||
|         if not hasattr(data, 'items'): | ||||
|             iterator = enumerate(data) | ||||
|         else: | ||||
|             iterator = data.iteritems() | ||||
|  | ||||
|         for index, value in iterator: | ||||
|             list_key = "%s%s." % (key, index) | ||||
|             if hasattr(value, '_get_changed_fields'): | ||||
|                 changed = value._get_changed_fields(inspected) | ||||
|                 changed_fields += ["%s%s" % (list_key, k) | ||||
|                                     for k in changed if k] | ||||
|             elif isinstance(value, (list, tuple, dict)): | ||||
|                 self._nestable_types_changed_fields(changed_fields, list_key, value, inspected) | ||||
|  | ||||
|     def _get_changed_fields(self, inspected=None): | ||||
|         """Returns a list of all fields that have explicitly been changed. | ||||
| @@ -412,13 +433,12 @@ class BaseDocument(object): | ||||
|         EmbeddedDocument = _import_class("EmbeddedDocument") | ||||
|         DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") | ||||
|         ReferenceField = _import_class("ReferenceField") | ||||
|         _changed_fields = [] | ||||
|         _changed_fields += getattr(self, '_changed_fields', []) | ||||
|  | ||||
|         changed_fields = [] | ||||
|         changed_fields += getattr(self, '_changed_fields', []) | ||||
|         inspected = inspected or set() | ||||
|         if hasattr(self, 'id'): | ||||
|         if hasattr(self, 'id') and isinstance(self.id, Hashable): | ||||
|             if self.id in inspected: | ||||
|                 return _changed_fields | ||||
|                 return changed_fields | ||||
|             inspected.add(self.id) | ||||
|  | ||||
|         for field_name in self._fields_ordered: | ||||
| @@ -434,29 +454,17 @@ class BaseDocument(object): | ||||
|             if isinstance(field, ReferenceField): | ||||
|                 continue | ||||
|             elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) | ||||
|                and db_field_name not in _changed_fields): | ||||
|                and db_field_name not in changed_fields): | ||||
|                  # Find all embedded fields that have been changed | ||||
|                 changed = data._get_changed_fields(inspected) | ||||
|                 _changed_fields += ["%s%s" % (key, k) for k in changed if k] | ||||
|                 changed_fields += ["%s%s" % (key, k) for k in changed if k] | ||||
|             elif (isinstance(data, (list, tuple, dict)) and | ||||
|                     db_field_name not in _changed_fields): | ||||
|                 # Loop list / dict fields as they contain documents | ||||
|                 # Determine the iterator to use | ||||
|                 if not hasattr(data, 'items'): | ||||
|                     iterator = enumerate(data) | ||||
|                 else: | ||||
|                     iterator = data.iteritems() | ||||
|                 for index, value in iterator: | ||||
|                     if not hasattr(value, '_get_changed_fields'): | ||||
|                         continue | ||||
|                     if (hasattr(field, 'field') and | ||||
|                         isinstance(field.field, ReferenceField)): | ||||
|                         continue | ||||
|                     list_key = "%s%s." % (key, index) | ||||
|                     changed = value._get_changed_fields(inspected) | ||||
|                     _changed_fields += ["%s%s" % (list_key, k) | ||||
|                                         for k in changed if k] | ||||
|         return _changed_fields | ||||
|                     db_field_name not in changed_fields): | ||||
|                 if (hasattr(field, 'field') and | ||||
|                     isinstance(field.field, ReferenceField)): | ||||
|                     continue | ||||
|                 self._nestable_types_changed_fields(changed_fields, key, data, inspected) | ||||
|         return changed_fields | ||||
|  | ||||
|     def _delta(self): | ||||
|         """Returns the delta (set, unset) of the changes for a document. | ||||
| @@ -552,10 +560,6 @@ class BaseDocument(object): | ||||
|         # class if unavailable | ||||
|         class_name = son.get('_cls', cls._class_name) | ||||
|         data = dict(("%s" % key, value) for key, value in son.iteritems()) | ||||
|         if not UNICODE_KWARGS: | ||||
|             # python 2.6.4 and lower cannot handle unicode keys | ||||
|             # passed to class constructor example: cls(**data) | ||||
|             to_str_keys_recursive(data) | ||||
|  | ||||
|         # Return correct subclass for document type | ||||
|         if class_name != cls._class_name: | ||||
| @@ -773,6 +777,9 @@ class BaseDocument(object): | ||||
|         """Lookup a field based on its attribute and return a list containing | ||||
|         the field's parents and the field. | ||||
|         """ | ||||
|  | ||||
|         ListField = _import_class("ListField") | ||||
|  | ||||
|         if not isinstance(parts, (list, tuple)): | ||||
|             parts = [parts] | ||||
|         fields = [] | ||||
| @@ -780,7 +787,7 @@ class BaseDocument(object): | ||||
|  | ||||
|         for field_name in parts: | ||||
|             # Handle ListField indexing: | ||||
|             if field_name.isdigit() and hasattr(field, 'field'): | ||||
|             if field_name.isdigit() and isinstance(field, ListField): | ||||
|                 new_field = field.field | ||||
|                 fields.append(field_name) | ||||
|                 continue | ||||
|   | ||||
| @@ -89,12 +89,7 @@ class BaseField(object): | ||||
|             return self | ||||
|  | ||||
|         # Get value from document instance if available | ||||
|         value = instance._data.get(self.name) | ||||
|  | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         if isinstance(value, EmbeddedDocument) and value._instance is None: | ||||
|             value._instance = weakref.proxy(instance) | ||||
|         return value | ||||
|         return instance._data.get(self.name) | ||||
|  | ||||
|     def __set__(self, instance, value): | ||||
|         """Descriptor for assigning a value to a field in a document. | ||||
| @@ -116,6 +111,10 @@ class BaseField(object): | ||||
|                 # Values cant be compared eg: naive and tz datetimes | ||||
|                 # So mark it as changed | ||||
|                 instance._mark_as_changed(self.name) | ||||
|  | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         if isinstance(value, EmbeddedDocument) and value._instance is None: | ||||
|             value._instance = weakref.proxy(instance) | ||||
|         instance._data[self.name] = value | ||||
|  | ||||
|     def error(self, message="", errors=None, field_name=None): | ||||
| @@ -203,7 +202,7 @@ class ComplexBaseField(BaseField): | ||||
|         _dereference = _import_class("DeReference")() | ||||
|  | ||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||
|         if instance._initialised and dereference: | ||||
|         if instance._initialised and dereference and instance._data.get(self.name): | ||||
|             instance._data[self.name] = _dereference( | ||||
|                 instance._data.get(self.name), max_depth=1, instance=instance, | ||||
|                 name=self.name | ||||
|   | ||||
| @@ -25,7 +25,7 @@ def _import_class(cls_name): | ||||
|                      'GenericEmbeddedDocumentField', 'GeoPointField', | ||||
|                      'PointField', 'LineStringField', 'ListField', | ||||
|                      'PolygonField', 'ReferenceField', 'StringField', | ||||
|                      'ComplexBaseField') | ||||
|                      'ComplexBaseField', 'GeoJsonBaseField') | ||||
|     queryset_classes = ('OperationError',) | ||||
|     deref_classes = ('DeReference',) | ||||
|  | ||||
|   | ||||
| @@ -18,7 +18,7 @@ _connections = {} | ||||
| _dbs = {} | ||||
|  | ||||
|  | ||||
| def register_connection(alias, name, host='localhost', port=27017, | ||||
| def register_connection(alias, name, host=None, port=None, | ||||
|                         is_slave=False, read_preference=False, slaves=None, | ||||
|                         username=None, password=None, **kwargs): | ||||
|     """Add a connection. | ||||
| @@ -43,8 +43,8 @@ def register_connection(alias, name, host='localhost', port=27017, | ||||
|  | ||||
|     conn_settings = { | ||||
|         'name': name, | ||||
|         'host': host, | ||||
|         'port': port, | ||||
|         'host': host or 'localhost', | ||||
|         'port': port or 27017, | ||||
|         'is_slave': is_slave, | ||||
|         'slaves': slaves or [], | ||||
|         'username': username, | ||||
| @@ -53,16 +53,15 @@ def register_connection(alias, name, host='localhost', port=27017, | ||||
|     } | ||||
|  | ||||
|     # Handle uri style connections | ||||
|     if "://" in host: | ||||
|         uri_dict = uri_parser.parse_uri(host) | ||||
|     if "://" in conn_settings['host']: | ||||
|         uri_dict = uri_parser.parse_uri(conn_settings['host']) | ||||
|         conn_settings.update({ | ||||
|             'host': host, | ||||
|             'name': uri_dict.get('database') or name, | ||||
|             'username': uri_dict.get('username'), | ||||
|             'password': uri_dict.get('password'), | ||||
|             'read_preference': read_preference, | ||||
|         }) | ||||
|         if "replicaSet" in host: | ||||
|         if "replicaSet" in conn_settings['host']: | ||||
|             conn_settings['replicaSet'] = True | ||||
|  | ||||
|     conn_settings.update(kwargs) | ||||
| @@ -94,20 +93,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|             raise ConnectionError(msg) | ||||
|         conn_settings = _connection_settings[alias].copy() | ||||
|  | ||||
|         if hasattr(pymongo, 'version_tuple'):  # Support for 2.1+ | ||||
|             conn_settings.pop('name', None) | ||||
|             conn_settings.pop('slaves', None) | ||||
|             conn_settings.pop('is_slave', None) | ||||
|             conn_settings.pop('username', None) | ||||
|             conn_settings.pop('password', None) | ||||
|         else: | ||||
|             # Get all the slave connections | ||||
|             if 'slaves' in conn_settings: | ||||
|                 slaves = [] | ||||
|                 for slave_alias in conn_settings['slaves']: | ||||
|                     slaves.append(get_connection(slave_alias)) | ||||
|                 conn_settings['slaves'] = slaves | ||||
|                 conn_settings.pop('read_preference', None) | ||||
|         conn_settings.pop('name', None) | ||||
|         conn_settings.pop('slaves', None) | ||||
|         conn_settings.pop('is_slave', None) | ||||
|         conn_settings.pop('username', None) | ||||
|         conn_settings.pop('password', None) | ||||
|  | ||||
|         connection_class = MongoClient | ||||
|         if 'replicaSet' in conn_settings: | ||||
| @@ -120,7 +110,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|             connection_class = MongoReplicaSetClient | ||||
|  | ||||
|         try: | ||||
|             _connections[alias] = connection_class(**conn_settings) | ||||
|             connection = None | ||||
|             connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems()) | ||||
|             for alias, connection_settings in connection_settings_iterator: | ||||
|                 connection_settings.pop('name', None) | ||||
|                 connection_settings.pop('slaves', None) | ||||
|                 connection_settings.pop('is_slave', None) | ||||
|                 connection_settings.pop('username', None) | ||||
|                 connection_settings.pop('password', None) | ||||
|                 if conn_settings == connection_settings and _connections.get(alias, None): | ||||
|                     connection = _connections[alias] | ||||
|                     break | ||||
|  | ||||
|             _connections[alias] = connection if connection else connection_class(**conn_settings) | ||||
|         except Exception, e: | ||||
|             raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) | ||||
|     return _connections[alias] | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.queryset import QuerySet | ||||
|  | ||||
|  | ||||
| __all__ = ("switch_db", "switch_collection", "no_dereference", | ||||
| @@ -162,12 +161,6 @@ class no_sub_classes(object): | ||||
|         return self.cls | ||||
|  | ||||
|  | ||||
| class QuerySetNoDeRef(QuerySet): | ||||
|     """Special no_dereference QuerySet""" | ||||
|     def __dereference(items, max_depth=1, instance=None, name=None): | ||||
|             return items | ||||
|  | ||||
|  | ||||
| class query_counter(object): | ||||
|     """ Query_counter context manager to get the number of queries. """ | ||||
|  | ||||
|   | ||||
| @@ -8,6 +8,10 @@ from django.contrib import auth | ||||
| from django.contrib.auth.models import AnonymousUser | ||||
| from django.utils.translation import ugettext_lazy as _ | ||||
|  | ||||
| from .utils import datetime_now | ||||
|  | ||||
| REDIRECT_FIELD_NAME = 'next' | ||||
|  | ||||
| try: | ||||
|     from django.contrib.auth.hashers import check_password, make_password | ||||
| except ImportError: | ||||
| @@ -33,10 +37,6 @@ except ImportError: | ||||
|         hash = get_hexdigest(algo, salt, raw_password) | ||||
|         return '%s$%s$%s' % (algo, salt, hash) | ||||
|  | ||||
| from .utils import datetime_now | ||||
|  | ||||
| REDIRECT_FIELD_NAME = 'next' | ||||
|  | ||||
|  | ||||
| class ContentType(Document): | ||||
|     name = StringField(max_length=100) | ||||
| @@ -230,6 +230,9 @@ class User(Document): | ||||
|     date_joined = DateTimeField(default=datetime_now, | ||||
|                                 verbose_name=_('date joined')) | ||||
|  | ||||
|     user_permissions = ListField(ReferenceField(Permission), verbose_name=_('user permissions'), | ||||
|                                                 help_text=_('Permissions for the user.')) | ||||
|  | ||||
|     USERNAME_FIELD = 'username' | ||||
|     REQUIRED_FIELDS = ['email'] | ||||
|  | ||||
| @@ -378,9 +381,10 @@ class MongoEngineBackend(object): | ||||
|     supports_object_permissions = False | ||||
|     supports_anonymous_user = False | ||||
|     supports_inactive_user = False | ||||
|     _user_doc = False | ||||
|  | ||||
|     def authenticate(self, username=None, password=None): | ||||
|         user = User.objects(username=username).first() | ||||
|         user = self.user_document.objects(username=username).first() | ||||
|         if user: | ||||
|             if password and user.check_password(password): | ||||
|                 backend = auth.get_backends()[0] | ||||
| @@ -389,8 +393,14 @@ class MongoEngineBackend(object): | ||||
|         return None | ||||
|  | ||||
|     def get_user(self, user_id): | ||||
|         return User.objects.with_id(user_id) | ||||
|         return self.user_document.objects.with_id(user_id) | ||||
|  | ||||
|     @property | ||||
|     def user_document(self): | ||||
|         if self._user_doc is False: | ||||
|             from .mongo_auth.models import get_user_document | ||||
|             self._user_doc = get_user_document() | ||||
|         return self._user_doc | ||||
|  | ||||
| def get_user(userid): | ||||
|     """Returns a User object from an id (User.id). Django's equivalent takes | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.hashers import make_password | ||||
| from django.contrib.auth.models import UserManager | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from django.db import models | ||||
| @@ -105,3 +106,10 @@ class MongoUser(models.Model): | ||||
|     """ | ||||
|  | ||||
|     objects = MongoUserManager() | ||||
|  | ||||
|     class Meta: | ||||
|         app_label = 'mongo_auth' | ||||
|  | ||||
|     def set_password(self, password): | ||||
|         """Doesn't do anything, but works around the issue with Django 1.6.""" | ||||
|         make_password(password) | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| from bson import json_util | ||||
| from django.conf import settings | ||||
| from django.contrib.sessions.backends.base import SessionBase, CreateError | ||||
| from django.core.exceptions import SuspiciousOperation | ||||
| @@ -55,6 +56,12 @@ class SessionStore(SessionBase): | ||||
|     """A MongoEngine-based session store for Django. | ||||
|     """ | ||||
|  | ||||
|     def _get_session(self, *args, **kwargs): | ||||
|         sess = super(SessionStore, self)._get_session(*args, **kwargs) | ||||
|         if sess.get('_auth_user_id', None): | ||||
|             sess['_auth_user_id'] = str(sess.get('_auth_user_id')) | ||||
|         return sess | ||||
|  | ||||
|     def load(self): | ||||
|         try: | ||||
|             s = MongoSession.objects(session_key=self.session_key, | ||||
| @@ -103,3 +110,15 @@ class SessionStore(SessionBase): | ||||
|                 return | ||||
|             session_key = self.session_key | ||||
|         MongoSession.objects(session_key=session_key).delete() | ||||
|  | ||||
|  | ||||
| class BSONSerializer(object): | ||||
|     """ | ||||
|     Serializer that can handle BSON types (eg ObjectId). | ||||
|     """ | ||||
|     def dumps(self, obj): | ||||
|         return json_util.dumps(obj, separators=(',', ':')).encode('ascii') | ||||
|  | ||||
|     def loads(self, data): | ||||
|         return json_util.loads(data.decode('ascii')) | ||||
|  | ||||
|   | ||||
| @@ -76,7 +76,7 @@ class GridFSStorage(Storage): | ||||
|         """Find the documents in the store with the given name | ||||
|         """ | ||||
|         docs = self.document.objects | ||||
|         doc = [d for d in docs if getattr(d, self.field).name == name] | ||||
|         doc = [d for d in docs if hasattr(getattr(d, self.field), 'name') and getattr(d, self.field).name == name] | ||||
|         if doc: | ||||
|             return doc[0] | ||||
|         else: | ||||
|   | ||||
| @@ -12,7 +12,9 @@ from mongoengine.common import _import_class | ||||
| from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, | ||||
|                               BaseDocument, BaseDict, BaseList, | ||||
|                               ALLOW_INHERITANCE, get_document) | ||||
| from mongoengine.queryset import OperationError, NotUniqueError, QuerySet | ||||
| from mongoengine.errors import ValidationError | ||||
| from mongoengine.queryset import (OperationError, NotUniqueError, | ||||
|                                   QuerySet, transform) | ||||
| from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME | ||||
| from mongoengine.context_managers import switch_db, switch_collection | ||||
|  | ||||
| @@ -67,7 +69,7 @@ class EmbeddedDocument(BaseDocument): | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, self.__class__): | ||||
|             return self._data == other._data | ||||
|             return self.to_mongo() == other.to_mongo() | ||||
|         return False | ||||
|  | ||||
|     def __ne__(self, other): | ||||
| @@ -182,7 +184,7 @@ class Document(BaseDocument): | ||||
|  | ||||
|     def save(self, force_insert=False, validate=True, clean=True, | ||||
|              write_concern=None,  cascade=None, cascade_kwargs=None, | ||||
|              _refs=None, **kwargs): | ||||
|              _refs=None, save_condition=None, **kwargs): | ||||
|         """Save the :class:`~mongoengine.Document` to the database. If the | ||||
|         document already exists, it will be updated, otherwise it will be | ||||
|         created. | ||||
| @@ -205,7 +207,8 @@ class Document(BaseDocument): | ||||
|         :param cascade_kwargs: (optional) kwargs dictionary to be passed throw | ||||
|             to cascading saves.  Implies ``cascade=True``. | ||||
|         :param _refs: A list of processed references used in cascading saves | ||||
|  | ||||
|         :param save_condition: only perform save if matching record in db | ||||
|             satisfies condition(s) (e.g., version number) | ||||
|         .. versionchanged:: 0.5 | ||||
|             In existing documents it only saves changed fields using | ||||
|             set / unset.  Saves are cascaded and any | ||||
| @@ -219,6 +222,9 @@ class Document(BaseDocument): | ||||
|             meta['cascade'] = True.  Also you can pass different kwargs to | ||||
|             the cascade save using cascade_kwargs which overwrites the | ||||
|             existing kwargs with custom values. | ||||
|         .. versionchanged:: 0.8.5 | ||||
|             Optional save_condition that only overwrites existing documents | ||||
|             if the condition is satisfied in the current db record. | ||||
|         """ | ||||
|         signals.pre_save.send(self.__class__, document=self) | ||||
|  | ||||
| @@ -232,7 +238,8 @@ class Document(BaseDocument): | ||||
|  | ||||
|         created = ('_id' not in doc or self._created or force_insert) | ||||
|  | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, created=created) | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||
|                                               created=created) | ||||
|  | ||||
|         try: | ||||
|             collection = self._get_collection() | ||||
| @@ -245,7 +252,12 @@ class Document(BaseDocument): | ||||
|                 object_id = doc['_id'] | ||||
|                 updates, removals = self._delta() | ||||
|                 # Need to add shard key to query, or you get an error | ||||
|                 select_dict = {'_id': object_id} | ||||
|                 if save_condition is not None: | ||||
|                     select_dict = transform.query(self.__class__, | ||||
|                                                   **save_condition) | ||||
|                 else: | ||||
|                     select_dict = {} | ||||
|                 select_dict['_id'] = object_id | ||||
|                 shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||
|                 for k in shard_key: | ||||
|                     actual_key = self._db_field_map.get(k, k) | ||||
| @@ -265,10 +277,12 @@ class Document(BaseDocument): | ||||
|                 if removals: | ||||
|                     update_query["$unset"] = removals | ||||
|                 if updates or removals: | ||||
|                     upsert = save_condition is None | ||||
|                     last_error = collection.update(select_dict, update_query, | ||||
|                                                    upsert=True, **write_concern) | ||||
|                                                    upsert=upsert, **write_concern) | ||||
|                     created = is_new_object(last_error) | ||||
|  | ||||
|  | ||||
|             if cascade is None: | ||||
|                 cascade = self._meta.get('cascade', False) or cascade_kwargs is not None | ||||
|  | ||||
| @@ -283,7 +297,9 @@ class Document(BaseDocument): | ||||
|                     kwargs.update(cascade_kwargs) | ||||
|                 kwargs['_refs'] = _refs | ||||
|                 self.cascade_save(**kwargs) | ||||
|  | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             message = u'Tried to save duplicate unique keys (%s)' | ||||
|             raise NotUniqueError(message % unicode(err)) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             message = 'Could not save document (%s)' | ||||
|             if re.match('^E1100[01] duplicate key', unicode(err)): | ||||
| @@ -453,14 +469,16 @@ class Document(BaseDocument): | ||||
|         .. versionadded:: 0.1.2 | ||||
|         .. versionchanged:: 0.6  Now chainable | ||||
|         """ | ||||
|         if not self.pk: | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||
|                 **self._object_key).limit(1).select_related(max_depth=max_depth) | ||||
|                     **self._object_key).limit(1).select_related(max_depth=max_depth) | ||||
|  | ||||
|  | ||||
|         if obj: | ||||
|             obj = obj[0] | ||||
|         else: | ||||
|             msg = "Reloaded document has been deleted" | ||||
|             raise OperationError(msg) | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|         for field in self._fields_ordered: | ||||
|             setattr(self, field, self._reload(field, obj[field])) | ||||
|         self._changed_fields = obj._changed_fields | ||||
| @@ -550,6 +568,8 @@ class Document(BaseDocument): | ||||
|         index_cls = cls._meta.get('index_cls', True) | ||||
|  | ||||
|         collection = cls._get_collection() | ||||
|         if collection.read_preference > 1: | ||||
|             return | ||||
|  | ||||
|         # determine if an index which we are creating includes | ||||
|         # _cls as its first field; if so, we can avoid creating | ||||
|   | ||||
| @@ -42,7 +42,8 @@ __all__ = ['StringField',  'URLField',  'EmailField',  'IntField',  'LongField', | ||||
|            'GenericReferenceField',  'BinaryField',  'GridFSError', | ||||
|            'GridFSProxy',  'FileField',  'ImageGridFsProxy', | ||||
|            'ImproperlyConfigured',  'ImageField',  'GeoPointField', 'PointField', | ||||
|            'LineStringField', 'PolygonField', 'SequenceField',  'UUIDField'] | ||||
|            'LineStringField', 'PolygonField', 'SequenceField',  'UUIDField', | ||||
|            'GeoJsonBaseField'] | ||||
|  | ||||
|  | ||||
| RECURSIVE_REFERENCE_CONSTANT = 'self' | ||||
| @@ -152,7 +153,7 @@ class EmailField(StringField): | ||||
|     EMAIL_REGEX = re.compile( | ||||
|         r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*"  # dot-atom | ||||
|         r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"'  # quoted-string | ||||
|         r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE  # domain | ||||
|         r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE  # domain | ||||
|     ) | ||||
|  | ||||
|     def validate(self, value): | ||||
| @@ -304,7 +305,10 @@ class DecimalField(BaseField): | ||||
|             return value | ||||
|  | ||||
|         # Convert to string for python 2.6 before casting to Decimal | ||||
|         value = decimal.Decimal("%s" % value) | ||||
|         try: | ||||
|             value = decimal.Decimal("%s" % value) | ||||
|         except decimal.InvalidOperation: | ||||
|             return value | ||||
|         return value.quantize(self.precision, rounding=self.rounding) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
| @@ -387,7 +391,7 @@ class DateTimeField(BaseField): | ||||
|         if dateutil: | ||||
|             try: | ||||
|                 return dateutil.parser.parse(value) | ||||
|             except ValueError: | ||||
|             except (TypeError, ValueError): | ||||
|                 return None | ||||
|  | ||||
|         # split usecs, because they are not recognized by strptime. | ||||
| @@ -735,13 +739,28 @@ class SortedListField(ListField): | ||||
|                           reverse=self._order_reverse) | ||||
|         return sorted(value, reverse=self._order_reverse) | ||||
|  | ||||
| def key_not_string(d): | ||||
|     """ Helper function to recursively determine if any key in a dictionary is | ||||
|     not a string. | ||||
|     """ | ||||
|     for k, v in d.items(): | ||||
|         if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): | ||||
|             return True | ||||
|  | ||||
| def key_has_dot_or_dollar(d): | ||||
|     """ Helper function to recursively determine if any key in a dictionary | ||||
|     contains a dot or a dollar sign. | ||||
|     """ | ||||
|     for k, v in d.items(): | ||||
|         if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): | ||||
|             return True | ||||
|  | ||||
| class DictField(ComplexBaseField): | ||||
|     """A dictionary field that wraps a standard Python dictionary. This is | ||||
|     similar to an embedded document, but the structure is not defined. | ||||
|  | ||||
|     .. note:: | ||||
|         Required means it cannot be empty - as the default for ListFields is [] | ||||
|         Required means it cannot be empty - as the default for DictFields is {} | ||||
|  | ||||
|     .. versionadded:: 0.3 | ||||
|     .. versionchanged:: 0.5 - Can now handle complex / varying types of data | ||||
| @@ -761,11 +780,11 @@ class DictField(ComplexBaseField): | ||||
|         if not isinstance(value, dict): | ||||
|             self.error('Only dictionaries may be used in a DictField') | ||||
|  | ||||
|         if any(k for k in value.keys() if not isinstance(k, basestring)): | ||||
|         if key_not_string(value): | ||||
|             msg = ("Invalid dictionary key - documents must " | ||||
|                    "have only string keys") | ||||
|             self.error(msg) | ||||
|         if any(('.' in k or '$' in k) for k in value.keys()): | ||||
|         if key_has_dot_or_dollar(value): | ||||
|             self.error('Invalid dictionary key name - keys may not contain "."' | ||||
|                        ' or "$" characters') | ||||
|         super(DictField, self).validate(value) | ||||
| @@ -1004,7 +1023,10 @@ class GenericReferenceField(BaseField): | ||||
|         id_ = id_field.to_mongo(id_) | ||||
|         collection = document._get_collection_name() | ||||
|         ref = DBRef(collection, id_) | ||||
|         return {'_cls': document._class_name, '_ref': ref} | ||||
|         return SON(( | ||||
|             ('_cls', document._class_name), | ||||
|             ('_ref', ref) | ||||
|         )) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if value is None: | ||||
| @@ -1591,7 +1613,12 @@ class UUIDField(BaseField): | ||||
|  | ||||
|  | ||||
| class GeoPointField(BaseField): | ||||
|     """A list storing a latitude and longitude. | ||||
|     """A list storing a longitude and latitude coordinate.  | ||||
|  | ||||
|     .. note:: this represents a generic point in a 2D plane and a legacy way of  | ||||
|         representing a geo point. It admits 2d indexes but not "2dsphere" indexes  | ||||
|         in MongoDB > 2.4 which are more natural for modeling geospatial points.  | ||||
|         See :ref:`geospatial-indexes`  | ||||
|  | ||||
|     .. versionadded:: 0.4 | ||||
|     """ | ||||
| @@ -1613,7 +1640,7 @@ class GeoPointField(BaseField): | ||||
|  | ||||
|  | ||||
| class PointField(GeoJsonBaseField): | ||||
|     """A geo json field storing a latitude and longitude. | ||||
|     """A GeoJSON field storing a longitude and latitude coordinate. | ||||
|  | ||||
|     The data is represented as: | ||||
|  | ||||
| @@ -1632,7 +1659,7 @@ class PointField(GeoJsonBaseField): | ||||
|  | ||||
|  | ||||
| class LineStringField(GeoJsonBaseField): | ||||
|     """A geo json field storing a line of latitude and longitude coordinates. | ||||
|     """A GeoJSON field storing a line of longitude and latitude coordinates. | ||||
|  | ||||
|     The data is represented as: | ||||
|  | ||||
| @@ -1650,7 +1677,7 @@ class LineStringField(GeoJsonBaseField): | ||||
|  | ||||
|  | ||||
| class PolygonField(GeoJsonBaseField): | ||||
|     """A geo json field storing a polygon of latitude and longitude coordinates. | ||||
|     """A GeoJSON field storing a polygon of longitude and latitude coordinates. | ||||
|  | ||||
|     The data is represented as: | ||||
|  | ||||
|   | ||||
| @@ -3,8 +3,6 @@ | ||||
| import sys | ||||
|  | ||||
| PY3 = sys.version_info[0] == 3 | ||||
| PY25 = sys.version_info[:2] == (2, 5) | ||||
| UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264 | ||||
|  | ||||
| if PY3: | ||||
|     import codecs | ||||
| @@ -29,33 +27,3 @@ else: | ||||
|     txt_type = unicode | ||||
|  | ||||
| str_types = (bin_type, txt_type) | ||||
|  | ||||
| if PY25: | ||||
|     def product(*args, **kwds): | ||||
|         pools = map(tuple, args) * kwds.get('repeat', 1) | ||||
|         result = [[]] | ||||
|         for pool in pools: | ||||
|             result = [x + [y] for x in result for y in pool] | ||||
|         for prod in result: | ||||
|             yield tuple(prod) | ||||
|     reduce = reduce | ||||
| else: | ||||
|     from itertools import product | ||||
|     from functools import reduce | ||||
|  | ||||
|  | ||||
| # For use with Python 2.5 | ||||
| # converts all keys from unicode to str for d and all nested dictionaries | ||||
| def to_str_keys_recursive(d): | ||||
|     if isinstance(d, list): | ||||
|         for val in d: | ||||
|             if isinstance(val, (dict, list)): | ||||
|                 to_str_keys_recursive(val) | ||||
|     elif isinstance(d, dict): | ||||
|         for key, val in d.items(): | ||||
|             if isinstance(val, (dict, list)): | ||||
|                 to_str_keys_recursive(val) | ||||
|             if isinstance(key, unicode): | ||||
|                 d[str(key)] = d.pop(key) | ||||
|     else: | ||||
|         raise ValueError("non list/dict parameter not allowed") | ||||
|   | ||||
| @@ -10,14 +10,15 @@ import warnings | ||||
| from bson.code import Code | ||||
| from bson import json_util | ||||
| import pymongo | ||||
| import pymongo.errors | ||||
| from pymongo.common import validate_read_preference | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.context_managers import switch_db | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.base.common import get_document | ||||
| from mongoengine.errors import (OperationError, NotUniqueError, | ||||
|                                 InvalidQueryError, LookUpError) | ||||
|  | ||||
| from mongoengine.queryset import transform | ||||
| from mongoengine.queryset.field_list import QueryFieldList | ||||
| from mongoengine.queryset.visitor import Q, QNode | ||||
| @@ -50,7 +51,7 @@ class BaseQuerySet(object): | ||||
|         self._initial_query = {} | ||||
|         self._where_clause = None | ||||
|         self._loaded_fields = QueryFieldList() | ||||
|         self._ordering = [] | ||||
|         self._ordering = None | ||||
|         self._snapshot = False | ||||
|         self._timeout = True | ||||
|         self._class_check = True | ||||
| @@ -154,6 +155,22 @@ class BaseQuerySet(object): | ||||
|     def __iter__(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _has_data(self): | ||||
|         """ Retrieves whether cursor has any data. """ | ||||
|  | ||||
|         queryset = self.order_by() | ||||
|         return False if queryset.first() is None else True | ||||
|  | ||||
|     def __nonzero__(self): | ||||
|         """ Avoid to open all records in an if stmt in Py2. """ | ||||
|  | ||||
|         return self._has_data() | ||||
|  | ||||
|     def __bool__(self): | ||||
|         """ Avoid to open all records in an if stmt in Py3. """ | ||||
|  | ||||
|         return self._has_data() | ||||
|  | ||||
|     # Core functions | ||||
|  | ||||
|     def all(self): | ||||
| @@ -302,8 +319,11 @@ class BaseQuerySet(object): | ||||
|         signals.pre_bulk_insert.send(self._document, documents=docs) | ||||
|         try: | ||||
|             ids = self._collection.insert(raw, **write_concern) | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             message = 'Could not save document (%s)'; | ||||
|             raise NotUniqueError(message % unicode(err)) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             message = 'Could not save document (%s)' | ||||
|             message = 'Could not save document (%s)'; | ||||
|             if re.match('^E1100[01] duplicate key', unicode(err)): | ||||
|                 # E11000 - duplicate key error index | ||||
|                 # E11001 - duplicate key on update | ||||
| @@ -331,7 +351,7 @@ class BaseQuerySet(object): | ||||
|             :meth:`skip` that has been applied to this cursor into account when | ||||
|             getting the count | ||||
|         """ | ||||
|         if self._limit == 0 and with_limit_and_skip: | ||||
|         if self._limit == 0 and with_limit_and_skip or self._none: | ||||
|             return 0 | ||||
|         return self._cursor.count(with_limit_and_skip=with_limit_and_skip) | ||||
|  | ||||
| @@ -386,7 +406,7 @@ class BaseQuerySet(object): | ||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self}) | ||||
|                 ref_q_count = ref_q.count() | ||||
|                 if (doc != document_cls and ref_q_count > 0 | ||||
|                    or (doc == document_cls and ref_q_count > 0)): | ||||
|                     or (doc == document_cls and ref_q_count > 0)): | ||||
|                     ref_q.delete(write_concern=write_concern) | ||||
|             elif rule == NULLIFY: | ||||
|                 document_cls.objects(**{field_name + '__in': self}).update( | ||||
| @@ -440,6 +460,8 @@ class BaseQuerySet(object): | ||||
|                 return result | ||||
|             elif result: | ||||
|                 return result['n'] | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             raise NotUniqueError(u'Update failed (%s)' % unicode(err)) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             if unicode(err) == u'multi not coded yet': | ||||
|                 message = u'update() method requires MongoDB 1.1.3+' | ||||
| @@ -463,6 +485,59 @@ class BaseQuerySet(object): | ||||
|         return self.update( | ||||
|             upsert=upsert, multi=False, write_concern=write_concern, **update) | ||||
|  | ||||
|     def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): | ||||
|         """Update and return the updated document. | ||||
|  | ||||
|         Returns either the document before or after modification based on `new` | ||||
|         parameter. If no documents match the query and `upsert` is false, | ||||
|         returns ``None``. If upserting and `new` is false, returns ``None``. | ||||
|  | ||||
|         If the full_response parameter is ``True``, the return value will be | ||||
|         the entire response object from the server, including the 'ok' and | ||||
|         'lastErrorObject' fields, rather than just the modified document. | ||||
|         This is useful mainly because the 'lastErrorObject' document holds | ||||
|         information about the command's execution. | ||||
|  | ||||
|         :param upsert: insert if document doesn't exist (default ``False``) | ||||
|         :param full_response: return the entire response object from the | ||||
|             server (default ``False``) | ||||
|         :param remove: remove rather than updating (default ``False``) | ||||
|         :param new: return updated rather than original document | ||||
|             (default ``False``) | ||||
|         :param update: Django-style update keyword arguments | ||||
|  | ||||
|         .. versionadded:: 0.9 | ||||
|         """ | ||||
|  | ||||
|         if remove and new: | ||||
|             raise OperationError("Conflicting parameters: remove and new") | ||||
|  | ||||
|         if not update and not upsert and not remove: | ||||
|             raise OperationError("No update parameters, must either update or remove") | ||||
|  | ||||
|         queryset = self.clone() | ||||
|         query = queryset._query | ||||
|         update = transform.update(queryset._document, **update) | ||||
|         sort = queryset._ordering | ||||
|  | ||||
|         try: | ||||
|             result = queryset._collection.find_and_modify( | ||||
|                 query, update, upsert=upsert, sort=sort, remove=remove, new=new, | ||||
|                 full_response=full_response, **self._cursor_args) | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             raise NotUniqueError(u"Update failed (%s)" % err) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             raise OperationError(u"Update failed (%s)" % err) | ||||
|  | ||||
|         if full_response: | ||||
|             if result["value"] is not None: | ||||
|                 result["value"] = self._document._from_son(result["value"]) | ||||
|         else: | ||||
|             if result is not None: | ||||
|                 result = self._document._from_son(result) | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     def with_id(self, object_id): | ||||
|         """Retrieve the object matching the id provided.  Uses `object_id` only | ||||
|         and raises InvalidQueryError if a filter has been applied. Returns | ||||
| @@ -519,6 +594,19 @@ class BaseQuerySet(object): | ||||
|  | ||||
|         return self | ||||
|  | ||||
|     def using(self, alias): | ||||
|         """This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database. | ||||
|  | ||||
|         :param alias: The database alias | ||||
|  | ||||
|         .. versionadded:: 0.8 | ||||
|         """ | ||||
|  | ||||
|         with switch_db(self._document, alias) as cls: | ||||
|             collection = cls._get_collection() | ||||
|  | ||||
|         return self.clone_into(self.__class__(self._document, collection)) | ||||
|  | ||||
|     def clone(self): | ||||
|         """Creates a copy of the current | ||||
|           :class:`~mongoengine.queryset.QuerySet` | ||||
| @@ -621,8 +709,15 @@ class BaseQuerySet(object): | ||||
|         try: | ||||
|             field = self._fields_to_dbfields([field]).pop() | ||||
|         finally: | ||||
|             return self._dereference(queryset._cursor.distinct(field), 1, | ||||
|                                      name=field, instance=self._document) | ||||
|             distinct = self._dereference(queryset._cursor.distinct(field), 1, | ||||
|                                          name=field, instance=self._document) | ||||
|  | ||||
|             # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) | ||||
|             doc_field = getattr(self._document._fields.get(field), "field", None) | ||||
|             instance = getattr(doc_field, "document_type", False) | ||||
|             if instance: | ||||
|                 distinct = [instance(**doc) for doc in distinct] | ||||
|             return distinct | ||||
|  | ||||
|     def only(self, *fields): | ||||
|         """Load only a subset of this document's fields. :: | ||||
| @@ -850,7 +945,7 @@ class BaseQuerySet(object): | ||||
|         :param output: output collection name, if set to 'inline' will try to | ||||
|            use :class:`~pymongo.collection.Collection.inline_map_reduce` | ||||
|            This can also be a dictionary containing output options | ||||
|            see: http://docs.mongodb.org/manual/reference/commands/#mapReduce | ||||
|            see: http://docs.mongodb.org/manual/reference/command/mapReduce/#dbcmd.mapReduce | ||||
|         :param finalize_f: finalize function, an optional function that | ||||
|                            performs any post-reduction processing. | ||||
|         :param scope: values to insert into map/reduce global scope. Optional. | ||||
| @@ -916,7 +1011,7 @@ class BaseQuerySet(object): | ||||
|             mr_args['out'] = output | ||||
|  | ||||
|         results = getattr(queryset._collection, map_reduce_function)( | ||||
|                           map_f, reduce_f, **mr_args) | ||||
|             map_f, reduce_f, **mr_args) | ||||
|  | ||||
|         if map_reduce_function == 'map_reduce': | ||||
|             results = results.find() | ||||
| @@ -1179,8 +1274,9 @@ class BaseQuerySet(object): | ||||
|             if self._ordering: | ||||
|                 # Apply query ordering | ||||
|                 self._cursor_obj.sort(self._ordering) | ||||
|             elif self._document._meta['ordering']: | ||||
|                 # Otherwise, apply the ordering from the document model | ||||
|             elif self._ordering is None and self._document._meta['ordering']: | ||||
|                 # Otherwise, apply the ordering from the document model, unless | ||||
|                 # it's been explicitly cleared via order_by with no arguments | ||||
|                 order = self._get_order_by(self._document._meta['ordering']) | ||||
|                 self._cursor_obj.sort(order) | ||||
|  | ||||
| @@ -1352,7 +1448,7 @@ class BaseQuerySet(object): | ||||
|                 for subdoc in subclasses: | ||||
|                     try: | ||||
|                         subfield = ".".join(f.db_field for f in | ||||
|                                         subdoc._lookup_field(field.split('.'))) | ||||
|                                             subdoc._lookup_field(field.split('.'))) | ||||
|                         ret.append(subfield) | ||||
|                         found = True | ||||
|                         break | ||||
| @@ -1382,7 +1478,7 @@ class BaseQuerySet(object): | ||||
|                 pass | ||||
|             key_list.append((key, direction)) | ||||
|  | ||||
|         if self._cursor_obj: | ||||
|         if self._cursor_obj and key_list: | ||||
|             self._cursor_obj.sort(key_list) | ||||
|         return key_list | ||||
|  | ||||
| @@ -1440,6 +1536,7 @@ class BaseQuerySet(object): | ||||
|                     # type of this field and use the corresponding | ||||
|                     # .to_python(...) | ||||
|                     from mongoengine.fields import EmbeddedDocumentField | ||||
|  | ||||
|                     obj = self._document | ||||
|                     for chunk in path.split('.'): | ||||
|                         obj = getattr(obj, chunk, None) | ||||
| @@ -1450,6 +1547,7 @@ class BaseQuerySet(object): | ||||
|                     if obj and data is not None: | ||||
|                         data = obj.to_python(data) | ||||
|             return data | ||||
|  | ||||
|         return clean(row) | ||||
|  | ||||
|     def _sub_js_fields(self, code): | ||||
| @@ -1458,6 +1556,7 @@ class BaseQuerySet(object): | ||||
|         substituted for the MongoDB name of the field (specified using the | ||||
|         :attr:`name` keyword argument in a field's constructor). | ||||
|         """ | ||||
|  | ||||
|         def field_sub(match): | ||||
|             # Extract just the field name, and look up the field objects | ||||
|             field_name = match.group(1).split('.') | ||||
| @@ -1491,4 +1590,4 @@ class BaseQuerySet(object): | ||||
|         msg = ("Doc.objects()._ensure_indexes() is deprecated. " | ||||
|                "Use Doc.ensure_indexes() instead.") | ||||
|         warnings.warn(msg, DeprecationWarning) | ||||
|         self._document.__class__.ensure_indexes() | ||||
|         self._document.__class__.ensure_indexes() | ||||
|   | ||||
| @@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet): | ||||
|             queryset = self.clone() | ||||
|         queryset.rewind() | ||||
|         return queryset | ||||
|  | ||||
|  | ||||
| class QuerySetNoDeRef(QuerySet): | ||||
|     """Special no_dereference QuerySet""" | ||||
|  | ||||
|     def __dereference(items, max_depth=1, instance=None, name=None): | ||||
|         return items | ||||
| @@ -38,7 +38,7 @@ def query(_doc_cls=None, _field_operation=False, **query): | ||||
|             mongo_query.update(value) | ||||
|             continue | ||||
|  | ||||
|         parts = key.split('__') | ||||
|         parts = key.rsplit('__') | ||||
|         indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] | ||||
|         parts = [part for part in parts if not part.isdigit()] | ||||
|         # Check for an operator and transform to mongo-style if there is | ||||
| @@ -206,6 +206,10 @@ def update(_doc_cls=None, **update): | ||||
|             else: | ||||
|                 field = cleaned_fields[-1] | ||||
|  | ||||
|             GeoJsonBaseField = _import_class("GeoJsonBaseField") | ||||
|             if isinstance(field, GeoJsonBaseField): | ||||
|                 value = field.to_mongo(value) | ||||
|  | ||||
|             if op in (None, 'set', 'push', 'pull'): | ||||
|                 if field.required or value is not None: | ||||
|                     value = field.prepare_query_value(op, value) | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import copy | ||||
|  | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.python_support import product, reduce | ||||
| from itertools import product | ||||
| from functools import reduce | ||||
|  | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.queryset import transform | ||||
|  | ||||
| __all__ = ('Q',) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user