Compare commits
	
		
			138 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | c6c5f85abb | ||
|  | 7b860f7739 | ||
|  | e28804c03a | ||
|  | 1b9432824b | ||
|  | 25e0f12976 | ||
|  | f168682a68 | ||
|  | d25058a46d | ||
|  | 4d0c092d9f | ||
|  | 15714ef855 | ||
|  | eb743beaa3 | ||
|  | 0007535a46 | ||
|  | 8391af026c | ||
|  | 800f656dcf | ||
|  | 088c5f49d9 | ||
|  | d8d98b6143 | ||
|  | 02fb3b9315 | ||
|  | 4f87db784e | ||
|  | 7e6287b925 | ||
|  | 999cdfd997 | ||
|  | 8d6cb087c6 | ||
|  | 2b7417c728 | ||
|  | 3c455cf1c1 | ||
|  | 5135185e31 | ||
|  | b461f26e5d | ||
|  | faef5b8570 | ||
|  | 0a20e04c10 | ||
|  | d19bb2308d | ||
|  | d8dd07d9ef | ||
|  | 36c56243cd | ||
|  | 23d06b79a6 | ||
|  | e4c4e923ee | ||
|  | 936d2f1f47 | ||
|  | 07018b5060 | ||
|  | ac90d6ae5c | ||
|  | 2141f2c4c5 | ||
|  | 81870777a9 | ||
|  | 845092dcad | ||
|  | dd473d1e1e | ||
|  | d2869bf4ed | ||
|  | 891a3f4b29 | ||
|  | 6767b50d75 | ||
|  | d9e4b562a9 | ||
|  | fb3243f1bc | ||
|  | 5fe1497c92 | ||
|  | 5446592d44 | ||
|  | 40ed9a53c9 | ||
|  | f7ac8cea90 | ||
|  | 4ef5d1f0cd | ||
|  | 6992615c98 | ||
|  | 43dabb2825 | ||
|  | 05e40e5681 | ||
|  | 2c4536e137 | ||
|  | 3dc81058a0 | ||
|  | bd84667a2b | ||
|  | e5b6a12977 | ||
|  | ca415d5d62 | ||
|  | 99b4fe7278 | ||
|  | 327e164869 | ||
|  | 25bc571f30 | ||
|  | 38c7e8a1d2 | ||
|  | ca282e28e0 | ||
|  | 5ef59c06df | ||
|  | 8f55d385d6 | ||
|  | cd2fc25c19 | ||
|  | 709983eea6 | ||
|  | 40e99b1b80 | ||
|  | 488684d960 | ||
|  | f35034b989 | ||
|  | 9d6f9b1f26 | ||
|  | 6148a608fb | ||
|  | 3fa9e70383 | ||
|  | 16fea6f009 | ||
|  | df9ed835ca | ||
|  | e394c8f0f2 | ||
|  | 21974f7288 | ||
|  | 5ef0170d77 | ||
|  | c21dcf14de | ||
|  | a8d20d4e1e | ||
|  | 8b307485b0 | ||
|  | 4544afe422 | ||
|  | 9d7eba5f70 | ||
|  | be0aee95f2 | ||
|  | 3469ed7ab9 | ||
|  | 1f223aa7e6 | ||
|  | 0a431ead5e | ||
|  | f750796444 | ||
|  | c82bcd882a | ||
|  | 7d0ec33b54 | ||
|  | 43d48b3feb | ||
|  | 2e406d2687 | ||
|  | 3f30808104 | ||
|  | ab10217c86 | ||
|  | 00430491ca | ||
|  | 109202329f | ||
|  | 3b1509f307 | ||
|  | 7ad7b08bed | ||
|  | 4650e5e8fb | ||
|  | af59d4929e | ||
|  | e34100bab4 | ||
|  | d9b3a9fb60 | ||
|  | 39eec59c90 | ||
|  | d651d0d472 | ||
|  | 87a2358a65 | ||
|  | cef4e313e1 | ||
|  | 7cc1a4eba0 | ||
|  | c6cc0133b3 | ||
|  | 7748e68440 | ||
|  | 6c2230a076 | ||
|  | 66b233eaea | ||
|  | fed58f3920 | ||
|  | 815b2be7f7 | ||
|  | f420c9fb7c | ||
|  | 01bdf10b94 | ||
|  | ddedc1ee92 | ||
|  | 9e9703183f | ||
|  | adce9e6220 | ||
|  | c499133bbe | ||
|  | 8f505c2dcc | ||
|  | b320064418 | ||
|  | a643933d16 | ||
|  | 2659ec5887 | ||
|  | 9f8327926d | ||
|  | 7a568dc118 | ||
|  | c946b06be5 | ||
|  | c65fd0e477 | ||
|  | 8f8217e928 | ||
|  | 6c9e1799c7 | ||
|  | decd70eb23 | ||
|  | feb5eed8a5 | ||
|  | acc7448dc5 | ||
|  | 35d3d3de72 | ||
|  | 9c264611cf | ||
|  | 8e7c5af16c | ||
|  | c1645ab7a7 | ||
|  | 2ae2bfdde9 | ||
|  | 3fe93968a6 | ||
|  | eb8176971c | ||
|  | 5bbfca45fa | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -14,4 +14,6 @@ env/ | ||||
| .project | ||||
| .pydevproject | ||||
| tests/test_bugfix.py | ||||
| htmlcov/ | ||||
| htmlcov/ | ||||
| venv | ||||
| venv3 | ||||
|   | ||||
							
								
								
									
										24
									
								
								.travis.yml
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								.travis.yml
									
									
									
									
									
								
							| @@ -1,42 +1,58 @@ | ||||
| language: python | ||||
|  | ||||
| python: | ||||
| - '2.6' | ||||
| - '2.6'  # TODO remove in v0.11.0 | ||||
| - '2.7' | ||||
| - '3.2' | ||||
| - '3.3' | ||||
| - '3.4' | ||||
| - '3.5' | ||||
| - pypy | ||||
| - pypy3 | ||||
|  | ||||
| env: | ||||
| - PYMONGO=2.7 | ||||
| - PYMONGO=2.8 | ||||
| - PYMONGO=3.0 | ||||
| - PYMONGO=dev | ||||
|  | ||||
| matrix: | ||||
|   fast_finish: true | ||||
|  | ||||
| before_install: | ||||
| - travis_retry sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 | ||||
| - echo 'deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen' | | ||||
|   sudo tee /etc/apt/sources.list.d/mongodb.list | ||||
| - travis_retry sudo apt-get update | ||||
| - travis_retry sudo apt-get install mongodb-org-server | ||||
|  | ||||
| install: | ||||
| - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev | ||||
|   libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev | ||||
|   python-tk | ||||
| # virtualenv>=14.0.0 has dropped Python 3.2 support | ||||
| - travis_retry pip install "virtualenv<14.0.0" "tox>=1.9" coveralls | ||||
| - travis_retry pip install --upgrade pip | ||||
| - travis_retry pip install coveralls | ||||
| - travis_retry pip install flake8 | ||||
| - travis_retry pip install tox>=1.9 | ||||
| - travis_retry pip install "virtualenv<14.0.0"  # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) | ||||
| - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test | ||||
|  | ||||
| # Run flake8 for py27 | ||||
| before_script: | ||||
| - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then tox -e flake8; fi | ||||
|  | ||||
| script: | ||||
| - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage | ||||
|  | ||||
| after_script: coveralls --verbose | ||||
|  | ||||
| notifications: | ||||
|   irc: irc.freenode.org#mongoengine | ||||
|  | ||||
| branches: | ||||
|   only: | ||||
|   - master | ||||
|   - /^v.*$/ | ||||
|  | ||||
| deploy: | ||||
|   provider: pypi | ||||
|   user: the_drow | ||||
|   | ||||
							
								
								
									
										8
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -228,9 +228,17 @@ that much better: | ||||
|  * Vicki Donchenko (https://github.com/kivistein) | ||||
|  * Emile Caron (https://github.com/emilecaron) | ||||
|  * Amit Lichtenberg (https://github.com/amitlicht) | ||||
|  * Gang Li (https://github.com/iici-gli) | ||||
|  * Lars Butler (https://github.com/larsbutler) | ||||
|  * George Macon (https://github.com/gmacon) | ||||
|  * Ashley Whetter (https://github.com/AWhetter) | ||||
|  * Paul-Armand Verhaegen (https://github.com/paularmand) | ||||
|  * Steven Rossiter (https://github.com/BeardedSteve) | ||||
|  * Luo Peng (https://github.com/RussellLuo) | ||||
|  * Bryan Bennett (https://github.com/bbenne10) | ||||
|  * Gilb's Gilb's (https://github.com/gilbsgilbs) | ||||
|  * Joshua Nedrud (https://github.com/Neurostack) | ||||
|  * Shu Shen (https://github.com/shushen) | ||||
|  * xiaost7 (https://github.com/xiaost7) | ||||
|  * Victor Varvaryuk | ||||
|  * Stanislav Kaledin (https://github.com/sallyruthstruik) | ||||
|   | ||||
							
								
								
									
										28
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								README.rst
									
									
									
									
									
								
							| @@ -6,23 +6,23 @@ MongoEngine | ||||
| :Author: Harry Marr (http://github.com/hmarr) | ||||
| :Maintainer: Ross Lawley (http://github.com/rozza) | ||||
|  | ||||
| .. image:: https://secure.travis-ci.org/MongoEngine/mongoengine.png?branch=master | ||||
|   :target: http://travis-ci.org/MongoEngine/mongoengine | ||||
| .. image:: https://travis-ci.org/MongoEngine/mongoengine.svg?branch=master | ||||
|   :target: https://travis-ci.org/MongoEngine/mongoengine | ||||
|  | ||||
| .. image:: https://coveralls.io/repos/MongoEngine/mongoengine/badge.png?branch=master | ||||
|   :target: https://coveralls.io/r/MongoEngine/mongoengine?branch=master | ||||
| .. image:: https://coveralls.io/repos/github/MongoEngine/mongoengine/badge.svg?branch=master | ||||
|   :target: https://coveralls.io/github/MongoEngine/mongoengine?branch=master | ||||
|  | ||||
| .. image:: https://landscape.io/github/MongoEngine/mongoengine/master/landscape.png | ||||
|    :target: https://landscape.io/github/MongoEngine/mongoengine/master | ||||
|    :alt: Code Health | ||||
| .. image:: https://landscape.io/github/MongoEngine/mongoengine/master/landscape.svg?style=flat | ||||
|   :target: https://landscape.io/github/MongoEngine/mongoengine/master | ||||
|   :alt: Code Health | ||||
|  | ||||
| About | ||||
| ===== | ||||
| MongoEngine is a Python Object-Document Mapper for working with MongoDB. | ||||
| Documentation available at http://mongoengine-odm.rtfd.org - there is currently | ||||
| a `tutorial <http://readthedocs.org/docs/mongoengine-odm/en/latest/tutorial.html>`_, a `user guide | ||||
| <https://mongoengine-odm.readthedocs.org/en/latest/guide/index.html>`_ and an `API reference | ||||
| <http://readthedocs.org/docs/mongoengine-odm/en/latest/apireference.html>`_. | ||||
| Documentation available at https://mongoengine-odm.readthedocs.io - there is currently | ||||
| a `tutorial <https://mongoengine-odm.readthedocs.io/tutorial.html>`_, a `user guide | ||||
| <https://mongoengine-odm.readthedocs.io/guide/index.html>`_ and an `API reference | ||||
| <https://mongoengine-odm.readthedocs.io/apireference.html>`_. | ||||
|  | ||||
| Installation | ||||
| ============ | ||||
| @@ -52,10 +52,14 @@ Some simple examples of what MongoEngine code looks like: | ||||
|  | ||||
| .. code :: python | ||||
|  | ||||
|     from mongoengine import * | ||||
|     connect('mydb') | ||||
|  | ||||
|     class BlogPost(Document): | ||||
|         title = StringField(required=True, max_length=200) | ||||
|         posted = DateTimeField(default=datetime.datetime.now) | ||||
|         tags = ListField(StringField(max_length=50)) | ||||
|         meta = {'allow_inheritance': True} | ||||
|  | ||||
|     class TextPost(BlogPost): | ||||
|         content = StringField(required=True) | ||||
| @@ -99,7 +103,7 @@ Some simple examples of what MongoEngine code looks like: | ||||
| Tests | ||||
| ===== | ||||
| To run the test suite, ensure you are running a local instance of MongoDB on | ||||
| the standard port, and run: ``python setup.py nosetests``. | ||||
| the standard port and have ``nose`` installed. Then, run: ``python setup.py nosetests``. | ||||
|  | ||||
| To run the test suite on every supported Python version and every supported PyMongo version, | ||||
| you can use ``tox``. | ||||
|   | ||||
| @@ -2,11 +2,51 @@ | ||||
| Changelog | ||||
| ========= | ||||
|  | ||||
| Changes in 0.10.8 | ||||
| ================= | ||||
| - Added support for QuerySet.batch_size (#1426) | ||||
| - Fixed query set iteration within iteration #1427 | ||||
| - Fixed an issue where specifying a MongoDB URI host would override more information than it should #1421 | ||||
| - Added ability to filter the generic reference field by ObjectId and DBRef #1425 | ||||
| - Fixed delete cascade for models with a custom primary key field #1247 | ||||
| - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | ||||
| - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | ||||
| - Fixed QuerySet#sum/average for fields w/ explicit db_field #1417 | ||||
| - Fixed filtering by embedded_doc=None #1422 | ||||
| - Added support for cursor.comment #1420 | ||||
| - Fixed doc.get_<field>_display #1419 | ||||
| - Fixed __repr__ method of the StrictDict #1424 | ||||
| - Added a deprecation warning for Python 2.6 | ||||
|  | ||||
| Changes in 0.10.7 | ||||
| ================= | ||||
| - Dropped Python 3.2 support #1390 | ||||
| - Fixed the bug where dynamic doc has index inside a dict field #1278 | ||||
| - Fixed: ListField minus index assignment does not work #1128 | ||||
| - Fixed cascade delete mixing among collections #1224 | ||||
| - Add `signal_kwargs` argument to `Document.save`, `Document.delete` and `BaseQuerySet.insert` to be passed to signals calls #1206 | ||||
| - Raise `OperationError` when trying to do a `drop_collection` on document with no collection set. | ||||
| - count on ListField of EmbeddedDocumentField fails. #1187 | ||||
| - Fixed long fields stored as int32 in Python 3. #1253 | ||||
| - MapField now handles unicodes keys correctly. #1267 | ||||
| - ListField now handles negative indicies correctly. #1270 | ||||
| - Fixed AttributeError when initializing EmbeddedDocument with positional args. #681 | ||||
| - Fixed no_cursor_timeout error with pymongo 3.0+ #1304 | ||||
| - Replaced map-reduce based QuerySet.sum/average with aggregation-based implementations #1336 | ||||
| - Fixed support for `__` to escape field names that match operators names in `update` #1351 | ||||
| - Fixed BaseDocument#_mark_as_changed #1369 | ||||
| - Added support for pickling QuerySet instances. #1397 | ||||
| - Fixed connecting to a list of hosts #1389 | ||||
| - Fixed a bug where accessing broken references wouldn't raise a DoesNotExist error #1334 | ||||
| - Fixed not being able to specify use_db_field=False on ListField(EmbeddedDocumentField) instances #1218 | ||||
| - Improvements to the dictionary fields docs #1383 | ||||
|  | ||||
| Changes in 0.10.6 | ||||
| ================= | ||||
| - Add support for mocking MongoEngine based on mongomock. #1151 | ||||
| - Fixed not being able to run tests on Windows. #1153 | ||||
| - Allow creation of sparse compound indexes. #1114 | ||||
| - count on ListField of EmbeddedDocumentField fails. #1187 | ||||
|  | ||||
| Changes in 0.10.5 | ||||
| ================= | ||||
| @@ -35,6 +75,8 @@ Changes in 0.10.1 | ||||
| - Document save's save_condition error raises `SaveConditionError` exception #1070 | ||||
| - Fix Document.reload for DynamicDocument. #1050 | ||||
| - StrictDict & SemiStrictDict are shadowed at init time. #1105 | ||||
| - Fix ListField minus index assignment does not work. #1119 | ||||
| - Remove code that marks field as changed when the field has default but not existed in database #1126 | ||||
| - Remove test dependencies (nose and rednose) from install dependencies list. #1079 | ||||
| - Recursively build query when using elemMatch operator. #1130 | ||||
| - Fix instance back references for lists of embedded documents. #1131 | ||||
|   | ||||
| @@ -29,7 +29,7 @@ documents are serialized based on their field order. | ||||
|  | ||||
| Dynamic document schemas | ||||
| ======================== | ||||
| One of the benefits of MongoDb is dynamic schemas for a collection, whilst data | ||||
| One of the benefits of MongoDB is dynamic schemas for a collection, whilst data | ||||
| should be planned and organised (after all explicit is better than implicit!) | ||||
| there are scenarios where having dynamic / expando style documents is desirable. | ||||
|  | ||||
| @@ -75,6 +75,7 @@ are as follows: | ||||
| * :class:`~mongoengine.fields.DynamicField` | ||||
| * :class:`~mongoengine.fields.EmailField` | ||||
| * :class:`~mongoengine.fields.EmbeddedDocumentField` | ||||
| * :class:`~mongoengine.fields.EmbeddedDocumentListField` | ||||
| * :class:`~mongoengine.fields.FileField` | ||||
| * :class:`~mongoengine.fields.FloatField` | ||||
| * :class:`~mongoengine.fields.GenericEmbeddedDocumentField` | ||||
| @@ -213,9 +214,9 @@ document class as the first argument:: | ||||
|  | ||||
| Dictionary Fields | ||||
| ----------------- | ||||
| Often, an embedded document may be used instead of a dictionary -- generally | ||||
| this is recommended as dictionaries don't support validation or custom field | ||||
| types. However, sometimes you will not know the structure of what you want to | ||||
| Often, an embedded document may be used instead of a dictionary – generally  | ||||
| embedded documents are recommended as dictionaries don’t support validation  | ||||
| or custom field types. However, sometimes you will not know the structure of what you want to | ||||
| store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate:: | ||||
|  | ||||
|     class SurveyResponse(Document): | ||||
|   | ||||
| @@ -237,7 +237,7 @@ is preferred for achieving this:: | ||||
|     # All except for the first 5 people | ||||
|     users = User.objects[5:] | ||||
|  | ||||
|     # 5 users, starting from the 10th user found | ||||
|     # 5 users, starting from the 11th user found | ||||
|     users = User.objects[10:15] | ||||
|  | ||||
| You may also index the query to retrieve a single result. If an item at that | ||||
|   | ||||
| @@ -1,20 +1,20 @@ | ||||
| import document | ||||
| from document import * | ||||
| import fields | ||||
| from fields import * | ||||
| import connection | ||||
| from connection import * | ||||
| import document | ||||
| from document import * | ||||
| import errors | ||||
| from errors import * | ||||
| import fields | ||||
| from fields import * | ||||
| import queryset | ||||
| from queryset import * | ||||
| import signals | ||||
| from signals import * | ||||
| from errors import * | ||||
| import errors | ||||
|  | ||||
| __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + | ||||
|            list(queryset.__all__) + signals.__all__ + list(errors.__all__)) | ||||
|  | ||||
| VERSION = (0, 10, 6) | ||||
| VERSION = (0, 10, 7) | ||||
|  | ||||
|  | ||||
| def get_version(): | ||||
| @@ -22,4 +22,5 @@ def get_version(): | ||||
|         return '.'.join(map(str, VERSION[:-1])) + VERSION[-1] | ||||
|     return '.'.join(map(str, VERSION)) | ||||
|  | ||||
|  | ||||
| __version__ = get_version() | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import weakref | ||||
| import itertools | ||||
| import weakref | ||||
|  | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | ||||
| @@ -199,7 +199,9 @@ class BaseList(list): | ||||
|     def _mark_as_changed(self, key=None): | ||||
|         if hasattr(self._instance, '_mark_as_changed'): | ||||
|             if key: | ||||
|                 self._instance._mark_as_changed('%s.%s' % (self._name, key)) | ||||
|                 self._instance._mark_as_changed( | ||||
|                     '%s.%s' % (self._name, key % len(self)) | ||||
|                 ) | ||||
|             else: | ||||
|                 self._instance._mark_as_changed(self._name) | ||||
|  | ||||
| @@ -210,7 +212,7 @@ class EmbeddedDocumentList(BaseList): | ||||
|     def __match_all(cls, i, kwargs): | ||||
|         items = kwargs.items() | ||||
|         return all([ | ||||
|             getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items | ||||
|             getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items | ||||
|         ]) | ||||
|  | ||||
|     @classmethod | ||||
| @@ -436,7 +438,7 @@ class StrictDict(object): | ||||
|                 __slots__ = allowed_keys_tuple | ||||
|  | ||||
|                 def __repr__(self): | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys()) | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) | ||||
|  | ||||
|             cls._classes[allowed_keys] = SpecificStrictDict | ||||
|         return cls._classes[allowed_keys] | ||||
|   | ||||
| @@ -1,28 +1,28 @@ | ||||
| import copy | ||||
| import operator | ||||
| import numbers | ||||
| import operator | ||||
| from collections import Hashable | ||||
| from functools import partial | ||||
|  | ||||
| import pymongo | ||||
| from bson import json_util, ObjectId | ||||
| from bson import ObjectId, json_util | ||||
| from bson.dbref import DBRef | ||||
| from bson.son import SON | ||||
| import pymongo | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import (ValidationError, InvalidDocumentError, | ||||
|                                 LookUpError, FieldDoesNotExist) | ||||
| from mongoengine.python_support import PY3, txt_type | ||||
| from mongoengine.base.common import get_document, ALLOW_INHERITANCE | ||||
| from mongoengine.base.common import ALLOW_INHERITANCE, get_document | ||||
| from mongoengine.base.datastructures import ( | ||||
|     BaseDict, | ||||
|     BaseList, | ||||
|     EmbeddedDocumentList, | ||||
|     StrictDict, | ||||
|     SemiStrictDict | ||||
|     SemiStrictDict, | ||||
|     StrictDict | ||||
| ) | ||||
| from mongoengine.base.fields import ComplexBaseField | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, | ||||
|                                 LookUpError, ValidationError) | ||||
| from mongoengine.python_support import PY3, txt_type | ||||
|  | ||||
| __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | ||||
|  | ||||
| @@ -51,7 +51,7 @@ class BaseDocument(object): | ||||
|             # We only want named arguments. | ||||
|             field = iter(self._fields_ordered) | ||||
|             # If its an automatic id field then skip to the first defined field | ||||
|             if self._auto_id_field: | ||||
|             if getattr(self, '_auto_id_field', False): | ||||
|                 next(field) | ||||
|             for value in args: | ||||
|                 name = next(field) | ||||
| @@ -72,12 +72,13 @@ class BaseDocument(object): | ||||
|         # Check if there are undefined fields supplied to the constructor, | ||||
|         # if so raise an Exception. | ||||
|         if not self._dynamic and (self._meta.get('strict', True) or _created): | ||||
|             for var in values.keys(): | ||||
|                 if var not in self._fields.keys() + ['id', 'pk', '_cls', '_text_score']: | ||||
|                     msg = ( | ||||
|                         "The field '{0}' does not exist on the document '{1}'" | ||||
|                     ).format(var, self._class_name) | ||||
|                     raise FieldDoesNotExist(msg) | ||||
|             _undefined_fields = set(values.keys()) - set( | ||||
|                 self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) | ||||
|             if _undefined_fields: | ||||
|                 msg = ( | ||||
|                     "The fields '{0}' do not exist on the document '{1}'" | ||||
|                 ).format(_undefined_fields, self._class_name) | ||||
|                 raise FieldDoesNotExist(msg) | ||||
|  | ||||
|         if self.STRICT and not self._dynamic: | ||||
|             self._data = StrictDict.create(allowed_keys=self._fields_ordered)() | ||||
| @@ -120,7 +121,7 @@ class BaseDocument(object): | ||||
|                 else: | ||||
|                     self._data[key] = value | ||||
|  | ||||
|         # Set any get_fieldname_display methods | ||||
|         # Set any get_<field>_display methods | ||||
|         self.__set_field_display() | ||||
|  | ||||
|         if self._dynamic: | ||||
| @@ -309,7 +310,7 @@ class BaseDocument(object): | ||||
|         data = SON() | ||||
|         data["_id"] = None | ||||
|         data['_cls'] = self._class_name | ||||
|         EmbeddedDocumentField = _import_class("EmbeddedDocumentField") | ||||
|  | ||||
|         # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] | ||||
|         root_fields = set([f.split('.')[0] for f in fields]) | ||||
|  | ||||
| @@ -324,21 +325,20 @@ class BaseDocument(object): | ||||
|                 field = self._dynamic_fields.get(field_name) | ||||
|  | ||||
|             if value is not None: | ||||
|                 f_inputs = field.to_mongo.__code__.co_varnames | ||||
|                 ex_vars = {} | ||||
|                 if fields and 'fields' in f_inputs: | ||||
|                     key = '%s.' % field_name | ||||
|                     embedded_fields = [ | ||||
|                         i.replace(key, '') for i in fields | ||||
|                         if i.startswith(key)] | ||||
|  | ||||
|                 if isinstance(field, EmbeddedDocumentField): | ||||
|                     if fields: | ||||
|                         key = '%s.' % field_name | ||||
|                         embedded_fields = [ | ||||
|                             i.replace(key, '') for i in fields | ||||
|                             if i.startswith(key)] | ||||
|                     ex_vars['fields'] = embedded_fields | ||||
|  | ||||
|                     else: | ||||
|                         embedded_fields = [] | ||||
|                 if 'use_db_field' in f_inputs: | ||||
|                     ex_vars['use_db_field'] = use_db_field | ||||
|  | ||||
|                     value = field.to_mongo(value, use_db_field=use_db_field, | ||||
|                                            fields=embedded_fields) | ||||
|                 else: | ||||
|                     value = field.to_mongo(value) | ||||
|                 value = field.to_mongo(value, **ex_vars) | ||||
|  | ||||
|             # Handle self generating fields | ||||
|             if value is None and field._auto_gen: | ||||
| @@ -491,7 +491,7 @@ class BaseDocument(object): | ||||
|                 # remove lower level changed fields | ||||
|                 level = '.'.join(levels[:idx]) + '.' | ||||
|                 remove = self._changed_fields.remove | ||||
|                 for field in self._changed_fields: | ||||
|                 for field in self._changed_fields[:]: | ||||
|                     if field.startswith(level): | ||||
|                         remove(field) | ||||
|  | ||||
| @@ -566,8 +566,10 @@ class BaseDocument(object): | ||||
|                     continue | ||||
|             if isinstance(field, ReferenceField): | ||||
|                 continue | ||||
|             elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) | ||||
|                   and db_field_name not in changed_fields): | ||||
|             elif ( | ||||
|                 isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and | ||||
|                 db_field_name not in changed_fields | ||||
|             ): | ||||
|                 # Find all embedded fields that have been changed | ||||
|                 changed = data._get_changed_fields(inspected) | ||||
|                 changed_fields += ["%s%s" % (key, k) for k in changed if k] | ||||
| @@ -606,7 +608,9 @@ class BaseDocument(object): | ||||
|                 for p in parts: | ||||
|                     if isinstance(d, (ObjectId, DBRef)): | ||||
|                         break | ||||
|                     elif isinstance(d, list) and p.isdigit(): | ||||
|                     elif isinstance(d, list) and p.lstrip('-').isdigit(): | ||||
|                         if p[0] == '-': | ||||
|                             p = str(len(d) + int(p)) | ||||
|                         try: | ||||
|                             d = d[int(p)] | ||||
|                         except IndexError: | ||||
| @@ -640,7 +644,9 @@ class BaseDocument(object): | ||||
|                 parts = path.split('.') | ||||
|                 db_field_name = parts.pop() | ||||
|                 for p in parts: | ||||
|                     if isinstance(d, list) and p.isdigit(): | ||||
|                     if isinstance(d, list) and p.lstrip('-').isdigit(): | ||||
|                         if p[0] == '-': | ||||
|                             p = str(len(d) + int(p)) | ||||
|                         d = d[int(p)] | ||||
|                     elif (hasattr(d, '__getattribute__') and | ||||
|                           not isinstance(d, dict)): | ||||
| @@ -708,14 +714,6 @@ class BaseDocument(object): | ||||
|                         del data[field.db_field] | ||||
|                 except (AttributeError, ValueError), e: | ||||
|                     errors_dict[field_name] = e | ||||
|             elif field.default: | ||||
|                 default = field.default | ||||
|                 if callable(default): | ||||
|                     default = default() | ||||
|                 if isinstance(default, BaseDocument): | ||||
|                     changed_fields.append(field_name) | ||||
|                 elif not only_fields or field_name in only_fields: | ||||
|                     changed_fields.append(field_name) | ||||
|  | ||||
|         if errors_dict: | ||||
|             errors = "\n".join(["%s - %s" % (k, v) | ||||
| @@ -779,8 +777,12 @@ class BaseDocument(object): | ||||
|         # Check to see if we need to include _cls | ||||
|         allow_inheritance = cls._meta.get('allow_inheritance', | ||||
|                                           ALLOW_INHERITANCE) | ||||
|         include_cls = (allow_inheritance and not spec.get('sparse', False) and | ||||
|                        spec.get('cls',  True) and '_cls' not in spec['fields']) | ||||
|         include_cls = ( | ||||
|             allow_inheritance and | ||||
|             not spec.get('sparse', False) and | ||||
|             spec.get('cls', True) and | ||||
|             '_cls' not in spec['fields'] | ||||
|         ) | ||||
|  | ||||
|         # 733: don't include cls if index_cls is False unless there is an explicit cls with the index | ||||
|         include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) | ||||
| @@ -974,7 +976,7 @@ class BaseDocument(object): | ||||
|                 if hasattr(getattr(field, 'field', None), 'lookup_member'): | ||||
|                     new_field = field.field.lookup_member(field_name) | ||||
|                 elif cls._dynamic and (isinstance(field, DynamicField) or | ||||
|                                        getattr(getattr(field, 'document_type'), '_dynamic')): | ||||
|                                        getattr(getattr(field, 'document_type', None), '_dynamic', None)): | ||||
|                     new_field = DynamicField(db_field=field_name) | ||||
|                 else: | ||||
|                     # Look up subfield on the previous field or raise | ||||
| @@ -1003,19 +1005,18 @@ class BaseDocument(object): | ||||
|         return '.'.join(parts) | ||||
|  | ||||
|     def __set_field_display(self): | ||||
|         """Dynamically set the display value for a field with choices""" | ||||
|         for attr_name, field in self._fields.items(): | ||||
|             if field.choices: | ||||
|                 if self._dynamic: | ||||
|                     obj = self | ||||
|                 else: | ||||
|                     obj = type(self) | ||||
|                 setattr(obj, | ||||
|                         'get_%s_display' % attr_name, | ||||
|                         partial(self.__get_field_display, field=field)) | ||||
|         """For each field that specifies choices, create a | ||||
|         get_<field>_display method. | ||||
|         """ | ||||
|         fields_with_choices = [(n, f) for n, f in self._fields.items() | ||||
|                                if f.choices] | ||||
|         for attr_name, field in fields_with_choices: | ||||
|             setattr(self, | ||||
|                     'get_%s_display' % attr_name, | ||||
|                     partial(self.__get_field_display, field=field)) | ||||
|  | ||||
|     def __get_field_display(self, field): | ||||
|         """Returns the display value for a choice field""" | ||||
|         """Return the display value for a choice field""" | ||||
|         value = getattr(self, field.name) | ||||
|         if field.choices and isinstance(field.choices[0], (list, tuple)): | ||||
|             return dict(field.choices).get(value, value) | ||||
|   | ||||
| @@ -5,12 +5,12 @@ import weakref | ||||
| from bson import DBRef, ObjectId, SON | ||||
| import pymongo | ||||
|  | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import ValidationError | ||||
| from mongoengine.base.common import ALLOW_INHERITANCE | ||||
| from mongoengine.base.datastructures import ( | ||||
|     BaseDict, BaseList, EmbeddedDocumentList | ||||
| ) | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import ValidationError | ||||
|  | ||||
| __all__ = ("BaseField", "ComplexBaseField", | ||||
|            "ObjectIdField", "GeoJsonBaseField") | ||||
| @@ -85,13 +85,13 @@ class BaseField(object): | ||||
|         self.null = null | ||||
|         self.sparse = sparse | ||||
|         self._owner_document = None | ||||
|          | ||||
|  | ||||
|         # Detect and report conflicts between metadata and base properties. | ||||
|         conflicts = set(dir(self)) & set(kwargs) | ||||
|         if conflicts: | ||||
|             raise TypeError("%s already has attribute(s): %s" % ( | ||||
|                 self.__class__.__name__, ', '.join(conflicts) )) | ||||
|          | ||||
|                 self.__class__.__name__, ', '.join(conflicts))) | ||||
|  | ||||
|         # Assign metadata to the instance | ||||
|         # This efficient method is available because no __slots__ are defined. | ||||
|         self.__dict__.update(kwargs) | ||||
| @@ -133,7 +133,7 @@ class BaseField(object): | ||||
|                 if (self.name not in instance._data or | ||||
|                         instance._data[self.name] != value): | ||||
|                     instance._mark_as_changed(self.name) | ||||
|             except: | ||||
|             except Exception: | ||||
|                 # Values cant be compared eg: naive and tz datetimes | ||||
|                 # So mark it as changed | ||||
|                 instance._mark_as_changed(self.name) | ||||
| @@ -163,6 +163,19 @@ class BaseField(object): | ||||
|         """ | ||||
|         return self.to_python(value) | ||||
|  | ||||
|     def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): | ||||
|         """A helper method to call to_mongo with proper inputs | ||||
|         """ | ||||
|         f_inputs = self.to_mongo.__code__.co_varnames | ||||
|         ex_vars = {} | ||||
|         if 'fields' in f_inputs: | ||||
|             ex_vars['fields'] = fields | ||||
|  | ||||
|         if 'use_db_field' in f_inputs: | ||||
|             ex_vars['use_db_field'] = use_db_field | ||||
|  | ||||
|         return self.to_mongo(value, **ex_vars) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         """Prepare a value that is being used in a query for PyMongo. | ||||
|         """ | ||||
| @@ -193,7 +206,6 @@ class BaseField(object): | ||||
|         elif value not in choice_list: | ||||
|             self.error('Value must be one of %s' % unicode(choice_list)) | ||||
|  | ||||
|  | ||||
|     def _validate(self, value, **kwargs): | ||||
|         # Check the Choices Constraint | ||||
|         if self.choices: | ||||
| @@ -285,8 +297,6 @@ class ComplexBaseField(BaseField): | ||||
|     def to_python(self, value): | ||||
|         """Convert a MongoDB-compatible type to a Python type. | ||||
|         """ | ||||
|         Document = _import_class('Document') | ||||
|  | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|  | ||||
| @@ -306,6 +316,7 @@ class ComplexBaseField(BaseField): | ||||
|             value_dict = dict([(key, self.field.to_python(item)) | ||||
|                                for key, item in value.items()]) | ||||
|         else: | ||||
|             Document = _import_class('Document') | ||||
|             value_dict = {} | ||||
|             for k, v in value.items(): | ||||
|                 if isinstance(v, Document): | ||||
| @@ -325,7 +336,7 @@ class ComplexBaseField(BaseField): | ||||
|                                          key=operator.itemgetter(0))] | ||||
|         return value_dict | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|     def to_mongo(self, value, use_db_field=True, fields=None): | ||||
|         """Convert a Python type to a MongoDB-compatible type. | ||||
|         """ | ||||
|         Document = _import_class("Document") | ||||
| @@ -339,7 +350,7 @@ class ComplexBaseField(BaseField): | ||||
|             if isinstance(value, Document): | ||||
|                 return GenericReferenceField().to_mongo(value) | ||||
|             cls = value.__class__ | ||||
|             val = value.to_mongo() | ||||
|             val = value.to_mongo(use_db_field, fields) | ||||
|             # If it's a document that is not inherited add _cls | ||||
|             if isinstance(value, EmbeddedDocument): | ||||
|                 val['_cls'] = cls.__name__ | ||||
| @@ -354,7 +365,7 @@ class ComplexBaseField(BaseField): | ||||
|                 return value | ||||
|  | ||||
|         if self.field: | ||||
|             value_dict = dict([(key, self.field.to_mongo(item)) | ||||
|             value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields)) | ||||
|                                for key, item in value.iteritems()]) | ||||
|         else: | ||||
|             value_dict = {} | ||||
| @@ -379,13 +390,13 @@ class ComplexBaseField(BaseField): | ||||
|                         value_dict[k] = DBRef(collection, v.pk) | ||||
|                 elif hasattr(v, 'to_mongo'): | ||||
|                     cls = v.__class__ | ||||
|                     val = v.to_mongo() | ||||
|                     val = v.to_mongo(use_db_field, fields) | ||||
|                     # If it's a document that is not inherited add _cls | ||||
|                     if isinstance(v, (Document, EmbeddedDocument)): | ||||
|                         val['_cls'] = cls.__name__ | ||||
|                     value_dict[k] = val | ||||
|                 else: | ||||
|                     value_dict[k] = self.to_mongo(v) | ||||
|                     value_dict[k] = self.to_mongo(v, use_db_field, fields) | ||||
|  | ||||
|         if is_list:  # Convert back to a list | ||||
|             return [v for _, v in sorted(value_dict.items(), | ||||
| @@ -439,7 +450,7 @@ class ObjectIdField(BaseField): | ||||
|         try: | ||||
|             if not isinstance(value, ObjectId): | ||||
|                 value = ObjectId(value) | ||||
|         except: | ||||
|         except Exception: | ||||
|             pass | ||||
|         return value | ||||
|  | ||||
| @@ -458,7 +469,7 @@ class ObjectIdField(BaseField): | ||||
|     def validate(self, value): | ||||
|         try: | ||||
|             ObjectId(unicode(value)) | ||||
|         except: | ||||
|         except Exception: | ||||
|             self.error('Invalid Object ID') | ||||
|  | ||||
|  | ||||
| @@ -510,7 +521,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         # Quick and dirty validator | ||||
|         try: | ||||
|             value[0][0][0] | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             return "Invalid Polygon must contain at least one valid linestring" | ||||
|  | ||||
|         errors = [] | ||||
| @@ -534,7 +545,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         # Quick and dirty validator | ||||
|         try: | ||||
|             value[0][0] | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             return "Invalid LineString must contain at least one valid point" | ||||
|  | ||||
|         errors = [] | ||||
| @@ -565,7 +576,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         # Quick and dirty validator | ||||
|         try: | ||||
|             value[0][0] | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             return "Invalid MultiPoint must contain at least one valid point" | ||||
|  | ||||
|         errors = [] | ||||
| @@ -584,7 +595,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         # Quick and dirty validator | ||||
|         try: | ||||
|             value[0][0][0] | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             return "Invalid MultiLineString must contain at least one valid linestring" | ||||
|  | ||||
|         errors = [] | ||||
| @@ -606,7 +617,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         # Quick and dirty validator | ||||
|         try: | ||||
|             value[0][0][0][0] | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             return "Invalid MultiPolygon must contain at least one valid Polygon" | ||||
|  | ||||
|         errors = [] | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| import warnings | ||||
|  | ||||
| from mongoengine.base.common import ALLOW_INHERITANCE, _document_registry | ||||
| from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import InvalidDocumentError | ||||
| from mongoengine.python_support import PY3 | ||||
| @@ -7,16 +9,14 @@ from mongoengine.queryset import (DO_NOTHING, DoesNotExist, | ||||
|                                   MultipleObjectsReturned, | ||||
|                                   QuerySetManager) | ||||
|  | ||||
| from mongoengine.base.common import _document_registry, ALLOW_INHERITANCE | ||||
| from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | ||||
|  | ||||
| __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') | ||||
|  | ||||
|  | ||||
| class DocumentMetaclass(type): | ||||
|     """Metaclass for all documents. | ||||
|     """ | ||||
|     """Metaclass for all documents.""" | ||||
|  | ||||
|     # TODO lower complexity of this method | ||||
|     def __new__(cls, name, bases, attrs): | ||||
|         flattened_bases = cls._get_bases(bases) | ||||
|         super_new = super(DocumentMetaclass, cls).__new__ | ||||
| @@ -162,7 +162,7 @@ class DocumentMetaclass(type): | ||||
|         # copies __func__ into im_func and __self__ into im_self for | ||||
|         # classmethod objects in Document derived classes. | ||||
|         if PY3: | ||||
|             for key, val in new_class.__dict__.items(): | ||||
|             for val in new_class.__dict__.values(): | ||||
|                 if isinstance(val, classmethod): | ||||
|                     f = val.__get__(new_class) | ||||
|                     if hasattr(f, '__func__') and not hasattr(f, 'im_func'): | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| from pymongo import MongoClient, ReadPreference, uri_parser | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| from mongoengine.python_support import (IS_PYMONGO_3, str_types) | ||||
|  | ||||
| __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||
|            'DEFAULT_CONNECTION_NAME'] | ||||
|  | ||||
|  | ||||
| DEFAULT_CONNECTION_NAME = 'default' | ||||
|  | ||||
| if IS_PYMONGO_3: | ||||
|     READ_PREFERENCE = ReadPreference.PRIMARY | ||||
| else: | ||||
| @@ -24,7 +25,9 @@ _dbs = {} | ||||
|  | ||||
| def register_connection(alias, name=None, host=None, port=None, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, authentication_source=None, | ||||
|                         username=None, password=None, | ||||
|                         authentication_source=None, | ||||
|                         authentication_mechanism=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
|  | ||||
| @@ -38,6 +41,9 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|     :param username: username to authenticate with | ||||
|     :param password: password to authenticate with | ||||
|     :param authentication_source: database to authenticate against | ||||
|     :param authentication_mechanism: database authentication mechanisms. | ||||
|         By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, | ||||
|         MONGODB-CR (MongoDB Challenge Response protocol) for older servers. | ||||
|     :param is_mock: explicitly use mongomock for this connection | ||||
|         (can also be done by using `mongomock://` as db host prefix) | ||||
|     :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver | ||||
| @@ -53,28 +59,48 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|         'read_preference': read_preference, | ||||
|         'username': username, | ||||
|         'password': password, | ||||
|         'authentication_source': authentication_source | ||||
|         'authentication_source': authentication_source, | ||||
|         'authentication_mechanism': authentication_mechanism | ||||
|     } | ||||
|  | ||||
|     # Handle uri style connections | ||||
|     conn_host = conn_settings['host'] | ||||
|     if conn_host.startswith('mongomock://'): | ||||
|         conn_settings['is_mock'] = True | ||||
|         # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` | ||||
|         conn_settings['host'] = conn_host.replace('mongomock://', 'mongodb://', 1) | ||||
|     elif '://' in conn_host: | ||||
|         uri_dict = uri_parser.parse_uri(conn_host) | ||||
|         conn_settings.update({ | ||||
|             'name': uri_dict.get('database') or name, | ||||
|             'username': uri_dict.get('username'), | ||||
|             'password': uri_dict.get('password'), | ||||
|             'read_preference': read_preference, | ||||
|         }) | ||||
|         uri_options = uri_dict['options'] | ||||
|         if 'replicaset' in uri_options: | ||||
|             conn_settings['replicaSet'] = True | ||||
|         if 'authsource' in uri_options: | ||||
|             conn_settings['authentication_source'] = uri_options['authsource'] | ||||
|     # host can be a list or a string, so if string, force to a list | ||||
|     if isinstance(conn_host, str_types): | ||||
|         conn_host = [conn_host] | ||||
|  | ||||
|     resolved_hosts = [] | ||||
|     for entity in conn_host: | ||||
|  | ||||
|         # Handle Mongomock | ||||
|         if entity.startswith('mongomock://'): | ||||
|             conn_settings['is_mock'] = True | ||||
|             # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` | ||||
|             resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) | ||||
|  | ||||
|         # Handle URI style connections, only updating connection params which | ||||
|         # were explicitly specified in the URI. | ||||
|         elif '://' in entity: | ||||
|             uri_dict = uri_parser.parse_uri(entity) | ||||
|             resolved_hosts.append(entity) | ||||
|  | ||||
|             if uri_dict.get('database'): | ||||
|                 conn_settings['name'] = uri_dict.get('database') | ||||
|  | ||||
|             for param in ('read_preference', 'username', 'password'): | ||||
|                 if uri_dict.get(param): | ||||
|                     conn_settings[param] = uri_dict[param] | ||||
|  | ||||
|             uri_options = uri_dict['options'] | ||||
|             if 'replicaset' in uri_options: | ||||
|                 conn_settings['replicaSet'] = True | ||||
|             if 'authsource' in uri_options: | ||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||
|             if 'authmechanism' in uri_options: | ||||
|                 conn_settings['authentication_mechanism'] = uri_options['authmechanism'] | ||||
|         else: | ||||
|             resolved_hosts.append(entity) | ||||
|     conn_settings['host'] = resolved_hosts | ||||
|  | ||||
|     # Deprecated parameters that should not be passed on | ||||
|     kwargs.pop('slaves', None) | ||||
| @@ -113,6 +139,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|         conn_settings.pop('username', None) | ||||
|         conn_settings.pop('password', None) | ||||
|         conn_settings.pop('authentication_source', None) | ||||
|         conn_settings.pop('authentication_mechanism', None) | ||||
|  | ||||
|         is_mock = conn_settings.pop('is_mock', None) | ||||
|         if is_mock: | ||||
| @@ -147,6 +174,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|                 connection_settings.pop('username', None) | ||||
|                 connection_settings.pop('password', None) | ||||
|                 connection_settings.pop('authentication_source', None) | ||||
|                 connection_settings.pop('authentication_mechanism', None) | ||||
|                 if conn_settings == connection_settings and _connections.get(db_alias, None): | ||||
|                     connection = _connections[db_alias] | ||||
|                     break | ||||
| @@ -166,11 +194,13 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|         conn = get_connection(alias) | ||||
|         conn_settings = _connection_settings[alias] | ||||
|         db = conn[conn_settings['name']] | ||||
|         auth_kwargs = {'source': conn_settings['authentication_source']} | ||||
|         if conn_settings['authentication_mechanism'] is not None: | ||||
|             auth_kwargs['mechanism'] = conn_settings['authentication_mechanism'] | ||||
|         # Authenticate if necessary | ||||
|         if conn_settings['username'] and conn_settings['password']: | ||||
|             db.authenticate(conn_settings['username'], | ||||
|                             conn_settings['password'], | ||||
|                             source=conn_settings['authentication_source']) | ||||
|         if conn_settings['username'] and (conn_settings['password'] or | ||||
|                                           conn_settings['authentication_mechanism'] == 'MONGODB-X509'): | ||||
|             db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs) | ||||
|         _dbs[alias] = db | ||||
|     return _dbs[alias] | ||||
|  | ||||
|   | ||||
| @@ -1,13 +1,14 @@ | ||||
| from bson import DBRef, SON | ||||
|  | ||||
| from base import ( | ||||
| from .base import ( | ||||
|     BaseDict, BaseList, EmbeddedDocumentList, | ||||
|     TopLevelDocumentMetaclass, get_document | ||||
| ) | ||||
| from fields import (ReferenceField, ListField, DictField, MapField) | ||||
| from connection import get_db | ||||
| from queryset import QuerySet | ||||
| from document import Document, EmbeddedDocument | ||||
| from .connection import get_db | ||||
| from .document import Document, EmbeddedDocument | ||||
| from .fields import DictField, ListField, MapField, ReferenceField | ||||
| from .python_support import txt_type | ||||
| from .queryset import QuerySet | ||||
|  | ||||
|  | ||||
| class DeReference(object): | ||||
| @@ -226,7 +227,7 @@ class DeReference(object): | ||||
|                         data[k]._data[field_name] = self.object_map.get( | ||||
|                             (v['_ref'].collection, v['_ref'].id), v) | ||||
|                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                         item_name = "{0}.{1}.{2}".format(name, k, field_name) | ||||
|                         item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) | ||||
|             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                 item_name = '%s.%s' % (name, k) if name else name | ||||
|   | ||||
| @@ -1,28 +1,29 @@ | ||||
| import warnings | ||||
| import pymongo | ||||
| import re | ||||
| import warnings | ||||
|  | ||||
| from pymongo.read_preferences import ReadPreference | ||||
| from bson.dbref import DBRef | ||||
| import pymongo | ||||
| from pymongo.read_preferences import ReadPreference | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.base import ( | ||||
|     DocumentMetaclass, | ||||
|     TopLevelDocumentMetaclass, | ||||
|     BaseDocument, | ||||
|     BaseDict, | ||||
|     BaseList, | ||||
|     EmbeddedDocumentList, | ||||
|     ALLOW_INHERITANCE, | ||||
|     BaseDict, | ||||
|     BaseDocument, | ||||
|     BaseList, | ||||
|     DocumentMetaclass, | ||||
|     EmbeddedDocumentList, | ||||
|     TopLevelDocumentMetaclass, | ||||
|     get_document | ||||
| ) | ||||
| from mongoengine.errors import (InvalidQueryError, InvalidDocumentError, | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.context_managers import switch_collection, switch_db | ||||
| from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, | ||||
|                                 SaveConditionError) | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| from mongoengine.queryset import (OperationError, NotUniqueError, | ||||
| from mongoengine.queryset import (NotUniqueError, OperationError, | ||||
|                                   QuerySet, transform) | ||||
| from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME | ||||
| from mongoengine.context_managers import switch_db, switch_collection | ||||
|  | ||||
| __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | ||||
|            'DynamicEmbeddedDocument', 'OperationError', | ||||
| @@ -250,7 +251,7 @@ class Document(BaseDocument): | ||||
|  | ||||
|     def save(self, force_insert=False, validate=True, clean=True, | ||||
|              write_concern=None, cascade=None, cascade_kwargs=None, | ||||
|              _refs=None, save_condition=None, **kwargs): | ||||
|              _refs=None, save_condition=None, signal_kwargs=None, **kwargs): | ||||
|         """Save the :class:`~mongoengine.Document` to the database. If the | ||||
|         document already exists, it will be updated, otherwise it will be | ||||
|         created. | ||||
| @@ -276,6 +277,8 @@ class Document(BaseDocument): | ||||
|         :param save_condition: only perform save if matching record in db | ||||
|             satisfies condition(s) (e.g. version number). | ||||
|             Raises :class:`OperationError` if the conditions are not satisfied | ||||
|         :parm signal_kwargs: (optional) kwargs dictionary to be passed to | ||||
|             the signal calls. | ||||
|  | ||||
|         .. versionchanged:: 0.5 | ||||
|             In existing documents it only saves changed fields using | ||||
| @@ -297,8 +300,11 @@ class Document(BaseDocument): | ||||
|             :class:`OperationError` exception raised if save_condition fails. | ||||
|         .. versionchanged:: 0.10.1 | ||||
|             :class: save_condition failure now raises a `SaveConditionError` | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             Add signal_kwargs argument | ||||
|         """ | ||||
|         signals.pre_save.send(self.__class__, document=self) | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) | ||||
|  | ||||
|         if validate: | ||||
|             self.validate(clean=clean) | ||||
| @@ -311,7 +317,7 @@ class Document(BaseDocument): | ||||
|         created = ('_id' not in doc or self._created or force_insert) | ||||
|  | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||
|                                               created=created) | ||||
|                                               created=created, **signal_kwargs) | ||||
|  | ||||
|         try: | ||||
|             collection = self._get_collection() | ||||
| @@ -327,8 +333,10 @@ class Document(BaseDocument): | ||||
|                     # Correct behaviour in 2.X and in 3.0.1+ versions | ||||
|                     if not object_id and pymongo.version_tuple == (3, 0): | ||||
|                         pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||
|                         object_id = self._qs.filter(pk=pk_as_mongo_obj).first() and \ | ||||
|                                     self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|                         object_id = ( | ||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|                         )  # TODO doesn't this make 2 queries? | ||||
|             else: | ||||
|                 object_id = doc['_id'] | ||||
|                 updates, removals = self._delta() | ||||
| @@ -400,7 +408,8 @@ class Document(BaseDocument): | ||||
|         if created or id_field not in self._meta.get('shard_key', []): | ||||
|             self[id_field] = self._fields[id_field].to_python(object_id) | ||||
|  | ||||
|         signals.post_save.send(self.__class__, document=self, created=created) | ||||
|         signals.post_save.send(self.__class__, document=self, | ||||
|                                created=created, **signal_kwargs) | ||||
|         self._clear_changed_fields() | ||||
|         self._created = False | ||||
|         return self | ||||
| @@ -463,7 +472,7 @@ class Document(BaseDocument): | ||||
|         Raises :class:`OperationError` if called on an object that has not yet | ||||
|         been saved. | ||||
|         """ | ||||
|         if not self.pk: | ||||
|         if self.pk is None: | ||||
|             if kwargs.get('upsert', False): | ||||
|                 query = self.to_mongo() | ||||
|                 if "_cls" in query: | ||||
| @@ -476,23 +485,29 @@ class Document(BaseDocument): | ||||
|         # Need to add shard key to query, or you get an error | ||||
|         return self._qs.filter(**self._object_key).update_one(**kwargs) | ||||
|  | ||||
|     def delete(self, **write_concern): | ||||
|     def delete(self, signal_kwargs=None, **write_concern): | ||||
|         """Delete the :class:`~mongoengine.Document` from the database. This | ||||
|         will only take effect if the document has been previously saved. | ||||
|  | ||||
|         :parm signal_kwargs: (optional) kwargs dictionary to be passed to | ||||
|             the signal calls. | ||||
|         :param write_concern: Extra keyword arguments are passed down which | ||||
|             will be used as options for the resultant | ||||
|             ``getLastError`` command.  For example, | ||||
|             ``save(..., write_concern={w: 2, fsync: True}, ...)`` will | ||||
|             wait until at least two servers have recorded the write and | ||||
|             will force an fsync on the primary server. | ||||
|         """ | ||||
|         signals.pre_delete.send(self.__class__, document=self) | ||||
|  | ||||
|         # Delete FileFields separately  | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             Add signal_kwargs argument | ||||
|         """ | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|         signals.pre_delete.send(self.__class__, document=self, **signal_kwargs) | ||||
|  | ||||
|         # Delete FileFields separately | ||||
|         FileField = _import_class('FileField') | ||||
|         for name, field in self._fields.iteritems(): | ||||
|             if isinstance(field, FileField):  | ||||
|             if isinstance(field, FileField): | ||||
|                 getattr(self, name).delete() | ||||
|  | ||||
|         try: | ||||
| @@ -501,7 +516,7 @@ class Document(BaseDocument): | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             message = u'Could not delete document (%s)' % err.message | ||||
|             raise OperationError(message) | ||||
|         signals.post_delete.send(self.__class__, document=self) | ||||
|         signals.post_delete.send(self.__class__, document=self, **signal_kwargs) | ||||
|  | ||||
|     def switch_db(self, db_alias, keep_created=True): | ||||
|         """ | ||||
| @@ -589,7 +604,7 @@ class Document(BaseDocument): | ||||
|         elif "max_depth" in kwargs: | ||||
|             max_depth = kwargs["max_depth"] | ||||
|  | ||||
|         if not self.pk: | ||||
|         if self.pk is None: | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||
|             **self._object_key).only(*fields).limit( | ||||
| @@ -640,7 +655,7 @@ class Document(BaseDocument): | ||||
|     def to_dbref(self): | ||||
|         """Returns an instance of :class:`~bson.dbref.DBRef` useful in | ||||
|         `__raw__` queries.""" | ||||
|         if not self.pk: | ||||
|         if self.pk is None: | ||||
|             msg = "Only saved documents can have a valid dbref" | ||||
|             raise OperationError(msg) | ||||
|         return DBRef(self.__class__._get_collection_name(), self.pk) | ||||
| @@ -667,10 +682,20 @@ class Document(BaseDocument): | ||||
|     def drop_collection(cls): | ||||
|         """Drops the entire collection associated with this | ||||
|         :class:`~mongoengine.Document` type from the database. | ||||
|  | ||||
|         Raises :class:`OperationError` if the document has no collection set | ||||
|         (i.g. if it is `abstract`) | ||||
|  | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             :class:`OperationError` exception raised if no collection available | ||||
|         """ | ||||
|         col_name = cls._get_collection_name() | ||||
|         if not col_name: | ||||
|             raise OperationError('Document %s has no collection defined ' | ||||
|                                  '(is it abstract ?)' % cls) | ||||
|         cls._collection = None | ||||
|         db = cls._get_db() | ||||
|         db.drop_collection(cls._get_collection_name()) | ||||
|         db.drop_collection(col_name) | ||||
|  | ||||
|     @classmethod | ||||
|     def create_index(cls, keys, background=False, **kwargs): | ||||
| @@ -959,7 +984,7 @@ class MapReduceDocument(object): | ||||
|         if not isinstance(self.key, id_field_type): | ||||
|             try: | ||||
|                 self.key = id_field_type(self.key) | ||||
|             except: | ||||
|             except Exception: | ||||
|                 raise Exception("Could not cast key as %s" % | ||||
|                                 id_field_type.__name__) | ||||
|  | ||||
|   | ||||
| @@ -8,6 +8,11 @@ import uuid | ||||
| import warnings | ||||
| from operator import itemgetter | ||||
|  | ||||
| from bson import Binary, DBRef, ObjectId, SON | ||||
| import gridfs | ||||
| import pymongo | ||||
| import six | ||||
|  | ||||
| try: | ||||
|     import dateutil | ||||
| except ImportError: | ||||
| @@ -15,18 +20,18 @@ except ImportError: | ||||
| else: | ||||
|     import dateutil.parser | ||||
|  | ||||
| import pymongo | ||||
| import gridfs | ||||
| from bson import Binary, DBRef, SON, ObjectId | ||||
| try: | ||||
|     from bson.int64 import Int64 | ||||
| except ImportError: | ||||
|     Int64 = long | ||||
|  | ||||
| from mongoengine.errors import ValidationError | ||||
| from mongoengine.python_support import (PY3, bin_type, txt_type, | ||||
|                                         str_types, StringIO) | ||||
| from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, | ||||
|                   get_document, BaseDocument) | ||||
| from queryset import DO_NOTHING, QuerySet | ||||
| from document import Document, EmbeddedDocument | ||||
| from connection import get_db, DEFAULT_CONNECTION_NAME | ||||
| from .base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField, | ||||
|                    ObjectIdField, get_document) | ||||
| from .connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from .document import Document, EmbeddedDocument | ||||
| from .errors import DoesNotExist, ValidationError | ||||
| from .python_support import PY3, StringIO, bin_type, str_types, txt_type | ||||
| from .queryset import DO_NOTHING, QuerySet | ||||
|  | ||||
| try: | ||||
|     from PIL import Image, ImageOps | ||||
| @@ -65,7 +70,7 @@ class StringField(BaseField): | ||||
|             return value | ||||
|         try: | ||||
|             value = value.decode('utf-8') | ||||
|         except: | ||||
|         except Exception: | ||||
|             pass | ||||
|         return value | ||||
|  | ||||
| @@ -156,7 +161,7 @@ class URLField(StringField): | ||||
|  | ||||
|  | ||||
| class EmailField(StringField): | ||||
|     """A field that validates input as an E-Mail-Address. | ||||
|     """A field that validates input as an email address. | ||||
|  | ||||
|     .. versionadded:: 0.4 | ||||
|     """ | ||||
| @@ -172,7 +177,7 @@ class EmailField(StringField): | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if not EmailField.EMAIL_REGEX.match(value): | ||||
|             self.error('Invalid Mail-address: %s' % value) | ||||
|             self.error('Invalid email address: %s' % value) | ||||
|         super(EmailField, self).validate(value) | ||||
|  | ||||
|  | ||||
| @@ -194,7 +199,7 @@ class IntField(BaseField): | ||||
|     def validate(self, value): | ||||
|         try: | ||||
|             value = int(value) | ||||
|         except: | ||||
|         except Exception: | ||||
|             self.error('%s could not be converted to int' % value) | ||||
|  | ||||
|         if self.min_value is not None and value < self.min_value: | ||||
| @@ -225,10 +230,13 @@ class LongField(BaseField): | ||||
|             pass | ||||
|         return value | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         return Int64(value) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         try: | ||||
|             value = long(value) | ||||
|         except: | ||||
|         except Exception: | ||||
|             self.error('%s could not be converted to long' % value) | ||||
|  | ||||
|         if self.min_value is not None and value < self.min_value: | ||||
| @@ -260,10 +268,14 @@ class FloatField(BaseField): | ||||
|         return value | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if isinstance(value, int): | ||||
|             value = float(value) | ||||
|         if isinstance(value, six.integer_types): | ||||
|             try: | ||||
|                 value = float(value) | ||||
|             except OverflowError: | ||||
|                 self.error('The value is too large to be converted to float') | ||||
|  | ||||
|         if not isinstance(value, float): | ||||
|             self.error('FloatField only accepts float values') | ||||
|             self.error('FloatField only accepts float and integer values') | ||||
|  | ||||
|         if self.min_value is not None and value < self.min_value: | ||||
|             self.error('Float value is too small') | ||||
| @@ -325,7 +337,7 @@ class DecimalField(BaseField): | ||||
|             return value | ||||
|         return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) | ||||
|  | ||||
|     def to_mongo(self, value, use_db_field=True): | ||||
|     def to_mongo(self, value): | ||||
|         if value is None: | ||||
|             return value | ||||
|         if self.force_string: | ||||
| @@ -508,7 +520,7 @@ class ComplexDateTimeField(StringField): | ||||
|         original_value = value | ||||
|         try: | ||||
|             return self._convert_from_string(value) | ||||
|         except: | ||||
|         except Exception: | ||||
|             return original_value | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
| @@ -546,11 +558,10 @@ class EmbeddedDocumentField(BaseField): | ||||
|             return self.document_type._from_son(value, _auto_dereference=self._auto_dereference) | ||||
|         return value | ||||
|  | ||||
|     def to_mongo(self, value, use_db_field=True, fields=[]): | ||||
|     def to_mongo(self, value, use_db_field=True, fields=None): | ||||
|         if not isinstance(value, self.document_type): | ||||
|             return value | ||||
|         return self.document_type.to_mongo(value, use_db_field, | ||||
|                                            fields=fields) | ||||
|         return self.document_type.to_mongo(value, use_db_field, fields) | ||||
|  | ||||
|     def validate(self, value, clean=True): | ||||
|         """Make sure that the document instance is an instance of the | ||||
| @@ -566,7 +577,7 @@ class EmbeddedDocumentField(BaseField): | ||||
|         return self.document_type._fields.get(member_name) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if not isinstance(value, self.document_type): | ||||
|         if value is not None and not isinstance(value, self.document_type): | ||||
|             value = self.document_type._from_son(value) | ||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||
|         return self.to_mongo(value) | ||||
| @@ -600,11 +611,11 @@ class GenericEmbeddedDocumentField(BaseField): | ||||
|  | ||||
|         value.validate(clean=clean) | ||||
|  | ||||
|     def to_mongo(self, document, use_db_field=True): | ||||
|     def to_mongo(self, document, use_db_field=True, fields=None): | ||||
|         if document is None: | ||||
|             return None | ||||
|  | ||||
|         data = document.to_mongo(use_db_field) | ||||
|         data = document.to_mongo(use_db_field, fields) | ||||
|         if '_cls' not in data: | ||||
|             data['_cls'] = document._class_name | ||||
|         return data | ||||
| @@ -616,7 +627,7 @@ class DynamicField(BaseField): | ||||
|  | ||||
|     Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|     def to_mongo(self, value, use_db_field=True, fields=None): | ||||
|         """Convert a Python type to a MongoDB compatible type. | ||||
|         """ | ||||
|  | ||||
| @@ -625,7 +636,7 @@ class DynamicField(BaseField): | ||||
|  | ||||
|         if hasattr(value, 'to_mongo'): | ||||
|             cls = value.__class__ | ||||
|             val = value.to_mongo() | ||||
|             val = value.to_mongo(use_db_field, fields) | ||||
|             # If we its a document thats not inherited add _cls | ||||
|             if isinstance(value, Document): | ||||
|                 val = {"_ref": value.to_dbref(), "_cls": cls.__name__} | ||||
| @@ -643,7 +654,7 @@ class DynamicField(BaseField): | ||||
|  | ||||
|         data = {} | ||||
|         for k, v in value.iteritems(): | ||||
|             data[k] = self.to_mongo(v) | ||||
|             data[k] = self.to_mongo(v, use_db_field, fields) | ||||
|  | ||||
|         value = data | ||||
|         if is_list:  # Convert back to a list | ||||
| @@ -697,7 +708,7 @@ class ListField(ComplexBaseField): | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if self.field: | ||||
|             if op in ('set', 'unset') and ( | ||||
|             if op in ('set', 'unset', None) and ( | ||||
|                     not isinstance(value, basestring) and | ||||
|                     not isinstance(value, BaseDocument) and | ||||
|                     hasattr(value, '__iter__')): | ||||
| @@ -755,8 +766,8 @@ class SortedListField(ListField): | ||||
|             self._order_reverse = kwargs.pop('reverse') | ||||
|         super(SortedListField, self).__init__(field, **kwargs) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         value = super(SortedListField, self).to_mongo(value) | ||||
|     def to_mongo(self, value, use_db_field=True, fields=None): | ||||
|         value = super(SortedListField, self).to_mongo(value, use_db_field, fields) | ||||
|         if self._ordering is not None: | ||||
|             return sorted(value, key=itemgetter(self._ordering), | ||||
|                           reverse=self._order_reverse) | ||||
| @@ -878,7 +889,7 @@ class ReferenceField(BaseField): | ||||
|             content = StringField() | ||||
|             foo = ReferenceField('Foo') | ||||
|  | ||||
|         Bar.register_delete_rule(Foo, 'bar', NULLIFY) | ||||
|         Foo.register_delete_rule(Bar, 'foo', NULLIFY) | ||||
|  | ||||
|     .. note :: | ||||
|         `reverse_delete_rule` does not trigger pre / post delete signals to be | ||||
| @@ -936,9 +947,11 @@ class ReferenceField(BaseField): | ||||
|                 cls = get_document(value.cls) | ||||
|             else: | ||||
|                 cls = self.document_type | ||||
|             value = cls._get_db().dereference(value) | ||||
|             if value is not None: | ||||
|                 instance._data[self.name] = cls._from_son(value) | ||||
|             dereferenced = cls._get_db().dereference(value) | ||||
|             if dereferenced is None: | ||||
|                 raise DoesNotExist('Trying to dereference unknown document %s' % value) | ||||
|             else: | ||||
|                 instance._data[self.name] = cls._from_son(dereferenced) | ||||
|  | ||||
|         return super(ReferenceField, self).__get__(instance, owner) | ||||
|  | ||||
| @@ -1001,11 +1014,10 @@ class ReferenceField(BaseField): | ||||
|  | ||||
|         if self.document_type._meta.get('abstract') and \ | ||||
|                 not isinstance(value, self.document_type): | ||||
|             self.error('%s is not an instance of abstract reference' | ||||
|                     ' type %s' % (value._class_name, | ||||
|                         self.document_type._class_name) | ||||
|                     ) | ||||
|  | ||||
|             self.error( | ||||
|                 '%s is not an instance of abstract reference type %s' % ( | ||||
|                     self.document_type._class_name) | ||||
|             ) | ||||
|  | ||||
|     def lookup_member(self, member_name): | ||||
|         return self.document_type._fields.get(member_name) | ||||
| @@ -1014,7 +1026,7 @@ class ReferenceField(BaseField): | ||||
| class CachedReferenceField(BaseField): | ||||
|     """ | ||||
|     A referencefield with cache fields to purpose pseudo-joins | ||||
|      | ||||
|  | ||||
|     .. versionadded:: 0.9 | ||||
|     """ | ||||
|  | ||||
| @@ -1082,13 +1094,15 @@ class CachedReferenceField(BaseField): | ||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||
|         # Dereference DBRefs | ||||
|         if self._auto_dereference and isinstance(value, DBRef): | ||||
|             value = self.document_type._get_db().dereference(value) | ||||
|             if value is not None: | ||||
|                 instance._data[self.name] = self.document_type._from_son(value) | ||||
|             dereferenced = self.document_type._get_db().dereference(value) | ||||
|             if dereferenced is None: | ||||
|                 raise DoesNotExist('Trying to dereference unknown document %s' % value) | ||||
|             else: | ||||
|                 instance._data[self.name] = self.document_type._from_son(dereferenced) | ||||
|  | ||||
|         return super(CachedReferenceField, self).__get__(instance, owner) | ||||
|  | ||||
|     def to_mongo(self, document): | ||||
|     def to_mongo(self, document, use_db_field=True, fields=None): | ||||
|         id_field_name = self.document_type._meta['id_field'] | ||||
|         id_field = self.document_type._fields[id_field_name] | ||||
|  | ||||
| @@ -1106,7 +1120,12 @@ class CachedReferenceField(BaseField): | ||||
|             ("_id", id_field.to_mongo(id_)), | ||||
|         )) | ||||
|  | ||||
|         value.update(dict(document.to_mongo(fields=self.fields))) | ||||
|         if fields: | ||||
|             new_fields = [f for f in self.fields if f in fields] | ||||
|         else: | ||||
|             new_fields = self.fields | ||||
|  | ||||
|         value.update(dict(document.to_mongo(use_db_field, fields=new_fields))) | ||||
|         return value | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
| @@ -1197,7 +1216,11 @@ class GenericReferenceField(BaseField): | ||||
|  | ||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||
|         if self._auto_dereference and isinstance(value, (dict, SON)): | ||||
|             instance._data[self.name] = self.dereference(value) | ||||
|             dereferenced = self.dereference(value) | ||||
|             if dereferenced is None: | ||||
|                 raise DoesNotExist('Trying to dereference unknown document %s' % value) | ||||
|             else: | ||||
|                 instance._data[self.name] = dereferenced | ||||
|  | ||||
|         return super(GenericReferenceField, self).__get__(instance, owner) | ||||
|  | ||||
| @@ -1222,11 +1245,11 @@ class GenericReferenceField(BaseField): | ||||
|             doc = doc_cls._from_son(doc) | ||||
|         return doc | ||||
|  | ||||
|     def to_mongo(self, document, use_db_field=True): | ||||
|     def to_mongo(self, document): | ||||
|         if document is None: | ||||
|             return None | ||||
|  | ||||
|         if isinstance(document, (dict, SON)): | ||||
|         if isinstance(document, (dict, SON, ObjectId, DBRef)): | ||||
|             return document | ||||
|  | ||||
|         id_field_name = document.__class__._meta['id_field'] | ||||
| @@ -1370,7 +1393,7 @@ class GridFSProxy(object): | ||||
|             if self.gridout is None: | ||||
|                 self.gridout = self.fs.get(self.grid_id) | ||||
|             return self.gridout | ||||
|         except: | ||||
|         except Exception: | ||||
|             # File has been deleted | ||||
|             return None | ||||
|  | ||||
| @@ -1408,7 +1431,7 @@ class GridFSProxy(object): | ||||
|         else: | ||||
|             try: | ||||
|                 return gridout.read(size) | ||||
|             except: | ||||
|             except Exception: | ||||
|                 return "" | ||||
|  | ||||
|     def delete(self): | ||||
| @@ -1473,7 +1496,7 @@ class FileField(BaseField): | ||||
|             if grid_file: | ||||
|                 try: | ||||
|                     grid_file.delete() | ||||
|                 except: | ||||
|                 except Exception: | ||||
|                     pass | ||||
|  | ||||
|             # Create a new proxy object as we don't already have one | ||||
| @@ -1707,17 +1730,17 @@ class SequenceField(BaseField): | ||||
|     :param collection_name:  Name of the counter collection (default 'mongoengine.counters') | ||||
|     :param sequence_name: Name of the sequence in the collection (default 'ClassName.counter') | ||||
|     :param value_decorator: Any callable to use as a counter (default int) | ||||
|          | ||||
|  | ||||
|     Use any callable as `value_decorator` to transform calculated counter into | ||||
|     any value suitable for your needs, e.g. string or hexadecimal | ||||
|     representation of the default integer counter value. | ||||
|      | ||||
|  | ||||
|     .. note:: | ||||
|      | ||||
|         In case the counter is defined in the abstract document, it will be  | ||||
|         common to all inherited documents and the default sequence name will  | ||||
|  | ||||
|         In case the counter is defined in the abstract document, it will be | ||||
|         common to all inherited documents and the default sequence name will | ||||
|         be the class name of the abstract document. | ||||
|      | ||||
|  | ||||
|     .. versionadded:: 0.5 | ||||
|     .. versionchanged:: 0.8 added `value_decorator` | ||||
|     """ | ||||
| @@ -1841,7 +1864,7 @@ class UUIDField(BaseField): | ||||
|                 if not isinstance(value, basestring): | ||||
|                     value = unicode(value) | ||||
|                 return uuid.UUID(value) | ||||
|             except: | ||||
|             except Exception: | ||||
|                 return original_value | ||||
|         return value | ||||
|  | ||||
|   | ||||
| @@ -1,9 +1,22 @@ | ||||
| """Helper functions and types to aid with Python 2.5 - 3 support.""" | ||||
| """Helper functions and types to aid with Python 2.6 - 3 support.""" | ||||
|  | ||||
| import sys | ||||
| import warnings | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
|  | ||||
| # Show a deprecation warning for people using Python v2.6 | ||||
| # TODO remove in mongoengine v0.11.0 | ||||
| if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||
|     warnings.warn( | ||||
|         'Python v2.6 support is deprecated and is going to be dropped ' | ||||
|         'entirely in the upcoming v0.11.0 release. Update your Python ' | ||||
|         'version if you want to have access to the latest features and ' | ||||
|         'bug fixes in MongoEngine.', | ||||
|         DeprecationWarning | ||||
|     ) | ||||
|  | ||||
| if pymongo.version_tuple[0] < 3: | ||||
|     IS_PYMONGO_3 = False | ||||
| else: | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| from mongoengine.errors import (DoesNotExist, MultipleObjectsReturned, | ||||
|                                 InvalidQueryError, OperationError, | ||||
|                                 NotUniqueError) | ||||
| from mongoengine.errors import (DoesNotExist, InvalidQueryError, | ||||
|                                 MultipleObjectsReturned, NotUniqueError, | ||||
|                                 OperationError) | ||||
| from mongoengine.queryset.field_list import * | ||||
| from mongoengine.queryset.manager import * | ||||
| from mongoengine.queryset.queryset import * | ||||
|   | ||||
| @@ -7,20 +7,19 @@ import pprint | ||||
| import re | ||||
| import warnings | ||||
|  | ||||
| from bson import SON | ||||
| from bson import SON, json_util | ||||
| from bson.code import Code | ||||
| from bson import json_util | ||||
| import pymongo | ||||
| import pymongo.errors | ||||
| from pymongo.common import validate_read_preference | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.base.common import get_document | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.context_managers import switch_db | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.base.common import get_document | ||||
| from mongoengine.errors import (OperationError, NotUniqueError, | ||||
|                                 InvalidQueryError, LookUpError) | ||||
| from mongoengine.errors import (InvalidQueryError, LookUpError, | ||||
|                                 NotUniqueError, OperationError) | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| from mongoengine.queryset import transform | ||||
| from mongoengine.queryset.field_list import QueryFieldList | ||||
| @@ -83,6 +82,7 @@ class BaseQuerySet(object): | ||||
|         self._limit = None | ||||
|         self._skip = None | ||||
|         self._hint = -1  # Using -1 as None is a valid value for hint | ||||
|         self._batch_size = None | ||||
|         self.only_fields = [] | ||||
|         self._max_time_ms = None | ||||
|  | ||||
| @@ -123,9 +123,40 @@ class BaseQuerySet(object): | ||||
|  | ||||
|         return queryset | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         """Support skip and limit using getitem and slicing syntax. | ||||
|     def __getstate__(self): | ||||
|         """ | ||||
|         Need for pickling queryset | ||||
|  | ||||
|         See https://github.com/MongoEngine/mongoengine/issues/442 | ||||
|         """ | ||||
|  | ||||
|         obj_dict = self.__dict__.copy() | ||||
|  | ||||
|         # don't picke collection, instead pickle collection params | ||||
|         obj_dict.pop("_collection_obj") | ||||
|  | ||||
|         # don't pickle cursor | ||||
|         obj_dict["_cursor_obj"] = None | ||||
|  | ||||
|         return obj_dict | ||||
|  | ||||
|     def __setstate__(self, obj_dict): | ||||
|         """ | ||||
|         Need for pickling queryset | ||||
|  | ||||
|         See https://github.com/MongoEngine/mongoengine/issues/442 | ||||
|         """ | ||||
|  | ||||
|         obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() | ||||
|  | ||||
|         # update attributes | ||||
|         self.__dict__.update(obj_dict) | ||||
|  | ||||
|         # forse load cursor | ||||
|         # self._cursor | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         """Support skip and limit using getitem and slicing syntax.""" | ||||
|         queryset = self.clone() | ||||
|  | ||||
|         # Slice provided | ||||
| @@ -245,6 +276,8 @@ class BaseQuerySet(object): | ||||
|         except StopIteration: | ||||
|             return result | ||||
|  | ||||
|         # If we were able to retrieve the 2nd doc, rewind the cursor and | ||||
|         # raise the MultipleObjectsReturned exception. | ||||
|         queryset.rewind() | ||||
|         message = u'%d items returned, instead of 1' % queryset.count() | ||||
|         raise queryset._document.MultipleObjectsReturned(message) | ||||
| @@ -266,7 +299,8 @@ class BaseQuerySet(object): | ||||
|             result = None | ||||
|         return result | ||||
|  | ||||
|     def insert(self, doc_or_docs, load_bulk=True, write_concern=None): | ||||
|     def insert(self, doc_or_docs, load_bulk=True, | ||||
|                write_concern=None, signal_kwargs=None): | ||||
|         """bulk insert documents | ||||
|  | ||||
|         :param doc_or_docs: a document or list of documents to be inserted | ||||
| @@ -279,11 +313,15 @@ class BaseQuerySet(object): | ||||
|                 ``insert(..., {w: 2, fsync: True})`` will wait until at least | ||||
|                 two servers have recorded the write and will force an fsync on | ||||
|                 each server being written to. | ||||
|         :parm signal_kwargs: (optional) kwargs dictionary to be passed to | ||||
|             the signal calls. | ||||
|  | ||||
|         By default returns document instances, set ``load_bulk`` to False to | ||||
|         return just ``ObjectIds`` | ||||
|  | ||||
|         .. versionadded:: 0.5 | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             Add signal_kwargs argument | ||||
|         """ | ||||
|         Document = _import_class('Document') | ||||
|  | ||||
| @@ -296,7 +334,6 @@ class BaseQuerySet(object): | ||||
|             return_one = True | ||||
|             docs = [docs] | ||||
|  | ||||
|         raw = [] | ||||
|         for doc in docs: | ||||
|             if not isinstance(doc, self._document): | ||||
|                 msg = ("Some documents inserted aren't instances of %s" | ||||
| @@ -305,9 +342,12 @@ class BaseQuerySet(object): | ||||
|             if doc.pk and not doc._created: | ||||
|                 msg = "Some documents have ObjectIds use doc.update() instead" | ||||
|                 raise OperationError(msg) | ||||
|             raw.append(doc.to_mongo()) | ||||
|  | ||||
|         signals.pre_bulk_insert.send(self._document, documents=docs) | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|         signals.pre_bulk_insert.send(self._document, | ||||
|                                      documents=docs, **signal_kwargs) | ||||
|  | ||||
|         raw = [doc.to_mongo() for doc in docs] | ||||
|         try: | ||||
|             ids = self._collection.insert(raw, **write_concern) | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
| @@ -324,7 +364,7 @@ class BaseQuerySet(object): | ||||
|  | ||||
|         if not load_bulk: | ||||
|             signals.post_bulk_insert.send( | ||||
|                 self._document, documents=docs, loaded=False) | ||||
|                 self._document, documents=docs, loaded=False, **signal_kwargs) | ||||
|             return return_one and ids[0] or ids | ||||
|  | ||||
|         documents = self.in_bulk(ids) | ||||
| @@ -332,7 +372,7 @@ class BaseQuerySet(object): | ||||
|         for obj_id in ids: | ||||
|             results.append(documents.get(obj_id)) | ||||
|         signals.post_bulk_insert.send( | ||||
|             self._document, documents=results, loaded=True) | ||||
|             self._document, documents=results, loaded=True, **signal_kwargs) | ||||
|         return return_one and results[0] or results | ||||
|  | ||||
|     def count(self, with_limit_and_skip=False): | ||||
| @@ -403,9 +443,11 @@ class BaseQuerySet(object): | ||||
|             rule = doc._meta['delete_rules'][rule_entry] | ||||
|             if rule == CASCADE: | ||||
|                 cascade_refs = set() if cascade_refs is None else cascade_refs | ||||
|                 for ref in queryset: | ||||
|                     cascade_refs.add(ref.id) | ||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) | ||||
|                 # Handle recursive reference | ||||
|                 if doc._collection == document_cls._collection: | ||||
|                     for ref in queryset: | ||||
|                         cascade_refs.add(ref.id) | ||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self, 'pk__nin': cascade_refs}) | ||||
|                 ref_q_count = ref_q.count() | ||||
|                 if ref_q_count > 0: | ||||
|                     ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs) | ||||
| @@ -425,7 +467,7 @@ class BaseQuerySet(object): | ||||
|                full_result=False, **update): | ||||
|         """Perform an atomic update on the fields matched by the query. | ||||
|  | ||||
|         :param upsert: Any existing document with that "_id" is overwritten. | ||||
|         :param upsert: insert if document doesn't exist (default ``False``) | ||||
|         :param multi: Update multiple documents. | ||||
|         :param write_concern: Extra keyword arguments are passed down which | ||||
|             will be used as options for the resultant | ||||
| @@ -471,7 +513,6 @@ class BaseQuerySet(object): | ||||
|                 raise OperationError(message) | ||||
|             raise OperationError(u'Update failed (%s)' % unicode(err)) | ||||
|  | ||||
|  | ||||
|     def upsert_one(self, write_concern=None, **update): | ||||
|         """Overwrite or add the first document matched by the query. | ||||
|  | ||||
| @@ -488,8 +529,9 @@ class BaseQuerySet(object): | ||||
|         .. versionadded:: 0.10.2 | ||||
|         """ | ||||
|  | ||||
|         atomic_update = self.update(multi=False, upsert=True, write_concern=write_concern, | ||||
|                              full_result=True,**update) | ||||
|         atomic_update = self.update(multi=False, upsert=True, | ||||
|                                     write_concern=write_concern, | ||||
|                                     full_result=True, **update) | ||||
|  | ||||
|         if atomic_update['updatedExisting']: | ||||
|             document = self.get() | ||||
| @@ -501,7 +543,7 @@ class BaseQuerySet(object): | ||||
|         """Perform an atomic update on the fields of the first document | ||||
|         matched by the query. | ||||
|  | ||||
|         :param upsert: Any existing document with that "_id" is overwritten. | ||||
|         :param upsert: insert if document doesn't exist (default ``False``) | ||||
|         :param write_concern: Extra keyword arguments are passed down which | ||||
|             will be used as options for the resultant | ||||
|             ``getLastError`` command.  For example, | ||||
| @@ -742,6 +784,19 @@ class BaseQuerySet(object): | ||||
|         queryset._hint = index | ||||
|         return queryset | ||||
|  | ||||
|     def batch_size(self, size): | ||||
|         """Limit the number of documents returned in a single batch (each | ||||
|         batch requires a round trip to the server). | ||||
|  | ||||
|         See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size | ||||
|         for details. | ||||
|  | ||||
|         :param size: desired size of each batch. | ||||
|         """ | ||||
|         queryset = self.clone() | ||||
|         queryset._batch_size = size | ||||
|         return queryset | ||||
|  | ||||
|     def distinct(self, field): | ||||
|         """Return a list of distinct values for a given field. | ||||
|  | ||||
| @@ -894,6 +949,14 @@ class BaseQuerySet(object): | ||||
|         queryset._ordering = queryset._get_order_by(keys) | ||||
|         return queryset | ||||
|  | ||||
|     def comment(self, text): | ||||
|         """Add a comment to the query. | ||||
|  | ||||
|         See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment | ||||
|         for details. | ||||
|         """ | ||||
|         return self._chainable_method("comment", text) | ||||
|  | ||||
|     def explain(self, format=False): | ||||
|         """Return an explain plan record for the | ||||
|         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. | ||||
| @@ -1229,66 +1292,29 @@ class BaseQuerySet(object): | ||||
|     def sum(self, field): | ||||
|         """Sum over the values of the specified field. | ||||
|  | ||||
|         :param field: the field to sum over; use dot-notation to refer to | ||||
|         :param field: the field to sum over; use dot notation to refer to | ||||
|             embedded document fields | ||||
|  | ||||
|         .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work | ||||
|             with sharding. | ||||
|         """ | ||||
|         map_func = """ | ||||
|             function() { | ||||
|                 var path = '{{~%(field)s}}'.split('.'), | ||||
|                 field = this; | ||||
|  | ||||
|                 for (p in path) { | ||||
|                     if (typeof field != 'undefined') | ||||
|                        field = field[path[p]]; | ||||
|                     else | ||||
|                        break; | ||||
|                 } | ||||
|  | ||||
|                 if (field && field.constructor == Array) { | ||||
|                     field.forEach(function(item) { | ||||
|                         emit(1, item||0); | ||||
|                     }); | ||||
|                 } else if (typeof field != 'undefined') { | ||||
|                     emit(1, field||0); | ||||
|                 } | ||||
|             } | ||||
|         """ % dict(field=field) | ||||
|  | ||||
|         reduce_func = Code(""" | ||||
|             function(key, values) { | ||||
|                 var sum = 0; | ||||
|                 for (var i in values) { | ||||
|                     sum += values[i]; | ||||
|                 } | ||||
|                 return sum; | ||||
|             } | ||||
|         """) | ||||
|  | ||||
|         for result in self.map_reduce(map_func, reduce_func, output='inline'): | ||||
|             return result.value | ||||
|         else: | ||||
|             return 0 | ||||
|  | ||||
|     def aggregate_sum(self, field): | ||||
|         """Sum over the values of the specified field. | ||||
|  | ||||
|         :param field: the field to sum over; use dot-notation to refer to | ||||
|             embedded document fields | ||||
|  | ||||
|         This method is more performant than the regular `sum`, because it uses | ||||
|         the aggregation framework instead of map-reduce. | ||||
|         """ | ||||
|         result = self._document._get_collection().aggregate([ | ||||
|         db_field = self._fields_to_dbfields([field]).pop() | ||||
|         pipeline = [ | ||||
|             {'$match': self._query}, | ||||
|             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}} | ||||
|         ]) | ||||
|             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}} | ||||
|         ] | ||||
|  | ||||
|         # if we're performing a sum over a list field, we sum up all the | ||||
|         # elements in the list, hence we need to $unwind the arrays first | ||||
|         ListField = _import_class('ListField') | ||||
|         field_parts = field.split('.') | ||||
|         field_instances = self._document._lookup_field(field_parts) | ||||
|         if isinstance(field_instances[-1], ListField): | ||||
|             pipeline.insert(1, {'$unwind': '$' + field}) | ||||
|  | ||||
|         result = self._document._get_collection().aggregate(pipeline) | ||||
|         if IS_PYMONGO_3: | ||||
|             result = list(result) | ||||
|             result = tuple(result) | ||||
|         else: | ||||
|             result = result.get('result') | ||||
|  | ||||
|         if result: | ||||
|             return result[0]['total'] | ||||
|         return 0 | ||||
| @@ -1296,73 +1322,27 @@ class BaseQuerySet(object): | ||||
|     def average(self, field): | ||||
|         """Average over the values of the specified field. | ||||
|  | ||||
|         :param field: the field to average over; use dot-notation to refer to | ||||
|         :param field: the field to average over; use dot notation to refer to | ||||
|             embedded document fields | ||||
|  | ||||
|         .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work | ||||
|             with sharding. | ||||
|         """ | ||||
|         map_func = """ | ||||
|             function() { | ||||
|                 var path = '{{~%(field)s}}'.split('.'), | ||||
|                 field = this; | ||||
|  | ||||
|                 for (p in path) { | ||||
|                     if (typeof field != 'undefined') | ||||
|                        field = field[path[p]]; | ||||
|                     else | ||||
|                        break; | ||||
|                 } | ||||
|  | ||||
|                 if (field && field.constructor == Array) { | ||||
|                     field.forEach(function(item) { | ||||
|                         emit(1, {t: item||0, c: 1}); | ||||
|                     }); | ||||
|                 } else if (typeof field != 'undefined') { | ||||
|                     emit(1, {t: field||0, c: 1}); | ||||
|                 } | ||||
|             } | ||||
|         """ % dict(field=field) | ||||
|  | ||||
|         reduce_func = Code(""" | ||||
|             function(key, values) { | ||||
|                 var out = {t: 0, c: 0}; | ||||
|                 for (var i in values) { | ||||
|                     var value = values[i]; | ||||
|                     out.t += value.t; | ||||
|                     out.c += value.c; | ||||
|                 } | ||||
|                 return out; | ||||
|             } | ||||
|         """) | ||||
|  | ||||
|         finalize_func = Code(""" | ||||
|             function(key, value) { | ||||
|                 return value.t / value.c; | ||||
|             } | ||||
|         """) | ||||
|  | ||||
|         for result in self.map_reduce(map_func, reduce_func, | ||||
|                                       finalize_f=finalize_func, output='inline'): | ||||
|             return result.value | ||||
|         else: | ||||
|             return 0 | ||||
|  | ||||
|     def aggregate_average(self, field): | ||||
|         """Average over the values of the specified field. | ||||
|  | ||||
|         :param field: the field to average over; use dot-notation to refer to | ||||
|             embedded document fields | ||||
|  | ||||
|         This method is more performant than the regular `average`, because it | ||||
|         uses the aggregation framework instead of map-reduce. | ||||
|         """ | ||||
|         result = self._document._get_collection().aggregate([ | ||||
|         db_field = self._fields_to_dbfields([field]).pop() | ||||
|         pipeline = [ | ||||
|             {'$match': self._query}, | ||||
|             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}} | ||||
|         ]) | ||||
|             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}} | ||||
|         ] | ||||
|  | ||||
|         # if we're performing an average over a list field, we average out | ||||
|         # all the elements in the list, hence we need to $unwind the arrays | ||||
|         # first | ||||
|         ListField = _import_class('ListField') | ||||
|         field_parts = field.split('.') | ||||
|         field_instances = self._document._lookup_field(field_parts) | ||||
|         if isinstance(field_instances[-1], ListField): | ||||
|             pipeline.insert(1, {'$unwind': '$' + field}) | ||||
|  | ||||
|         result = self._document._get_collection().aggregate(pipeline) | ||||
|         if IS_PYMONGO_3: | ||||
|             result = list(result) | ||||
|             result = tuple(result) | ||||
|         else: | ||||
|             result = result.get('result') | ||||
|         if result: | ||||
| @@ -1379,7 +1359,7 @@ class BaseQuerySet(object): | ||||
|             Can only do direct simple mappings and cannot map across | ||||
|             :class:`~mongoengine.fields.ReferenceField` or | ||||
|             :class:`~mongoengine.fields.GenericReferenceField` for more complex | ||||
|             counting a manual map reduce call would is required. | ||||
|             counting a manual map reduce call is required. | ||||
|  | ||||
|         If the field is a :class:`~mongoengine.fields.ListField`, the items within | ||||
|         each list will be counted individually. | ||||
| @@ -1453,7 +1433,7 @@ class BaseQuerySet(object): | ||||
|                 msg = "The snapshot option is not anymore available with PyMongo 3+" | ||||
|                 warnings.warn(msg, DeprecationWarning) | ||||
|             cursor_args = { | ||||
|                 'no_cursor_timeout': self._timeout | ||||
|                 'no_cursor_timeout': not self._timeout | ||||
|             } | ||||
|         if self._loaded_fields: | ||||
|             cursor_args[fields_name] = self._loaded_fields.as_dict() | ||||
| @@ -1503,6 +1483,9 @@ class BaseQuerySet(object): | ||||
|             if self._hint != -1: | ||||
|                 self._cursor_obj.hint(self._hint) | ||||
|  | ||||
|             if self._batch_size is not None: | ||||
|                 self._cursor_obj.batch_size(self._batch_size) | ||||
|  | ||||
|         return self._cursor_obj | ||||
|  | ||||
|     def __deepcopy__(self, memo): | ||||
| @@ -1696,7 +1679,7 @@ class BaseQuerySet(object): | ||||
|             key = key.replace('__', '.') | ||||
|             try: | ||||
|                 key = self._document._translate_field_name(key) | ||||
|             except: | ||||
|             except Exception: | ||||
|                 pass | ||||
|             key_list.append((key, direction)) | ||||
|  | ||||
|   | ||||
| @@ -29,7 +29,7 @@ class QuerySetManager(object): | ||||
|         Document.objects is accessed. | ||||
|         """ | ||||
|         if instance is not None: | ||||
|             # Document class being used rather than a document object | ||||
|             # Document object being used rather than a document class | ||||
|             return self | ||||
|  | ||||
|         # owner is the document that contains the QuerySetManager | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| from mongoengine.errors import OperationError | ||||
| from mongoengine.queryset.base import (BaseQuerySet, DO_NOTHING, NULLIFY, | ||||
|                                        CASCADE, DENY, PULL) | ||||
| from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING, | ||||
|                                        NULLIFY, PULL) | ||||
|  | ||||
| __all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', | ||||
|            'DENY', 'PULL') | ||||
| @@ -27,9 +27,10 @@ class QuerySet(BaseQuerySet): | ||||
|         in batches of ``ITER_CHUNK_SIZE``. | ||||
|  | ||||
|         If ``self._has_more`` the cursor hasn't been exhausted so cache then | ||||
|         batch.  Otherwise iterate the result_cache. | ||||
|         batch. Otherwise iterate the result_cache. | ||||
|         """ | ||||
|         self._iter = True | ||||
|  | ||||
|         if self._has_more: | ||||
|             return self._iter_results() | ||||
|  | ||||
| @@ -38,14 +39,16 @@ class QuerySet(BaseQuerySet): | ||||
|  | ||||
|     def __len__(self): | ||||
|         """Since __len__ is called quite frequently (for example, as part of | ||||
|         list(qs) we populate the result cache and cache the length. | ||||
|         list(qs)), we populate the result cache and cache the length. | ||||
|         """ | ||||
|         if self._len is not None: | ||||
|             return self._len | ||||
|  | ||||
|         # Populate the result cache with *all* of the docs in the cursor | ||||
|         if self._has_more: | ||||
|             # populate the cache | ||||
|             list(self._iter_results()) | ||||
|  | ||||
|         # Cache the length of the complete result cache and return it | ||||
|         self._len = len(self._result_cache) | ||||
|         return self._len | ||||
|  | ||||
| @@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet): | ||||
|     def _iter_results(self): | ||||
|         """A generator for iterating over the result cache. | ||||
|  | ||||
|         Also populates the cache if there are more possible results to yield. | ||||
|         Raises StopIteration when there are no more results""" | ||||
|         Also populates the cache if there are more possible results to | ||||
|         yield. Raises StopIteration when there are no more results. | ||||
|         """ | ||||
|         if self._result_cache is None: | ||||
|             self._result_cache = [] | ||||
|  | ||||
|         pos = 0 | ||||
|         while True: | ||||
|             upper = len(self._result_cache) | ||||
|             while pos < upper: | ||||
|  | ||||
|             # For all positions lower than the length of the current result | ||||
|             # cache, serve the docs straight from the cache w/o hitting the | ||||
|             # database. | ||||
|             # XXX it's VERY important to compute the len within the `while` | ||||
|             # condition because the result cache might expand mid-iteration | ||||
|             # (e.g. if we call len(qs) inside a loop that iterates over the | ||||
|             # queryset). Fortunately len(list) is O(1) in Python, so this | ||||
|             # doesn't cause performance issues. | ||||
|             while pos < len(self._result_cache): | ||||
|                 yield self._result_cache[pos] | ||||
|                 pos += 1 | ||||
|  | ||||
|             # Raise StopIteration if we already established there were no more | ||||
|             # docs in the db cursor. | ||||
|             if not self._has_more: | ||||
|                 raise StopIteration | ||||
|  | ||||
|             # Otherwise, populate more of the cache and repeat. | ||||
|             if len(self._result_cache) <= pos: | ||||
|                 self._populate_cache() | ||||
|  | ||||
| @@ -86,12 +104,22 @@ class QuerySet(BaseQuerySet): | ||||
|         """ | ||||
|         if self._result_cache is None: | ||||
|             self._result_cache = [] | ||||
|         if self._has_more: | ||||
|             try: | ||||
|                 for i in xrange(ITER_CHUNK_SIZE): | ||||
|                     self._result_cache.append(self.next()) | ||||
|             except StopIteration: | ||||
|                 self._has_more = False | ||||
|  | ||||
|         # Skip populating the cache if we already established there are no | ||||
|         # more docs to pull from the database. | ||||
|         if not self._has_more: | ||||
|             return | ||||
|  | ||||
|         # Pull in ITER_CHUNK_SIZE docs from the database and store them in | ||||
|         # the result cache. | ||||
|         try: | ||||
|             for i in xrange(ITER_CHUNK_SIZE): | ||||
|                 self._result_cache.append(self.next()) | ||||
|         except StopIteration: | ||||
|             # Getting this exception means there are no more docs in the | ||||
|             # db cursor. Set _has_more to False so that we can use that | ||||
|             # information in other places. | ||||
|             self._has_more = False | ||||
|  | ||||
|     def count(self, with_limit_and_skip=False): | ||||
|         """Count the selected elements in the query. | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| from collections import defaultdict | ||||
|  | ||||
| from bson import ObjectId, SON | ||||
| from bson.dbref import DBRef | ||||
| import pymongo | ||||
| from bson import SON | ||||
|  | ||||
| from mongoengine.base.fields import UPDATE_OPERATORS | ||||
| from mongoengine.connection import get_connection | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import get_connection | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
|  | ||||
| @@ -26,6 +27,7 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + | ||||
|                    STRING_OPERATORS + CUSTOM_OPERATORS) | ||||
|  | ||||
|  | ||||
| # TODO make this less complex | ||||
| def query(_doc_cls=None, **kwargs): | ||||
|     """Transform a query from Django-style format to Mongo format. | ||||
|     """ | ||||
| @@ -44,7 +46,7 @@ def query(_doc_cls=None, **kwargs): | ||||
|         if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: | ||||
|             op = parts.pop() | ||||
|  | ||||
|         # Allw to escape operator-like field name by __ | ||||
|         # Allow to escape operator-like field name by __ | ||||
|         if len(parts) > 1 and parts[-1] == "": | ||||
|             parts.pop() | ||||
|  | ||||
| @@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs): | ||||
|             parts = [] | ||||
|  | ||||
|             CachedReferenceField = _import_class('CachedReferenceField') | ||||
|             GenericReferenceField = _import_class('GenericReferenceField') | ||||
|  | ||||
|             cleaned_fields = [] | ||||
|             for field in fields: | ||||
| @@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs): | ||||
|                 # 'in', 'nin' and 'all' require a list of values | ||||
|                 value = [field.prepare_query_value(op, v) for v in value] | ||||
|  | ||||
|             # If we're querying a GenericReferenceField, we need to alter the | ||||
|             # key depending on the value: | ||||
|             # * If the value is a DBRef, the key should be "field_name._ref". | ||||
|             # * If the value is an ObjectId, the key should be "field_name._ref.$id". | ||||
|             if isinstance(field, GenericReferenceField): | ||||
|                 if isinstance(value, DBRef): | ||||
|                     parts[-1] += '._ref' | ||||
|                 elif isinstance(value, ObjectId): | ||||
|                     parts[-1] += '._ref.$id' | ||||
|  | ||||
|         # if op and op not in COMPARISON_OPERATORS: | ||||
|         if op: | ||||
|             if op in GEO_OPERATORS: | ||||
| @@ -108,8 +121,11 @@ def query(_doc_cls=None, **kwargs): | ||||
|             elif op in ('match', 'elemMatch'): | ||||
|                 ListField = _import_class('ListField') | ||||
|                 EmbeddedDocumentField = _import_class('EmbeddedDocumentField') | ||||
|                 if (isinstance(value, dict) and isinstance(field, ListField) and | ||||
|                     isinstance(field.field, EmbeddedDocumentField)): | ||||
|                 if ( | ||||
|                     isinstance(value, dict) and | ||||
|                     isinstance(field, ListField) and | ||||
|                     isinstance(field.field, EmbeddedDocumentField) | ||||
|                 ): | ||||
|                     value = query(field.field.document_type, **value) | ||||
|                 else: | ||||
|                     value = field.prepare_query_value(op, value) | ||||
| @@ -125,11 +141,13 @@ def query(_doc_cls=None, **kwargs): | ||||
|  | ||||
|         for i, part in indices: | ||||
|             parts.insert(i, part) | ||||
|  | ||||
|         key = '.'.join(parts) | ||||
|  | ||||
|         if op is None or key not in mongo_query: | ||||
|             mongo_query[key] = value | ||||
|         elif key in mongo_query: | ||||
|             if key in mongo_query and isinstance(mongo_query[key], dict): | ||||
|             if isinstance(mongo_query[key], dict): | ||||
|                 mongo_query[key].update(value) | ||||
|                 # $max/minDistance needs to come last - convert to SON | ||||
|                 value_dict = mongo_query[key] | ||||
| @@ -212,6 +230,10 @@ def update(_doc_cls=None, **update): | ||||
|         if parts[-1] in COMPARISON_OPERATORS: | ||||
|             match = parts.pop() | ||||
|  | ||||
|         # Allow to escape operator-like field name by __ | ||||
|         if len(parts) > 1 and parts[-1] == "": | ||||
|             parts.pop() | ||||
|  | ||||
|         if _doc_cls: | ||||
|             # Switch field names to proper names [set in Field(name='foo')] | ||||
|             try: | ||||
| @@ -364,20 +386,24 @@ def _infer_geometry(value): | ||||
|                                 "type and coordinates keys") | ||||
|     elif isinstance(value, (list, set)): | ||||
|         # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? | ||||
|         # TODO: should both TypeError and IndexError be alike interpreted? | ||||
|  | ||||
|         try: | ||||
|             value[0][0][0] | ||||
|             return {"$geometry": {"type": "Polygon", "coordinates": value}} | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             value[0][0] | ||||
|             return {"$geometry": {"type": "LineString", "coordinates": value}} | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             value[0] | ||||
|             return {"$geometry": {"type": "Point", "coordinates": value}} | ||||
|         except: | ||||
|         except (TypeError, IndexError): | ||||
|             pass | ||||
|  | ||||
|     raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " | ||||
|   | ||||
| @@ -29,7 +29,7 @@ except ImportError: | ||||
|                                'because the blinker library is ' | ||||
|                                'not installed.') | ||||
|  | ||||
|         send = lambda *a, **kw: None | ||||
|         send = lambda *a, **kw: None  # noqa | ||||
|         connect = disconnect = has_receivers_for = receivers_for = \ | ||||
|             temporarily_connected_to = _fail | ||||
|         del _fail | ||||
|   | ||||
| @@ -1,2 +1,5 @@ | ||||
| pymongo>=2.7.1 | ||||
| nose | ||||
| pymongo>=2.7.1 | ||||
| six==1.10.0 | ||||
| flake8 | ||||
| flake8-import-order | ||||
|   | ||||
| @@ -1,8 +1,13 @@ | ||||
| [nosetests] | ||||
| rednose = 1 | ||||
| verbosity = 2 | ||||
| detailed-errors = 1 | ||||
| cover-erase = 1 | ||||
| cover-branches = 1 | ||||
| cover-package = mongoengine | ||||
| tests = tests | ||||
|  | ||||
| [flake8] | ||||
| ignore=E501,F401,F403,F405,I201 | ||||
| exclude=build,dist,docs,venv,.tox,.eggs,tests | ||||
| max-complexity=45 | ||||
| application-import-names=mongoengine,tests | ||||
|   | ||||
							
								
								
									
										57
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								setup.py
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | ||||
| import os | ||||
| import sys | ||||
| from setuptools import setup, find_packages | ||||
| from setuptools import find_packages, setup | ||||
|  | ||||
| # Hack to silence atexit traceback in newer python versions | ||||
| try: | ||||
| @@ -8,13 +8,16 @@ try: | ||||
| except ImportError: | ||||
|     pass | ||||
|  | ||||
| DESCRIPTION = 'MongoEngine is a Python Object-Document ' + \ | ||||
| 'Mapper for working with MongoDB.' | ||||
| LONG_DESCRIPTION = None | ||||
| DESCRIPTION = ( | ||||
|     'MongoEngine is a Python Object-Document ' | ||||
|     'Mapper for working with MongoDB.' | ||||
| ) | ||||
|  | ||||
| try: | ||||
|     LONG_DESCRIPTION = open('README.rst').read() | ||||
| except: | ||||
|     pass | ||||
|     with open('README.rst') as fin: | ||||
|         LONG_DESCRIPTION = fin.read() | ||||
| except Exception: | ||||
|     LONG_DESCRIPTION = None | ||||
|  | ||||
|  | ||||
| def get_version(version_tuple): | ||||
| @@ -22,6 +25,7 @@ def get_version(version_tuple): | ||||
|         return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] | ||||
|     return '.'.join(map(str, version_tuple)) | ||||
|  | ||||
|  | ||||
| # Dirty hack to get version number from monogengine/__init__.py - we can't | ||||
| # import it as it depends on PyMongo and PyMongo isn't installed until this | ||||
| # file is read | ||||
| @@ -52,32 +56,33 @@ CLASSIFIERS = [ | ||||
| extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | ||||
| if sys.version_info[0] == 3: | ||||
|     extra_opts['use_2to3'] = True | ||||
|     extra_opts['tests_require'] = ['nose', 'rednose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] | ||||
|     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] | ||||
|     if "test" in sys.argv or "nosetests" in sys.argv: | ||||
|         extra_opts['packages'] = find_packages() | ||||
|         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} | ||||
| else: | ||||
|     # coverage 4 does not support Python 3.2 anymore | ||||
|     extra_opts['tests_require'] = ['nose', 'rednose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] | ||||
|     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] | ||||
|  | ||||
|     if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||
|         extra_opts['tests_require'].append('unittest2') | ||||
|  | ||||
| setup(name='mongoengine', | ||||
|       version=VERSION, | ||||
|       author='Harry Marr', | ||||
|       author_email='harry.marr@{nospam}gmail.com', | ||||
|       maintainer="Ross Lawley", | ||||
|       maintainer_email="ross.lawley@{nospam}gmail.com", | ||||
|       url='http://mongoengine.org/', | ||||
|       download_url='https://github.com/MongoEngine/mongoengine/tarball/master', | ||||
|       license='MIT', | ||||
|       include_package_data=True, | ||||
|       description=DESCRIPTION, | ||||
|       long_description=LONG_DESCRIPTION, | ||||
|       platforms=['any'], | ||||
|       classifiers=CLASSIFIERS, | ||||
|       install_requires=['pymongo>=2.7.1'], | ||||
|       test_suite='nose.collector', | ||||
|       **extra_opts | ||||
| setup( | ||||
|     name='mongoengine', | ||||
|     version=VERSION, | ||||
|     author='Harry Marr', | ||||
|     author_email='harry.marr@{nospam}gmail.com', | ||||
|     maintainer="Ross Lawley", | ||||
|     maintainer_email="ross.lawley@{nospam}gmail.com", | ||||
|     url='http://mongoengine.org/', | ||||
|     download_url='https://github.com/MongoEngine/mongoengine/tarball/master', | ||||
|     license='MIT', | ||||
|     include_package_data=True, | ||||
|     description=DESCRIPTION, | ||||
|     long_description=LONG_DESCRIPTION, | ||||
|     platforms=['any'], | ||||
|     classifiers=CLASSIFIERS, | ||||
|     install_requires=['pymongo>=2.7.1', 'six'], | ||||
|     test_suite='nose.collector', | ||||
|     **extra_opts | ||||
| ) | ||||
|   | ||||
| @@ -2,7 +2,6 @@ | ||||
| import unittest | ||||
| import sys | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
| @@ -32,10 +31,7 @@ class IndexesTest(unittest.TestCase): | ||||
|         self.Person = Person | ||||
|  | ||||
|     def tearDown(self): | ||||
|         for collection in self.db.collection_names(): | ||||
|             if 'system.' in collection: | ||||
|                 continue | ||||
|             self.db.drop_collection(collection) | ||||
|         self.connection.drop_database(self.db) | ||||
|  | ||||
|     def test_indexes_document(self): | ||||
|         """Ensure that indexes are used when meta[indexes] is specified for | ||||
| @@ -822,33 +818,34 @@ class IndexesTest(unittest.TestCase): | ||||
|             name = StringField(required=True) | ||||
|             term = StringField(required=True) | ||||
|  | ||||
|         class Report(Document): | ||||
|         class ReportEmbedded(Document): | ||||
|             key = EmbeddedDocumentField(CompoundKey, primary_key=True) | ||||
|             text = StringField() | ||||
|  | ||||
|         Report.drop_collection() | ||||
|  | ||||
|         my_key = CompoundKey(name="n", term="ok") | ||||
|         report = Report(text="OK", key=my_key).save() | ||||
|         report = ReportEmbedded(text="OK", key=my_key).save() | ||||
|  | ||||
|         self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, | ||||
|                          report.to_mongo()) | ||||
|         self.assertEqual(report, Report.objects.get(pk=my_key)) | ||||
|         self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key)) | ||||
|  | ||||
|     def test_compound_key_dictfield(self): | ||||
|  | ||||
|         class Report(Document): | ||||
|         class ReportDictField(Document): | ||||
|             key = DictField(primary_key=True) | ||||
|             text = StringField() | ||||
|  | ||||
|         Report.drop_collection() | ||||
|  | ||||
|         my_key = {"name": "n", "term": "ok"} | ||||
|         report = Report(text="OK", key=my_key).save() | ||||
|         report = ReportDictField(text="OK", key=my_key).save() | ||||
|  | ||||
|         self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, | ||||
|                          report.to_mongo()) | ||||
|         self.assertEqual(report, Report.objects.get(pk=my_key)) | ||||
|  | ||||
|         # We can't directly call ReportDictField.objects.get(pk=my_key), | ||||
|         # because dicts are unordered, and if the order in MongoDB is | ||||
|         # different than the one in `my_key`, this test will fail. | ||||
|         self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key['name'])) | ||||
|         self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key['term'])) | ||||
|  | ||||
|     def test_string_indexes(self): | ||||
|  | ||||
| @@ -909,26 +906,38 @@ class IndexesTest(unittest.TestCase): | ||||
|  | ||||
|         Issue #812 | ||||
|         """ | ||||
|         # Use a new connection and database since dropping the database could | ||||
|         # cause concurrent tests to fail. | ||||
|         connection = connect(db='tempdatabase', | ||||
|                              alias='test_indexes_after_database_drop') | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             title = StringField() | ||||
|             slug = StringField(unique=True) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|             meta = {'db_alias': 'test_indexes_after_database_drop'} | ||||
|  | ||||
|         # Create Post #1 | ||||
|         post1 = BlogPost(title='test1', slug='test') | ||||
|         post1.save() | ||||
|         try: | ||||
|             BlogPost.drop_collection() | ||||
|  | ||||
|         # Drop the Database | ||||
|         self.connection.drop_database(BlogPost._get_db().name) | ||||
|             # Create Post #1 | ||||
|             post1 = BlogPost(title='test1', slug='test') | ||||
|             post1.save() | ||||
|  | ||||
|         # Re-create Post #1 | ||||
|         post1 = BlogPost(title='test1', slug='test') | ||||
|         post1.save() | ||||
|             # Drop the Database | ||||
|             connection.drop_database('tempdatabase') | ||||
|  | ||||
|             # Re-create Post #1 | ||||
|             post1 = BlogPost(title='test1', slug='test') | ||||
|             post1.save() | ||||
|  | ||||
|             # Create Post #2 | ||||
|             post2 = BlogPost(title='test2', slug='test') | ||||
|             self.assertRaises(NotUniqueError, post2.save) | ||||
|         finally: | ||||
|             # Drop the temporary database at the end | ||||
|             connection.drop_database('tempdatabase') | ||||
|  | ||||
|         # Create Post #2 | ||||
|         post2 = BlogPost(title='test2', slug='test') | ||||
|         self.assertRaises(NotUniqueError, post2.save) | ||||
|  | ||||
|     def test_index_dont_send_cls_option(self): | ||||
|         """ | ||||
|   | ||||
| @@ -411,7 +411,7 @@ class InheritanceTest(unittest.TestCase): | ||||
|         try: | ||||
|             class MyDocument(DateCreatedDocument, DateUpdatedDocument): | ||||
|                 pass | ||||
|         except: | ||||
|         except Exception: | ||||
|             self.assertTrue(False, "Couldn't create MyDocument class") | ||||
|  | ||||
|     def test_abstract_documents(self): | ||||
|   | ||||
| @@ -13,7 +13,7 @@ from datetime import datetime | ||||
| from bson import DBRef, ObjectId | ||||
| from tests import fixtures | ||||
| from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, | ||||
|                             PickleDyanmicEmbedded, PickleDynamicTest) | ||||
|                             PickleDynamicEmbedded, PickleDynamicTest) | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.errors import (NotRegistered, InvalidDocumentError, | ||||
| @@ -679,6 +679,19 @@ class InstanceTest(unittest.TestCase): | ||||
|         doc = Doc.objects.get() | ||||
|         self.assertHasInstance(doc.embedded_field[0], doc) | ||||
|  | ||||
|     def test_embedded_document_complex_instance_no_use_db_field(self): | ||||
|         """Ensure that use_db_field is propagated to list of Emb Docs | ||||
|         """ | ||||
|         class Embedded(EmbeddedDocument): | ||||
|             string = StringField(db_field='s') | ||||
|  | ||||
|         class Doc(Document): | ||||
|             embedded_field = ListField(EmbeddedDocumentField(Embedded)) | ||||
|  | ||||
|         d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo( | ||||
|             use_db_field=False).to_dict() | ||||
|         self.assertEqual(d['embedded_field'], [{'string': 'Hi'}]) | ||||
|  | ||||
|     def test_instance_is_set_on_setattr(self): | ||||
|  | ||||
|         class Email(EmbeddedDocument): | ||||
| @@ -1893,6 +1906,62 @@ class InstanceTest(unittest.TestCase): | ||||
|         author.delete() | ||||
|         self.assertEqual(BlogPost.objects.count(), 0) | ||||
|  | ||||
|     def test_reverse_delete_rule_with_custom_id_field(self): | ||||
|         """Ensure that a referenced document with custom primary key | ||||
|         is also deleted upon deletion. | ||||
|         """ | ||||
|         class User(Document): | ||||
|             name = StringField(primary_key=True) | ||||
|  | ||||
|         class Book(Document): | ||||
|             author = ReferenceField(User, reverse_delete_rule=CASCADE) | ||||
|             reviewer = ReferenceField(User, reverse_delete_rule=NULLIFY) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         Book.drop_collection() | ||||
|  | ||||
|         user = User(name='Mike').save() | ||||
|         reviewer = User(name='John').save() | ||||
|         book = Book(author=user, reviewer=reviewer).save() | ||||
|  | ||||
|         reviewer.delete() | ||||
|         self.assertEqual(Book.objects.count(), 1) | ||||
|         self.assertEqual(Book.objects.get().reviewer, None) | ||||
|  | ||||
|         user.delete() | ||||
|         self.assertEqual(Book.objects.count(), 0) | ||||
|  | ||||
|     def test_reverse_delete_rule_with_shared_id_among_collections(self): | ||||
|         """Ensure that cascade delete rule doesn't mix id among collections. | ||||
|         """ | ||||
|         class User(Document): | ||||
|             id = IntField(primary_key=True) | ||||
|  | ||||
|         class Book(Document): | ||||
|             id = IntField(primary_key=True) | ||||
|             author = ReferenceField(User, reverse_delete_rule=CASCADE) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         Book.drop_collection() | ||||
|  | ||||
|         user_1 = User(id=1).save() | ||||
|         user_2 = User(id=2).save() | ||||
|         book_1 = Book(id=1, author=user_2).save() | ||||
|         book_2 = Book(id=2, author=user_1).save() | ||||
|  | ||||
|         user_2.delete() | ||||
|         # Deleting user_2 should also delete book_1 but not book_2 | ||||
|         self.assertEqual(Book.objects.count(), 1) | ||||
|         self.assertEqual(Book.objects.get(), book_2) | ||||
|  | ||||
|         user_3 = User(id=3).save() | ||||
|         book_3 = Book(id=3, author=user_3).save() | ||||
|  | ||||
|         user_3.delete() | ||||
|         # Deleting user_3 should also delete book_3 | ||||
|         self.assertEqual(Book.objects.count(), 1) | ||||
|         self.assertEqual(Book.objects.get(), book_2) | ||||
|  | ||||
|     def test_reverse_delete_rule_with_document_inheritance(self): | ||||
|         """Ensure that a referenced document is also deleted upon deletion | ||||
|         of a child document. | ||||
| @@ -2248,7 +2317,7 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|         pickle_doc = PickleDynamicTest( | ||||
|             name="test", number=1, string="One", lists=['1', '2']) | ||||
|         pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") | ||||
|         pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar") | ||||
|         pickled_doc = pickle.dumps(pickle_doc)  # make sure pickling works even before the doc is saved | ||||
|  | ||||
|         pickle_doc.save() | ||||
| @@ -2859,6 +2928,20 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(person.name, "Test User") | ||||
|         self.assertEqual(person.age, 42) | ||||
|  | ||||
|     def test_positional_creation_embedded(self): | ||||
|         """Ensure that embedded document may be created using positional arguments. | ||||
|         """ | ||||
|         job = self.Job("Test Job", 4) | ||||
|         self.assertEqual(job.name, "Test Job") | ||||
|         self.assertEqual(job.years, 4) | ||||
|  | ||||
|     def test_mixed_creation_embedded(self): | ||||
|         """Ensure that embedded document may be created using mixed arguments. | ||||
|         """ | ||||
|         job = self.Job("Test Job", years=4) | ||||
|         self.assertEqual(job.name, "Test Job") | ||||
|         self.assertEqual(job.years, 4) | ||||
|  | ||||
|     def test_mixed_creation_dynamic(self): | ||||
|         """Ensure that document may be created using mixed arguments. | ||||
|         """ | ||||
| @@ -3035,6 +3118,17 @@ class InstanceTest(unittest.TestCase): | ||||
|         p4 = Person.objects()[0] | ||||
|         p4.save() | ||||
|         self.assertEquals(p4.height, 189) | ||||
|          | ||||
|         # However the default will not be fixed in DB | ||||
|         self.assertEquals(Person.objects(height=189).count(), 0) | ||||
|          | ||||
|         # alter DB for the new default | ||||
|         coll = Person._get_collection() | ||||
|         for person in Person.objects.as_pymongo(): | ||||
|             if 'height' not in person: | ||||
|                 person['height'] = 189 | ||||
|                 coll.save(person) | ||||
|                  | ||||
|         self.assertEquals(Person.objects(height=189).count(), 1) | ||||
|  | ||||
|     def test_from_son(self): | ||||
| @@ -3108,5 +3202,20 @@ class InstanceTest(unittest.TestCase): | ||||
|             self.assertEqual(b._instance, a) | ||||
|         self.assertEqual(idx, 2) | ||||
|  | ||||
|     def test_falsey_pk(self): | ||||
|         """Ensure that we can create and update a document with Falsey PK. | ||||
|         """ | ||||
|         class Person(Document): | ||||
|             age = IntField(primary_key=True) | ||||
|             height = FloatField() | ||||
|  | ||||
|         person = Person() | ||||
|         person.age = 0 | ||||
|         person.height = 1.89 | ||||
|         person.save() | ||||
|  | ||||
|         person.update(set__height=2.0) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
|  | ||||
| import six | ||||
| from nose.plugins.skip import SkipTest | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
| @@ -10,6 +12,7 @@ import uuid | ||||
| import math | ||||
| import itertools | ||||
| import re | ||||
| import six | ||||
|  | ||||
| try: | ||||
|     import dateutil | ||||
| @@ -19,12 +22,16 @@ except ImportError: | ||||
| from decimal import Decimal | ||||
|  | ||||
| from bson import Binary, DBRef, ObjectId | ||||
| try: | ||||
|     from bson.int64 import Int64 | ||||
| except ImportError: | ||||
|     Int64 = long | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.base import _document_registry | ||||
| from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList | ||||
| from mongoengine.errors import NotRegistered | ||||
| from mongoengine.errors import NotRegistered, DoesNotExist | ||||
| from mongoengine.python_support import PY3, b, bin_type | ||||
|  | ||||
| __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") | ||||
| @@ -399,20 +406,37 @@ class FieldTest(unittest.TestCase): | ||||
|         class Person(Document): | ||||
|             height = FloatField(min_value=0.1, max_value=3.5) | ||||
|  | ||||
|         class BigPerson(Document): | ||||
|             height = FloatField() | ||||
|  | ||||
|         person = Person() | ||||
|         person.height = 1.89 | ||||
|         person.validate() | ||||
|  | ||||
|         person.height = '2.0' | ||||
|         self.assertRaises(ValidationError, person.validate) | ||||
|  | ||||
|         person.height = 0.01 | ||||
|         self.assertRaises(ValidationError, person.validate) | ||||
|  | ||||
|         person.height = 4.0 | ||||
|         self.assertRaises(ValidationError, person.validate) | ||||
|  | ||||
|         person_2 = Person(height='something invalid') | ||||
|         self.assertRaises(ValidationError, person_2.validate) | ||||
|  | ||||
|         big_person = BigPerson() | ||||
|  | ||||
|         for value, value_type in enumerate(six.integer_types): | ||||
|             big_person.height = value_type(value) | ||||
|             big_person.validate() | ||||
|  | ||||
|         big_person.height = 2 ** 500 | ||||
|         big_person.validate() | ||||
|  | ||||
|         big_person.height = 2 ** 100000  # Too big for a float value | ||||
|         self.assertRaises(ValidationError, big_person.validate) | ||||
|  | ||||
|     def test_decimal_validation(self): | ||||
|         """Ensure that invalid values cannot be assigned to decimal fields. | ||||
|         """ | ||||
| @@ -1022,6 +1046,54 @@ class FieldTest(unittest.TestCase): | ||||
|         self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1) | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_assignment(self): | ||||
|         """Ensure that list field element assignment and slicing work | ||||
|         """ | ||||
|         class BlogPost(Document): | ||||
|             info = ListField() | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost() | ||||
|         post.info = ['e1', 'e2', 3, '4', 5] | ||||
|         post.save() | ||||
|  | ||||
|         post.info[0] = 1 | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info[0], 1) | ||||
|  | ||||
|         post.info[1:3] = ['n2', 'n3'] | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 'n2', 'n3', '4', 5]) | ||||
|  | ||||
|         post.info[-1] = 'n5' | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 'n2', 'n3', '4', 'n5']) | ||||
|  | ||||
|         post.info[-2] = 4 | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5']) | ||||
|  | ||||
|         post.info[1:-1] = [2] | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 2, 'n5']) | ||||
|  | ||||
|         post.info[:-1] = [1, 'n2', 'n3', 4] | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5']) | ||||
|  | ||||
|         post.info[-4:3] = [2, 3] | ||||
|         post.save() | ||||
|         post.reload() | ||||
|         self.assertEqual(post.info, [1, 2, 3, 4, 'n5']) | ||||
|  | ||||
|  | ||||
|     def test_list_field_passed_in_value(self): | ||||
|         class Foo(Document): | ||||
|             bars = ListField(ReferenceField("Bar")) | ||||
| @@ -1136,6 +1208,19 @@ class FieldTest(unittest.TestCase): | ||||
|         simple = simple.reload() | ||||
|         self.assertEqual(simple.widgets, [4]) | ||||
|  | ||||
|     def test_list_field_with_negative_indices(self): | ||||
|  | ||||
|         class Simple(Document): | ||||
|             widgets = ListField() | ||||
|  | ||||
|         simple = Simple(widgets=[1, 2, 3, 4]).save() | ||||
|         simple.widgets[-1] = 5 | ||||
|         self.assertEqual(['widgets.3'], simple._changed_fields) | ||||
|         simple.save() | ||||
|  | ||||
|         simple = simple.reload() | ||||
|         self.assertEqual(simple.widgets, [1, 2, 3, 5]) | ||||
|  | ||||
|     def test_list_field_complex(self): | ||||
|         """Ensure that the list fields can handle the complex types.""" | ||||
|  | ||||
| @@ -1515,6 +1600,29 @@ class FieldTest(unittest.TestCase): | ||||
|             actions__friends__operation='drink', | ||||
|             actions__friends__object='beer').count()) | ||||
|  | ||||
|     def test_map_field_unicode(self): | ||||
|  | ||||
|         class Info(EmbeddedDocument): | ||||
|             description = StringField() | ||||
|             value_list = ListField(field=StringField()) | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             info_dict = MapField(field=EmbeddedDocumentField(Info)) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         tree = BlogPost(info_dict={ | ||||
|             u"éééé": { | ||||
|                 'description': u"VALUE: éééé" | ||||
|             } | ||||
|         }) | ||||
|  | ||||
|         tree.save() | ||||
|  | ||||
|         self.assertEqual(BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, u"VALUE: éééé") | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_embedded_db_field(self): | ||||
|  | ||||
|         class Embedded(EmbeddedDocument): | ||||
| @@ -1551,6 +1659,8 @@ class FieldTest(unittest.TestCase): | ||||
|             name = StringField() | ||||
|             preferences = EmbeddedDocumentField(PersonPreferences) | ||||
|  | ||||
|         Person.drop_collection() | ||||
|  | ||||
|         person = Person(name='Test User') | ||||
|         person.preferences = 'My Preferences' | ||||
|         self.assertRaises(ValidationError, person.validate) | ||||
| @@ -1583,12 +1693,70 @@ class FieldTest(unittest.TestCase): | ||||
|             content = StringField() | ||||
|             author = EmbeddedDocumentField(User) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost(content='What I did today...') | ||||
|         post.author = PowerUser(name='Test User', power=47) | ||||
|         post.save() | ||||
|  | ||||
|         self.assertEqual(47, BlogPost.objects.first().author.power) | ||||
|  | ||||
|     def test_embedded_document_inheritance_with_list(self): | ||||
|         """Ensure that nested list of subclassed embedded documents is | ||||
|         handled correctly. | ||||
|         """ | ||||
|  | ||||
|         class Group(EmbeddedDocument): | ||||
|             name = StringField() | ||||
|             content = ListField(StringField()) | ||||
|  | ||||
|         class Basedoc(Document): | ||||
|             groups = ListField(EmbeddedDocumentField(Group)) | ||||
|             meta = {'abstract': True} | ||||
|  | ||||
|         class User(Basedoc): | ||||
|             doctype = StringField(require=True, default='userdata') | ||||
|  | ||||
|         User.drop_collection() | ||||
|  | ||||
|         content = ['la', 'le', 'lu'] | ||||
|         group = Group(name='foo', content=content) | ||||
|         foobar = User(groups=[group]) | ||||
|         foobar.save() | ||||
|  | ||||
|         self.assertEqual(content, User.objects.first().groups[0].content) | ||||
|  | ||||
|     def test_reference_miss(self): | ||||
|         """Ensure an exception is raised when dereferencing unknow document | ||||
|         """ | ||||
|  | ||||
|         class Foo(Document): | ||||
|             pass | ||||
|  | ||||
|         class Bar(Document): | ||||
|             ref = ReferenceField(Foo) | ||||
|             generic_ref = GenericReferenceField() | ||||
|  | ||||
|         Foo.drop_collection() | ||||
|         Bar.drop_collection() | ||||
|  | ||||
|         foo = Foo().save() | ||||
|         bar = Bar(ref=foo, generic_ref=foo).save() | ||||
|  | ||||
|         # Reference is no longer valid | ||||
|         foo.delete() | ||||
|         bar = Bar.objects.get() | ||||
|         self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref')) | ||||
|         self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref')) | ||||
|  | ||||
|         # When auto_dereference is disabled, there is no trouble returning DBRef | ||||
|         bar = Bar.objects.get() | ||||
|         expected = foo.to_dbref() | ||||
|         bar._fields['ref']._auto_dereference = False | ||||
|         self.assertEqual(bar.ref, expected) | ||||
|         bar._fields['generic_ref']._auto_dereference = False | ||||
|         self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'}) | ||||
|  | ||||
|     def test_reference_validation(self): | ||||
|         """Ensure that invalid docment objects cannot be assigned to reference | ||||
|         fields. | ||||
| @@ -1655,7 +1823,7 @@ class FieldTest(unittest.TestCase): | ||||
|                                'parent': "50a234ea469ac1eda42d347d"}) | ||||
|         mongoed = p1.to_mongo() | ||||
|         self.assertTrue(isinstance(mongoed['parent'], ObjectId)) | ||||
|          | ||||
|  | ||||
|     def test_cached_reference_field_get_and_save(self): | ||||
|         """ | ||||
|         Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo | ||||
| @@ -1667,11 +1835,11 @@ class FieldTest(unittest.TestCase): | ||||
|         class Ocorrence(Document): | ||||
|             person = StringField() | ||||
|             animal = CachedReferenceField(Animal) | ||||
|          | ||||
|  | ||||
|         Animal.drop_collection() | ||||
|         Ocorrence.drop_collection() | ||||
|          | ||||
|         Ocorrence(person="testte",  | ||||
|  | ||||
|         Ocorrence(person="testte", | ||||
|                   animal=Animal(name="Leopard", tag="heavy").save()).save() | ||||
|         p = Ocorrence.objects.get() | ||||
|         p.person = 'new_testte' | ||||
| @@ -2281,6 +2449,16 @@ class FieldTest(unittest.TestCase): | ||||
|         Member.drop_collection() | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_drop_abstract_document(self): | ||||
|         """Ensure that an abstract document cannot be dropped given it | ||||
|         has no underlying collection. | ||||
|         """ | ||||
|         class AbstractDoc(Document): | ||||
|             name = StringField() | ||||
|             meta = {"abstract": True} | ||||
|  | ||||
|         self.assertRaises(OperationError, AbstractDoc.drop_collection) | ||||
|  | ||||
|     def test_reference_class_with_abstract_parent(self): | ||||
|         """Ensure that a class with an abstract parent can be referenced. | ||||
|         """ | ||||
| @@ -2632,6 +2810,38 @@ class FieldTest(unittest.TestCase): | ||||
|         Post.drop_collection() | ||||
|         User.drop_collection() | ||||
|  | ||||
|     def test_generic_reference_filter_by_dbref(self): | ||||
|         """Ensure we can search for a specific generic reference by | ||||
|         providing its ObjectId. | ||||
|         """ | ||||
|         class Doc(Document): | ||||
|             ref = GenericReferenceField() | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         doc1 = Doc.objects.create() | ||||
|         doc2 = Doc.objects.create(ref=doc1) | ||||
|  | ||||
|         doc = Doc.objects.get(ref=DBRef('doc', doc1.pk)) | ||||
|         self.assertEqual(doc, doc2) | ||||
|  | ||||
|     def test_generic_reference_filter_by_objectid(self): | ||||
|         """Ensure we can search for a specific generic reference by | ||||
|         providing its DBRef. | ||||
|         """ | ||||
|         class Doc(Document): | ||||
|             ref = GenericReferenceField() | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         doc1 = Doc.objects.create() | ||||
|         doc2 = Doc.objects.create(ref=doc1) | ||||
|  | ||||
|         self.assertTrue(isinstance(doc1.pk, ObjectId)) | ||||
|  | ||||
|         doc = Doc.objects.get(ref=doc1.pk) | ||||
|         self.assertEqual(doc, doc2) | ||||
|  | ||||
|     def test_binary_fields(self): | ||||
|         """Ensure that binary fields can be stored and retrieved. | ||||
|         """ | ||||
| @@ -2823,28 +3033,32 @@ class FieldTest(unittest.TestCase): | ||||
|                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), | ||||
|                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) | ||||
|             style = StringField(max_length=3, choices=( | ||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') | ||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') | ||||
|  | ||||
|         Shirt.drop_collection() | ||||
|  | ||||
|         shirt = Shirt() | ||||
|         shirt1 = Shirt() | ||||
|         shirt2 = Shirt() | ||||
|  | ||||
|         self.assertEqual(shirt.get_size_display(), None) | ||||
|         self.assertEqual(shirt.get_style_display(), 'Small') | ||||
|         # Make sure get_<field>_display returns the default value (or None) | ||||
|         self.assertEqual(shirt1.get_size_display(), None) | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Wide') | ||||
|  | ||||
|         shirt.size = "XXL" | ||||
|         shirt.style = "B" | ||||
|         self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') | ||||
|         self.assertEqual(shirt.get_style_display(), 'Baggy') | ||||
|         shirt1.size = 'XXL' | ||||
|         shirt1.style = 'B' | ||||
|         shirt2.size = 'M' | ||||
|         shirt2.style = 'S' | ||||
|         self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Baggy') | ||||
|         self.assertEqual(shirt2.get_size_display(), 'Medium') | ||||
|         self.assertEqual(shirt2.get_style_display(), 'Small') | ||||
|  | ||||
|         # Set as Z - an invalid choice | ||||
|         shirt.size = "Z" | ||||
|         shirt.style = "Z" | ||||
|         self.assertEqual(shirt.get_size_display(), 'Z') | ||||
|         self.assertEqual(shirt.get_style_display(), 'Z') | ||||
|         self.assertRaises(ValidationError, shirt.validate) | ||||
|  | ||||
|         Shirt.drop_collection() | ||||
|         shirt1.size = 'Z' | ||||
|         shirt1.style = 'Z' | ||||
|         self.assertEqual(shirt1.get_size_display(), 'Z') | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Z') | ||||
|         self.assertRaises(ValidationError, shirt1.validate) | ||||
|  | ||||
|     def test_simple_choices_validation(self): | ||||
|         """Ensure that value is in a container of allowed values. | ||||
| @@ -3547,6 +3761,19 @@ class FieldTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertRaises(FieldDoesNotExist, test) | ||||
|  | ||||
|     def test_long_field_is_considered_as_int64(self): | ||||
|         """ | ||||
|         Tests that long fields are stored as long in mongo, even if long value | ||||
|         is small enough to be an int. | ||||
|         """ | ||||
|         class TestLongFieldConsideredAsInt64(Document): | ||||
|             some_long = LongField() | ||||
|  | ||||
|         doc = TestLongFieldConsideredAsInt64(some_long=42).save() | ||||
|         db = get_db() | ||||
|         self.assertTrue(isinstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64)) | ||||
|         self.assertTrue(isinstance(doc.some_long, six.integer_types)) | ||||
|  | ||||
|  | ||||
| class EmbeddedDocumentListFieldTestCase(unittest.TestCase): | ||||
|  | ||||
| @@ -3934,6 +4161,17 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): | ||||
|         # modified | ||||
|         self.assertEqual(number, 2) | ||||
|  | ||||
|     def test_unicode(self): | ||||
|         """ | ||||
|         Tests that unicode strings handled correctly | ||||
|         """ | ||||
|         post = self.BlogPost(comments=[ | ||||
|             self.Comments(author='user1', message=u'сообщение'), | ||||
|             self.Comments(author='user2', message=u'хабарлама') | ||||
|         ]).save() | ||||
|         self.assertEqual(post.comments.get(message=u'сообщение').author, | ||||
|                          'user1') | ||||
|  | ||||
|     def test_save(self): | ||||
|         """ | ||||
|         Tests the save method of a List of Embedded Documents. | ||||
|   | ||||
| @@ -26,7 +26,7 @@ class NewDocumentPickleTest(Document): | ||||
|     new_field = StringField() | ||||
|  | ||||
|  | ||||
| class PickleDyanmicEmbedded(DynamicEmbeddedDocument): | ||||
| class PickleDynamicEmbedded(DynamicEmbeddedDocument): | ||||
|     date = DateTimeField(default=datetime.now) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,11 @@ | ||||
| import unittest | ||||
|  | ||||
| from convert_to_new_inheritance_model import * | ||||
| from decimalfield_as_float import * | ||||
| from refrencefield_dbref_to_object_id import * | ||||
| from referencefield_dbref_to_object_id import * | ||||
| from turn_off_inheritance import * | ||||
| from uuidfield_to_binary import * | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
							
								
								
									
										78
									
								
								tests/queryset/pickable.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								tests/queryset/pickable.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | ||||
| import pickle | ||||
| import unittest | ||||
| from pymongo.mongo_client import MongoClient | ||||
| from mongoengine import Document, StringField, IntField | ||||
| from mongoengine.connection import connect | ||||
|  | ||||
| __author__ = 'stas' | ||||
|  | ||||
| class Person(Document): | ||||
|     name = StringField() | ||||
|     age = IntField() | ||||
|  | ||||
| class TestQuerysetPickable(unittest.TestCase): | ||||
|     """ | ||||
|     Test for adding pickling support for QuerySet instances | ||||
|     See issue https://github.com/MongoEngine/mongoengine/issues/442 | ||||
|     """ | ||||
|     def setUp(self): | ||||
|         super(TestQuerysetPickable, self).setUp() | ||||
|  | ||||
|         connection = connect(db="test") #type: pymongo.mongo_client.MongoClient | ||||
|  | ||||
|         connection.drop_database("test") | ||||
|  | ||||
|         self.john = Person.objects.create( | ||||
|             name="John", | ||||
|             age=21 | ||||
|         ) | ||||
|  | ||||
|  | ||||
|     def test_picke_simple_qs(self): | ||||
|  | ||||
|         qs = Person.objects.all() | ||||
|  | ||||
|         pickle.dumps(qs) | ||||
|  | ||||
|     def _get_loaded(self, qs): | ||||
|         s = pickle.dumps(qs) | ||||
|  | ||||
|         return pickle.loads(s) | ||||
|  | ||||
|     def test_unpickle(self): | ||||
|         qs = Person.objects.all() | ||||
|  | ||||
|         loadedQs = self._get_loaded(qs) | ||||
|  | ||||
|         self.assertEqual(qs.count(), loadedQs.count()) | ||||
|  | ||||
|         #can update loadedQs | ||||
|         loadedQs.update(age=23) | ||||
|  | ||||
|         #check | ||||
|         self.assertEqual(Person.objects.first().age, 23) | ||||
|  | ||||
|     def test_pickle_support_filtration(self): | ||||
|         Person.objects.create( | ||||
|             name="Alice", | ||||
|             age=22 | ||||
|         ) | ||||
|  | ||||
|         Person.objects.create( | ||||
|             name="Bob", | ||||
|             age=23 | ||||
|         ) | ||||
|  | ||||
|         qs = Person.objects.filter(age__gte=22) | ||||
|         self.assertEqual(qs.count(), 2) | ||||
|  | ||||
|         loaded = self._get_loaded(qs) | ||||
|  | ||||
|         self.assertEqual(loaded.count(), 2) | ||||
|         self.assertEqual(loaded.filter(name="Bob").first().age, 23) | ||||
|      | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -1,28 +1,23 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import datetime | ||||
| import unittest | ||||
| import uuid | ||||
|  | ||||
| from bson import DBRef, ObjectId | ||||
| from nose.plugins.skip import SkipTest | ||||
|  | ||||
| from datetime import datetime, timedelta | ||||
|  | ||||
| import pymongo | ||||
| from pymongo.errors import ConfigurationError | ||||
| from pymongo.read_preferences import ReadPreference | ||||
|  | ||||
| from bson import ObjectId, DBRef | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_connection, get_db | ||||
| from mongoengine.python_support import PY3, IS_PYMONGO_3 | ||||
| from mongoengine.context_managers import query_counter, switch_db | ||||
| from mongoengine.queryset import (QuerySet, QuerySetManager, | ||||
|                                   MultipleObjectsReturned, DoesNotExist, | ||||
|                                   queryset_manager) | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.python_support import IS_PYMONGO_3, PY3 | ||||
| from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, | ||||
|                                   QuerySet, QuerySetManager, queryset_manager) | ||||
|  | ||||
| __all__ = ("QuerySetTest",) | ||||
|  | ||||
| @@ -184,12 +179,14 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertEqual(self.Person.objects.count(), 55) | ||||
|         self.assertEqual("Person object", "%s" % self.Person.objects[0]) | ||||
|         self.assertEqual( | ||||
|             "[<Person: Person object>, <Person: Person object>]",  "%s" % self.Person.objects[1:3]) | ||||
|         self.assertEqual( | ||||
|             "[<Person: Person object>, <Person: Person object>]",  "%s" % self.Person.objects[51:53]) | ||||
|         self.assertEqual("[<Person: Person object>, <Person: Person object>]", | ||||
|                          "%s" % self.Person.objects[1:3]) | ||||
|         self.assertEqual("[<Person: Person object>, <Person: Person object>]", | ||||
|                          "%s" % self.Person.objects[51:53]) | ||||
|  | ||||
|         # Test only after limit | ||||
|         self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None) | ||||
|  | ||||
|         # Test only after skip | ||||
|         self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None) | ||||
|  | ||||
| @@ -287,6 +284,9 @@ class QuerySetTest(unittest.TestCase): | ||||
|         blog = Blog.objects(posts__0__comments__0__name='testa').get() | ||||
|         self.assertEqual(blog, blog1) | ||||
|  | ||||
|         blog = Blog.objects(posts__0__comments__0__name='testb').get() | ||||
|         self.assertEqual(blog, blog2) | ||||
|  | ||||
|         query = Blog.objects(posts__1__comments__1__name='testb') | ||||
|         self.assertEqual(query.count(), 2) | ||||
|  | ||||
| @@ -337,9 +337,36 @@ class QuerySetTest(unittest.TestCase): | ||||
|         query = query.filter(boolfield=True) | ||||
|         self.assertEqual(query.count(), 1) | ||||
|  | ||||
|     def test_batch_size(self): | ||||
|         """Ensure that batch_size works.""" | ||||
|         class A(Document): | ||||
|             s = StringField() | ||||
|  | ||||
|         A.drop_collection() | ||||
|  | ||||
|         for i in range(100): | ||||
|             A.objects.create(s=str(i)) | ||||
|  | ||||
|         # test iterating over the result set | ||||
|         cnt = 0 | ||||
|         for a in A.objects.batch_size(10): | ||||
|             cnt += 1 | ||||
|         self.assertEqual(cnt, 100) | ||||
|  | ||||
|         # test chaining | ||||
|         qs = A.objects.all() | ||||
|         qs = qs.limit(10).batch_size(20).skip(91) | ||||
|         cnt = 0 | ||||
|         for a in qs: | ||||
|             cnt += 1 | ||||
|         self.assertEqual(cnt, 9) | ||||
|  | ||||
|         # test invalid batch size | ||||
|         qs = A.objects.batch_size(-1) | ||||
|         self.assertRaises(ValueError, lambda: list(qs)) | ||||
|  | ||||
|     def test_update_write_concern(self): | ||||
|         """Test that passing write_concern works""" | ||||
|  | ||||
|         self.Person.drop_collection() | ||||
|  | ||||
|         write_concern = {"fsync": True} | ||||
| @@ -633,39 +660,39 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True) | ||||
|         self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True) | ||||
|  | ||||
|     def test_update_related_models( self ): | ||||
|             class TestPerson( Document ): | ||||
|     def test_update_related_models(self): | ||||
|             class TestPerson(Document): | ||||
|                 name = StringField() | ||||
|  | ||||
|             class TestOrganization( Document ): | ||||
|             class TestOrganization(Document): | ||||
|                 name = StringField() | ||||
|                 owner = ReferenceField( TestPerson ) | ||||
|                 owner = ReferenceField(TestPerson) | ||||
|  | ||||
|             TestPerson.drop_collection() | ||||
|             TestOrganization.drop_collection() | ||||
|  | ||||
|             p = TestPerson( name='p1' ) | ||||
|             p = TestPerson(name='p1') | ||||
|             p.save() | ||||
|             o = TestOrganization( name='o1' ) | ||||
|             o = TestOrganization(name='o1') | ||||
|             o.save() | ||||
|  | ||||
|             o.owner = p | ||||
|             p.name = 'p2' | ||||
|  | ||||
|             self.assertEqual( o._get_changed_fields(), [ 'owner' ] ) | ||||
|             self.assertEqual( p._get_changed_fields(), [ 'name' ] ) | ||||
|             self.assertEqual(o._get_changed_fields(), ['owner']) | ||||
|             self.assertEqual(p._get_changed_fields(), ['name']) | ||||
|  | ||||
|             o.save() | ||||
|  | ||||
|             self.assertEqual( o._get_changed_fields(), [] ) | ||||
|             self.assertEqual( p._get_changed_fields(), [ 'name' ] ) # Fails; it's empty | ||||
|             self.assertEqual(o._get_changed_fields(), []) | ||||
|             self.assertEqual(p._get_changed_fields(), ['name'])  # Fails; it's empty | ||||
|  | ||||
|             # This will do NOTHING at all, even though we changed the name | ||||
|             p.save() | ||||
|  | ||||
|             p.reload() | ||||
|  | ||||
|             self.assertEqual( p.name, 'p2' ) # Fails; it's still `p1` | ||||
|             self.assertEqual(p.name, 'p2')  # Fails; it's still `p1` | ||||
|  | ||||
|     def test_upsert(self): | ||||
|         self.Person.drop_collection() | ||||
| @@ -694,7 +721,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(30, bobby.age) | ||||
|         self.assertEqual(bob.id, bobby.id) | ||||
|  | ||||
|  | ||||
|     def test_set_on_insert(self): | ||||
|         self.Person.drop_collection() | ||||
|  | ||||
| @@ -1113,24 +1139,29 @@ class QuerySetTest(unittest.TestCase): | ||||
|         blog_2.save() | ||||
|         blog_3.save() | ||||
|  | ||||
|         blog_post_1 = BlogPost(blog=blog_1, title="Blog Post #1", | ||||
|                                is_published=True, | ||||
|                                published_date=datetime(2010, 1, 5, 0, 0, 0)) | ||||
|         blog_post_2 = BlogPost(blog=blog_2, title="Blog Post #2", | ||||
|                                is_published=True, | ||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) | ||||
|         blog_post_3 = BlogPost(blog=blog_3, title="Blog Post #3", | ||||
|                                is_published=True, | ||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) | ||||
|  | ||||
|         blog_post_1.save() | ||||
|         blog_post_2.save() | ||||
|         blog_post_3.save() | ||||
|         BlogPost.objects.create( | ||||
|             blog=blog_1, | ||||
|             title="Blog Post #1", | ||||
|             is_published=True, | ||||
|             published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) | ||||
|         ) | ||||
|         BlogPost.objects.create( | ||||
|             blog=blog_2, | ||||
|             title="Blog Post #2", | ||||
|             is_published=True, | ||||
|             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||
|         ) | ||||
|         BlogPost.objects.create( | ||||
|             blog=blog_3, | ||||
|             title="Blog Post #3", | ||||
|             is_published=True, | ||||
|             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||
|         ) | ||||
|  | ||||
|         # find all published blog posts before 2010-01-07 | ||||
|         published_posts = BlogPost.published() | ||||
|         published_posts = published_posts.filter( | ||||
|             published_date__lt=datetime(2010, 1, 7, 0, 0, 0)) | ||||
|             published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0)) | ||||
|         self.assertEqual(published_posts.count(), 2) | ||||
|  | ||||
|         blog_posts = BlogPost.objects | ||||
| @@ -1161,16 +1192,18 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         blog_post_1 = BlogPost(title="Blog Post #1", | ||||
|                                published_date=datetime(2010, 1, 5, 0, 0, 0)) | ||||
|         blog_post_2 = BlogPost(title="Blog Post #2", | ||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) | ||||
|         blog_post_3 = BlogPost(title="Blog Post #3", | ||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) | ||||
|  | ||||
|         blog_post_1.save() | ||||
|         blog_post_2.save() | ||||
|         blog_post_3.save() | ||||
|         blog_post_1 = BlogPost.objects.create( | ||||
|             title="Blog Post #1", | ||||
|             published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_2 = BlogPost.objects.create( | ||||
|             title="Blog Post #2", | ||||
|             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_3 = BlogPost.objects.create( | ||||
|             title="Blog Post #3", | ||||
|             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||
|         ) | ||||
|  | ||||
|         # get the "first" BlogPost using default ordering | ||||
|         # from BlogPost.meta.ordering | ||||
| @@ -1219,7 +1252,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             } | ||||
|  | ||||
|         BlogPost.objects.create( | ||||
|             title='whatever', published_date=datetime.utcnow()) | ||||
|             title='whatever', published_date=datetime.datetime.utcnow()) | ||||
|  | ||||
|         with db_ops_tracker() as q: | ||||
|             BlogPost.objects.get(title='whatever') | ||||
| @@ -1233,7 +1266,8 @@ class QuerySetTest(unittest.TestCase): | ||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||
|  | ||||
|     def test_find_embedded(self): | ||||
|         """Ensure that an embedded document is properly returned from a query. | ||||
|         """Ensure that an embedded document is properly returned from | ||||
|         a query. | ||||
|         """ | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
| @@ -1244,16 +1278,31 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost(content='Had a good coffee today...') | ||||
|         post.author = User(name='Test User') | ||||
|         post.save() | ||||
|         BlogPost.objects.create( | ||||
|             author=User(name='Test User'), | ||||
|             content='Had a good coffee today...' | ||||
|         ) | ||||
|  | ||||
|         result = BlogPost.objects.first() | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|     def test_find_empty_embedded(self): | ||||
|         """Ensure that you can save and find an empty embedded document.""" | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             content = StringField() | ||||
|             author = EmbeddedDocumentField(User) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         BlogPost.objects.create(content='Anonymous post...') | ||||
|  | ||||
|         result = BlogPost.objects.get(author=None) | ||||
|         self.assertEqual(result.author, None) | ||||
|  | ||||
|     def test_find_dict_item(self): | ||||
|         """Ensure that DictField items may be found. | ||||
|         """ | ||||
| @@ -2082,18 +2131,22 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         blog_post_3 = BlogPost(title="Blog Post #3", | ||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) | ||||
|         blog_post_2 = BlogPost(title="Blog Post #2", | ||||
|                                published_date=datetime(2010, 1, 5, 0, 0, 0)) | ||||
|         blog_post_4 = BlogPost(title="Blog Post #4", | ||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) | ||||
|         blog_post_1 = BlogPost(title="Blog Post #1", published_date=None) | ||||
|  | ||||
|         blog_post_3.save() | ||||
|         blog_post_1.save() | ||||
|         blog_post_4.save() | ||||
|         blog_post_2.save() | ||||
|         blog_post_3 = BlogPost.objects.create( | ||||
|             title="Blog Post #3", | ||||
|             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_2 = BlogPost.objects.create( | ||||
|             title="Blog Post #2", | ||||
|             published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_4 = BlogPost.objects.create( | ||||
|             title="Blog Post #4", | ||||
|             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_1 = BlogPost.objects.create( | ||||
|             title="Blog Post #1", | ||||
|             published_date=None | ||||
|         ) | ||||
|  | ||||
|         expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4] | ||||
|         self.assertSequence(BlogPost.objects.order_by('published_date'), | ||||
| @@ -2112,16 +2165,18 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         blog_post_1 = BlogPost(title="A", | ||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) | ||||
|         blog_post_2 = BlogPost(title="B", | ||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) | ||||
|         blog_post_3 = BlogPost(title="C", | ||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) | ||||
|  | ||||
|         blog_post_2.save() | ||||
|         blog_post_3.save() | ||||
|         blog_post_1.save() | ||||
|         blog_post_1 = BlogPost.objects.create( | ||||
|             title="A", | ||||
|             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_2 = BlogPost.objects.create( | ||||
|             title="B", | ||||
|             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||
|         ) | ||||
|         blog_post_3 = BlogPost.objects.create( | ||||
|             title="C", | ||||
|             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||
|         ) | ||||
|  | ||||
|         qs = BlogPost.objects.order_by('published_date', 'title') | ||||
|         expected = [blog_post_1, blog_post_2, blog_post_3] | ||||
| @@ -2187,6 +2242,21 @@ class QuerySetTest(unittest.TestCase): | ||||
|             a.author.name for a in Author.objects.order_by('-author__age')] | ||||
|         self.assertEqual(names, ['User A', 'User B', 'User C']) | ||||
|  | ||||
|     def test_comment(self): | ||||
|         """Make sure adding a comment to the query works.""" | ||||
|         class User(Document): | ||||
|             age = IntField() | ||||
|  | ||||
|         with db_ops_tracker() as q: | ||||
|             adult = (User.objects.filter(age__gte=18) | ||||
|                 .comment('looking for an adult') | ||||
|                 .first()) | ||||
|             ops = q.get_ops() | ||||
|             self.assertEqual(len(ops), 1) | ||||
|             op = ops[0] | ||||
|             self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}}) | ||||
|             self.assertEqual(op['query']['$comment'], 'looking for an adult') | ||||
|  | ||||
|     def test_map_reduce(self): | ||||
|         """Ensure map/reduce is both mapping and reducing. | ||||
|         """ | ||||
| @@ -2425,7 +2495,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Link.drop_collection() | ||||
|  | ||||
|         now = datetime.utcnow() | ||||
|         now = datetime.datetime.utcnow() | ||||
|  | ||||
|         # Note: Test data taken from a custom Reddit homepage on | ||||
|         # Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should | ||||
| @@ -2434,27 +2504,27 @@ class QuerySetTest(unittest.TestCase): | ||||
|         Link(title="Google Buzz auto-followed a woman's abusive ex ...", | ||||
|              up_votes=1079, | ||||
|              down_votes=553, | ||||
|              submitted=now - timedelta(hours=4)).save() | ||||
|              submitted=now - datetime.timedelta(hours=4)).save() | ||||
|         Link(title="We did it! Barbie is a computer engineer.", | ||||
|              up_votes=481, | ||||
|              down_votes=124, | ||||
|              submitted=now - timedelta(hours=2)).save() | ||||
|              submitted=now - datetime.timedelta(hours=2)).save() | ||||
|         Link(title="This Is A Mosquito Getting Killed By A Laser", | ||||
|              up_votes=1446, | ||||
|              down_votes=530, | ||||
|              submitted=now - timedelta(hours=13)).save() | ||||
|              submitted=now - datetime.timedelta(hours=13)).save() | ||||
|         Link(title="Arabic flashcards land physics student in jail.", | ||||
|              up_votes=215, | ||||
|              down_votes=105, | ||||
|              submitted=now - timedelta(hours=6)).save() | ||||
|              submitted=now - datetime.timedelta(hours=6)).save() | ||||
|         Link(title="The Burger Lab: Presenting, the Flood Burger", | ||||
|              up_votes=48, | ||||
|              down_votes=17, | ||||
|              submitted=now - timedelta(hours=5)).save() | ||||
|              submitted=now - datetime.timedelta(hours=5)).save() | ||||
|         Link(title="How to see polarization with the naked eye", | ||||
|              up_votes=74, | ||||
|              down_votes=13, | ||||
|              submitted=now - timedelta(hours=10)).save() | ||||
|              submitted=now - datetime.timedelta(hours=10)).save() | ||||
|  | ||||
|         map_f = """ | ||||
|             function() { | ||||
| @@ -2504,7 +2574,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # provide the reddit epoch (used for ranking) as a variable available | ||||
|         # to all phases of the map/reduce operation: map, reduce, and finalize. | ||||
|         reddit_epoch = mktime(datetime(2005, 12, 8, 7, 46, 43).timetuple()) | ||||
|         reddit_epoch = mktime(datetime.datetime(2005, 12, 8, 7, 46, 43).timetuple()) | ||||
|         scope = {'reddit_epoch': reddit_epoch} | ||||
|  | ||||
|         # run a map/reduce operation across all links. ordering is set | ||||
| @@ -2766,25 +2836,15 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         avg = float(sum(ages)) / (len(ages) + 1)  # take into account the 0 | ||||
|         self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) | ||||
|         self.assertAlmostEqual( | ||||
|             int(self.Person.objects.aggregate_average('age')), avg | ||||
|         ) | ||||
|  | ||||
|         self.Person(name='ageless person').save() | ||||
|         self.assertEqual(int(self.Person.objects.average('age')), avg) | ||||
|         self.assertEqual( | ||||
|             int(self.Person.objects.aggregate_average('age')), avg | ||||
|         ) | ||||
|  | ||||
|         # dot notation | ||||
|         self.Person( | ||||
|             name='person meta', person_meta=self.PersonMeta(weight=0)).save() | ||||
|         self.assertAlmostEqual( | ||||
|             int(self.Person.objects.average('person_meta.weight')), 0) | ||||
|         self.assertAlmostEqual( | ||||
|             int(self.Person.objects.aggregate_average('person_meta.weight')), | ||||
|             0 | ||||
|         ) | ||||
|  | ||||
|         for i, weight in enumerate(ages): | ||||
|             self.Person( | ||||
| @@ -2793,19 +2853,11 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertAlmostEqual( | ||||
|             int(self.Person.objects.average('person_meta.weight')), avg | ||||
|         ) | ||||
|         self.assertAlmostEqual( | ||||
|             int(self.Person.objects.aggregate_average('person_meta.weight')), | ||||
|             avg | ||||
|         ) | ||||
|  | ||||
|         self.Person(name='test meta none').save() | ||||
|         self.assertEqual( | ||||
|             int(self.Person.objects.average('person_meta.weight')), avg | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             int(self.Person.objects.aggregate_average('person_meta.weight')), | ||||
|             avg | ||||
|         ) | ||||
|  | ||||
|         # test summing over a filtered queryset | ||||
|         over_50 = [a for a in ages if a >= 50] | ||||
| @@ -2814,10 +2866,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|             self.Person.objects.filter(age__gte=50).average('age'), | ||||
|             avg | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.filter(age__gte=50).aggregate_average('age'), | ||||
|             avg | ||||
|         ) | ||||
|  | ||||
|     def test_sum(self): | ||||
|         """Ensure that field can be summed over correctly. | ||||
| @@ -2827,15 +2875,9 @@ class QuerySetTest(unittest.TestCase): | ||||
|             self.Person(name='test%s' % i, age=age).save() | ||||
|  | ||||
|         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.aggregate_sum('age'), sum(ages) | ||||
|         ) | ||||
|  | ||||
|         self.Person(name='ageless person').save() | ||||
|         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.aggregate_sum('age'), sum(ages) | ||||
|         ) | ||||
|  | ||||
|         for i, age in enumerate(ages): | ||||
|             self.Person(name='test meta%s' % | ||||
| @@ -2844,26 +2886,43 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.sum('person_meta.weight'), sum(ages) | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.aggregate_sum('person_meta.weight'), | ||||
|             sum(ages) | ||||
|         ) | ||||
|  | ||||
|         self.Person(name='weightless person').save() | ||||
|         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.aggregate_sum('age'), sum(ages) | ||||
|         ) | ||||
|  | ||||
|         # test summing over a filtered queryset | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.filter(age__gte=50).sum('age'), | ||||
|             sum([a for a in ages if a >= 50]) | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             self.Person.objects.filter(age__gte=50).aggregate_sum('age'), | ||||
|             sum([a for a in ages if a >= 50]) | ||||
|         ) | ||||
|  | ||||
|     def test_sum_over_db_field(self): | ||||
|         """Ensure that a field mapped to a db field with a different name | ||||
|         can be summed over correctly. | ||||
|         """ | ||||
|         class UserVisit(Document): | ||||
|             num_visits = IntField(db_field='visits') | ||||
|  | ||||
|         UserVisit.drop_collection() | ||||
|  | ||||
|         UserVisit.objects.create(num_visits=10) | ||||
|         UserVisit.objects.create(num_visits=5) | ||||
|  | ||||
|         self.assertEqual(UserVisit.objects.sum('num_visits'), 15) | ||||
|  | ||||
|     def test_average_over_db_field(self): | ||||
|         """Ensure that a field mapped to a db field with a different name | ||||
|         can have its average computed correctly. | ||||
|         """ | ||||
|         class UserVisit(Document): | ||||
|             num_visits = IntField(db_field='visits') | ||||
|  | ||||
|         UserVisit.drop_collection() | ||||
|  | ||||
|         UserVisit.objects.create(num_visits=20) | ||||
|         UserVisit.objects.create(num_visits=10) | ||||
|  | ||||
|         self.assertEqual(UserVisit.objects.average('num_visits'), 15) | ||||
|  | ||||
|     def test_embedded_average(self): | ||||
|         class Pay(EmbeddedDocument): | ||||
| @@ -2876,21 +2935,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         Doc(name=u"Wilson Junior", | ||||
|             pay=Pay(value=150)).save() | ||||
|         Doc(name='Wilson Junior', pay=Pay(value=150)).save() | ||||
|         Doc(name='Isabella Luanna', pay=Pay(value=530)).save() | ||||
|         Doc(name='Tayza mariana', pay=Pay(value=165)).save() | ||||
|         Doc(name='Eliana Costa', pay=Pay(value=115)).save() | ||||
|  | ||||
|         Doc(name=u"Isabella Luanna", | ||||
|             pay=Pay(value=530)).save() | ||||
|  | ||||
|         Doc(name=u"Tayza mariana", | ||||
|             pay=Pay(value=165)).save() | ||||
|  | ||||
|         Doc(name=u"Eliana Costa", | ||||
|             pay=Pay(value=115)).save() | ||||
|  | ||||
|         self.assertEqual( | ||||
|             Doc.objects.average('pay.value'), | ||||
|             240) | ||||
|         self.assertEqual(Doc.objects.average('pay.value'), 240) | ||||
|  | ||||
|     def test_embedded_array_average(self): | ||||
|         class Pay(EmbeddedDocument): | ||||
| @@ -2898,26 +2948,16 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         class Doc(Document): | ||||
|             name = StringField() | ||||
|             pay = EmbeddedDocumentField( | ||||
|                 Pay) | ||||
|             pay = EmbeddedDocumentField(Pay) | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         Doc(name=u"Wilson Junior", | ||||
|             pay=Pay(values=[150, 100])).save() | ||||
|         Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() | ||||
|         Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() | ||||
|         Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() | ||||
|         Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() | ||||
|  | ||||
|         Doc(name=u"Isabella Luanna", | ||||
|             pay=Pay(values=[530, 100])).save() | ||||
|  | ||||
|         Doc(name=u"Tayza mariana", | ||||
|             pay=Pay(values=[165, 100])).save() | ||||
|  | ||||
|         Doc(name=u"Eliana Costa", | ||||
|             pay=Pay(values=[115, 100])).save() | ||||
|  | ||||
|         self.assertEqual( | ||||
|             Doc.objects.average('pay.values'), | ||||
|             170) | ||||
|         self.assertEqual(Doc.objects.average('pay.values'), 170) | ||||
|  | ||||
|     def test_array_average(self): | ||||
|         class Doc(Document): | ||||
| @@ -2930,9 +2970,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         Doc(values=[165, 100]).save() | ||||
|         Doc(values=[115, 100]).save() | ||||
|  | ||||
|         self.assertEqual( | ||||
|             Doc.objects.average('values'), | ||||
|             170) | ||||
|         self.assertEqual(Doc.objects.average('values'), 170) | ||||
|  | ||||
|     def test_embedded_sum(self): | ||||
|         class Pay(EmbeddedDocument): | ||||
| @@ -2940,26 +2978,16 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         class Doc(Document): | ||||
|             name = StringField() | ||||
|             pay = EmbeddedDocumentField( | ||||
|                 Pay) | ||||
|             pay = EmbeddedDocumentField(Pay) | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         Doc(name=u"Wilson Junior", | ||||
|             pay=Pay(value=150)).save() | ||||
|         Doc(name='Wilson Junior', pay=Pay(value=150)).save() | ||||
|         Doc(name='Isabella Luanna', pay=Pay(value=530)).save() | ||||
|         Doc(name='Tayza mariana', pay=Pay(value=165)).save() | ||||
|         Doc(name='Eliana Costa', pay=Pay(value=115)).save() | ||||
|  | ||||
|         Doc(name=u"Isabella Luanna", | ||||
|             pay=Pay(value=530)).save() | ||||
|  | ||||
|         Doc(name=u"Tayza mariana", | ||||
|             pay=Pay(value=165)).save() | ||||
|  | ||||
|         Doc(name=u"Eliana Costa", | ||||
|             pay=Pay(value=115)).save() | ||||
|  | ||||
|         self.assertEqual( | ||||
|             Doc.objects.sum('pay.value'), | ||||
|             960) | ||||
|         self.assertEqual(Doc.objects.sum('pay.value'), 960) | ||||
|  | ||||
|     def test_embedded_array_sum(self): | ||||
|         class Pay(EmbeddedDocument): | ||||
| @@ -2967,26 +2995,16 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         class Doc(Document): | ||||
|             name = StringField() | ||||
|             pay = EmbeddedDocumentField( | ||||
|                 Pay) | ||||
|             pay = EmbeddedDocumentField(Pay) | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         Doc(name=u"Wilson Junior", | ||||
|             pay=Pay(values=[150, 100])).save() | ||||
|         Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() | ||||
|         Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() | ||||
|         Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() | ||||
|         Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() | ||||
|  | ||||
|         Doc(name=u"Isabella Luanna", | ||||
|             pay=Pay(values=[530, 100])).save() | ||||
|  | ||||
|         Doc(name=u"Tayza mariana", | ||||
|             pay=Pay(values=[165, 100])).save() | ||||
|  | ||||
|         Doc(name=u"Eliana Costa", | ||||
|             pay=Pay(values=[115, 100])).save() | ||||
|  | ||||
|         self.assertEqual( | ||||
|             Doc.objects.sum('pay.values'), | ||||
|             1360) | ||||
|         self.assertEqual(Doc.objects.sum('pay.values'), 1360) | ||||
|  | ||||
|     def test_array_sum(self): | ||||
|         class Doc(Document): | ||||
| @@ -2999,9 +3017,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         Doc(values=[165, 100]).save() | ||||
|         Doc(values=[115, 100]).save() | ||||
|  | ||||
|         self.assertEqual( | ||||
|             Doc.objects.sum('values'), | ||||
|             1360) | ||||
|         self.assertEqual(Doc.objects.sum('values'), 1360) | ||||
|  | ||||
|     def test_distinct(self): | ||||
|         """Ensure that the QuerySet.distinct method works. | ||||
| @@ -3178,13 +3194,11 @@ class QuerySetTest(unittest.TestCase): | ||||
|         mark_twain = Author(name="Mark Twain") | ||||
|         john_tolkien = Author(name="John Ronald Reuel Tolkien") | ||||
|  | ||||
|         book = Book(title="Tom Sawyer", authors=[mark_twain]).save() | ||||
|         book = Book( | ||||
|             title="The Lord of the Rings", authors=[john_tolkien]).save() | ||||
|         book = Book( | ||||
|             title="The Stories", authors=[mark_twain, john_tolkien]).save() | ||||
|         authors = Book.objects.distinct("authors") | ||||
|         Book.objects.create(title="Tom Sawyer", authors=[mark_twain]) | ||||
|         Book.objects.create(title="The Lord of the Rings", authors=[john_tolkien]) | ||||
|         Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) | ||||
|  | ||||
|         authors = Book.objects.distinct("authors") | ||||
|         self.assertEqual(authors, [mark_twain, john_tolkien]) | ||||
|  | ||||
|     def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): | ||||
| @@ -3214,17 +3228,14 @@ class QuerySetTest(unittest.TestCase): | ||||
|         mark_twain = Author(name="Mark Twain", country=scotland) | ||||
|         john_tolkien = Author(name="John Ronald Reuel Tolkien", country=tibet) | ||||
|  | ||||
|         book = Book(title="Tom Sawyer", authors=[mark_twain]).save() | ||||
|         book = Book( | ||||
|             title="The Lord of the Rings", authors=[john_tolkien]).save() | ||||
|         book = Book( | ||||
|             title="The Stories", authors=[mark_twain, john_tolkien]).save() | ||||
|         country_list = Book.objects.distinct("authors.country") | ||||
|         Book.objects.create(title="Tom Sawyer", authors=[mark_twain]) | ||||
|         Book.objects.create(title="The Lord of the Rings", authors=[john_tolkien]) | ||||
|         Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) | ||||
|  | ||||
|         country_list = Book.objects.distinct("authors.country") | ||||
|         self.assertEqual(country_list, [scotland, tibet]) | ||||
|  | ||||
|         continent_list = Book.objects.distinct("authors.country.continent") | ||||
|  | ||||
|         self.assertEqual(continent_list, [europe, asia]) | ||||
|  | ||||
|     def test_distinct_ListField_ReferenceField(self): | ||||
| @@ -3256,7 +3267,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         class BlogPost(Document): | ||||
|             tags = ListField(StringField()) | ||||
|             deleted = BooleanField(default=False) | ||||
|             date = DateTimeField(default=datetime.now) | ||||
|             date = DateTimeField(default=datetime.datetime.now) | ||||
|  | ||||
|             @queryset_manager | ||||
|             def objects(cls, qryset): | ||||
| @@ -3613,6 +3624,15 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(MyDoc.objects.count(), 10) | ||||
|         self.assertEqual(MyDoc.objects.none().count(), 0) | ||||
|  | ||||
|     def test_count_list_embedded(self): | ||||
|         class B(EmbeddedDocument): | ||||
|             c = StringField() | ||||
|  | ||||
|         class A(Document): | ||||
|             b = ListField(EmbeddedDocumentField(B)) | ||||
|  | ||||
|         self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0) | ||||
|  | ||||
|     def test_call_after_limits_set(self): | ||||
|         """Ensure that re-filtering after slicing works | ||||
|         """ | ||||
| @@ -4070,14 +4090,14 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual( | ||||
|             "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) | ||||
|         if PY3: | ||||
|             self.assertEqual( | ||||
|                 "['A1', 'A2']",  "%s" % self.Person.objects.order_by('age').scalar('name')[1:3]) | ||||
|             self.assertEqual("['A51', 'A52']",  "%s" % self.Person.objects.order_by( | ||||
|             self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( | ||||
|                 'age').scalar('name')[1:3]) | ||||
|             self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( | ||||
|                 'age').scalar('name')[51:53]) | ||||
|         else: | ||||
|             self.assertEqual("[u'A1', u'A2']",  "%s" % self.Person.objects.order_by( | ||||
|             self.assertEqual("[u'A1', u'A2']", "%s" % self.Person.objects.order_by( | ||||
|                 'age').scalar('name')[1:3]) | ||||
|             self.assertEqual("[u'A51', u'A52']",  "%s" % self.Person.objects.order_by( | ||||
|             self.assertEqual("[u'A51', u'A52']", "%s" % self.Person.objects.order_by( | ||||
|                 'age').scalar('name')[51:53]) | ||||
|  | ||||
|         # with_id and in_bulk | ||||
| @@ -4086,12 +4106,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|                          self.Person.objects.scalar('name').with_id(person.id)) | ||||
|  | ||||
|         pks = self.Person.objects.order_by('age').scalar('pk')[1:3] | ||||
|         names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() | ||||
|         if PY3: | ||||
|             self.assertEqual("['A1', 'A2']",  "%s" % sorted( | ||||
|                 self.Person.objects.scalar('name').in_bulk(list(pks)).values())) | ||||
|             expected = "['A1', 'A2']" | ||||
|         else: | ||||
|             self.assertEqual("[u'A1', u'A2']",  "%s" % sorted( | ||||
|                 self.Person.objects.scalar('name').in_bulk(list(pks)).values())) | ||||
|             expected = "[u'A1', u'A2']" | ||||
|         self.assertEqual(expected, "%s" % sorted(names)) | ||||
|  | ||||
|     def test_elem_match(self): | ||||
|         class Foo(EmbeddedDocument): | ||||
| @@ -4114,6 +4134,10 @@ class QuerySetTest(unittest.TestCase): | ||||
|                       Foo(shape="circle", color="purple", thick=False)]) | ||||
|         b2.save() | ||||
|  | ||||
|         b3 = Bar(foo=[Foo(shape="square", thick=True), | ||||
|                       Foo(shape="circle", color="purple", thick=False)]) | ||||
|         b3.save() | ||||
|  | ||||
|         ak = list( | ||||
|             Bar.objects(foo__match={'shape': "square", "color": "purple"})) | ||||
|         self.assertEqual([b1], ak) | ||||
| @@ -4133,6 +4157,13 @@ class QuerySetTest(unittest.TestCase): | ||||
|             Bar.objects(foo__match={'shape': "square", "color__exists": True})) | ||||
|         self.assertEqual([b1, b2], ak) | ||||
|  | ||||
|         ak = list( | ||||
|             Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False})) | ||||
|         self.assertEqual([b3], ak) | ||||
|  | ||||
|         ak = list( | ||||
|             Bar.objects(foo__match={'shape': "square", "color__exists": False})) | ||||
|         self.assertEqual([b3], ak) | ||||
|  | ||||
|     def test_upsert_includes_cls(self): | ||||
|         """Upserts should include _cls information for inheritable classes | ||||
| @@ -4177,7 +4208,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             txt = StringField() | ||||
|  | ||||
|             meta = { | ||||
|                 'indexes': [ 'txt' ] | ||||
|                 'indexes': ['txt'] | ||||
|             } | ||||
|  | ||||
|         Bar.drop_collection() | ||||
| @@ -4192,49 +4223,49 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # read_preference as a kwarg | ||||
|         bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual( | ||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._read_preference, | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||
|             ReadPreference.SECONDARY_PREFERRED) | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|  | ||||
|         # read_preference as a query set method | ||||
|         bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual( | ||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._read_preference, | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||
|             ReadPreference.SECONDARY_PREFERRED) | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|  | ||||
|         # read_preference after skip | ||||
|         bars = Bar.objects.skip(1) \ | ||||
|             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual( | ||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._read_preference, | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||
|             ReadPreference.SECONDARY_PREFERRED) | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|  | ||||
|         # read_preference after limit | ||||
|         bars = Bar.objects.limit(1) \ | ||||
|             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual( | ||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._read_preference, | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||
|             ReadPreference.SECONDARY_PREFERRED) | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|  | ||||
|         # read_preference after order_by | ||||
|         bars = Bar.objects.order_by('txt') \ | ||||
|             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual( | ||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._read_preference, | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||
|             ReadPreference.SECONDARY_PREFERRED) | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|  | ||||
|         # read_preference after hint | ||||
|         bars = Bar.objects.hint([('txt', 1)]) \ | ||||
|             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual( | ||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._read_preference, | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||
|             ReadPreference.SECONDARY_PREFERRED) | ||||
|                          ReadPreference.SECONDARY_PREFERRED) | ||||
|  | ||||
|     def test_json_simple(self): | ||||
|  | ||||
| @@ -4270,7 +4301,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             int_field = IntField(default=1) | ||||
|             float_field = FloatField(default=1.1) | ||||
|             boolean_field = BooleanField(default=True) | ||||
|             datetime_field = DateTimeField(default=datetime.now) | ||||
|             datetime_field = DateTimeField(default=datetime.datetime.now) | ||||
|             embedded_document_field = EmbeddedDocumentField( | ||||
|                 EmbeddedDoc, default=lambda: EmbeddedDoc()) | ||||
|             list_field = ListField(default=lambda: [1, 2, 3]) | ||||
| @@ -4280,7 +4311,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|                 Simple, default=lambda: Simple().save()) | ||||
|             map_field = MapField(IntField(), default=lambda: {"simple": 1}) | ||||
|             decimal_field = DecimalField(default=1.0) | ||||
|             complex_datetime_field = ComplexDateTimeField(default=datetime.now) | ||||
|             complex_datetime_field = ComplexDateTimeField(default=datetime.datetime.now) | ||||
|             url_field = URLField(default="http://mongoengine.org") | ||||
|             dynamic_field = DynamicField(default=1) | ||||
|             generic_reference_field = GenericReferenceField( | ||||
| @@ -4627,8 +4658,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         B.drop_collection() | ||||
|  | ||||
|         a = A.objects.create(id='custom_id') | ||||
|  | ||||
|         b = B.objects.create(a=a) | ||||
|         B.objects.create(a=a) | ||||
|  | ||||
|         self.assertEqual(B.objects.count(), 1) | ||||
|         self.assertEqual(B.objects.get(a=a).a, a) | ||||
| @@ -4888,5 +4918,56 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||
|  | ||||
|     def test_len_during_iteration(self): | ||||
|         """Tests that calling len on a queyset during iteration doesn't | ||||
|         stop paging. | ||||
|         """ | ||||
|         class Data(Document): | ||||
|             pass | ||||
|  | ||||
|         for i in xrange(300): | ||||
|             Data().save() | ||||
|  | ||||
|         records = Data.objects.limit(250) | ||||
|  | ||||
|         # This should pull all 250 docs from mongo and populate the result | ||||
|         # cache | ||||
|         len(records) | ||||
|  | ||||
|         # Assert that iterating over documents in the qs touches every | ||||
|         # document even if we call len(qs) midway through the iteration. | ||||
|         for i, r in enumerate(records): | ||||
|             if i == 58: | ||||
|                 len(records) | ||||
|         self.assertEqual(i, 249) | ||||
|  | ||||
|         # Assert the same behavior is true even if we didn't pre-populate the | ||||
|         # result cache. | ||||
|         records = Data.objects.limit(250) | ||||
|         for i, r in enumerate(records): | ||||
|             if i == 58: | ||||
|                 len(records) | ||||
|         self.assertEqual(i, 249) | ||||
|  | ||||
|     def test_iteration_within_iteration(self): | ||||
|         """You should be able to reliably iterate over all the documents | ||||
|         in a given queryset even if there are multiple iterations of it | ||||
|         happening at the same time. | ||||
|         """ | ||||
|         class Data(Document): | ||||
|             pass | ||||
|  | ||||
|         for i in xrange(300): | ||||
|             Data().save() | ||||
|  | ||||
|         qs = Data.objects.limit(250) | ||||
|         for i, doc in enumerate(qs): | ||||
|             for j, doc2 in enumerate(qs): | ||||
|                 pass | ||||
|  | ||||
|         self.assertEqual(i, 249) | ||||
|         self.assertEqual(j, 249) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -1,11 +1,7 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.queryset import Q | ||||
| from mongoengine.queryset import transform | ||||
| from mongoengine.queryset import Q, transform | ||||
|  | ||||
| __all__ = ("TransformTest",) | ||||
|  | ||||
| @@ -41,8 +37,8 @@ class TransformTest(unittest.TestCase): | ||||
|         DicDoc.drop_collection() | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         DicDoc().save() | ||||
|         doc = Doc().save() | ||||
|         dic_doc = DicDoc().save() | ||||
|  | ||||
|         for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): | ||||
|             update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) | ||||
| @@ -55,7 +51,6 @@ class TransformTest(unittest.TestCase): | ||||
|         update = transform.update(DicDoc, pull__dictField__test=doc) | ||||
|         self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) | ||||
|  | ||||
|  | ||||
|     def test_query_field_name(self): | ||||
|         """Ensure that the correct field name is used when querying. | ||||
|         """ | ||||
| @@ -156,26 +151,33 @@ class TransformTest(unittest.TestCase): | ||||
|         class Doc(Document): | ||||
|             meta = {'allow_inheritance': False} | ||||
|  | ||||
|         raw_query = Doc.objects(__raw__={'deleted': False, | ||||
|                                 'scraped': 'yes', | ||||
|                                 '$nor': [{'views.extracted': 'no'}, | ||||
|                                          {'attachments.views.extracted':'no'}] | ||||
|                                 })._query | ||||
|         raw_query = Doc.objects(__raw__={ | ||||
|             'deleted': False, | ||||
|             'scraped': 'yes', | ||||
|             '$nor': [ | ||||
|                 {'views.extracted': 'no'}, | ||||
|                 {'attachments.views.extracted': 'no'} | ||||
|             ] | ||||
|         })._query | ||||
|  | ||||
|         expected = {'deleted': False, 'scraped': 'yes', | ||||
|                     '$nor': [{'views.extracted': 'no'}, | ||||
|                              {'attachments.views.extracted': 'no'}]} | ||||
|         self.assertEqual(expected, raw_query) | ||||
|         self.assertEqual(raw_query, { | ||||
|             'deleted': False, | ||||
|             'scraped': 'yes', | ||||
|             '$nor': [ | ||||
|                 {'views.extracted': 'no'}, | ||||
|                 {'attachments.views.extracted': 'no'} | ||||
|             ] | ||||
|         }) | ||||
|  | ||||
|     def test_geojson_PointField(self): | ||||
|         class Location(Document): | ||||
|             loc = PointField() | ||||
|  | ||||
|         update = transform.update(Location, set__loc=[1, 2]) | ||||
|         self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1,2]}}}) | ||||
|         self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1, 2]}}}) | ||||
|  | ||||
|         update = transform.update(Location, set__loc={"type": "Point", "coordinates": [1,2]}) | ||||
|         self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1,2]}}}) | ||||
|         update = transform.update(Location, set__loc={"type": "Point", "coordinates": [1, 2]}) | ||||
|         self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1, 2]}}}) | ||||
|  | ||||
|     def test_geojson_LineStringField(self): | ||||
|         class Location(Document): | ||||
| @@ -224,6 +226,10 @@ class TransformTest(unittest.TestCase): | ||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||
|         self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) | ||||
|  | ||||
|         Doc.objects(id=doc.id).update(set__item__type__='sword') | ||||
|         self.assertEqual(1, Doc.objects(item__type__="sword").count()) | ||||
|         self.assertEqual(0, Doc.objects(item__type__="axe").count()) | ||||
|  | ||||
|     def test_understandable_error_raised(self): | ||||
|         class Event(Document): | ||||
|             title = StringField() | ||||
| @@ -234,5 +240,6 @@ class TransformTest(unittest.TestCase): | ||||
|         events = Event.objects(location__within=box) | ||||
|         self.assertRaises(InvalidQueryError, lambda: events.count()) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -1,14 +1,12 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import datetime | ||||
| import re | ||||
| import unittest | ||||
|  | ||||
| from bson import ObjectId | ||||
| from datetime import datetime | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.queryset import Q | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.queryset import Q | ||||
|  | ||||
| __all__ = ("QTest",) | ||||
|  | ||||
| @@ -132,12 +130,12 @@ class QTest(unittest.TestCase): | ||||
|         TestDoc(x=10).save() | ||||
|         TestDoc(y=True).save() | ||||
|  | ||||
|         self.assertEqual(query, | ||||
|         {'$and': [ | ||||
|             {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]}, | ||||
|             {'$or': [{'x': {'$lt': 100}}, {'y': True}]} | ||||
|         ]}) | ||||
|  | ||||
|         self.assertEqual(query, { | ||||
|             '$and': [ | ||||
|                 {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]}, | ||||
|                 {'$or': [{'x': {'$lt': 100}}, {'y': True}]} | ||||
|             ] | ||||
|         }) | ||||
|         self.assertEqual(2, TestDoc.objects(q1 & q2).count()) | ||||
|  | ||||
|     def test_or_and_or_combination(self): | ||||
| @@ -157,15 +155,14 @@ class QTest(unittest.TestCase): | ||||
|         q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) | ||||
|         query = (q1 | q2).to_query(TestDoc) | ||||
|  | ||||
|         self.assertEqual(query, | ||||
|             {'$or': [ | ||||
|         self.assertEqual(query, { | ||||
|             '$or': [ | ||||
|                 {'$and': [{'x': {'$gt': 0}}, | ||||
|                           {'$or': [{'y': True}, {'y': {'$exists': False}}]}]}, | ||||
|                 {'$and': [{'x': {'$lt': 100}}, | ||||
|                           {'$or': [{'y': False}, {'y': {'$exists': False}}]}]} | ||||
|             ]} | ||||
|         ) | ||||
|  | ||||
|             ] | ||||
|         }) | ||||
|         self.assertEqual(2, TestDoc.objects(q1 | q2).count()) | ||||
|  | ||||
|     def test_multiple_occurence_in_field(self): | ||||
| @@ -215,19 +212,19 @@ class QTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post1 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 8), published=False) | ||||
|         post1 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 8), published=False) | ||||
|         post1.save() | ||||
|  | ||||
|         post2 = BlogPost(title='Test 2', publish_date=datetime(2010, 1, 15), published=True) | ||||
|         post2 = BlogPost(title='Test 2', publish_date=datetime.datetime(2010, 1, 15), published=True) | ||||
|         post2.save() | ||||
|  | ||||
|         post3 = BlogPost(title='Test 3', published=True) | ||||
|         post3.save() | ||||
|  | ||||
|         post4 = BlogPost(title='Test 4', publish_date=datetime(2010, 1, 8)) | ||||
|         post4 = BlogPost(title='Test 4', publish_date=datetime.datetime(2010, 1, 8)) | ||||
|         post4.save() | ||||
|  | ||||
|         post5 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 15)) | ||||
|         post5 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 15)) | ||||
|         post5.save() | ||||
|  | ||||
|         post6 = BlogPost(title='Test 1', published=False) | ||||
| @@ -250,7 +247,7 @@ class QTest(unittest.TestCase): | ||||
|         self.assertTrue(all(obj.id in posts for obj in published_posts)) | ||||
|  | ||||
|         # Check Q object combination | ||||
|         date = datetime(2010, 1, 10) | ||||
|         date = datetime.datetime(2010, 1, 10) | ||||
|         q = BlogPost.objects(Q(publish_date__lte=date) | Q(published=True)) | ||||
|         posts = [post.id for post in q] | ||||
|  | ||||
| @@ -273,8 +270,10 @@ class QTest(unittest.TestCase): | ||||
|         # Test invalid query objs | ||||
|         def wrong_query_objs(): | ||||
|             self.Person.objects('user1') | ||||
|  | ||||
|         def wrong_query_objs_filter(): | ||||
|             self.Person.objects('user1') | ||||
|  | ||||
|         self.assertRaises(InvalidQueryError, wrong_query_objs) | ||||
|         self.assertRaises(InvalidQueryError, wrong_query_objs_filter) | ||||
|  | ||||
| @@ -284,7 +283,6 @@ class QTest(unittest.TestCase): | ||||
|         person = self.Person(name='Guido van Rossum') | ||||
|         person.save() | ||||
|  | ||||
|         import re | ||||
|         obj = self.Person.objects(Q(name=re.compile('^Gui'))).first() | ||||
|         self.assertEqual(obj, person) | ||||
|         obj = self.Person.objects(Q(name=re.compile('^gui'))).first() | ||||
|   | ||||
| @@ -88,6 +88,40 @@ class ConnectionTest(unittest.TestCase): | ||||
|         conn = get_connection('testdb7') | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|     def test_connect_with_host_list(self): | ||||
|         """Ensure that the connect() method works when host is a list | ||||
|  | ||||
|         Uses mongomock to test w/o needing multiple mongod/mongos processes | ||||
|         """ | ||||
|         try: | ||||
|             import mongomock | ||||
|         except ImportError: | ||||
|             raise SkipTest('you need mongomock installed to run this testcase') | ||||
|  | ||||
|         connect(host=['mongomock://localhost']) | ||||
|         conn = get_connection() | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|         connect(host=['mongodb://localhost'], is_mock=True,  alias='testdb2') | ||||
|         conn = get_connection('testdb2') | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|         connect(host=['localhost'], is_mock=True,  alias='testdb3') | ||||
|         conn = get_connection('testdb3') | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|         connect(host=['mongomock://localhost:27017', 'mongomock://localhost:27018'], alias='testdb4') | ||||
|         conn = get_connection('testdb4') | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|         connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True,  alias='testdb5') | ||||
|         conn = get_connection('testdb5') | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|         connect(host=['localhost:27017', 'localhost:27018'], is_mock=True,  alias='testdb6') | ||||
|         conn = get_connection('testdb6') | ||||
|         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||
|  | ||||
|     def test_disconnect(self): | ||||
|         """Ensure that the disconnect() method works properly | ||||
|         """ | ||||
| @@ -140,19 +174,9 @@ class ConnectionTest(unittest.TestCase): | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|     def test_connect_uri_without_db(self): | ||||
|         """Ensure connect() method works properly with uri's without database_name | ||||
|         """Ensure connect() method works properly if the URI doesn't | ||||
|         include a database name. | ||||
|         """ | ||||
|         c = connect(db='mongoenginetest', alias='admin') | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|         c.admin.add_user("admin", "password") | ||||
|         c.admin.authenticate("admin", "password") | ||||
|         c.mongoenginetest.add_user("username", "password") | ||||
|  | ||||
|         if not IS_PYMONGO_3: | ||||
|             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|  | ||||
|         connect("mongoenginetest", host='mongodb://localhost/') | ||||
|  | ||||
|         conn = get_connection() | ||||
| @@ -162,8 +186,31 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'mongoenginetest') | ||||
|  | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|     def test_connect_uri_default_db(self): | ||||
|         """Ensure connect() defaults to the right database name if | ||||
|         the URI and the database_name don't explicitly specify it. | ||||
|         """ | ||||
|         connect(host='mongodb://localhost/') | ||||
|  | ||||
|         conn = get_connection() | ||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||
|  | ||||
|         db = get_db() | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'test') | ||||
|  | ||||
|     def test_uri_without_credentials_doesnt_override_conn_settings(self): | ||||
|         """Ensure connect() uses the username & password params if the URI | ||||
|         doesn't explicitly specify them. | ||||
|         """ | ||||
|         c = connect(host='mongodb://localhost/mongoenginetest', | ||||
|                     username='user', | ||||
|                     password='pass') | ||||
|  | ||||
|         # OperationFailure means that mongoengine attempted authentication | ||||
|         # w/ the provided username/password and failed - that's the desired | ||||
|         # behavior. If the MongoDB URI would override the credentials | ||||
|         self.assertRaises(OperationFailure, get_db) | ||||
|  | ||||
|     def test_connect_uri_with_authsource(self): | ||||
|         """Ensure that the connect() method works well with | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import unittest | ||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict  | ||||
|  | ||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict | ||||
|  | ||||
|  | ||||
| class TestStrictDict(unittest.TestCase): | ||||
| @@ -13,9 +14,17 @@ class TestStrictDict(unittest.TestCase): | ||||
|         d = self.dtype(a=1, b=1, c=1) | ||||
|         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||
|  | ||||
|     def test_repr(self): | ||||
|         d = self.dtype(a=1, b=2, c=3) | ||||
|         self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') | ||||
|  | ||||
|         # make sure quotes are escaped properly | ||||
|         d = self.dtype(a='"', b="'", c="") | ||||
|         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') | ||||
|  | ||||
|     def test_init_fails_on_nonexisting_attrs(self): | ||||
|         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||
|          | ||||
|  | ||||
|     def test_eq(self): | ||||
|         d = self.dtype(a=1, b=1, c=1) | ||||
|         dd = self.dtype(a=1, b=1, c=1) | ||||
| @@ -24,7 +33,7 @@ class TestStrictDict(unittest.TestCase): | ||||
|         g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) | ||||
|         h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) | ||||
|         i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) | ||||
|          | ||||
|  | ||||
|         self.assertEqual(d, dd) | ||||
|         self.assertNotEqual(d, e) | ||||
|         self.assertNotEqual(d, f) | ||||
| @@ -38,19 +47,19 @@ class TestStrictDict(unittest.TestCase): | ||||
|         d.a = 1 | ||||
|         self.assertEqual(d.a, 1) | ||||
|         self.assertRaises(AttributeError, lambda: d.b) | ||||
|      | ||||
|  | ||||
|     def test_setattr_raises_on_nonexisting_attr(self): | ||||
|         d = self.dtype() | ||||
|  | ||||
|         def _f(): | ||||
|             d.x = 1 | ||||
|         self.assertRaises(AttributeError, _f) | ||||
|      | ||||
|  | ||||
|     def test_setattr_getattr_special(self): | ||||
|         d = self.strict_dict_class(["items"]) | ||||
|         d.items = 1 | ||||
|         self.assertEqual(d.items, 1) | ||||
|      | ||||
|  | ||||
|     def test_get(self): | ||||
|         d = self.dtype(a=1) | ||||
|         self.assertEqual(d.get('a'), 1) | ||||
| @@ -88,7 +97,7 @@ class TestSemiSrictDict(TestStrictDict): | ||||
|     def test_init_succeeds_with_nonexisting_attrs(self): | ||||
|         d = self.dtype(a=1, b=1, c=1, x=2) | ||||
|         self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) | ||||
|     | ||||
|  | ||||
|     def test_iter_with_nonexisting_attrs(self): | ||||
|         d = self.dtype(a=1, b=1, c=1, x=2) | ||||
|         self.assertEqual(list(d), ['a', 'b', 'c', 'x']) | ||||
|   | ||||
| @@ -12,9 +12,13 @@ from mongoengine.context_managers import query_counter | ||||
|  | ||||
| class FieldTest(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         cls.db = connect(db='mongoenginetest') | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         cls.db.drop_database('mongoenginetest') | ||||
|  | ||||
|     def test_list_item_dereference(self): | ||||
|         """Ensure that DBRef items in ListFields are dereferenced. | ||||
| @@ -304,6 +308,7 @@ class FieldTest(unittest.TestCase): | ||||
|  | ||||
|         User.drop_collection() | ||||
|         Post.drop_collection() | ||||
|         SimpleList.drop_collection() | ||||
|  | ||||
|         u1 = User.objects.create(name='u1') | ||||
|         u2 = User.objects.create(name='u2') | ||||
|   | ||||
| @@ -25,6 +25,8 @@ class SignalTests(unittest.TestCase): | ||||
|         connect(db='mongoenginetest') | ||||
|  | ||||
|         class Author(Document): | ||||
|             # Make the id deterministic for easier testing | ||||
|             id = SequenceField(primary_key=True) | ||||
|             name = StringField() | ||||
|  | ||||
|             def __unicode__(self): | ||||
| @@ -33,7 +35,7 @@ class SignalTests(unittest.TestCase): | ||||
|             @classmethod | ||||
|             def pre_init(cls, sender, document, *args, **kwargs): | ||||
|                 signal_output.append('pre_init signal, %s' % cls.__name__) | ||||
|                 signal_output.append(str(kwargs['values'])) | ||||
|                 signal_output.append(kwargs['values']) | ||||
|  | ||||
|             @classmethod | ||||
|             def post_init(cls, sender, document, **kwargs): | ||||
| @@ -43,48 +45,55 @@ class SignalTests(unittest.TestCase): | ||||
|             @classmethod | ||||
|             def pre_save(cls, sender, document, **kwargs): | ||||
|                 signal_output.append('pre_save signal, %s' % document) | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def pre_save_post_validation(cls, sender, document, **kwargs): | ||||
|                 signal_output.append('pre_save_post_validation signal, %s' % document) | ||||
|                 if 'created' in kwargs: | ||||
|                     if kwargs['created']: | ||||
|                         signal_output.append('Is created') | ||||
|                     else: | ||||
|                         signal_output.append('Is updated') | ||||
|                 if kwargs.pop('created', False): | ||||
|                     signal_output.append('Is created') | ||||
|                 else: | ||||
|                     signal_output.append('Is updated') | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def post_save(cls, sender, document, **kwargs): | ||||
|                 dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() | ||||
|                 signal_output.append('post_save signal, %s' % document) | ||||
|                 signal_output.append('post_save dirty keys, %s' % dirty_keys) | ||||
|                 if 'created' in kwargs: | ||||
|                     if kwargs['created']: | ||||
|                         signal_output.append('Is created') | ||||
|                     else: | ||||
|                         signal_output.append('Is updated') | ||||
|                 if kwargs.pop('created', False): | ||||
|                     signal_output.append('Is created') | ||||
|                 else: | ||||
|                     signal_output.append('Is updated') | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def pre_delete(cls, sender, document, **kwargs): | ||||
|                 signal_output.append('pre_delete signal, %s' % document) | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def post_delete(cls, sender, document, **kwargs): | ||||
|                 signal_output.append('post_delete signal, %s' % document) | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def pre_bulk_insert(cls, sender, documents, **kwargs): | ||||
|                 signal_output.append('pre_bulk_insert signal, %s' % documents) | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def post_bulk_insert(cls, sender, documents, **kwargs): | ||||
|                 signal_output.append('post_bulk_insert signal, %s' % documents) | ||||
|                 if kwargs.get('loaded', False): | ||||
|                 if kwargs.pop('loaded', False): | ||||
|                     signal_output.append('Is loaded') | ||||
|                 else: | ||||
|                     signal_output.append('Not loaded') | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|         self.Author = Author | ||||
|         Author.drop_collection() | ||||
|         Author.id.set_next_value(0) | ||||
|  | ||||
|         class Another(Document): | ||||
|  | ||||
| @@ -96,10 +105,12 @@ class SignalTests(unittest.TestCase): | ||||
|             @classmethod | ||||
|             def pre_delete(cls, sender, document, **kwargs): | ||||
|                 signal_output.append('pre_delete signal, %s' % document) | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def post_delete(cls, sender, document, **kwargs): | ||||
|                 signal_output.append('post_delete signal, %s' % document) | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|         self.Another = Another | ||||
|         Another.drop_collection() | ||||
| @@ -118,6 +129,41 @@ class SignalTests(unittest.TestCase): | ||||
|         self.ExplicitId = ExplicitId | ||||
|         ExplicitId.drop_collection() | ||||
|  | ||||
|         class Post(Document): | ||||
|             title = StringField() | ||||
|             content = StringField() | ||||
|             active = BooleanField(default=False) | ||||
|  | ||||
|             def __unicode__(self): | ||||
|                 return self.title | ||||
|  | ||||
|             @classmethod | ||||
|             def pre_bulk_insert(cls, sender, documents, **kwargs): | ||||
|                 signal_output.append('pre_bulk_insert signal, %s' % | ||||
|                                      [(doc, {'active': documents[n].active}) | ||||
|                                       for n, doc in enumerate(documents)]) | ||||
|  | ||||
|                 # make changes here, this is just an example - | ||||
|                 # it could be anything that needs pre-validation or looks-ups before bulk bulk inserting | ||||
|                 for document in documents: | ||||
|                     if not document.active: | ||||
|                         document.active = True | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|             @classmethod | ||||
|             def post_bulk_insert(cls, sender, documents, **kwargs): | ||||
|                 signal_output.append('post_bulk_insert signal, %s' % | ||||
|                                      [(doc, {'active': documents[n].active}) | ||||
|                                       for n, doc in enumerate(documents)]) | ||||
|                 if kwargs.pop('loaded', False): | ||||
|                     signal_output.append('Is loaded') | ||||
|                 else: | ||||
|                     signal_output.append('Not loaded') | ||||
|                 signal_output.append(kwargs) | ||||
|  | ||||
|         self.Post = Post | ||||
|         Post.drop_collection() | ||||
|  | ||||
|         # Save up the number of connected signals so that we can check at the | ||||
|         # end that all the signals we register get properly unregistered | ||||
|         self.pre_signals = ( | ||||
| @@ -147,6 +193,9 @@ class SignalTests(unittest.TestCase): | ||||
|  | ||||
|         signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) | ||||
|  | ||||
|         signals.pre_bulk_insert.connect(Post.pre_bulk_insert, sender=Post) | ||||
|         signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         signals.pre_init.disconnect(self.Author.pre_init) | ||||
|         signals.post_init.disconnect(self.Author.post_init) | ||||
| @@ -163,6 +212,9 @@ class SignalTests(unittest.TestCase): | ||||
|  | ||||
|         signals.post_save.disconnect(self.ExplicitId.post_save) | ||||
|  | ||||
|         signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) | ||||
|         signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) | ||||
|  | ||||
|         # Check that all our signals got disconnected properly. | ||||
|         post_signals = ( | ||||
|             len(signals.pre_init.receivers), | ||||
| @@ -199,66 +251,121 @@ class SignalTests(unittest.TestCase): | ||||
|             a.save() | ||||
|             self.get_signal_output(lambda: None) # eliminate signal output | ||||
|             a1 = self.Author.objects(name='Bill Shakespeare')[0] | ||||
|          | ||||
|  | ||||
|         self.assertEqual(self.get_signal_output(create_author), [ | ||||
|             "pre_init signal, Author", | ||||
|             "{'name': 'Bill Shakespeare'}", | ||||
|             {'name': 'Bill Shakespeare'}, | ||||
|             "post_init signal, Bill Shakespeare, document._created = True", | ||||
|         ]) | ||||
|  | ||||
|         a1 = self.Author(name='Bill Shakespeare') | ||||
|         self.assertEqual(self.get_signal_output(a1.save), [ | ||||
|             "pre_save signal, Bill Shakespeare", | ||||
|             {}, | ||||
|             "pre_save_post_validation signal, Bill Shakespeare", | ||||
|             "Is created", | ||||
|             {}, | ||||
|             "post_save signal, Bill Shakespeare", | ||||
|             "post_save dirty keys, ['name']", | ||||
|             "Is created" | ||||
|             "Is created", | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
|         a1.reload() | ||||
|         a1.name = 'William Shakespeare' | ||||
|         self.assertEqual(self.get_signal_output(a1.save), [ | ||||
|             "pre_save signal, William Shakespeare", | ||||
|             {}, | ||||
|             "pre_save_post_validation signal, William Shakespeare", | ||||
|             "Is updated", | ||||
|             {}, | ||||
|             "post_save signal, William Shakespeare", | ||||
|             "post_save dirty keys, ['name']", | ||||
|             "Is updated" | ||||
|             "Is updated", | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
|         self.assertEqual(self.get_signal_output(a1.delete), [ | ||||
|             'pre_delete signal, William Shakespeare', | ||||
|             {}, | ||||
|             'post_delete signal, William Shakespeare', | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
|         signal_output = self.get_signal_output(load_existing_author) | ||||
|         # test signal_output lines separately, because of random ObjectID after object load | ||||
|         self.assertEqual(signal_output[0], | ||||
|         self.assertEqual(self.get_signal_output(load_existing_author), [ | ||||
|             "pre_init signal, Author", | ||||
|         ) | ||||
|         self.assertEqual(signal_output[2], | ||||
|             "post_init signal, Bill Shakespeare, document._created = False", | ||||
|         ) | ||||
|             {'id': 2, 'name': 'Bill Shakespeare'}, | ||||
|             "post_init signal, Bill Shakespeare, document._created = False" | ||||
|         ]) | ||||
|  | ||||
|  | ||||
|         signal_output = self.get_signal_output(bulk_create_author_with_load) | ||||
|  | ||||
|         # The output of this signal is not entirely deterministic. The reloaded | ||||
|         # object will have an object ID. Hence, we only check part of the output | ||||
|         self.assertEqual(signal_output[3], "pre_bulk_insert signal, [<Author: Bill Shakespeare>]" | ||||
|         ) | ||||
|         self.assertEqual(signal_output[-2:], | ||||
|             ["post_bulk_insert signal, [<Author: Bill Shakespeare>]", | ||||
|              "Is loaded",]) | ||||
|         self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ | ||||
|             'pre_init signal, Author', | ||||
|             {'name': 'Bill Shakespeare'}, | ||||
|             'post_init signal, Bill Shakespeare, document._created = True', | ||||
|             'pre_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||
|             {}, | ||||
|             'pre_init signal, Author', | ||||
|             {'id': 3, 'name': 'Bill Shakespeare'}, | ||||
|             'post_init signal, Bill Shakespeare, document._created = False', | ||||
|             'post_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||
|             'Is loaded', | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
|         self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ | ||||
|             "pre_init signal, Author", | ||||
|             "{'name': 'Bill Shakespeare'}", | ||||
|             {'name': 'Bill Shakespeare'}, | ||||
|             "post_init signal, Bill Shakespeare, document._created = True", | ||||
|             "pre_bulk_insert signal, [<Author: Bill Shakespeare>]", | ||||
|             {}, | ||||
|             "post_bulk_insert signal, [<Author: Bill Shakespeare>]", | ||||
|             "Not loaded", | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
|     def test_signal_kwargs(self): | ||||
|         """ Make sure signal_kwargs is passed to signals calls. """ | ||||
|  | ||||
|         def live_and_let_die(): | ||||
|             a = self.Author(name='Bill Shakespeare') | ||||
|             a.save(signal_kwargs={'live': True, 'die': False}) | ||||
|             a.delete(signal_kwargs={'live': False, 'die': True}) | ||||
|  | ||||
|         self.assertEqual(self.get_signal_output(live_and_let_die), [ | ||||
|             "pre_init signal, Author", | ||||
|             {'name': 'Bill Shakespeare'}, | ||||
|             "post_init signal, Bill Shakespeare, document._created = True", | ||||
|             "pre_save signal, Bill Shakespeare", | ||||
|             {'die': False, 'live': True}, | ||||
|             "pre_save_post_validation signal, Bill Shakespeare", | ||||
|             "Is created", | ||||
|             {'die': False, 'live': True}, | ||||
|             "post_save signal, Bill Shakespeare", | ||||
|             "post_save dirty keys, ['name']", | ||||
|             "Is created", | ||||
|             {'die': False, 'live': True}, | ||||
|             'pre_delete signal, Bill Shakespeare', | ||||
|             {'die': True, 'live': False}, | ||||
|             'post_delete signal, Bill Shakespeare', | ||||
|             {'die': True, 'live': False} | ||||
|         ]) | ||||
|  | ||||
|         def bulk_create_author(): | ||||
|             a1 = self.Author(name='Bill Shakespeare') | ||||
|             self.Author.objects.insert([a1], signal_kwargs={'key': True}) | ||||
|  | ||||
|         self.assertEqual(self.get_signal_output(bulk_create_author), [ | ||||
|             'pre_init signal, Author', | ||||
|             {'name': 'Bill Shakespeare'}, | ||||
|             'post_init signal, Bill Shakespeare, document._created = True', | ||||
|             'pre_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||
|             {'key': True}, | ||||
|             'pre_init signal, Author', | ||||
|             {'id': 2, 'name': 'Bill Shakespeare'}, | ||||
|             'post_init signal, Bill Shakespeare, document._created = False', | ||||
|             'post_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||
|             'Is loaded', | ||||
|             {'key': True} | ||||
|         ]) | ||||
|  | ||||
|     def test_queryset_delete_signals(self): | ||||
| @@ -267,7 +374,9 @@ class SignalTests(unittest.TestCase): | ||||
|         self.Another(name='Bill Shakespeare').save() | ||||
|         self.assertEqual(self.get_signal_output(self.Another.objects.delete), [ | ||||
|             'pre_delete signal, Bill Shakespeare', | ||||
|             {}, | ||||
|             'post_delete signal, Bill Shakespeare', | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
|     def test_signals_with_explicit_doc_ids(self): | ||||
| @@ -306,6 +415,23 @@ class SignalTests(unittest.TestCase): | ||||
|         ei.switch_db("testdb-1", keep_created=False) | ||||
|         self.assertEqual(self.get_signal_output(ei.save), ['Is created']) | ||||
|  | ||||
|     def test_signals_bulk_insert(self): | ||||
|         def bulk_set_active_post(): | ||||
|             posts = [ | ||||
|                 self.Post(title='Post 1'), | ||||
|                 self.Post(title='Post 2'), | ||||
|                 self.Post(title='Post 3') | ||||
|             ] | ||||
|             self.Post.objects.insert(posts) | ||||
|  | ||||
|         results = self.get_signal_output(bulk_set_active_post) | ||||
|         self.assertEqual(results, [ | ||||
|             "pre_bulk_insert signal, [(<Post: Post 1>, {'active': False}), (<Post: Post 2>, {'active': False}), (<Post: Post 3>, {'active': False})]", | ||||
|             {}, | ||||
|             "post_bulk_insert signal, [(<Post: Post 1>, {'active': True}), (<Post: Post 2>, {'active': True}), (<Post: Post 3>, {'active': True})]", | ||||
|             'Is loaded', | ||||
|             {} | ||||
|         ]) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
							
								
								
									
										11
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								tox.ini
									
									
									
									
									
								
							| @@ -1,13 +1,11 @@ | ||||
| [tox] | ||||
| envlist = {py26,py27,py32,py33,py34,py35,pypy,pypy3}-{mg27,mg28} | ||||
| #envlist = {py26,py27,py32,py33,py34,pypy,pypy3}-{mg27,mg28,mg30,mgdev} | ||||
| envlist = {py26,py27,py33,py34,py35,pypy,pypy3}-{mg27,mg28},flake8 | ||||
|  | ||||
| [testenv] | ||||
| commands = | ||||
|     python setup.py nosetests {posargs} | ||||
| deps = | ||||
|     nose | ||||
|     rednose | ||||
|     mg27: PyMongo<2.8 | ||||
|     mg28: PyMongo>=2.8,<3.0 | ||||
|     mg30: PyMongo>=3.0 | ||||
| @@ -15,3 +13,10 @@ deps = | ||||
| setenv = | ||||
|     PYTHON_EGG_CACHE = {envdir}/python-eggs | ||||
| passenv = windir | ||||
|  | ||||
| [testenv:flake8] | ||||
| deps = | ||||
|     flake8 | ||||
|     flake8-import-order | ||||
| commands = | ||||
|    flake8 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user