Merge branch 'master' into pr/592
This commit is contained in:
		| @@ -1,12 +1,13 @@ | ||||
| import weakref | ||||
| import functools | ||||
| import itertools | ||||
| from mongoengine.common import _import_class | ||||
|  | ||||
| __all__ = ("BaseDict", "BaseList") | ||||
|  | ||||
|  | ||||
| class BaseDict(dict): | ||||
|     """A special dict so we can watch any changes | ||||
|     """ | ||||
|     """A special dict so we can watch any changes""" | ||||
|  | ||||
|     _dereferenced = False | ||||
|     _instance = None | ||||
| @@ -21,29 +22,37 @@ class BaseDict(dict): | ||||
|         self._name = name | ||||
|         return super(BaseDict, self).__init__(dict_items) | ||||
|  | ||||
|     def __getitem__(self, *args, **kwargs): | ||||
|         value = super(BaseDict, self).__getitem__(*args, **kwargs) | ||||
|     def __getitem__(self, key, *args, **kwargs): | ||||
|         value = super(BaseDict, self).__getitem__(key) | ||||
|  | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         if isinstance(value, EmbeddedDocument) and value._instance is None: | ||||
|             value._instance = self._instance | ||||
|         elif not isinstance(value, BaseDict) and isinstance(value, dict): | ||||
|             value = BaseDict(value, None, '%s.%s' % (self._name, key)) | ||||
|             super(BaseDict, self).__setitem__(key, value) | ||||
|             value._instance = self._instance | ||||
|         elif not isinstance(value, BaseList) and isinstance(value, list): | ||||
|             value = BaseList(value, None, '%s.%s' % (self._name, key)) | ||||
|             super(BaseDict, self).__setitem__(key, value) | ||||
|             value._instance = self._instance | ||||
|         return value | ||||
|  | ||||
|     def __setitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__setitem__(*args, **kwargs) | ||||
|     def __setitem__(self, key, value, *args, **kwargs): | ||||
|         self._mark_as_changed(key) | ||||
|         return super(BaseDict, self).__setitem__(key, value) | ||||
|  | ||||
|     def __delete__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__delete__(*args, **kwargs) | ||||
|  | ||||
|     def __delitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__delitem__(*args, **kwargs) | ||||
|     def __delitem__(self, key, *args, **kwargs): | ||||
|         self._mark_as_changed(key) | ||||
|         return super(BaseDict, self).__delitem__(key) | ||||
|  | ||||
|     def __delattr__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__delattr__(*args, **kwargs) | ||||
|     def __delattr__(self, key, *args, **kwargs): | ||||
|         self._mark_as_changed(key) | ||||
|         return super(BaseDict, self).__delattr__(key) | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         self.instance = None | ||||
| @@ -70,9 +79,12 @@ class BaseDict(dict): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).update(*args, **kwargs) | ||||
|  | ||||
|     def _mark_as_changed(self): | ||||
|     def _mark_as_changed(self, key=None): | ||||
|         if hasattr(self._instance, '_mark_as_changed'): | ||||
|             self._instance._mark_as_changed(self._name) | ||||
|             if key: | ||||
|                 self._instance._mark_as_changed('%s.%s' % (self._name, key)) | ||||
|             else: | ||||
|                 self._instance._mark_as_changed(self._name) | ||||
|  | ||||
|  | ||||
| class BaseList(list): | ||||
| @@ -92,21 +104,35 @@ class BaseList(list): | ||||
|         self._name = name | ||||
|         return super(BaseList, self).__init__(list_items) | ||||
|  | ||||
|     def __getitem__(self, *args, **kwargs): | ||||
|         value = super(BaseList, self).__getitem__(*args, **kwargs) | ||||
|     def __getitem__(self, key, *args, **kwargs): | ||||
|         value = super(BaseList, self).__getitem__(key) | ||||
|  | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         if isinstance(value, EmbeddedDocument) and value._instance is None: | ||||
|             value._instance = self._instance | ||||
|         elif not isinstance(value, BaseDict) and isinstance(value, dict): | ||||
|             value = BaseDict(value, None, '%s.%s' % (self._name, key)) | ||||
|             super(BaseList, self).__setitem__(key, value) | ||||
|             value._instance = self._instance | ||||
|         elif not isinstance(value, BaseList) and isinstance(value, list): | ||||
|             value = BaseList(value, None, '%s.%s' % (self._name, key)) | ||||
|             super(BaseList, self).__setitem__(key, value) | ||||
|             value._instance = self._instance | ||||
|         return value | ||||
|  | ||||
|     def __setitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseList, self).__setitem__(*args, **kwargs) | ||||
|     def __setitem__(self, key, value, *args, **kwargs): | ||||
|         if isinstance(key, slice): | ||||
|             self._mark_as_changed() | ||||
|         else: | ||||
|             self._mark_as_changed(key) | ||||
|         return super(BaseList, self).__setitem__(key, value) | ||||
|  | ||||
|     def __delitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseList, self).__delitem__(*args, **kwargs) | ||||
|     def __delitem__(self, key, *args, **kwargs): | ||||
|         if isinstance(key, slice): | ||||
|             self._mark_as_changed() | ||||
|         else: | ||||
|             self._mark_as_changed(key) | ||||
|         return super(BaseList, self).__delitem__(key) | ||||
|  | ||||
|     def __setslice__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
| @@ -153,6 +179,103 @@ class BaseList(list): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseList, self).sort(*args, **kwargs) | ||||
|  | ||||
|     def _mark_as_changed(self): | ||||
|     def _mark_as_changed(self, key=None): | ||||
|         if hasattr(self._instance, '_mark_as_changed'): | ||||
|             self._instance._mark_as_changed(self._name) | ||||
|             if key: | ||||
|                 self._instance._mark_as_changed('%s.%s' % (self._name, key)) | ||||
|             else: | ||||
|                 self._instance._mark_as_changed(self._name) | ||||
|  | ||||
|  | ||||
| class StrictDict(object): | ||||
|     __slots__ = () | ||||
|     _special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create']) | ||||
|     _classes = {} | ||||
|     def __init__(self, **kwargs): | ||||
|         for k,v in kwargs.iteritems(): | ||||
|             setattr(self, k, v) | ||||
|     def __getitem__(self, key): | ||||
|         key = '_reserved_' + key if key in self._special_fields else key | ||||
|         try: | ||||
|             return getattr(self, key) | ||||
|         except AttributeError: | ||||
|             raise KeyError(key) | ||||
|     def __setitem__(self, key, value): | ||||
|         key = '_reserved_' + key if key in self._special_fields else key | ||||
|         return setattr(self, key, value) | ||||
|     def __contains__(self, key): | ||||
|         return hasattr(self, key) | ||||
|     def get(self, key, default=None): | ||||
|         try: | ||||
|             return self[key] | ||||
|         except KeyError: | ||||
|             return default | ||||
|     def pop(self, key, default=None): | ||||
|         v = self.get(key, default) | ||||
|         try: | ||||
|             delattr(self, key) | ||||
|         except AttributeError: | ||||
|             pass | ||||
|         return v | ||||
|     def iteritems(self): | ||||
|         for key in self: | ||||
|             yield key, self[key] | ||||
|     def items(self): | ||||
|         return [(k, self[k]) for k in iter(self)] | ||||
|     def keys(self): | ||||
|         return list(iter(self)) | ||||
|     def __iter__(self): | ||||
|         return (key for key in self.__slots__ if hasattr(self, key)) | ||||
|     def __len__(self): | ||||
|         return len(list(self.iteritems())) | ||||
|     def __eq__(self, other): | ||||
|         return self.items() == other.items() | ||||
|     def __neq__(self, other): | ||||
|         return self.items() != other.items() | ||||
|  | ||||
|     @classmethod | ||||
|     def create(cls, allowed_keys): | ||||
|         allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) | ||||
|         allowed_keys = frozenset(allowed_keys_tuple) | ||||
|         if allowed_keys not in cls._classes: | ||||
|             class SpecificStrictDict(cls): | ||||
|                 __slots__ = allowed_keys_tuple | ||||
|             cls._classes[allowed_keys] = SpecificStrictDict | ||||
|         return cls._classes[allowed_keys] | ||||
|  | ||||
|  | ||||
| class SemiStrictDict(StrictDict): | ||||
|     __slots__ = ('_extras') | ||||
|     _classes = {} | ||||
|     def __getattr__(self, attr): | ||||
|         try: | ||||
|             super(SemiStrictDict, self).__getattr__(attr) | ||||
|         except AttributeError: | ||||
|             try: | ||||
|                 return self.__getattribute__('_extras')[attr] | ||||
|             except KeyError as e: | ||||
|                 raise AttributeError(e) | ||||
|     def __setattr__(self, attr, value): | ||||
|         try: | ||||
|             super(SemiStrictDict, self).__setattr__(attr, value) | ||||
|         except AttributeError: | ||||
|             try: | ||||
|                 self._extras[attr] = value | ||||
|             except AttributeError: | ||||
|                 self._extras = {attr: value} | ||||
|  | ||||
|     def __delattr__(self, attr): | ||||
|         try: | ||||
|             super(SemiStrictDict, self).__delattr__(attr) | ||||
|         except AttributeError: | ||||
|             try: | ||||
|                 del self._extras[attr] | ||||
|             except KeyError as e: | ||||
|                 raise AttributeError(e) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         try: | ||||
|             extras_iter = iter(self.__getattribute__('_extras')) | ||||
|         except AttributeError: | ||||
|             extras_iter = () | ||||
|         return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter) | ||||
|   | ||||
| @@ -13,24 +13,23 @@ from mongoengine import signals | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import (ValidationError, InvalidDocumentError, | ||||
|                                 LookUpError) | ||||
| from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, | ||||
|                                         to_str_keys_recursive) | ||||
| from mongoengine.python_support import PY3, txt_type | ||||
|  | ||||
| from mongoengine.base.common import get_document, ALLOW_INHERITANCE | ||||
| from mongoengine.base.datastructures import BaseDict, BaseList | ||||
| from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict | ||||
| from mongoengine.base.fields import ComplexBaseField | ||||
|  | ||||
| __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | ||||
|  | ||||
| NON_FIELD_ERRORS = '__all__' | ||||
|  | ||||
|  | ||||
| class BaseDocument(object): | ||||
|     __slots__ = ('_changed_fields', '_initialised', '_created', '_data', | ||||
|                   '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') | ||||
|  | ||||
|     _dynamic = False | ||||
|     _created = True | ||||
|     _dynamic_lock = True | ||||
|     _initialised = False | ||||
|     STRICT = False | ||||
|  | ||||
|     def __init__(self, *args, **values): | ||||
|         """ | ||||
| @@ -39,6 +38,8 @@ class BaseDocument(object): | ||||
|         :param __auto_convert: Try and will cast python objects to Object types | ||||
|         :param values: A dictionary of values for the document | ||||
|         """ | ||||
|         self._initialised = False | ||||
|         self._created = True | ||||
|         if args: | ||||
|             # Combine positional arguments with named arguments. | ||||
|             # We only want named arguments. | ||||
| @@ -54,7 +55,11 @@ class BaseDocument(object): | ||||
|         __auto_convert = values.pop("__auto_convert", True) | ||||
|         signals.pre_init.send(self.__class__, document=self, values=values) | ||||
|  | ||||
|         self._data = {} | ||||
|         if self.STRICT and not self._dynamic: | ||||
|             self._data = StrictDict.create(allowed_keys=self._fields.keys())() | ||||
|         else: | ||||
|             self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())() | ||||
|  | ||||
|         self._dynamic_fields = SON() | ||||
|  | ||||
|         # Assign default values to instance | ||||
| @@ -130,17 +135,25 @@ class BaseDocument(object): | ||||
|                 self._data[name] = value | ||||
|                 if hasattr(self, '_changed_fields'): | ||||
|                     self._mark_as_changed(name) | ||||
|         try: | ||||
|             self__created = self._created | ||||
|         except AttributeError: | ||||
|             self__created = True | ||||
|  | ||||
|         if (self._is_document and not self._created and | ||||
|         if (self._is_document and not self__created and | ||||
|            name in self._meta.get('shard_key', tuple()) and | ||||
|            self._data.get(name) != value): | ||||
|             OperationError = _import_class('OperationError') | ||||
|             msg = "Shard Keys are immutable. Tried to update %s" % name | ||||
|             raise OperationError(msg) | ||||
|  | ||||
|         try: | ||||
|             self__initialised = self._initialised | ||||
|         except AttributeError: | ||||
|             self__initialised = False | ||||
|         # Check if the user has created a new instance of a class | ||||
|         if (self._is_document and self._initialised | ||||
|            and self._created and name == self._meta['id_field']): | ||||
|         if (self._is_document and self__initialised | ||||
|            and self__created and name == self._meta['id_field']): | ||||
|                 super(BaseDocument, self).__setattr__('_created', False) | ||||
|  | ||||
|         super(BaseDocument, self).__setattr__(name, value) | ||||
| @@ -158,9 +171,11 @@ class BaseDocument(object): | ||||
|         if isinstance(data["_data"], SON): | ||||
|             data["_data"] = self.__class__._from_son(data["_data"])._data | ||||
|         for k in ('_changed_fields', '_initialised', '_created', '_data', | ||||
|                   '_fields_ordered', '_dynamic_fields'): | ||||
|                    '_dynamic_fields'): | ||||
|             if k in data: | ||||
|                 setattr(self, k, data[k]) | ||||
|         if '_fields_ordered' in data: | ||||
|             setattr(type(self), '_fields_ordered', data['_fields_ordered']) | ||||
|         dynamic_fields = data.get('_dynamic_fields') or SON() | ||||
|         for k in dynamic_fields.keys(): | ||||
|             setattr(self, k, data["_data"].get(k)) | ||||
| @@ -182,7 +197,7 @@ class BaseDocument(object): | ||||
|         """Dictionary-style field access, set a field's value. | ||||
|         """ | ||||
|         # Ensure that the field exists before settings its value | ||||
|         if name not in self._fields: | ||||
|         if not self._dynamic and name not in self._fields: | ||||
|             raise KeyError(name) | ||||
|         return setattr(self, name, value) | ||||
|  | ||||
| @@ -214,8 +229,9 @@ class BaseDocument(object): | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, self.__class__) and hasattr(other, 'id'): | ||||
|             if self.id == other.id: | ||||
|                 return True | ||||
|             return self.id == other.id | ||||
|         if isinstance(other, DBRef): | ||||
|             return self._get_collection_name() == other.collection and self.id == other.id | ||||
|         return False | ||||
|  | ||||
|     def __ne__(self, other): | ||||
| @@ -317,7 +333,7 @@ class BaseDocument(object): | ||||
|             pk = "None" | ||||
|             if hasattr(self, 'pk'): | ||||
|                 pk = self.pk | ||||
|             elif self._instance: | ||||
|             elif self._instance and hasattr(self._instance, 'pk'): | ||||
|                 pk = self._instance.pk | ||||
|             message = "ValidationError (%s:%s) " % (self._class_name, pk) | ||||
|             raise ValidationError(message, errors=errors) | ||||
| @@ -370,9 +386,18 @@ class BaseDocument(object): | ||||
|         """ | ||||
|         if not key: | ||||
|             return | ||||
|         key = self._db_field_map.get(key, key) | ||||
|         if (hasattr(self, '_changed_fields') and | ||||
|            key not in self._changed_fields): | ||||
|  | ||||
|         if not hasattr(self, '_changed_fields'): | ||||
|             return | ||||
|  | ||||
|         if '.' in key: | ||||
|             key, rest = key.split('.', 1) | ||||
|             key = self._db_field_map.get(key, key) | ||||
|             key = '%s.%s' % (key, rest) | ||||
|         else: | ||||
|             key = self._db_field_map.get(key, key) | ||||
|  | ||||
|         if key not in self._changed_fields: | ||||
|             self._changed_fields.append(key) | ||||
|  | ||||
|     def _clear_changed_fields(self): | ||||
| @@ -392,6 +417,8 @@ class BaseDocument(object): | ||||
|                 else: | ||||
|                     data = getattr(data, part, None) | ||||
|                 if hasattr(data, "_changed_fields"): | ||||
|                     if hasattr(data, "_is_document") and data._is_document: | ||||
|                         continue | ||||
|                     data._changed_fields = [] | ||||
|         self._changed_fields = [] | ||||
|  | ||||
| @@ -405,6 +432,10 @@ class BaseDocument(object): | ||||
|  | ||||
|         for index, value in iterator: | ||||
|             list_key = "%s%s." % (key, index) | ||||
|             # don't check anything lower if this key is already marked | ||||
|             # as changed. | ||||
|             if list_key[:-1] in changed_fields: | ||||
|                 continue | ||||
|             if hasattr(value, '_get_changed_fields'): | ||||
|                 changed = value._get_changed_fields(inspected) | ||||
|                 changed_fields += ["%s%s" % (list_key, k) | ||||
| @@ -420,6 +451,7 @@ class BaseDocument(object): | ||||
|         ReferenceField = _import_class("ReferenceField") | ||||
|         changed_fields = [] | ||||
|         changed_fields += getattr(self, '_changed_fields', []) | ||||
|  | ||||
|         inspected = inspected or set() | ||||
|         if hasattr(self, 'id') and isinstance(self.id, Hashable): | ||||
|             if self.id in inspected: | ||||
| @@ -472,7 +504,10 @@ class BaseDocument(object): | ||||
|                     if isinstance(d, (ObjectId, DBRef)): | ||||
|                         break | ||||
|                     elif isinstance(d, list) and p.isdigit(): | ||||
|                         d = d[int(p)] | ||||
|                         try: | ||||
|                             d = d[int(p)] | ||||
|                         except IndexError: | ||||
|                             d = None | ||||
|                     elif hasattr(d, 'get'): | ||||
|                         d = d.get(p) | ||||
|                     new_path.append(p) | ||||
| @@ -545,10 +580,6 @@ class BaseDocument(object): | ||||
|         # class if unavailable | ||||
|         class_name = son.get('_cls', cls._class_name) | ||||
|         data = dict(("%s" % key, value) for key, value in son.iteritems()) | ||||
|         if not UNICODE_KWARGS: | ||||
|             # python 2.6.4 and lower cannot handle unicode keys | ||||
|             # passed to class constructor example: cls(**data) | ||||
|             to_str_keys_recursive(data) | ||||
|  | ||||
|         # Return correct subclass for document type | ||||
|         if class_name != cls._class_name: | ||||
| @@ -586,6 +617,8 @@ class BaseDocument(object): | ||||
|                    % (cls._class_name, errors)) | ||||
|             raise InvalidDocumentError(msg) | ||||
|  | ||||
|         if cls.STRICT: | ||||
|             data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) | ||||
|         obj = cls(__auto_convert=False, **data) | ||||
|         obj._changed_fields = changed_fields | ||||
|         obj._created = False | ||||
| @@ -804,8 +837,17 @@ class BaseDocument(object): | ||||
|                    # Look up subfield on the previous field | ||||
|                     new_field = field.lookup_member(field_name) | ||||
|                 if not new_field and isinstance(field, ComplexBaseField): | ||||
|                     fields.append(field_name) | ||||
|                     continue | ||||
|                     if hasattr(field.field, 'document_type') and cls._dynamic \ | ||||
|                             and field.field.document_type._dynamic: | ||||
|                         DynamicField = _import_class('DynamicField') | ||||
|                         new_field = DynamicField(db_field=field_name) | ||||
|                     else: | ||||
|                         fields.append(field_name) | ||||
|                         continue | ||||
|                 elif not new_field and hasattr(field, 'document_type') and cls._dynamic \ | ||||
|                         and field.document_type._dynamic: | ||||
|                     DynamicField = _import_class('DynamicField') | ||||
|                     new_field = DynamicField(db_field=field_name) | ||||
|                 elif not new_field: | ||||
|                     raise LookUpError('Cannot resolve field "%s"' | ||||
|                                       % field_name) | ||||
| @@ -825,7 +867,11 @@ class BaseDocument(object): | ||||
|         """Dynamically set the display value for a field with choices""" | ||||
|         for attr_name, field in self._fields.items(): | ||||
|             if field.choices: | ||||
|                 setattr(self, | ||||
|                 if self._dynamic: | ||||
|                     obj = self | ||||
|                 else: | ||||
|                     obj = type(self) | ||||
|                 setattr(obj, | ||||
|                         'get_%s_display' % attr_name, | ||||
|                         partial(self.__get_field_display, field=field)) | ||||
|  | ||||
|   | ||||
| @@ -359,7 +359,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | ||||
|                     new_class.id = field | ||||
|  | ||||
|         # Set primary key if not defined by the document | ||||
|         new_class._auto_id_field = False | ||||
|         new_class._auto_id_field = getattr(parent_doc_cls, | ||||
|                                            '_auto_id_field', False) | ||||
|         if not new_class._meta.get('id_field'): | ||||
|             new_class._auto_id_field = True | ||||
|             new_class._meta['id_field'] = 'id' | ||||
|   | ||||
| @@ -20,7 +20,8 @@ _dbs = {} | ||||
|  | ||||
| def register_connection(alias, name, host=None, port=None, | ||||
|                         is_slave=False, read_preference=False, slaves=None, | ||||
|                         username=None, password=None, **kwargs): | ||||
|                         username=None, password=None, authentication_source=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
|  | ||||
|     :param alias: the name that will be used to refer to this connection | ||||
| @@ -36,6 +37,7 @@ def register_connection(alias, name, host=None, port=None, | ||||
|         be a registered connection that has :attr:`is_slave` set to ``True`` | ||||
|     :param username: username to authenticate with | ||||
|     :param password: password to authenticate with | ||||
|     :param authentication_source: database to authenticate against | ||||
|     :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver | ||||
|  | ||||
|     """ | ||||
| @@ -46,10 +48,11 @@ def register_connection(alias, name, host=None, port=None, | ||||
|         'host': host or 'localhost', | ||||
|         'port': port or 27017, | ||||
|         'is_slave': is_slave, | ||||
|         'read_preference': read_preference, | ||||
|         'slaves': slaves or [], | ||||
|         'username': username, | ||||
|         'password': password, | ||||
|         'read_preference': read_preference | ||||
|         'authentication_source': authentication_source | ||||
|     } | ||||
|  | ||||
|     # Handle uri style connections | ||||
| @@ -93,20 +96,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|             raise ConnectionError(msg) | ||||
|         conn_settings = _connection_settings[alias].copy() | ||||
|  | ||||
|         if hasattr(pymongo, 'version_tuple'):  # Support for 2.1+ | ||||
|             conn_settings.pop('name', None) | ||||
|             conn_settings.pop('slaves', None) | ||||
|             conn_settings.pop('is_slave', None) | ||||
|             conn_settings.pop('username', None) | ||||
|             conn_settings.pop('password', None) | ||||
|         else: | ||||
|             # Get all the slave connections | ||||
|             if 'slaves' in conn_settings: | ||||
|                 slaves = [] | ||||
|                 for slave_alias in conn_settings['slaves']: | ||||
|                     slaves.append(get_connection(slave_alias)) | ||||
|                 conn_settings['slaves'] = slaves | ||||
|                 conn_settings.pop('read_preference', None) | ||||
|         conn_settings.pop('name', None) | ||||
|         conn_settings.pop('slaves', None) | ||||
|         conn_settings.pop('is_slave', None) | ||||
|         conn_settings.pop('username', None) | ||||
|         conn_settings.pop('password', None) | ||||
|         conn_settings.pop('authentication_source', None) | ||||
|  | ||||
|         connection_class = MongoClient | ||||
|         if 'replicaSet' in conn_settings: | ||||
| @@ -119,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|             connection_class = MongoReplicaSetClient | ||||
|  | ||||
|         try: | ||||
|             _connections[alias] = connection_class(**conn_settings) | ||||
|             connection = None | ||||
|             connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems()) | ||||
|             for alias, connection_settings in connection_settings_iterator: | ||||
|                 connection_settings.pop('name', None) | ||||
|                 connection_settings.pop('slaves', None) | ||||
|                 connection_settings.pop('is_slave', None) | ||||
|                 connection_settings.pop('username', None) | ||||
|                 connection_settings.pop('password', None) | ||||
|                 if conn_settings == connection_settings and _connections.get(alias, None): | ||||
|                     connection = _connections[alias] | ||||
|                     break | ||||
|  | ||||
|             _connections[alias] = connection if connection else connection_class(**conn_settings) | ||||
|         except Exception, e: | ||||
|             raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) | ||||
|     return _connections[alias] | ||||
| @@ -137,7 +144,8 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|         # Authenticate if necessary | ||||
|         if conn_settings['username'] and conn_settings['password']: | ||||
|             db.authenticate(conn_settings['username'], | ||||
|                             conn_settings['password']) | ||||
|                             conn_settings['password'], | ||||
|                             source=conn_settings['authentication_source']) | ||||
|         _dbs[alias] = db | ||||
|     return _dbs[alias] | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.queryset import QuerySet | ||||
|  | ||||
|  | ||||
| __all__ = ("switch_db", "switch_collection", "no_dereference", | ||||
| @@ -162,12 +161,6 @@ class no_sub_classes(object): | ||||
|         return self.cls | ||||
|  | ||||
|  | ||||
| class QuerySetNoDeRef(QuerySet): | ||||
|     """Special no_dereference QuerySet""" | ||||
|     def __dereference(items, max_depth=1, instance=None, name=None): | ||||
|             return items | ||||
|  | ||||
|  | ||||
| class query_counter(object): | ||||
|     """ Query_counter context manager to get the number of queries. """ | ||||
|  | ||||
|   | ||||
| @@ -36,7 +36,7 @@ class DeReference(object): | ||||
|         if instance and isinstance(instance, (Document, EmbeddedDocument, | ||||
|                                               TopLevelDocumentMetaclass)): | ||||
|             doc_type = instance._fields.get(name) | ||||
|             if hasattr(doc_type, 'field'): | ||||
|             while hasattr(doc_type, 'field'): | ||||
|                 doc_type = doc_type.field | ||||
|  | ||||
|             if isinstance(doc_type, ReferenceField): | ||||
| @@ -51,9 +51,19 @@ class DeReference(object): | ||||
|                     return items | ||||
|                 elif not field.dbref: | ||||
|                     if not hasattr(items, 'items'): | ||||
|                         items = [field.to_python(v) | ||||
|                              if not isinstance(v, (DBRef, Document)) else v | ||||
|                              for v in items] | ||||
|  | ||||
|                         def _get_items(items): | ||||
|                             new_items = [] | ||||
|                             for v in items: | ||||
|                                 if isinstance(v, list): | ||||
|                                     new_items.append(_get_items(v)) | ||||
|                                 elif not isinstance(v, (DBRef, Document)): | ||||
|                                     new_items.append(field.to_python(v)) | ||||
|                                 else: | ||||
|                                     new_items.append(v) | ||||
|                             return new_items | ||||
|  | ||||
|                         items = _get_items(items) | ||||
|                     else: | ||||
|                         items = dict([ | ||||
|                             (k, field.to_python(v)) | ||||
| @@ -114,11 +124,11 @@ class DeReference(object): | ||||
|         """Fetch all references and convert to their document objects | ||||
|         """ | ||||
|         object_map = {} | ||||
|         for col, dbrefs in self.reference_map.iteritems(): | ||||
|         for collection, dbrefs in self.reference_map.iteritems(): | ||||
|             keys = object_map.keys() | ||||
|             refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys])) | ||||
|             if hasattr(col, 'objects'):  # We have a document class for the refs | ||||
|                 references = col.objects.in_bulk(refs) | ||||
|             if hasattr(collection, 'objects'):  # We have a document class for the refs | ||||
|                 references = collection.objects.in_bulk(refs) | ||||
|                 for key, doc in references.iteritems(): | ||||
|                     object_map[key] = doc | ||||
|             else:  # Generic reference: use the refs data to convert to document | ||||
| @@ -126,19 +136,19 @@ class DeReference(object): | ||||
|                     continue | ||||
|  | ||||
|                 if doc_type: | ||||
|                     references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) | ||||
|                     references = doc_type._get_db()[collection].find({'_id': {'$in': refs}}) | ||||
|                     for ref in references: | ||||
|                         doc = doc_type._from_son(ref) | ||||
|                         object_map[doc.id] = doc | ||||
|                 else: | ||||
|                     references = get_db()[col].find({'_id': {'$in': refs}}) | ||||
|                     references = get_db()[collection].find({'_id': {'$in': refs}}) | ||||
|                     for ref in references: | ||||
|                         if '_cls' in ref: | ||||
|                             doc = get_document(ref["_cls"])._from_son(ref) | ||||
|                         elif doc_type is None: | ||||
|                             doc = get_document( | ||||
|                                 ''.join(x.capitalize() | ||||
|                                     for x in col.split('_')))._from_son(ref) | ||||
|                                     for x in collection.split('_')))._from_son(ref) | ||||
|                         else: | ||||
|                             doc = doc_type._from_son(ref) | ||||
|                         object_map[doc.id] = doc | ||||
| @@ -204,7 +214,8 @@ class DeReference(object): | ||||
|                     elif isinstance(v, (list, tuple)) and depth <= self.max_depth: | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) | ||||
|             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                 data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) | ||||
|                 item_name = '%s.%s' % (name, k) if name else name | ||||
|                 data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) | ||||
|             elif hasattr(v, 'id'): | ||||
|                 data[k] = self.object_map.get(v.id, v) | ||||
|  | ||||
|   | ||||
| @@ -13,7 +13,8 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, | ||||
|                               BaseDocument, BaseDict, BaseList, | ||||
|                               ALLOW_INHERITANCE, get_document) | ||||
| from mongoengine.errors import ValidationError | ||||
| from mongoengine.queryset import OperationError, NotUniqueError, QuerySet | ||||
| from mongoengine.queryset import (OperationError, NotUniqueError, | ||||
|                                   QuerySet, transform) | ||||
| from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME | ||||
| from mongoengine.context_managers import switch_db, switch_collection | ||||
|  | ||||
| @@ -54,20 +55,21 @@ class EmbeddedDocument(BaseDocument): | ||||
|     dictionary. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ('_instance') | ||||
|  | ||||
|     # The __metaclass__ attribute is removed by 2to3 when running with Python3 | ||||
|     # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 | ||||
|     my_metaclass  = DocumentMetaclass | ||||
|     __metaclass__ = DocumentMetaclass | ||||
|  | ||||
|     _instance = None | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(EmbeddedDocument, self).__init__(*args, **kwargs) | ||||
|         self._instance = None | ||||
|         self._changed_fields = [] | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, self.__class__): | ||||
|             return self.to_mongo() == other.to_mongo() | ||||
|             return self._data == other._data | ||||
|         return False | ||||
|  | ||||
|     def __ne__(self, other): | ||||
| @@ -125,6 +127,8 @@ class Document(BaseDocument): | ||||
|     my_metaclass  = TopLevelDocumentMetaclass | ||||
|     __metaclass__ = TopLevelDocumentMetaclass | ||||
|  | ||||
|     __slots__ = ('__objects' ) | ||||
|  | ||||
|     def pk(): | ||||
|         """Primary key alias | ||||
|         """ | ||||
| @@ -180,7 +184,7 @@ class Document(BaseDocument): | ||||
|  | ||||
|     def save(self, force_insert=False, validate=True, clean=True, | ||||
|              write_concern=None,  cascade=None, cascade_kwargs=None, | ||||
|              _refs=None, **kwargs): | ||||
|              _refs=None, save_condition=None, **kwargs): | ||||
|         """Save the :class:`~mongoengine.Document` to the database. If the | ||||
|         document already exists, it will be updated, otherwise it will be | ||||
|         created. | ||||
| @@ -203,7 +207,8 @@ class Document(BaseDocument): | ||||
|         :param cascade_kwargs: (optional) kwargs dictionary to be passed throw | ||||
|             to cascading saves.  Implies ``cascade=True``. | ||||
|         :param _refs: A list of processed references used in cascading saves | ||||
|  | ||||
|         :param save_condition: only perform save if matching record in db | ||||
|             satisfies condition(s) (e.g., version number) | ||||
|         .. versionchanged:: 0.5 | ||||
|             In existing documents it only saves changed fields using | ||||
|             set / unset.  Saves are cascaded and any | ||||
| @@ -217,6 +222,9 @@ class Document(BaseDocument): | ||||
|             meta['cascade'] = True.  Also you can pass different kwargs to | ||||
|             the cascade save using cascade_kwargs which overwrites the | ||||
|             existing kwargs with custom values. | ||||
|         .. versionchanged:: 0.8.5 | ||||
|             Optional save_condition that only overwrites existing documents | ||||
|             if the condition is satisfied in the current db record. | ||||
|         """ | ||||
|         signals.pre_save.send(self.__class__, document=self) | ||||
|  | ||||
| @@ -230,7 +238,8 @@ class Document(BaseDocument): | ||||
|  | ||||
|         created = ('_id' not in doc or self._created or force_insert) | ||||
|  | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, created=created) | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||
|                                               created=created) | ||||
|  | ||||
|         try: | ||||
|             collection = self._get_collection() | ||||
| @@ -243,7 +252,12 @@ class Document(BaseDocument): | ||||
|                 object_id = doc['_id'] | ||||
|                 updates, removals = self._delta() | ||||
|                 # Need to add shard key to query, or you get an error | ||||
|                 select_dict = {'_id': object_id} | ||||
|                 if save_condition is not None: | ||||
|                     select_dict = transform.query(self.__class__, | ||||
|                                                   **save_condition) | ||||
|                 else: | ||||
|                     select_dict = {} | ||||
|                 select_dict['_id'] = object_id | ||||
|                 shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||
|                 for k in shard_key: | ||||
|                     actual_key = self._db_field_map.get(k, k) | ||||
| @@ -263,10 +277,12 @@ class Document(BaseDocument): | ||||
|                 if removals: | ||||
|                     update_query["$unset"] = removals | ||||
|                 if updates or removals: | ||||
|                     upsert = save_condition is None | ||||
|                     last_error = collection.update(select_dict, update_query, | ||||
|                                                    upsert=True, **write_concern) | ||||
|                                                    upsert=upsert, **write_concern) | ||||
|                     created = is_new_object(last_error) | ||||
|  | ||||
|  | ||||
|             if cascade is None: | ||||
|                 cascade = self._meta.get('cascade', False) or cascade_kwargs is not None | ||||
|  | ||||
| @@ -293,12 +309,12 @@ class Document(BaseDocument): | ||||
|                 raise NotUniqueError(message % unicode(err)) | ||||
|             raise OperationError(message % unicode(err)) | ||||
|         id_field = self._meta['id_field'] | ||||
|         if id_field not in self._meta.get('shard_key', []): | ||||
|         if created or id_field not in self._meta.get('shard_key', []): | ||||
|             self[id_field] = self._fields[id_field].to_python(object_id) | ||||
|  | ||||
|         signals.post_save.send(self.__class__, document=self, created=created) | ||||
|         self._clear_changed_fields() | ||||
|         self._created = False | ||||
|         signals.post_save.send(self.__class__, document=self, created=created) | ||||
|         return self | ||||
|  | ||||
|     def cascade_save(self, *args, **kwargs): | ||||
| @@ -447,27 +463,41 @@ class Document(BaseDocument): | ||||
|         DeReference()([self], max_depth + 1) | ||||
|         return self | ||||
|  | ||||
|     def reload(self, max_depth=1): | ||||
|     def reload(self, *fields, **kwargs): | ||||
|         """Reloads all attributes from the database. | ||||
|  | ||||
|         :param fields: (optional) args list of fields to reload | ||||
|         :param max_depth: (optional) depth of dereferencing to follow | ||||
|  | ||||
|         .. versionadded:: 0.1.2 | ||||
|         .. versionchanged:: 0.6  Now chainable | ||||
|         .. versionchanged:: 0.9  Can provide specific fields to reload | ||||
|         """ | ||||
|         max_depth = 1 | ||||
|         if fields and isinstance(fields[0], int): | ||||
|             max_depth = fields[0] | ||||
|             fields = fields[1:] | ||||
|         elif "max_depth" in kwargs: | ||||
|             max_depth = kwargs["max_depth"] | ||||
|  | ||||
|         if not self.pk: | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||
|                     **self._object_key).limit(1).select_related(max_depth=max_depth) | ||||
|  | ||||
|                     **self._object_key).only(*fields).limit(1 | ||||
|                     ).select_related(max_depth=max_depth) | ||||
|  | ||||
|         if obj: | ||||
|             obj = obj[0] | ||||
|         else: | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|  | ||||
|         for field in self._fields_ordered: | ||||
|             setattr(self, field, self._reload(field, obj[field])) | ||||
|             if not fields or field in fields: | ||||
|                 setattr(self, field, self._reload(field, obj[field])) | ||||
|  | ||||
|         self._changed_fields = obj._changed_fields | ||||
|         self._created = False | ||||
|         return obj | ||||
|         return self | ||||
|  | ||||
|     def _reload(self, key, value): | ||||
|         """Used by :meth:`~mongoengine.Document.reload` to ensure the | ||||
|   | ||||
| @@ -391,7 +391,7 @@ class DateTimeField(BaseField): | ||||
|         if dateutil: | ||||
|             try: | ||||
|                 return dateutil.parser.parse(value) | ||||
|             except ValueError: | ||||
|             except (TypeError, ValueError): | ||||
|                 return None | ||||
|  | ||||
|         # split usecs, because they are not recognized by strptime. | ||||
| @@ -760,7 +760,7 @@ class DictField(ComplexBaseField): | ||||
|     similar to an embedded document, but the structure is not defined. | ||||
|  | ||||
|     .. note:: | ||||
|         Required means it cannot be empty - as the default for ListFields is [] | ||||
|         Required means it cannot be empty - as the default for DictFields is {} | ||||
|  | ||||
|     .. versionadded:: 0.3 | ||||
|     .. versionchanged:: 0.5 - Can now handle complex / varying types of data | ||||
| @@ -1554,6 +1554,14 @@ class SequenceField(BaseField): | ||||
|  | ||||
|         return super(SequenceField, self).__set__(instance, value) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         """ | ||||
|         This method is overriden in order to convert the query value into to required | ||||
|         type. We need to do this in order to be able to successfully compare query    | ||||
|         values passed as string, the base implementation returns the value as is. | ||||
|         """ | ||||
|         return self.value_decorator(value) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         if value is None: | ||||
|             value = self.generate() | ||||
| @@ -1613,7 +1621,12 @@ class UUIDField(BaseField): | ||||
|  | ||||
|  | ||||
| class GeoPointField(BaseField): | ||||
|     """A list storing a latitude and longitude. | ||||
|     """A list storing a longitude and latitude coordinate.  | ||||
|  | ||||
|     .. note:: this represents a generic point in a 2D plane and a legacy way of  | ||||
|         representing a geo point. It admits 2d indexes but not "2dsphere" indexes  | ||||
|         in MongoDB > 2.4 which are more natural for modeling geospatial points.  | ||||
|         See :ref:`geospatial-indexes`  | ||||
|  | ||||
|     .. versionadded:: 0.4 | ||||
|     """ | ||||
| @@ -1635,7 +1648,7 @@ class GeoPointField(BaseField): | ||||
|  | ||||
|  | ||||
| class PointField(GeoJsonBaseField): | ||||
|     """A geo json field storing a latitude and longitude. | ||||
|     """A GeoJSON field storing a longitude and latitude coordinate. | ||||
|  | ||||
|     The data is represented as: | ||||
|  | ||||
| @@ -1654,7 +1667,7 @@ class PointField(GeoJsonBaseField): | ||||
|  | ||||
|  | ||||
| class LineStringField(GeoJsonBaseField): | ||||
|     """A geo json field storing a line of latitude and longitude coordinates. | ||||
|     """A GeoJSON field storing a line of longitude and latitude coordinates. | ||||
|  | ||||
|     The data is represented as: | ||||
|  | ||||
| @@ -1672,7 +1685,7 @@ class LineStringField(GeoJsonBaseField): | ||||
|  | ||||
|  | ||||
| class PolygonField(GeoJsonBaseField): | ||||
|     """A geo json field storing a polygon of latitude and longitude coordinates. | ||||
|     """A GeoJSON field storing a polygon of longitude and latitude coordinates. | ||||
|  | ||||
|     The data is represented as: | ||||
|  | ||||
|   | ||||
| @@ -3,8 +3,6 @@ | ||||
| import sys | ||||
|  | ||||
| PY3 = sys.version_info[0] == 3 | ||||
| PY25 = sys.version_info[:2] == (2, 5) | ||||
| UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264 | ||||
|  | ||||
| if PY3: | ||||
|     import codecs | ||||
| @@ -29,33 +27,3 @@ else: | ||||
|     txt_type = unicode | ||||
|  | ||||
| str_types = (bin_type, txt_type) | ||||
|  | ||||
| if PY25: | ||||
|     def product(*args, **kwds): | ||||
|         pools = map(tuple, args) * kwds.get('repeat', 1) | ||||
|         result = [[]] | ||||
|         for pool in pools: | ||||
|             result = [x + [y] for x in result for y in pool] | ||||
|         for prod in result: | ||||
|             yield tuple(prod) | ||||
|     reduce = reduce | ||||
| else: | ||||
|     from itertools import product | ||||
|     from functools import reduce | ||||
|  | ||||
|  | ||||
| # For use with Python 2.5 | ||||
| # converts all keys from unicode to str for d and all nested dictionaries | ||||
| def to_str_keys_recursive(d): | ||||
|     if isinstance(d, list): | ||||
|         for val in d: | ||||
|             if isinstance(val, (dict, list)): | ||||
|                 to_str_keys_recursive(val) | ||||
|     elif isinstance(d, dict): | ||||
|         for key, val in d.items(): | ||||
|             if isinstance(val, (dict, list)): | ||||
|                 to_str_keys_recursive(val) | ||||
|             if isinstance(key, unicode): | ||||
|                 d[str(key)] = d.pop(key) | ||||
|     else: | ||||
|         raise ValueError("non list/dict parameter not allowed") | ||||
|   | ||||
| @@ -7,17 +7,20 @@ import pprint | ||||
| import re | ||||
| import warnings | ||||
|  | ||||
| from bson import SON | ||||
| from bson.code import Code | ||||
| from bson import json_util | ||||
| import pymongo | ||||
| import pymongo.errors | ||||
| from pymongo.common import validate_read_preference | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.context_managers import switch_db | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.base.common import get_document | ||||
| from mongoengine.errors import (OperationError, NotUniqueError, | ||||
|                                 InvalidQueryError, LookUpError) | ||||
|  | ||||
| from mongoengine.queryset import transform | ||||
| from mongoengine.queryset.field_list import QueryFieldList | ||||
| from mongoengine.queryset.visitor import Q, QNode | ||||
| @@ -50,7 +53,7 @@ class BaseQuerySet(object): | ||||
|         self._initial_query = {} | ||||
|         self._where_clause = None | ||||
|         self._loaded_fields = QueryFieldList() | ||||
|         self._ordering = [] | ||||
|         self._ordering = None | ||||
|         self._snapshot = False | ||||
|         self._timeout = True | ||||
|         self._class_check = True | ||||
| @@ -146,7 +149,7 @@ class BaseQuerySet(object): | ||||
|                     queryset._document._from_son(queryset._cursor[key], | ||||
|                                                  _auto_dereference=self._auto_dereference)) | ||||
|             if queryset._as_pymongo: | ||||
|                 return queryset._get_as_pymongo(queryset._cursor.next()) | ||||
|                 return queryset._get_as_pymongo(queryset._cursor[key]) | ||||
|             return queryset._document._from_son(queryset._cursor[key], | ||||
|                                                 _auto_dereference=self._auto_dereference) | ||||
|         raise AttributeError | ||||
| @@ -154,6 +157,22 @@ class BaseQuerySet(object): | ||||
|     def __iter__(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _has_data(self): | ||||
|         """ Retrieves whether cursor has any data. """ | ||||
|  | ||||
|         queryset = self.order_by() | ||||
|         return False if queryset.first() is None else True | ||||
|  | ||||
|     def __nonzero__(self): | ||||
|         """ Avoid to open all records in an if stmt in Py2. """ | ||||
|  | ||||
|         return self._has_data() | ||||
|  | ||||
|     def __bool__(self): | ||||
|         """ Avoid to open all records in an if stmt in Py3. """ | ||||
|  | ||||
|         return self._has_data() | ||||
|  | ||||
|     # Core functions | ||||
|  | ||||
|     def all(self): | ||||
| @@ -175,7 +194,7 @@ class BaseQuerySet(object): | ||||
|         .. versionadded:: 0.3 | ||||
|         """ | ||||
|         queryset = self.clone() | ||||
|         queryset = queryset.limit(2) | ||||
|         queryset = queryset.order_by().limit(2) | ||||
|         queryset = queryset.filter(*q_objs, **query) | ||||
|  | ||||
|         try: | ||||
| @@ -389,7 +408,7 @@ class BaseQuerySet(object): | ||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self}) | ||||
|                 ref_q_count = ref_q.count() | ||||
|                 if (doc != document_cls and ref_q_count > 0 | ||||
|                    or (doc == document_cls and ref_q_count > 0)): | ||||
|                     or (doc == document_cls and ref_q_count > 0)): | ||||
|                     ref_q.delete(write_concern=write_concern) | ||||
|             elif rule == NULLIFY: | ||||
|                 document_cls.objects(**{field_name + '__in': self}).update( | ||||
| @@ -443,6 +462,8 @@ class BaseQuerySet(object): | ||||
|                 return result | ||||
|             elif result: | ||||
|                 return result['n'] | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             raise NotUniqueError(u'Update failed (%s)' % unicode(err)) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             if unicode(err) == u'multi not coded yet': | ||||
|                 message = u'update() method requires MongoDB 1.1.3+' | ||||
| @@ -466,6 +487,59 @@ class BaseQuerySet(object): | ||||
|         return self.update( | ||||
|             upsert=upsert, multi=False, write_concern=write_concern, **update) | ||||
|  | ||||
|     def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): | ||||
|         """Update and return the updated document. | ||||
|  | ||||
|         Returns either the document before or after modification based on `new` | ||||
|         parameter. If no documents match the query and `upsert` is false, | ||||
|         returns ``None``. If upserting and `new` is false, returns ``None``. | ||||
|  | ||||
|         If the full_response parameter is ``True``, the return value will be | ||||
|         the entire response object from the server, including the 'ok' and | ||||
|         'lastErrorObject' fields, rather than just the modified document. | ||||
|         This is useful mainly because the 'lastErrorObject' document holds | ||||
|         information about the command's execution. | ||||
|  | ||||
|         :param upsert: insert if document doesn't exist (default ``False``) | ||||
|         :param full_response: return the entire response object from the | ||||
|             server (default ``False``) | ||||
|         :param remove: remove rather than updating (default ``False``) | ||||
|         :param new: return updated rather than original document | ||||
|             (default ``False``) | ||||
|         :param update: Django-style update keyword arguments | ||||
|  | ||||
|         .. versionadded:: 0.9 | ||||
|         """ | ||||
|  | ||||
|         if remove and new: | ||||
|             raise OperationError("Conflicting parameters: remove and new") | ||||
|  | ||||
|         if not update and not upsert and not remove: | ||||
|             raise OperationError("No update parameters, must either update or remove") | ||||
|  | ||||
|         queryset = self.clone() | ||||
|         query = queryset._query | ||||
|         update = transform.update(queryset._document, **update) | ||||
|         sort = queryset._ordering | ||||
|  | ||||
|         try: | ||||
|             result = queryset._collection.find_and_modify( | ||||
|                 query, update, upsert=upsert, sort=sort, remove=remove, new=new, | ||||
|                 full_response=full_response, **self._cursor_args) | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             raise NotUniqueError(u"Update failed (%s)" % err) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             raise OperationError(u"Update failed (%s)" % err) | ||||
|  | ||||
|         if full_response: | ||||
|             if result["value"] is not None: | ||||
|                 result["value"] = self._document._from_son(result["value"]) | ||||
|         else: | ||||
|             if result is not None: | ||||
|                 result = self._document._from_son(result) | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     def with_id(self, object_id): | ||||
|         """Retrieve the object matching the id provided.  Uses `object_id` only | ||||
|         and raises InvalidQueryError if a filter has been applied. Returns | ||||
| @@ -522,6 +596,19 @@ class BaseQuerySet(object): | ||||
|  | ||||
|         return self | ||||
|  | ||||
|     def using(self, alias): | ||||
|         """This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database. | ||||
|  | ||||
|         :param alias: The database alias | ||||
|  | ||||
|         .. versionadded:: 0.8 | ||||
|         """ | ||||
|  | ||||
|         with switch_db(self._document, alias) as cls: | ||||
|             collection = cls._get_collection() | ||||
|  | ||||
|         return self.clone_into(self.__class__(self._document, collection)) | ||||
|  | ||||
|     def clone(self): | ||||
|         """Creates a copy of the current | ||||
|           :class:`~mongoengine.queryset.QuerySet` | ||||
| @@ -630,7 +717,10 @@ class BaseQuerySet(object): | ||||
|             # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) | ||||
|             doc_field = getattr(self._document._fields.get(field), "field", None) | ||||
|             instance = getattr(doc_field, "document_type", False) | ||||
|             if instance: | ||||
|             EmbeddedDocumentField = _import_class('EmbeddedDocumentField') | ||||
|             GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') | ||||
|             if instance and isinstance(doc_field, (EmbeddedDocumentField, | ||||
|                                                    GenericEmbeddedDocumentField)): | ||||
|                 distinct = [instance(**doc) for doc in distinct] | ||||
|             return distinct | ||||
|  | ||||
| @@ -923,10 +1013,39 @@ class BaseQuerySet(object): | ||||
|             map_reduce_function = 'inline_map_reduce' | ||||
|         else: | ||||
|             map_reduce_function = 'map_reduce' | ||||
|             mr_args['out'] = output | ||||
|  | ||||
|             if isinstance(output, basestring): | ||||
|                 mr_args['out'] = output | ||||
|  | ||||
|             elif isinstance(output, dict): | ||||
|                 ordered_output = [] | ||||
|  | ||||
|                 for part in ('replace', 'merge', 'reduce'): | ||||
|                     value = output.get(part) | ||||
|                     if value: | ||||
|                         ordered_output.append((part, value)) | ||||
|                         break | ||||
|  | ||||
|                 else: | ||||
|                     raise OperationError("actionData not specified for output") | ||||
|  | ||||
|                 db_alias = output.get('db_alias') | ||||
|                 remaing_args = ['db', 'sharded', 'nonAtomic'] | ||||
|  | ||||
|                 if db_alias: | ||||
|                     ordered_output.append(('db', get_db(db_alias).name)) | ||||
|                     del remaing_args[0] | ||||
|  | ||||
|  | ||||
|                 for part in remaing_args: | ||||
|                     value = output.get(part) | ||||
|                     if value: | ||||
|                         ordered_output.append((part, value)) | ||||
|  | ||||
|                 mr_args['out'] = SON(ordered_output) | ||||
|  | ||||
|         results = getattr(queryset._collection, map_reduce_function)( | ||||
|                           map_f, reduce_f, **mr_args) | ||||
|             map_f, reduce_f, **mr_args) | ||||
|  | ||||
|         if map_reduce_function == 'map_reduce': | ||||
|             results = results.find() | ||||
| @@ -1189,8 +1308,9 @@ class BaseQuerySet(object): | ||||
|             if self._ordering: | ||||
|                 # Apply query ordering | ||||
|                 self._cursor_obj.sort(self._ordering) | ||||
|             elif self._document._meta['ordering']: | ||||
|                 # Otherwise, apply the ordering from the document model | ||||
|             elif self._ordering is None and self._document._meta['ordering']: | ||||
|                 # Otherwise, apply the ordering from the document model, unless | ||||
|                 # it's been explicitly cleared via order_by with no arguments | ||||
|                 order = self._get_order_by(self._document._meta['ordering']) | ||||
|                 self._cursor_obj.sort(order) | ||||
|  | ||||
| @@ -1362,7 +1482,7 @@ class BaseQuerySet(object): | ||||
|                 for subdoc in subclasses: | ||||
|                     try: | ||||
|                         subfield = ".".join(f.db_field for f in | ||||
|                                         subdoc._lookup_field(field.split('.'))) | ||||
|                                             subdoc._lookup_field(field.split('.'))) | ||||
|                         ret.append(subfield) | ||||
|                         found = True | ||||
|                         break | ||||
| @@ -1392,7 +1512,7 @@ class BaseQuerySet(object): | ||||
|                 pass | ||||
|             key_list.append((key, direction)) | ||||
|  | ||||
|         if self._cursor_obj: | ||||
|         if self._cursor_obj and key_list: | ||||
|             self._cursor_obj.sort(key_list) | ||||
|         return key_list | ||||
|  | ||||
| @@ -1450,6 +1570,7 @@ class BaseQuerySet(object): | ||||
|                     # type of this field and use the corresponding | ||||
|                     # .to_python(...) | ||||
|                     from mongoengine.fields import EmbeddedDocumentField | ||||
|  | ||||
|                     obj = self._document | ||||
|                     for chunk in path.split('.'): | ||||
|                         obj = getattr(obj, chunk, None) | ||||
| @@ -1460,6 +1581,7 @@ class BaseQuerySet(object): | ||||
|                     if obj and data is not None: | ||||
|                         data = obj.to_python(data) | ||||
|             return data | ||||
|  | ||||
|         return clean(row) | ||||
|  | ||||
|     def _sub_js_fields(self, code): | ||||
| @@ -1468,6 +1590,7 @@ class BaseQuerySet(object): | ||||
|         substituted for the MongoDB name of the field (specified using the | ||||
|         :attr:`name` keyword argument in a field's constructor). | ||||
|         """ | ||||
|  | ||||
|         def field_sub(match): | ||||
|             # Extract just the field name, and look up the field objects | ||||
|             field_name = match.group(1).split('.') | ||||
|   | ||||
| @@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet): | ||||
|             queryset = self.clone() | ||||
|         queryset.rewind() | ||||
|         return queryset | ||||
|  | ||||
|  | ||||
| class QuerySetNoDeRef(QuerySet): | ||||
|     """Special no_dereference QuerySet""" | ||||
|  | ||||
|     def __dereference(items, max_depth=1, instance=None, name=None): | ||||
|         return items | ||||
| @@ -3,6 +3,7 @@ from collections import defaultdict | ||||
| import pymongo | ||||
| from bson import SON | ||||
|  | ||||
| from mongoengine.connection import get_connection | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import InvalidQueryError, LookUpError | ||||
|  | ||||
| @@ -38,7 +39,7 @@ def query(_doc_cls=None, _field_operation=False, **query): | ||||
|             mongo_query.update(value) | ||||
|             continue | ||||
|  | ||||
|         parts = key.split('__') | ||||
|         parts = key.rsplit('__') | ||||
|         indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] | ||||
|         parts = [part for part in parts if not part.isdigit()] | ||||
|         # Check for an operator and transform to mongo-style if there is | ||||
| @@ -115,14 +116,26 @@ def query(_doc_cls=None, _field_operation=False, **query): | ||||
|             if key in mongo_query and isinstance(mongo_query[key], dict): | ||||
|                 mongo_query[key].update(value) | ||||
|                 # $maxDistance needs to come last - convert to SON | ||||
|                 if '$maxDistance' in mongo_query[key]: | ||||
|                     value_dict = mongo_query[key] | ||||
|                 value_dict = mongo_query[key] | ||||
|                 if ('$maxDistance' in value_dict and '$near' in value_dict): | ||||
|                     value_son = SON() | ||||
|                     for k, v in value_dict.iteritems(): | ||||
|                         if k == '$maxDistance': | ||||
|                             continue | ||||
|                         value_son[k] = v | ||||
|                     value_son['$maxDistance'] = value_dict['$maxDistance'] | ||||
|                     if isinstance(value_dict['$near'], dict): | ||||
|                         for k, v in value_dict.iteritems(): | ||||
|                             if k == '$maxDistance': | ||||
|                                 continue | ||||
|                             value_son[k] = v | ||||
|                         if (get_connection().max_wire_version <= 1): | ||||
|                             value_son['$maxDistance'] = value_dict['$maxDistance'] | ||||
|                         else: | ||||
|                             value_son['$near'] = SON(value_son['$near']) | ||||
|                             value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] | ||||
|                     else: | ||||
|                         for k, v in value_dict.iteritems(): | ||||
|                             if k == '$maxDistance': | ||||
|                                 continue | ||||
|                             value_son[k] = v | ||||
|                         value_son['$maxDistance'] = value_dict['$maxDistance'] | ||||
|  | ||||
|                     mongo_query[key] = value_son | ||||
|             else: | ||||
|                 # Store for manually merging later | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import copy | ||||
|  | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.python_support import product, reduce | ||||
| from itertools import product | ||||
| from functools import reduce | ||||
|  | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.queryset import transform | ||||
|  | ||||
| __all__ = ('Q',) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user