Compare commits
	
		
			32 Commits
		
	
	
		
			improve-he
			...
			cleaner-sa
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | a8889b6dfb | ||
|  | d05301b3a1 | ||
|  | a120eae5ae | ||
|  | 3d75573889 | ||
|  | c6240ca415 | ||
|  | 2ee8984b44 | ||
|  | b7ec587e5b | ||
|  | 47c58bce2b | ||
|  | 96e95ac533 | ||
|  | b013a065f7 | ||
|  | 74b37d11cf | ||
|  | c6cc013617 | ||
|  | f4e1d80a87 | ||
|  | 91dad4060f | ||
|  | e07cb82c15 | ||
|  | 2770cec187 | ||
|  | 5c3928190a | ||
|  | 9f4b04ea0f | ||
|  | 96d20756ca | ||
|  | b8454c7f5b | ||
|  | c84f703f92 | ||
|  | 57c2e867d8 | ||
|  | 553f496d84 | ||
|  | b1d8aca46a | ||
|  | 8e884fd3ea | ||
|  | 76524b7498 | ||
|  | 65914fb2b2 | ||
|  | a4d0da0085 | ||
|  | c9d496e9a0 | ||
|  | 88a951ba4f | ||
|  | 403ceb19dc | ||
|  | 835d3c3d18 | 
| @@ -20,7 +20,7 @@ post to the `user group <http://groups.google.com/group/mongoengine-users>` | ||||
| Supported Interpreters | ||||
| ---------------------- | ||||
|  | ||||
| MongoEngine supports CPython 2.6 and newer. Language | ||||
| MongoEngine supports CPython 2.7 and newer. Language | ||||
| features not supported by all interpreters can not be used. | ||||
| Please also ensure that your code is properly converted by | ||||
| `2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support. | ||||
|   | ||||
							
								
								
									
										20
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								README.rst
									
									
									
									
									
								
							| @@ -4,7 +4,7 @@ MongoEngine | ||||
| :Info: MongoEngine is an ORM-like layer on top of PyMongo. | ||||
| :Repository: https://github.com/MongoEngine/mongoengine | ||||
| :Author: Harry Marr (http://github.com/hmarr) | ||||
| :Maintainer: Ross Lawley (http://github.com/rozza) | ||||
| :Maintainer: Stefan Wójcik (http://github.com/wojcikstefan) | ||||
|  | ||||
| .. image:: https://travis-ci.org/MongoEngine/mongoengine.svg?branch=master | ||||
|   :target: https://travis-ci.org/MongoEngine/mongoengine | ||||
| @@ -57,7 +57,7 @@ Some simple examples of what MongoEngine code looks like: | ||||
|  | ||||
|     class BlogPost(Document): | ||||
|         title = StringField(required=True, max_length=200) | ||||
|         posted = DateTimeField(default=datetime.datetime.now) | ||||
|         posted = DateTimeField(default=datetime.datetime.utcnow) | ||||
|         tags = ListField(StringField(max_length=50)) | ||||
|         meta = {'allow_inheritance': True} | ||||
|  | ||||
| @@ -87,17 +87,18 @@ Some simple examples of what MongoEngine code looks like: | ||||
|     ...     print | ||||
|     ... | ||||
|  | ||||
|     >>> len(BlogPost.objects) | ||||
|     # Count all blog posts and its subtypes | ||||
|     >>> BlogPost.objects.count() | ||||
|     2 | ||||
|     >>> len(TextPost.objects) | ||||
|     >>> TextPost.objects.count() | ||||
|     1 | ||||
|     >>> len(LinkPost.objects) | ||||
|     >>> LinkPost.objects.count() | ||||
|     1 | ||||
|  | ||||
|     # Find tagged posts | ||||
|     >>> len(BlogPost.objects(tags='mongoengine')) | ||||
|     # Count tagged posts | ||||
|     >>> BlogPost.objects(tags='mongoengine').count() | ||||
|     2 | ||||
|     >>> len(BlogPost.objects(tags='mongodb')) | ||||
|     >>> BlogPost.objects(tags='mongodb').count() | ||||
|     1 | ||||
|  | ||||
| Tests | ||||
| @@ -130,8 +131,7 @@ Community | ||||
|   <http://groups.google.com/group/mongoengine-users>`_ | ||||
| - `MongoEngine Developers mailing list | ||||
|   <http://groups.google.com/group/mongoengine-dev>`_ | ||||
| - `#mongoengine IRC channel <http://webchat.freenode.net/?channels=mongoengine>`_ | ||||
|  | ||||
| Contributing | ||||
| ============ | ||||
| We welcome contributions! see  the `Contribution guidelines <https://github.com/MongoEngine/mongoengine/blob/master/CONTRIBUTING.rst>`_ | ||||
| We welcome contributions! See the `Contribution guidelines <https://github.com/MongoEngine/mongoengine/blob/master/CONTRIBUTING.rst>`_ | ||||
|   | ||||
| @@ -5,6 +5,8 @@ Changelog | ||||
| Development | ||||
| =========== | ||||
| - (Fill this out as you fix issues and develop you features). | ||||
| - Fixed connecting to a replica set with PyMongo 2.x #1436 | ||||
| - Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237 | ||||
|  | ||||
| Changes in 0.11.0 | ||||
| ================= | ||||
|   | ||||
| @@ -33,7 +33,7 @@ the :attr:`host` to | ||||
|     corresponding parameters in :func:`~mongoengine.connect`: :: | ||||
|  | ||||
|         connect( | ||||
|             name='test', | ||||
|             db='test', | ||||
|             username='user', | ||||
|             password='12345', | ||||
|             host='mongodb://admin:qwerty@localhost/production' | ||||
|   | ||||
| @@ -479,6 +479,8 @@ operators. To use a :class:`~mongoengine.queryset.Q` object, pass it in as the | ||||
| first positional argument to :attr:`Document.objects` when you filter it by | ||||
| calling it with keyword arguments:: | ||||
|  | ||||
|     from mongoengine.queryset.visitor import Q | ||||
|  | ||||
|     # Get published posts | ||||
|     Post.objects(Q(published=True) | Q(publish_date__lte=datetime.now())) | ||||
|  | ||||
|   | ||||
| @@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) + | ||||
|            list(signals.__all__) + list(errors.__all__)) | ||||
|  | ||||
|  | ||||
| VERSION = (0, 10, 9) | ||||
| VERSION = (0, 11, 0) | ||||
|  | ||||
|  | ||||
| def get_version(): | ||||
|   | ||||
| @@ -5,7 +5,7 @@ __all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry') | ||||
|  | ||||
| UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', | ||||
|                         'push_all', 'pull', 'pull_all', 'add_to_set', | ||||
|                         'set_on_insert', 'min', 'max']) | ||||
|                         'set_on_insert', 'min', 'max', 'rename']) | ||||
|  | ||||
|  | ||||
| _document_registry = {} | ||||
|   | ||||
| @@ -138,10 +138,7 @@ class BaseList(list): | ||||
|         return super(BaseList, self).__setitem__(key, value) | ||||
|  | ||||
|     def __delitem__(self, key, *args, **kwargs): | ||||
|         if isinstance(key, slice): | ||||
|         self._mark_as_changed() | ||||
|         else: | ||||
|             self._mark_as_changed(key) | ||||
|         return super(BaseList, self).__delitem__(key) | ||||
|  | ||||
|     def __setslice__(self, *args, **kwargs): | ||||
|   | ||||
| @@ -402,9 +402,11 @@ class BaseDocument(object): | ||||
|             raise ValidationError(message, errors=errors) | ||||
|  | ||||
|     def to_json(self, *args, **kwargs): | ||||
|         """Converts a document to JSON. | ||||
|         :param use_db_field: Set to True by default but enables the output of the json structure with the field names | ||||
|             and not the mongodb store db_names in case of set to False | ||||
|         """Convert this document to JSON. | ||||
|  | ||||
|         :param use_db_field: Serialize field names as they appear in | ||||
|             MongoDB (as opposed to attribute names on this document). | ||||
|             Defaults to True. | ||||
|         """ | ||||
|         use_db_field = kwargs.pop('use_db_field', True) | ||||
|         return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) | ||||
| @@ -675,6 +677,9 @@ class BaseDocument(object): | ||||
|         if not only_fields: | ||||
|             only_fields = [] | ||||
|  | ||||
|         if son and not isinstance(son, dict): | ||||
|             raise ValueError("The source SON object needs to be of type 'dict'") | ||||
|  | ||||
|         # Get the class name from the document, falling back to the given | ||||
|         # class if unavailable | ||||
|         class_name = son.get('_cls', cls._class_name) | ||||
|   | ||||
| @@ -23,7 +23,6 @@ class BaseField(object): | ||||
|  | ||||
|     .. versionchanged:: 0.5 - added verbose and help text | ||||
|     """ | ||||
|  | ||||
|     name = None | ||||
|     _geo_index = False | ||||
|     _auto_gen = False  # Call `generate` to generate a value | ||||
| @@ -42,7 +41,7 @@ class BaseField(object): | ||||
|         """ | ||||
|         :param db_field: The database field to store this field in | ||||
|             (defaults to the name of the field) | ||||
|         :param name: Depreciated - use db_field | ||||
|         :param name: Deprecated - use db_field | ||||
|         :param required: If the field is required. Whether it has to have a | ||||
|             value or not. Defaults to False. | ||||
|         :param default: (optional) The default value for this field if no value | ||||
| @@ -82,6 +81,17 @@ class BaseField(object): | ||||
|         self.sparse = sparse | ||||
|         self._owner_document = None | ||||
|  | ||||
|         # Validate the db_field | ||||
|         if isinstance(self.db_field, six.string_types) and ( | ||||
|             '.' in self.db_field or | ||||
|             '\0' in self.db_field or | ||||
|             self.db_field.startswith('$') | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 'field names cannot contain dots (".") or null characters ' | ||||
|                 '("\\0"), and they must not start with a dollar sign ("$").' | ||||
|             ) | ||||
|  | ||||
|         # Detect and report conflicts between metadata and base properties. | ||||
|         conflicts = set(dir(self)) & set(kwargs) | ||||
|         if conflicts: | ||||
|   | ||||
| @@ -34,7 +34,10 @@ def _import_class(cls_name): | ||||
|     queryset_classes = ('OperationError',) | ||||
|     deref_classes = ('DeReference',) | ||||
|  | ||||
|     if cls_name in doc_classes: | ||||
|     if cls_name == 'BaseDocument': | ||||
|         from mongoengine.base import document as module | ||||
|         import_classes = ['BaseDocument'] | ||||
|     elif cls_name in doc_classes: | ||||
|         from mongoengine import document as module | ||||
|         import_classes = doc_classes | ||||
|     elif cls_name in field_classes: | ||||
|   | ||||
| @@ -66,9 +66,9 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|         'authentication_mechanism': authentication_mechanism | ||||
|     } | ||||
|  | ||||
|     # Handle uri style connections | ||||
|     conn_host = conn_settings['host'] | ||||
|     # host can be a list or a string, so if string, force to a list | ||||
|  | ||||
|     # Host can be a list or a string, so if string, force to a list. | ||||
|     if isinstance(conn_host, six.string_types): | ||||
|         conn_host = [conn_host] | ||||
|  | ||||
| @@ -96,7 +96,7 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|  | ||||
|             uri_options = uri_dict['options'] | ||||
|             if 'replicaset' in uri_options: | ||||
|                 conn_settings['replicaSet'] = True | ||||
|                 conn_settings['replicaSet'] = uri_options['replicaset'] | ||||
|             if 'authsource' in uri_options: | ||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||
|             if 'authmechanism' in uri_options: | ||||
| @@ -170,23 +170,22 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|     else: | ||||
|         connection_class = MongoClient | ||||
|  | ||||
|         # Handle replica set connections | ||||
|         if 'replicaSet' in conn_settings: | ||||
|  | ||||
|             # Discard port since it can't be used on MongoReplicaSetClient | ||||
|             conn_settings.pop('port', None) | ||||
|  | ||||
|             # Discard replicaSet if it's not a string | ||||
|             if not isinstance(conn_settings['replicaSet'], six.string_types): | ||||
|                 del conn_settings['replicaSet'] | ||||
|  | ||||
|         # For replica set connections with PyMongo 2.x, use | ||||
|         # MongoReplicaSetClient. | ||||
|         # TODO remove this once we stop supporting PyMongo 2.x. | ||||
|             if not IS_PYMONGO_3: | ||||
|         if 'replicaSet' in conn_settings and not IS_PYMONGO_3: | ||||
|             connection_class = MongoReplicaSetClient | ||||
|             conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) | ||||
|  | ||||
|             # hosts_or_uri has to be a string, so if 'host' was provided | ||||
|             # as a list, join its parts and separate them by ',' | ||||
|             if isinstance(conn_settings['hosts_or_uri'], list): | ||||
|                 conn_settings['hosts_or_uri'] = ','.join( | ||||
|                     conn_settings['hosts_or_uri']) | ||||
|  | ||||
|             # Discard port since it can't be used on MongoReplicaSetClient | ||||
|             conn_settings.pop('port', None) | ||||
|  | ||||
|     # Iterate over all of the connection settings and if a connection with | ||||
|     # the same parameters is already established, use it instead of creating | ||||
|     # a new one. | ||||
|   | ||||
| @@ -313,6 +313,9 @@ class Document(BaseDocument): | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             Add signal_kwargs argument | ||||
|         """ | ||||
|         if self._meta.get('abstract'): | ||||
|             raise InvalidDocumentError('Cannot save an abstract document.') | ||||
|  | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) | ||||
|  | ||||
| @@ -329,68 +332,20 @@ class Document(BaseDocument): | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||
|                                               created=created, **signal_kwargs) | ||||
|  | ||||
|         try: | ||||
|             collection = self._get_collection() | ||||
|         if self._meta.get('auto_create_index', True): | ||||
|             self.ensure_indexes() | ||||
|  | ||||
|         try: | ||||
|             # Save a new document or update an existing one | ||||
|             if created: | ||||
|                 if force_insert: | ||||
|                     object_id = collection.insert(doc, **write_concern) | ||||
|                 object_id = self._save_create(doc, force_insert, write_concern) | ||||
|             else: | ||||
|                     object_id = collection.save(doc, **write_concern) | ||||
|                     # In PyMongo 3.0, the save() call calls internally the _update() call | ||||
|                     # but they forget to return the _id value passed back, therefore getting it back here | ||||
|                     # Correct behaviour in 2.X and in 3.0.1+ versions | ||||
|                     if not object_id and pymongo.version_tuple == (3, 0): | ||||
|                         pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||
|                         object_id = ( | ||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|                         )  # TODO doesn't this make 2 queries? | ||||
|             else: | ||||
|                 object_id = doc['_id'] | ||||
|                 updates, removals = self._delta() | ||||
|                 # Need to add shard key to query, or you get an error | ||||
|                 if save_condition is not None: | ||||
|                     select_dict = transform.query(self.__class__, | ||||
|                                                   **save_condition) | ||||
|                 else: | ||||
|                     select_dict = {} | ||||
|                 select_dict['_id'] = object_id | ||||
|                 shard_key = self._meta.get('shard_key', tuple()) | ||||
|                 for k in shard_key: | ||||
|                     path = self._lookup_field(k.split('.')) | ||||
|                     actual_key = [p.db_field for p in path] | ||||
|                     val = doc | ||||
|                     for ak in actual_key: | ||||
|                         val = val[ak] | ||||
|                     select_dict['.'.join(actual_key)] = val | ||||
|  | ||||
|                 def is_new_object(last_error): | ||||
|                     if last_error is not None: | ||||
|                         updated = last_error.get('updatedExisting') | ||||
|                         if updated is not None: | ||||
|                             return not updated | ||||
|                     return created | ||||
|  | ||||
|                 update_query = {} | ||||
|  | ||||
|                 if updates: | ||||
|                     update_query['$set'] = updates | ||||
|                 if removals: | ||||
|                     update_query['$unset'] = removals | ||||
|                 if updates or removals: | ||||
|                     upsert = save_condition is None | ||||
|                     last_error = collection.update(select_dict, update_query, | ||||
|                                                    upsert=upsert, **write_concern) | ||||
|                     if not upsert and last_error['n'] == 0: | ||||
|                         raise SaveConditionError('Race condition preventing' | ||||
|                                                  ' document update detected') | ||||
|                     created = is_new_object(last_error) | ||||
|                 object_id, created = self._save_update(doc, save_condition, | ||||
|                                                        write_concern) | ||||
|  | ||||
|             if cascade is None: | ||||
|                 cascade = self._meta.get( | ||||
|                     'cascade', False) or cascade_kwargs is not None | ||||
|                 cascade = (self._meta.get('cascade', False) or | ||||
|                            cascade_kwargs is not None) | ||||
|  | ||||
|             if cascade: | ||||
|                 kwargs = { | ||||
| @@ -403,6 +358,7 @@ class Document(BaseDocument): | ||||
|                     kwargs.update(cascade_kwargs) | ||||
|                 kwargs['_refs'] = _refs | ||||
|                 self.cascade_save(**kwargs) | ||||
|  | ||||
|         except pymongo.errors.DuplicateKeyError as err: | ||||
|             message = u'Tried to save duplicate unique keys (%s)' | ||||
|             raise NotUniqueError(message % six.text_type(err)) | ||||
| @@ -415,16 +371,91 @@ class Document(BaseDocument): | ||||
|                 raise NotUniqueError(message % six.text_type(err)) | ||||
|             raise OperationError(message % six.text_type(err)) | ||||
|  | ||||
|         # Make sure we store the PK on this document now that it's saved | ||||
|         id_field = self._meta['id_field'] | ||||
|         if created or id_field not in self._meta.get('shard_key', []): | ||||
|             self[id_field] = self._fields[id_field].to_python(object_id) | ||||
|  | ||||
|         signals.post_save.send(self.__class__, document=self, | ||||
|                                created=created, **signal_kwargs) | ||||
|  | ||||
|         self._clear_changed_fields() | ||||
|         self._created = False | ||||
|  | ||||
|         return self | ||||
|  | ||||
|     def _save_create(self, doc, force_insert, write_concern): | ||||
|         """Save a new document. | ||||
|  | ||||
|         Helper method, should only be used inside save(). | ||||
|         """ | ||||
|         collection = self._get_collection() | ||||
|  | ||||
|         if force_insert: | ||||
|             return collection.insert(doc, **write_concern) | ||||
|  | ||||
|         object_id = collection.save(doc, **write_concern) | ||||
|  | ||||
|         # In PyMongo 3.0, the save() call calls internally the _update() call | ||||
|         # but they forget to return the _id value passed back, therefore getting it back here | ||||
|         # Correct behaviour in 2.X and in 3.0.1+ versions | ||||
|         if not object_id and pymongo.version_tuple == (3, 0): | ||||
|             pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||
|             object_id = ( | ||||
|                 self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||
|                 self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|             )  # TODO doesn't this make 2 queries? | ||||
|  | ||||
|         return object_id | ||||
|  | ||||
|     def _save_update(self, doc, save_condition, write_concern): | ||||
|         """Update an existing document. | ||||
|  | ||||
|         Helper method, should only be used inside save(). | ||||
|         """ | ||||
|         collection = self._get_collection() | ||||
|         object_id = doc['_id'] | ||||
|         created = False | ||||
|  | ||||
|         select_dict = {} | ||||
|         if save_condition is not None: | ||||
|             select_dict = transform.query(self.__class__, **save_condition) | ||||
|  | ||||
|         select_dict['_id'] = object_id | ||||
|  | ||||
|         # Need to add shard key to query, or you get an error | ||||
|         shard_key = self._meta.get('shard_key', tuple()) | ||||
|         for k in shard_key: | ||||
|             path = self._lookup_field(k.split('.')) | ||||
|             actual_key = [p.db_field for p in path] | ||||
|             val = doc | ||||
|             for ak in actual_key: | ||||
|                 val = val[ak] | ||||
|             select_dict['.'.join(actual_key)] = val | ||||
|  | ||||
|         updates, removals = self._delta() | ||||
|         update_query = {} | ||||
|         if updates: | ||||
|             update_query['$set'] = updates | ||||
|         if removals: | ||||
|             update_query['$unset'] = removals | ||||
|         if updates or removals: | ||||
|             upsert = save_condition is None | ||||
|             last_error = collection.update(select_dict, update_query, | ||||
|                                            upsert=upsert, **write_concern) | ||||
|             if not upsert and last_error['n'] == 0: | ||||
|                 raise SaveConditionError('Race condition preventing' | ||||
|                                          ' document update detected') | ||||
|             if last_error is not None: | ||||
|                 updated_existing = last_error.get('updatedExisting') | ||||
|                 if updated_existing is False: | ||||
|                     created = True | ||||
|                     # !!! This is bad, means we accidentally created a new, | ||||
|                     # potentially corrupted document. See | ||||
|                     # https://github.com/MongoEngine/mongoengine/issues/564 | ||||
|  | ||||
|         return object_id, created | ||||
|  | ||||
|     def cascade_save(self, **kwargs): | ||||
|         """Recursively save any references and generic references on the | ||||
|         document. | ||||
| @@ -828,7 +859,6 @@ class Document(BaseDocument): | ||||
|         """ Lists all of the indexes that should be created for given | ||||
|         collection. It includes all the indexes from super- and sub-classes. | ||||
|         """ | ||||
|  | ||||
|         if cls._meta.get('abstract'): | ||||
|             return [] | ||||
|  | ||||
|   | ||||
| @@ -28,7 +28,7 @@ from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, | ||||
|                               GeoJsonBaseField, ObjectIdField, get_document) | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.document import Document, EmbeddedDocument | ||||
| from mongoengine.errors import DoesNotExist, ValidationError | ||||
| from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError | ||||
| from mongoengine.python_support import StringIO | ||||
| from mongoengine.queryset import DO_NOTHING, QuerySet | ||||
|  | ||||
| @@ -566,7 +566,11 @@ class EmbeddedDocumentField(BaseField): | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if value is not None and not isinstance(value, self.document_type): | ||||
|             try: | ||||
|                 value = self.document_type._from_son(value) | ||||
|             except ValueError: | ||||
|                 raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" % | ||||
|                                         (self.document_type._class_name,)) | ||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||
|         return self.to_mongo(value) | ||||
|  | ||||
|   | ||||
| @@ -901,17 +901,23 @@ class BaseQuerySet(object): | ||||
|  | ||||
|     def fields(self, _only_called=False, **kwargs): | ||||
|         """Manipulate how you load this document's fields. Used by `.only()` | ||||
|         and `.exclude()` to manipulate which fields to retrieve.  Fields also | ||||
|         allows for a greater level of control for example: | ||||
|         and `.exclude()` to manipulate which fields to retrieve. If called | ||||
|         directly, use a set of kwargs similar to the MongoDB projection | ||||
|         document. For example: | ||||
|  | ||||
|         Retrieving a Subrange of Array Elements: | ||||
|         Include only a subset of fields: | ||||
|  | ||||
|         You can use the $slice operator to retrieve a subrange of elements in | ||||
|         an array. For example to get the first 5 comments:: | ||||
|             posts = BlogPost.objects(...).fields(author=1, title=1) | ||||
|  | ||||
|             post = BlogPost.objects(...).fields(slice__comments=5) | ||||
|         Exclude a specific field: | ||||
|  | ||||
|         :param kwargs: A dictionary identifying what to include | ||||
|             posts = BlogPost.objects(...).fields(comments=0) | ||||
|  | ||||
|         To retrieve a subrange of array elements: | ||||
|  | ||||
|             posts = BlogPost.objects(...).fields(slice__comments=5) | ||||
|  | ||||
|         :param kwargs: A set keywors arguments identifying what to include. | ||||
|  | ||||
|         .. versionadded:: 0.5 | ||||
|         """ | ||||
| @@ -927,7 +933,20 @@ class BaseQuerySet(object): | ||||
|             key = '.'.join(parts) | ||||
|             cleaned_fields.append((key, value)) | ||||
|  | ||||
|         fields = sorted(cleaned_fields, key=operator.itemgetter(1)) | ||||
|         # Sort fields by their values, explicitly excluded fields first, then | ||||
|         # explicitly included, and then more complicated operators such as | ||||
|         # $slice. | ||||
|         def _sort_key(field_tuple): | ||||
|             key, value = field_tuple | ||||
|             if isinstance(value, (int)): | ||||
|                 return value  # 0 for exclusion, 1 for inclusion | ||||
|             else: | ||||
|                 return 2  # so that complex values appear last | ||||
|  | ||||
|         fields = sorted(cleaned_fields, key=_sort_key) | ||||
|  | ||||
|         # Clone the queryset, group all fields by their value, convert | ||||
|         # each of them to db_fields, and set the queryset's _loaded_fields | ||||
|         queryset = self.clone() | ||||
|         for value, group in itertools.groupby(fields, lambda x: x[1]): | ||||
|             fields = [field for field, value in group] | ||||
|   | ||||
| @@ -101,7 +101,20 @@ def query(_doc_cls=None, **kwargs): | ||||
|                         value = value['_id'] | ||||
|  | ||||
|             elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): | ||||
|                 # 'in', 'nin' and 'all' require a list of values | ||||
|                 # Raise an error if the in/nin/all/near param is not iterable. We need a | ||||
|                 # special check for BaseDocument, because - although it's iterable - using | ||||
|                 # it as such in the context of this method is most definitely a mistake. | ||||
|                 BaseDocument = _import_class('BaseDocument') | ||||
|                 if isinstance(value, BaseDocument): | ||||
|                     raise TypeError("When using the `in`, `nin`, `all`, or " | ||||
|                                     "`near`-operators you can\'t use a " | ||||
|                                     "`Document`, you must wrap your object " | ||||
|                                     "in a list (object -> [object]).") | ||||
|                 elif not hasattr(value, '__iter__'): | ||||
|                     raise TypeError("The `in`, `nin`, `all`, or " | ||||
|                                     "`near`-operators must be applied to an " | ||||
|                                     "iterable (e.g. a list).") | ||||
|                 else: | ||||
|                     value = [field.prepare_query_value(op, v) for v in value] | ||||
|  | ||||
|             # If we're querying a GenericReferenceField, we need to alter the | ||||
| @@ -220,7 +233,6 @@ def update(_doc_cls=None, **update): | ||||
|                 # Support decrement by flipping a positive value's sign | ||||
|                 # and using 'inc' | ||||
|                 op = 'inc' | ||||
|                 if value > 0: | ||||
|                 value = -value | ||||
|             elif op == 'add_to_set': | ||||
|                 op = 'addToSet' | ||||
|   | ||||
| @@ -7,5 +7,5 @@ cover-package=mongoengine | ||||
| [flake8] | ||||
| ignore=E501,F401,F403,F405,I201 | ||||
| exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests | ||||
| max-complexity=45 | ||||
| max-complexity=47 | ||||
| application-import-names=mongoengine,tests | ||||
|   | ||||
| @@ -435,6 +435,15 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|         person.to_dbref() | ||||
|  | ||||
|     def test_save_abstract_document(self): | ||||
|         """Saving an abstract document should fail.""" | ||||
|         class Doc(Document): | ||||
|             name = StringField() | ||||
|             meta = {'abstract': True} | ||||
|  | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|             Doc(name='aaa').save() | ||||
|  | ||||
|     def test_reload(self): | ||||
|         """Ensure that attributes may be reloaded. | ||||
|         """ | ||||
| @@ -1223,6 +1232,19 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(person.name, None) | ||||
|         self.assertEqual(person.age, None) | ||||
|  | ||||
|     def test_update_rename_operator(self): | ||||
|         """Test the $rename operator.""" | ||||
|         coll = self.Person._get_collection() | ||||
|         doc = self.Person(name='John').save() | ||||
|         raw_doc = coll.find_one({'_id': doc.pk}) | ||||
|         self.assertEqual(set(raw_doc.keys()), set(['_id', '_cls', 'name'])) | ||||
|  | ||||
|         doc.update(rename__name='first_name') | ||||
|         raw_doc = coll.find_one({'_id': doc.pk}) | ||||
|         self.assertEqual(set(raw_doc.keys()), | ||||
|                          set(['_id', '_cls', 'first_name'])) | ||||
|         self.assertEqual(raw_doc['first_name'], 'John') | ||||
|  | ||||
|     def test_inserts_if_you_set_the_pk(self): | ||||
|         p1 = self.Person(name='p1', id=bson.ObjectId()).save() | ||||
|         p2 = self.Person(name='p2') | ||||
| @@ -1860,6 +1882,10 @@ class InstanceTest(unittest.TestCase): | ||||
|                 'occurs': {"hello": None} | ||||
|             }) | ||||
|  | ||||
|         # Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438 | ||||
|         with self.assertRaises(ValueError): | ||||
|             Word._from_son('this is not a valid SON dict') | ||||
|  | ||||
|     def test_reverse_delete_rule_cascade_and_nullify(self): | ||||
|         """Ensure that a referenced document is also deleted upon deletion. | ||||
|         """ | ||||
|   | ||||
| @@ -306,6 +306,24 @@ class FieldTest(unittest.TestCase): | ||||
|         person.id = '497ce96f395f2f052a494fd4' | ||||
|         person.validate() | ||||
|  | ||||
|     def test_db_field_validation(self): | ||||
|         """Ensure that db_field doesn't accept invalid values.""" | ||||
|  | ||||
|         # dot in the name | ||||
|         with self.assertRaises(ValueError): | ||||
|             class User(Document): | ||||
|                 name = StringField(db_field='user.name') | ||||
|  | ||||
|         # name starting with $ | ||||
|         with self.assertRaises(ValueError): | ||||
|             class User(Document): | ||||
|                 name = StringField(db_field='$name') | ||||
|  | ||||
|         # name containing a null character | ||||
|         with self.assertRaises(ValueError): | ||||
|             class User(Document): | ||||
|                 name = StringField(db_field='name\0') | ||||
|  | ||||
|     def test_string_validation(self): | ||||
|         """Ensure that invalid values cannot be assigned to string fields. | ||||
|         """ | ||||
| @@ -1042,6 +1060,7 @@ class FieldTest(unittest.TestCase): | ||||
|         self.assertEqual( | ||||
|             BlogPost.objects.filter(info__100__test__exact='test').count(), 0) | ||||
|  | ||||
|         # test queries by list | ||||
|         post = BlogPost() | ||||
|         post.info = ['1', '2'] | ||||
|         post.save() | ||||
| @@ -1053,6 +1072,248 @@ class FieldTest(unittest.TestCase): | ||||
|         post.info *= 2 | ||||
|         post.save() | ||||
|         self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_field_manipulative_operators(self): | ||||
|         """Ensure that ListField works with standard list operators that manipulate the list. | ||||
|         """ | ||||
|         class BlogPost(Document): | ||||
|             ref = StringField() | ||||
|             info = ListField(StringField()) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost() | ||||
|         post.ref = "1234" | ||||
|         post.info = ['0', '1', '2', '3', '4', '5'] | ||||
|         post.save() | ||||
|  | ||||
|         def reset_post(): | ||||
|             post.info = ['0', '1', '2', '3', '4', '5'] | ||||
|             post.save() | ||||
|  | ||||
|         # '__add__(listB)' | ||||
|         # listA+listB | ||||
|         # operator.add(listA, listB) | ||||
|         reset_post() | ||||
|         temp = ['a', 'b'] | ||||
|         post.info = post.info + temp | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) | ||||
|  | ||||
|         # '__delitem__(index)' | ||||
|         # aka 'del list[index]' | ||||
|         # aka 'operator.delitem(list, index)' | ||||
|         reset_post() | ||||
|         del post.info[2] # del from middle ('2') | ||||
|         self.assertEqual(post.info, ['0', '1', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '3', '4', '5']) | ||||
|  | ||||
|         # '__delitem__(slice(i, j))' | ||||
|         # aka 'del list[i:j]' | ||||
|         # aka 'operator.delitem(list, slice(i,j))' | ||||
|         reset_post() | ||||
|         del post.info[1:3] # removes '1', '2' | ||||
|         self.assertEqual(post.info, ['0', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '3', '4', '5']) | ||||
|  | ||||
|         # '__iadd__' | ||||
|         # aka 'list += list' | ||||
|         reset_post() | ||||
|         temp = ['a', 'b'] | ||||
|         post.info += temp | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) | ||||
|  | ||||
|         # '__imul__' | ||||
|         # aka 'list *= number' | ||||
|         reset_post() | ||||
|         post.info *= 2 | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) | ||||
|  | ||||
|         # '__mul__' | ||||
|         # aka 'listA*listB' | ||||
|         reset_post() | ||||
|         post.info = post.info * 2 | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) | ||||
|  | ||||
|         # '__rmul__' | ||||
|         # aka 'listB*listA' | ||||
|         reset_post() | ||||
|         post.info = 2 * post.info | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) | ||||
|  | ||||
|         # '__setitem__(index, value)' | ||||
|         # aka 'list[index]=value' | ||||
|         # aka 'setitem(list, value)' | ||||
|         reset_post() | ||||
|         post.info[4] = 'a' | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5']) | ||||
|  | ||||
|         # '__setitem__(slice(i, j), listB)' | ||||
|         # aka 'listA[i:j] = listB' | ||||
|         # aka 'setitem(listA, slice(i, j), listB)' | ||||
|         reset_post() | ||||
|         post.info[1:3] = ['h', 'e', 'l', 'l', 'o'] | ||||
|         self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5']) | ||||
|  | ||||
|         # 'append' | ||||
|         reset_post() | ||||
|         post.info.append('h') | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h']) | ||||
|  | ||||
|         # 'extend' | ||||
|         reset_post() | ||||
|         post.info.extend(['h', 'e', 'l', 'l', 'o']) | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o']) | ||||
|         # 'insert' | ||||
|  | ||||
|         # 'pop' | ||||
|         reset_post() | ||||
|         x = post.info.pop(2) | ||||
|         y = post.info.pop() | ||||
|         self.assertEqual(post.info, ['0', '1', '3', '4']) | ||||
|         self.assertEqual(x, '2') | ||||
|         self.assertEqual(y, '5') | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '3', '4']) | ||||
|  | ||||
|         # 'remove' | ||||
|         reset_post() | ||||
|         post.info.remove('2') | ||||
|         self.assertEqual(post.info, ['0', '1', '3', '4', '5']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['0', '1', '3', '4', '5']) | ||||
|  | ||||
|         # 'reverse' | ||||
|         reset_post() | ||||
|         post.info.reverse() | ||||
|         self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0']) | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0']) | ||||
|  | ||||
|         # 'sort': though this operator method does manipulate the list, it is tested in | ||||
|         #     the 'test_list_field_lexicograpic_operators' function | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_field_invalid_operators(self): | ||||
|         class BlogPost(Document): | ||||
|             ref = StringField() | ||||
|             info = ListField(StringField()) | ||||
|         post = BlogPost() | ||||
|         post.ref = "1234" | ||||
|         post.info = ['0', '1', '2', '3', '4', '5'] | ||||
|         # '__hash__' | ||||
|         # aka 'hash(list)' | ||||
|         # # assert TypeError | ||||
|         self.assertRaises(TypeError, lambda: hash(post.info)) | ||||
|  | ||||
|     def test_list_field_lexicographic_operators(self): | ||||
|         """Ensure that ListField works with standard list operators that do lexigraphic ordering. | ||||
|         """ | ||||
|         class BlogPost(Document): | ||||
|             ref = StringField() | ||||
|             text_info = ListField(StringField()) | ||||
|             oid_info = ListField(ObjectIdField()) | ||||
|             bool_info = ListField(BooleanField()) | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         blogSmall = BlogPost(ref="small") | ||||
|         blogSmall.text_info = ["a", "a", "a"] | ||||
|         blogSmall.bool_info = [False, False] | ||||
|         blogSmall.save() | ||||
|         blogSmall.reload() | ||||
|  | ||||
|         blogLargeA = BlogPost(ref="big") | ||||
|         blogLargeA.text_info = ["a", "z", "j"] | ||||
|         blogLargeA.bool_info = [False, True] | ||||
|         blogLargeA.save() | ||||
|         blogLargeA.reload() | ||||
|  | ||||
|         blogLargeB = BlogPost(ref="big2") | ||||
|         blogLargeB.text_info = ["a", "z", "j"] | ||||
|         blogLargeB.oid_info = [ | ||||
|             "54495ad94c934721ede76f90", | ||||
|             "54495ad94c934721ede76d23", | ||||
|             "54495ad94c934721ede76d00" | ||||
|         ] | ||||
|         blogLargeB.bool_info = [False, True] | ||||
|         blogLargeB.save() | ||||
|         blogLargeB.reload() | ||||
|         # '__eq__' aka '==' | ||||
|         self.assertEqual(blogLargeA.text_info, blogLargeB.text_info) | ||||
|         self.assertEqual(blogLargeA.bool_info, blogLargeB.bool_info) | ||||
|         # '__ge__' aka '>=' | ||||
|         self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info) | ||||
|         self.assertGreaterEqual(blogLargeA.text_info, blogLargeB.text_info) | ||||
|         self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info) | ||||
|         self.assertGreaterEqual(blogLargeA.bool_info, blogLargeB.bool_info) | ||||
|         # '__gt__' aka '>' | ||||
|         self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info) | ||||
|         self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info) | ||||
|         # '__le__' aka '<=' | ||||
|         self.assertLessEqual(blogSmall.text_info, blogLargeB.text_info) | ||||
|         self.assertLessEqual(blogLargeA.text_info, blogLargeB.text_info) | ||||
|         self.assertLessEqual(blogSmall.bool_info, blogLargeB.bool_info) | ||||
|         self.assertLessEqual(blogLargeA.bool_info, blogLargeB.bool_info) | ||||
|         # '__lt__' aka '<' | ||||
|         self.assertLess(blogSmall.text_info, blogLargeB.text_info) | ||||
|         self.assertLess(blogSmall.bool_info, blogLargeB.bool_info) | ||||
|         # '__ne__' aka '!=' | ||||
|         self.assertNotEqual(blogSmall.text_info, blogLargeB.text_info) | ||||
|         self.assertNotEqual(blogSmall.bool_info, blogLargeB.bool_info) | ||||
|         # 'sort' | ||||
|         blogLargeB.bool_info = [True, False, True, False] | ||||
|         blogLargeB.text_info.sort() | ||||
|         blogLargeB.oid_info.sort() | ||||
|         blogLargeB.bool_info.sort() | ||||
|         sorted_target_list = [ | ||||
|             ObjectId("54495ad94c934721ede76d00"), | ||||
|             ObjectId("54495ad94c934721ede76d23"), | ||||
|             ObjectId("54495ad94c934721ede76f90") | ||||
|         ] | ||||
|         self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) | ||||
|         self.assertEqual(blogLargeB.oid_info, sorted_target_list) | ||||
|         self.assertEqual(blogLargeB.bool_info, [False, False, True, True]) | ||||
|         blogLargeB.save() | ||||
|         blogLargeB.reload() | ||||
|         self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) | ||||
|         self.assertEqual(blogLargeB.oid_info, sorted_target_list) | ||||
|         self.assertEqual(blogLargeB.bool_info, [False, False, True, True]) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_assignment(self): | ||||
| @@ -1102,7 +1363,6 @@ class FieldTest(unittest.TestCase): | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 2, 3, 4, 'n5']) | ||||
|  | ||||
|  | ||||
|     def test_list_field_passed_in_value(self): | ||||
|         class Foo(Document): | ||||
|             bars = ListField(ReferenceField("Bar")) | ||||
| @@ -3731,30 +3991,25 @@ class FieldTest(unittest.TestCase): | ||||
|         """Tests if a `FieldDoesNotExist` exception is raised when trying to | ||||
|         instanciate a document with a field that's not defined. | ||||
|         """ | ||||
|  | ||||
|         class Doc(Document): | ||||
|             foo = StringField(db_field='f') | ||||
|             foo = StringField() | ||||
|  | ||||
|         def test(): | ||||
|         with self.assertRaises(FieldDoesNotExist): | ||||
|             Doc(bar='test') | ||||
|  | ||||
|         self.assertRaises(FieldDoesNotExist, test) | ||||
|  | ||||
|     def test_undefined_field_exception_with_strict(self): | ||||
|         """Tests if a `FieldDoesNotExist` exception is raised when trying to | ||||
|         instanciate a document with a field that's not defined, | ||||
|         even when strict is set to False. | ||||
|         """ | ||||
|  | ||||
|         class Doc(Document): | ||||
|             foo = StringField(db_field='f') | ||||
|             foo = StringField() | ||||
|             meta = {'strict': False} | ||||
|  | ||||
|         def test(): | ||||
|         with self.assertRaises(FieldDoesNotExist): | ||||
|             Doc(bar='test') | ||||
|  | ||||
|         self.assertRaises(FieldDoesNotExist, test) | ||||
|  | ||||
|     def test_long_field_is_considered_as_int64(self): | ||||
|         """ | ||||
|         Tests that long fields are stored as long in mongo, even if long value | ||||
|   | ||||
| @@ -141,6 +141,16 @@ class OnlyExcludeAllTest(unittest.TestCase): | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), | ||||
|                          {'b': {'$slice': 5}}) | ||||
|  | ||||
|     def test_mix_slice_with_other_fields(self): | ||||
|         class MyDoc(Document): | ||||
|             a = ListField() | ||||
|             b = ListField() | ||||
|             c = ListField() | ||||
|  | ||||
|         qs = MyDoc.objects.fields(a=1, b=0, slice__c=2) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), | ||||
|                          {'c': {'$slice': 2}, 'a': 1}) | ||||
|  | ||||
|     def test_only(self): | ||||
|         """Ensure that QuerySet.only only returns the requested fields. | ||||
|         """ | ||||
|   | ||||
| @@ -1266,7 +1266,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|     def test_find_embedded(self): | ||||
|         """Ensure that an embedded document is properly returned from | ||||
|         a query. | ||||
|         different manners of querying. | ||||
|         """ | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
| @@ -1277,8 +1277,9 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         user = User(name='Test User') | ||||
|         BlogPost.objects.create( | ||||
|             author=User(name='Test User'), | ||||
|             author=user, | ||||
|             content='Had a good coffee today...' | ||||
|         ) | ||||
|  | ||||
| @@ -1286,6 +1287,19 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|         result = BlogPost.objects.get(author__name=user.name) | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|         result = BlogPost.objects.get(author={'name': user.name}) | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|         # Fails, since the string is not a type that is able to represent the | ||||
|         # author's document structure (should be dict) | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             BlogPost.objects.get(author=user.name) | ||||
|  | ||||
|     def test_find_empty_embedded(self): | ||||
|         """Ensure that you can save and find an empty embedded document.""" | ||||
|         class User(EmbeddedDocument): | ||||
| @@ -1812,6 +1826,11 @@ class QuerySetTest(unittest.TestCase): | ||||
|         post.reload() | ||||
|         self.assertEqual(post.hits, 10) | ||||
|  | ||||
|         # Negative dec operator is equal to a positive inc operator | ||||
|         BlogPost.objects.update_one(dec__hits=-1) | ||||
|         post.reload() | ||||
|         self.assertEqual(post.hits, 11) | ||||
|  | ||||
|         BlogPost.objects.update(push__tags='mongo') | ||||
|         post.reload() | ||||
|         self.assertTrue('mongo' in post.tags) | ||||
| @@ -4963,6 +4982,35 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(i, 249) | ||||
|         self.assertEqual(j, 249) | ||||
|  | ||||
|     def test_in_operator_on_non_iterable(self): | ||||
|         """Ensure that using the `__in` operator on a non-iterable raises an | ||||
|         error. | ||||
|         """ | ||||
|         class User(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             content = StringField() | ||||
|             authors = ListField(ReferenceField(User)) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         author = User.objects.create(name='Test User') | ||||
|         post = BlogPost.objects.create(content='Had a good coffee today...', | ||||
|                                        authors=[author]) | ||||
|  | ||||
|         # Make sure using `__in` with a list works | ||||
|         blog_posts = BlogPost.objects(authors__in=[author]) | ||||
|         self.assertEqual(list(blog_posts), [post]) | ||||
|  | ||||
|         # Using `__in` with a non-iterable should raise a TypeError | ||||
|         self.assertRaises(TypeError, BlogPost.objects(authors__in=author.pk).count) | ||||
|  | ||||
|         # Using `__in` with a `Document` (which is seemingly iterable but not | ||||
|         # in a way we'd expect) should raise a TypeError, too | ||||
|         self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -200,6 +200,19 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'test') | ||||
|  | ||||
|     def test_connect_uri_with_replicaset(self): | ||||
|         """Ensure connect() works when specifying a replicaSet.""" | ||||
|         if IS_PYMONGO_3: | ||||
|             c = connect(host='mongodb://localhost/test?replicaSet=local-rs') | ||||
|             db = get_db() | ||||
|             self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|             self.assertEqual(db.name, 'test') | ||||
|         else: | ||||
|             # PyMongo < v3.x raises an exception: | ||||
|             # "localhost:27017 is not a member of replica set local-rs" | ||||
|             with self.assertRaises(MongoEngineConnectionError): | ||||
|                 c = connect(host='mongodb://localhost/test?replicaSet=local-rs') | ||||
|  | ||||
|     def test_uri_without_credentials_doesnt_override_conn_settings(self): | ||||
|         """Ensure connect() uses the username & password params if the URI | ||||
|         doesn't explicitly specify them. | ||||
| @@ -283,6 +296,19 @@ class ConnectionTest(unittest.TestCase): | ||||
|         conn = get_connection('t2') | ||||
|         self.assertFalse(get_tz_awareness(conn)) | ||||
|  | ||||
|     def test_write_concern(self): | ||||
|         """Ensure write concern can be specified in connect() via | ||||
|         a kwarg or as part of the connection URI. | ||||
|         """ | ||||
|         conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true') | ||||
|         conn2 = connect('testing', alias='conn2', w=1, j=True) | ||||
|         if IS_PYMONGO_3: | ||||
|             self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) | ||||
|             self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) | ||||
|         else: | ||||
|             self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True}) | ||||
|             self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True}) | ||||
|  | ||||
|     def test_datetime(self): | ||||
|         connect('mongoenginetest', tz_aware=True) | ||||
|         d = datetime.datetime(2010, 5, 5, tzinfo=utc) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user