Compare commits
	
		
			3 Commits
		
	
	
		
			cleanup-fi
			...
			batch-size
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 34ba527e6d | ||
|  | ea9027755f | ||
|  | 43668a93a2 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -15,6 +15,3 @@ env/ | ||||
| .pydevproject | ||||
| tests/test_bugfix.py | ||||
| htmlcov/ | ||||
| venv | ||||
| venv3 | ||||
| scratchpad | ||||
|   | ||||
| @@ -1,23 +0,0 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 | ||||
|  | ||||
| if [ "$MONGODB" = "2.4" ]; then | ||||
|     echo "deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen" | sudo tee /etc/apt/sources.list.d/mongodb.list | ||||
|     sudo apt-get update | ||||
|     sudo apt-get install mongodb-10gen=2.4.14 | ||||
|     sudo service mongodb start | ||||
| elif [ "$MONGODB" = "2.6" ]; then | ||||
|     echo "deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen" | sudo tee /etc/apt/sources.list.d/mongodb.list | ||||
|     sudo apt-get update | ||||
|     sudo apt-get install mongodb-org-server=2.6.12 | ||||
|     # service should be started automatically | ||||
| elif [ "$MONGODB" = "3.0" ]; then | ||||
|     echo "deb http://repo.mongodb.org/apt/ubuntu precise/mongodb-org/3.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb.list | ||||
|     sudo apt-get update | ||||
|     sudo apt-get install mongodb-org-server=3.0.14 | ||||
|     # service should be started automatically | ||||
| else | ||||
|     echo "Invalid MongoDB version, expected 2.4, 2.6, or 3.0." | ||||
|     exit 1 | ||||
| fi; | ||||
| @@ -1,22 +0,0 @@ | ||||
| pylint: | ||||
|     disable: | ||||
|         # We use this a lot (e.g. via document._meta) | ||||
|         - protected-access | ||||
|  | ||||
|     options: | ||||
|         additional-builtins: | ||||
|             # add xrange and long as valid built-ins. In Python 3, xrange is | ||||
|             # translated into range and long is translated into int via 2to3 (see | ||||
|             # "use_2to3" in setup.py). This should be removed when we drop Python | ||||
|             # 2 support (which probably won't happen any time soon). | ||||
|             - xrange | ||||
|             - long | ||||
|  | ||||
| pyflakes: | ||||
|     disable: | ||||
|         # undefined variables are already covered by pylint (and exclude | ||||
|         # xrange & long) | ||||
|         - F821 | ||||
|  | ||||
| ignore-paths: | ||||
|     - benchmark.py | ||||
							
								
								
									
										72
									
								
								.travis.yml
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								.travis.yml
									
									
									
									
									
								
							| @@ -1,48 +1,29 @@ | ||||
| # For full coverage, we'd have to test all supported Python, MongoDB, and | ||||
| # PyMongo combinations. However, that would result in an overly long build | ||||
| # with a very large number of jobs, hence we only test a subset of all the | ||||
| # combinations: | ||||
| # * MongoDB v2.4 & v3.0 are only tested against Python v2.7 & v3.5. | ||||
| # * MongoDB v2.4 is tested against PyMongo v2.7 & v3.x. | ||||
| # * MongoDB v3.0 is tested against PyMongo v3.x. | ||||
| # * MongoDB v2.6 is currently the "main" version tested against Python v2.7, | ||||
| #   v3.5, PyPy & PyPy3, and PyMongo v2.7, v2.8 & v3.x. | ||||
| # | ||||
| # Reminder: Update README.rst if you change MongoDB versions we test. | ||||
|  | ||||
| language: python | ||||
|  | ||||
| python: | ||||
| - 2.7 | ||||
| - 3.5 | ||||
| - '2.6' | ||||
| - '2.7' | ||||
| - '3.3' | ||||
| - '3.4' | ||||
| - '3.5' | ||||
| - pypy | ||||
| - pypy3 | ||||
|  | ||||
| env: | ||||
| - MONGODB=2.6 PYMONGO=2.7 | ||||
| - MONGODB=2.6 PYMONGO=2.8 | ||||
| - MONGODB=2.6 PYMONGO=3.0 | ||||
| - PYMONGO=2.7 | ||||
| - PYMONGO=2.8 | ||||
| - PYMONGO=3.0 | ||||
| - PYMONGO=dev | ||||
|  | ||||
| matrix: | ||||
|   # Finish the build as soon as one job fails | ||||
|   fast_finish: true | ||||
|  | ||||
|   include: | ||||
|   - python: 2.7 | ||||
|     env: MONGODB=2.4 PYMONGO=2.7 | ||||
|   - python: 2.7 | ||||
|     env: MONGODB=2.4 PYMONGO=3.0 | ||||
|   - python: 2.7 | ||||
|     env: MONGODB=3.0 PYMONGO=3.0 | ||||
|   - python: 3.5 | ||||
|     env: MONGODB=2.4 PYMONGO=2.7 | ||||
|   - python: 3.5 | ||||
|     env: MONGODB=2.4 PYMONGO=3.0 | ||||
|   - python: 3.5 | ||||
|     env: MONGODB=3.0 PYMONGO=3.0 | ||||
|  | ||||
| before_install: | ||||
| - bash .install_mongodb_on_travis.sh | ||||
| - travis_retry sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 | ||||
| - echo 'deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen' | | ||||
|   sudo tee /etc/apt/sources.list.d/mongodb.list | ||||
| - travis_retry sudo apt-get update | ||||
| - travis_retry sudo apt-get install mongodb-org-server | ||||
|  | ||||
| install: | ||||
| - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev | ||||
| @@ -50,52 +31,33 @@ install: | ||||
|   python-tk | ||||
| - travis_retry pip install --upgrade pip | ||||
| - travis_retry pip install coveralls | ||||
| - travis_retry pip install flake8 flake8-import-order | ||||
| - travis_retry pip install flake8 | ||||
| - travis_retry pip install tox>=1.9 | ||||
| - travis_retry pip install "virtualenv<14.0.0"  # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) | ||||
| - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test | ||||
|  | ||||
| # Cache dependencies installed via pip | ||||
| cache: pip | ||||
|  | ||||
| # Run flake8 for py27 | ||||
| before_script: | ||||
| - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then flake8 .; else echo "flake8 only runs on py27"; fi | ||||
| - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then tox -e flake8; fi | ||||
|  | ||||
| script: | ||||
| - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage | ||||
|  | ||||
| # For now only submit coveralls for Python v2.7. Python v3.x currently shows | ||||
| # 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible | ||||
| # code in a separate dir and runs tests on that. | ||||
| after_success: | ||||
| - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then coveralls --verbose; fi | ||||
| after_script: coveralls --verbose | ||||
|  | ||||
| notifications: | ||||
|   irc: irc.freenode.org#mongoengine | ||||
|  | ||||
| # Only run builds on the master branch and GitHub releases (tagged as vX.Y.Z) | ||||
| branches: | ||||
|   only: | ||||
|   - master | ||||
|   - /^v.*$/ | ||||
|  | ||||
| # Whenever a new release is created via GitHub, publish it on PyPI. | ||||
| deploy: | ||||
|   provider: pypi | ||||
|   user: the_drow | ||||
|   password: | ||||
|     secure: QMyatmWBnC6ZN3XLW2+fTBDU4LQcp1m/LjR2/0uamyeUzWKdlOoh/Wx5elOgLwt/8N9ppdPeG83ose1jOz69l5G0MUMjv8n/RIcMFSpCT59tGYqn3kh55b0cIZXFT9ar+5cxlif6a5rS72IHm5li7QQyxexJIII6Uxp0kpvUmek= | ||||
|  | ||||
|   # create a source distribution and a pure python wheel for faster installs | ||||
|   distributions: "sdist bdist_wheel" | ||||
|  | ||||
|   # only deploy on tagged commits (aka GitHub releases) and only for the | ||||
|   # parent repo's builds running Python 2.7 along with dev PyMongo (we run | ||||
|   # Travis against many different Python and PyMongo versions and we don't | ||||
|   # want the deploy to occur multiple times). | ||||
|   on: | ||||
|     tags: true | ||||
|     repo: MongoEngine/mongoengine | ||||
|     condition: "$PYMONGO = 3.0" | ||||
|     python: 2.7 | ||||
|   | ||||
							
								
								
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -242,4 +242,3 @@ that much better: | ||||
|  * xiaost7 (https://github.com/xiaost7) | ||||
|  * Victor Varvaryuk | ||||
|  * Stanislav Kaledin (https://github.com/sallyruthstruik) | ||||
|  * Dmitry Yantsen (https://github.com/mrTable) | ||||
|   | ||||
| @@ -14,13 +14,13 @@ Before starting to write code, look for existing `tickets | ||||
| <https://github.com/MongoEngine/mongoengine/issues?state=open>`_ or `create one | ||||
| <https://github.com/MongoEngine/mongoengine/issues>`_ for your specific | ||||
| issue or feature request. That way you avoid working on something | ||||
| that might not be of interest or that has already been addressed. If in doubt | ||||
| that might not be of interest or that has already been addressed.  If in doubt | ||||
| post to the `user group <http://groups.google.com/group/mongoengine-users>` | ||||
|  | ||||
| Supported Interpreters | ||||
| ---------------------- | ||||
|  | ||||
| MongoEngine supports CPython 2.7 and newer. Language | ||||
| MongoEngine supports CPython 2.6 and newer. Language | ||||
| features not supported by all interpreters can not be used. | ||||
| Please also ensure that your code is properly converted by | ||||
| `2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support. | ||||
| @@ -29,20 +29,19 @@ Style Guide | ||||
| ----------- | ||||
|  | ||||
| MongoEngine aims to follow `PEP8 <http://www.python.org/dev/peps/pep-0008/>`_ | ||||
| including 4 space indents. When possible we try to stick to 79 character line | ||||
| limits. However, screens got bigger and an ORM has a strong focus on | ||||
| readability and if it can help, we accept 119 as maximum line length, in a | ||||
| similar way as `django does | ||||
| <https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/coding-style/#python-style>`_ | ||||
| including 4 space indents. When possible we try to stick to 79 character line limits. | ||||
| However, screens got bigger and an ORM has a strong focus on readability and | ||||
| if it can help, we accept 119 as maximum line length, in a similar way as | ||||
| `django does <https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/coding-style/#python-style>`_ | ||||
|  | ||||
| Testing | ||||
| ------- | ||||
|  | ||||
| All tests are run on `Travis <http://travis-ci.org/MongoEngine/mongoengine>`_ | ||||
| and any pull requests are automatically tested. Any pull requests without | ||||
| tests will take longer to be integrated and might be refused. | ||||
| and any pull requests are automatically tested by Travis. Any pull requests | ||||
| without tests will take longer to be integrated and might be refused. | ||||
|  | ||||
| You may also submit a simple failing test as a pull request if you don't know | ||||
| You may also submit a simple failing test as a PullRequest if you don't know | ||||
| how to fix it, it will be easier for other people to work on it and it may get | ||||
| fixed faster. | ||||
|  | ||||
| @@ -50,18 +49,13 @@ General Guidelines | ||||
| ------------------ | ||||
|  | ||||
| - Avoid backward breaking changes if at all possible. | ||||
| - If you *have* to introduce a breaking change, make it very clear in your | ||||
|   pull request's description. Also, describe how users of this package | ||||
|   should adapt to the breaking change in docs/upgrade.rst. | ||||
| - Write inline documentation for new classes and methods. | ||||
| - Write tests and make sure they pass (make sure you have a mongod | ||||
|   running on the default port, then execute ``python setup.py nosetests`` | ||||
|   from the cmd line to run the test suite). | ||||
| - Ensure tests pass on all supported Python, PyMongo, and MongoDB versions. | ||||
|   You can test various Python and PyMongo versions locally by executing | ||||
|   ``tox``. For different MongoDB versions, you can rely on our automated | ||||
|   Travis tests. | ||||
| - Add enhancements or problematic bug fixes to docs/changelog.rst. | ||||
| - Ensure tests pass on every Python and PyMongo versions. | ||||
|   You can test on these versions locally by executing ``tox`` | ||||
| - Add enhancements or problematic bug fixes to docs/changelog.rst | ||||
| - Add yourself to AUTHORS :) | ||||
|  | ||||
| Documentation | ||||
| @@ -75,6 +69,3 @@ just make your changes to the inline documentation of the appropriate | ||||
| branch and submit a `pull request <https://help.github.com/articles/using-pull-requests>`_. | ||||
| You might also use the github `Edit <https://github.com/blog/844-forking-with-the-edit-button>`_ | ||||
| button. | ||||
|  | ||||
| If you want to test your documentation changes locally, you need to install | ||||
| the ``sphinx`` package. | ||||
|   | ||||
							
								
								
									
										73
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										73
									
								
								README.rst
									
									
									
									
									
								
							| @@ -4,7 +4,7 @@ MongoEngine | ||||
| :Info: MongoEngine is an ORM-like layer on top of PyMongo. | ||||
| :Repository: https://github.com/MongoEngine/mongoengine | ||||
| :Author: Harry Marr (http://github.com/hmarr) | ||||
| :Maintainer: Stefan Wójcik (http://github.com/wojcikstefan) | ||||
| :Maintainer: Ross Lawley (http://github.com/rozza) | ||||
|  | ||||
| .. image:: https://travis-ci.org/MongoEngine/mongoengine.svg?branch=master | ||||
|   :target: https://travis-ci.org/MongoEngine/mongoengine | ||||
| @@ -19,42 +19,32 @@ MongoEngine | ||||
| About | ||||
| ===== | ||||
| MongoEngine is a Python Object-Document Mapper for working with MongoDB. | ||||
| Documentation is available at https://mongoengine-odm.readthedocs.io - there | ||||
| is currently a `tutorial <https://mongoengine-odm.readthedocs.io/tutorial.html>`_, | ||||
| a `user guide <https://mongoengine-odm.readthedocs.io/guide/index.html>`_, and | ||||
| an `API reference <https://mongoengine-odm.readthedocs.io/apireference.html>`_. | ||||
|  | ||||
| Supported MongoDB Versions | ||||
| ========================== | ||||
| MongoEngine is currently tested against MongoDB v2.4, v2.6, and v3.0. Future | ||||
| versions should be supported as well, but aren't actively tested at the moment. | ||||
| Make sure to open an issue or submit a pull request if you experience any | ||||
| problems with MongoDB v3.2+. | ||||
| Documentation available at https://mongoengine-odm.readthedocs.io - there is currently | ||||
| a `tutorial <https://mongoengine-odm.readthedocs.io/tutorial.html>`_, a `user guide | ||||
| <https://mongoengine-odm.readthedocs.io/guide/index.html>`_ and an `API reference | ||||
| <https://mongoengine-odm.readthedocs.io/apireference.html>`_. | ||||
|  | ||||
| Installation | ||||
| ============ | ||||
| We recommend the use of `virtualenv <https://virtualenv.pypa.io/>`_ and of | ||||
| `pip <https://pip.pypa.io/>`_. You can then use ``pip install -U mongoengine``. | ||||
| You may also have `setuptools <http://peak.telecommunity.com/DevCenter/setuptools>`_ | ||||
| and thus you can use ``easy_install -U mongoengine``. Otherwise, you can download the | ||||
| You may also have `setuptools <http://peak.telecommunity.com/DevCenter/setuptools>`_ and thus | ||||
| you can use ``easy_install -U mongoengine``. Otherwise, you can download the | ||||
| source from `GitHub <http://github.com/MongoEngine/mongoengine>`_ and run ``python | ||||
| setup.py install``. | ||||
|  | ||||
| Dependencies | ||||
| ============ | ||||
| All of the dependencies can easily be installed via `pip <https://pip.pypa.io/>`_. | ||||
| At the very least, you'll need these two packages to use MongoEngine: | ||||
|  | ||||
| - pymongo>=2.7.1 | ||||
| - six>=1.10.0 | ||||
|  | ||||
| If you utilize a ``DateTimeField``, you might also use a more flexible date parser: | ||||
| - sphinx (optional - for documentation generation) | ||||
|  | ||||
| Optional Dependencies | ||||
| --------------------- | ||||
| - **Image Fields**: Pillow>=2.0.0 | ||||
| - dateutil>=2.1.0 | ||||
|  | ||||
| If you need to use an ``ImageField`` or ``ImageGridFsProxy``: | ||||
|  | ||||
| - Pillow>=2.0.0 | ||||
| .. note | ||||
|    MongoEngine always runs it's test suite against the latest patch version of each dependecy. e.g.: PyMongo 3.0.1 | ||||
|  | ||||
| Examples | ||||
| ======== | ||||
| @@ -67,7 +57,7 @@ Some simple examples of what MongoEngine code looks like: | ||||
|  | ||||
|     class BlogPost(Document): | ||||
|         title = StringField(required=True, max_length=200) | ||||
|         posted = DateTimeField(default=datetime.datetime.utcnow) | ||||
|         posted = DateTimeField(default=datetime.datetime.now) | ||||
|         tags = ListField(StringField(max_length=50)) | ||||
|         meta = {'allow_inheritance': True} | ||||
|  | ||||
| @@ -97,28 +87,27 @@ Some simple examples of what MongoEngine code looks like: | ||||
|     ...     print | ||||
|     ... | ||||
|  | ||||
|     # Count all blog posts and its subtypes | ||||
|     >>> BlogPost.objects.count() | ||||
|     >>> len(BlogPost.objects) | ||||
|     2 | ||||
|     >>> TextPost.objects.count() | ||||
|     >>> len(TextPost.objects) | ||||
|     1 | ||||
|     >>> LinkPost.objects.count() | ||||
|     >>> len(LinkPost.objects) | ||||
|     1 | ||||
|  | ||||
|     # Count tagged posts | ||||
|     >>> BlogPost.objects(tags='mongoengine').count() | ||||
|     # Find tagged posts | ||||
|     >>> len(BlogPost.objects(tags='mongoengine')) | ||||
|     2 | ||||
|     >>> BlogPost.objects(tags='mongodb').count() | ||||
|     >>> len(BlogPost.objects(tags='mongodb')) | ||||
|     1 | ||||
|  | ||||
| Tests | ||||
| ===== | ||||
| To run the test suite, ensure you are running a local instance of MongoDB on | ||||
| the standard port and have ``nose`` installed. Then, run ``python setup.py nosetests``. | ||||
| the standard port and have ``nose`` installed. Then, run: ``python setup.py nosetests``. | ||||
|  | ||||
| To run the test suite on every supported Python and PyMongo version, you can | ||||
| use ``tox``. You'll need to make sure you have each supported Python version | ||||
| installed in your environment and then: | ||||
| To run the test suite on every supported Python version and every supported PyMongo version, | ||||
| you can use ``tox``. | ||||
| tox and each supported Python version should be installed in your environment: | ||||
|  | ||||
| .. code-block:: shell | ||||
|  | ||||
| @@ -127,16 +116,13 @@ installed in your environment and then: | ||||
|     # Run the test suites | ||||
|     $ tox | ||||
|  | ||||
| If you wish to run a subset of tests, use the nosetests convention: | ||||
| If you wish to run one single or selected tests, use the nosetest convention. It will find the folder, | ||||
| eventually the file, go to the TestClass specified after the colon and eventually right to the single test. | ||||
| Also use the -s argument if you want to print out whatever or access pdb while testing. | ||||
|  | ||||
| .. code-block:: shell | ||||
|  | ||||
|     # Run all the tests in a particular test file | ||||
|     $ python setup.py nosetests --tests tests/fields/fields.py | ||||
|     # Run only particular test class in that file | ||||
|     $ python setup.py nosetests --tests tests/fields/fields.py:FieldTest | ||||
|     # Use the -s option if you want to print some debug statements or use pdb | ||||
|     $ python setup.py nosetests --tests tests/fields/fields.py:FieldTest -s | ||||
|     $ python setup.py nosetests --tests tests/fields/fields.py:FieldTest.test_cls_field -s | ||||
|  | ||||
| Community | ||||
| ========= | ||||
| @@ -144,7 +130,8 @@ Community | ||||
|   <http://groups.google.com/group/mongoengine-users>`_ | ||||
| - `MongoEngine Developers mailing list | ||||
|   <http://groups.google.com/group/mongoengine-dev>`_ | ||||
| - `#mongoengine IRC channel <http://webchat.freenode.net/?channels=mongoengine>`_ | ||||
|  | ||||
| Contributing | ||||
| ============ | ||||
| We welcome contributions! See the `Contribution guidelines <https://github.com/MongoEngine/mongoengine/blob/master/CONTRIBUTING.rst>`_ | ||||
| We welcome contributions! see  the `Contribution guidelines <https://github.com/MongoEngine/mongoengine/blob/master/CONTRIBUTING.rst>`_ | ||||
|   | ||||
							
								
								
									
										152
									
								
								benchmark.py
									
									
									
									
									
								
							
							
						
						
									
										152
									
								
								benchmark.py
									
									
									
									
									
								
							| @@ -1,41 +1,118 @@ | ||||
| #!/usr/bin/env python | ||||
|  | ||||
| """ | ||||
| Simple benchmark comparing PyMongo and MongoEngine. | ||||
|  | ||||
| Sample run on a mid 2015 MacBook Pro (commit b282511): | ||||
|  | ||||
| Benchmarking... | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - Pymongo | ||||
| 2.58979988098 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - Pymongo write_concern={"w": 0} | ||||
| 1.26657605171 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - MongoEngine | ||||
| 8.4351580143 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries without continual assign - MongoEngine | ||||
| 7.20191693306 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True | ||||
| 6.31104588509 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True | ||||
| 6.07083487511 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False | ||||
| 5.97704291344 | ||||
| ---------------------------------------------------------------------------------------------------- | ||||
| Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False | ||||
| 5.9111430645 | ||||
| """ | ||||
|  | ||||
| import timeit | ||||
|  | ||||
|  | ||||
| def cprofile_main(): | ||||
|     from pymongo import Connection | ||||
|     connection = Connection() | ||||
|     connection.drop_database('timeit_test') | ||||
|     connection.disconnect() | ||||
|  | ||||
|     from mongoengine import Document, DictField, connect | ||||
|     connect("timeit_test") | ||||
|  | ||||
|     class Noddy(Document): | ||||
|         fields = DictField() | ||||
|  | ||||
|     for i in range(1): | ||||
|         noddy = Noddy() | ||||
|         for j in range(20): | ||||
|             noddy.fields["key" + str(j)] = "value " + str(j) | ||||
|         noddy.save() | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     """ | ||||
|     0.4 Performance Figures ... | ||||
|  | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - Pymongo | ||||
|     3.86744189262 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine | ||||
|     6.23374891281 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||
|     5.33027005196 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||
|     pass - No Cascade | ||||
|  | ||||
|     0.5.X | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - Pymongo | ||||
|     3.89597702026 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine | ||||
|     21.7735359669 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||
|     19.8670389652 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||
|     pass - No Cascade | ||||
|  | ||||
|     0.6.X | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - Pymongo | ||||
|     3.81559205055 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine | ||||
|     10.0446798801 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||
|     9.51354718208 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||
|     9.02567505836 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, force=True | ||||
|     8.44933390617 | ||||
|  | ||||
|     0.7.X | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - Pymongo | ||||
|     3.78801012039 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine | ||||
|     9.73050498962 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||
|     8.33456707001 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||
|     8.37778115273 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, force=True | ||||
|     8.36906409264 | ||||
|     0.8.X | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - Pymongo | ||||
|     3.69964408875 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - Pymongo write_concern={"w": 0} | ||||
|     3.5526599884 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine | ||||
|     7.00959801674 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries without continual assign - MongoEngine | ||||
|     5.60943293571 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade=True | ||||
|     6.715102911 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True | ||||
|     5.50644683838 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False | ||||
|     4.69851183891 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False | ||||
|     4.68946313858 | ||||
|     ---------------------------------------------------------------------------------------------------- | ||||
|     """ | ||||
|     print("Benchmarking...") | ||||
|  | ||||
|     setup = """ | ||||
| @@ -54,7 +131,7 @@ noddy = db.noddy | ||||
| for i in range(10000): | ||||
|     example = {'fields': {}} | ||||
|     for j in range(20): | ||||
|         example['fields']['key' + str(j)] = 'value ' + str(j) | ||||
|         example['fields']["key"+str(j)] = "value "+str(j) | ||||
|  | ||||
|     noddy.save(example) | ||||
|  | ||||
| @@ -69,10 +146,9 @@ myNoddys = noddy.find() | ||||
|  | ||||
|     stmt = """ | ||||
| from pymongo import MongoClient | ||||
| from pymongo.write_concern import WriteConcern | ||||
| connection = MongoClient() | ||||
|  | ||||
| db = connection.get_database('timeit_test', write_concern=WriteConcern(w=0)) | ||||
| db = connection.timeit_test | ||||
| noddy = db.noddy | ||||
|  | ||||
| for i in range(10000): | ||||
| @@ -80,7 +156,7 @@ for i in range(10000): | ||||
|     for j in range(20): | ||||
|         example['fields']["key"+str(j)] = "value "+str(j) | ||||
|  | ||||
|     noddy.save(example) | ||||
|     noddy.save(example, write_concern={"w": 0}) | ||||
|  | ||||
| myNoddys = noddy.find() | ||||
| [n for n in myNoddys] # iterate | ||||
| @@ -95,10 +171,10 @@ myNoddys = noddy.find() | ||||
| from pymongo import MongoClient | ||||
| connection = MongoClient() | ||||
| connection.drop_database('timeit_test') | ||||
| connection.close() | ||||
| connection.disconnect() | ||||
|  | ||||
| from mongoengine import Document, DictField, connect | ||||
| connect('timeit_test') | ||||
| connect("timeit_test") | ||||
|  | ||||
| class Noddy(Document): | ||||
|     fields = DictField() | ||||
|   | ||||
| @@ -2,38 +2,11 @@ | ||||
| Changelog | ||||
| ========= | ||||
|  | ||||
| Development | ||||
| =========== | ||||
| - (Fill this out as you fix issues and develop your features). | ||||
| - Fixed using sets in field choices #1481 | ||||
| - POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476 | ||||
| - POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476 | ||||
| - Fixed connecting to a replica set with PyMongo 2.x #1436 | ||||
| - Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237 | ||||
|  | ||||
| Changes in 0.11.0 | ||||
| ================= | ||||
| - BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428 | ||||
| - BREAKING CHANGE: Dropped Python 2.6 support. #1428 | ||||
| - BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428 | ||||
| - BREAKING CHANGE: Accessing a broken reference will raise a `DoesNotExist` error. In the past it used to return `None`. #1334 | ||||
| - Fixed absent rounding for DecimalField when `force_string` is set. #1103 | ||||
|  | ||||
| Changes in 0.10.8 | ||||
| ================= | ||||
| - Added support for QuerySet.batch_size (#1426) | ||||
| - Fixed query set iteration within iteration #1427 | ||||
| - Fixed an issue where specifying a MongoDB URI host would override more information than it should #1421 | ||||
| - Added ability to filter the generic reference field by ObjectId and DBRef #1425 | ||||
| - Fixed delete cascade for models with a custom primary key field #1247 | ||||
| - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | ||||
| - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | ||||
| - Fixed QuerySet#sum/average for fields w/ explicit db_field #1417 | ||||
| - Fixed filtering by embedded_doc=None #1422 | ||||
| - Added support for cursor.comment #1420 | ||||
| - Fixed doc.get_<field>_display #1419 | ||||
| - Fixed __repr__ method of the StrictDict #1424 | ||||
| - Added a deprecation warning for Python 2.6 | ||||
| - Fixed BaseQuerySet#sum/average for fields w/ explicit db_field #1417 | ||||
|  | ||||
| Changes in 0.10.7 | ||||
| ================= | ||||
|   | ||||
| @@ -33,7 +33,7 @@ the :attr:`host` to | ||||
|     corresponding parameters in :func:`~mongoengine.connect`: :: | ||||
|  | ||||
|         connect( | ||||
|             db='test', | ||||
|             name='test', | ||||
|             username='user', | ||||
|             password='12345', | ||||
|             host='mongodb://admin:qwerty@localhost/production' | ||||
| @@ -42,18 +42,13 @@ the :attr:`host` to | ||||
|     will establish connection to ``production`` database using | ||||
|     ``admin`` username and ``qwerty`` password. | ||||
|  | ||||
| Replica Sets | ||||
| ============ | ||||
| ReplicaSets | ||||
| =========== | ||||
|  | ||||
| MongoEngine supports connecting to replica sets:: | ||||
|  | ||||
|     from mongoengine import connect | ||||
|  | ||||
|     # Regular connect | ||||
|     connect('dbname', replicaset='rs-name') | ||||
|  | ||||
|     # MongoDB URI-style connect | ||||
|     connect(host='mongodb://localhost/dbname?replicaSet=rs-name') | ||||
| MongoEngine supports | ||||
| :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`. To use them, | ||||
| please use an URI style connection and provide the ``replicaSet`` name | ||||
| in the connection kwargs. | ||||
|  | ||||
| Read preferences are supported through the connection or via individual | ||||
| queries by passing the read_preference :: | ||||
| @@ -64,74 +59,76 @@ queries by passing the read_preference :: | ||||
| Multiple Databases | ||||
| ================== | ||||
|  | ||||
| To use multiple databases you can use :func:`~mongoengine.connect` and provide | ||||
| an `alias` name for the connection - if no `alias` is provided then "default" | ||||
| is used. | ||||
| Multiple database support was added in MongoEngine 0.6. To use multiple | ||||
| databases you can use :func:`~mongoengine.connect` and provide an `alias` name | ||||
| for the connection - if no `alias` is provided then "default" is used. | ||||
|  | ||||
| In the background this uses :func:`~mongoengine.register_connection` to | ||||
| store the data and you can register all aliases up front if required. | ||||
|  | ||||
| Individual documents can also support multiple databases by providing a | ||||
| `db_alias` in their meta data. This allows :class:`~pymongo.dbref.DBRef` | ||||
| objects to point across databases and collections. Below is an example schema, | ||||
| using 3 different databases to store data:: | ||||
| `db_alias` in their meta data.  This allows :class:`~pymongo.dbref.DBRef` objects | ||||
| to point across databases and collections.  Below is an example schema, using | ||||
| 3 different databases to store data:: | ||||
|  | ||||
|         class User(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|             meta = {'db_alias': 'user-db'} | ||||
|             meta = {"db_alias": "user-db"} | ||||
|  | ||||
|         class Book(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|             meta = {'db_alias': 'book-db'} | ||||
|             meta = {"db_alias": "book-db"} | ||||
|  | ||||
|         class AuthorBooks(Document): | ||||
|             author = ReferenceField(User) | ||||
|             book = ReferenceField(Book) | ||||
|  | ||||
|             meta = {'db_alias': 'users-books-db'} | ||||
|             meta = {"db_alias": "users-books-db"} | ||||
|  | ||||
|  | ||||
| Context Managers | ||||
| ================ | ||||
| Sometimes you may want to switch the database or collection to query against. | ||||
| Sometimes you may want to switch the database or collection to query against | ||||
| for a class. | ||||
| For example, archiving older data into a separate database for performance | ||||
| reasons or writing functions that dynamically choose collections to write | ||||
| a document to. | ||||
| document to. | ||||
|  | ||||
| Switch Database | ||||
| --------------- | ||||
| The :class:`~mongoengine.context_managers.switch_db` context manager allows | ||||
| you to change the database alias for a given class allowing quick and easy | ||||
| access to the same User document across databases:: | ||||
| access the same User document across databases:: | ||||
|  | ||||
|     from mongoengine.context_managers import switch_db | ||||
|  | ||||
|     class User(Document): | ||||
|         name = StringField() | ||||
|  | ||||
|         meta = {'db_alias': 'user-db'} | ||||
|         meta = {"db_alias": "user-db"} | ||||
|  | ||||
|     with switch_db(User, 'archive-user-db') as User: | ||||
|         User(name='Ross').save()  # Saves the 'archive-user-db' | ||||
|         User(name="Ross").save()  # Saves the 'archive-user-db' | ||||
|  | ||||
|  | ||||
| Switch Collection | ||||
| ----------------- | ||||
| The :class:`~mongoengine.context_managers.switch_collection` context manager | ||||
| allows you to change the collection for a given class allowing quick and easy | ||||
| access to the same Group document across collection:: | ||||
| access the same Group document across collection:: | ||||
|  | ||||
|         from mongoengine.context_managers import switch_collection | ||||
|  | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group(name='test').save()  # Saves in the default db | ||||
|         Group(name="test").save()  # Saves in the default db | ||||
|  | ||||
|         with switch_collection(Group, 'group2000') as Group: | ||||
|             Group(name='hello Group 2000 collection!').save()  # Saves in group2000 collection | ||||
|             Group(name="hello Group 2000 collection!").save()  # Saves in group2000 collection | ||||
|  | ||||
|  | ||||
|  | ||||
| .. note:: Make sure any aliases have been registered with | ||||
|   | ||||
| @@ -150,7 +150,7 @@ arguments can be set on all fields: | ||||
|     .. note:: If set, this field is also accessible through the `pk` field. | ||||
|  | ||||
| :attr:`choices` (Default: None) | ||||
|     An iterable (e.g. list, tuple or set) of choices to which the value of this | ||||
|     An iterable (e.g. a list or tuple) of choices to which the value of this | ||||
|     field should be limited. | ||||
|  | ||||
|     Can be either be a nested tuples of value (stored in mongo) and a | ||||
| @@ -361,6 +361,11 @@ Its value can take any of the following constants: | ||||
|    In Django, be sure to put all apps that have such delete rule declarations in | ||||
|    their :file:`models.py` in the :const:`INSTALLED_APPS` tuple. | ||||
|  | ||||
|  | ||||
| .. warning:: | ||||
|    Signals are not triggered when doing cascading updates / deletes - if this | ||||
|    is required you must manually handle the update / delete. | ||||
|  | ||||
| Generic reference fields | ||||
| '''''''''''''''''''''''' | ||||
| A second kind of reference field also exists, | ||||
|   | ||||
| @@ -2,13 +2,13 @@ | ||||
| Installing MongoEngine | ||||
| ====================== | ||||
|  | ||||
| To use MongoEngine, you will need to download `MongoDB <http://mongodb.com/>`_ | ||||
| To use MongoEngine, you will need to download `MongoDB <http://mongodb.org/>`_ | ||||
| and ensure it is running in an accessible location. You will also need | ||||
| `PyMongo <http://api.mongodb.org/python>`_ to use MongoEngine, but if you | ||||
| install MongoEngine using setuptools, then the dependencies will be handled for | ||||
| you. | ||||
|  | ||||
| MongoEngine is available on PyPI, so you can use :program:`pip`: | ||||
| MongoEngine is available on PyPI, so to use it you can use :program:`pip`: | ||||
|  | ||||
| .. code-block:: console | ||||
|  | ||||
|   | ||||
| @@ -479,8 +479,6 @@ operators. To use a :class:`~mongoengine.queryset.Q` object, pass it in as the | ||||
| first positional argument to :attr:`Document.objects` when you filter it by | ||||
| calling it with keyword arguments:: | ||||
|  | ||||
|     from mongoengine.queryset.visitor import Q | ||||
|  | ||||
|     # Get published posts | ||||
|     Post.objects(Q(published=True) | Q(publish_date__lte=datetime.now())) | ||||
|  | ||||
|   | ||||
| @@ -142,4 +142,11 @@ cleaner looking while still allowing manual execution of the callback:: | ||||
|         modified = DateTimeField() | ||||
|  | ||||
|  | ||||
| ReferenceFields and Signals | ||||
| --------------------------- | ||||
|  | ||||
| Currently `reverse_delete_rule` does not trigger signals on the other part of | ||||
| the relationship.  If this is required you must manually handle the | ||||
| reverse deletion. | ||||
|  | ||||
| .. _blinker: http://pypi.python.org/pypi/blinker | ||||
|   | ||||
| @@ -3,10 +3,11 @@ Tutorial | ||||
| ======== | ||||
|  | ||||
| This tutorial introduces **MongoEngine** by means of example --- we will walk | ||||
| through how to create a simple **Tumblelog** application. A tumblelog is a | ||||
| blog that supports mixed media content, including text, images, links, video, | ||||
| audio, etc. For simplicity's sake, we'll stick to text, image, and link | ||||
| entries. As the purpose of this tutorial is to introduce MongoEngine, we'll | ||||
| through how to create a simple **Tumblelog** application. A Tumblelog is a type | ||||
| of blog where posts are not constrained to being conventional text-based posts. | ||||
| As well as text-based entries, users may post images, links, videos, etc. For | ||||
| simplicity's sake, we'll stick to text, image and link entries in our | ||||
| application. As the purpose of this tutorial is to introduce MongoEngine, we'll | ||||
| focus on the data-modelling side of the application, leaving out a user | ||||
| interface. | ||||
|  | ||||
| @@ -15,14 +16,14 @@ Getting started | ||||
|  | ||||
| Before we start, make sure that a copy of MongoDB is running in an accessible | ||||
| location --- running it locally will be easier, but if that is not an option | ||||
| then it may be run on a remote server. If you haven't installed MongoEngine, | ||||
| then it may be run on a remote server. If you haven't installed mongoengine, | ||||
| simply use pip to install it like so:: | ||||
|  | ||||
|     $ pip install mongoengine | ||||
|  | ||||
| Before we can start using MongoEngine, we need to tell it how to connect to our | ||||
| instance of :program:`mongod`. For this we use the :func:`~mongoengine.connect` | ||||
| function. If running locally, the only argument we need to provide is the name | ||||
| function. If running locally the only argument we need to provide is the name | ||||
| of the MongoDB database to use:: | ||||
|  | ||||
|     from mongoengine import * | ||||
| @@ -38,18 +39,18 @@ Defining our documents | ||||
| MongoDB is *schemaless*, which means that no schema is enforced by the database | ||||
| --- we may add and remove fields however we want and MongoDB won't complain. | ||||
| This makes life a lot easier in many regards, especially when there is a change | ||||
| to the data model. However, defining schemas for our documents can help to iron | ||||
| out bugs involving incorrect types or missing fields, and also allow us to | ||||
| to the data model. However, defining schemata for our documents can help to | ||||
| iron out bugs involving incorrect types or missing fields, and also allow us to | ||||
| define utility methods on our documents in the same way that traditional | ||||
| :abbr:`ORMs (Object-Relational Mappers)` do. | ||||
|  | ||||
| In our Tumblelog application we need to store several different types of | ||||
| information. We will need to have a collection of **users**, so that we may | ||||
| information.  We will need to have a collection of **users**, so that we may | ||||
| link posts to an individual. We also need to store our different types of | ||||
| **posts** (eg: text, image and link) in the database. To aid navigation of our | ||||
| Tumblelog, posts may have **tags** associated with them, so that the list of | ||||
| posts shown to the user may be limited to posts that have been assigned a | ||||
| specific tag. Finally, it would be nice if **comments** could be added to | ||||
| specific tag.  Finally, it would be nice if **comments** could be added to | ||||
| posts. We'll start with **users**, as the other document models are slightly | ||||
| more involved. | ||||
|  | ||||
| @@ -77,7 +78,7 @@ Now we'll think about how to store the rest of the information. If we were | ||||
| using a relational database, we would most likely have a table of **posts**, a | ||||
| table of **comments** and a table of **tags**.  To associate the comments with | ||||
| individual posts, we would put a column in the comments table that contained a | ||||
| foreign key to the posts table. We'd also need a link table to provide the | ||||
| foreign key to the posts table.  We'd also need a link table to provide the | ||||
| many-to-many relationship between posts and tags. Then we'd need to address the | ||||
| problem of storing the specialised post-types (text, image and link). There are | ||||
| several ways we can achieve this, but each of them have their problems --- none | ||||
| @@ -95,7 +96,7 @@ using* the new fields we need to support video posts. This fits with the | ||||
| Object-Oriented principle of *inheritance* nicely. We can think of | ||||
| :class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and | ||||
| :class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports | ||||
| this kind of modeling out of the box --- all you need do is turn on inheritance | ||||
| this kind of modelling out of the box --- all you need do is turn on inheritance | ||||
| by setting :attr:`allow_inheritance` to True in the :attr:`meta`:: | ||||
|  | ||||
|     class Post(Document): | ||||
| @@ -127,8 +128,8 @@ link table, we can just store a list of tags in each post. So, for both | ||||
| efficiency and simplicity's sake, we'll store the tags as strings directly | ||||
| within the post, rather than storing references to tags in a separate | ||||
| collection. Especially as tags are generally very short (often even shorter | ||||
| than a document's id), this denormalization won't impact the size of the | ||||
| database very strongly. Let's take a look at the code of our modified | ||||
| than a document's id), this denormalisation won't impact very strongly on the | ||||
| size of our database. So let's take a look that the code our modified | ||||
| :class:`Post` class:: | ||||
|  | ||||
|     class Post(Document): | ||||
| @@ -140,7 +141,7 @@ The :class:`~mongoengine.fields.ListField` object that is used to define a Post' | ||||
| takes a field object as its first argument --- this means that you can have | ||||
| lists of any type of field (including lists). | ||||
|  | ||||
| .. note:: We don't need to modify the specialized post types as they all | ||||
| .. note:: We don't need to modify the specialised post types as they all | ||||
|     inherit from :class:`Post`. | ||||
|  | ||||
| Comments | ||||
| @@ -148,7 +149,7 @@ Comments | ||||
|  | ||||
| A comment is typically associated with *one* post. In a relational database, to | ||||
| display a post with its comments, we would have to retrieve the post from the | ||||
| database and then query the database again for the comments associated with the | ||||
| database, then query the database again for the comments associated with the | ||||
| post. This works, but there is no real reason to be storing the comments | ||||
| separately from their associated posts, other than to work around the | ||||
| relational model. Using MongoDB we can store the comments as a list of | ||||
| @@ -218,8 +219,8 @@ Now that we've got our user in the database, let's add a couple of posts:: | ||||
|     post2.tags = ['mongoengine'] | ||||
|     post2.save() | ||||
|  | ||||
| .. note:: If you change a field on an object that has already been saved and | ||||
|     then call :meth:`save` again, the document will be updated. | ||||
| .. note:: If you change a field on a object that has already been saved, then | ||||
|     call :meth:`save` again, the document will be updated. | ||||
|  | ||||
| Accessing our data | ||||
| ================== | ||||
| @@ -231,17 +232,17 @@ used to access the documents in the database collection associated with that | ||||
| class. So let's see how we can get our posts' titles:: | ||||
|  | ||||
|     for post in Post.objects: | ||||
|         print(post.title) | ||||
|         print post.title | ||||
|  | ||||
| Retrieving type-specific information | ||||
| ------------------------------------ | ||||
|  | ||||
| This will print the titles of our posts, one on each line. But what if we want | ||||
| This will print the titles of our posts, one on each line. But What if we want | ||||
| to access the type-specific data (link_url, content, etc.)? One way is simply | ||||
| to use the :attr:`objects` attribute of a subclass of :class:`Post`:: | ||||
|  | ||||
|     for post in TextPost.objects: | ||||
|         print(post.content) | ||||
|         print post.content | ||||
|  | ||||
| Using TextPost's :attr:`objects` attribute only returns documents that were | ||||
| created using :class:`TextPost`. Actually, there is a more general rule here: | ||||
| @@ -258,14 +259,16 @@ instances of :class:`Post` --- they were instances of the subclass of | ||||
| practice:: | ||||
|  | ||||
|     for post in Post.objects: | ||||
|         print(post.title) | ||||
|         print('=' * len(post.title)) | ||||
|         print post.title | ||||
|         print '=' * len(post.title) | ||||
|  | ||||
|         if isinstance(post, TextPost): | ||||
|             print(post.content) | ||||
|             print post.content | ||||
|  | ||||
|         if isinstance(post, LinkPost): | ||||
|             print('Link: {}'.format(post.link_url)) | ||||
|             print 'Link:', post.link_url | ||||
|  | ||||
|         print | ||||
|  | ||||
| This would print the title of each post, followed by the content if it was a | ||||
| text post, and "Link: <url>" if it was a link post. | ||||
| @@ -280,7 +283,7 @@ your query.  Let's adjust our query so that only posts with the tag "mongodb" | ||||
| are returned:: | ||||
|  | ||||
|     for post in Post.objects(tags='mongodb'): | ||||
|         print(post.title) | ||||
|         print post.title | ||||
|  | ||||
| There are also methods available on :class:`~mongoengine.queryset.QuerySet` | ||||
| objects that allow different results to be returned, for example, calling | ||||
| @@ -289,11 +292,11 @@ the first matched by the query you provide. Aggregation functions may also be | ||||
| used on :class:`~mongoengine.queryset.QuerySet` objects:: | ||||
|  | ||||
|     num_posts = Post.objects(tags='mongodb').count() | ||||
|     print('Found {} posts with tag "mongodb"'.format(num_posts)) | ||||
|     print 'Found %d posts with tag "mongodb"' % num_posts | ||||
|  | ||||
| Learning more about MongoEngine | ||||
| Learning more about mongoengine | ||||
| ------------------------------- | ||||
|  | ||||
| If you got this far you've made a great start, so well done! The next step on | ||||
| your MongoEngine journey is the `full user guide <guide/index.html>`_, where | ||||
| you can learn in-depth about how to use MongoEngine and MongoDB. | ||||
| If you got this far you've made a great start, so well done!  The next step on | ||||
| your mongoengine journey is the `full user guide <guide/index.html>`_, where you | ||||
| can learn indepth about how to use mongoengine and mongodb. | ||||
|   | ||||
| @@ -2,53 +2,6 @@ | ||||
| Upgrading | ||||
| ######### | ||||
|  | ||||
| Development | ||||
| *********** | ||||
| (Fill this out whenever you introduce breaking changes to MongoEngine) | ||||
|  | ||||
| This release includes various fixes for the `BaseQuerySet` methods and how they | ||||
| are chained together. Since version 0.10.1 applying limit/skip/hint/batch_size | ||||
| to an already-existing queryset wouldn't modify the underlying PyMongo cursor. | ||||
| This has been fixed now, so you'll need to make sure that your code didn't rely | ||||
| on the broken implementation. | ||||
|  | ||||
| Additionally, a public `BaseQuerySet.clone_into` has been renamed to a private | ||||
| `_clone_into`. If you directly used that method in your code, you'll need to | ||||
| rename its occurrences. | ||||
|  | ||||
| 0.11.0 | ||||
| ****** | ||||
| This release includes a major rehaul of MongoEngine's code quality and | ||||
| introduces a few breaking changes. It also touches many different parts of | ||||
| the package and although all the changes have been tested and scrutinized, | ||||
| you're encouraged to thorougly test the upgrade. | ||||
|  | ||||
| First breaking change involves renaming `ConnectionError` to `MongoEngineConnectionError`. | ||||
| If you import or catch this exception, you'll need to rename it in your code. | ||||
|  | ||||
| Second breaking change drops Python v2.6 support. If you run MongoEngine on | ||||
| that Python version, you'll need to upgrade it first. | ||||
|  | ||||
| Third breaking change drops an old backward compatibility measure where | ||||
| `from mongoengine.base import ErrorClass` would work on top of | ||||
| `from mongoengine.errors import ErrorClass` (where `ErrorClass` is e.g. | ||||
| `ValidationError`). If you import any exceptions from `mongoengine.base`, | ||||
| change it to `mongoengine.errors`. | ||||
|  | ||||
| 0.10.8 | ||||
| ****** | ||||
| This version fixed an issue where specifying a MongoDB URI host would override | ||||
| more information than it should. These changes are minor, but they still | ||||
| subtly modify the connection logic and thus you're encouraged to test your | ||||
| MongoDB connection before shipping v0.10.8 in production. | ||||
|  | ||||
| 0.10.7 | ||||
| ****** | ||||
|  | ||||
| `QuerySet.aggregate_sum` and `QuerySet.aggregate_average` are dropped. Use | ||||
| `QuerySet.sum` and `QuerySet.average` instead which use the aggreation framework | ||||
| by default from now on. | ||||
|  | ||||
| 0.9.0 | ||||
| ***** | ||||
|  | ||||
|   | ||||
| @@ -1,35 +1,25 @@ | ||||
| # Import submodules so that we can expose their __all__ | ||||
| from mongoengine import connection | ||||
| from mongoengine import document | ||||
| from mongoengine import errors | ||||
| from mongoengine import fields | ||||
| from mongoengine import queryset | ||||
| from mongoengine import signals | ||||
| import connection | ||||
| from connection import * | ||||
| import document | ||||
| from document import * | ||||
| import errors | ||||
| from errors import * | ||||
| import fields | ||||
| from fields import * | ||||
| import queryset | ||||
| from queryset import * | ||||
| import signals | ||||
| from signals import * | ||||
|  | ||||
| # Import everything from each submodule so that it can be accessed via | ||||
| # mongoengine, e.g. instead of `from mongoengine.connection import connect`, | ||||
| # users can simply use `from mongoengine import connect`, or even | ||||
| # `from mongoengine import *` and then `connect('testdb')`. | ||||
| from mongoengine.connection import * | ||||
| from mongoengine.document import * | ||||
| from mongoengine.errors import * | ||||
| from mongoengine.fields import * | ||||
| from mongoengine.queryset import * | ||||
| from mongoengine.signals import * | ||||
| __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + | ||||
|            list(queryset.__all__) + signals.__all__ + list(errors.__all__)) | ||||
|  | ||||
|  | ||||
| __all__ = (list(document.__all__) + list(fields.__all__) + | ||||
|            list(connection.__all__) + list(queryset.__all__) + | ||||
|            list(signals.__all__) + list(errors.__all__)) | ||||
|  | ||||
|  | ||||
| VERSION = (0, 11, 0) | ||||
| VERSION = (0, 10, 7) | ||||
|  | ||||
|  | ||||
| def get_version(): | ||||
|     """Return the VERSION as a string, e.g. for VERSION == (0, 10, 7), | ||||
|     return '0.10.7'. | ||||
|     """ | ||||
|     if isinstance(VERSION[-1], basestring): | ||||
|         return '.'.join(map(str, VERSION[:-1])) + VERSION[-1] | ||||
|     return '.'.join(map(str, VERSION)) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,28 +1,8 @@ | ||||
| # Base module is split into several files for convenience. Files inside of | ||||
| # this module should import from a specific submodule (e.g. | ||||
| # `from mongoengine.base.document import BaseDocument`), but all of the | ||||
| # other modules should import directly from the top-level module (e.g. | ||||
| # `from mongoengine.base import BaseDocument`). This approach is cleaner and | ||||
| # also helps with cyclical import errors. | ||||
| from mongoengine.base.common import * | ||||
| from mongoengine.base.datastructures import * | ||||
| from mongoengine.base.document import * | ||||
| from mongoengine.base.fields import * | ||||
| from mongoengine.base.metaclasses import * | ||||
|  | ||||
| __all__ = ( | ||||
|     # common | ||||
|     'UPDATE_OPERATORS', '_document_registry', 'get_document', | ||||
|  | ||||
|     # datastructures | ||||
|     'BaseDict', 'BaseList', 'EmbeddedDocumentList', | ||||
|  | ||||
|     # document | ||||
|     'BaseDocument', | ||||
|  | ||||
|     # fields | ||||
|     'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField', | ||||
|  | ||||
|     # metaclasses | ||||
|     'DocumentMetaclass', 'TopLevelDocumentMetaclass' | ||||
| ) | ||||
| # Help with backwards compatibility | ||||
| from mongoengine.errors import * | ||||
|   | ||||
| @@ -1,18 +1,13 @@ | ||||
| from mongoengine.errors import NotRegistered | ||||
|  | ||||
| __all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry') | ||||
|  | ||||
|  | ||||
| UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', | ||||
|                         'push_all', 'pull', 'pull_all', 'add_to_set', | ||||
|                         'set_on_insert', 'min', 'max', 'rename']) | ||||
| __all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') | ||||
|  | ||||
| ALLOW_INHERITANCE = False | ||||
|  | ||||
| _document_registry = {} | ||||
|  | ||||
|  | ||||
| def get_document(name): | ||||
|     """Get a document class by name.""" | ||||
|     doc = _document_registry.get(name, None) | ||||
|     if not doc: | ||||
|         # Possible old style name | ||||
|   | ||||
| @@ -1,16 +1,14 @@ | ||||
| import itertools | ||||
| import weakref | ||||
|  | ||||
| import six | ||||
|  | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | ||||
|  | ||||
| __all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList') | ||||
| __all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList") | ||||
|  | ||||
|  | ||||
| class BaseDict(dict): | ||||
|     """A special dict so we can watch any changes.""" | ||||
|     """A special dict so we can watch any changes""" | ||||
|  | ||||
|     _dereferenced = False | ||||
|     _instance = None | ||||
| @@ -95,7 +93,8 @@ class BaseDict(dict): | ||||
|  | ||||
|  | ||||
| class BaseList(list): | ||||
|     """A special list so we can watch any changes.""" | ||||
|     """A special list so we can watch any changes | ||||
|     """ | ||||
|  | ||||
|     _dereferenced = False | ||||
|     _instance = None | ||||
| @@ -138,7 +137,10 @@ class BaseList(list): | ||||
|         return super(BaseList, self).__setitem__(key, value) | ||||
|  | ||||
|     def __delitem__(self, key, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         if isinstance(key, slice): | ||||
|             self._mark_as_changed() | ||||
|         else: | ||||
|             self._mark_as_changed(key) | ||||
|         return super(BaseList, self).__delitem__(key) | ||||
|  | ||||
|     def __setslice__(self, *args, **kwargs): | ||||
| @@ -207,22 +209,17 @@ class BaseList(list): | ||||
| class EmbeddedDocumentList(BaseList): | ||||
|  | ||||
|     @classmethod | ||||
|     def __match_all(cls, embedded_doc, kwargs): | ||||
|         """Return True if a given embedded doc matches all the filter | ||||
|         kwargs. If it doesn't return False. | ||||
|         """ | ||||
|         for key, expected_value in kwargs.items(): | ||||
|             doc_val = getattr(embedded_doc, key) | ||||
|             if doc_val != expected_value and six.text_type(doc_val) != expected_value: | ||||
|                 return False | ||||
|         return True | ||||
|     def __match_all(cls, i, kwargs): | ||||
|         items = kwargs.items() | ||||
|         return all([ | ||||
|             getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items | ||||
|         ]) | ||||
|  | ||||
|     @classmethod | ||||
|     def __only_matches(cls, embedded_docs, kwargs): | ||||
|         """Return embedded docs that match the filter kwargs.""" | ||||
|     def __only_matches(cls, obj, kwargs): | ||||
|         if not kwargs: | ||||
|             return embedded_docs | ||||
|         return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)] | ||||
|             return obj | ||||
|         return filter(lambda i: cls.__match_all(i, kwargs), obj) | ||||
|  | ||||
|     def __init__(self, list_items, instance, name): | ||||
|         super(EmbeddedDocumentList, self).__init__(list_items, instance, name) | ||||
| @@ -288,18 +285,18 @@ class EmbeddedDocumentList(BaseList): | ||||
|         values = self.__only_matches(self, kwargs) | ||||
|         if len(values) == 0: | ||||
|             raise DoesNotExist( | ||||
|                 '%s matching query does not exist.' % self._name | ||||
|                 "%s matching query does not exist." % self._name | ||||
|             ) | ||||
|         elif len(values) > 1: | ||||
|             raise MultipleObjectsReturned( | ||||
|                 '%d items returned, instead of 1' % len(values) | ||||
|                 "%d items returned, instead of 1" % len(values) | ||||
|             ) | ||||
|  | ||||
|         return values[0] | ||||
|  | ||||
|     def first(self): | ||||
|         """Return the first embedded document in the list, or ``None`` | ||||
|         if empty. | ||||
|         """ | ||||
|         Returns the first embedded document in the list, or ``None`` if empty. | ||||
|         """ | ||||
|         if len(self) > 0: | ||||
|             return self[0] | ||||
| @@ -429,7 +426,7 @@ class StrictDict(object): | ||||
|     def __eq__(self, other): | ||||
|         return self.items() == other.items() | ||||
|  | ||||
|     def __ne__(self, other): | ||||
|     def __neq__(self, other): | ||||
|         return self.items() != other.items() | ||||
|  | ||||
|     @classmethod | ||||
| @@ -441,7 +438,7 @@ class StrictDict(object): | ||||
|                 __slots__ = allowed_keys_tuple | ||||
|  | ||||
|                 def __repr__(self): | ||||
|                     return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) | ||||
|  | ||||
|             cls._classes[allowed_keys] = SpecificStrictDict | ||||
|         return cls._classes[allowed_keys] | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import copy | ||||
| import numbers | ||||
| import operator | ||||
| from collections import Hashable | ||||
| from functools import partial | ||||
|  | ||||
| @@ -7,27 +8,30 @@ from bson import ObjectId, json_util | ||||
| from bson.dbref import DBRef | ||||
| from bson.son import SON | ||||
| import pymongo | ||||
| import six | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.base.common import get_document | ||||
| from mongoengine.base.datastructures import (BaseDict, BaseList, | ||||
|                                              EmbeddedDocumentList, | ||||
|                                              SemiStrictDict, StrictDict) | ||||
| from mongoengine.base.common import ALLOW_INHERITANCE, get_document | ||||
| from mongoengine.base.datastructures import ( | ||||
|     BaseDict, | ||||
|     BaseList, | ||||
|     EmbeddedDocumentList, | ||||
|     SemiStrictDict, | ||||
|     StrictDict | ||||
| ) | ||||
| from mongoengine.base.fields import ComplexBaseField | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, | ||||
|                                 LookUpError, OperationError, ValidationError) | ||||
|                                 LookUpError, ValidationError) | ||||
| from mongoengine.python_support import PY3, txt_type | ||||
|  | ||||
| __all__ = ('BaseDocument',) | ||||
| __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | ||||
|  | ||||
| NON_FIELD_ERRORS = '__all__' | ||||
|  | ||||
|  | ||||
| class BaseDocument(object): | ||||
|     __slots__ = ('_changed_fields', '_initialised', '_created', '_data', | ||||
|                  '_dynamic_fields', '_auto_id_field', '_db_field_map', | ||||
|                  '__weakref__') | ||||
|                  '_dynamic_fields', '_auto_id_field', '_db_field_map', '__weakref__') | ||||
|  | ||||
|     _dynamic = False | ||||
|     _dynamic_lock = True | ||||
| @@ -53,15 +57,15 @@ class BaseDocument(object): | ||||
|                 name = next(field) | ||||
|                 if name in values: | ||||
|                     raise TypeError( | ||||
|                         'Multiple values for keyword argument "%s"' % name) | ||||
|                         "Multiple values for keyword argument '" + name + "'") | ||||
|                 values[name] = value | ||||
|  | ||||
|         __auto_convert = values.pop('__auto_convert', True) | ||||
|         __auto_convert = values.pop("__auto_convert", True) | ||||
|  | ||||
|         # 399: set default values only to fields loaded from DB | ||||
|         __only_fields = set(values.pop('__only_fields', values)) | ||||
|         __only_fields = set(values.pop("__only_fields", values)) | ||||
|  | ||||
|         _created = values.pop('_created', True) | ||||
|         _created = values.pop("_created", True) | ||||
|  | ||||
|         signals.pre_init.send(self.__class__, document=self, values=values) | ||||
|  | ||||
| @@ -72,7 +76,7 @@ class BaseDocument(object): | ||||
|                 self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) | ||||
|             if _undefined_fields: | ||||
|                 msg = ( | ||||
|                     'The fields "{0}" do not exist on the document "{1}"' | ||||
|                     "The fields '{0}' do not exist on the document '{1}'" | ||||
|                 ).format(_undefined_fields, self._class_name) | ||||
|                 raise FieldDoesNotExist(msg) | ||||
|  | ||||
| @@ -91,7 +95,7 @@ class BaseDocument(object): | ||||
|             value = getattr(self, key, None) | ||||
|             setattr(self, key, value) | ||||
|  | ||||
|         if '_cls' not in values: | ||||
|         if "_cls" not in values: | ||||
|             self._cls = self._class_name | ||||
|  | ||||
|         # Set passed values after initialisation | ||||
| @@ -146,7 +150,7 @@ class BaseDocument(object): | ||||
|         if self._dynamic and not self._dynamic_lock: | ||||
|  | ||||
|             if not hasattr(self, name) and not name.startswith('_'): | ||||
|                 DynamicField = _import_class('DynamicField') | ||||
|                 DynamicField = _import_class("DynamicField") | ||||
|                 field = DynamicField(db_field=name) | ||||
|                 field.name = name | ||||
|                 self._dynamic_fields[name] = field | ||||
| @@ -165,13 +169,11 @@ class BaseDocument(object): | ||||
|         except AttributeError: | ||||
|             self__created = True | ||||
|  | ||||
|         if ( | ||||
|             self._is_document and | ||||
|             not self__created and | ||||
|             name in self._meta.get('shard_key', tuple()) and | ||||
|             self._data.get(name) != value | ||||
|         ): | ||||
|             msg = 'Shard Keys are immutable. Tried to update %s' % name | ||||
|         if (self._is_document and not self__created and | ||||
|                 name in self._meta.get('shard_key', tuple()) and | ||||
|                 self._data.get(name) != value): | ||||
|             OperationError = _import_class('OperationError') | ||||
|             msg = "Shard Keys are immutable. Tried to update %s" % name | ||||
|             raise OperationError(msg) | ||||
|  | ||||
|         try: | ||||
| @@ -195,8 +197,8 @@ class BaseDocument(object): | ||||
|         return data | ||||
|  | ||||
|     def __setstate__(self, data): | ||||
|         if isinstance(data['_data'], SON): | ||||
|             data['_data'] = self.__class__._from_son(data['_data'])._data | ||||
|         if isinstance(data["_data"], SON): | ||||
|             data["_data"] = self.__class__._from_son(data["_data"])._data | ||||
|         for k in ('_changed_fields', '_initialised', '_created', '_data', | ||||
|                   '_dynamic_fields'): | ||||
|             if k in data: | ||||
| @@ -210,7 +212,7 @@ class BaseDocument(object): | ||||
|  | ||||
|         dynamic_fields = data.get('_dynamic_fields') or SON() | ||||
|         for k in dynamic_fields.keys(): | ||||
|             setattr(self, k, data['_data'].get(k)) | ||||
|             setattr(self, k, data["_data"].get(k)) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return iter(self._fields_ordered) | ||||
| @@ -252,13 +254,12 @@ class BaseDocument(object): | ||||
|         return repr_type('<%s: %s>' % (self.__class__.__name__, u)) | ||||
|  | ||||
|     def __str__(self): | ||||
|         # TODO this could be simpler? | ||||
|         if hasattr(self, '__unicode__'): | ||||
|             if six.PY3: | ||||
|             if PY3: | ||||
|                 return self.__unicode__() | ||||
|             else: | ||||
|                 return six.text_type(self).encode('utf-8') | ||||
|         return six.text_type('%s object' % self.__class__.__name__) | ||||
|                 return unicode(self).encode('utf-8') | ||||
|         return txt_type('%s object' % self.__class__.__name__) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: | ||||
| @@ -307,7 +308,7 @@ class BaseDocument(object): | ||||
|             fields = [] | ||||
|  | ||||
|         data = SON() | ||||
|         data['_id'] = None | ||||
|         data["_id"] = None | ||||
|         data['_cls'] = self._class_name | ||||
|  | ||||
|         # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] | ||||
| @@ -350,8 +351,18 @@ class BaseDocument(object): | ||||
|                 else: | ||||
|                     data[field.name] = value | ||||
|  | ||||
|         # If "_id" has not been set, then try and set it | ||||
|         Document = _import_class("Document") | ||||
|         if isinstance(self, Document): | ||||
|             if data["_id"] is None: | ||||
|                 data["_id"] = self._data.get("id", None) | ||||
|  | ||||
|         if data['_id'] is None: | ||||
|             data.pop('_id') | ||||
|  | ||||
|         # Only add _cls if allow_inheritance is True | ||||
|         if not self._meta.get('allow_inheritance'): | ||||
|         if (not hasattr(self, '_meta') or | ||||
|                 not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): | ||||
|             data.pop('_cls') | ||||
|  | ||||
|         return data | ||||
| @@ -365,16 +376,16 @@ class BaseDocument(object): | ||||
|         if clean: | ||||
|             try: | ||||
|                 self.clean() | ||||
|             except ValidationError as error: | ||||
|             except ValidationError, error: | ||||
|                 errors[NON_FIELD_ERRORS] = error | ||||
|  | ||||
|         # Get a list of tuples of field names and their current values | ||||
|         fields = [(self._fields.get(name, self._dynamic_fields.get(name)), | ||||
|                    self._data.get(name)) for name in self._fields_ordered] | ||||
|  | ||||
|         EmbeddedDocumentField = _import_class('EmbeddedDocumentField') | ||||
|         EmbeddedDocumentField = _import_class("EmbeddedDocumentField") | ||||
|         GenericEmbeddedDocumentField = _import_class( | ||||
|             'GenericEmbeddedDocumentField') | ||||
|             "GenericEmbeddedDocumentField") | ||||
|  | ||||
|         for field, value in fields: | ||||
|             if value is not None: | ||||
| @@ -384,29 +395,27 @@ class BaseDocument(object): | ||||
|                         field._validate(value, clean=clean) | ||||
|                     else: | ||||
|                         field._validate(value) | ||||
|                 except ValidationError as error: | ||||
|                 except ValidationError, error: | ||||
|                     errors[field.name] = error.errors or error | ||||
|                 except (ValueError, AttributeError, AssertionError) as error: | ||||
|                 except (ValueError, AttributeError, AssertionError), error: | ||||
|                     errors[field.name] = error | ||||
|             elif field.required and not getattr(field, '_auto_gen', False): | ||||
|                 errors[field.name] = ValidationError('Field is required', | ||||
|                                                      field_name=field.name) | ||||
|  | ||||
|         if errors: | ||||
|             pk = 'None' | ||||
|             pk = "None" | ||||
|             if hasattr(self, 'pk'): | ||||
|                 pk = self.pk | ||||
|             elif self._instance and hasattr(self._instance, 'pk'): | ||||
|                 pk = self._instance.pk | ||||
|             message = 'ValidationError (%s:%s) ' % (self._class_name, pk) | ||||
|             message = "ValidationError (%s:%s) " % (self._class_name, pk) | ||||
|             raise ValidationError(message, errors=errors) | ||||
|  | ||||
|     def to_json(self, *args, **kwargs): | ||||
|         """Convert this document to JSON. | ||||
|  | ||||
|         :param use_db_field: Serialize field names as they appear in | ||||
|             MongoDB (as opposed to attribute names on this document). | ||||
|             Defaults to True. | ||||
|         """Converts a document to JSON. | ||||
|         :param use_db_field: Set to True by default but enables the output of the json structure with the field names | ||||
|             and not the mongodb store db_names in case of set to False | ||||
|         """ | ||||
|         use_db_field = kwargs.pop('use_db_field', True) | ||||
|         return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) | ||||
| @@ -417,26 +426,33 @@ class BaseDocument(object): | ||||
|         return cls._from_son(json_util.loads(json_data), created=created) | ||||
|  | ||||
|     def __expand_dynamic_values(self, name, value): | ||||
|         """Expand any dynamic values to their correct types / values.""" | ||||
|         """expand any dynamic values to their correct types / values""" | ||||
|         if not isinstance(value, (dict, list, tuple)): | ||||
|             return value | ||||
|  | ||||
|         # If the value is a dict with '_cls' in it, turn it into a document | ||||
|         is_dict = isinstance(value, dict) | ||||
|         if is_dict and '_cls' in value: | ||||
|         EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') | ||||
|  | ||||
|         is_list = False | ||||
|         if not hasattr(value, 'items'): | ||||
|             is_list = True | ||||
|             value = dict([(k, v) for k, v in enumerate(value)]) | ||||
|  | ||||
|         if not is_list and '_cls' in value: | ||||
|             cls = get_document(value['_cls']) | ||||
|             return cls(**value) | ||||
|  | ||||
|         if is_dict: | ||||
|             value = { | ||||
|                 k: self.__expand_dynamic_values(k, v) | ||||
|                 for k, v in value.items() | ||||
|             } | ||||
|         data = {} | ||||
|         for k, v in value.items(): | ||||
|             key = name if is_list else k | ||||
|             data[k] = self.__expand_dynamic_values(key, v) | ||||
|  | ||||
|         if is_list:  # Convert back to a list | ||||
|             data_items = sorted(data.items(), key=operator.itemgetter(0)) | ||||
|             value = [v for k, v in data_items] | ||||
|         else: | ||||
|             value = [self.__expand_dynamic_values(name, v) for v in value] | ||||
|             value = data | ||||
|  | ||||
|         # Convert lists / values so we can watch for any changes on them | ||||
|         EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') | ||||
|         if (isinstance(value, (list, tuple)) and | ||||
|                 not isinstance(value, BaseList)): | ||||
|             if issubclass(type(self), EmbeddedDocumentListField): | ||||
| @@ -449,7 +465,8 @@ class BaseDocument(object): | ||||
|         return value | ||||
|  | ||||
|     def _mark_as_changed(self, key): | ||||
|         """Mark a key as explicitly changed by the user.""" | ||||
|         """Marks a key as explicitly changed by the user | ||||
|         """ | ||||
|         if not key: | ||||
|             return | ||||
|  | ||||
| @@ -479,11 +496,10 @@ class BaseDocument(object): | ||||
|                         remove(field) | ||||
|  | ||||
|     def _clear_changed_fields(self): | ||||
|         """Using _get_changed_fields iterate and remove any fields that | ||||
|         are marked as changed. | ||||
|         """ | ||||
|         """Using get_changed_fields iterate and remove any fields that are | ||||
|         marked as changed""" | ||||
|         for changed in self._get_changed_fields(): | ||||
|             parts = changed.split('.') | ||||
|             parts = changed.split(".") | ||||
|             data = self | ||||
|             for part in parts: | ||||
|                 if isinstance(data, list): | ||||
| @@ -495,13 +511,10 @@ class BaseDocument(object): | ||||
|                     data = data.get(part, None) | ||||
|                 else: | ||||
|                     data = getattr(data, part, None) | ||||
|  | ||||
|                 if hasattr(data, '_changed_fields'): | ||||
|                     if getattr(data, '_is_document', False): | ||||
|                 if hasattr(data, "_changed_fields"): | ||||
|                     if hasattr(data, "_is_document") and data._is_document: | ||||
|                         continue | ||||
|  | ||||
|                     data._changed_fields = [] | ||||
|  | ||||
|         self._changed_fields = [] | ||||
|  | ||||
|     def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): | ||||
| @@ -513,27 +526,26 @@ class BaseDocument(object): | ||||
|             iterator = data.iteritems() | ||||
|  | ||||
|         for index, value in iterator: | ||||
|             list_key = '%s%s.' % (key, index) | ||||
|             list_key = "%s%s." % (key, index) | ||||
|             # don't check anything lower if this key is already marked | ||||
|             # as changed. | ||||
|             if list_key[:-1] in changed_fields: | ||||
|                 continue | ||||
|             if hasattr(value, '_get_changed_fields'): | ||||
|                 changed = value._get_changed_fields(inspected) | ||||
|                 changed_fields += ['%s%s' % (list_key, k) | ||||
|                 changed_fields += ["%s%s" % (list_key, k) | ||||
|                                    for k in changed if k] | ||||
|             elif isinstance(value, (list, tuple, dict)): | ||||
|                 self._nestable_types_changed_fields( | ||||
|                     changed_fields, list_key, value, inspected) | ||||
|  | ||||
|     def _get_changed_fields(self, inspected=None): | ||||
|         """Return a list of all fields that have explicitly been changed. | ||||
|         """Returns a list of all fields that have explicitly been changed. | ||||
|         """ | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument') | ||||
|         ReferenceField = _import_class('ReferenceField') | ||||
|         SortedListField = _import_class('SortedListField') | ||||
|  | ||||
|         EmbeddedDocument = _import_class("EmbeddedDocument") | ||||
|         DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") | ||||
|         ReferenceField = _import_class("ReferenceField") | ||||
|         SortedListField = _import_class("SortedListField") | ||||
|         changed_fields = [] | ||||
|         changed_fields += getattr(self, '_changed_fields', []) | ||||
|  | ||||
| @@ -560,7 +572,7 @@ class BaseDocument(object): | ||||
|             ): | ||||
|                 # Find all embedded fields that have been changed | ||||
|                 changed = data._get_changed_fields(inspected) | ||||
|                 changed_fields += ['%s%s' % (key, k) for k in changed if k] | ||||
|                 changed_fields += ["%s%s" % (key, k) for k in changed if k] | ||||
|             elif (isinstance(data, (list, tuple, dict)) and | ||||
|                     db_field_name not in changed_fields): | ||||
|                 if (hasattr(field, 'field') and | ||||
| @@ -664,28 +676,21 @@ class BaseDocument(object): | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_collection_name(cls): | ||||
|         """Return the collection name for this class. None for abstract | ||||
|         class. | ||||
|         """Returns the collection name for this class. None for abstract class | ||||
|         """ | ||||
|         return cls._meta.get('collection', None) | ||||
|  | ||||
|     @classmethod | ||||
|     def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): | ||||
|         """Create an instance of a Document (subclass) from a PyMongo | ||||
|         SON. | ||||
|         """Create an instance of a Document (subclass) from a PyMongo SON. | ||||
|         """ | ||||
|         if not only_fields: | ||||
|             only_fields = [] | ||||
|  | ||||
|         if son and not isinstance(son, dict): | ||||
|             raise ValueError("The source SON object needs to be of type 'dict'") | ||||
|  | ||||
|         # Get the class name from the document, falling back to the given | ||||
|         # get the class name from the document, falling back to the given | ||||
|         # class if unavailable | ||||
|         class_name = son.get('_cls', cls._class_name) | ||||
|  | ||||
|         # Convert SON to a dict, making sure each key is a string | ||||
|         data = {str(key): value for key, value in son.iteritems()} | ||||
|         data = dict(("%s" % key, value) for key, value in son.iteritems()) | ||||
|  | ||||
|         # Return correct subclass for document type | ||||
|         if class_name != cls._class_name: | ||||
| @@ -707,20 +712,19 @@ class BaseDocument(object): | ||||
|                                         else field.to_python(value)) | ||||
|                     if field_name != field.db_field: | ||||
|                         del data[field.db_field] | ||||
|                 except (AttributeError, ValueError) as e: | ||||
|                 except (AttributeError, ValueError), e: | ||||
|                     errors_dict[field_name] = e | ||||
|  | ||||
|         if errors_dict: | ||||
|             errors = '\n'.join(['%s - %s' % (k, v) | ||||
|             errors = "\n".join(["%s - %s" % (k, v) | ||||
|                                 for k, v in errors_dict.items()]) | ||||
|             msg = ('Invalid data to create a `%s` instance.\n%s' | ||||
|             msg = ("Invalid data to create a `%s` instance.\n%s" | ||||
|                    % (cls._class_name, errors)) | ||||
|             raise InvalidDocumentError(msg) | ||||
|  | ||||
|         # In STRICT documents, remove any keys that aren't in cls._fields | ||||
|         if cls.STRICT: | ||||
|             data = {k: v for k, v in data.iteritems() if k in cls._fields} | ||||
|  | ||||
|             data = dict((k, v) | ||||
|                         for k, v in data.iteritems() if k in cls._fields) | ||||
|         obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) | ||||
|         obj._changed_fields = changed_fields | ||||
|         if not _auto_dereference: | ||||
| @@ -730,43 +734,37 @@ class BaseDocument(object): | ||||
|  | ||||
|     @classmethod | ||||
|     def _build_index_specs(cls, meta_indexes): | ||||
|         """Generate and merge the full index specs.""" | ||||
|         """Generate and merge the full index specs | ||||
|         """ | ||||
|  | ||||
|         geo_indices = cls._geo_indices() | ||||
|         unique_indices = cls._unique_with_indexes() | ||||
|         index_specs = [cls._build_index_spec(spec) for spec in meta_indexes] | ||||
|         index_specs = [cls._build_index_spec(spec) | ||||
|                        for spec in meta_indexes] | ||||
|  | ||||
|         def merge_index_specs(index_specs, indices): | ||||
|             """Helper method for merging index specs.""" | ||||
|             if not indices: | ||||
|                 return index_specs | ||||
|  | ||||
|             # Create a map of index fields to index spec. We're converting | ||||
|             # the fields from a list to a tuple so that it's hashable. | ||||
|             spec_fields = { | ||||
|                 tuple(index['fields']): index for index in index_specs | ||||
|             } | ||||
|  | ||||
|             # For each new index, if there's an existing index with the same | ||||
|             # fields list, update the existing spec with all data from the | ||||
|             # new spec. | ||||
|             for new_index in indices: | ||||
|                 candidate = spec_fields.get(tuple(new_index['fields'])) | ||||
|                 if candidate is None: | ||||
|                     index_specs.append(new_index) | ||||
|             spec_fields = [v['fields'] | ||||
|                            for k, v in enumerate(index_specs)] | ||||
|             # Merge unique_indexes with existing specs | ||||
|             for k, v in enumerate(indices): | ||||
|                 if v['fields'] in spec_fields: | ||||
|                     index_specs[spec_fields.index(v['fields'])].update(v) | ||||
|                 else: | ||||
|                     candidate.update(new_index) | ||||
|  | ||||
|                     index_specs.append(v) | ||||
|             return index_specs | ||||
|  | ||||
|         # Merge geo indexes and unique_with indexes into the meta index specs. | ||||
|         index_specs = merge_index_specs(index_specs, geo_indices) | ||||
|         index_specs = merge_index_specs(index_specs, unique_indices) | ||||
|         return index_specs | ||||
|  | ||||
|     @classmethod | ||||
|     def _build_index_spec(cls, spec): | ||||
|         """Build a PyMongo index spec from a MongoEngine index spec.""" | ||||
|         if isinstance(spec, six.string_types): | ||||
|         """Build a PyMongo index spec from a MongoEngine index spec. | ||||
|         """ | ||||
|         if isinstance(spec, basestring): | ||||
|             spec = {'fields': [spec]} | ||||
|         elif isinstance(spec, (list, tuple)): | ||||
|             spec = {'fields': list(spec)} | ||||
| @@ -777,7 +775,8 @@ class BaseDocument(object): | ||||
|         direction = None | ||||
|  | ||||
|         # Check to see if we need to include _cls | ||||
|         allow_inheritance = cls._meta.get('allow_inheritance') | ||||
|         allow_inheritance = cls._meta.get('allow_inheritance', | ||||
|                                           ALLOW_INHERITANCE) | ||||
|         include_cls = ( | ||||
|             allow_inheritance and | ||||
|             not spec.get('sparse', False) and | ||||
| @@ -787,7 +786,7 @@ class BaseDocument(object): | ||||
|  | ||||
|         # 733: don't include cls if index_cls is False unless there is an explicit cls with the index | ||||
|         include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) | ||||
|         if 'cls' in spec: | ||||
|         if "cls" in spec: | ||||
|             spec.pop('cls') | ||||
|         for key in spec['fields']: | ||||
|             # If inherited spec continue | ||||
| @@ -802,19 +801,19 @@ class BaseDocument(object): | ||||
|             # GEOHAYSTACK from ) | ||||
|             # GEO2D from * | ||||
|             direction = pymongo.ASCENDING | ||||
|             if key.startswith('-'): | ||||
|             if key.startswith("-"): | ||||
|                 direction = pymongo.DESCENDING | ||||
|             elif key.startswith('$'): | ||||
|             elif key.startswith("$"): | ||||
|                 direction = pymongo.TEXT | ||||
|             elif key.startswith('#'): | ||||
|             elif key.startswith("#"): | ||||
|                 direction = pymongo.HASHED | ||||
|             elif key.startswith('('): | ||||
|             elif key.startswith("("): | ||||
|                 direction = pymongo.GEOSPHERE | ||||
|             elif key.startswith(')'): | ||||
|             elif key.startswith(")"): | ||||
|                 direction = pymongo.GEOHAYSTACK | ||||
|             elif key.startswith('*'): | ||||
|             elif key.startswith("*"): | ||||
|                 direction = pymongo.GEO2D | ||||
|             if key.startswith(('+', '-', '*', '$', '#', '(', ')')): | ||||
|             if key.startswith(("+", "-", "*", "$", "#", "(", ")")): | ||||
|                 key = key[1:] | ||||
|  | ||||
|             # Use real field name, do it manually because we need field | ||||
| @@ -827,7 +826,7 @@ class BaseDocument(object): | ||||
|                 parts = [] | ||||
|                 for field in fields: | ||||
|                     try: | ||||
|                         if field != '_id': | ||||
|                         if field != "_id": | ||||
|                             field = field.db_field | ||||
|                     except AttributeError: | ||||
|                         pass | ||||
| @@ -846,53 +845,49 @@ class BaseDocument(object): | ||||
|         return spec | ||||
|  | ||||
|     @classmethod | ||||
|     def _unique_with_indexes(cls, namespace=''): | ||||
|         """Find unique indexes in the document schema and return them.""" | ||||
|     def _unique_with_indexes(cls, namespace=""): | ||||
|         """ | ||||
|         Find and set unique indexes | ||||
|         """ | ||||
|         unique_indexes = [] | ||||
|         for field_name, field in cls._fields.items(): | ||||
|             sparse = field.sparse | ||||
|  | ||||
|             # Generate a list of indexes needed by uniqueness constraints | ||||
|             if field.unique: | ||||
|                 unique_fields = [field.db_field] | ||||
|  | ||||
|                 # Add any unique_with fields to the back of the index spec | ||||
|                 if field.unique_with: | ||||
|                     if isinstance(field.unique_with, six.string_types): | ||||
|                     if isinstance(field.unique_with, basestring): | ||||
|                         field.unique_with = [field.unique_with] | ||||
|  | ||||
|                     # Convert unique_with field names to real field names | ||||
|                     unique_with = [] | ||||
|                     for other_name in field.unique_with: | ||||
|                         parts = other_name.split('.') | ||||
|  | ||||
|                         # Lookup real name | ||||
|                         parts = cls._lookup_field(parts) | ||||
|                         name_parts = [part.db_field for part in parts] | ||||
|                         unique_with.append('.'.join(name_parts)) | ||||
|  | ||||
|                         # Unique field should be required | ||||
|                         parts[-1].required = True | ||||
|                         sparse = (not sparse and | ||||
|                                   parts[-1].name not in cls.__dict__) | ||||
|  | ||||
|                     unique_fields += unique_with | ||||
|  | ||||
|                 # Add the new index to the list | ||||
|                 fields = [ | ||||
|                     ('%s%s' % (namespace, f), pymongo.ASCENDING) | ||||
|                     for f in unique_fields | ||||
|                 ] | ||||
|                 fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) | ||||
|                           for f in unique_fields] | ||||
|                 index = {'fields': fields, 'unique': True, 'sparse': sparse} | ||||
|                 unique_indexes.append(index) | ||||
|  | ||||
|             if field.__class__.__name__ == 'ListField': | ||||
|             if field.__class__.__name__ == "ListField": | ||||
|                 field = field.field | ||||
|  | ||||
|             # Grab any embedded document field unique indexes | ||||
|             if (field.__class__.__name__ == 'EmbeddedDocumentField' and | ||||
|             if (field.__class__.__name__ == "EmbeddedDocumentField" and | ||||
|                     field.document_type != cls): | ||||
|                 field_namespace = '%s.' % field_name | ||||
|                 field_namespace = "%s." % field_name | ||||
|                 doc_cls = field.document_type | ||||
|                 unique_indexes += doc_cls._unique_with_indexes(field_namespace) | ||||
|  | ||||
| @@ -904,9 +899,8 @@ class BaseDocument(object): | ||||
|         geo_indices = [] | ||||
|         inspected.append(cls) | ||||
|  | ||||
|         geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField', | ||||
|                                 'PointField', 'LineStringField', | ||||
|                                 'PolygonField') | ||||
|         geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", | ||||
|                                 "PointField", "LineStringField", "PolygonField"] | ||||
|  | ||||
|         geo_field_types = tuple([_import_class(field) | ||||
|                                  for field in geo_field_type_names]) | ||||
| @@ -914,68 +908,32 @@ class BaseDocument(object): | ||||
|         for field in cls._fields.values(): | ||||
|             if not isinstance(field, geo_field_types): | ||||
|                 continue | ||||
|  | ||||
|             if hasattr(field, 'document_type'): | ||||
|                 field_cls = field.document_type | ||||
|                 if field_cls in inspected: | ||||
|                     continue | ||||
|  | ||||
|                 if hasattr(field_cls, '_geo_indices'): | ||||
|                     geo_indices += field_cls._geo_indices( | ||||
|                         inspected, parent_field=field.db_field) | ||||
|             elif field._geo_index: | ||||
|                 field_name = field.db_field | ||||
|                 if parent_field: | ||||
|                     field_name = '%s.%s' % (parent_field, field_name) | ||||
|                 geo_indices.append({ | ||||
|                     'fields': [(field_name, field._geo_index)] | ||||
|                 }) | ||||
|  | ||||
|                     field_name = "%s.%s" % (parent_field, field_name) | ||||
|                 geo_indices.append({'fields': | ||||
|                                     [(field_name, field._geo_index)]}) | ||||
|         return geo_indices | ||||
|  | ||||
|     @classmethod | ||||
|     def _lookup_field(cls, parts): | ||||
|         """Given the path to a given field, return a list containing | ||||
|         the Field object associated with that field and all of its parent | ||||
|         Field objects. | ||||
|  | ||||
|         Args: | ||||
|             parts (str, list, or tuple) - path to the field. Should be a | ||||
|             string for simple fields existing on this document or a list | ||||
|             of strings for a field that exists deeper in embedded documents. | ||||
|  | ||||
|         Returns: | ||||
|             A list of Field instances for fields that were found or | ||||
|             strings for sub-fields that weren't. | ||||
|  | ||||
|         Example: | ||||
|             >>> user._lookup_field('name') | ||||
|             [<mongoengine.fields.StringField at 0x1119bff50>] | ||||
|  | ||||
|             >>> user._lookup_field('roles') | ||||
|             [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>] | ||||
|  | ||||
|             >>> user._lookup_field(['roles', 'role']) | ||||
|             [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>, | ||||
|              <mongoengine.fields.StringField at 0x1119ec050>] | ||||
|  | ||||
|             >>> user._lookup_field('doesnt_exist') | ||||
|             raises LookUpError | ||||
|  | ||||
|             >>> user._lookup_field(['roles', 'doesnt_exist']) | ||||
|             [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>, | ||||
|              'doesnt_exist'] | ||||
|  | ||||
|         """Lookup a field based on its attribute and return a list containing | ||||
|         the field's parents and the field. | ||||
|         """ | ||||
|         # TODO this method is WAY too complicated. Simplify it. | ||||
|         # TODO don't think returning a string for embedded non-existent fields is desired | ||||
|  | ||||
|         ListField = _import_class('ListField') | ||||
|         ListField = _import_class("ListField") | ||||
|         DynamicField = _import_class('DynamicField') | ||||
|  | ||||
|         if not isinstance(parts, (list, tuple)): | ||||
|             parts = [parts] | ||||
|  | ||||
|         fields = [] | ||||
|         field = None | ||||
|  | ||||
| @@ -985,17 +943,16 @@ class BaseDocument(object): | ||||
|                 fields.append(field_name) | ||||
|                 continue | ||||
|  | ||||
|             # Look up first field from the document | ||||
|             if field is None: | ||||
|                 # Look up first field from the document | ||||
|                 if field_name == 'pk': | ||||
|                     # Deal with "primary key" alias | ||||
|                     field_name = cls._meta['id_field'] | ||||
|  | ||||
|                 if field_name in cls._fields: | ||||
|                     field = cls._fields[field_name] | ||||
|                 elif cls._dynamic: | ||||
|                     field = DynamicField(db_field=field_name) | ||||
|                 elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False): | ||||
|                 elif cls._meta.get("allow_inheritance", False) or cls._meta.get("abstract", False): | ||||
|                     # 744: in case the field is defined in a subclass | ||||
|                     for subcls in cls.__subclasses__(): | ||||
|                         try: | ||||
| @@ -1008,55 +965,35 @@ class BaseDocument(object): | ||||
|                     else: | ||||
|                         raise LookUpError('Cannot resolve field "%s"' % field_name) | ||||
|                 else: | ||||
|                     raise LookUpError('Cannot resolve field "%s"' % field_name) | ||||
|                     raise LookUpError('Cannot resolve field "%s"' | ||||
|                                       % field_name) | ||||
|             else: | ||||
|                 ReferenceField = _import_class('ReferenceField') | ||||
|                 GenericReferenceField = _import_class('GenericReferenceField') | ||||
|  | ||||
|                 # If previous field was a reference, throw an error (we | ||||
|                 # cannot look up fields that are on references). | ||||
|                 if isinstance(field, (ReferenceField, GenericReferenceField)): | ||||
|                     raise LookUpError('Cannot perform join in mongoDB: %s' % | ||||
|                                       '__'.join(parts)) | ||||
|  | ||||
|                 # If the parent field has a "field" attribute which has a | ||||
|                 # lookup_member method, call it to find the field | ||||
|                 # corresponding to this iteration. | ||||
|                 if hasattr(getattr(field, 'field', None), 'lookup_member'): | ||||
|                     new_field = field.field.lookup_member(field_name) | ||||
|  | ||||
|                 # If the parent field is a DynamicField or if it's part of | ||||
|                 # a DynamicDocument, mark current field as a DynamicField | ||||
|                 # with db_name equal to the field name. | ||||
|                 elif cls._dynamic and (isinstance(field, DynamicField) or | ||||
|                                        getattr(getattr(field, 'document_type', None), '_dynamic', None)): | ||||
|                     new_field = DynamicField(db_field=field_name) | ||||
|  | ||||
|                 # Else, try to use the parent field's lookup_member method | ||||
|                 # to find the subfield. | ||||
|                 elif hasattr(field, 'lookup_member'): | ||||
|                     new_field = field.lookup_member(field_name) | ||||
|  | ||||
|                 # Raise a LookUpError if all the other conditions failed. | ||||
|                 else: | ||||
|                     raise LookUpError( | ||||
|                         'Cannot resolve subfield or operator {} ' | ||||
|                         'on the field {}'.format(field_name, field.name) | ||||
|                     ) | ||||
|  | ||||
|                 # If current field still wasn't found and the parent field | ||||
|                 # is a ComplexBaseField, add the name current field name and | ||||
|                 # move on. | ||||
|                     # Look up subfield on the previous field or raise | ||||
|                     try: | ||||
|                         new_field = field.lookup_member(field_name) | ||||
|                     except AttributeError: | ||||
|                         raise LookUpError('Cannot resolve subfield or operator {} ' | ||||
|                                           'on the field {}'.format( | ||||
|                                               field_name, field.name)) | ||||
|                 if not new_field and isinstance(field, ComplexBaseField): | ||||
|                     fields.append(field_name) | ||||
|                     continue | ||||
|                 elif not new_field: | ||||
|                     raise LookUpError('Cannot resolve field "%s"' % field_name) | ||||
|  | ||||
|                     raise LookUpError('Cannot resolve field "%s"' | ||||
|                                       % field_name) | ||||
|                 field = new_field  # update field to the new field type | ||||
|  | ||||
|             fields.append(field) | ||||
|  | ||||
|         return fields | ||||
|  | ||||
|     @classmethod | ||||
|   | ||||
| @@ -4,17 +4,21 @@ import weakref | ||||
|  | ||||
| from bson import DBRef, ObjectId, SON | ||||
| import pymongo | ||||
| import six | ||||
|  | ||||
| from mongoengine.base.common import UPDATE_OPERATORS | ||||
| from mongoengine.base.datastructures import (BaseDict, BaseList, | ||||
|                                              EmbeddedDocumentList) | ||||
| from mongoengine.base.common import ALLOW_INHERITANCE | ||||
| from mongoengine.base.datastructures import ( | ||||
|     BaseDict, BaseList, EmbeddedDocumentList | ||||
| ) | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import ValidationError | ||||
|  | ||||
| __all__ = ("BaseField", "ComplexBaseField", | ||||
|            "ObjectIdField", "GeoJsonBaseField") | ||||
|  | ||||
| __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', | ||||
|            'GeoJsonBaseField') | ||||
|  | ||||
| UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', | ||||
|                         'push_all', 'pull', 'pull_all', 'add_to_set', | ||||
|                         'set_on_insert', 'min', 'max']) | ||||
|  | ||||
|  | ||||
| class BaseField(object): | ||||
| @@ -23,6 +27,7 @@ class BaseField(object): | ||||
|  | ||||
|     .. versionchanged:: 0.5 - added verbose and help text | ||||
|     """ | ||||
|  | ||||
|     name = None | ||||
|     _geo_index = False | ||||
|     _auto_gen = False  # Call `generate` to generate a value | ||||
| @@ -41,7 +46,7 @@ class BaseField(object): | ||||
|         """ | ||||
|         :param db_field: The database field to store this field in | ||||
|             (defaults to the name of the field) | ||||
|         :param name: Deprecated - use db_field | ||||
|         :param name: Depreciated - use db_field | ||||
|         :param required: If the field is required. Whether it has to have a | ||||
|             value or not. Defaults to False. | ||||
|         :param default: (optional) The default value for this field if no value | ||||
| @@ -68,7 +73,7 @@ class BaseField(object): | ||||
|         self.db_field = (db_field or name) if not primary_key else '_id' | ||||
|  | ||||
|         if name: | ||||
|             msg = 'Field\'s "name" attribute deprecated in favour of "db_field"' | ||||
|             msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         self.required = required or primary_key | ||||
|         self.default = default | ||||
| @@ -81,21 +86,10 @@ class BaseField(object): | ||||
|         self.sparse = sparse | ||||
|         self._owner_document = None | ||||
|  | ||||
|         # Validate the db_field | ||||
|         if isinstance(self.db_field, six.string_types) and ( | ||||
|             '.' in self.db_field or | ||||
|             '\0' in self.db_field or | ||||
|             self.db_field.startswith('$') | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 'field names cannot contain dots (".") or null characters ' | ||||
|                 '("\\0"), and they must not start with a dollar sign ("$").' | ||||
|             ) | ||||
|  | ||||
|         # Detect and report conflicts between metadata and base properties. | ||||
|         conflicts = set(dir(self)) & set(kwargs) | ||||
|         if conflicts: | ||||
|             raise TypeError('%s already has attribute(s): %s' % ( | ||||
|             raise TypeError("%s already has attribute(s): %s" % ( | ||||
|                 self.__class__.__name__, ', '.join(conflicts))) | ||||
|  | ||||
|         # Assign metadata to the instance | ||||
| @@ -153,21 +147,25 @@ class BaseField(object): | ||||
|                     v._instance = weakref.proxy(instance) | ||||
|         instance._data[self.name] = value | ||||
|  | ||||
|     def error(self, message='', errors=None, field_name=None): | ||||
|         """Raise a ValidationError.""" | ||||
|     def error(self, message="", errors=None, field_name=None): | ||||
|         """Raises a ValidationError. | ||||
|         """ | ||||
|         field_name = field_name if field_name else self.name | ||||
|         raise ValidationError(message, errors=errors, field_name=field_name) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         """Convert a MongoDB-compatible type to a Python type.""" | ||||
|         """Convert a MongoDB-compatible type to a Python type. | ||||
|         """ | ||||
|         return value | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         """Convert a Python type to a MongoDB-compatible type.""" | ||||
|         """Convert a Python type to a MongoDB-compatible type. | ||||
|         """ | ||||
|         return self.to_python(value) | ||||
|  | ||||
|     def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): | ||||
|         """Helper method to call to_mongo with proper inputs.""" | ||||
|         """A helper method to call to_mongo with proper inputs | ||||
|         """ | ||||
|         f_inputs = self.to_mongo.__code__.co_varnames | ||||
|         ex_vars = {} | ||||
|         if 'fields' in f_inputs: | ||||
| @@ -179,13 +177,15 @@ class BaseField(object): | ||||
|         return self.to_mongo(value, **ex_vars) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         """Prepare a value that is being used in a query for PyMongo.""" | ||||
|         """Prepare a value that is being used in a query for PyMongo. | ||||
|         """ | ||||
|         if op in UPDATE_OPERATORS: | ||||
|             self.validate(value) | ||||
|         return value | ||||
|  | ||||
|     def validate(self, value, clean=True): | ||||
|         """Perform validation on a value.""" | ||||
|         """Perform validation on a value. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def _validate_choices(self, value): | ||||
| @@ -193,21 +193,18 @@ class BaseField(object): | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|  | ||||
|         choice_list = self.choices | ||||
|         if isinstance(next(iter(choice_list)), (list, tuple)): | ||||
|             # next(iter) is useful for sets | ||||
|         if isinstance(choice_list[0], (list, tuple)): | ||||
|             choice_list = [k for k, _ in choice_list] | ||||
|  | ||||
|         # Choices which are other types of Documents | ||||
|         if isinstance(value, (Document, EmbeddedDocument)): | ||||
|             if not any(isinstance(value, c) for c in choice_list): | ||||
|                 self.error( | ||||
|                     'Value must be an instance of %s' % ( | ||||
|                         six.text_type(choice_list) | ||||
|                     ) | ||||
|                     'Value must be instance of %s' % unicode(choice_list) | ||||
|                 ) | ||||
|         # Choices which are types other than Documents | ||||
|         elif value not in choice_list: | ||||
|             self.error('Value must be one of %s' % six.text_type(choice_list)) | ||||
|             self.error('Value must be one of %s' % unicode(choice_list)) | ||||
|  | ||||
|     def _validate(self, value, **kwargs): | ||||
|         # Check the Choices Constraint | ||||
| @@ -250,7 +247,8 @@ class ComplexBaseField(BaseField): | ||||
|     field = None | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor to automatically dereference references.""" | ||||
|         """Descriptor to automatically dereference references. | ||||
|         """ | ||||
|         if instance is None: | ||||
|             # Document class being used rather than a document object | ||||
|             return self | ||||
| @@ -262,7 +260,7 @@ class ComplexBaseField(BaseField): | ||||
|                        (self.field is None or isinstance(self.field, | ||||
|                                                          (GenericReferenceField, ReferenceField)))) | ||||
|  | ||||
|         _dereference = _import_class('DeReference')() | ||||
|         _dereference = _import_class("DeReference")() | ||||
|  | ||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||
|         if instance._initialised and dereference and instance._data.get(self.name): | ||||
| @@ -297,8 +295,9 @@ class ComplexBaseField(BaseField): | ||||
|         return value | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         """Convert a MongoDB-compatible type to a Python type.""" | ||||
|         if isinstance(value, six.string_types): | ||||
|         """Convert a MongoDB-compatible type to a Python type. | ||||
|         """ | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|  | ||||
|         if hasattr(value, 'to_python'): | ||||
| @@ -308,14 +307,14 @@ class ComplexBaseField(BaseField): | ||||
|         if not hasattr(value, 'items'): | ||||
|             try: | ||||
|                 is_list = True | ||||
|                 value = {k: v for k, v in enumerate(value)} | ||||
|                 value = dict([(k, v) for k, v in enumerate(value)]) | ||||
|             except TypeError:  # Not iterable return the value | ||||
|                 return value | ||||
|  | ||||
|         if self.field: | ||||
|             self.field._auto_dereference = self._auto_dereference | ||||
|             value_dict = {key: self.field.to_python(item) | ||||
|                           for key, item in value.items()} | ||||
|             value_dict = dict([(key, self.field.to_python(item)) | ||||
|                                for key, item in value.items()]) | ||||
|         else: | ||||
|             Document = _import_class('Document') | ||||
|             value_dict = {} | ||||
| @@ -338,12 +337,13 @@ class ComplexBaseField(BaseField): | ||||
|         return value_dict | ||||
|  | ||||
|     def to_mongo(self, value, use_db_field=True, fields=None): | ||||
|         """Convert a Python type to a MongoDB-compatible type.""" | ||||
|         Document = _import_class('Document') | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         GenericReferenceField = _import_class('GenericReferenceField') | ||||
|         """Convert a Python type to a MongoDB-compatible type. | ||||
|         """ | ||||
|         Document = _import_class("Document") | ||||
|         EmbeddedDocument = _import_class("EmbeddedDocument") | ||||
|         GenericReferenceField = _import_class("GenericReferenceField") | ||||
|  | ||||
|         if isinstance(value, six.string_types): | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|  | ||||
|         if hasattr(value, 'to_mongo'): | ||||
| @@ -360,15 +360,13 @@ class ComplexBaseField(BaseField): | ||||
|         if not hasattr(value, 'items'): | ||||
|             try: | ||||
|                 is_list = True | ||||
|                 value = {k: v for k, v in enumerate(value)} | ||||
|                 value = dict([(k, v) for k, v in enumerate(value)]) | ||||
|             except TypeError:  # Not iterable return the value | ||||
|                 return value | ||||
|  | ||||
|         if self.field: | ||||
|             value_dict = { | ||||
|                 key: self.field._to_mongo_safe_call(item, use_db_field, fields) | ||||
|                 for key, item in value.iteritems() | ||||
|             } | ||||
|             value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields)) | ||||
|                                for key, item in value.iteritems()]) | ||||
|         else: | ||||
|             value_dict = {} | ||||
|             for k, v in value.iteritems(): | ||||
| @@ -382,7 +380,9 @@ class ComplexBaseField(BaseField): | ||||
|                     # any _cls data so make it a generic reference allows | ||||
|                     # us to dereference | ||||
|                     meta = getattr(v, '_meta', {}) | ||||
|                     allow_inheritance = meta.get('allow_inheritance') | ||||
|                     allow_inheritance = ( | ||||
|                         meta.get('allow_inheritance', ALLOW_INHERITANCE) | ||||
|                         is True) | ||||
|                     if not allow_inheritance and not self.field: | ||||
|                         value_dict[k] = GenericReferenceField().to_mongo(v) | ||||
|                     else: | ||||
| @@ -404,7 +404,8 @@ class ComplexBaseField(BaseField): | ||||
|         return value_dict | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """If field is provided ensure the value is valid.""" | ||||
|         """If field is provided ensure the value is valid. | ||||
|         """ | ||||
|         errors = {} | ||||
|         if self.field: | ||||
|             if hasattr(value, 'iteritems') or hasattr(value, 'items'): | ||||
| @@ -414,9 +415,9 @@ class ComplexBaseField(BaseField): | ||||
|             for k, v in sequence: | ||||
|                 try: | ||||
|                     self.field._validate(v) | ||||
|                 except ValidationError as error: | ||||
|                 except ValidationError, error: | ||||
|                     errors[k] = error.errors or error | ||||
|                 except (ValueError, AssertionError) as error: | ||||
|                 except (ValueError, AssertionError), error: | ||||
|                     errors[k] = error | ||||
|  | ||||
|             if errors: | ||||
| @@ -442,7 +443,8 @@ class ComplexBaseField(BaseField): | ||||
|  | ||||
|  | ||||
| class ObjectIdField(BaseField): | ||||
|     """A field wrapper around MongoDB's ObjectIds.""" | ||||
|     """A field wrapper around MongoDB's ObjectIds. | ||||
|     """ | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         try: | ||||
| @@ -455,10 +457,10 @@ class ObjectIdField(BaseField): | ||||
|     def to_mongo(self, value): | ||||
|         if not isinstance(value, ObjectId): | ||||
|             try: | ||||
|                 return ObjectId(six.text_type(value)) | ||||
|             except Exception as e: | ||||
|                 return ObjectId(unicode(value)) | ||||
|             except Exception, e: | ||||
|                 # e.message attribute has been deprecated since Python 2.6 | ||||
|                 self.error(six.text_type(e)) | ||||
|                 self.error(unicode(e)) | ||||
|         return value | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
| @@ -466,7 +468,7 @@ class ObjectIdField(BaseField): | ||||
|  | ||||
|     def validate(self, value): | ||||
|         try: | ||||
|             ObjectId(six.text_type(value)) | ||||
|             ObjectId(unicode(value)) | ||||
|         except Exception: | ||||
|             self.error('Invalid Object ID') | ||||
|  | ||||
| @@ -478,20 +480,21 @@ class GeoJsonBaseField(BaseField): | ||||
|     """ | ||||
|  | ||||
|     _geo_index = pymongo.GEOSPHERE | ||||
|     _type = 'GeoBase' | ||||
|     _type = "GeoBase" | ||||
|  | ||||
|     def __init__(self, auto_index=True, *args, **kwargs): | ||||
|         """ | ||||
|         :param bool auto_index: Automatically create a '2dsphere' index.\ | ||||
|         :param bool auto_index: Automatically create a "2dsphere" index.\ | ||||
|             Defaults to `True`. | ||||
|         """ | ||||
|         self._name = '%sField' % self._type | ||||
|         self._name = "%sField" % self._type | ||||
|         if not auto_index: | ||||
|             self._geo_index = False | ||||
|         super(GeoJsonBaseField, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Validate the GeoJson object based on its type.""" | ||||
|         """Validate the GeoJson object based on its type | ||||
|         """ | ||||
|         if isinstance(value, dict): | ||||
|             if set(value.keys()) == set(['type', 'coordinates']): | ||||
|                 if value['type'] != self._type: | ||||
| @@ -506,7 +509,7 @@ class GeoJsonBaseField(BaseField): | ||||
|             self.error('%s can only accept lists of [x, y]' % self._name) | ||||
|             return | ||||
|  | ||||
|         validate = getattr(self, '_validate_%s' % self._type.lower()) | ||||
|         validate = getattr(self, "_validate_%s" % self._type.lower()) | ||||
|         error = validate(value) | ||||
|         if error: | ||||
|             self.error(error) | ||||
| @@ -519,7 +522,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         try: | ||||
|             value[0][0][0] | ||||
|         except (TypeError, IndexError): | ||||
|             return 'Invalid Polygon must contain at least one valid linestring' | ||||
|             return "Invalid Polygon must contain at least one valid linestring" | ||||
|  | ||||
|         errors = [] | ||||
|         for val in value: | ||||
| @@ -530,12 +533,12 @@ class GeoJsonBaseField(BaseField): | ||||
|                 errors.append(error) | ||||
|         if errors: | ||||
|             if top_level: | ||||
|                 return 'Invalid Polygon:\n%s' % ', '.join(errors) | ||||
|                 return "Invalid Polygon:\n%s" % ", ".join(errors) | ||||
|             else: | ||||
|                 return '%s' % ', '.join(errors) | ||||
|                 return "%s" % ", ".join(errors) | ||||
|  | ||||
|     def _validate_linestring(self, value, top_level=True): | ||||
|         """Validate a linestring.""" | ||||
|         """Validates a linestring""" | ||||
|         if not isinstance(value, (list, tuple)): | ||||
|             return 'LineStrings must contain list of coordinate pairs' | ||||
|  | ||||
| @@ -543,7 +546,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         try: | ||||
|             value[0][0] | ||||
|         except (TypeError, IndexError): | ||||
|             return 'Invalid LineString must contain at least one valid point' | ||||
|             return "Invalid LineString must contain at least one valid point" | ||||
|  | ||||
|         errors = [] | ||||
|         for val in value: | ||||
| @@ -552,19 +555,19 @@ class GeoJsonBaseField(BaseField): | ||||
|                 errors.append(error) | ||||
|         if errors: | ||||
|             if top_level: | ||||
|                 return 'Invalid LineString:\n%s' % ', '.join(errors) | ||||
|                 return "Invalid LineString:\n%s" % ", ".join(errors) | ||||
|             else: | ||||
|                 return '%s' % ', '.join(errors) | ||||
|                 return "%s" % ", ".join(errors) | ||||
|  | ||||
|     def _validate_point(self, value): | ||||
|         """Validate each set of coords""" | ||||
|         if not isinstance(value, (list, tuple)): | ||||
|             return 'Points must be a list of coordinate pairs' | ||||
|         elif not len(value) == 2: | ||||
|             return 'Value (%s) must be a two-dimensional point' % repr(value) | ||||
|             return "Value (%s) must be a two-dimensional point" % repr(value) | ||||
|         elif (not isinstance(value[0], (float, int)) or | ||||
|               not isinstance(value[1], (float, int))): | ||||
|             return 'Both values (%s) in point must be float or int' % repr(value) | ||||
|             return "Both values (%s) in point must be float or int" % repr(value) | ||||
|  | ||||
|     def _validate_multipoint(self, value): | ||||
|         if not isinstance(value, (list, tuple)): | ||||
| @@ -574,7 +577,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         try: | ||||
|             value[0][0] | ||||
|         except (TypeError, IndexError): | ||||
|             return 'Invalid MultiPoint must contain at least one valid point' | ||||
|             return "Invalid MultiPoint must contain at least one valid point" | ||||
|  | ||||
|         errors = [] | ||||
|         for point in value: | ||||
| @@ -583,7 +586,7 @@ class GeoJsonBaseField(BaseField): | ||||
|                 errors.append(error) | ||||
|  | ||||
|         if errors: | ||||
|             return '%s' % ', '.join(errors) | ||||
|             return "%s" % ", ".join(errors) | ||||
|  | ||||
|     def _validate_multilinestring(self, value, top_level=True): | ||||
|         if not isinstance(value, (list, tuple)): | ||||
| @@ -593,7 +596,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         try: | ||||
|             value[0][0][0] | ||||
|         except (TypeError, IndexError): | ||||
|             return 'Invalid MultiLineString must contain at least one valid linestring' | ||||
|             return "Invalid MultiLineString must contain at least one valid linestring" | ||||
|  | ||||
|         errors = [] | ||||
|         for linestring in value: | ||||
| @@ -603,9 +606,9 @@ class GeoJsonBaseField(BaseField): | ||||
|  | ||||
|         if errors: | ||||
|             if top_level: | ||||
|                 return 'Invalid MultiLineString:\n%s' % ', '.join(errors) | ||||
|                 return "Invalid MultiLineString:\n%s" % ", ".join(errors) | ||||
|             else: | ||||
|                 return '%s' % ', '.join(errors) | ||||
|                 return "%s" % ", ".join(errors) | ||||
|  | ||||
|     def _validate_multipolygon(self, value): | ||||
|         if not isinstance(value, (list, tuple)): | ||||
| @@ -615,7 +618,7 @@ class GeoJsonBaseField(BaseField): | ||||
|         try: | ||||
|             value[0][0][0][0] | ||||
|         except (TypeError, IndexError): | ||||
|             return 'Invalid MultiPolygon must contain at least one valid Polygon' | ||||
|             return "Invalid MultiPolygon must contain at least one valid Polygon" | ||||
|  | ||||
|         errors = [] | ||||
|         for polygon in value: | ||||
| @@ -624,9 +627,9 @@ class GeoJsonBaseField(BaseField): | ||||
|                 errors.append(error) | ||||
|  | ||||
|         if errors: | ||||
|             return 'Invalid MultiPolygon:\n%s' % ', '.join(errors) | ||||
|             return "Invalid MultiPolygon:\n%s" % ", ".join(errors) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         if isinstance(value, dict): | ||||
|             return value | ||||
|         return SON([('type', self._type), ('coordinates', value)]) | ||||
|         return SON([("type", self._type), ("coordinates", value)]) | ||||
|   | ||||
| @@ -1,11 +1,10 @@ | ||||
| import warnings | ||||
|  | ||||
| import six | ||||
|  | ||||
| from mongoengine.base.common import _document_registry | ||||
| from mongoengine.base.common import ALLOW_INHERITANCE, _document_registry | ||||
| from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.errors import InvalidDocumentError | ||||
| from mongoengine.python_support import PY3 | ||||
| from mongoengine.queryset import (DO_NOTHING, DoesNotExist, | ||||
|                                   MultipleObjectsReturned, | ||||
|                                   QuerySetManager) | ||||
| @@ -46,8 +45,7 @@ class DocumentMetaclass(type): | ||||
|             attrs['_meta'] = meta | ||||
|             attrs['_meta']['abstract'] = False  # 789: EmbeddedDocument shouldn't inherit abstract | ||||
|  | ||||
|         # If allow_inheritance is True, add a "_cls" string field to the attrs | ||||
|         if attrs['_meta'].get('allow_inheritance'): | ||||
|         if attrs['_meta'].get('allow_inheritance', ALLOW_INHERITANCE): | ||||
|             StringField = _import_class('StringField') | ||||
|             attrs['_cls'] = StringField() | ||||
|  | ||||
| @@ -89,17 +87,16 @@ class DocumentMetaclass(type): | ||||
|         # Ensure no duplicate db_fields | ||||
|         duplicate_db_fields = [k for k, v in field_names.items() if v > 1] | ||||
|         if duplicate_db_fields: | ||||
|             msg = ('Multiple db_fields defined for: %s ' % | ||||
|                    ', '.join(duplicate_db_fields)) | ||||
|             msg = ("Multiple db_fields defined for: %s " % | ||||
|                    ", ".join(duplicate_db_fields)) | ||||
|             raise InvalidDocumentError(msg) | ||||
|  | ||||
|         # Set _fields and db_field maps | ||||
|         attrs['_fields'] = doc_fields | ||||
|         attrs['_db_field_map'] = {k: getattr(v, 'db_field', k) | ||||
|                                   for k, v in doc_fields.items()} | ||||
|         attrs['_reverse_db_field_map'] = { | ||||
|             v: k for k, v in attrs['_db_field_map'].items() | ||||
|         } | ||||
|         attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) | ||||
|                                        for k, v in doc_fields.iteritems()]) | ||||
|         attrs['_reverse_db_field_map'] = dict( | ||||
|             (v, k) for k, v in attrs['_db_field_map'].iteritems()) | ||||
|  | ||||
|         attrs['_fields_ordered'] = tuple(i[1] for i in sorted( | ||||
|                                          (v.creation_counter, v.name) | ||||
| @@ -119,8 +116,10 @@ class DocumentMetaclass(type): | ||||
|             if hasattr(base, '_meta'): | ||||
|                 # Warn if allow_inheritance isn't set and prevent | ||||
|                 # inheritance of classes where inheritance is set to False | ||||
|                 allow_inheritance = base._meta.get('allow_inheritance') | ||||
|                 if not allow_inheritance and not base._meta.get('abstract'): | ||||
|                 allow_inheritance = base._meta.get('allow_inheritance', | ||||
|                                                    ALLOW_INHERITANCE) | ||||
|                 if (allow_inheritance is not True and | ||||
|                         not base._meta.get('abstract')): | ||||
|                     raise ValueError('Document %s may not be subclassed' % | ||||
|                                      base.__name__) | ||||
|  | ||||
| @@ -162,7 +161,7 @@ class DocumentMetaclass(type): | ||||
|         # module continues to use im_func and im_self, so the code below | ||||
|         # copies __func__ into im_func and __self__ into im_self for | ||||
|         # classmethod objects in Document derived classes. | ||||
|         if six.PY3: | ||||
|         if PY3: | ||||
|             for val in new_class.__dict__.values(): | ||||
|                 if isinstance(val, classmethod): | ||||
|                     f = val.__get__(new_class) | ||||
| @@ -180,11 +179,11 @@ class DocumentMetaclass(type): | ||||
|             if isinstance(f, CachedReferenceField): | ||||
|  | ||||
|                 if issubclass(new_class, EmbeddedDocument): | ||||
|                     raise InvalidDocumentError('CachedReferenceFields is not ' | ||||
|                                                'allowed in EmbeddedDocuments') | ||||
|                     raise InvalidDocumentError( | ||||
|                         "CachedReferenceFields is not allowed in EmbeddedDocuments") | ||||
|                 if not f.document_type: | ||||
|                     raise InvalidDocumentError( | ||||
|                         'Document is not available to sync') | ||||
|                         "Document is not available to sync") | ||||
|  | ||||
|                 if f.auto_sync: | ||||
|                     f.start_listener() | ||||
| @@ -196,8 +195,8 @@ class DocumentMetaclass(type): | ||||
|                                       'reverse_delete_rule', | ||||
|                                       DO_NOTHING) | ||||
|                 if isinstance(f, DictField) and delete_rule != DO_NOTHING: | ||||
|                     msg = ('Reverse delete rules are not supported ' | ||||
|                            'for %s (field: %s)' % | ||||
|                     msg = ("Reverse delete rules are not supported " | ||||
|                            "for %s (field: %s)" % | ||||
|                            (field.__class__.__name__, field.name)) | ||||
|                     raise InvalidDocumentError(msg) | ||||
|  | ||||
| @@ -205,16 +204,16 @@ class DocumentMetaclass(type): | ||||
|  | ||||
|             if delete_rule != DO_NOTHING: | ||||
|                 if issubclass(new_class, EmbeddedDocument): | ||||
|                     msg = ('Reverse delete rules are not supported for ' | ||||
|                            'EmbeddedDocuments (field: %s)' % field.name) | ||||
|                     msg = ("Reverse delete rules are not supported for " | ||||
|                            "EmbeddedDocuments (field: %s)" % field.name) | ||||
|                     raise InvalidDocumentError(msg) | ||||
|                 f.document_type.register_delete_rule(new_class, | ||||
|                                                      field.name, delete_rule) | ||||
|  | ||||
|             if (field.name and hasattr(Document, field.name) and | ||||
|                     EmbeddedDocument not in new_class.mro()): | ||||
|                 msg = ('%s is a document method and not a valid ' | ||||
|                        'field name' % field.name) | ||||
|                 msg = ("%s is a document method and not a valid " | ||||
|                        "field name" % field.name) | ||||
|                 raise InvalidDocumentError(msg) | ||||
|  | ||||
|         return new_class | ||||
| @@ -272,11 +271,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | ||||
|                 'index_drop_dups': False, | ||||
|                 'index_opts': None, | ||||
|                 'delete_rules': None, | ||||
|  | ||||
|                 # allow_inheritance can be True, False, and None. True means | ||||
|                 # "allow inheritance", False means "don't allow inheritance", | ||||
|                 # None means "do whatever your parent does, or don't allow | ||||
|                 # inheritance if you're a top-level class". | ||||
|                 'allow_inheritance': None, | ||||
|             } | ||||
|             attrs['_is_base_cls'] = True | ||||
| @@ -309,7 +303,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | ||||
|         # If parent wasn't an abstract class | ||||
|         if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and | ||||
|                 not parent_doc_cls._meta.get('abstract', True)): | ||||
|             msg = 'Trying to set a collection on a subclass (%s)' % name | ||||
|             msg = "Trying to set a collection on a subclass (%s)" % name | ||||
|             warnings.warn(msg, SyntaxWarning) | ||||
|             del attrs['_meta']['collection'] | ||||
|  | ||||
| @@ -317,7 +311,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | ||||
|         if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): | ||||
|             if (parent_doc_cls and | ||||
|                     not parent_doc_cls._meta.get('abstract', False)): | ||||
|                 msg = 'Abstract document cannot have non-abstract base' | ||||
|                 msg = "Abstract document cannot have non-abstract base" | ||||
|                 raise ValueError(msg) | ||||
|             return super_new(cls, name, bases, attrs) | ||||
|  | ||||
| @@ -340,16 +334,12 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | ||||
|  | ||||
|         meta.merge(attrs.get('_meta', {}))  # Top level meta | ||||
|  | ||||
|         # Only simple classes (i.e. direct subclasses of Document) may set | ||||
|         # allow_inheritance to False. If the base Document allows inheritance, | ||||
|         # none of its subclasses can override allow_inheritance to False. | ||||
|         # Only simple classes (direct subclasses of Document) | ||||
|         # may set allow_inheritance to False | ||||
|         simple_class = all([b._meta.get('abstract') | ||||
|                             for b in flattened_bases if hasattr(b, '_meta')]) | ||||
|         if ( | ||||
|             not simple_class and | ||||
|             meta['allow_inheritance'] is False and | ||||
|             not meta['abstract'] | ||||
|         ): | ||||
|         if (not simple_class and meta['allow_inheritance'] is False and | ||||
|                 not meta['abstract']): | ||||
|             raise ValueError('Only direct subclasses of Document may set ' | ||||
|                              '"allow_inheritance" to False') | ||||
|  | ||||
|   | ||||
| @@ -34,10 +34,7 @@ def _import_class(cls_name): | ||||
|     queryset_classes = ('OperationError',) | ||||
|     deref_classes = ('DeReference',) | ||||
|  | ||||
|     if cls_name == 'BaseDocument': | ||||
|         from mongoengine.base import document as module | ||||
|         import_classes = ['BaseDocument'] | ||||
|     elif cls_name in doc_classes: | ||||
|     if cls_name in doc_classes: | ||||
|         from mongoengine import document as module | ||||
|         import_classes = doc_classes | ||||
|     elif cls_name in field_classes: | ||||
|   | ||||
| @@ -1,9 +1,7 @@ | ||||
| from pymongo import MongoClient, ReadPreference, uri_parser | ||||
| import six | ||||
| from mongoengine.python_support import (IS_PYMONGO_3, str_types) | ||||
|  | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
|  | ||||
| __all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', | ||||
| __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||
|            'DEFAULT_CONNECTION_NAME'] | ||||
|  | ||||
|  | ||||
| @@ -16,10 +14,7 @@ else: | ||||
|     READ_PREFERENCE = False | ||||
|  | ||||
|  | ||||
| class MongoEngineConnectionError(Exception): | ||||
|     """Error raised when the database connection can't be established or | ||||
|     when a connection with a requested alias can't be retrieved. | ||||
|     """ | ||||
| class ConnectionError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @@ -30,8 +25,7 @@ _dbs = {} | ||||
|  | ||||
| def register_connection(alias, name=None, host=None, port=None, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, | ||||
|                         authentication_source=None, | ||||
|                         username=None, password=None, authentication_source=None, | ||||
|                         authentication_mechanism=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
| @@ -51,12 +45,12 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|         MONGODB-CR (MongoDB Challenge Response protocol) for older servers. | ||||
|     :param is_mock: explicitly use mongomock for this connection | ||||
|         (can also be done by using `mongomock://` as db host prefix) | ||||
|     :param kwargs: ad-hoc parameters to be passed into the pymongo driver, | ||||
|         for example maxpoolsize, tz_aware, etc. See the documentation | ||||
|         for pymongo's `MongoClient` for a full list. | ||||
|     :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver | ||||
|  | ||||
|     .. versionchanged:: 0.10.6 - added mongomock support | ||||
|     """ | ||||
|     global _connection_settings | ||||
|  | ||||
|     conn_settings = { | ||||
|         'name': name or 'test', | ||||
|         'host': host or 'localhost', | ||||
| @@ -68,37 +62,31 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|         'authentication_mechanism': authentication_mechanism | ||||
|     } | ||||
|  | ||||
|     # Handle uri style connections | ||||
|     conn_host = conn_settings['host'] | ||||
|  | ||||
|     # Host can be a list or a string, so if string, force to a list. | ||||
|     if isinstance(conn_host, six.string_types): | ||||
|     # host can be a list or a string, so if string, force to a list | ||||
|     if isinstance(conn_host, str_types): | ||||
|         conn_host = [conn_host] | ||||
|  | ||||
|     resolved_hosts = [] | ||||
|     for entity in conn_host: | ||||
|  | ||||
|         # Handle Mongomock | ||||
|         # Handle uri style connections | ||||
|         if entity.startswith('mongomock://'): | ||||
|             conn_settings['is_mock'] = True | ||||
|             # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` | ||||
|             resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) | ||||
|  | ||||
|         # Handle URI style connections, only updating connection params which | ||||
|         # were explicitly specified in the URI. | ||||
|         elif '://' in entity: | ||||
|             uri_dict = uri_parser.parse_uri(entity) | ||||
|             resolved_hosts.append(entity) | ||||
|  | ||||
|             if uri_dict.get('database'): | ||||
|                 conn_settings['name'] = uri_dict.get('database') | ||||
|  | ||||
|             for param in ('read_preference', 'username', 'password'): | ||||
|                 if uri_dict.get(param): | ||||
|                     conn_settings[param] = uri_dict[param] | ||||
|  | ||||
|             conn_settings.update({ | ||||
|                 'name': uri_dict.get('database') or name, | ||||
|                 'username': uri_dict.get('username'), | ||||
|                 'password': uri_dict.get('password'), | ||||
|                 'read_preference': read_preference, | ||||
|             }) | ||||
|             uri_options = uri_dict['options'] | ||||
|             if 'replicaset' in uri_options: | ||||
|                 conn_settings['replicaSet'] = uri_options['replicaset'] | ||||
|                 conn_settings['replicaSet'] = True | ||||
|             if 'authsource' in uri_options: | ||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||
|             if 'authmechanism' in uri_options: | ||||
| @@ -116,7 +104,9 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|  | ||||
|  | ||||
| def disconnect(alias=DEFAULT_CONNECTION_NAME): | ||||
|     """Close the connection with a given alias.""" | ||||
|     global _connections | ||||
|     global _dbs | ||||
|  | ||||
|     if alias in _connections: | ||||
|         get_connection(alias=alias).close() | ||||
|         del _connections[alias] | ||||
| @@ -125,99 +115,71 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME): | ||||
|  | ||||
|  | ||||
| def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|     """Return a connection with a given alias.""" | ||||
|  | ||||
|     global _connections | ||||
|     # Connect to the database if not already connected | ||||
|     if reconnect: | ||||
|         disconnect(alias) | ||||
|  | ||||
|     # If the requested alias already exists in the _connections list, return | ||||
|     # it immediately. | ||||
|     if alias in _connections: | ||||
|         return _connections[alias] | ||||
|  | ||||
|     # Validate that the requested alias exists in the _connection_settings. | ||||
|     # Raise MongoEngineConnectionError if it doesn't. | ||||
|     if alias not in _connection_settings: | ||||
|         if alias == DEFAULT_CONNECTION_NAME: | ||||
|             msg = 'You have not defined a default connection' | ||||
|         else: | ||||
|     if alias not in _connections: | ||||
|         if alias not in _connection_settings: | ||||
|             msg = 'Connection with alias "%s" has not been defined' % alias | ||||
|         raise MongoEngineConnectionError(msg) | ||||
|             if alias == DEFAULT_CONNECTION_NAME: | ||||
|                 msg = 'You have not defined a default connection' | ||||
|             raise ConnectionError(msg) | ||||
|         conn_settings = _connection_settings[alias].copy() | ||||
|  | ||||
|     def _clean_settings(settings_dict): | ||||
|         irrelevant_fields = set([ | ||||
|             'name', 'username', 'password', 'authentication_source', | ||||
|             'authentication_mechanism' | ||||
|         ]) | ||||
|         return { | ||||
|             k: v for k, v in settings_dict.items() | ||||
|             if k not in irrelevant_fields | ||||
|         } | ||||
|         conn_settings.pop('name', None) | ||||
|         conn_settings.pop('username', None) | ||||
|         conn_settings.pop('password', None) | ||||
|         conn_settings.pop('authentication_source', None) | ||||
|         conn_settings.pop('authentication_mechanism', None) | ||||
|  | ||||
|     # Retrieve a copy of the connection settings associated with the requested | ||||
|     # alias and remove the database name and authentication info (we don't | ||||
|     # care about them at this point). | ||||
|     conn_settings = _clean_settings(_connection_settings[alias].copy()) | ||||
|  | ||||
|     # Determine if we should use PyMongo's or mongomock's MongoClient. | ||||
|     is_mock = conn_settings.pop('is_mock', False) | ||||
|     if is_mock: | ||||
|         try: | ||||
|             import mongomock | ||||
|         except ImportError: | ||||
|             raise RuntimeError('You need mongomock installed to mock ' | ||||
|                                'MongoEngine.') | ||||
|         connection_class = mongomock.MongoClient | ||||
|     else: | ||||
|         connection_class = MongoClient | ||||
|  | ||||
|         # For replica set connections with PyMongo 2.x, use | ||||
|         # MongoReplicaSetClient. | ||||
|         # TODO remove this once we stop supporting PyMongo 2.x. | ||||
|         if 'replicaSet' in conn_settings and not IS_PYMONGO_3: | ||||
|             connection_class = MongoReplicaSetClient | ||||
|             conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) | ||||
|  | ||||
|             # hosts_or_uri has to be a string, so if 'host' was provided | ||||
|             # as a list, join its parts and separate them by ',' | ||||
|             if isinstance(conn_settings['hosts_or_uri'], list): | ||||
|                 conn_settings['hosts_or_uri'] = ','.join( | ||||
|                     conn_settings['hosts_or_uri']) | ||||
|         is_mock = conn_settings.pop('is_mock', None) | ||||
|         if is_mock: | ||||
|             # Use MongoClient from mongomock | ||||
|             try: | ||||
|                 import mongomock | ||||
|             except ImportError: | ||||
|                 raise RuntimeError('You need mongomock installed ' | ||||
|                                    'to mock MongoEngine.') | ||||
|             connection_class = mongomock.MongoClient | ||||
|         else: | ||||
|             # Use MongoClient from pymongo | ||||
|             connection_class = MongoClient | ||||
|  | ||||
|         if 'replicaSet' in conn_settings: | ||||
|             # Discard port since it can't be used on MongoReplicaSetClient | ||||
|             conn_settings.pop('port', None) | ||||
|             # Discard replicaSet if not base string | ||||
|             if not isinstance(conn_settings['replicaSet'], basestring): | ||||
|                 conn_settings.pop('replicaSet', None) | ||||
|             if not IS_PYMONGO_3: | ||||
|                 connection_class = MongoReplicaSetClient | ||||
|                 conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) | ||||
|  | ||||
|     # Iterate over all of the connection settings and if a connection with | ||||
|     # the same parameters is already established, use it instead of creating | ||||
|     # a new one. | ||||
|     existing_connection = None | ||||
|     connection_settings_iterator = ( | ||||
|         (db_alias, settings.copy()) | ||||
|         for db_alias, settings in _connection_settings.items() | ||||
|     ) | ||||
|     for db_alias, connection_settings in connection_settings_iterator: | ||||
|         connection_settings = _clean_settings(connection_settings) | ||||
|         if conn_settings == connection_settings and _connections.get(db_alias): | ||||
|             existing_connection = _connections[db_alias] | ||||
|             break | ||||
|  | ||||
|     # If an existing connection was found, assign it to the new alias | ||||
|     if existing_connection: | ||||
|         _connections[alias] = existing_connection | ||||
|     else: | ||||
|         # Otherwise, create the new connection for this alias. Raise | ||||
|         # MongoEngineConnectionError if it can't be established. | ||||
|         try: | ||||
|             _connections[alias] = connection_class(**conn_settings) | ||||
|         except Exception as e: | ||||
|             raise MongoEngineConnectionError( | ||||
|                 'Cannot connect to database %s :\n%s' % (alias, e)) | ||||
|             connection = None | ||||
|             # check for shared connections | ||||
|             connection_settings_iterator = ( | ||||
|                 (db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems()) | ||||
|             for db_alias, connection_settings in connection_settings_iterator: | ||||
|                 connection_settings.pop('name', None) | ||||
|                 connection_settings.pop('username', None) | ||||
|                 connection_settings.pop('password', None) | ||||
|                 connection_settings.pop('authentication_source', None) | ||||
|                 connection_settings.pop('authentication_mechanism', None) | ||||
|                 if conn_settings == connection_settings and _connections.get(db_alias, None): | ||||
|                     connection = _connections[db_alias] | ||||
|                     break | ||||
|  | ||||
|             _connections[alias] = connection if connection else connection_class(**conn_settings) | ||||
|         except Exception, e: | ||||
|             raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) | ||||
|     return _connections[alias] | ||||
|  | ||||
|  | ||||
| def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|     global _dbs | ||||
|     if reconnect: | ||||
|         disconnect(alias) | ||||
|  | ||||
| @@ -243,14 +205,12 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): | ||||
|     running on the default port on localhost. If authentication is needed, | ||||
|     provide username and password arguments as well. | ||||
|  | ||||
|     Multiple databases are supported by using aliases. Provide a separate | ||||
|     Multiple databases are supported by using aliases.  Provide a separate | ||||
|     `alias` to connect to a different instance of :program:`mongod`. | ||||
|  | ||||
|     See the docstring for `register_connection` for more details about all | ||||
|     supported kwargs. | ||||
|  | ||||
|     .. versionchanged:: 0.6 - added multiple database support. | ||||
|     """ | ||||
|     global _connections | ||||
|     if alias not in _connections: | ||||
|         register_connection(alias, db, **kwargs) | ||||
|  | ||||
|   | ||||
| @@ -2,12 +2,12 @@ from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
|  | ||||
|  | ||||
| __all__ = ('switch_db', 'switch_collection', 'no_dereference', | ||||
|            'no_sub_classes', 'query_counter') | ||||
| __all__ = ("switch_db", "switch_collection", "no_dereference", | ||||
|            "no_sub_classes", "query_counter") | ||||
|  | ||||
|  | ||||
| class switch_db(object): | ||||
|     """switch_db alias context manager. | ||||
|     """ switch_db alias context manager. | ||||
|  | ||||
|     Example :: | ||||
|  | ||||
| @@ -18,14 +18,15 @@ class switch_db(object): | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group(name='test').save()  # Saves in the default db | ||||
|         Group(name="test").save()  # Saves in the default db | ||||
|  | ||||
|         with switch_db(Group, 'testdb-1') as Group: | ||||
|             Group(name='hello testdb!').save()  # Saves in testdb-1 | ||||
|             Group(name="hello testdb!").save()  # Saves in testdb-1 | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls, db_alias): | ||||
|         """Construct the switch_db context manager | ||||
|         """ Construct the switch_db context manager | ||||
|  | ||||
|         :param cls: the class to change the registered db | ||||
|         :param db_alias: the name of the specific database to use | ||||
| @@ -33,36 +34,37 @@ class switch_db(object): | ||||
|         self.cls = cls | ||||
|         self.collection = cls._get_collection() | ||||
|         self.db_alias = db_alias | ||||
|         self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME) | ||||
|         self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """Change the db_alias and clear the cached collection.""" | ||||
|         self.cls._meta['db_alias'] = self.db_alias | ||||
|         """ change the db_alias and clear the cached collection """ | ||||
|         self.cls._meta["db_alias"] = self.db_alias | ||||
|         self.cls._collection = None | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """Reset the db_alias and collection.""" | ||||
|         self.cls._meta['db_alias'] = self.ori_db_alias | ||||
|         """ Reset the db_alias and collection """ | ||||
|         self.cls._meta["db_alias"] = self.ori_db_alias | ||||
|         self.cls._collection = self.collection | ||||
|  | ||||
|  | ||||
| class switch_collection(object): | ||||
|     """switch_collection alias context manager. | ||||
|     """ switch_collection alias context manager. | ||||
|  | ||||
|     Example :: | ||||
|  | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group(name='test').save()  # Saves in the default db | ||||
|         Group(name="test").save()  # Saves in the default db | ||||
|  | ||||
|         with switch_collection(Group, 'group1') as Group: | ||||
|             Group(name='hello testdb!').save()  # Saves in group1 collection | ||||
|             Group(name="hello testdb!").save()  # Saves in group1 collection | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls, collection_name): | ||||
|         """Construct the switch_collection context manager. | ||||
|         """ Construct the switch_collection context manager | ||||
|  | ||||
|         :param cls: the class to change the registered db | ||||
|         :param collection_name: the name of the collection to use | ||||
| @@ -73,7 +75,7 @@ class switch_collection(object): | ||||
|         self.collection_name = collection_name | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """Change the _get_collection_name and clear the cached collection.""" | ||||
|         """ change the _get_collection_name and clear the cached collection """ | ||||
|  | ||||
|         @classmethod | ||||
|         def _get_collection_name(cls): | ||||
| @@ -84,23 +86,24 @@ class switch_collection(object): | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """Reset the collection.""" | ||||
|         """ Reset the collection """ | ||||
|         self.cls._collection = self.ori_collection | ||||
|         self.cls._get_collection_name = self.ori_get_collection_name | ||||
|  | ||||
|  | ||||
| class no_dereference(object): | ||||
|     """no_dereference context manager. | ||||
|     """ no_dereference context manager. | ||||
|  | ||||
|     Turns off all dereferencing in Documents for the duration of the context | ||||
|     manager:: | ||||
|  | ||||
|         with no_dereference(Group) as Group: | ||||
|             Group.objects.find() | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls): | ||||
|         """Construct the no_dereference context manager. | ||||
|         """ Construct the no_dereference context manager. | ||||
|  | ||||
|         :param cls: the class to turn dereferencing off on | ||||
|         """ | ||||
| @@ -116,102 +119,103 @@ class no_dereference(object): | ||||
|                                                ComplexBaseField))] | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """Change the objects default and _auto_dereference values.""" | ||||
|         """ change the objects default and _auto_dereference values""" | ||||
|         for field in self.deref_fields: | ||||
|             self.cls._fields[field]._auto_dereference = False | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """Reset the default and _auto_dereference values.""" | ||||
|         """ Reset the default and _auto_dereference values""" | ||||
|         for field in self.deref_fields: | ||||
|             self.cls._fields[field]._auto_dereference = True | ||||
|         return self.cls | ||||
|  | ||||
|  | ||||
| class no_sub_classes(object): | ||||
|     """no_sub_classes context manager. | ||||
|     """ no_sub_classes context manager. | ||||
|  | ||||
|     Only returns instances of this class and no sub (inherited) classes:: | ||||
|  | ||||
|         with no_sub_classes(Group) as Group: | ||||
|             Group.objects.find() | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, cls): | ||||
|         """Construct the no_sub_classes context manager. | ||||
|         """ Construct the no_sub_classes context manager. | ||||
|  | ||||
|         :param cls: the class to turn querying sub classes on | ||||
|         """ | ||||
|         self.cls = cls | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """Change the objects default and _auto_dereference values.""" | ||||
|         """ change the objects default and _auto_dereference values""" | ||||
|         self.cls._all_subclasses = self.cls._subclasses | ||||
|         self.cls._subclasses = (self.cls,) | ||||
|         return self.cls | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """Reset the default and _auto_dereference values.""" | ||||
|         """ Reset the default and _auto_dereference values""" | ||||
|         self.cls._subclasses = self.cls._all_subclasses | ||||
|         delattr(self.cls, '_all_subclasses') | ||||
|         return self.cls | ||||
|  | ||||
|  | ||||
| class query_counter(object): | ||||
|     """Query_counter context manager to get the number of queries.""" | ||||
|     """ Query_counter context manager to get the number of queries. """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         """Construct the query_counter.""" | ||||
|         """ Construct the query_counter. """ | ||||
|         self.counter = 0 | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """On every with block we need to drop the profile collection.""" | ||||
|         """ On every with block we need to drop the profile collection. """ | ||||
|         self.db.set_profiling_level(0) | ||||
|         self.db.system.profile.drop() | ||||
|         self.db.set_profiling_level(2) | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, t, value, traceback): | ||||
|         """Reset the profiling level.""" | ||||
|         """ Reset the profiling level. """ | ||||
|         self.db.set_profiling_level(0) | ||||
|  | ||||
|     def __eq__(self, value): | ||||
|         """== Compare querycounter.""" | ||||
|         """ == Compare querycounter. """ | ||||
|         counter = self._get_count() | ||||
|         return value == counter | ||||
|  | ||||
|     def __ne__(self, value): | ||||
|         """!= Compare querycounter.""" | ||||
|         """ != Compare querycounter. """ | ||||
|         return not self.__eq__(value) | ||||
|  | ||||
|     def __lt__(self, value): | ||||
|         """< Compare querycounter.""" | ||||
|         """ < Compare querycounter. """ | ||||
|         return self._get_count() < value | ||||
|  | ||||
|     def __le__(self, value): | ||||
|         """<= Compare querycounter.""" | ||||
|         """ <= Compare querycounter. """ | ||||
|         return self._get_count() <= value | ||||
|  | ||||
|     def __gt__(self, value): | ||||
|         """> Compare querycounter.""" | ||||
|         """ > Compare querycounter. """ | ||||
|         return self._get_count() > value | ||||
|  | ||||
|     def __ge__(self, value): | ||||
|         """>= Compare querycounter.""" | ||||
|         """ >= Compare querycounter. """ | ||||
|         return self._get_count() >= value | ||||
|  | ||||
|     def __int__(self): | ||||
|         """int representation.""" | ||||
|         """ int representation. """ | ||||
|         return self._get_count() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         """repr query_counter as the number of queries.""" | ||||
|         """ repr query_counter as the number of queries. """ | ||||
|         return u"%s" % self._get_count() | ||||
|  | ||||
|     def _get_count(self): | ||||
|         """Get the number of queries.""" | ||||
|         ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} | ||||
|         """ Get the number of queries. """ | ||||
|         ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} | ||||
|         count = self.db.system.profile.find(ignore_query).count() - self.counter | ||||
|         self.counter += 1 | ||||
|         return count | ||||
|   | ||||
| @@ -1,12 +1,14 @@ | ||||
| from bson import DBRef, SON | ||||
| import six | ||||
|  | ||||
| from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, | ||||
|                               TopLevelDocumentMetaclass, get_document) | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.document import Document, EmbeddedDocument | ||||
| from mongoengine.fields import DictField, ListField, MapField, ReferenceField | ||||
| from mongoengine.queryset import QuerySet | ||||
| from .base import ( | ||||
|     BaseDict, BaseList, EmbeddedDocumentList, | ||||
|     TopLevelDocumentMetaclass, get_document | ||||
| ) | ||||
| from .connection import get_db | ||||
| from .document import Document, EmbeddedDocument | ||||
| from .fields import DictField, ListField, MapField, ReferenceField | ||||
| from .python_support import txt_type | ||||
| from .queryset import QuerySet | ||||
|  | ||||
|  | ||||
| class DeReference(object): | ||||
| @@ -23,7 +25,7 @@ class DeReference(object): | ||||
|             :class:`~mongoengine.base.ComplexBaseField` | ||||
|         :param get: A boolean determining if being called by __get__ | ||||
|         """ | ||||
|         if items is None or isinstance(items, six.string_types): | ||||
|         if items is None or isinstance(items, basestring): | ||||
|             return items | ||||
|  | ||||
|         # cheapest way to convert a queryset to a list | ||||
| @@ -66,11 +68,11 @@ class DeReference(object): | ||||
|  | ||||
|                         items = _get_items(items) | ||||
|                     else: | ||||
|                         items = { | ||||
|                             k: (v if isinstance(v, (DBRef, Document)) | ||||
|                                 else field.to_python(v)) | ||||
|                             for k, v in items.iteritems() | ||||
|                         } | ||||
|                         items = dict([ | ||||
|                             (k, field.to_python(v)) | ||||
|                             if not isinstance(v, (DBRef, Document)) else (k, v) | ||||
|                             for k, v in items.iteritems()] | ||||
|                         ) | ||||
|  | ||||
|         self.reference_map = self._find_references(items) | ||||
|         self.object_map = self._fetch_objects(doc_type=doc_type) | ||||
| @@ -88,14 +90,14 @@ class DeReference(object): | ||||
|             return reference_map | ||||
|  | ||||
|         # Determine the iterator to use | ||||
|         if isinstance(items, dict): | ||||
|             iterator = items.values() | ||||
|         if not hasattr(items, 'items'): | ||||
|             iterator = enumerate(items) | ||||
|         else: | ||||
|             iterator = items | ||||
|             iterator = items.iteritems() | ||||
|  | ||||
|         # Recursively find dbreferences | ||||
|         depth += 1 | ||||
|         for item in iterator: | ||||
|         for k, item in iterator: | ||||
|             if isinstance(item, (Document, EmbeddedDocument)): | ||||
|                 for field_name, field in item._fields.iteritems(): | ||||
|                     v = item._data.get(field_name, None) | ||||
| @@ -149,7 +151,7 @@ class DeReference(object): | ||||
|                     references = get_db()[collection].find({'_id': {'$in': refs}}) | ||||
|                     for ref in references: | ||||
|                         if '_cls' in ref: | ||||
|                             doc = get_document(ref['_cls'])._from_son(ref) | ||||
|                             doc = get_document(ref["_cls"])._from_son(ref) | ||||
|                         elif doc_type is None: | ||||
|                             doc = get_document( | ||||
|                                 ''.join(x.capitalize() | ||||
| @@ -216,7 +218,7 @@ class DeReference(object): | ||||
|             if k in self.object_map and not is_list: | ||||
|                 data[k] = self.object_map[k] | ||||
|             elif isinstance(v, (Document, EmbeddedDocument)): | ||||
|                 for field_name in v._fields: | ||||
|                 for field_name, field in v._fields.iteritems(): | ||||
|                     v = data[k]._data.get(field_name, None) | ||||
|                     if isinstance(v, DBRef): | ||||
|                         data[k]._data[field_name] = self.object_map.get( | ||||
| @@ -225,7 +227,7 @@ class DeReference(object): | ||||
|                         data[k]._data[field_name] = self.object_map.get( | ||||
|                             (v['_ref'].collection, v['_ref'].id), v) | ||||
|                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                         item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name) | ||||
|                         item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) | ||||
|             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                 item_name = '%s.%s' % (name, k) if name else name | ||||
|   | ||||
| @@ -4,12 +4,18 @@ import warnings | ||||
| from bson.dbref import DBRef | ||||
| import pymongo | ||||
| from pymongo.read_preferences import ReadPreference | ||||
| import six | ||||
|  | ||||
| from mongoengine import signals | ||||
| from mongoengine.base import (BaseDict, BaseDocument, BaseList, | ||||
|                               DocumentMetaclass, EmbeddedDocumentList, | ||||
|                               TopLevelDocumentMetaclass, get_document) | ||||
| from mongoengine.base import ( | ||||
|     ALLOW_INHERITANCE, | ||||
|     BaseDict, | ||||
|     BaseDocument, | ||||
|     BaseList, | ||||
|     DocumentMetaclass, | ||||
|     EmbeddedDocumentList, | ||||
|     TopLevelDocumentMetaclass, | ||||
|     get_document | ||||
| ) | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.context_managers import switch_collection, switch_db | ||||
| @@ -25,10 +31,12 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | ||||
|  | ||||
|  | ||||
| def includes_cls(fields): | ||||
|     """Helper function used for ensuring and comparing indexes.""" | ||||
|     """ Helper function used for ensuring and comparing indexes | ||||
|     """ | ||||
|  | ||||
|     first_field = None | ||||
|     if len(fields): | ||||
|         if isinstance(fields[0], six.string_types): | ||||
|         if isinstance(fields[0], basestring): | ||||
|             first_field = fields[0] | ||||
|         elif isinstance(fields[0], (list, tuple)) and len(fields[0]): | ||||
|             first_field = fields[0][0] | ||||
| @@ -49,8 +57,9 @@ class EmbeddedDocument(BaseDocument): | ||||
|     to create a specialised version of the embedded document that will be | ||||
|     stored in the same collection. To facilitate this behaviour a `_cls` | ||||
|     field is added to documents (hidden though the MongoEngine interface). | ||||
|     To enable this behaviour set :attr:`allow_inheritance` to ``True`` in the | ||||
|     :attr:`meta` dictionary. | ||||
|     To disable this behaviour and remove the dependence on the presence of | ||||
|     `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` | ||||
|     dictionary. | ||||
|     """ | ||||
|  | ||||
|     __slots__ = ('_instance', ) | ||||
| @@ -73,15 +82,6 @@ class EmbeddedDocument(BaseDocument): | ||||
|     def __ne__(self, other): | ||||
|         return not self.__eq__(other) | ||||
|  | ||||
|     def to_mongo(self, *args, **kwargs): | ||||
|         data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs) | ||||
|  | ||||
|         # remove _id from the SON if it's in it and it's None | ||||
|         if '_id' in data and data['_id'] is None: | ||||
|             del data['_id'] | ||||
|  | ||||
|         return data | ||||
|  | ||||
|     def save(self, *args, **kwargs): | ||||
|         self._instance.save(*args, **kwargs) | ||||
|  | ||||
| @@ -106,8 +106,9 @@ class Document(BaseDocument): | ||||
|     create a specialised version of the document that will be stored in the | ||||
|     same collection. To facilitate this behaviour a `_cls` | ||||
|     field is added to documents (hidden though the MongoEngine interface). | ||||
|     To enable this behaviourset :attr:`allow_inheritance` to ``True`` in the | ||||
|     :attr:`meta` dictionary. | ||||
|     To disable this behaviour and remove the dependence on the presence of | ||||
|     `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` | ||||
|     dictionary. | ||||
|  | ||||
|     A :class:`~mongoengine.Document` may use a **Capped Collection** by | ||||
|     specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` | ||||
| @@ -148,22 +149,26 @@ class Document(BaseDocument): | ||||
|  | ||||
|     __slots__ = ('__objects',) | ||||
|  | ||||
|     @property | ||||
|     def pk(self): | ||||
|         """Get the primary key.""" | ||||
|         if 'id_field' not in self._meta: | ||||
|             return None | ||||
|         return getattr(self, self._meta['id_field']) | ||||
|     def pk(): | ||||
|         """Primary key alias | ||||
|         """ | ||||
|  | ||||
|     @pk.setter | ||||
|     def pk(self, value): | ||||
|         """Set the primary key.""" | ||||
|         return setattr(self, self._meta['id_field'], value) | ||||
|         def fget(self): | ||||
|             if 'id_field' not in self._meta: | ||||
|                 return None | ||||
|             return getattr(self, self._meta['id_field']) | ||||
|  | ||||
|         def fset(self, value): | ||||
|             return setattr(self, self._meta['id_field'], value) | ||||
|  | ||||
|         return property(fget, fset) | ||||
|  | ||||
|     pk = pk() | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_db(cls): | ||||
|         """Some Model using other db_alias""" | ||||
|         return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) | ||||
|         return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_collection(cls): | ||||
| @@ -206,20 +211,7 @@ class Document(BaseDocument): | ||||
|                 cls.ensure_indexes() | ||||
|         return cls._collection | ||||
|  | ||||
|     def to_mongo(self, *args, **kwargs): | ||||
|         data = super(Document, self).to_mongo(*args, **kwargs) | ||||
|  | ||||
|         # If '_id' is None, try and set it from self._data. If that | ||||
|         # doesn't exist either, remote '_id' from the SON completely. | ||||
|         if data['_id'] is None: | ||||
|             if self._data.get('id') is None: | ||||
|                 del data['_id'] | ||||
|             else: | ||||
|                 data['_id'] = self._data['id'] | ||||
|  | ||||
|         return data | ||||
|  | ||||
|     def modify(self, query=None, **update): | ||||
|     def modify(self, query={}, **update): | ||||
|         """Perform an atomic update of the document in the database and reload | ||||
|         the document object using updated version. | ||||
|  | ||||
| @@ -233,19 +225,17 @@ class Document(BaseDocument): | ||||
|             database matches the query | ||||
|         :param update: Django-style update keyword arguments | ||||
|         """ | ||||
|         if query is None: | ||||
|             query = {} | ||||
|  | ||||
|         if self.pk is None: | ||||
|             raise InvalidDocumentError('The document does not have a primary key.') | ||||
|             raise InvalidDocumentError("The document does not have a primary key.") | ||||
|  | ||||
|         id_field = self._meta['id_field'] | ||||
|         id_field = self._meta["id_field"] | ||||
|         query = query.copy() if isinstance(query, dict) else query.to_query(self) | ||||
|  | ||||
|         if id_field not in query: | ||||
|             query[id_field] = self.pk | ||||
|         elif query[id_field] != self.pk: | ||||
|             raise InvalidQueryError('Invalid document modify query: it must modify only this document.') | ||||
|             raise InvalidQueryError("Invalid document modify query: it must modify only this document.") | ||||
|  | ||||
|         updated = self._qs(**query).modify(new=True, **update) | ||||
|         if updated is None: | ||||
| @@ -313,9 +303,6 @@ class Document(BaseDocument): | ||||
|         .. versionchanged:: 0.10.7 | ||||
|             Add signal_kwargs argument | ||||
|         """ | ||||
|         if self._meta.get('abstract'): | ||||
|             raise InvalidDocumentError('Cannot save an abstract document.') | ||||
|  | ||||
|         signal_kwargs = signal_kwargs or {} | ||||
|         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) | ||||
|  | ||||
| @@ -323,7 +310,7 @@ class Document(BaseDocument): | ||||
|             self.validate(clean=clean) | ||||
|  | ||||
|         if write_concern is None: | ||||
|             write_concern = {'w': 1} | ||||
|             write_concern = {"w": 1} | ||||
|  | ||||
|         doc = self.to_mongo() | ||||
|  | ||||
| @@ -332,135 +319,105 @@ class Document(BaseDocument): | ||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||
|                                               created=created, **signal_kwargs) | ||||
|  | ||||
|         if self._meta.get('auto_create_index', True): | ||||
|             self.ensure_indexes() | ||||
|  | ||||
|         try: | ||||
|             # Save a new document or update an existing one | ||||
|             collection = self._get_collection() | ||||
|             if self._meta.get('auto_create_index', True): | ||||
|                 self.ensure_indexes() | ||||
|             if created: | ||||
|                 object_id = self._save_create(doc, force_insert, write_concern) | ||||
|                 if force_insert: | ||||
|                     object_id = collection.insert(doc, **write_concern) | ||||
|                 else: | ||||
|                     object_id = collection.save(doc, **write_concern) | ||||
|                     # In PyMongo 3.0, the save() call calls internally the _update() call | ||||
|                     # but they forget to return the _id value passed back, therefore getting it back here | ||||
|                     # Correct behaviour in 2.X and in 3.0.1+ versions | ||||
|                     if not object_id and pymongo.version_tuple == (3, 0): | ||||
|                         pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||
|                         object_id = ( | ||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|                         )  # TODO doesn't this make 2 queries? | ||||
|             else: | ||||
|                 object_id, created = self._save_update(doc, save_condition, | ||||
|                                                        write_concern) | ||||
|                 object_id = doc['_id'] | ||||
|                 updates, removals = self._delta() | ||||
|                 # Need to add shard key to query, or you get an error | ||||
|                 if save_condition is not None: | ||||
|                     select_dict = transform.query(self.__class__, | ||||
|                                                   **save_condition) | ||||
|                 else: | ||||
|                     select_dict = {} | ||||
|                 select_dict['_id'] = object_id | ||||
|                 shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||
|                 for k in shard_key: | ||||
|                     path = self._lookup_field(k.split('.')) | ||||
|                     actual_key = [p.db_field for p in path] | ||||
|                     val = doc | ||||
|                     for ak in actual_key: | ||||
|                         val = val[ak] | ||||
|                     select_dict['.'.join(actual_key)] = val | ||||
|  | ||||
|                 def is_new_object(last_error): | ||||
|                     if last_error is not None: | ||||
|                         updated = last_error.get("updatedExisting") | ||||
|                         if updated is not None: | ||||
|                             return not updated | ||||
|                     return created | ||||
|  | ||||
|                 update_query = {} | ||||
|  | ||||
|                 if updates: | ||||
|                     update_query["$set"] = updates | ||||
|                 if removals: | ||||
|                     update_query["$unset"] = removals | ||||
|                 if updates or removals: | ||||
|                     upsert = save_condition is None | ||||
|                     last_error = collection.update(select_dict, update_query, | ||||
|                                                    upsert=upsert, **write_concern) | ||||
|                     if not upsert and last_error["n"] == 0: | ||||
|                         raise SaveConditionError('Race condition preventing' | ||||
|                                                  ' document update detected') | ||||
|                     created = is_new_object(last_error) | ||||
|  | ||||
|             if cascade is None: | ||||
|                 cascade = (self._meta.get('cascade', False) or | ||||
|                            cascade_kwargs is not None) | ||||
|                 cascade = self._meta.get( | ||||
|                     'cascade', False) or cascade_kwargs is not None | ||||
|  | ||||
|             if cascade: | ||||
|                 kwargs = { | ||||
|                     'force_insert': force_insert, | ||||
|                     'validate': validate, | ||||
|                     'write_concern': write_concern, | ||||
|                     'cascade': cascade | ||||
|                     "force_insert": force_insert, | ||||
|                     "validate": validate, | ||||
|                     "write_concern": write_concern, | ||||
|                     "cascade": cascade | ||||
|                 } | ||||
|                 if cascade_kwargs:  # Allow granular control over cascades | ||||
|                     kwargs.update(cascade_kwargs) | ||||
|                 kwargs['_refs'] = _refs | ||||
|                 self.cascade_save(**kwargs) | ||||
|  | ||||
|         except pymongo.errors.DuplicateKeyError as err: | ||||
|         except pymongo.errors.DuplicateKeyError, err: | ||||
|             message = u'Tried to save duplicate unique keys (%s)' | ||||
|             raise NotUniqueError(message % six.text_type(err)) | ||||
|         except pymongo.errors.OperationFailure as err: | ||||
|             raise NotUniqueError(message % unicode(err)) | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             message = 'Could not save document (%s)' | ||||
|             if re.match('^E1100[01] duplicate key', six.text_type(err)): | ||||
|             if re.match('^E1100[01] duplicate key', unicode(err)): | ||||
|                 # E11000 - duplicate key error index | ||||
|                 # E11001 - duplicate key on update | ||||
|                 message = u'Tried to save duplicate unique keys (%s)' | ||||
|                 raise NotUniqueError(message % six.text_type(err)) | ||||
|             raise OperationError(message % six.text_type(err)) | ||||
|  | ||||
|         # Make sure we store the PK on this document now that it's saved | ||||
|                 raise NotUniqueError(message % unicode(err)) | ||||
|             raise OperationError(message % unicode(err)) | ||||
|         id_field = self._meta['id_field'] | ||||
|         if created or id_field not in self._meta.get('shard_key', []): | ||||
|             self[id_field] = self._fields[id_field].to_python(object_id) | ||||
|  | ||||
|         signals.post_save.send(self.__class__, document=self, | ||||
|                                created=created, **signal_kwargs) | ||||
|  | ||||
|         self._clear_changed_fields() | ||||
|         self._created = False | ||||
|  | ||||
|         return self | ||||
|  | ||||
|     def _save_create(self, doc, force_insert, write_concern): | ||||
|         """Save a new document. | ||||
|  | ||||
|         Helper method, should only be used inside save(). | ||||
|         """ | ||||
|         collection = self._get_collection() | ||||
|  | ||||
|         if force_insert: | ||||
|             return collection.insert(doc, **write_concern) | ||||
|  | ||||
|         object_id = collection.save(doc, **write_concern) | ||||
|  | ||||
|         # In PyMongo 3.0, the save() call calls internally the _update() call | ||||
|         # but they forget to return the _id value passed back, therefore getting it back here | ||||
|         # Correct behaviour in 2.X and in 3.0.1+ versions | ||||
|         if not object_id and pymongo.version_tuple == (3, 0): | ||||
|             pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||
|             object_id = ( | ||||
|                 self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||
|                 self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||
|             )  # TODO doesn't this make 2 queries? | ||||
|  | ||||
|         return object_id | ||||
|  | ||||
|     def _save_update(self, doc, save_condition, write_concern): | ||||
|         """Update an existing document. | ||||
|  | ||||
|         Helper method, should only be used inside save(). | ||||
|         """ | ||||
|         collection = self._get_collection() | ||||
|         object_id = doc['_id'] | ||||
|         created = False | ||||
|  | ||||
|         select_dict = {} | ||||
|         if save_condition is not None: | ||||
|             select_dict = transform.query(self.__class__, **save_condition) | ||||
|  | ||||
|         select_dict['_id'] = object_id | ||||
|  | ||||
|         # Need to add shard key to query, or you get an error | ||||
|         shard_key = self._meta.get('shard_key', tuple()) | ||||
|         for k in shard_key: | ||||
|             path = self._lookup_field(k.split('.')) | ||||
|             actual_key = [p.db_field for p in path] | ||||
|             val = doc | ||||
|             for ak in actual_key: | ||||
|                 val = val[ak] | ||||
|             select_dict['.'.join(actual_key)] = val | ||||
|  | ||||
|         updates, removals = self._delta() | ||||
|         update_query = {} | ||||
|         if updates: | ||||
|             update_query['$set'] = updates | ||||
|         if removals: | ||||
|             update_query['$unset'] = removals | ||||
|         if updates or removals: | ||||
|             upsert = save_condition is None | ||||
|             last_error = collection.update(select_dict, update_query, | ||||
|                                            upsert=upsert, **write_concern) | ||||
|             if not upsert and last_error['n'] == 0: | ||||
|                 raise SaveConditionError('Race condition preventing' | ||||
|                                          ' document update detected') | ||||
|             if last_error is not None: | ||||
|                 updated_existing = last_error.get('updatedExisting') | ||||
|                 if updated_existing is False: | ||||
|                     created = True | ||||
|                     # !!! This is bad, means we accidentally created a new, | ||||
|                     # potentially corrupted document. See | ||||
|                     # https://github.com/MongoEngine/mongoengine/issues/564 | ||||
|  | ||||
|         return object_id, created | ||||
|  | ||||
|     def cascade_save(self, **kwargs): | ||||
|         """Recursively save any references and generic references on the | ||||
|         document. | ||||
|         """ | ||||
|         _refs = kwargs.get('_refs') or [] | ||||
|     def cascade_save(self, *args, **kwargs): | ||||
|         """Recursively saves any references / | ||||
|            generic references on the document""" | ||||
|         _refs = kwargs.get('_refs', []) or [] | ||||
|  | ||||
|         ReferenceField = _import_class('ReferenceField') | ||||
|         GenericReferenceField = _import_class('GenericReferenceField') | ||||
| @@ -486,17 +443,16 @@ class Document(BaseDocument): | ||||
|  | ||||
|     @property | ||||
|     def _qs(self): | ||||
|         """Return the queryset to use for updating / reloading / deletions.""" | ||||
|         """ | ||||
|         Returns the queryset to use for updating / reloading / deletions | ||||
|         """ | ||||
|         if not hasattr(self, '__objects'): | ||||
|             self.__objects = QuerySet(self, self._get_collection()) | ||||
|         return self.__objects | ||||
|  | ||||
|     @property | ||||
|     def _object_key(self): | ||||
|         """Get the query dict that can be used to fetch this object from | ||||
|         the database. Most of the time it's a simple PK lookup, but in | ||||
|         case of a sharded collection with a compound shard key, it can | ||||
|         contain a more complex query. | ||||
|         """Dict to identify object in collection | ||||
|         """ | ||||
|         select_dict = {'pk': self.pk} | ||||
|         shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||
| @@ -519,8 +475,8 @@ class Document(BaseDocument): | ||||
|         if self.pk is None: | ||||
|             if kwargs.get('upsert', False): | ||||
|                 query = self.to_mongo() | ||||
|                 if '_cls' in query: | ||||
|                     del query['_cls'] | ||||
|                 if "_cls" in query: | ||||
|                     del query["_cls"] | ||||
|                 return self._qs.filter(**query).update_one(**kwargs) | ||||
|             else: | ||||
|                 raise OperationError( | ||||
| @@ -557,7 +513,7 @@ class Document(BaseDocument): | ||||
|         try: | ||||
|             self._qs.filter( | ||||
|                 **self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) | ||||
|         except pymongo.errors.OperationFailure as err: | ||||
|         except pymongo.errors.OperationFailure, err: | ||||
|             message = u'Could not delete document (%s)' % err.message | ||||
|             raise OperationError(message) | ||||
|         signals.post_delete.send(self.__class__, document=self, **signal_kwargs) | ||||
| @@ -645,12 +601,11 @@ class Document(BaseDocument): | ||||
|         if fields and isinstance(fields[0], int): | ||||
|             max_depth = fields[0] | ||||
|             fields = fields[1:] | ||||
|         elif 'max_depth' in kwargs: | ||||
|             max_depth = kwargs['max_depth'] | ||||
|         elif "max_depth" in kwargs: | ||||
|             max_depth = kwargs["max_depth"] | ||||
|  | ||||
|         if self.pk is None: | ||||
|             raise self.DoesNotExist('Document does not exist') | ||||
|  | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||
|             **self._object_key).only(*fields).limit( | ||||
|             1).select_related(max_depth=max_depth) | ||||
| @@ -658,7 +613,7 @@ class Document(BaseDocument): | ||||
|         if obj: | ||||
|             obj = obj[0] | ||||
|         else: | ||||
|             raise self.DoesNotExist('Document does not exist') | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|  | ||||
|         for field in obj._data: | ||||
|             if not fields or field in fields: | ||||
| @@ -701,7 +656,7 @@ class Document(BaseDocument): | ||||
|         """Returns an instance of :class:`~bson.dbref.DBRef` useful in | ||||
|         `__raw__` queries.""" | ||||
|         if self.pk is None: | ||||
|             msg = 'Only saved documents can have a valid dbref' | ||||
|             msg = "Only saved documents can have a valid dbref" | ||||
|             raise OperationError(msg) | ||||
|         return DBRef(self.__class__._get_collection_name(), self.pk) | ||||
|  | ||||
| @@ -756,7 +711,7 @@ class Document(BaseDocument): | ||||
|         fields = index_spec.pop('fields') | ||||
|         drop_dups = kwargs.get('drop_dups', False) | ||||
|         if IS_PYMONGO_3 and drop_dups: | ||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' | ||||
|             msg = "drop_dups is deprecated and is removed when using PyMongo 3+." | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         elif not IS_PYMONGO_3: | ||||
|             index_spec['drop_dups'] = drop_dups | ||||
| @@ -782,7 +737,7 @@ class Document(BaseDocument): | ||||
|             will be removed if PyMongo3+ is used | ||||
|         """ | ||||
|         if IS_PYMONGO_3 and drop_dups: | ||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' | ||||
|             msg = "drop_dups is deprecated and is removed when using PyMongo 3+." | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|         elif not IS_PYMONGO_3: | ||||
|             kwargs.update({'drop_dups': drop_dups}) | ||||
| @@ -802,7 +757,7 @@ class Document(BaseDocument): | ||||
|         index_opts = cls._meta.get('index_opts') or {} | ||||
|         index_cls = cls._meta.get('index_cls', True) | ||||
|         if IS_PYMONGO_3 and drop_dups: | ||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' | ||||
|             msg = "drop_dups is deprecated and is removed when using PyMongo 3+." | ||||
|             warnings.warn(msg, DeprecationWarning) | ||||
|  | ||||
|         collection = cls._get_collection() | ||||
| @@ -840,7 +795,8 @@ class Document(BaseDocument): | ||||
|  | ||||
|         # If _cls is being used (for polymorphism), it needs an index, | ||||
|         # only if another index doesn't begin with _cls | ||||
|         if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'): | ||||
|         if (index_cls and not cls_indexed and | ||||
|                 cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): | ||||
|  | ||||
|             # we shouldn't pass 'cls' to the collection.ensureIndex options | ||||
|             # because of https://jira.mongodb.org/browse/SERVER-769 | ||||
| @@ -859,6 +815,7 @@ class Document(BaseDocument): | ||||
|         """ Lists all of the indexes that should be created for given | ||||
|         collection. It includes all the indexes from super- and sub-classes. | ||||
|         """ | ||||
|  | ||||
|         if cls._meta.get('abstract'): | ||||
|             return [] | ||||
|  | ||||
| @@ -909,15 +866,16 @@ class Document(BaseDocument): | ||||
|         # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed | ||||
|         if [(u'_id', 1)] not in indexes: | ||||
|             indexes.append([(u'_id', 1)]) | ||||
|         if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'): | ||||
|         if (cls._meta.get('index_cls', True) and | ||||
|                 cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): | ||||
|             indexes.append([(u'_cls', 1)]) | ||||
|  | ||||
|         return indexes | ||||
|  | ||||
|     @classmethod | ||||
|     def compare_indexes(cls): | ||||
|         """ Compares the indexes defined in MongoEngine with the ones | ||||
|         existing in the database. Returns any missing/extra indexes. | ||||
|         """ Compares the indexes defined in MongoEngine with the ones existing | ||||
|         in the database. Returns any missing/extra indexes. | ||||
|         """ | ||||
|  | ||||
|         required = cls.list_indexes() | ||||
| @@ -961,9 +919,8 @@ class DynamicDocument(Document): | ||||
|     _dynamic = True | ||||
|  | ||||
|     def __delattr__(self, *args, **kwargs): | ||||
|         """Delete the attribute by setting to None and allowing _delta | ||||
|         to unset it. | ||||
|         """ | ||||
|         """Deletes the attribute by setting to None and allowing _delta to unset | ||||
|         it""" | ||||
|         field_name = args[0] | ||||
|         if field_name in self._dynamic_fields: | ||||
|             setattr(self, field_name, None) | ||||
| @@ -985,9 +942,8 @@ class DynamicEmbeddedDocument(EmbeddedDocument): | ||||
|     _dynamic = True | ||||
|  | ||||
|     def __delattr__(self, *args, **kwargs): | ||||
|         """Delete the attribute by setting to None and allowing _delta | ||||
|         to unset it. | ||||
|         """ | ||||
|         """Deletes the attribute by setting to None and allowing _delta to unset | ||||
|         it""" | ||||
|         field_name = args[0] | ||||
|         if field_name in self._fields: | ||||
|             default = self._fields[field_name].default | ||||
| @@ -1029,10 +985,10 @@ class MapReduceDocument(object): | ||||
|             try: | ||||
|                 self.key = id_field_type(self.key) | ||||
|             except Exception: | ||||
|                 raise Exception('Could not cast key as %s' % | ||||
|                 raise Exception("Could not cast key as %s" % | ||||
|                                 id_field_type.__name__) | ||||
|  | ||||
|         if not hasattr(self, '_key_object'): | ||||
|         if not hasattr(self, "_key_object"): | ||||
|             self._key_object = self._document.objects.with_id(self.key) | ||||
|             return self._key_object | ||||
|         return self._key_object | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| from collections import defaultdict | ||||
|  | ||||
| import six | ||||
| from mongoengine.python_support import txt_type | ||||
|  | ||||
|  | ||||
| __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', | ||||
|            'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', | ||||
| @@ -50,8 +51,8 @@ class FieldDoesNotExist(Exception): | ||||
|     or an :class:`~mongoengine.EmbeddedDocument`. | ||||
|  | ||||
|     To avoid this behavior on data loading, | ||||
|     you should set the :attr:`strict` to ``False`` | ||||
|     in the :attr:`meta` dictionary. | ||||
|     you should the :attr:`strict` to ``False`` | ||||
|     in the :attr:`meta` dictionnary. | ||||
|     """ | ||||
|  | ||||
|  | ||||
| @@ -70,13 +71,13 @@ class ValidationError(AssertionError): | ||||
|     field_name = None | ||||
|     _message = None | ||||
|  | ||||
|     def __init__(self, message='', **kwargs): | ||||
|     def __init__(self, message="", **kwargs): | ||||
|         self.errors = kwargs.get('errors', {}) | ||||
|         self.field_name = kwargs.get('field_name') | ||||
|         self.message = message | ||||
|  | ||||
|     def __str__(self): | ||||
|         return six.text_type(self.message) | ||||
|         return txt_type(self.message) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return '%s(%s,)' % (self.__class__.__name__, self.message) | ||||
| @@ -110,20 +111,17 @@ class ValidationError(AssertionError): | ||||
|             errors_dict = {} | ||||
|             if not source: | ||||
|                 return errors_dict | ||||
|  | ||||
|             if isinstance(source, dict): | ||||
|                 for field_name, error in source.iteritems(): | ||||
|                     errors_dict[field_name] = build_dict(error) | ||||
|             elif isinstance(source, ValidationError) and source.errors: | ||||
|                 return build_dict(source.errors) | ||||
|             else: | ||||
|                 return six.text_type(source) | ||||
|  | ||||
|                 return unicode(source) | ||||
|             return errors_dict | ||||
|  | ||||
|         if not self.errors: | ||||
|             return {} | ||||
|  | ||||
|         return build_dict(self.errors) | ||||
|  | ||||
|     def _format_errors(self): | ||||
| @@ -136,10 +134,10 @@ class ValidationError(AssertionError): | ||||
|                 value = ' '.join( | ||||
|                     [generate_key(v, k) for k, v in value.iteritems()]) | ||||
|  | ||||
|             results = '%s.%s' % (prefix, value) if prefix else value | ||||
|             results = "%s.%s" % (prefix, value) if prefix else value | ||||
|             return results | ||||
|  | ||||
|         error_dict = defaultdict(list) | ||||
|         for k, v in self.to_dict().iteritems(): | ||||
|             error_dict[generate_key(v)].append(k) | ||||
|         return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()]) | ||||
|         return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import decimal | ||||
| import itertools | ||||
| import re | ||||
| import time | ||||
| import urllib2 | ||||
| import uuid | ||||
| import warnings | ||||
| from operator import itemgetter | ||||
| @@ -24,13 +25,13 @@ try: | ||||
| except ImportError: | ||||
|     Int64 = long | ||||
|  | ||||
| from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, | ||||
|                               GeoJsonBaseField, ObjectIdField, get_document) | ||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from mongoengine.document import Document, EmbeddedDocument | ||||
| from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError | ||||
| from mongoengine.python_support import StringIO | ||||
| from mongoengine.queryset import DO_NOTHING, QuerySet | ||||
| from .base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField, | ||||
|                    ObjectIdField, get_document) | ||||
| from .connection import DEFAULT_CONNECTION_NAME, get_db | ||||
| from .document import Document, EmbeddedDocument | ||||
| from .errors import DoesNotExist, ValidationError | ||||
| from .python_support import PY3, StringIO, bin_type, str_types, txt_type | ||||
| from .queryset import DO_NOTHING, QuerySet | ||||
|  | ||||
| try: | ||||
|     from PIL import Image, ImageOps | ||||
| @@ -38,7 +39,7 @@ except ImportError: | ||||
|     Image = None | ||||
|     ImageOps = None | ||||
|  | ||||
| __all__ = ( | ||||
| __all__ = [ | ||||
|     'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', | ||||
|     'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', | ||||
|     'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', | ||||
| @@ -49,14 +50,14 @@ __all__ = ( | ||||
|     'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', | ||||
|     'GeoPointField', 'PointField', 'LineStringField', 'PolygonField', | ||||
|     'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField', | ||||
|     'MultiPolygonField', 'GeoJsonBaseField' | ||||
| ) | ||||
|     'MultiPolygonField', 'GeoJsonBaseField'] | ||||
|  | ||||
| RECURSIVE_REFERENCE_CONSTANT = 'self' | ||||
|  | ||||
|  | ||||
| class StringField(BaseField): | ||||
|     """A unicode string field.""" | ||||
|     """A unicode string field. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): | ||||
|         self.regex = re.compile(regex) if regex else None | ||||
| @@ -65,7 +66,7 @@ class StringField(BaseField): | ||||
|         super(StringField, self).__init__(**kwargs) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         if isinstance(value, six.text_type): | ||||
|         if isinstance(value, unicode): | ||||
|             return value | ||||
|         try: | ||||
|             value = value.decode('utf-8') | ||||
| @@ -74,7 +75,7 @@ class StringField(BaseField): | ||||
|         return value | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if not isinstance(value, six.string_types): | ||||
|         if not isinstance(value, basestring): | ||||
|             self.error('StringField only accepts string values') | ||||
|  | ||||
|         if self.max_length is not None and len(value) > self.max_length: | ||||
| @@ -90,7 +91,7 @@ class StringField(BaseField): | ||||
|         return None | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if not isinstance(op, six.string_types): | ||||
|         if not isinstance(op, basestring): | ||||
|             return value | ||||
|  | ||||
|         if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): | ||||
| @@ -139,14 +140,25 @@ class URLField(StringField): | ||||
|         # Check first if the scheme is valid | ||||
|         scheme = value.split('://')[0].lower() | ||||
|         if scheme not in self.schemes: | ||||
|             self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value)) | ||||
|             self.error('Invalid scheme {} in URL: {}'.format(scheme, value)) | ||||
|             return | ||||
|  | ||||
|         # Then check full URL | ||||
|         if not self.url_regex.match(value): | ||||
|             self.error(u'Invalid URL: {}'.format(value)) | ||||
|             self.error('Invalid URL: {}'.format(value)) | ||||
|             return | ||||
|  | ||||
|         if self.verify_exists: | ||||
|             warnings.warn( | ||||
|                 "The URLField verify_exists argument has intractable security " | ||||
|                 "and performance issues. Accordingly, it has been deprecated.", | ||||
|                 DeprecationWarning) | ||||
|             try: | ||||
|                 request = urllib2.Request(value) | ||||
|                 urllib2.urlopen(request) | ||||
|             except Exception, e: | ||||
|                 self.error('This URL appears to be a broken link: %s' % e) | ||||
|  | ||||
|  | ||||
| class EmailField(StringField): | ||||
|     """A field that validates input as an email address. | ||||
| @@ -170,7 +182,8 @@ class EmailField(StringField): | ||||
|  | ||||
|  | ||||
| class IntField(BaseField): | ||||
|     """32-bit integer field.""" | ||||
|     """An 32-bit integer field. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, min_value=None, max_value=None, **kwargs): | ||||
|         self.min_value, self.max_value = min_value, max_value | ||||
| @@ -203,7 +216,8 @@ class IntField(BaseField): | ||||
|  | ||||
|  | ||||
| class LongField(BaseField): | ||||
|     """64-bit integer field.""" | ||||
|     """An 64-bit integer field. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, min_value=None, max_value=None, **kwargs): | ||||
|         self.min_value, self.max_value = min_value, max_value | ||||
| @@ -239,7 +253,8 @@ class LongField(BaseField): | ||||
|  | ||||
|  | ||||
| class FloatField(BaseField): | ||||
|     """Floating point number field.""" | ||||
|     """An floating point number field. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, min_value=None, max_value=None, **kwargs): | ||||
|         self.min_value, self.max_value = min_value, max_value | ||||
| @@ -276,7 +291,7 @@ class FloatField(BaseField): | ||||
|  | ||||
|  | ||||
| class DecimalField(BaseField): | ||||
|     """Fixed-point decimal number field. | ||||
|     """A fixed-point decimal number field. | ||||
|  | ||||
|     .. versionchanged:: 0.8 | ||||
|     .. versionadded:: 0.3 | ||||
| @@ -317,25 +332,25 @@ class DecimalField(BaseField): | ||||
|  | ||||
|         # Convert to string for python 2.6 before casting to Decimal | ||||
|         try: | ||||
|             value = decimal.Decimal('%s' % value) | ||||
|             value = decimal.Decimal("%s" % value) | ||||
|         except decimal.InvalidOperation: | ||||
|             return value | ||||
|         return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding) | ||||
|         return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         if value is None: | ||||
|             return value | ||||
|         if self.force_string: | ||||
|             return six.text_type(self.to_python(value)) | ||||
|             return unicode(value) | ||||
|         return float(self.to_python(value)) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if not isinstance(value, decimal.Decimal): | ||||
|             if not isinstance(value, six.string_types): | ||||
|                 value = six.text_type(value) | ||||
|             if not isinstance(value, basestring): | ||||
|                 value = unicode(value) | ||||
|             try: | ||||
|                 value = decimal.Decimal(value) | ||||
|             except Exception as exc: | ||||
|             except Exception, exc: | ||||
|                 self.error('Could not convert value to decimal: %s' % exc) | ||||
|  | ||||
|         if self.min_value is not None and value < self.min_value: | ||||
| @@ -349,7 +364,7 @@ class DecimalField(BaseField): | ||||
|  | ||||
|  | ||||
| class BooleanField(BaseField): | ||||
|     """Boolean field type. | ||||
|     """A boolean field type. | ||||
|  | ||||
|     .. versionadded:: 0.1.2 | ||||
|     """ | ||||
| @@ -367,7 +382,7 @@ class BooleanField(BaseField): | ||||
|  | ||||
|  | ||||
| class DateTimeField(BaseField): | ||||
|     """Datetime field. | ||||
|     """A datetime field. | ||||
|  | ||||
|     Uses the python-dateutil library if available alternatively use time.strptime | ||||
|     to parse the dates.  Note: python-dateutil's parser is fully featured and when | ||||
| @@ -395,7 +410,7 @@ class DateTimeField(BaseField): | ||||
|         if callable(value): | ||||
|             return value() | ||||
|  | ||||
|         if not isinstance(value, six.string_types): | ||||
|         if not isinstance(value, basestring): | ||||
|             return None | ||||
|  | ||||
|         # Attempt to parse a datetime: | ||||
| @@ -522,19 +537,16 @@ class EmbeddedDocumentField(BaseField): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, document_type, **kwargs): | ||||
|         if ( | ||||
|             not isinstance(document_type, six.string_types) and | ||||
|             not issubclass(document_type, EmbeddedDocument) | ||||
|         ): | ||||
|             self.error('Invalid embedded document class provided to an ' | ||||
|                        'EmbeddedDocumentField') | ||||
|  | ||||
|         if not isinstance(document_type, basestring): | ||||
|             if not issubclass(document_type, EmbeddedDocument): | ||||
|                 self.error('Invalid embedded document class provided to an ' | ||||
|                            'EmbeddedDocumentField') | ||||
|         self.document_type_obj = document_type | ||||
|         super(EmbeddedDocumentField, self).__init__(**kwargs) | ||||
|  | ||||
|     @property | ||||
|     def document_type(self): | ||||
|         if isinstance(self.document_type_obj, six.string_types): | ||||
|         if isinstance(self.document_type_obj, basestring): | ||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||
|                 self.document_type_obj = self.owner_document | ||||
|             else: | ||||
| @@ -566,11 +578,7 @@ class EmbeddedDocumentField(BaseField): | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if value is not None and not isinstance(value, self.document_type): | ||||
|             try: | ||||
|                 value = self.document_type._from_son(value) | ||||
|             except ValueError: | ||||
|                 raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" % | ||||
|                                         (self.document_type._class_name,)) | ||||
|             value = self.document_type._from_son(value) | ||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||
|         return self.to_mongo(value) | ||||
|  | ||||
| @@ -623,7 +631,7 @@ class DynamicField(BaseField): | ||||
|         """Convert a Python type to a MongoDB compatible type. | ||||
|         """ | ||||
|  | ||||
|         if isinstance(value, six.string_types): | ||||
|         if isinstance(value, basestring): | ||||
|             return value | ||||
|  | ||||
|         if hasattr(value, 'to_mongo'): | ||||
| @@ -631,7 +639,7 @@ class DynamicField(BaseField): | ||||
|             val = value.to_mongo(use_db_field, fields) | ||||
|             # If we its a document thats not inherited add _cls | ||||
|             if isinstance(value, Document): | ||||
|                 val = {'_ref': value.to_dbref(), '_cls': cls.__name__} | ||||
|                 val = {"_ref": value.to_dbref(), "_cls": cls.__name__} | ||||
|             if isinstance(value, EmbeddedDocument): | ||||
|                 val['_cls'] = cls.__name__ | ||||
|             return val | ||||
| @@ -642,7 +650,7 @@ class DynamicField(BaseField): | ||||
|         is_list = False | ||||
|         if not hasattr(value, 'items'): | ||||
|             is_list = True | ||||
|             value = {k: v for k, v in enumerate(value)} | ||||
|             value = dict([(k, v) for k, v in enumerate(value)]) | ||||
|  | ||||
|         data = {} | ||||
|         for k, v in value.iteritems(): | ||||
| @@ -666,12 +674,12 @@ class DynamicField(BaseField): | ||||
|         return member_name | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if isinstance(value, six.string_types): | ||||
|         if isinstance(value, basestring): | ||||
|             return StringField().prepare_query_value(op, value) | ||||
|         return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) | ||||
|  | ||||
|     def validate(self, value, clean=True): | ||||
|         if hasattr(value, 'validate'): | ||||
|         if hasattr(value, "validate"): | ||||
|             value.validate(clean=clean) | ||||
|  | ||||
|  | ||||
| @@ -691,27 +699,21 @@ class ListField(ComplexBaseField): | ||||
|         super(ListField, self).__init__(**kwargs) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Make sure that a list of valid fields is being used.""" | ||||
|         """Make sure that a list of valid fields is being used. | ||||
|         """ | ||||
|         if (not isinstance(value, (list, tuple, QuerySet)) or | ||||
|                 isinstance(value, six.string_types)): | ||||
|                 isinstance(value, basestring)): | ||||
|             self.error('Only lists and tuples may be used in a list field') | ||||
|         super(ListField, self).validate(value) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if self.field: | ||||
|  | ||||
|             # If the value is iterable and it's not a string nor a | ||||
|             # BaseDocument, call prepare_query_value for each of its items. | ||||
|             if ( | ||||
|                 op in ('set', 'unset', None) and | ||||
|                 hasattr(value, '__iter__') and | ||||
|                 not isinstance(value, six.string_types) and | ||||
|                 not isinstance(value, BaseDocument) | ||||
|             ): | ||||
|             if op in ('set', 'unset', None) and ( | ||||
|                     not isinstance(value, basestring) and | ||||
|                     not isinstance(value, BaseDocument) and | ||||
|                     hasattr(value, '__iter__')): | ||||
|                 return [self.field.prepare_query_value(op, v) for v in value] | ||||
|  | ||||
|             return self.field.prepare_query_value(op, value) | ||||
|  | ||||
|         return super(ListField, self).prepare_query_value(op, value) | ||||
|  | ||||
|  | ||||
| @@ -724,6 +726,7 @@ class EmbeddedDocumentListField(ListField): | ||||
|         :class:`~mongoengine.EmbeddedDocument`. | ||||
|  | ||||
|     .. versionadded:: 0.9 | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, document_type, **kwargs): | ||||
| @@ -772,17 +775,17 @@ class SortedListField(ListField): | ||||
|  | ||||
|  | ||||
| def key_not_string(d): | ||||
|     """Helper function to recursively determine if any key in a | ||||
|     dictionary is not a string. | ||||
|     """ Helper function to recursively determine if any key in a dictionary is | ||||
|     not a string. | ||||
|     """ | ||||
|     for k, v in d.items(): | ||||
|         if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)): | ||||
|         if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): | ||||
|             return True | ||||
|  | ||||
|  | ||||
| def key_has_dot_or_dollar(d): | ||||
|     """Helper function to recursively determine if any key in a | ||||
|     dictionary contains a dot or a dollar sign. | ||||
|     """ Helper function to recursively determine if any key in a dictionary | ||||
|     contains a dot or a dollar sign. | ||||
|     """ | ||||
|     for k, v in d.items(): | ||||
|         if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): | ||||
| @@ -810,13 +813,14 @@ class DictField(ComplexBaseField): | ||||
|         super(DictField, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Make sure that a list of valid fields is being used.""" | ||||
|         """Make sure that a list of valid fields is being used. | ||||
|         """ | ||||
|         if not isinstance(value, dict): | ||||
|             self.error('Only dictionaries may be used in a DictField') | ||||
|  | ||||
|         if key_not_string(value): | ||||
|             msg = ('Invalid dictionary key - documents must ' | ||||
|                    'have only string keys') | ||||
|             msg = ("Invalid dictionary key - documents must " | ||||
|                    "have only string keys") | ||||
|             self.error(msg) | ||||
|         if key_has_dot_or_dollar(value): | ||||
|             self.error('Invalid dictionary key name - keys may not contain "."' | ||||
| @@ -831,15 +835,14 @@ class DictField(ComplexBaseField): | ||||
|                            'istartswith', 'endswith', 'iendswith', | ||||
|                            'exact', 'iexact'] | ||||
|  | ||||
|         if op in match_operators and isinstance(value, six.string_types): | ||||
|         if op in match_operators and isinstance(value, basestring): | ||||
|             return StringField().prepare_query_value(op, value) | ||||
|  | ||||
|         if hasattr(self.field, 'field'): | ||||
|             if op in ('set', 'unset') and isinstance(value, dict): | ||||
|                 return { | ||||
|                     k: self.field.prepare_query_value(op, v) | ||||
|                     for k, v in value.items() | ||||
|                 } | ||||
|                 return dict( | ||||
|                     (k, self.field.prepare_query_value(op, v)) | ||||
|                     for k, v in value.items()) | ||||
|             return self.field.prepare_query_value(op, value) | ||||
|  | ||||
|         return super(DictField, self).prepare_query_value(op, value) | ||||
| @@ -888,6 +891,10 @@ class ReferenceField(BaseField): | ||||
|  | ||||
|         Foo.register_delete_rule(Bar, 'foo', NULLIFY) | ||||
|  | ||||
|     .. note :: | ||||
|         `reverse_delete_rule` does not trigger pre / post delete signals to be | ||||
|         triggered. | ||||
|  | ||||
|     .. versionchanged:: 0.5 added `reverse_delete_rule` | ||||
|     """ | ||||
|  | ||||
| @@ -904,12 +911,10 @@ class ReferenceField(BaseField): | ||||
|             A reference to an abstract document type is always stored as a | ||||
|             :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. | ||||
|         """ | ||||
|         if ( | ||||
|             not isinstance(document_type, six.string_types) and | ||||
|             not issubclass(document_type, Document) | ||||
|         ): | ||||
|             self.error('Argument to ReferenceField constructor must be a ' | ||||
|                        'document class or a string') | ||||
|         if not isinstance(document_type, basestring): | ||||
|             if not issubclass(document_type, (Document, basestring)): | ||||
|                 self.error('Argument to ReferenceField constructor must be a ' | ||||
|                            'document class or a string') | ||||
|  | ||||
|         self.dbref = dbref | ||||
|         self.document_type_obj = document_type | ||||
| @@ -918,7 +923,7 @@ class ReferenceField(BaseField): | ||||
|  | ||||
|     @property | ||||
|     def document_type(self): | ||||
|         if isinstance(self.document_type_obj, six.string_types): | ||||
|         if isinstance(self.document_type_obj, basestring): | ||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||
|                 self.document_type_obj = self.owner_document | ||||
|             else: | ||||
| @@ -926,7 +931,8 @@ class ReferenceField(BaseField): | ||||
|         return self.document_type_obj | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor to allow lazy dereferencing.""" | ||||
|         """Descriptor to allow lazy dereferencing. | ||||
|         """ | ||||
|         if instance is None: | ||||
|             # Document class being used rather than a document object | ||||
|             return self | ||||
| @@ -983,7 +989,8 @@ class ReferenceField(BaseField): | ||||
|         return id_ | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         """Convert a MongoDB-compatible type to a Python type.""" | ||||
|         """Convert a MongoDB-compatible type to a Python type. | ||||
|         """ | ||||
|         if (not self.dbref and | ||||
|                 not isinstance(value, (DBRef, Document, EmbeddedDocument))): | ||||
|             collection = self.document_type._get_collection_name() | ||||
| @@ -999,7 +1006,7 @@ class ReferenceField(BaseField): | ||||
|     def validate(self, value): | ||||
|  | ||||
|         if not isinstance(value, (self.document_type, DBRef)): | ||||
|             self.error('A ReferenceField only accepts DBRef or documents') | ||||
|             self.error("A ReferenceField only accepts DBRef or documents") | ||||
|  | ||||
|         if isinstance(value, Document) and value.id is None: | ||||
|             self.error('You can only reference documents once they have been ' | ||||
| @@ -1023,19 +1030,14 @@ class CachedReferenceField(BaseField): | ||||
|     .. versionadded:: 0.9 | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, document_type, fields=None, auto_sync=True, **kwargs): | ||||
|     def __init__(self, document_type, fields=[], auto_sync=True, **kwargs): | ||||
|         """Initialises the Cached Reference Field. | ||||
|  | ||||
|         :param fields:  A list of fields to be cached in document | ||||
|         :param auto_sync: if True documents are auto updated. | ||||
|         """ | ||||
|         if fields is None: | ||||
|             fields = [] | ||||
|  | ||||
|         if ( | ||||
|             not isinstance(document_type, six.string_types) and | ||||
|             not issubclass(document_type, Document) | ||||
|         ): | ||||
|         if not isinstance(document_type, basestring) and \ | ||||
|                 not issubclass(document_type, (Document, basestring)): | ||||
|             self.error('Argument to CachedReferenceField constructor must be a' | ||||
|                        ' document class or a string') | ||||
|  | ||||
| @@ -1051,20 +1053,18 @@ class CachedReferenceField(BaseField): | ||||
|                                   sender=self.document_type) | ||||
|  | ||||
|     def on_document_pre_save(self, sender, document, created, **kwargs): | ||||
|         if created: | ||||
|             return None | ||||
|         if not created: | ||||
|             update_kwargs = dict( | ||||
|                 ('set__%s__%s' % (self.name, k), v) | ||||
|                 for k, v in document._delta()[0].items() | ||||
|                 if k in self.fields) | ||||
|  | ||||
|         update_kwargs = { | ||||
|             'set__%s__%s' % (self.name, key): val | ||||
|             for key, val in document._delta()[0].items() | ||||
|             if key in self.fields | ||||
|         } | ||||
|         if update_kwargs: | ||||
|             filter_kwargs = {} | ||||
|             filter_kwargs[self.name] = document | ||||
|             if update_kwargs: | ||||
|                 filter_kwargs = {} | ||||
|                 filter_kwargs[self.name] = document | ||||
|  | ||||
|             self.owner_document.objects( | ||||
|                 **filter_kwargs).update(**update_kwargs) | ||||
|                 self.owner_document.objects( | ||||
|                     **filter_kwargs).update(**update_kwargs) | ||||
|  | ||||
|     def to_python(self, value): | ||||
|         if isinstance(value, dict): | ||||
| @@ -1077,7 +1077,7 @@ class CachedReferenceField(BaseField): | ||||
|  | ||||
|     @property | ||||
|     def document_type(self): | ||||
|         if isinstance(self.document_type_obj, six.string_types): | ||||
|         if isinstance(self.document_type_obj, basestring): | ||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||
|                 self.document_type_obj = self.owner_document | ||||
|             else: | ||||
| @@ -1117,7 +1117,7 @@ class CachedReferenceField(BaseField): | ||||
|             # TODO: should raise here or will fail next statement | ||||
|  | ||||
|         value = SON(( | ||||
|             ('_id', id_field.to_mongo(id_)), | ||||
|             ("_id", id_field.to_mongo(id_)), | ||||
|         )) | ||||
|  | ||||
|         if fields: | ||||
| @@ -1143,7 +1143,7 @@ class CachedReferenceField(BaseField): | ||||
|     def validate(self, value): | ||||
|  | ||||
|         if not isinstance(value, self.document_type): | ||||
|             self.error('A CachedReferenceField only accepts documents') | ||||
|             self.error("A CachedReferenceField only accepts documents") | ||||
|  | ||||
|         if isinstance(value, Document) and value.id is None: | ||||
|             self.error('You can only reference documents once they have been ' | ||||
| @@ -1191,13 +1191,13 @@ class GenericReferenceField(BaseField): | ||||
|         # Keep the choices as a list of allowed Document class names | ||||
|         if choices: | ||||
|             for choice in choices: | ||||
|                 if isinstance(choice, six.string_types): | ||||
|                 if isinstance(choice, basestring): | ||||
|                     self.choices.append(choice) | ||||
|                 elif isinstance(choice, type) and issubclass(choice, Document): | ||||
|                     self.choices.append(choice._class_name) | ||||
|                 else: | ||||
|                     self.error('Invalid choices provided: must be a list of' | ||||
|                                'Document subclasses and/or six.string_typess') | ||||
|                                'Document subclasses and/or basestrings') | ||||
|  | ||||
|     def _validate_choices(self, value): | ||||
|         if isinstance(value, dict): | ||||
| @@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField): | ||||
|         if document is None: | ||||
|             return None | ||||
|  | ||||
|         if isinstance(document, (dict, SON, ObjectId, DBRef)): | ||||
|         if isinstance(document, (dict, SON)): | ||||
|             return document | ||||
|  | ||||
|         id_field_name = document.__class__._meta['id_field'] | ||||
| @@ -1280,7 +1280,8 @@ class GenericReferenceField(BaseField): | ||||
|  | ||||
|  | ||||
| class BinaryField(BaseField): | ||||
|     """A binary data field.""" | ||||
|     """A binary data field. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, max_bytes=None, **kwargs): | ||||
|         self.max_bytes = max_bytes | ||||
| @@ -1288,18 +1289,18 @@ class BinaryField(BaseField): | ||||
|  | ||||
|     def __set__(self, instance, value): | ||||
|         """Handle bytearrays in python 3.1""" | ||||
|         if six.PY3 and isinstance(value, bytearray): | ||||
|             value = six.binary_type(value) | ||||
|         if PY3 and isinstance(value, bytearray): | ||||
|             value = bin_type(value) | ||||
|         return super(BinaryField, self).__set__(instance, value) | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         return Binary(value) | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if not isinstance(value, (six.binary_type, six.text_type, Binary)): | ||||
|             self.error('BinaryField only accepts instances of ' | ||||
|                        '(%s, %s, Binary)' % ( | ||||
|                            six.binary_type.__name__, six.text_type.__name__)) | ||||
|         if not isinstance(value, (bin_type, txt_type, Binary)): | ||||
|             self.error("BinaryField only accepts instances of " | ||||
|                        "(%s, %s, Binary)" % ( | ||||
|                            bin_type.__name__, txt_type.__name__)) | ||||
|  | ||||
|         if self.max_bytes is not None and len(value) > self.max_bytes: | ||||
|             self.error('Binary value is too long') | ||||
| @@ -1383,13 +1384,11 @@ class GridFSProxy(object): | ||||
|                 get_db(self.db_alias), self.collection_name) | ||||
|         return self._fs | ||||
|  | ||||
|     def get(self, grid_id=None): | ||||
|         if grid_id: | ||||
|             self.grid_id = grid_id | ||||
|  | ||||
|     def get(self, id=None): | ||||
|         if id: | ||||
|             self.grid_id = id | ||||
|         if self.grid_id is None: | ||||
|             return None | ||||
|  | ||||
|         try: | ||||
|             if self.gridout is None: | ||||
|                 self.gridout = self.fs.get(self.grid_id) | ||||
| @@ -1433,7 +1432,7 @@ class GridFSProxy(object): | ||||
|             try: | ||||
|                 return gridout.read(size) | ||||
|             except Exception: | ||||
|                 return '' | ||||
|                 return "" | ||||
|  | ||||
|     def delete(self): | ||||
|         # Delete file from GridFS, FileField still remains | ||||
| @@ -1465,8 +1464,9 @@ class FileField(BaseField): | ||||
|     """ | ||||
|     proxy_class = GridFSProxy | ||||
|  | ||||
|     def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs', | ||||
|                  **kwargs): | ||||
|     def __init__(self, | ||||
|                  db_alias=DEFAULT_CONNECTION_NAME, | ||||
|                  collection_name="fs", **kwargs): | ||||
|         super(FileField, self).__init__(**kwargs) | ||||
|         self.collection_name = collection_name | ||||
|         self.db_alias = db_alias | ||||
| @@ -1488,10 +1488,8 @@ class FileField(BaseField): | ||||
|  | ||||
|     def __set__(self, instance, value): | ||||
|         key = self.name | ||||
|         if ( | ||||
|             (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or | ||||
|             isinstance(value, (six.binary_type, six.string_types)) | ||||
|         ): | ||||
|         if ((hasattr(value, 'read') and not | ||||
|                 isinstance(value, GridFSProxy)) or isinstance(value, str_types)): | ||||
|             # using "FileField() = file/string" notation | ||||
|             grid_file = instance._data.get(self.name) | ||||
|             # If a file already exists, delete it | ||||
| @@ -1560,7 +1558,7 @@ class ImageGridFsProxy(GridFSProxy): | ||||
|         try: | ||||
|             img = Image.open(file_obj) | ||||
|             img_format = img.format | ||||
|         except Exception as e: | ||||
|         except Exception, e: | ||||
|             raise ValidationError('Invalid image: %s' % e) | ||||
|  | ||||
|         # Progressive JPEG | ||||
| @@ -1669,10 +1667,10 @@ class ImageGridFsProxy(GridFSProxy): | ||||
|             return self.fs.get(out.thumbnail_id) | ||||
|  | ||||
|     def write(self, *args, **kwargs): | ||||
|         raise RuntimeError('Please use "put" method instead') | ||||
|         raise RuntimeError("Please use \"put\" method instead") | ||||
|  | ||||
|     def writelines(self, *args, **kwargs): | ||||
|         raise RuntimeError('Please use "put" method instead') | ||||
|         raise RuntimeError("Please use \"put\" method instead") | ||||
|  | ||||
|  | ||||
| class ImproperlyConfigured(Exception): | ||||
| @@ -1697,17 +1695,14 @@ class ImageField(FileField): | ||||
|     def __init__(self, size=None, thumbnail_size=None, | ||||
|                  collection_name='images', **kwargs): | ||||
|         if not Image: | ||||
|             raise ImproperlyConfigured('PIL library was not found') | ||||
|             raise ImproperlyConfigured("PIL library was not found") | ||||
|  | ||||
|         params_size = ('width', 'height', 'force') | ||||
|         extra_args = { | ||||
|             'size': size, | ||||
|             'thumbnail_size': thumbnail_size | ||||
|         } | ||||
|         extra_args = dict(size=size, thumbnail_size=thumbnail_size) | ||||
|         for att_name, att in extra_args.items(): | ||||
|             value = None | ||||
|             if isinstance(att, (tuple, list)): | ||||
|                 if six.PY3: | ||||
|                 if PY3: | ||||
|                     value = dict(itertools.zip_longest(params_size, att, | ||||
|                                                        fillvalue=None)) | ||||
|                 else: | ||||
| @@ -1768,10 +1763,10 @@ class SequenceField(BaseField): | ||||
|         Generate and Increment the counter | ||||
|         """ | ||||
|         sequence_name = self.get_sequence_name() | ||||
|         sequence_id = '%s.%s' % (sequence_name, self.name) | ||||
|         sequence_id = "%s.%s" % (sequence_name, self.name) | ||||
|         collection = get_db(alias=self.db_alias)[self.collection_name] | ||||
|         counter = collection.find_and_modify(query={'_id': sequence_id}, | ||||
|                                              update={'$inc': {'next': 1}}, | ||||
|         counter = collection.find_and_modify(query={"_id": sequence_id}, | ||||
|                                              update={"$inc": {"next": 1}}, | ||||
|                                              new=True, | ||||
|                                              upsert=True) | ||||
|         return self.value_decorator(counter['next']) | ||||
| @@ -1794,9 +1789,9 @@ class SequenceField(BaseField): | ||||
|         as it is only fixed on set. | ||||
|         """ | ||||
|         sequence_name = self.get_sequence_name() | ||||
|         sequence_id = '%s.%s' % (sequence_name, self.name) | ||||
|         sequence_id = "%s.%s" % (sequence_name, self.name) | ||||
|         collection = get_db(alias=self.db_alias)[self.collection_name] | ||||
|         data = collection.find_one({'_id': sequence_id}) | ||||
|         data = collection.find_one({"_id": sequence_id}) | ||||
|  | ||||
|         if data: | ||||
|             return self.value_decorator(data['next'] + 1) | ||||
| @@ -1866,8 +1861,8 @@ class UUIDField(BaseField): | ||||
|         if not self._binary: | ||||
|             original_value = value | ||||
|             try: | ||||
|                 if not isinstance(value, six.string_types): | ||||
|                     value = six.text_type(value) | ||||
|                 if not isinstance(value, basestring): | ||||
|                     value = unicode(value) | ||||
|                 return uuid.UUID(value) | ||||
|             except Exception: | ||||
|                 return original_value | ||||
| @@ -1875,8 +1870,8 @@ class UUIDField(BaseField): | ||||
|  | ||||
|     def to_mongo(self, value): | ||||
|         if not self._binary: | ||||
|             return six.text_type(value) | ||||
|         elif isinstance(value, six.string_types): | ||||
|             return unicode(value) | ||||
|         elif isinstance(value, basestring): | ||||
|             return uuid.UUID(value) | ||||
|         return value | ||||
|  | ||||
| @@ -1887,11 +1882,11 @@ class UUIDField(BaseField): | ||||
|  | ||||
|     def validate(self, value): | ||||
|         if not isinstance(value, uuid.UUID): | ||||
|             if not isinstance(value, six.string_types): | ||||
|             if not isinstance(value, basestring): | ||||
|                 value = str(value) | ||||
|             try: | ||||
|                 uuid.UUID(value) | ||||
|             except Exception as exc: | ||||
|             except Exception, exc: | ||||
|                 self.error('Could not convert to UUID: %s' % exc) | ||||
|  | ||||
|  | ||||
| @@ -1909,18 +1904,19 @@ class GeoPointField(BaseField): | ||||
|     _geo_index = pymongo.GEO2D | ||||
|  | ||||
|     def validate(self, value): | ||||
|         """Make sure that a geo-value is of type (x, y)""" | ||||
|         """Make sure that a geo-value is of type (x, y) | ||||
|         """ | ||||
|         if not isinstance(value, (list, tuple)): | ||||
|             self.error('GeoPointField can only accept tuples or lists ' | ||||
|                        'of (x, y)') | ||||
|  | ||||
|         if not len(value) == 2: | ||||
|             self.error('Value (%s) must be a two-dimensional point' % | ||||
|             self.error("Value (%s) must be a two-dimensional point" % | ||||
|                        repr(value)) | ||||
|         elif (not isinstance(value[0], (float, int)) or | ||||
|               not isinstance(value[1], (float, int))): | ||||
|             self.error( | ||||
|                 'Both values (%s) in point must be float or int' % repr(value)) | ||||
|                 "Both values (%s) in point must be float or int" % repr(value)) | ||||
|  | ||||
|  | ||||
| class PointField(GeoJsonBaseField): | ||||
| @@ -1930,8 +1926,8 @@ class PointField(GeoJsonBaseField): | ||||
|  | ||||
|     .. code-block:: js | ||||
|  | ||||
|         {'type' : 'Point' , | ||||
|          'coordinates' : [x, y]} | ||||
|         { "type" : "Point" , | ||||
|           "coordinates" : [x, y]} | ||||
|  | ||||
|     You can either pass a dict with the full information or a list | ||||
|     to set the value. | ||||
| @@ -1940,7 +1936,7 @@ class PointField(GeoJsonBaseField): | ||||
|  | ||||
|     .. versionadded:: 0.8 | ||||
|     """ | ||||
|     _type = 'Point' | ||||
|     _type = "Point" | ||||
|  | ||||
|  | ||||
| class LineStringField(GeoJsonBaseField): | ||||
| @@ -1950,8 +1946,8 @@ class LineStringField(GeoJsonBaseField): | ||||
|  | ||||
|     .. code-block:: js | ||||
|  | ||||
|         {'type' : 'LineString' , | ||||
|          'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]} | ||||
|         { "type" : "LineString" , | ||||
|           "coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]} | ||||
|  | ||||
|     You can either pass a dict with the full information or a list of points. | ||||
|  | ||||
| @@ -1959,7 +1955,7 @@ class LineStringField(GeoJsonBaseField): | ||||
|  | ||||
|     .. versionadded:: 0.8 | ||||
|     """ | ||||
|     _type = 'LineString' | ||||
|     _type = "LineString" | ||||
|  | ||||
|  | ||||
| class PolygonField(GeoJsonBaseField): | ||||
| @@ -1969,9 +1965,9 @@ class PolygonField(GeoJsonBaseField): | ||||
|  | ||||
|     .. code-block:: js | ||||
|  | ||||
|         {'type' : 'Polygon' , | ||||
|          'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                           [[x1, y1], [x1, y1] ... [xn, yn]]} | ||||
|         { "type" : "Polygon" , | ||||
|           "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                            [[x1, y1], [x1, y1] ... [xn, yn]]} | ||||
|  | ||||
|     You can either pass a dict with the full information or a list | ||||
|     of LineStrings. The first LineString being the outside and the rest being | ||||
| @@ -1981,7 +1977,7 @@ class PolygonField(GeoJsonBaseField): | ||||
|  | ||||
|     .. versionadded:: 0.8 | ||||
|     """ | ||||
|     _type = 'Polygon' | ||||
|     _type = "Polygon" | ||||
|  | ||||
|  | ||||
| class MultiPointField(GeoJsonBaseField): | ||||
| @@ -1991,8 +1987,8 @@ class MultiPointField(GeoJsonBaseField): | ||||
|  | ||||
|     .. code-block:: js | ||||
|  | ||||
|         {'type' : 'MultiPoint' , | ||||
|          'coordinates' : [[x1, y1], [x2, y2]]} | ||||
|         { "type" : "MultiPoint" , | ||||
|           "coordinates" : [[x1, y1], [x2, y2]]} | ||||
|  | ||||
|     You can either pass a dict with the full information or a list | ||||
|     to set the value. | ||||
| @@ -2001,7 +1997,7 @@ class MultiPointField(GeoJsonBaseField): | ||||
|  | ||||
|     .. versionadded:: 0.9 | ||||
|     """ | ||||
|     _type = 'MultiPoint' | ||||
|     _type = "MultiPoint" | ||||
|  | ||||
|  | ||||
| class MultiLineStringField(GeoJsonBaseField): | ||||
| @@ -2011,9 +2007,9 @@ class MultiLineStringField(GeoJsonBaseField): | ||||
|  | ||||
|     .. code-block:: js | ||||
|  | ||||
|         {'type' : 'MultiLineString' , | ||||
|          'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                           [[x1, y1], [x1, y1] ... [xn, yn]]]} | ||||
|         { "type" : "MultiLineString" , | ||||
|           "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                            [[x1, y1], [x1, y1] ... [xn, yn]]]} | ||||
|  | ||||
|     You can either pass a dict with the full information or a list of points. | ||||
|  | ||||
| @@ -2021,7 +2017,7 @@ class MultiLineStringField(GeoJsonBaseField): | ||||
|  | ||||
|     .. versionadded:: 0.9 | ||||
|     """ | ||||
|     _type = 'MultiLineString' | ||||
|     _type = "MultiLineString" | ||||
|  | ||||
|  | ||||
| class MultiPolygonField(GeoJsonBaseField): | ||||
| @@ -2031,14 +2027,14 @@ class MultiPolygonField(GeoJsonBaseField): | ||||
|  | ||||
|     .. code-block:: js | ||||
|  | ||||
|         {'type' : 'MultiPolygon' , | ||||
|          'coordinates' : [[ | ||||
|                [[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                [[x1, y1], [x1, y1] ... [xn, yn]] | ||||
|            ], [ | ||||
|                [[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                [[x1, y1], [x1, y1] ... [xn, yn]] | ||||
|            ] | ||||
|         { "type" : "MultiPolygon" , | ||||
|           "coordinates" : [[ | ||||
|                 [[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                 [[x1, y1], [x1, y1] ... [xn, yn]] | ||||
|             ], [ | ||||
|                 [[x1, y1], [x1, y1] ... [xn, yn]], | ||||
|                 [[x1, y1], [x1, y1] ... [xn, yn]] | ||||
|             ] | ||||
|         } | ||||
|  | ||||
|     You can either pass a dict with the full information or a list | ||||
| @@ -2048,4 +2044,4 @@ class MultiPolygonField(GeoJsonBaseField): | ||||
|  | ||||
|     .. versionadded:: 0.9 | ||||
|     """ | ||||
|     _type = 'MultiPolygon' | ||||
|     _type = "MultiPolygon" | ||||
|   | ||||
| @@ -1,9 +1,7 @@ | ||||
| """ | ||||
| Helper functions, constants, and types to aid with Python v2.7 - v3.x and | ||||
| PyMongo v2.7 - v3.x support. | ||||
| """ | ||||
| """Helper functions and types to aid with Python 2.5 - 3 support.""" | ||||
|  | ||||
| import sys | ||||
| import pymongo | ||||
| import six | ||||
|  | ||||
|  | ||||
| if pymongo.version_tuple[0] < 3: | ||||
| @@ -11,15 +9,29 @@ if pymongo.version_tuple[0] < 3: | ||||
| else: | ||||
|     IS_PYMONGO_3 = True | ||||
|  | ||||
| PY3 = sys.version_info[0] == 3 | ||||
|  | ||||
| # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. | ||||
| StringIO = six.BytesIO | ||||
| if PY3: | ||||
|     import codecs | ||||
|     from io import BytesIO as StringIO | ||||
|  | ||||
| # Additionally for Py2, try to use the faster cStringIO, if available | ||||
| if not six.PY3: | ||||
|     # return s converted to binary.  b('test') should be equivalent to b'test' | ||||
|     def b(s): | ||||
|         return codecs.latin_1_encode(s)[0] | ||||
|  | ||||
|     bin_type = bytes | ||||
|     txt_type = str | ||||
| else: | ||||
|     try: | ||||
|         import cStringIO | ||||
|         from cStringIO import StringIO | ||||
|     except ImportError: | ||||
|         pass | ||||
|     else: | ||||
|         StringIO = cStringIO.StringIO | ||||
|         from StringIO import StringIO | ||||
|  | ||||
|     # Conversion to binary only necessary in Python 3 | ||||
|     def b(s): | ||||
|         return s | ||||
|  | ||||
|     bin_type = str | ||||
|     txt_type = unicode | ||||
|  | ||||
| str_types = (bin_type, txt_type) | ||||
|   | ||||
| @@ -1,17 +1,11 @@ | ||||
| from mongoengine.errors import * | ||||
| from mongoengine.errors import (DoesNotExist, InvalidQueryError, | ||||
|                                 MultipleObjectsReturned, NotUniqueError, | ||||
|                                 OperationError) | ||||
| from mongoengine.queryset.field_list import * | ||||
| from mongoengine.queryset.manager import * | ||||
| from mongoengine.queryset.queryset import * | ||||
| from mongoengine.queryset.transform import * | ||||
| from mongoengine.queryset.visitor import * | ||||
|  | ||||
| # Expose just the public subset of all imported objects and constants. | ||||
| __all__ = ( | ||||
|     'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager', | ||||
|     'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL', | ||||
|  | ||||
|     # Errors that might be related to a queryset, mostly here for backward | ||||
|     # compatibility | ||||
|     'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned', | ||||
|     'NotUniqueError', 'OperationError', | ||||
| ) | ||||
| __all__ = (field_list.__all__ + manager.__all__ + queryset.__all__ + | ||||
|            transform.__all__ + visitor.__all__) | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -67,7 +67,7 @@ class QueryFieldList(object): | ||||
|         return bool(self.fields) | ||||
|  | ||||
|     def as_dict(self): | ||||
|         field_list = {field: self.value for field in self.fields} | ||||
|         field_list = dict((field, self.value) for field in self.fields) | ||||
|         if self.slice: | ||||
|             field_list.update(self.slice) | ||||
|         if self._id is not None: | ||||
|   | ||||
| @@ -27,10 +27,9 @@ class QuerySet(BaseQuerySet): | ||||
|         in batches of ``ITER_CHUNK_SIZE``. | ||||
|  | ||||
|         If ``self._has_more`` the cursor hasn't been exhausted so cache then | ||||
|         batch. Otherwise iterate the result_cache. | ||||
|         batch.  Otherwise iterate the result_cache. | ||||
|         """ | ||||
|         self._iter = True | ||||
|  | ||||
|         if self._has_more: | ||||
|             return self._iter_results() | ||||
|  | ||||
| @@ -43,56 +42,40 @@ class QuerySet(BaseQuerySet): | ||||
|         """ | ||||
|         if self._len is not None: | ||||
|             return self._len | ||||
|  | ||||
|         # Populate the result cache with *all* of the docs in the cursor | ||||
|         if self._has_more: | ||||
|             # populate the cache | ||||
|             list(self._iter_results()) | ||||
|  | ||||
|         # Cache the length of the complete result cache and return it | ||||
|         self._len = len(self._result_cache) | ||||
|         return self._len | ||||
|  | ||||
|     def __repr__(self): | ||||
|         """Provide a string representation of the QuerySet""" | ||||
|         """Provides the string representation of the QuerySet | ||||
|         """ | ||||
|         if self._iter: | ||||
|             return '.. queryset mid-iteration ..' | ||||
|  | ||||
|         self._populate_cache() | ||||
|         data = self._result_cache[:REPR_OUTPUT_SIZE + 1] | ||||
|         if len(data) > REPR_OUTPUT_SIZE: | ||||
|             data[-1] = '...(remaining elements truncated)...' | ||||
|             data[-1] = "...(remaining elements truncated)..." | ||||
|         return repr(data) | ||||
|  | ||||
|     def _iter_results(self): | ||||
|         """A generator for iterating over the result cache. | ||||
|  | ||||
|         Also populates the cache if there are more possible results to | ||||
|         yield. Raises StopIteration when there are no more results. | ||||
|         """ | ||||
|         Also populates the cache if there are more possible results to yield. | ||||
|         Raises StopIteration when there are no more results""" | ||||
|         if self._result_cache is None: | ||||
|             self._result_cache = [] | ||||
|  | ||||
|         pos = 0 | ||||
|         while True: | ||||
|  | ||||
|             # For all positions lower than the length of the current result | ||||
|             # cache, serve the docs straight from the cache w/o hitting the | ||||
|             # database. | ||||
|             # XXX it's VERY important to compute the len within the `while` | ||||
|             # condition because the result cache might expand mid-iteration | ||||
|             # (e.g. if we call len(qs) inside a loop that iterates over the | ||||
|             # queryset). Fortunately len(list) is O(1) in Python, so this | ||||
|             # doesn't cause performance issues. | ||||
|             while pos < len(self._result_cache): | ||||
|             upper = len(self._result_cache) | ||||
|             while pos < upper: | ||||
|                 yield self._result_cache[pos] | ||||
|                 pos += 1 | ||||
|  | ||||
|             # Raise StopIteration if we already established there were no more | ||||
|             # docs in the db cursor. | ||||
|             if not self._has_more: | ||||
|                 raise StopIteration | ||||
|  | ||||
|             # Otherwise, populate more of the cache and repeat. | ||||
|             if len(self._result_cache) <= pos: | ||||
|                 self._populate_cache() | ||||
|  | ||||
| @@ -103,22 +86,12 @@ class QuerySet(BaseQuerySet): | ||||
|         """ | ||||
|         if self._result_cache is None: | ||||
|             self._result_cache = [] | ||||
|  | ||||
|         # Skip populating the cache if we already established there are no | ||||
|         # more docs to pull from the database. | ||||
|         if not self._has_more: | ||||
|             return | ||||
|  | ||||
|         # Pull in ITER_CHUNK_SIZE docs from the database and store them in | ||||
|         # the result cache. | ||||
|         try: | ||||
|             for _ in xrange(ITER_CHUNK_SIZE): | ||||
|                 self._result_cache.append(self.next()) | ||||
|         except StopIteration: | ||||
|             # Getting this exception means there are no more docs in the | ||||
|             # db cursor. Set _has_more to False so that we can use that | ||||
|             # information in other places. | ||||
|             self._has_more = False | ||||
|         if self._has_more: | ||||
|             try: | ||||
|                 for i in xrange(ITER_CHUNK_SIZE): | ||||
|                     self._result_cache.append(self.next()) | ||||
|             except StopIteration: | ||||
|                 self._has_more = False | ||||
|  | ||||
|     def count(self, with_limit_and_skip=False): | ||||
|         """Count the selected elements in the query. | ||||
| @@ -136,15 +109,13 @@ class QuerySet(BaseQuerySet): | ||||
|         return self._len | ||||
|  | ||||
|     def no_cache(self): | ||||
|         """Convert to a non-caching queryset | ||||
|         """Convert to a non_caching queryset | ||||
|  | ||||
|         .. versionadded:: 0.8.3 Convert to non caching queryset | ||||
|         """ | ||||
|         if self._result_cache is not None: | ||||
|             raise OperationError('QuerySet already cached') | ||||
|  | ||||
|         return self._clone_into(QuerySetNoCache(self._document, | ||||
|                                                 self._collection)) | ||||
|             raise OperationError("QuerySet already cached") | ||||
|         return self.clone_into(QuerySetNoCache(self._document, self._collection)) | ||||
|  | ||||
|  | ||||
| class QuerySetNoCache(BaseQuerySet): | ||||
| @@ -155,7 +126,7 @@ class QuerySetNoCache(BaseQuerySet): | ||||
|  | ||||
|         .. versionadded:: 0.8.3 Convert to caching queryset | ||||
|         """ | ||||
|         return self._clone_into(QuerySet(self._document, self._collection)) | ||||
|         return self.clone_into(QuerySet(self._document, self._collection)) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         """Provides the string representation of the QuerySet | ||||
| @@ -166,14 +137,13 @@ class QuerySetNoCache(BaseQuerySet): | ||||
|             return '.. queryset mid-iteration ..' | ||||
|  | ||||
|         data = [] | ||||
|         for _ in xrange(REPR_OUTPUT_SIZE + 1): | ||||
|         for i in xrange(REPR_OUTPUT_SIZE + 1): | ||||
|             try: | ||||
|                 data.append(self.next()) | ||||
|             except StopIteration: | ||||
|                 break | ||||
|  | ||||
|         if len(data) > REPR_OUTPUT_SIZE: | ||||
|             data[-1] = '...(remaining elements truncated)...' | ||||
|             data[-1] = "...(remaining elements truncated)..." | ||||
|  | ||||
|         self.rewind() | ||||
|         return repr(data) | ||||
|   | ||||
| @@ -1,11 +1,9 @@ | ||||
| from collections import defaultdict | ||||
|  | ||||
| from bson import ObjectId, SON | ||||
| from bson.dbref import DBRef | ||||
| from bson import SON | ||||
| import pymongo | ||||
| import six | ||||
|  | ||||
| from mongoengine.base import UPDATE_OPERATORS | ||||
| from mongoengine.base.fields import UPDATE_OPERATORS | ||||
| from mongoengine.common import _import_class | ||||
| from mongoengine.connection import get_connection | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| @@ -28,13 +26,13 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + | ||||
|                    STRING_OPERATORS + CUSTOM_OPERATORS) | ||||
|  | ||||
|  | ||||
| # TODO make this less complex | ||||
| def query(_doc_cls=None, **kwargs): | ||||
|     """Transform a query from Django-style format to Mongo format.""" | ||||
|     """Transform a query from Django-style format to Mongo format. | ||||
|     """ | ||||
|     mongo_query = {} | ||||
|     merge_query = defaultdict(list) | ||||
|     for key, value in sorted(kwargs.items()): | ||||
|         if key == '__raw__': | ||||
|         if key == "__raw__": | ||||
|             mongo_query.update(value) | ||||
|             continue | ||||
|  | ||||
| @@ -47,7 +45,7 @@ def query(_doc_cls=None, **kwargs): | ||||
|             op = parts.pop() | ||||
|  | ||||
|         # Allow to escape operator-like field name by __ | ||||
|         if len(parts) > 1 and parts[-1] == '': | ||||
|         if len(parts) > 1 and parts[-1] == "": | ||||
|             parts.pop() | ||||
|  | ||||
|         negate = False | ||||
| @@ -59,17 +57,16 @@ def query(_doc_cls=None, **kwargs): | ||||
|             # Switch field names to proper names [set in Field(name='foo')] | ||||
|             try: | ||||
|                 fields = _doc_cls._lookup_field(parts) | ||||
|             except Exception as e: | ||||
|             except Exception, e: | ||||
|                 raise InvalidQueryError(e) | ||||
|             parts = [] | ||||
|  | ||||
|             CachedReferenceField = _import_class('CachedReferenceField') | ||||
|             GenericReferenceField = _import_class('GenericReferenceField') | ||||
|  | ||||
|             cleaned_fields = [] | ||||
|             for field in fields: | ||||
|                 append_field = True | ||||
|                 if isinstance(field, six.string_types): | ||||
|                 if isinstance(field, basestring): | ||||
|                     parts.append(field) | ||||
|                     append_field = False | ||||
|                 # is last and CachedReferenceField | ||||
| @@ -87,9 +84,9 @@ def query(_doc_cls=None, **kwargs): | ||||
|             singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] | ||||
|             singular_ops += STRING_OPERATORS | ||||
|             if op in singular_ops: | ||||
|                 if isinstance(field, six.string_types): | ||||
|                 if isinstance(field, basestring): | ||||
|                     if (op in STRING_OPERATORS and | ||||
|                             isinstance(value, six.string_types)): | ||||
|                             isinstance(value, basestring)): | ||||
|                         StringField = _import_class('StringField') | ||||
|                         value = StringField.prepare_query_value(op, value) | ||||
|                     else: | ||||
| @@ -101,31 +98,8 @@ def query(_doc_cls=None, **kwargs): | ||||
|                         value = value['_id'] | ||||
|  | ||||
|             elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): | ||||
|                 # Raise an error if the in/nin/all/near param is not iterable. We need a | ||||
|                 # special check for BaseDocument, because - although it's iterable - using | ||||
|                 # it as such in the context of this method is most definitely a mistake. | ||||
|                 BaseDocument = _import_class('BaseDocument') | ||||
|                 if isinstance(value, BaseDocument): | ||||
|                     raise TypeError("When using the `in`, `nin`, `all`, or " | ||||
|                                     "`near`-operators you can\'t use a " | ||||
|                                     "`Document`, you must wrap your object " | ||||
|                                     "in a list (object -> [object]).") | ||||
|                 elif not hasattr(value, '__iter__'): | ||||
|                     raise TypeError("The `in`, `nin`, `all`, or " | ||||
|                                     "`near`-operators must be applied to an " | ||||
|                                     "iterable (e.g. a list).") | ||||
|                 else: | ||||
|                     value = [field.prepare_query_value(op, v) for v in value] | ||||
|  | ||||
|             # If we're querying a GenericReferenceField, we need to alter the | ||||
|             # key depending on the value: | ||||
|             # * If the value is a DBRef, the key should be "field_name._ref". | ||||
|             # * If the value is an ObjectId, the key should be "field_name._ref.$id". | ||||
|             if isinstance(field, GenericReferenceField): | ||||
|                 if isinstance(value, DBRef): | ||||
|                     parts[-1] += '._ref' | ||||
|                 elif isinstance(value, ObjectId): | ||||
|                     parts[-1] += '._ref.$id' | ||||
|                 # 'in', 'nin' and 'all' require a list of values | ||||
|                 value = [field.prepare_query_value(op, v) for v in value] | ||||
|  | ||||
|         # if op and op not in COMPARISON_OPERATORS: | ||||
|         if op: | ||||
| @@ -142,10 +116,10 @@ def query(_doc_cls=None, **kwargs): | ||||
|                     value = query(field.field.document_type, **value) | ||||
|                 else: | ||||
|                     value = field.prepare_query_value(op, value) | ||||
|                 value = {'$elemMatch': value} | ||||
|                 value = {"$elemMatch": value} | ||||
|             elif op in CUSTOM_OPERATORS: | ||||
|                 NotImplementedError('Custom method "%s" has not ' | ||||
|                                     'been implemented' % op) | ||||
|                 NotImplementedError("Custom method '%s' has not " | ||||
|                                     "been implemented" % op) | ||||
|             elif op not in STRING_OPERATORS: | ||||
|                 value = {'$' + op: value} | ||||
|  | ||||
| @@ -154,13 +128,11 @@ def query(_doc_cls=None, **kwargs): | ||||
|  | ||||
|         for i, part in indices: | ||||
|             parts.insert(i, part) | ||||
|  | ||||
|         key = '.'.join(parts) | ||||
|  | ||||
|         if op is None or key not in mongo_query: | ||||
|             mongo_query[key] = value | ||||
|         elif key in mongo_query: | ||||
|             if isinstance(mongo_query[key], dict): | ||||
|             if key in mongo_query and isinstance(mongo_query[key], dict): | ||||
|                 mongo_query[key].update(value) | ||||
|                 # $max/minDistance needs to come last - convert to SON | ||||
|                 value_dict = mongo_query[key] | ||||
| @@ -210,16 +182,15 @@ def query(_doc_cls=None, **kwargs): | ||||
|  | ||||
|  | ||||
| def update(_doc_cls=None, **update): | ||||
|     """Transform an update spec from Django-style format to Mongo | ||||
|     format. | ||||
|     """Transform an update spec from Django-style format to Mongo format. | ||||
|     """ | ||||
|     mongo_update = {} | ||||
|     for key, value in update.items(): | ||||
|         if key == '__raw__': | ||||
|         if key == "__raw__": | ||||
|             mongo_update.update(value) | ||||
|             continue | ||||
|         parts = key.split('__') | ||||
|         # if there is no operator, default to 'set' | ||||
|         # if there is no operator, default to "set" | ||||
|         if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: | ||||
|             parts.insert(0, 'set') | ||||
|         # Check for an operator and transform to mongo-style if there is | ||||
| @@ -233,25 +204,26 @@ def update(_doc_cls=None, **update): | ||||
|                 # Support decrement by flipping a positive value's sign | ||||
|                 # and using 'inc' | ||||
|                 op = 'inc' | ||||
|                 value = -value | ||||
|                 if value > 0: | ||||
|                     value = -value | ||||
|             elif op == 'add_to_set': | ||||
|                 op = 'addToSet' | ||||
|             elif op == 'set_on_insert': | ||||
|                 op = 'setOnInsert' | ||||
|                 op = "setOnInsert" | ||||
|  | ||||
|         match = None | ||||
|         if parts[-1] in COMPARISON_OPERATORS: | ||||
|             match = parts.pop() | ||||
|  | ||||
|         # Allow to escape operator-like field name by __ | ||||
|         if len(parts) > 1 and parts[-1] == '': | ||||
|         if len(parts) > 1 and parts[-1] == "": | ||||
|             parts.pop() | ||||
|  | ||||
|         if _doc_cls: | ||||
|             # Switch field names to proper names [set in Field(name='foo')] | ||||
|             try: | ||||
|                 fields = _doc_cls._lookup_field(parts) | ||||
|             except Exception as e: | ||||
|             except Exception, e: | ||||
|                 raise InvalidQueryError(e) | ||||
|             parts = [] | ||||
|  | ||||
| @@ -259,7 +231,7 @@ def update(_doc_cls=None, **update): | ||||
|             appended_sub_field = False | ||||
|             for field in fields: | ||||
|                 append_field = True | ||||
|                 if isinstance(field, six.string_types): | ||||
|                 if isinstance(field, basestring): | ||||
|                     # Convert the S operator to $ | ||||
|                     if field == 'S': | ||||
|                         field = '$' | ||||
| @@ -280,7 +252,7 @@ def update(_doc_cls=None, **update): | ||||
|             else: | ||||
|                 field = cleaned_fields[-1] | ||||
|  | ||||
|             GeoJsonBaseField = _import_class('GeoJsonBaseField') | ||||
|             GeoJsonBaseField = _import_class("GeoJsonBaseField") | ||||
|             if isinstance(field, GeoJsonBaseField): | ||||
|                 value = field.to_mongo(value) | ||||
|  | ||||
| @@ -294,7 +266,7 @@ def update(_doc_cls=None, **update): | ||||
|                     value = [field.prepare_query_value(op, v) for v in value] | ||||
|                 elif field.required or value is not None: | ||||
|                     value = field.prepare_query_value(op, value) | ||||
|             elif op == 'unset': | ||||
|             elif op == "unset": | ||||
|                 value = 1 | ||||
|  | ||||
|         if match: | ||||
| @@ -304,16 +276,16 @@ def update(_doc_cls=None, **update): | ||||
|         key = '.'.join(parts) | ||||
|  | ||||
|         if not op: | ||||
|             raise InvalidQueryError('Updates must supply an operation ' | ||||
|                                     'eg: set__FIELD=value') | ||||
|             raise InvalidQueryError("Updates must supply an operation " | ||||
|                                     "eg: set__FIELD=value") | ||||
|  | ||||
|         if 'pull' in op and '.' in key: | ||||
|             # Dot operators don't work on pull operations | ||||
|             # unless they point to a list field | ||||
|             # Otherwise it uses nested dict syntax | ||||
|             if op == 'pullAll': | ||||
|                 raise InvalidQueryError('pullAll operations only support ' | ||||
|                                         'a single field depth') | ||||
|                 raise InvalidQueryError("pullAll operations only support " | ||||
|                                         "a single field depth") | ||||
|  | ||||
|             # Look for the last list field and use dot notation until there | ||||
|             field_classes = [c.__class__ for c in cleaned_fields] | ||||
| @@ -324,7 +296,7 @@ def update(_doc_cls=None, **update): | ||||
|                 # Then process as normal | ||||
|                 last_listField = len( | ||||
|                     cleaned_fields) - field_classes.index(ListField) | ||||
|                 key = '.'.join(parts[:last_listField]) | ||||
|                 key = ".".join(parts[:last_listField]) | ||||
|                 parts = parts[last_listField:] | ||||
|                 parts.insert(0, key) | ||||
|  | ||||
| @@ -332,7 +304,7 @@ def update(_doc_cls=None, **update): | ||||
|             for key in parts: | ||||
|                 value = {key: value} | ||||
|         elif op == 'addToSet' and isinstance(value, list): | ||||
|             value = {key: {'$each': value}} | ||||
|             value = {key: {"$each": value}} | ||||
|         else: | ||||
|             value = {key: value} | ||||
|         key = '$' + op | ||||
| @@ -346,82 +318,78 @@ def update(_doc_cls=None, **update): | ||||
|  | ||||
|  | ||||
| def _geo_operator(field, op, value): | ||||
|     """Helper to return the query for a given geo query.""" | ||||
|     if op == 'max_distance': | ||||
|     """Helper to return the query for a given geo query""" | ||||
|     if op == "max_distance": | ||||
|         value = {'$maxDistance': value} | ||||
|     elif op == 'min_distance': | ||||
|     elif op == "min_distance": | ||||
|         value = {'$minDistance': value} | ||||
|     elif field._geo_index == pymongo.GEO2D: | ||||
|         if op == 'within_distance': | ||||
|         if op == "within_distance": | ||||
|             value = {'$within': {'$center': value}} | ||||
|         elif op == 'within_spherical_distance': | ||||
|         elif op == "within_spherical_distance": | ||||
|             value = {'$within': {'$centerSphere': value}} | ||||
|         elif op == 'within_polygon': | ||||
|         elif op == "within_polygon": | ||||
|             value = {'$within': {'$polygon': value}} | ||||
|         elif op == 'near': | ||||
|         elif op == "near": | ||||
|             value = {'$near': value} | ||||
|         elif op == 'near_sphere': | ||||
|         elif op == "near_sphere": | ||||
|             value = {'$nearSphere': value} | ||||
|         elif op == 'within_box': | ||||
|             value = {'$within': {'$box': value}} | ||||
|         else: | ||||
|             raise NotImplementedError('Geo method "%s" has not been ' | ||||
|                                       'implemented for a GeoPointField' % op) | ||||
|             raise NotImplementedError("Geo method '%s' has not " | ||||
|                                       "been implemented for a GeoPointField" % op) | ||||
|     else: | ||||
|         if op == 'geo_within': | ||||
|             value = {'$geoWithin': _infer_geometry(value)} | ||||
|         elif op == 'geo_within_box': | ||||
|             value = {'$geoWithin': {'$box': value}} | ||||
|         elif op == 'geo_within_polygon': | ||||
|             value = {'$geoWithin': {'$polygon': value}} | ||||
|         elif op == 'geo_within_center': | ||||
|             value = {'$geoWithin': {'$center': value}} | ||||
|         elif op == 'geo_within_sphere': | ||||
|             value = {'$geoWithin': {'$centerSphere': value}} | ||||
|         elif op == 'geo_intersects': | ||||
|             value = {'$geoIntersects': _infer_geometry(value)} | ||||
|         elif op == 'near': | ||||
|         if op == "geo_within": | ||||
|             value = {"$geoWithin": _infer_geometry(value)} | ||||
|         elif op == "geo_within_box": | ||||
|             value = {"$geoWithin": {"$box": value}} | ||||
|         elif op == "geo_within_polygon": | ||||
|             value = {"$geoWithin": {"$polygon": value}} | ||||
|         elif op == "geo_within_center": | ||||
|             value = {"$geoWithin": {"$center": value}} | ||||
|         elif op == "geo_within_sphere": | ||||
|             value = {"$geoWithin": {"$centerSphere": value}} | ||||
|         elif op == "geo_intersects": | ||||
|             value = {"$geoIntersects": _infer_geometry(value)} | ||||
|         elif op == "near": | ||||
|             value = {'$near': _infer_geometry(value)} | ||||
|         else: | ||||
|             raise NotImplementedError( | ||||
|                 'Geo method "%s" has not been implemented for a %s ' | ||||
|                 % (op, field._name) | ||||
|             ) | ||||
|             raise NotImplementedError("Geo method '%s' has not " | ||||
|                                       "been implemented for a %s " % (op, field._name)) | ||||
|     return value | ||||
|  | ||||
|  | ||||
| def _infer_geometry(value): | ||||
|     """Helper method that tries to infer the $geometry shape for a | ||||
|     given value. | ||||
|     """ | ||||
|     """Helper method that tries to infer the $geometry shape for a given value""" | ||||
|     if isinstance(value, dict): | ||||
|         if '$geometry' in value: | ||||
|         if "$geometry" in value: | ||||
|             return value | ||||
|         elif 'coordinates' in value and 'type' in value: | ||||
|             return {'$geometry': value} | ||||
|         raise InvalidQueryError('Invalid $geometry dictionary should have ' | ||||
|                                 'type and coordinates keys') | ||||
|             return {"$geometry": value} | ||||
|         raise InvalidQueryError("Invalid $geometry dictionary should have " | ||||
|                                 "type and coordinates keys") | ||||
|     elif isinstance(value, (list, set)): | ||||
|         # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? | ||||
|         # TODO: should both TypeError and IndexError be alike interpreted? | ||||
|  | ||||
|         try: | ||||
|             value[0][0][0] | ||||
|             return {'$geometry': {'type': 'Polygon', 'coordinates': value}} | ||||
|             return {"$geometry": {"type": "Polygon", "coordinates": value}} | ||||
|         except (TypeError, IndexError): | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             value[0][0] | ||||
|             return {'$geometry': {'type': 'LineString', 'coordinates': value}} | ||||
|             return {"$geometry": {"type": "LineString", "coordinates": value}} | ||||
|         except (TypeError, IndexError): | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             value[0] | ||||
|             return {'$geometry': {'type': 'Point', 'coordinates': value}} | ||||
|             return {"$geometry": {"type": "Point", "coordinates": value}} | ||||
|         except (TypeError, IndexError): | ||||
|             pass | ||||
|  | ||||
|     raise InvalidQueryError('Invalid $geometry data. Can be either a ' | ||||
|                             'dictionary or (nested) lists of coordinate(s)') | ||||
|     raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " | ||||
|                             "or (nested) lists of coordinate(s)") | ||||
|   | ||||
| @@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor): | ||||
|         self.document = document | ||||
|  | ||||
|     def visit_combination(self, combination): | ||||
|         operator = '$and' | ||||
|         operator = "$and" | ||||
|         if combination.operation == combination.OR: | ||||
|             operator = '$or' | ||||
|             operator = "$or" | ||||
|         return {operator: combination.children} | ||||
|  | ||||
|     def visit_query(self, query): | ||||
| @@ -79,7 +79,8 @@ class QueryCompilerVisitor(QNodeVisitor): | ||||
|  | ||||
|  | ||||
| class QNode(object): | ||||
|     """Base class for nodes in query trees.""" | ||||
|     """Base class for nodes in query trees. | ||||
|     """ | ||||
|  | ||||
|     AND = 0 | ||||
|     OR = 1 | ||||
| @@ -93,8 +94,7 @@ class QNode(object): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _combine(self, other, operation): | ||||
|         """Combine this node with another node into a QCombination | ||||
|         object. | ||||
|         """Combine this node with another node into a QCombination object. | ||||
|         """ | ||||
|         if getattr(other, 'empty', True): | ||||
|             return self | ||||
| @@ -116,8 +116,8 @@ class QNode(object): | ||||
|  | ||||
|  | ||||
| class QCombination(QNode): | ||||
|     """Represents the combination of several conditions by a given | ||||
|     logical operator. | ||||
|     """Represents the combination of several conditions by a given logical | ||||
|     operator. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, operation, children): | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| __all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', | ||||
|            'post_save', 'pre_delete', 'post_delete') | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| __all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', | ||||
|            'post_save', 'pre_delete', 'post_delete'] | ||||
|  | ||||
| signals_available = False | ||||
| try: | ||||
| @@ -32,7 +34,6 @@ except ImportError: | ||||
|             temporarily_connected_to = _fail | ||||
|         del _fail | ||||
|  | ||||
|  | ||||
| # the namespace for code signals.  If you are not mongoengine code, do | ||||
| # not put signals in here.  Create your own namespace instead. | ||||
| _signals = Namespace() | ||||
|   | ||||
							
								
								
									
										14
									
								
								setup.cfg
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								setup.cfg
									
									
									
									
									
								
							| @@ -1,11 +1,13 @@ | ||||
| [nosetests] | ||||
| verbosity=2 | ||||
| detailed-errors=1 | ||||
| tests=tests | ||||
| cover-package=mongoengine | ||||
| verbosity = 2 | ||||
| detailed-errors = 1 | ||||
| cover-erase = 1 | ||||
| cover-branches = 1 | ||||
| cover-package = mongoengine | ||||
| tests = tests | ||||
|  | ||||
| [flake8] | ||||
| ignore=E501,F401,F403,F405,I201 | ||||
| exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests | ||||
| max-complexity=47 | ||||
| exclude=build,dist,docs,venv,.tox,.eggs,tests | ||||
| max-complexity=42 | ||||
| application-import-names=mongoengine,tests | ||||
|   | ||||
							
								
								
									
										25
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								setup.py
									
									
									
									
									
								
							| @@ -21,9 +21,8 @@ except Exception: | ||||
|  | ||||
|  | ||||
| def get_version(version_tuple): | ||||
|     """Return the version tuple as a string, e.g. for (0, 10, 7), | ||||
|     return '0.10.7'. | ||||
|     """ | ||||
|     if not isinstance(version_tuple[-1], int): | ||||
|         return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] | ||||
|     return '.'.join(map(str, version_tuple)) | ||||
|  | ||||
|  | ||||
| @@ -42,29 +41,31 @@ CLASSIFIERS = [ | ||||
|     'Operating System :: OS Independent', | ||||
|     'Programming Language :: Python', | ||||
|     "Programming Language :: Python :: 2", | ||||
|     "Programming Language :: Python :: 2.6", | ||||
|     "Programming Language :: Python :: 2.7", | ||||
|     "Programming Language :: Python :: 3", | ||||
|     "Programming Language :: Python :: 3.2", | ||||
|     "Programming Language :: Python :: 3.3", | ||||
|     "Programming Language :: Python :: 3.4", | ||||
|     "Programming Language :: Python :: 3.5", | ||||
|     "Programming Language :: Python :: Implementation :: CPython", | ||||
|     "Programming Language :: Python :: Implementation :: PyPy", | ||||
|     'Topic :: Database', | ||||
|     'Topic :: Software Development :: Libraries :: Python Modules', | ||||
| ] | ||||
|  | ||||
| extra_opts = { | ||||
|     'packages': find_packages(exclude=['tests', 'tests.*']), | ||||
|     'tests_require': ['nose', 'coverage==4.2', 'blinker', 'Pillow>=2.0.0'] | ||||
| } | ||||
| extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | ||||
| if sys.version_info[0] == 3: | ||||
|     extra_opts['use_2to3'] = True | ||||
|     if 'test' in sys.argv or 'nosetests' in sys.argv: | ||||
|     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] | ||||
|     if "test" in sys.argv or "nosetests" in sys.argv: | ||||
|         extra_opts['packages'] = find_packages() | ||||
|         extra_opts['package_data'] = { | ||||
|             'tests': ['fields/mongoengine.png', 'fields/mongodb_leaf.png']} | ||||
|         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} | ||||
| else: | ||||
|     extra_opts['tests_require'] += ['python-dateutil'] | ||||
|     # coverage 4 does not support Python 3.2 anymore | ||||
|     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] | ||||
|  | ||||
|     if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||
|         extra_opts['tests_require'].append('unittest2') | ||||
|  | ||||
| setup( | ||||
|     name='mongoengine', | ||||
|   | ||||
| @@ -2,3 +2,4 @@ from all_warnings import AllWarnings | ||||
| from document import * | ||||
| from queryset import * | ||||
| from fields import * | ||||
| from migration import * | ||||
|   | ||||
| @@ -3,6 +3,8 @@ This test has been put into a module.  This is because it tests warnings that | ||||
| only get triggered on first hit.  This way we can ensure its imported into the | ||||
| top level and called first by the test suite. | ||||
| """ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
| import warnings | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from class_methods import * | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from bson import SON | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| import unittest | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db | ||||
| @@ -141,9 +143,11 @@ class DynamicTest(unittest.TestCase): | ||||
|  | ||||
|     def test_three_level_complex_data_lookups(self): | ||||
|         """Ensure you can query three level document dynamic fields""" | ||||
|         p = self.Person.objects.create( | ||||
|             misc={'hello': {'hello2': 'world'}} | ||||
|         ) | ||||
|         p = self.Person() | ||||
|         p.misc = {'hello': {'hello2': 'world'}} | ||||
|         p.save() | ||||
|         # from pprint import pprint as pp; import pdb; pdb.set_trace(); | ||||
|         print self.Person.objects(misc__hello__hello2='world') | ||||
|         self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count()) | ||||
|  | ||||
|     def test_complex_embedded_document_validation(self): | ||||
|   | ||||
| @@ -2,14 +2,14 @@ | ||||
| import unittest | ||||
| import sys | ||||
|  | ||||
| from nose.plugins.skip import SkipTest | ||||
| from datetime import datetime | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db | ||||
| from nose.plugins.skip import SkipTest | ||||
| from datetime import datetime | ||||
|  | ||||
| from tests.utils import get_mongodb_version, needs_mongodb_v26 | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db, get_connection | ||||
|  | ||||
| __all__ = ("IndexesTest", ) | ||||
|  | ||||
| @@ -412,6 +412,7 @@ class IndexesTest(unittest.TestCase): | ||||
|         User.ensure_indexes() | ||||
|         info = User.objects._collection.index_information() | ||||
|         self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_']) | ||||
|         User.drop_collection() | ||||
|  | ||||
|     def test_embedded_document_index(self): | ||||
|         """Tests settings an index on an embedded document | ||||
| @@ -433,6 +434,7 @@ class IndexesTest(unittest.TestCase): | ||||
|  | ||||
|         info = BlogPost.objects._collection.index_information() | ||||
|         self.assertEqual(sorted(info.keys()), ['_id_', 'date.yr_-1']) | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_list_embedded_document_index(self): | ||||
|         """Ensure list embedded documents can be indexed | ||||
| @@ -459,6 +461,7 @@ class IndexesTest(unittest.TestCase): | ||||
|         post1 = BlogPost(title="Embedded Indexes tests in place", | ||||
|                          tags=[Tag(name="about"), Tag(name="time")]) | ||||
|         post1.save() | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_recursive_embedded_objects_dont_break_indexes(self): | ||||
|  | ||||
| @@ -491,7 +494,8 @@ class IndexesTest(unittest.TestCase): | ||||
|         obj = Test(a=1) | ||||
|         obj.save() | ||||
|  | ||||
|         IS_MONGODB_3 = get_mongodb_version()[0] >= 3 | ||||
|         connection = get_connection() | ||||
|         IS_MONGODB_3 = connection.server_info()['versionArray'][0] >= 3 | ||||
|  | ||||
|         # Need to be explicit about covered indexes as mongoDB doesn't know if | ||||
|         # the documents returned might have more keys in that here. | ||||
| @@ -552,8 +556,8 @@ class IndexesTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         for i in range(0, 10): | ||||
|             tags = [("tag %i" % n) for n in range(0, i % 2)] | ||||
|         for i in xrange(0, 10): | ||||
|             tags = [("tag %i" % n) for n in xrange(0, i % 2)] | ||||
|             BlogPost(tags=tags).save() | ||||
|  | ||||
|         self.assertEqual(BlogPost.objects.count(), 10) | ||||
| @@ -619,6 +623,8 @@ class IndexesTest(unittest.TestCase): | ||||
|         post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') | ||||
|         self.assertRaises(OperationError, post3.save) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_unique_embedded_document(self): | ||||
|         """Ensure that uniqueness constraints are applied to fields on embedded documents. | ||||
|         """ | ||||
| @@ -646,6 +652,8 @@ class IndexesTest(unittest.TestCase): | ||||
|                          sub=SubDocument(year=2010, slug='test')) | ||||
|         self.assertRaises(NotUniqueError, post3.save) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_unique_embedded_document_in_list(self): | ||||
|         """ | ||||
|         Ensure that the uniqueness constraints are applied to fields in | ||||
| @@ -676,6 +684,8 @@ class IndexesTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertRaises(NotUniqueError, post2.save) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_unique_with_embedded_document_and_embedded_unique(self): | ||||
|         """Ensure that uniqueness constraints are applied to fields on | ||||
|         embedded documents.  And work with unique_with as well. | ||||
| @@ -709,6 +719,8 @@ class IndexesTest(unittest.TestCase): | ||||
|                          sub=SubDocument(year=2009, slug='test-1')) | ||||
|         self.assertRaises(NotUniqueError, post3.save) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_ttl_indexes(self): | ||||
|  | ||||
|         class Log(Document): | ||||
| @@ -721,6 +733,14 @@ class IndexesTest(unittest.TestCase): | ||||
|  | ||||
|         Log.drop_collection() | ||||
|  | ||||
|         if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3: | ||||
|             raise SkipTest('pymongo needs to be 2.3 or higher for this test') | ||||
|  | ||||
|         connection = get_connection() | ||||
|         version_array = connection.server_info()['versionArray'] | ||||
|         if version_array[0] < 2 and version_array[1] < 2: | ||||
|             raise SkipTest('MongoDB needs to be 2.2 or higher for this test') | ||||
|  | ||||
|         # Indexes are lazy so use list() to perform query | ||||
|         list(Log.objects) | ||||
|         info = Log.objects._collection.index_information() | ||||
| @@ -748,11 +768,13 @@ class IndexesTest(unittest.TestCase): | ||||
|             raise AssertionError("We saved a dupe!") | ||||
|         except NotUniqueError: | ||||
|             pass | ||||
|         Customer.drop_collection() | ||||
|  | ||||
|     def test_unique_and_primary(self): | ||||
|         """If you set a field as primary, then unexpected behaviour can occur. | ||||
|         You won't create a duplicate but you will update an existing document. | ||||
|         """ | ||||
|  | ||||
|         class User(Document): | ||||
|             name = StringField(primary_key=True, unique=True) | ||||
|             password = StringField() | ||||
| @@ -768,23 +790,8 @@ class IndexesTest(unittest.TestCase): | ||||
|         self.assertEqual(User.objects.count(), 1) | ||||
|         self.assertEqual(User.objects.get().password, 'secret2') | ||||
|  | ||||
|     def test_unique_and_primary_create(self): | ||||
|         """Create a new record with a duplicate primary key | ||||
|         throws an exception | ||||
|         """ | ||||
|         class User(Document): | ||||
|             name = StringField(primary_key=True) | ||||
|             password = StringField() | ||||
|  | ||||
|         User.drop_collection() | ||||
|  | ||||
|         User.objects.create(name='huangz', password='secret') | ||||
|         with self.assertRaises(NotUniqueError): | ||||
|             User.objects.create(name='huangz', password='secret2') | ||||
|  | ||||
|         self.assertEqual(User.objects.count(), 1) | ||||
|         self.assertEqual(User.objects.get().password, 'secret') | ||||
|  | ||||
|     def test_index_with_pk(self): | ||||
|         """Ensure you can use `pk` as part of a query""" | ||||
|  | ||||
| @@ -867,8 +874,8 @@ class IndexesTest(unittest.TestCase): | ||||
|                          info['provider_ids.foo_1_provider_ids.bar_1']['key']) | ||||
|         self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse']) | ||||
|  | ||||
|     @needs_mongodb_v26 | ||||
|     def test_text_indexes(self): | ||||
|  | ||||
|         class Book(Document): | ||||
|             title = DictField() | ||||
|             meta = { | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
| import warnings | ||||
|  | ||||
| @@ -251,17 +253,19 @@ class InheritanceTest(unittest.TestCase): | ||||
|         self.assertEqual(classes, [Human]) | ||||
|  | ||||
|     def test_allow_inheritance(self): | ||||
|         """Ensure that inheritance is disabled by default on simple | ||||
|         classes and that _cls will not be used. | ||||
|         """Ensure that inheritance may be disabled on simple classes and that | ||||
|         _cls and _subclasses will not be used. | ||||
|         """ | ||||
|  | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         # can't inherit because Animal didn't explicitly allow inheritance | ||||
|         with self.assertRaises(ValueError): | ||||
|         def create_dog_class(): | ||||
|             class Dog(Animal): | ||||
|                 pass | ||||
|  | ||||
|         self.assertRaises(ValueError, create_dog_class) | ||||
|  | ||||
|         # Check that _cls etc aren't present on simple documents | ||||
|         dog = Animal(name='dog').save() | ||||
|         self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) | ||||
| @@ -271,15 +275,17 @@ class InheritanceTest(unittest.TestCase): | ||||
|         self.assertFalse('_cls' in obj) | ||||
|  | ||||
|     def test_cant_turn_off_inheritance_on_subclass(self): | ||||
|         """Ensure if inheritance is on in a subclass you cant turn it off. | ||||
|         """Ensure if inheritance is on in a subclass you cant turn it off | ||||
|         """ | ||||
|  | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|             meta = {'allow_inheritance': True} | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|         def create_mammal_class(): | ||||
|             class Mammal(Animal): | ||||
|                 meta = {'allow_inheritance': False} | ||||
|         self.assertRaises(ValueError, create_mammal_class) | ||||
|  | ||||
|     def test_allow_inheritance_abstract_document(self): | ||||
|         """Ensure that abstract documents can set inheritance rules and that | ||||
| @@ -292,9 +298,10 @@ class InheritanceTest(unittest.TestCase): | ||||
|         class Animal(FinalDocument): | ||||
|             name = StringField() | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|         def create_mammal_class(): | ||||
|             class Mammal(Animal): | ||||
|                 pass | ||||
|         self.assertRaises(ValueError, create_mammal_class) | ||||
|  | ||||
|         # Check that _cls isn't present in simple documents | ||||
|         doc = Animal(name='dog') | ||||
| @@ -353,26 +360,29 @@ class InheritanceTest(unittest.TestCase): | ||||
|         self.assertEqual(berlin.pk, berlin.auto_id_0) | ||||
|  | ||||
|     def test_abstract_document_creation_does_not_fail(self): | ||||
|  | ||||
|         class City(Document): | ||||
|             continent = StringField() | ||||
|             meta = {'abstract': True, | ||||
|                     'allow_inheritance': False} | ||||
|  | ||||
|         bkk = City(continent='asia') | ||||
|         self.assertEqual(None, bkk.pk) | ||||
|         # TODO: expected error? Shouldn't we create a new error type? | ||||
|         with self.assertRaises(KeyError): | ||||
|             setattr(bkk, 'pk', 1) | ||||
|         self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1)) | ||||
|  | ||||
|     def test_allow_inheritance_embedded_document(self): | ||||
|         """Ensure embedded documents respect inheritance.""" | ||||
|         """Ensure embedded documents respect inheritance | ||||
|         """ | ||||
|  | ||||
|         class Comment(EmbeddedDocument): | ||||
|             content = StringField() | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|         def create_special_comment(): | ||||
|             class SpecialComment(Comment): | ||||
|                 pass | ||||
|  | ||||
|         self.assertRaises(ValueError, create_special_comment) | ||||
|  | ||||
|         doc = Comment(content='test') | ||||
|         self.assertFalse('_cls' in doc.to_mongo()) | ||||
|  | ||||
| @@ -444,11 +454,11 @@ class InheritanceTest(unittest.TestCase): | ||||
|         self.assertEqual(Guppy._get_collection_name(), 'fish') | ||||
|         self.assertEqual(Human._get_collection_name(), 'human') | ||||
|  | ||||
|         # ensure that a subclass of a non-abstract class can't be abstract | ||||
|         with self.assertRaises(ValueError): | ||||
|         def create_bad_abstract(): | ||||
|             class EvilHuman(Human): | ||||
|                 evil = BooleanField(default=True) | ||||
|                 meta = {'abstract': True} | ||||
|         self.assertRaises(ValueError, create_bad_abstract) | ||||
|  | ||||
|     def test_abstract_embedded_documents(self): | ||||
|         # 789: EmbeddedDocument shouldn't inherit abstract | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import bson | ||||
| import os | ||||
| import pickle | ||||
| @@ -13,12 +16,12 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, | ||||
|                             PickleDynamicEmbedded, PickleDynamicTest) | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.base import get_document, _document_registry | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.errors import (NotRegistered, InvalidDocumentError, | ||||
|                                 InvalidQueryError, NotUniqueError, | ||||
|                                 FieldDoesNotExist, SaveConditionError) | ||||
| from mongoengine.queryset import NULLIFY, Q | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.base import get_document | ||||
| from mongoengine.context_managers import switch_db, query_counter | ||||
| from mongoengine import signals | ||||
|  | ||||
| @@ -99,18 +102,21 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(options['size'], 4096) | ||||
|  | ||||
|         # Check that the document cannot be redefined with different options | ||||
|         class Log(Document): | ||||
|             date = DateTimeField(default=datetime.now) | ||||
|             meta = { | ||||
|                 'max_documents': 11, | ||||
|             } | ||||
|  | ||||
|         # Accessing Document.objects creates the collection | ||||
|         with self.assertRaises(InvalidCollectionError): | ||||
|         def recreate_log_document(): | ||||
|             class Log(Document): | ||||
|                 date = DateTimeField(default=datetime.now) | ||||
|                 meta = { | ||||
|                     'max_documents': 11, | ||||
|                 } | ||||
|             # Create the collection by accessing Document.objects | ||||
|             Log.objects | ||||
|         self.assertRaises(InvalidCollectionError, recreate_log_document) | ||||
|  | ||||
|         Log.drop_collection() | ||||
|  | ||||
|     def test_capped_collection_default(self): | ||||
|         """Ensure that capped collections defaults work properly.""" | ||||
|         """Ensure that capped collections defaults work properly. | ||||
|         """ | ||||
|         class Log(Document): | ||||
|             date = DateTimeField(default=datetime.now) | ||||
|             meta = { | ||||
| @@ -128,14 +134,16 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(options['size'], 10 * 2**20) | ||||
|  | ||||
|         # Check that the document with default value can be recreated | ||||
|         class Log(Document): | ||||
|             date = DateTimeField(default=datetime.now) | ||||
|             meta = { | ||||
|                 'max_documents': 10, | ||||
|             } | ||||
|  | ||||
|         # Create the collection by accessing Document.objects | ||||
|         Log.objects | ||||
|         def recreate_log_document(): | ||||
|             class Log(Document): | ||||
|                 date = DateTimeField(default=datetime.now) | ||||
|                 meta = { | ||||
|                     'max_documents': 10, | ||||
|                 } | ||||
|             # Create the collection by accessing Document.objects | ||||
|             Log.objects | ||||
|         recreate_log_document() | ||||
|         Log.drop_collection() | ||||
|  | ||||
|     def test_capped_collection_no_max_size_problems(self): | ||||
|         """Ensure that capped collections with odd max_size work properly. | ||||
| @@ -158,14 +166,16 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertTrue(options['size'] >= 10000) | ||||
|  | ||||
|         # Check that the document with odd max_size value can be recreated | ||||
|         class Log(Document): | ||||
|             date = DateTimeField(default=datetime.now) | ||||
|             meta = { | ||||
|                 'max_size': 10000, | ||||
|             } | ||||
|  | ||||
|         # Create the collection by accessing Document.objects | ||||
|         Log.objects | ||||
|         def recreate_log_document(): | ||||
|             class Log(Document): | ||||
|                 date = DateTimeField(default=datetime.now) | ||||
|                 meta = { | ||||
|                     'max_size': 10000, | ||||
|                 } | ||||
|             # Create the collection by accessing Document.objects | ||||
|             Log.objects | ||||
|         recreate_log_document() | ||||
|         Log.drop_collection() | ||||
|  | ||||
|     def test_repr(self): | ||||
|         """Ensure that unicode representation works | ||||
| @@ -276,7 +286,7 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|         list_stats = [] | ||||
|  | ||||
|         for i in range(10): | ||||
|         for i in xrange(10): | ||||
|             s = Stats() | ||||
|             s.save() | ||||
|             list_stats.append(s) | ||||
| @@ -346,14 +356,14 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(User._fields['username'].db_field, '_id') | ||||
|         self.assertEqual(User._meta['id_field'], 'username') | ||||
|  | ||||
|         # test no primary key field | ||||
|         self.assertRaises(ValidationError, User(name='test').save) | ||||
|         def create_invalid_user(): | ||||
|             User(name='test').save()  # no primary key field | ||||
|         self.assertRaises(ValidationError, create_invalid_user) | ||||
|  | ||||
|         # define a subclass with a different primary key field than the | ||||
|         # parent | ||||
|         with self.assertRaises(ValueError): | ||||
|         def define_invalid_user(): | ||||
|             class EmailUser(User): | ||||
|                 email = StringField(primary_key=True) | ||||
|         self.assertRaises(ValueError, define_invalid_user) | ||||
|  | ||||
|         class EmailUser(User): | ||||
|             email = StringField() | ||||
| @@ -401,10 +411,12 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|         # Mimic Place and NicePlace definitions being in a different file | ||||
|         # and the NicePlace model not being imported in at query time. | ||||
|         from mongoengine.base import _document_registry | ||||
|         del(_document_registry['Place.NicePlace']) | ||||
|  | ||||
|         with self.assertRaises(NotRegistered): | ||||
|             list(Place.objects.all()) | ||||
|         def query_without_importing_nice_place(): | ||||
|             print Place.objects.all() | ||||
|         self.assertRaises(NotRegistered, query_without_importing_nice_place) | ||||
|  | ||||
|     def test_document_registry_regressions(self): | ||||
|  | ||||
| @@ -435,15 +447,6 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|         person.to_dbref() | ||||
|  | ||||
|     def test_save_abstract_document(self): | ||||
|         """Saving an abstract document should fail.""" | ||||
|         class Doc(Document): | ||||
|             name = StringField() | ||||
|             meta = {'abstract': True} | ||||
|  | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|             Doc(name='aaa').save() | ||||
|  | ||||
|     def test_reload(self): | ||||
|         """Ensure that attributes may be reloaded. | ||||
|         """ | ||||
| @@ -742,7 +745,7 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|         try: | ||||
|             t.save() | ||||
|         except ValidationError as e: | ||||
|         except ValidationError, e: | ||||
|             expect_msg = "Draft entries may not have a publication date." | ||||
|             self.assertTrue(expect_msg in e.message) | ||||
|             self.assertEqual(e.to_dict(), {'__all__': expect_msg}) | ||||
| @@ -781,7 +784,7 @@ class InstanceTest(unittest.TestCase): | ||||
|         t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) | ||||
|         try: | ||||
|             t.save() | ||||
|         except ValidationError as e: | ||||
|         except ValidationError, e: | ||||
|             expect_msg = "Value of z != x + y" | ||||
|             self.assertTrue(expect_msg in e.message) | ||||
|             self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) | ||||
| @@ -795,10 +798,8 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|     def test_modify_empty(self): | ||||
|         doc = self.Person(name="bob", age=10).save() | ||||
|  | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|             self.Person().modify(set__age=10) | ||||
|  | ||||
|         self.assertRaises( | ||||
|             InvalidDocumentError, lambda: self.Person().modify(set__age=10)) | ||||
|         self.assertDbEqual([dict(doc.to_mongo())]) | ||||
|  | ||||
|     def test_modify_invalid_query(self): | ||||
| @@ -806,8 +807,9 @@ class InstanceTest(unittest.TestCase): | ||||
|         doc2 = self.Person(name="jim", age=20).save() | ||||
|         docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] | ||||
|  | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             doc1.modify({'id': doc2.id}, set__value=20) | ||||
|         self.assertRaises( | ||||
|             InvalidQueryError, | ||||
|             lambda: doc1.modify(dict(id=doc2.id), set__value=20)) | ||||
|  | ||||
|         self.assertDbEqual(docs) | ||||
|  | ||||
| @@ -816,7 +818,7 @@ class InstanceTest(unittest.TestCase): | ||||
|         doc2 = self.Person(name="jim", age=20).save() | ||||
|         docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] | ||||
|  | ||||
|         assert not doc1.modify({'name': doc2.name}, set__age=100) | ||||
|         assert not doc1.modify(dict(name=doc2.name), set__age=100) | ||||
|  | ||||
|         self.assertDbEqual(docs) | ||||
|  | ||||
| @@ -825,7 +827,7 @@ class InstanceTest(unittest.TestCase): | ||||
|         doc2 = self.Person(id=ObjectId(), name="jim", age=20) | ||||
|         docs = [dict(doc1.to_mongo())] | ||||
|  | ||||
|         assert not doc2.modify({'name': doc2.name}, set__age=100) | ||||
|         assert not doc2.modify(dict(name=doc2.name), set__age=100) | ||||
|  | ||||
|         self.assertDbEqual(docs) | ||||
|  | ||||
| @@ -1232,19 +1234,6 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(person.name, None) | ||||
|         self.assertEqual(person.age, None) | ||||
|  | ||||
|     def test_update_rename_operator(self): | ||||
|         """Test the $rename operator.""" | ||||
|         coll = self.Person._get_collection() | ||||
|         doc = self.Person(name='John').save() | ||||
|         raw_doc = coll.find_one({'_id': doc.pk}) | ||||
|         self.assertEqual(set(raw_doc.keys()), set(['_id', '_cls', 'name'])) | ||||
|  | ||||
|         doc.update(rename__name='first_name') | ||||
|         raw_doc = coll.find_one({'_id': doc.pk}) | ||||
|         self.assertEqual(set(raw_doc.keys()), | ||||
|                          set(['_id', '_cls', 'first_name'])) | ||||
|         self.assertEqual(raw_doc['first_name'], 'John') | ||||
|  | ||||
|     def test_inserts_if_you_set_the_pk(self): | ||||
|         p1 = self.Person(name='p1', id=bson.ObjectId()).save() | ||||
|         p2 = self.Person(name='p2') | ||||
| @@ -1304,11 +1293,12 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|     def test_document_update(self): | ||||
|  | ||||
|         # try updating a non-saved document | ||||
|         with self.assertRaises(OperationError): | ||||
|         def update_not_saved_raises(): | ||||
|             person = self.Person(name='dcrosta') | ||||
|             person.update(set__name='Dan Crosta') | ||||
|  | ||||
|         self.assertRaises(OperationError, update_not_saved_raises) | ||||
|  | ||||
|         author = self.Person(name='dcrosta') | ||||
|         author.save() | ||||
|  | ||||
| @@ -1318,17 +1308,19 @@ class InstanceTest(unittest.TestCase): | ||||
|         p1 = self.Person.objects.first() | ||||
|         self.assertEqual(p1.name, author.name) | ||||
|  | ||||
|         # try sending an empty update | ||||
|         with self.assertRaises(OperationError): | ||||
|         def update_no_value_raises(): | ||||
|             person = self.Person.objects.first() | ||||
|             person.update() | ||||
|  | ||||
|         # update that doesn't explicitly specify an operator should default | ||||
|         # to 'set__' | ||||
|         person = self.Person.objects.first() | ||||
|         person.update(name="Dan") | ||||
|         person.reload() | ||||
|         self.assertEqual("Dan", person.name) | ||||
|         self.assertRaises(OperationError, update_no_value_raises) | ||||
|  | ||||
|         def update_no_op_should_default_to_set(): | ||||
|             person = self.Person.objects.first() | ||||
|             person.update(name="Dan") | ||||
|             person.reload() | ||||
|             return person.name | ||||
|  | ||||
|         self.assertEqual("Dan", update_no_op_should_default_to_set()) | ||||
|  | ||||
|     def test_update_unique_field(self): | ||||
|         class Doc(Document): | ||||
| @@ -1337,8 +1329,8 @@ class InstanceTest(unittest.TestCase): | ||||
|         doc1 = Doc(name="first").save() | ||||
|         doc2 = Doc(name="second").save() | ||||
|  | ||||
|         with self.assertRaises(NotUniqueError): | ||||
|             doc2.update(set__name=doc1.name) | ||||
|         self.assertRaises(NotUniqueError, lambda: | ||||
|                           doc2.update(set__name=doc1.name)) | ||||
|  | ||||
|     def test_embedded_update(self): | ||||
|         """ | ||||
| @@ -1856,13 +1848,15 @@ class InstanceTest(unittest.TestCase): | ||||
|  | ||||
|     def test_duplicate_db_fields_raise_invalid_document_error(self): | ||||
|         """Ensure a InvalidDocumentError is thrown if duplicate fields | ||||
|         declare the same db_field. | ||||
|         """ | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|         declare the same db_field""" | ||||
|  | ||||
|         def throw_invalid_document_error(): | ||||
|             class Foo(Document): | ||||
|                 name = StringField() | ||||
|                 name2 = StringField(db_field='name') | ||||
|  | ||||
|         self.assertRaises(InvalidDocumentError, throw_invalid_document_error) | ||||
|  | ||||
|     def test_invalid_son(self): | ||||
|         """Raise an error if loading invalid data""" | ||||
|         class Occurrence(EmbeddedDocument): | ||||
| @@ -1874,17 +1868,11 @@ class InstanceTest(unittest.TestCase): | ||||
|             forms = ListField(StringField(), default=list) | ||||
|             occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) | ||||
|  | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|             Word._from_son({ | ||||
|                 'stem': [1, 2, 3], | ||||
|                 'forms': 1, | ||||
|                 'count': 'one', | ||||
|                 'occurs': {"hello": None} | ||||
|             }) | ||||
|         def raise_invalid_document(): | ||||
|             Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', | ||||
|                             'occurs': {"hello": None}}) | ||||
|  | ||||
|         # Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438 | ||||
|         with self.assertRaises(ValueError): | ||||
|             Word._from_son('this is not a valid SON dict') | ||||
|         self.assertRaises(InvalidDocumentError, raise_invalid_document) | ||||
|  | ||||
|     def test_reverse_delete_rule_cascade_and_nullify(self): | ||||
|         """Ensure that a referenced document is also deleted upon deletion. | ||||
| @@ -2115,7 +2103,8 @@ class InstanceTest(unittest.TestCase): | ||||
|         self.assertEqual(Bar.objects.get().foo, None) | ||||
|  | ||||
|     def test_invalid_reverse_delete_rule_raise_errors(self): | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|  | ||||
|         def throw_invalid_document_error(): | ||||
|             class Blog(Document): | ||||
|                 content = StringField() | ||||
|                 authors = MapField(ReferenceField( | ||||
| @@ -2125,15 +2114,21 @@ class InstanceTest(unittest.TestCase): | ||||
|                         self.Person, | ||||
|                         reverse_delete_rule=NULLIFY)) | ||||
|  | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|         self.assertRaises(InvalidDocumentError, throw_invalid_document_error) | ||||
|  | ||||
|         def throw_invalid_document_error_embedded(): | ||||
|             class Parents(EmbeddedDocument): | ||||
|                 father = ReferenceField('Person', reverse_delete_rule=DENY) | ||||
|                 mother = ReferenceField('Person', reverse_delete_rule=DENY) | ||||
|  | ||||
|         self.assertRaises( | ||||
|             InvalidDocumentError, throw_invalid_document_error_embedded) | ||||
|  | ||||
|     def test_reverse_delete_rule_cascade_recurs(self): | ||||
|         """Ensure that a chain of documents is also deleted upon cascaded | ||||
|         deletion. | ||||
|         """ | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             content = StringField() | ||||
|             author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) | ||||
| @@ -2349,14 +2344,15 @@ class InstanceTest(unittest.TestCase): | ||||
|         pickle_doc.save() | ||||
|         pickle_doc.delete() | ||||
|  | ||||
|     def test_override_method_with_field(self): | ||||
|         """Test creating a field with a field name that would override | ||||
|         the "validate" method. | ||||
|         """ | ||||
|         with self.assertRaises(InvalidDocumentError): | ||||
|     def test_throw_invalid_document_error(self): | ||||
|  | ||||
|         # test handles people trying to upsert | ||||
|         def throw_invalid_document_error(): | ||||
|             class Blog(Document): | ||||
|                 validate = DictField() | ||||
|  | ||||
|         self.assertRaises(InvalidDocumentError, throw_invalid_document_error) | ||||
|  | ||||
|     def test_mutating_documents(self): | ||||
|  | ||||
|         class B(EmbeddedDocument): | ||||
| @@ -2819,10 +2815,11 @@ class InstanceTest(unittest.TestCase): | ||||
|         log.log = "Saving" | ||||
|         log.save() | ||||
|  | ||||
|         # try to change the shard key | ||||
|         with self.assertRaises(OperationError): | ||||
|         def change_shard_key(): | ||||
|             log.machine = "127.0.0.1" | ||||
|  | ||||
|         self.assertRaises(OperationError, change_shard_key) | ||||
|  | ||||
|     def test_shard_key_in_embedded_document(self): | ||||
|         class Foo(EmbeddedDocument): | ||||
|             foo = StringField() | ||||
| @@ -2843,11 +2840,12 @@ class InstanceTest(unittest.TestCase): | ||||
|         bar_doc.bar = 'baz' | ||||
|         bar_doc.save() | ||||
|  | ||||
|         # try to change the shard key | ||||
|         with self.assertRaises(OperationError): | ||||
|         def change_shard_key(): | ||||
|             bar_doc.foo.foo = 'something' | ||||
|             bar_doc.save() | ||||
|  | ||||
|         self.assertRaises(OperationError, change_shard_key) | ||||
|  | ||||
|     def test_shard_key_primary(self): | ||||
|         class LogEntry(Document): | ||||
|             machine = StringField(primary_key=True) | ||||
| @@ -2868,10 +2866,11 @@ class InstanceTest(unittest.TestCase): | ||||
|         log.log = "Saving" | ||||
|         log.save() | ||||
|  | ||||
|         # try to change the shard key | ||||
|         with self.assertRaises(OperationError): | ||||
|         def change_shard_key(): | ||||
|             log.machine = "127.0.0.1" | ||||
|  | ||||
|         self.assertRaises(OperationError, change_shard_key) | ||||
|  | ||||
|     def test_kwargs_simple(self): | ||||
|  | ||||
|         class Embedded(EmbeddedDocument): | ||||
| @@ -2956,9 +2955,11 @@ class InstanceTest(unittest.TestCase): | ||||
|     def test_bad_mixed_creation(self): | ||||
|         """Ensure that document gives correct error when duplicating arguments | ||||
|         """ | ||||
|         with self.assertRaises(TypeError): | ||||
|         def construct_bad_instance(): | ||||
|             return self.Person("Test User", 42, name="Bad User") | ||||
|  | ||||
|         self.assertRaises(TypeError, construct_bad_instance) | ||||
|  | ||||
|     def test_data_contains_id_field(self): | ||||
|         """Ensure that asking for _data returns 'id' | ||||
|         """ | ||||
|   | ||||
| @@ -1,3 +1,6 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
| import uuid | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
| from datetime import datetime | ||||
|  | ||||
| @@ -57,7 +60,7 @@ class ValidatorErrorTest(unittest.TestCase): | ||||
|  | ||||
|         try: | ||||
|             User().validate() | ||||
|         except ValidationError as e: | ||||
|         except ValidationError, e: | ||||
|             self.assertTrue("User:None" in e.message) | ||||
|             self.assertEqual(e.to_dict(), { | ||||
|                 'username': 'Field is required', | ||||
| @@ -67,7 +70,7 @@ class ValidatorErrorTest(unittest.TestCase): | ||||
|         user.name = None | ||||
|         try: | ||||
|             user.save() | ||||
|         except ValidationError as e: | ||||
|         except ValidationError, e: | ||||
|             self.assertTrue("User:RossC0" in e.message) | ||||
|             self.assertEqual(e.to_dict(), { | ||||
|                 'name': 'Field is required'}) | ||||
| @@ -115,7 +118,7 @@ class ValidatorErrorTest(unittest.TestCase): | ||||
|  | ||||
|         try: | ||||
|             Doc(id="bad").validate() | ||||
|         except ValidationError as e: | ||||
|         except ValidationError, e: | ||||
|             self.assertTrue("SubDoc:None" in e.message) | ||||
|             self.assertEqual(e.to_dict(), { | ||||
|                 "e": {'val': 'OK could not be converted to int'}}) | ||||
| @@ -133,7 +136,7 @@ class ValidatorErrorTest(unittest.TestCase): | ||||
|         doc.e.val = "OK" | ||||
|         try: | ||||
|             doc.save() | ||||
|         except ValidationError as e: | ||||
|         except ValidationError, e: | ||||
|             self.assertTrue("Doc:test" in e.message) | ||||
|             self.assertEqual(e.to_dict(), { | ||||
|                 "e": {'val': 'OK could not be converted to int'}}) | ||||
| @@ -153,14 +156,14 @@ class ValidatorErrorTest(unittest.TestCase): | ||||
|  | ||||
|         s = SubDoc() | ||||
|  | ||||
|         self.assertRaises(ValidationError, s.validate) | ||||
|         self.assertRaises(ValidationError, lambda: s.validate()) | ||||
|  | ||||
|         d1.e = s | ||||
|         d2.e = s | ||||
|  | ||||
|         del d1 | ||||
|  | ||||
|         self.assertRaises(ValidationError, d2.validate) | ||||
|         self.assertRaises(ValidationError, lambda: d2.validate()) | ||||
|  | ||||
|     def test_parent_reference_in_child_document(self): | ||||
|         """ | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,16 +1,18 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import copy | ||||
| import os | ||||
| import unittest | ||||
| import tempfile | ||||
|  | ||||
| import gridfs | ||||
| import six | ||||
|  | ||||
| from nose.plugins.skip import SkipTest | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.python_support import StringIO | ||||
| from mongoengine.python_support import b, StringIO | ||||
|  | ||||
| try: | ||||
|     from PIL import Image | ||||
| @@ -18,13 +20,15 @@ try: | ||||
| except ImportError: | ||||
|     HAS_PIL = False | ||||
|  | ||||
| from tests.utils import MongoDBTestCase | ||||
|  | ||||
| TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') | ||||
| TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') | ||||
|  | ||||
|  | ||||
| class FileTest(MongoDBTestCase): | ||||
| class FileTest(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         self.db.drop_collection('fs.files') | ||||
| @@ -45,7 +49,7 @@ class FileTest(MongoDBTestCase): | ||||
|  | ||||
|         PutFile.drop_collection() | ||||
|  | ||||
|         text = six.b('Hello, World!') | ||||
|         text = b('Hello, World!') | ||||
|         content_type = 'text/plain' | ||||
|  | ||||
|         putfile = PutFile() | ||||
| @@ -84,8 +88,8 @@ class FileTest(MongoDBTestCase): | ||||
|  | ||||
|         StreamFile.drop_collection() | ||||
|  | ||||
|         text = six.b('Hello, World!') | ||||
|         more_text = six.b('Foo Bar') | ||||
|         text = b('Hello, World!') | ||||
|         more_text = b('Foo Bar') | ||||
|         content_type = 'text/plain' | ||||
|  | ||||
|         streamfile = StreamFile() | ||||
| @@ -119,8 +123,8 @@ class FileTest(MongoDBTestCase): | ||||
|  | ||||
|         StreamFile.drop_collection() | ||||
|  | ||||
|         text = six.b('Hello, World!') | ||||
|         more_text = six.b('Foo Bar') | ||||
|         text = b('Hello, World!') | ||||
|         more_text = b('Foo Bar') | ||||
|         content_type = 'text/plain' | ||||
|  | ||||
|         streamfile = StreamFile() | ||||
| @@ -151,8 +155,8 @@ class FileTest(MongoDBTestCase): | ||||
|         class SetFile(Document): | ||||
|             the_file = FileField() | ||||
|  | ||||
|         text = six.b('Hello, World!') | ||||
|         more_text = six.b('Foo Bar') | ||||
|         text = b('Hello, World!') | ||||
|         more_text = b('Foo Bar') | ||||
|  | ||||
|         SetFile.drop_collection() | ||||
|  | ||||
| @@ -181,7 +185,7 @@ class FileTest(MongoDBTestCase): | ||||
|         GridDocument.drop_collection() | ||||
|  | ||||
|         with tempfile.TemporaryFile() as f: | ||||
|             f.write(six.b("Hello World!")) | ||||
|             f.write(b("Hello World!")) | ||||
|             f.flush() | ||||
|  | ||||
|             # Test without default | ||||
| @@ -198,7 +202,7 @@ class FileTest(MongoDBTestCase): | ||||
|             self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) | ||||
|  | ||||
|             # Test with default | ||||
|             doc_d = GridDocument(the_file=six.b('')) | ||||
|             doc_d = GridDocument(the_file=b('')) | ||||
|             doc_d.save() | ||||
|  | ||||
|             doc_e = GridDocument.objects.with_id(doc_d.id) | ||||
| @@ -224,7 +228,7 @@ class FileTest(MongoDBTestCase): | ||||
|         # First instance | ||||
|         test_file = TestFile() | ||||
|         test_file.name = "Hello, World!" | ||||
|         test_file.the_file.put(six.b('Hello, World!')) | ||||
|         test_file.the_file.put(b('Hello, World!')) | ||||
|         test_file.save() | ||||
|  | ||||
|         # Second instance | ||||
| @@ -278,7 +282,7 @@ class FileTest(MongoDBTestCase): | ||||
|  | ||||
|         test_file = TestFile() | ||||
|         self.assertFalse(bool(test_file.the_file)) | ||||
|         test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain') | ||||
|         test_file.the_file.put(b('Hello, World!'), content_type='text/plain') | ||||
|         test_file.save() | ||||
|         self.assertTrue(bool(test_file.the_file)) | ||||
|  | ||||
| @@ -298,7 +302,7 @@ class FileTest(MongoDBTestCase): | ||||
|         class TestFile(Document): | ||||
|             the_file = FileField() | ||||
|              | ||||
|         text = six.b('Hello, World!') | ||||
|         text = b('Hello, World!') | ||||
|         content_type = 'text/plain' | ||||
|  | ||||
|         testfile = TestFile() | ||||
| @@ -342,7 +346,7 @@ class FileTest(MongoDBTestCase): | ||||
|         testfile.the_file.put(text, content_type=content_type, filename="hello") | ||||
|         testfile.save() | ||||
|          | ||||
|         text = six.b('Bonjour, World!') | ||||
|         text = b('Bonjour, World!') | ||||
|         testfile.the_file.replace(text, content_type=content_type, filename="hello") | ||||
|         testfile.save() | ||||
|          | ||||
| @@ -368,14 +372,14 @@ class FileTest(MongoDBTestCase): | ||||
|         TestImage.drop_collection() | ||||
|  | ||||
|         with tempfile.TemporaryFile() as f: | ||||
|             f.write(six.b("Hello World!")) | ||||
|             f.write(b("Hello World!")) | ||||
|             f.flush() | ||||
|  | ||||
|             t = TestImage() | ||||
|             try: | ||||
|                 t.image.put(f) | ||||
|                 self.fail("Should have raised an invalidation error") | ||||
|             except ValidationError as e: | ||||
|             except ValidationError, e: | ||||
|                 self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) | ||||
|  | ||||
|         t = TestImage() | ||||
| @@ -492,7 +496,7 @@ class FileTest(MongoDBTestCase): | ||||
|         # First instance | ||||
|         test_file = TestFile() | ||||
|         test_file.name = "Hello, World!" | ||||
|         test_file.the_file.put(six.b('Hello, World!'), | ||||
|         test_file.the_file.put(b('Hello, World!'), | ||||
|                           name="hello.txt") | ||||
|         test_file.save() | ||||
|  | ||||
| @@ -500,15 +504,16 @@ class FileTest(MongoDBTestCase): | ||||
|         self.assertEqual(data.get('name'), 'hello.txt') | ||||
|  | ||||
|         test_file = TestFile.objects.first() | ||||
|         self.assertEqual(test_file.the_file.read(), six.b('Hello, World!')) | ||||
|         self.assertEqual(test_file.the_file.read(), | ||||
|                           b('Hello, World!')) | ||||
|  | ||||
|         test_file = TestFile.objects.first() | ||||
|         test_file.the_file = six.b('HELLO, WORLD!') | ||||
|         test_file.the_file = b('HELLO, WORLD!') | ||||
|         test_file.save() | ||||
|  | ||||
|         test_file = TestFile.objects.first() | ||||
|         self.assertEqual(test_file.the_file.read(), | ||||
|                          six.b('HELLO, WORLD!')) | ||||
|                           b('HELLO, WORLD!')) | ||||
|  | ||||
|     def test_copyable(self): | ||||
|         class PutFile(Document): | ||||
| @@ -516,7 +521,7 @@ class FileTest(MongoDBTestCase): | ||||
|  | ||||
|         PutFile.drop_collection() | ||||
|  | ||||
|         text = six.b('Hello, World!') | ||||
|         text = b('Hello, World!') | ||||
|         content_type = 'text/plain' | ||||
|  | ||||
|         putfile = PutFile() | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
|   | ||||
							
								
								
									
										11
									
								
								tests/migration/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								tests/migration/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | ||||
| import unittest | ||||
|  | ||||
| from convert_to_new_inheritance_model import * | ||||
| from decimalfield_as_float import * | ||||
| from referencefield_dbref_to_object_id import * | ||||
| from turn_off_inheritance import * | ||||
| from uuidfield_to_binary import * | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
							
								
								
									
										51
									
								
								tests/migration/convert_to_new_inheritance_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								tests/migration/convert_to_new_inheritance_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import Document, connect | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.fields import StringField | ||||
|  | ||||
| __all__ = ('ConvertToNewInheritanceModel', ) | ||||
|  | ||||
|  | ||||
| class ConvertToNewInheritanceModel(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         for collection in self.db.collection_names(): | ||||
|             if 'system.' in collection: | ||||
|                 continue | ||||
|             self.db.drop_collection(collection) | ||||
|  | ||||
|     def test_how_to_convert_to_the_new_inheritance_model(self): | ||||
|         """Demonstrates migrating from 0.7 to 0.8 | ||||
|         """ | ||||
|  | ||||
|         # 1. Declaration of the class | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|             meta = { | ||||
|                 'allow_inheritance': True, | ||||
|                 'indexes': ['name'] | ||||
|             } | ||||
|  | ||||
|         # 2. Remove _types | ||||
|         collection = Animal._get_collection() | ||||
|         collection.update({}, {"$unset": {"_types": 1}}, multi=True) | ||||
|  | ||||
|         # 3. Confirm extra data is removed | ||||
|         count = collection.find({'_types': {"$exists": True}}).count() | ||||
|         self.assertEqual(0, count) | ||||
|  | ||||
|         # 4. Remove indexes | ||||
|         info = collection.index_information() | ||||
|         indexes_to_drop = [key for key, value in info.iteritems() | ||||
|                            if '_types' in dict(value['key'])] | ||||
|         for index in indexes_to_drop: | ||||
|             collection.drop_index(index) | ||||
|  | ||||
|         # 5. Recreate indexes | ||||
|         Animal.ensure_indexes() | ||||
							
								
								
									
										50
									
								
								tests/migration/decimalfield_as_float.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tests/migration/decimalfield_as_float.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
|  # -*- coding: utf-8 -*- | ||||
| import unittest | ||||
| import decimal | ||||
| from decimal import Decimal | ||||
|  | ||||
| from mongoengine import Document, connect | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.fields import StringField, DecimalField, ListField | ||||
|  | ||||
| __all__ = ('ConvertDecimalField', ) | ||||
|  | ||||
|  | ||||
| class ConvertDecimalField(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def test_how_to_convert_decimal_fields(self): | ||||
|         """Demonstrates migrating from 0.7 to 0.8 | ||||
|         """ | ||||
|  | ||||
|         # 1. Old definition - using dbrefs | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|             money = DecimalField(force_string=True) | ||||
|             monies = ListField(DecimalField(force_string=True)) | ||||
|  | ||||
|         Person.drop_collection() | ||||
|         Person(name="Wilson Jr", money=Decimal("2.50"), | ||||
|                monies=[Decimal("2.10"), Decimal("5.00")]).save() | ||||
|  | ||||
|         # 2. Start the migration by changing the schema | ||||
|         # Change DecimalField - add precision and rounding settings | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|             money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) | ||||
|             monies = ListField(DecimalField(precision=2, | ||||
|                                             rounding=decimal.ROUND_HALF_UP)) | ||||
|  | ||||
|         # 3. Loop all the objects and mark parent as changed | ||||
|         for p in Person.objects: | ||||
|             p._mark_as_changed('money') | ||||
|             p._mark_as_changed('monies') | ||||
|             p.save() | ||||
|  | ||||
|         # 4. Confirmation of the fix! | ||||
|         wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] | ||||
|         self.assertTrue(isinstance(wilson['money'], float)) | ||||
|         self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) | ||||
							
								
								
									
										52
									
								
								tests/migration/referencefield_dbref_to_object_id.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								tests/migration/referencefield_dbref_to_object_id.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import Document, connect | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.fields import StringField, ReferenceField, ListField | ||||
|  | ||||
| __all__ = ('ConvertToObjectIdsModel', ) | ||||
|  | ||||
|  | ||||
| class ConvertToObjectIdsModel(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def test_how_to_convert_to_object_id_reference_fields(self): | ||||
|         """Demonstrates migrating from 0.7 to 0.8 | ||||
|         """ | ||||
|  | ||||
|         # 1. Old definition - using dbrefs | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|             parent = ReferenceField('self', dbref=True) | ||||
|             friends = ListField(ReferenceField('self', dbref=True)) | ||||
|  | ||||
|         Person.drop_collection() | ||||
|  | ||||
|         p1 = Person(name="Wilson", parent=None).save() | ||||
|         f1 = Person(name="John", parent=None).save() | ||||
|         f2 = Person(name="Paul", parent=None).save() | ||||
|         f3 = Person(name="George", parent=None).save() | ||||
|         f4 = Person(name="Ringo", parent=None).save() | ||||
|         Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save() | ||||
|  | ||||
|         # 2. Start the migration by changing the schema | ||||
|         # Change ReferenceField as now dbref defaults to False | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|             parent = ReferenceField('self') | ||||
|             friends = ListField(ReferenceField('self')) | ||||
|  | ||||
|         # 3. Loop all the objects and mark parent as changed | ||||
|         for p in Person.objects: | ||||
|             p._mark_as_changed('parent') | ||||
|             p._mark_as_changed('friends') | ||||
|             p.save() | ||||
|  | ||||
|         # 4. Confirmation of the fix! | ||||
|         wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] | ||||
|         self.assertEqual(p1.id, wilson['parent']) | ||||
|         self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends']) | ||||
							
								
								
									
										62
									
								
								tests/migration/turn_off_inheritance.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								tests/migration/turn_off_inheritance.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import Document, connect | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.fields import StringField | ||||
|  | ||||
| __all__ = ('TurnOffInheritanceTest', ) | ||||
|  | ||||
|  | ||||
| class TurnOffInheritanceTest(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         for collection in self.db.collection_names(): | ||||
|             if 'system.' in collection: | ||||
|                 continue | ||||
|             self.db.drop_collection(collection) | ||||
|  | ||||
|     def test_how_to_turn_off_inheritance(self): | ||||
|         """Demonstrates migrating from allow_inheritance = True to False. | ||||
|         """ | ||||
|  | ||||
|         # 1. Old declaration of the class | ||||
|  | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|             meta = { | ||||
|                 'allow_inheritance': True, | ||||
|                 'indexes': ['name'] | ||||
|             } | ||||
|  | ||||
|         # 2. Turn off inheritance | ||||
|         class Animal(Document): | ||||
|             name = StringField() | ||||
|             meta = { | ||||
|                 'allow_inheritance': False, | ||||
|                 'indexes': ['name'] | ||||
|             } | ||||
|  | ||||
|         # 3. Remove _types and _cls | ||||
|         collection = Animal._get_collection() | ||||
|         collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) | ||||
|  | ||||
|         # 3. Confirm extra data is removed | ||||
|         count = collection.find({"$or": [{'_types': {"$exists": True}}, | ||||
|                                          {'_cls': {"$exists": True}}]}).count() | ||||
|         assert count == 0 | ||||
|  | ||||
|         # 4. Remove indexes | ||||
|         info = collection.index_information() | ||||
|         indexes_to_drop = [key for key, value in info.iteritems() | ||||
|                            if '_types' in dict(value['key']) | ||||
|                               or '_cls' in dict(value['key'])] | ||||
|         for index in indexes_to_drop: | ||||
|             collection.drop_index(index) | ||||
|  | ||||
|         # 5. Recreate indexes | ||||
|         Animal.ensure_indexes() | ||||
							
								
								
									
										48
									
								
								tests/migration/uuidfield_to_binary.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tests/migration/uuidfield_to_binary.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import unittest | ||||
| import uuid | ||||
|  | ||||
| from mongoengine import Document, connect | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.fields import StringField, UUIDField, ListField | ||||
|  | ||||
| __all__ = ('ConvertToBinaryUUID', ) | ||||
|  | ||||
|  | ||||
| class ConvertToBinaryUUID(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def test_how_to_convert_to_binary_uuid_fields(self): | ||||
|         """Demonstrates migrating from 0.7 to 0.8 | ||||
|         """ | ||||
|  | ||||
|         # 1. Old definition - using dbrefs | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|             uuid = UUIDField(binary=False) | ||||
|             uuids = ListField(UUIDField(binary=False)) | ||||
|  | ||||
|         Person.drop_collection() | ||||
|         Person(name="Wilson Jr", uuid=uuid.uuid4(), | ||||
|                uuids=[uuid.uuid4(), uuid.uuid4()]).save() | ||||
|  | ||||
|         # 2. Start the migration by changing the schema | ||||
|         # Change UUIDFIeld as now binary defaults to True | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|             uuid = UUIDField() | ||||
|             uuids = ListField(UUIDField()) | ||||
|  | ||||
|         # 3. Loop all the objects and mark parent as changed | ||||
|         for p in Person.objects: | ||||
|             p._mark_as_changed('uuid') | ||||
|             p._mark_as_changed('uuids') | ||||
|             p.save() | ||||
|  | ||||
|         # 4. Confirmation of the fix! | ||||
|         wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] | ||||
|         self.assertTrue(isinstance(wilson['uuid'], uuid.UUID)) | ||||
|         self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']])) | ||||
| @@ -1,3 +1,6 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
| @@ -92,7 +95,7 @@ class OnlyExcludeAllTest(unittest.TestCase): | ||||
|         exclude = ['d', 'e'] | ||||
|         only = ['b', 'c'] | ||||
|  | ||||
|         qs = MyDoc.objects.fields(**{i: 1 for i in include}) | ||||
|         qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), | ||||
|                          {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) | ||||
|         qs = qs.only(*only) | ||||
| @@ -100,14 +103,14 @@ class OnlyExcludeAllTest(unittest.TestCase): | ||||
|         qs = qs.exclude(*exclude) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) | ||||
|  | ||||
|         qs = MyDoc.objects.fields(**{i: 1 for i in include}) | ||||
|         qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) | ||||
|         qs = qs.exclude(*exclude) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) | ||||
|         qs = qs.only(*only) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) | ||||
|  | ||||
|         qs = MyDoc.objects.exclude(*exclude) | ||||
|         qs = qs.fields(**{i: 1 for i in include}) | ||||
|         qs = qs.fields(**dict(((i, 1) for i in include))) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) | ||||
|         qs = qs.only(*only) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) | ||||
| @@ -126,7 +129,7 @@ class OnlyExcludeAllTest(unittest.TestCase): | ||||
|         exclude = ['d', 'e'] | ||||
|         only = ['b', 'c'] | ||||
|  | ||||
|         qs = MyDoc.objects.fields(**{i: 1 for i in include}) | ||||
|         qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) | ||||
|         qs = qs.exclude(*exclude) | ||||
|         qs = qs.only(*only) | ||||
|         qs = qs.fields(slice__b=5) | ||||
| @@ -141,16 +144,6 @@ class OnlyExcludeAllTest(unittest.TestCase): | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), | ||||
|                          {'b': {'$slice': 5}}) | ||||
|  | ||||
|     def test_mix_slice_with_other_fields(self): | ||||
|         class MyDoc(Document): | ||||
|             a = ListField() | ||||
|             b = ListField() | ||||
|             c = ListField() | ||||
|  | ||||
|         qs = MyDoc.objects.fields(a=1, b=0, slice__c=2) | ||||
|         self.assertEqual(qs._loaded_fields.as_dict(), | ||||
|                          {'c': {'$slice': 2}, 'a': 1}) | ||||
|  | ||||
|     def test_only(self): | ||||
|         """Ensure that QuerySet.only only returns the requested fields. | ||||
|         """ | ||||
|   | ||||
| @@ -1,139 +1,109 @@ | ||||
| import datetime | ||||
| import sys | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
| from datetime import datetime, timedelta | ||||
|  | ||||
| from pymongo.errors import OperationFailure | ||||
| from mongoengine import * | ||||
|  | ||||
| from tests.utils import MongoDBTestCase, needs_mongodb_v3 | ||||
| from mongoengine.connection import get_connection | ||||
| from nose.plugins.skip import SkipTest | ||||
|  | ||||
|  | ||||
| __all__ = ("GeoQueriesTest",) | ||||
|  | ||||
|  | ||||
| class GeoQueriesTest(MongoDBTestCase): | ||||
| class GeoQueriesTest(unittest.TestCase): | ||||
|  | ||||
|     def _create_event_data(self, point_field_class=GeoPointField): | ||||
|         """Create some sample data re-used in many of the tests below.""" | ||||
|     def setUp(self): | ||||
|         connect(db='mongoenginetest') | ||||
|  | ||||
|     def test_geospatial_operators(self): | ||||
|         """Ensure that geospatial queries are working. | ||||
|         """ | ||||
|         class Event(Document): | ||||
|             title = StringField() | ||||
|             date = DateTimeField() | ||||
|             location = point_field_class() | ||||
|             location = GeoPointField() | ||||
|  | ||||
|             def __unicode__(self): | ||||
|                 return self.title | ||||
|  | ||||
|         self.Event = Event | ||||
|  | ||||
|         Event.drop_collection() | ||||
|  | ||||
|         event1 = Event.objects.create( | ||||
|             title="Coltrane Motion @ Double Door", | ||||
|             date=datetime.datetime.now() - datetime.timedelta(days=1), | ||||
|             location=[-87.677137, 41.909889]) | ||||
|         event2 = Event.objects.create( | ||||
|             title="Coltrane Motion @ Bottom of the Hill", | ||||
|             date=datetime.datetime.now() - datetime.timedelta(days=10), | ||||
|             location=[-122.4194155, 37.7749295]) | ||||
|         event3 = Event.objects.create( | ||||
|             title="Coltrane Motion @ Empty Bottle", | ||||
|             date=datetime.datetime.now(), | ||||
|             location=[-87.686638, 41.900474]) | ||||
|  | ||||
|         return event1, event2, event3 | ||||
|  | ||||
|     def test_near(self): | ||||
|         """Make sure the "near" operator works.""" | ||||
|         event1, event2, event3 = self._create_event_data() | ||||
|         event1 = Event(title="Coltrane Motion @ Double Door", | ||||
|                        date=datetime.now() - timedelta(days=1), | ||||
|                        location=[-87.677137, 41.909889]).save() | ||||
|         event2 = Event(title="Coltrane Motion @ Bottom of the Hill", | ||||
|                        date=datetime.now() - timedelta(days=10), | ||||
|                        location=[-122.4194155, 37.7749295]).save() | ||||
|         event3 = Event(title="Coltrane Motion @ Empty Bottle", | ||||
|                        date=datetime.now(), | ||||
|                        location=[-87.686638, 41.900474]).save() | ||||
|  | ||||
|         # find all events "near" pitchfork office, chicago. | ||||
|         # note that "near" will show the san francisco event, too, | ||||
|         # although it sorts to last. | ||||
|         events = self.Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event1, event3, event2]) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = self.Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         events = events.order_by("-date") | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event3, event1, event2]) | ||||
|  | ||||
|     def test_near_and_max_distance(self): | ||||
|         """Ensure the "max_distance" operator works alongside the "near" | ||||
|         operator. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data() | ||||
|  | ||||
|         # find events within 10 degrees of san francisco | ||||
|         point = [-122.415579, 37.7566023] | ||||
|         events = self.Event.objects(location__near=point, | ||||
|                                     location__max_distance=10) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|     # $minDistance was added in MongoDB v2.6, but continued being buggy | ||||
|     # until v3.0; skip for older versions | ||||
|     @needs_mongodb_v3 | ||||
|     def test_near_and_min_distance(self): | ||||
|         """Ensure the "min_distance" operator works alongside the "near" | ||||
|         operator. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data() | ||||
|  | ||||
|         # find events at least 10 degrees away of san francisco | ||||
|         point = [-122.415579, 37.7566023] | ||||
|         events = self.Event.objects(location__near=point, | ||||
|                                     location__min_distance=10) | ||||
|         self.assertEqual(events.count(), 2) | ||||
|  | ||||
|     def test_within_distance(self): | ||||
|         """Make sure the "within_distance" operator works.""" | ||||
|         event1, event2, event3 = self._create_event_data() | ||||
|  | ||||
|         # find events within 5 degrees of pitchfork office, chicago | ||||
|         point_and_distance = [[-87.67892, 41.9120459], 5] | ||||
|         events = self.Event.objects( | ||||
|             location__within_distance=point_and_distance) | ||||
|         events = Event.objects(location__within_distance=point_and_distance) | ||||
|         self.assertEqual(events.count(), 2) | ||||
|         events = list(events) | ||||
|         self.assertTrue(event2 not in events) | ||||
|         self.assertTrue(event1 in events) | ||||
|         self.assertTrue(event3 in events) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         events = events.order_by("-date") | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event3, event1, event2]) | ||||
|  | ||||
|         # find events within 10 degrees of san francisco | ||||
|         point = [-122.415579, 37.7566023] | ||||
|         events = Event.objects(location__near=point, location__max_distance=10) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|         # find events at least 10 degrees away of san francisco | ||||
|         point = [-122.415579, 37.7566023] | ||||
|         events = Event.objects(location__near=point, location__min_distance=10) | ||||
|         # The following real test passes on MongoDB 3 but minDistance seems | ||||
|         # buggy on older MongoDB versions | ||||
|         if get_connection().server_info()['versionArray'][0] > 2: | ||||
|             self.assertEqual(events.count(), 2) | ||||
|         else: | ||||
|             self.assertTrue(events.count() >= 2) | ||||
|  | ||||
|         # find events within 10 degrees of san francisco | ||||
|         point_and_distance = [[-122.415579, 37.7566023], 10] | ||||
|         events = self.Event.objects( | ||||
|             location__within_distance=point_and_distance) | ||||
|         events = Event.objects(location__within_distance=point_and_distance) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|         # find events within 1 degree of greenpoint, broolyn, nyc, ny | ||||
|         point_and_distance = [[-73.9509714, 40.7237134], 1] | ||||
|         events = self.Event.objects( | ||||
|             location__within_distance=point_and_distance) | ||||
|         events = Event.objects(location__within_distance=point_and_distance) | ||||
|         self.assertEqual(events.count(), 0) | ||||
|  | ||||
|         # ensure ordering is respected by "within_distance" | ||||
|         point_and_distance = [[-87.67892, 41.9120459], 10] | ||||
|         events = self.Event.objects( | ||||
|             location__within_distance=point_and_distance) | ||||
|         events = Event.objects(location__within_distance=point_and_distance) | ||||
|         events = events.order_by("-date") | ||||
|         self.assertEqual(events.count(), 2) | ||||
|         self.assertEqual(events[0], event3) | ||||
|  | ||||
|     def test_within_box(self): | ||||
|         """Ensure the "within_box" operator works.""" | ||||
|         event1, event2, event3 = self._create_event_data() | ||||
|  | ||||
|         # check that within_box works | ||||
|         box = [(-125.0, 35.0), (-100.0, 40.0)] | ||||
|         events = self.Event.objects(location__within_box=box) | ||||
|         events = Event.objects(location__within_box=box) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0].id, event2.id) | ||||
|  | ||||
|     def test_within_polygon(self): | ||||
|         """Ensure the "within_polygon" operator works.""" | ||||
|         event1, event2, event3 = self._create_event_data() | ||||
|  | ||||
|         polygon = [ | ||||
|             (-87.694445, 41.912114), | ||||
|             (-87.69084, 41.919395), | ||||
| @@ -141,7 +111,7 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|             (-87.654276, 41.911731), | ||||
|             (-87.656164, 41.898061), | ||||
|         ] | ||||
|         events = self.Event.objects(location__within_polygon=polygon) | ||||
|         events = Event.objects(location__within_polygon=polygon) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0].id, event1.id) | ||||
|  | ||||
| @@ -150,151 +120,13 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|             (-1.225891, 52.792797), | ||||
|             (-4.40094, 53.389881) | ||||
|         ] | ||||
|         events = self.Event.objects(location__within_polygon=polygon2) | ||||
|         events = Event.objects(location__within_polygon=polygon2) | ||||
|         self.assertEqual(events.count(), 0) | ||||
|  | ||||
|     def test_2dsphere_near(self): | ||||
|         """Make sure the "near" operator works with a PointField, which | ||||
|         corresponds to a 2dsphere index. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data( | ||||
|             point_field_class=PointField | ||||
|         ) | ||||
|     def test_geo_spatial_embedded(self): | ||||
|  | ||||
|         # find all events "near" pitchfork office, chicago. | ||||
|         # note that "near" will show the san francisco event, too, | ||||
|         # although it sorts to last. | ||||
|         events = self.Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event1, event3, event2]) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = self.Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         events = events.order_by("-date") | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event3, event1, event2]) | ||||
|  | ||||
|     def test_2dsphere_near_and_max_distance(self): | ||||
|         """Ensure the "max_distance" operator works alongside the "near" | ||||
|         operator with a 2dsphere index. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data( | ||||
|             point_field_class=PointField | ||||
|         ) | ||||
|  | ||||
|         # find events within 10km of san francisco | ||||
|         point = [-122.415579, 37.7566023] | ||||
|         events = self.Event.objects(location__near=point, | ||||
|                                     location__max_distance=10000) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|         # find events within 1km of greenpoint, broolyn, nyc, ny | ||||
|         events = self.Event.objects(location__near=[-73.9509714, 40.7237134], | ||||
|                                     location__max_distance=1000) | ||||
|         self.assertEqual(events.count(), 0) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = self.Event.objects( | ||||
|             location__near=[-87.67892, 41.9120459], | ||||
|             location__max_distance=10000 | ||||
|         ).order_by("-date") | ||||
|         self.assertEqual(events.count(), 2) | ||||
|         self.assertEqual(events[0], event3) | ||||
|  | ||||
|     def test_2dsphere_geo_within_box(self): | ||||
|         """Ensure the "geo_within_box" operator works with a 2dsphere | ||||
|         index. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data( | ||||
|             point_field_class=PointField | ||||
|         ) | ||||
|  | ||||
|         # check that within_box works | ||||
|         box = [(-125.0, 35.0), (-100.0, 40.0)] | ||||
|         events = self.Event.objects(location__geo_within_box=box) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0].id, event2.id) | ||||
|  | ||||
|     def test_2dsphere_geo_within_polygon(self): | ||||
|         """Ensure the "geo_within_polygon" operator works with a | ||||
|         2dsphere index. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data( | ||||
|             point_field_class=PointField | ||||
|         ) | ||||
|  | ||||
|         polygon = [ | ||||
|             (-87.694445, 41.912114), | ||||
|             (-87.69084, 41.919395), | ||||
|             (-87.681742, 41.927186), | ||||
|             (-87.654276, 41.911731), | ||||
|             (-87.656164, 41.898061), | ||||
|         ] | ||||
|         events = self.Event.objects(location__geo_within_polygon=polygon) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0].id, event1.id) | ||||
|  | ||||
|         polygon2 = [ | ||||
|             (-1.742249, 54.033586), | ||||
|             (-1.225891, 52.792797), | ||||
|             (-4.40094, 53.389881) | ||||
|         ] | ||||
|         events = self.Event.objects(location__geo_within_polygon=polygon2) | ||||
|         self.assertEqual(events.count(), 0) | ||||
|  | ||||
|     # $minDistance was added in MongoDB v2.6, but continued being buggy | ||||
|     # until v3.0; skip for older versions | ||||
|     @needs_mongodb_v3 | ||||
|     def test_2dsphere_near_and_min_max_distance(self): | ||||
|         """Ensure "min_distace" and "max_distance" operators work well | ||||
|         together with the "near" operator in a 2dsphere index. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data( | ||||
|             point_field_class=PointField | ||||
|         ) | ||||
|  | ||||
|         # ensure min_distance and max_distance combine well | ||||
|         events = self.Event.objects( | ||||
|             location__near=[-87.67892, 41.9120459], | ||||
|             location__min_distance=1000, | ||||
|             location__max_distance=10000 | ||||
|         ).order_by("-date") | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event3) | ||||
|  | ||||
|         # ensure ordering is respected by "near" with "min_distance" | ||||
|         events = self.Event.objects( | ||||
|             location__near=[-87.67892, 41.9120459], | ||||
|             location__min_distance=10000 | ||||
|         ).order_by("-date") | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|     def test_2dsphere_geo_within_center(self): | ||||
|         """Make sure the "geo_within_center" operator works with a | ||||
|         2dsphere index. | ||||
|         """ | ||||
|         event1, event2, event3 = self._create_event_data( | ||||
|             point_field_class=PointField | ||||
|         ) | ||||
|  | ||||
|         # find events within 5 degrees of pitchfork office, chicago | ||||
|         point_and_distance = [[-87.67892, 41.9120459], 2] | ||||
|         events = self.Event.objects( | ||||
|             location__geo_within_center=point_and_distance) | ||||
|         self.assertEqual(events.count(), 2) | ||||
|         events = list(events) | ||||
|         self.assertTrue(event2 not in events) | ||||
|         self.assertTrue(event1 in events) | ||||
|         self.assertTrue(event3 in events) | ||||
|  | ||||
|     def _test_embedded(self, point_field_class): | ||||
|         """Helper test method ensuring given point field class works | ||||
|         well in an embedded document. | ||||
|         """ | ||||
|         class Venue(EmbeddedDocument): | ||||
|             location = point_field_class() | ||||
|             location = GeoPointField() | ||||
|             name = StringField() | ||||
|  | ||||
|         class Event(Document): | ||||
| @@ -320,18 +152,16 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event1, event3, event2]) | ||||
|  | ||||
|     def test_geo_spatial_embedded(self): | ||||
|         """Make sure GeoPointField works properly in an embedded document.""" | ||||
|         self._test_embedded(point_field_class=GeoPointField) | ||||
|  | ||||
|     def test_2dsphere_point_embedded(self): | ||||
|         """Make sure PointField works properly in an embedded document.""" | ||||
|         self._test_embedded(point_field_class=PointField) | ||||
|  | ||||
|     # Needs MongoDB > 2.6.4 https://jira.mongodb.org/browse/SERVER-14039 | ||||
|     @needs_mongodb_v3 | ||||
|     def test_spherical_geospatial_operators(self): | ||||
|         """Ensure that spherical geospatial queries are working.""" | ||||
|         """Ensure that spherical geospatial queries are working | ||||
|         """ | ||||
|         # Needs MongoDB > 2.6.4 https://jira.mongodb.org/browse/SERVER-14039 | ||||
|         connection = get_connection() | ||||
|         info = connection.test.command('buildInfo') | ||||
|         mongodb_version = tuple([int(i) for i in info['version'].split('.')]) | ||||
|         if mongodb_version < (2, 6, 4): | ||||
|             raise SkipTest("Need MongoDB version 2.6.4+") | ||||
|  | ||||
|         class Point(Document): | ||||
|             location = GeoPointField() | ||||
|  | ||||
| @@ -351,10 +181,7 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|  | ||||
|         # Same behavior for _within_spherical_distance | ||||
|         points = Point.objects( | ||||
|             location__within_spherical_distance=[ | ||||
|                 [-122, 37.5], | ||||
|                 60 / earth_radius | ||||
|             ] | ||||
|             location__within_spherical_distance=[[-122, 37.5], 60 / earth_radius] | ||||
|         ) | ||||
|         self.assertEqual(points.count(), 2) | ||||
|  | ||||
| @@ -371,9 +198,14 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|         # Test query works with min_distance, being farer from one point | ||||
|         points = Point.objects(location__near_sphere=[-122, 37.8], | ||||
|                                location__min_distance=60 / earth_radius) | ||||
|         self.assertEqual(points.count(), 1) | ||||
|         far_point = points.first() | ||||
|         self.assertNotEqual(close_point, far_point) | ||||
|         # The following real test passes on MongoDB 3 but minDistance seems | ||||
|         # buggy on older MongoDB versions | ||||
|         if get_connection().server_info()['versionArray'][0] > 2: | ||||
|             self.assertEqual(points.count(), 1) | ||||
|             far_point = points.first() | ||||
|             self.assertNotEqual(close_point, far_point) | ||||
|         else: | ||||
|             self.assertTrue(points.count() >= 1) | ||||
|  | ||||
|         # Finds both points, but orders the north point first because it's | ||||
|         # closer to the reference point to the north. | ||||
| @@ -392,15 +224,141 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|         # Finds only one point because only the first point is within 60km of | ||||
|         # the reference point to the south. | ||||
|         points = Point.objects( | ||||
|             location__within_spherical_distance=[ | ||||
|                 [-122, 36.5], | ||||
|                 60 / earth_radius | ||||
|             ] | ||||
|         ) | ||||
|             location__within_spherical_distance=[[-122, 36.5], 60/earth_radius]) | ||||
|         self.assertEqual(points.count(), 1) | ||||
|         self.assertEqual(points[0].id, south_point.id) | ||||
|  | ||||
|     def test_2dsphere_point(self): | ||||
|  | ||||
|         class Event(Document): | ||||
|             title = StringField() | ||||
|             date = DateTimeField() | ||||
|             location = PointField() | ||||
|  | ||||
|             def __unicode__(self): | ||||
|                 return self.title | ||||
|  | ||||
|         Event.drop_collection() | ||||
|  | ||||
|         event1 = Event(title="Coltrane Motion @ Double Door", | ||||
|                        date=datetime.now() - timedelta(days=1), | ||||
|                        location=[-87.677137, 41.909889]) | ||||
|         event1.save() | ||||
|         event2 = Event(title="Coltrane Motion @ Bottom of the Hill", | ||||
|                        date=datetime.now() - timedelta(days=10), | ||||
|                        location=[-122.4194155, 37.7749295]).save() | ||||
|         event3 = Event(title="Coltrane Motion @ Empty Bottle", | ||||
|                        date=datetime.now(), | ||||
|                        location=[-87.686638, 41.900474]).save() | ||||
|  | ||||
|         # find all events "near" pitchfork office, chicago. | ||||
|         # note that "near" will show the san francisco event, too, | ||||
|         # although it sorts to last. | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event1, event3, event2]) | ||||
|  | ||||
|         # find events within 5 degrees of pitchfork office, chicago | ||||
|         point_and_distance = [[-87.67892, 41.9120459], 2] | ||||
|         events = Event.objects(location__geo_within_center=point_and_distance) | ||||
|         self.assertEqual(events.count(), 2) | ||||
|         events = list(events) | ||||
|         self.assertTrue(event2 not in events) | ||||
|         self.assertTrue(event1 in events) | ||||
|         self.assertTrue(event3 in events) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459]) | ||||
|         events = events.order_by("-date") | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event3, event1, event2]) | ||||
|  | ||||
|         # find events within 10km of san francisco | ||||
|         point = [-122.415579, 37.7566023] | ||||
|         events = Event.objects(location__near=point, location__max_distance=10000) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|         # find events within 1km of greenpoint, broolyn, nyc, ny | ||||
|         events = Event.objects(location__near=[-73.9509714, 40.7237134], location__max_distance=1000) | ||||
|         self.assertEqual(events.count(), 0) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459], | ||||
|                                location__max_distance=10000).order_by("-date") | ||||
|         self.assertEqual(events.count(), 2) | ||||
|         self.assertEqual(events[0], event3) | ||||
|  | ||||
|         # ensure min_distance and max_distance combine well | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459], | ||||
|                                location__min_distance=1000, | ||||
|                                location__max_distance=10000).order_by("-date") | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event3) | ||||
|  | ||||
|         # ensure ordering is respected by "near" | ||||
|         events = Event.objects(location__near=[-87.67892, 41.9120459], | ||||
|                                # location__min_distance=10000 | ||||
|                                location__min_distance=10000).order_by("-date") | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0], event2) | ||||
|  | ||||
|         # check that within_box works | ||||
|         box = [(-125.0, 35.0), (-100.0, 40.0)] | ||||
|         events = Event.objects(location__geo_within_box=box) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0].id, event2.id) | ||||
|  | ||||
|         polygon = [ | ||||
|             (-87.694445, 41.912114), | ||||
|             (-87.69084, 41.919395), | ||||
|             (-87.681742, 41.927186), | ||||
|             (-87.654276, 41.911731), | ||||
|             (-87.656164, 41.898061), | ||||
|         ] | ||||
|         events = Event.objects(location__geo_within_polygon=polygon) | ||||
|         self.assertEqual(events.count(), 1) | ||||
|         self.assertEqual(events[0].id, event1.id) | ||||
|  | ||||
|         polygon2 = [ | ||||
|             (-1.742249, 54.033586), | ||||
|             (-1.225891, 52.792797), | ||||
|             (-4.40094, 53.389881) | ||||
|         ] | ||||
|         events = Event.objects(location__geo_within_polygon=polygon2) | ||||
|         self.assertEqual(events.count(), 0) | ||||
|  | ||||
|     def test_2dsphere_point_embedded(self): | ||||
|  | ||||
|         class Venue(EmbeddedDocument): | ||||
|             location = GeoPointField() | ||||
|             name = StringField() | ||||
|  | ||||
|         class Event(Document): | ||||
|             title = StringField() | ||||
|             venue = EmbeddedDocumentField(Venue) | ||||
|  | ||||
|         Event.drop_collection() | ||||
|  | ||||
|         venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889]) | ||||
|         venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295]) | ||||
|  | ||||
|         event1 = Event(title="Coltrane Motion @ Double Door", | ||||
|                        venue=venue1).save() | ||||
|         event2 = Event(title="Coltrane Motion @ Bottom of the Hill", | ||||
|                        venue=venue2).save() | ||||
|         event3 = Event(title="Coltrane Motion @ Empty Bottle", | ||||
|                        venue=venue1).save() | ||||
|  | ||||
|         # find all events "near" pitchfork office, chicago. | ||||
|         # note that "near" will show the san francisco event, too, | ||||
|         # although it sorts to last. | ||||
|         events = Event.objects(venue__location__near=[-87.67892, 41.9120459]) | ||||
|         self.assertEqual(events.count(), 3) | ||||
|         self.assertEqual(list(events), [event1, event3, event2]) | ||||
|  | ||||
|     def test_linestring(self): | ||||
|  | ||||
|         class Road(Document): | ||||
|             name = StringField() | ||||
|             line = LineStringField() | ||||
| @@ -456,6 +414,7 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|         self.assertEqual(1, roads) | ||||
|  | ||||
|     def test_polygon(self): | ||||
|  | ||||
|         class Road(Document): | ||||
|             name = StringField() | ||||
|             poly = PolygonField() | ||||
| @@ -552,6 +511,5 @@ class GeoQueriesTest(MongoDBTestCase): | ||||
|         loc = Location.objects.as_pymongo()[0] | ||||
|         self.assertEqual(loc["poly"], {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]}) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -1,3 +1,6 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import connect, Document, IntField | ||||
|   | ||||
| @@ -9,32 +9,57 @@ from nose.plugins.skip import SkipTest | ||||
| import pymongo | ||||
| from pymongo.errors import ConfigurationError | ||||
| from pymongo.read_preferences import ReadPreference | ||||
| import six | ||||
|  | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_connection, get_db | ||||
| from mongoengine.context_managers import query_counter, switch_db | ||||
| from mongoengine.errors import InvalidQueryError | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| from mongoengine.python_support import IS_PYMONGO_3, PY3 | ||||
| from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, | ||||
|                                   QuerySet, QuerySetManager, queryset_manager) | ||||
|  | ||||
| from tests.utils import needs_mongodb_v26, skip_pymongo3 | ||||
|  | ||||
|  | ||||
| __all__ = ("QuerySetTest",) | ||||
|  | ||||
|  | ||||
| class db_ops_tracker(query_counter): | ||||
|  | ||||
|     def get_ops(self): | ||||
|         ignore_query = { | ||||
|             'ns': {'$ne': '%s.system.indexes' % self.db.name}, | ||||
|             'command.count': {'$ne': 'system.profile'} | ||||
|         } | ||||
|         ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} | ||||
|         return list(self.db.system.profile.find(ignore_query)) | ||||
|  | ||||
|  | ||||
| def skip_older_mongodb(f): | ||||
|     def _inner(*args, **kwargs): | ||||
|         connection = get_connection() | ||||
|         info = connection.test.command('buildInfo') | ||||
|         mongodb_version = tuple([int(i) for i in info['version'].split('.')]) | ||||
|  | ||||
|         if mongodb_version < (2, 6): | ||||
|             raise SkipTest("Need MongoDB version 2.6+") | ||||
|  | ||||
|         return f(*args, **kwargs) | ||||
|  | ||||
|     _inner.__name__ = f.__name__ | ||||
|     _inner.__doc__ = f.__doc__ | ||||
|  | ||||
|     return _inner | ||||
|  | ||||
|  | ||||
| def skip_pymongo3(f): | ||||
|     def _inner(*args, **kwargs): | ||||
|  | ||||
|         if IS_PYMONGO_3: | ||||
|             raise SkipTest("Useless with PyMongo 3+") | ||||
|  | ||||
|         return f(*args, **kwargs) | ||||
|  | ||||
|     _inner.__name__ = f.__name__ | ||||
|     _inner.__doc__ = f.__doc__ | ||||
|  | ||||
|     return _inner | ||||
|  | ||||
|  | ||||
| class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
| @@ -69,120 +94,67 @@ class QuerySetTest(unittest.TestCase): | ||||
|             author = ReferenceField(self.Person) | ||||
|             author2 = GenericReferenceField() | ||||
|  | ||||
|         # test addressing a field from a reference | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|         def test_reference(): | ||||
|             list(BlogPost.objects(author__name="test")) | ||||
|  | ||||
|         # should fail for a generic reference as well | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|         self.assertRaises(InvalidQueryError, test_reference) | ||||
|  | ||||
|         def test_generic_reference(): | ||||
|             list(BlogPost.objects(author2__name="test")) | ||||
|  | ||||
|     def test_find(self): | ||||
|         """Ensure that a query returns a valid set of results.""" | ||||
|         user_a = self.Person.objects.create(name='User A', age=20) | ||||
|         user_b = self.Person.objects.create(name='User B', age=30) | ||||
|         """Ensure that a query returns a valid set of results. | ||||
|         """ | ||||
|         self.Person(name="User A", age=20).save() | ||||
|         self.Person(name="User B", age=30).save() | ||||
|  | ||||
|         # Find all people in the collection | ||||
|         people = self.Person.objects | ||||
|         self.assertEqual(people.count(), 2) | ||||
|         results = list(people) | ||||
|  | ||||
|         self.assertTrue(isinstance(results[0], self.Person)) | ||||
|         self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode))) | ||||
|  | ||||
|         self.assertEqual(results[0], user_a) | ||||
|         self.assertEqual(results[0].name, 'User A') | ||||
|         self.assertEqual(results[0].name, "User A") | ||||
|         self.assertEqual(results[0].age, 20) | ||||
|  | ||||
|         self.assertEqual(results[1], user_b) | ||||
|         self.assertEqual(results[1].name, 'User B') | ||||
|         self.assertEqual(results[1].name, "User B") | ||||
|         self.assertEqual(results[1].age, 30) | ||||
|  | ||||
|         # Filter people by age | ||||
|         # Use a query to filter the people found to just person1 | ||||
|         people = self.Person.objects(age=20) | ||||
|         self.assertEqual(people.count(), 1) | ||||
|         person = people.next() | ||||
|         self.assertEqual(person, user_a) | ||||
|         self.assertEqual(person.name, "User A") | ||||
|         self.assertEqual(person.age, 20) | ||||
|  | ||||
|     def test_limit(self): | ||||
|         """Ensure that QuerySet.limit works as expected.""" | ||||
|         user_a = self.Person.objects.create(name='User A', age=20) | ||||
|         user_b = self.Person.objects.create(name='User B', age=30) | ||||
|  | ||||
|         # Test limit on a new queryset | ||||
|         # Test limit | ||||
|         people = list(self.Person.objects.limit(1)) | ||||
|         self.assertEqual(len(people), 1) | ||||
|         self.assertEqual(people[0], user_a) | ||||
|         self.assertEqual(people[0].name, 'User A') | ||||
|  | ||||
|         # Test limit on an existing queryset | ||||
|         people = self.Person.objects | ||||
|         self.assertEqual(len(people), 2) | ||||
|         people2 = people.limit(1) | ||||
|         self.assertEqual(len(people), 2) | ||||
|         self.assertEqual(len(people2), 1) | ||||
|         self.assertEqual(people2[0], user_a) | ||||
|  | ||||
|         # Test chaining of only after limit | ||||
|         person = self.Person.objects().limit(1).only('name').first() | ||||
|         self.assertEqual(person, user_a) | ||||
|         self.assertEqual(person.name, 'User A') | ||||
|         self.assertEqual(person.age, None) | ||||
|  | ||||
|     def test_skip(self): | ||||
|         """Ensure that QuerySet.skip works as expected.""" | ||||
|         user_a = self.Person.objects.create(name='User A', age=20) | ||||
|         user_b = self.Person.objects.create(name='User B', age=30) | ||||
|  | ||||
|         # Test skip on a new queryset | ||||
|         # Test skip | ||||
|         people = list(self.Person.objects.skip(1)) | ||||
|         self.assertEqual(len(people), 1) | ||||
|         self.assertEqual(people[0], user_b) | ||||
|         self.assertEqual(people[0].name, 'User B') | ||||
|  | ||||
|         # Test skip on an existing queryset | ||||
|         people = self.Person.objects | ||||
|         self.assertEqual(len(people), 2) | ||||
|         people2 = people.skip(1) | ||||
|         self.assertEqual(len(people), 2) | ||||
|         self.assertEqual(len(people2), 1) | ||||
|         self.assertEqual(people2[0], user_b) | ||||
|  | ||||
|         # Test chaining of only after skip | ||||
|         person = self.Person.objects().skip(1).only('name').first() | ||||
|         self.assertEqual(person, user_b) | ||||
|         self.assertEqual(person.name, 'User B') | ||||
|         self.assertEqual(person.age, None) | ||||
|  | ||||
|     def test_slice(self): | ||||
|         """Ensure slicing a queryset works as expected.""" | ||||
|         user_a = self.Person.objects.create(name='User A', age=20) | ||||
|         user_b = self.Person.objects.create(name='User B', age=30) | ||||
|         user_c = self.Person.objects.create(name="User C", age=40) | ||||
|         person3 = self.Person(name="User C", age=40) | ||||
|         person3.save() | ||||
|  | ||||
|         # Test slice limit | ||||
|         people = list(self.Person.objects[:2]) | ||||
|         self.assertEqual(len(people), 2) | ||||
|         self.assertEqual(people[0], user_a) | ||||
|         self.assertEqual(people[1], user_b) | ||||
|         self.assertEqual(people[0].name, 'User A') | ||||
|         self.assertEqual(people[1].name, 'User B') | ||||
|  | ||||
|         # Test slice skip | ||||
|         people = list(self.Person.objects[1:]) | ||||
|         self.assertEqual(len(people), 2) | ||||
|         self.assertEqual(people[0], user_b) | ||||
|         self.assertEqual(people[1], user_c) | ||||
|         self.assertEqual(people[0].name, 'User B') | ||||
|         self.assertEqual(people[1].name, 'User C') | ||||
|  | ||||
|         # Test slice limit and skip | ||||
|         people = list(self.Person.objects[1:2]) | ||||
|         self.assertEqual(len(people), 1) | ||||
|         self.assertEqual(people[0], user_b) | ||||
|  | ||||
|         # Test slice limit and skip on an existing queryset | ||||
|         people = self.Person.objects | ||||
|         self.assertEqual(len(people), 3) | ||||
|         people2 = people[1:2] | ||||
|         self.assertEqual(len(people2), 1) | ||||
|         self.assertEqual(people2[0], user_b) | ||||
|         self.assertEqual(people[0].name, 'User B') | ||||
|  | ||||
|         # Test slice limit and skip cursor reset | ||||
|         qs = self.Person.objects[1:2] | ||||
| @@ -193,7 +165,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(len(people), 1) | ||||
|         self.assertEqual(people[0].name, 'User B') | ||||
|  | ||||
|         # Test empty slice | ||||
|         people = list(self.Person.objects[1:1]) | ||||
|         self.assertEqual(len(people), 0) | ||||
|  | ||||
| @@ -203,7 +174,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # Test larger slice __repr__ | ||||
|         self.Person.objects.delete() | ||||
|         for i in range(55): | ||||
|         for i in xrange(55): | ||||
|             self.Person(name='A%s' % i, age=i).save() | ||||
|  | ||||
|         self.assertEqual(self.Person.objects.count(), 55) | ||||
| @@ -213,6 +184,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual("[<Person: Person object>, <Person: Person object>]", | ||||
|                          "%s" % self.Person.objects[51:53]) | ||||
|  | ||||
|         # Test only after limit | ||||
|         self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None) | ||||
|  | ||||
|         # Test only after skip | ||||
|         self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None) | ||||
|  | ||||
|     def test_find_one(self): | ||||
|         """Ensure that a query using find_one returns a valid result. | ||||
|         """ | ||||
| @@ -241,15 +218,14 @@ class QuerySetTest(unittest.TestCase): | ||||
|         person = self.Person.objects[1] | ||||
|         self.assertEqual(person.name, "User B") | ||||
|  | ||||
|         with self.assertRaises(IndexError): | ||||
|             self.Person.objects[2] | ||||
|         self.assertRaises(IndexError, self.Person.objects.__getitem__, 2) | ||||
|  | ||||
|         # Find a document using just the object id | ||||
|         person = self.Person.objects.with_id(person1.id) | ||||
|         self.assertEqual(person.name, "User A") | ||||
|  | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             self.Person.objects(name="User A").with_id(person1.id) | ||||
|         self.assertRaises( | ||||
|             InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id) | ||||
|  | ||||
|     def test_find_only_one(self): | ||||
|         """Ensure that a query using ``get`` returns at most one result. | ||||
| @@ -387,8 +363,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # test invalid batch size | ||||
|         qs = A.objects.batch_size(-1) | ||||
|         with self.assertRaises(ValueError): | ||||
|             list(qs) | ||||
|         self.assertRaises(ValueError, lambda: list(qs)) | ||||
|  | ||||
|     def test_update_write_concern(self): | ||||
|         """Test that passing write_concern works""" | ||||
| @@ -417,14 +392,18 @@ class QuerySetTest(unittest.TestCase): | ||||
|         """Test to ensure that update is passed a value to update to""" | ||||
|         self.Person.drop_collection() | ||||
|  | ||||
|         author = self.Person.objects.create(name='Test User') | ||||
|         author = self.Person(name='Test User') | ||||
|         author.save() | ||||
|  | ||||
|         with self.assertRaises(OperationError): | ||||
|         def update_raises(): | ||||
|             self.Person.objects(pk=author.pk).update({}) | ||||
|  | ||||
|         with self.assertRaises(OperationError): | ||||
|         def update_one_raises(): | ||||
|             self.Person.objects(pk=author.pk).update_one({}) | ||||
|  | ||||
|         self.assertRaises(OperationError, update_raises) | ||||
|         self.assertRaises(OperationError, update_one_raises) | ||||
|  | ||||
|     def test_update_array_position(self): | ||||
|         """Ensure that updating by array position works. | ||||
|  | ||||
| @@ -452,8 +431,8 @@ class QuerySetTest(unittest.TestCase): | ||||
|         Blog.objects.create(posts=[post2, post1]) | ||||
|  | ||||
|         # Update all of the first comments of second posts of all blogs | ||||
|         Blog.objects().update(set__posts__1__comments__0__name='testc') | ||||
|         testc_blogs = Blog.objects(posts__1__comments__0__name='testc') | ||||
|         Blog.objects().update(set__posts__1__comments__0__name="testc") | ||||
|         testc_blogs = Blog.objects(posts__1__comments__0__name="testc") | ||||
|         self.assertEqual(testc_blogs.count(), 2) | ||||
|  | ||||
|         Blog.drop_collection() | ||||
| @@ -462,13 +441,14 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # Update only the first blog returned by the query | ||||
|         Blog.objects().update_one( | ||||
|             set__posts__1__comments__1__name='testc') | ||||
|         testc_blogs = Blog.objects(posts__1__comments__1__name='testc') | ||||
|             set__posts__1__comments__1__name="testc") | ||||
|         testc_blogs = Blog.objects(posts__1__comments__1__name="testc") | ||||
|         self.assertEqual(testc_blogs.count(), 1) | ||||
|  | ||||
|         # Check that using this indexing syntax on a non-list fails | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             Blog.objects().update(set__posts__1__comments__0__name__1='asdf') | ||||
|         def non_list_indexing(): | ||||
|             Blog.objects().update(set__posts__1__comments__0__name__1="asdf") | ||||
|         self.assertRaises(InvalidQueryError, non_list_indexing) | ||||
|  | ||||
|         Blog.drop_collection() | ||||
|  | ||||
| @@ -536,12 +516,15 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) | ||||
|  | ||||
|         # Nested updates arent supported yet.. | ||||
|         with self.assertRaises(OperationError): | ||||
|         def update_nested(): | ||||
|             Simple.drop_collection() | ||||
|             Simple(x=[{'test': [1, 2, 3, 4]}]).save() | ||||
|             Simple.objects(x__test=2).update(set__x__S__test__S=3) | ||||
|             self.assertEqual(simple.x, [1, 2, 3, 4]) | ||||
|  | ||||
|         self.assertRaises(OperationError, update_nested) | ||||
|         Simple.drop_collection() | ||||
|  | ||||
|     def test_update_using_positional_operator_embedded_document(self): | ||||
|         """Ensure that the embedded documents can be updated using the positional | ||||
|         operator.""" | ||||
| @@ -571,23 +554,16 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(post.comments[0].by, 'joe') | ||||
|         self.assertEqual(post.comments[0].votes.score, 4) | ||||
|  | ||||
|     @needs_mongodb_v26 | ||||
|     def test_update_min_max(self): | ||||
|         class Scores(Document): | ||||
|             high_score = IntField() | ||||
|             low_score = IntField() | ||||
|  | ||||
|         scores = Scores.objects.create(high_score=800, low_score=200) | ||||
|  | ||||
|         scores = Scores(high_score=800, low_score=200) | ||||
|         scores.save() | ||||
|         Scores.objects(id=scores.id).update(min__low_score=150) | ||||
|         self.assertEqual(Scores.objects.get(id=scores.id).low_score, 150) | ||||
|         self.assertEqual(Scores.objects(id=scores.id).get().low_score, 150) | ||||
|         Scores.objects(id=scores.id).update(min__low_score=250) | ||||
|         self.assertEqual(Scores.objects.get(id=scores.id).low_score, 150) | ||||
|  | ||||
|         Scores.objects(id=scores.id).update(max__high_score=1000) | ||||
|         self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) | ||||
|         Scores.objects(id=scores.id).update(max__high_score=500) | ||||
|         self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) | ||||
|         self.assertEqual(Scores.objects(id=scores.id).get().low_score, 150) | ||||
|  | ||||
|     def test_updates_can_have_match_operators(self): | ||||
|  | ||||
| @@ -641,11 +617,11 @@ class QuerySetTest(unittest.TestCase): | ||||
|             members = DictField() | ||||
|  | ||||
|         club = Club() | ||||
|         club.members['John'] = {'gender': 'M', 'age': 13} | ||||
|         club.members['John'] = dict(gender="M", age=13) | ||||
|         club.save() | ||||
|  | ||||
|         Club.objects().update( | ||||
|             set__members={"John": {'gender': 'F', 'age': 14}}) | ||||
|             set__members={"John": dict(gender="F", age=14)}) | ||||
|  | ||||
|         club = Club.objects().first() | ||||
|         self.assertEqual(club.members['John']['gender'], "F") | ||||
| @@ -826,7 +802,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             post2 = Post(comments=[comment2, comment2]) | ||||
|  | ||||
|             blogs = [] | ||||
|             for i in range(1, 100): | ||||
|             for i in xrange(1, 100): | ||||
|                 blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) | ||||
|  | ||||
|             Blog.objects.insert(blogs, load_bulk=False) | ||||
| @@ -863,31 +839,30 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertEqual(Blog.objects.count(), 2) | ||||
|  | ||||
|         # test inserting an existing document (shouldn't be allowed) | ||||
|         with self.assertRaises(OperationError): | ||||
|             blog = Blog.objects.first() | ||||
|             Blog.objects.insert(blog) | ||||
|  | ||||
|         # test inserting a query set | ||||
|         with self.assertRaises(OperationError): | ||||
|         # test handles people trying to upsert | ||||
|         def throw_operation_error(): | ||||
|             blogs = Blog.objects | ||||
|             Blog.objects.insert(blogs) | ||||
|  | ||||
|         # insert a new doc | ||||
|         self.assertRaises(OperationError, throw_operation_error) | ||||
|  | ||||
|         # Test can insert new doc | ||||
|         new_post = Blog(title="code123", id=ObjectId()) | ||||
|         Blog.objects.insert(new_post) | ||||
|  | ||||
|         class Author(Document): | ||||
|             pass | ||||
|  | ||||
|         # try inserting a different document class | ||||
|         with self.assertRaises(OperationError): | ||||
|         # test handles other classes being inserted | ||||
|         def throw_operation_error_wrong_doc(): | ||||
|             class Author(Document): | ||||
|                 pass | ||||
|             Blog.objects.insert(Author()) | ||||
|  | ||||
|         # try inserting a non-document | ||||
|         with self.assertRaises(OperationError): | ||||
|         self.assertRaises(OperationError, throw_operation_error_wrong_doc) | ||||
|  | ||||
|         def throw_operation_error_not_a_document(): | ||||
|             Blog.objects.insert("HELLO WORLD") | ||||
|  | ||||
|         self.assertRaises(OperationError, throw_operation_error_not_a_document) | ||||
|  | ||||
|         Blog.drop_collection() | ||||
|  | ||||
|         blog1 = Blog(title="code", posts=[post1, post2]) | ||||
| @@ -907,13 +882,14 @@ class QuerySetTest(unittest.TestCase): | ||||
|         blog3 = Blog(title="baz", posts=[post1, post2]) | ||||
|         Blog.objects.insert([blog1, blog2]) | ||||
|  | ||||
|         with self.assertRaises(NotUniqueError): | ||||
|         def throw_operation_error_not_unique(): | ||||
|             Blog.objects.insert([blog2, blog3]) | ||||
|  | ||||
|         self.assertRaises(NotUniqueError, throw_operation_error_not_unique) | ||||
|         self.assertEqual(Blog.objects.count(), 2) | ||||
|  | ||||
|         Blog.objects.insert([blog2, blog3], | ||||
|                             write_concern={"w": 0, 'continue_on_error': True}) | ||||
|         Blog.objects.insert([blog2, blog3], write_concern={"w": 0, | ||||
|                                                            'continue_on_error': True}) | ||||
|         self.assertEqual(Blog.objects.count(), 3) | ||||
|  | ||||
|     def test_get_changed_fields_query_count(self): | ||||
| @@ -991,7 +967,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(person.name, "User A") | ||||
|         self.assertEqual(person.age, 20) | ||||
|  | ||||
|     @needs_mongodb_v26 | ||||
|     @skip_older_mongodb | ||||
|     @skip_pymongo3 | ||||
|     def test_cursor_args(self): | ||||
|         """Ensures the cursor args can be set as expected | ||||
| @@ -1046,7 +1022,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Doc.drop_collection() | ||||
|  | ||||
|         for i in range(1000): | ||||
|         for i in xrange(1000): | ||||
|             Doc(number=i).save() | ||||
|  | ||||
|         docs = Doc.objects.order_by('number') | ||||
| @@ -1200,7 +1176,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         qs = list(qs) | ||||
|         expected = list(expected) | ||||
|         self.assertEqual(len(qs), len(expected)) | ||||
|         for i in range(len(qs)): | ||||
|         for i in xrange(len(qs)): | ||||
|             self.assertEqual(qs[i], expected[i]) | ||||
|  | ||||
|     def test_ordering(self): | ||||
| @@ -1240,8 +1216,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertSequence(qs, expected) | ||||
|  | ||||
|     def test_clear_ordering(self): | ||||
|         """Ensure that the default ordering can be cleared by calling | ||||
|         order_by() w/o any arguments. | ||||
|         """ Ensure that the default ordering can be cleared by calling order_by(). | ||||
|         """ | ||||
|         class BlogPost(Document): | ||||
|             title = StringField() | ||||
| @@ -1253,35 +1228,16 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         # default ordering should be used by default | ||||
|         with db_ops_tracker() as q: | ||||
|             BlogPost.objects.filter(title='whatever').first() | ||||
|             self.assertEqual(len(q.get_ops()), 1) | ||||
|             self.assertEqual( | ||||
|                 q.get_ops()[0]['query']['$orderby'], | ||||
|                 {'published_date': -1} | ||||
|             ) | ||||
|                 q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) | ||||
|  | ||||
|         # calling order_by() should clear the default ordering | ||||
|         with db_ops_tracker() as q: | ||||
|             BlogPost.objects.filter(title='whatever').order_by().first() | ||||
|             self.assertEqual(len(q.get_ops()), 1) | ||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||
|  | ||||
|         # calling an explicit order_by should use a specified sort | ||||
|         with db_ops_tracker() as q: | ||||
|             BlogPost.objects.filter(title='whatever').order_by('published_date').first() | ||||
|             self.assertEqual(len(q.get_ops()), 1) | ||||
|             self.assertEqual( | ||||
|                 q.get_ops()[0]['query']['$orderby'], | ||||
|                 {'published_date': 1} | ||||
|             ) | ||||
|  | ||||
|         # calling order_by() after an explicit sort should clear it | ||||
|         with db_ops_tracker() as q: | ||||
|             qs = BlogPost.objects.filter(title='whatever').order_by('published_date') | ||||
|             qs.order_by().first() | ||||
|             self.assertEqual(len(q.get_ops()), 1) | ||||
|             print q.get_ops()[0]['query'] | ||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||
|  | ||||
|     def test_no_ordering_for_get(self): | ||||
| @@ -1311,7 +1267,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|     def test_find_embedded(self): | ||||
|         """Ensure that an embedded document is properly returned from | ||||
|         different manners of querying. | ||||
|         a query. | ||||
|         """ | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
| @@ -1322,9 +1278,8 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         user = User(name='Test User') | ||||
|         BlogPost.objects.create( | ||||
|             author=user, | ||||
|             author=User(name='Test User'), | ||||
|             content='Had a good coffee today...' | ||||
|         ) | ||||
|  | ||||
| @@ -1332,19 +1287,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|         result = BlogPost.objects.get(author__name=user.name) | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|         result = BlogPost.objects.get(author={'name': user.name}) | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|         # Fails, since the string is not a type that is able to represent the | ||||
|         # author's document structure (should be dict) | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             BlogPost.objects.get(author=user.name) | ||||
|  | ||||
|     def test_find_empty_embedded(self): | ||||
|         """Ensure that you can save and find an empty embedded document.""" | ||||
|         class User(EmbeddedDocument): | ||||
| @@ -1768,7 +1710,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Log.drop_collection() | ||||
|  | ||||
|         for i in range(10): | ||||
|         for i in xrange(10): | ||||
|             Log().save() | ||||
|  | ||||
|         Log.objects()[3:5].delete() | ||||
| @@ -1871,11 +1813,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|         post.reload() | ||||
|         self.assertEqual(post.hits, 10) | ||||
|  | ||||
|         # Negative dec operator is equal to a positive inc operator | ||||
|         BlogPost.objects.update_one(dec__hits=-1) | ||||
|         post.reload() | ||||
|         self.assertEqual(post.hits, 11) | ||||
|  | ||||
|         BlogPost.objects.update(push__tags='mongo') | ||||
|         post.reload() | ||||
|         self.assertTrue('mongo' in post.tags) | ||||
| @@ -1973,10 +1910,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|         Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') | ||||
|         self.assertEqual(Site.objects.first().collaborators, []) | ||||
|  | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|         def pull_all(): | ||||
|             Site.objects(id=s.id).update_one( | ||||
|                 pull_all__collaborators__user=['Ross']) | ||||
|  | ||||
|         self.assertRaises(InvalidQueryError, pull_all) | ||||
|  | ||||
|     def test_pull_from_nested_embedded(self): | ||||
|  | ||||
|         class User(EmbeddedDocument): | ||||
| @@ -2007,10 +1946,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|             pull__collaborators__unhelpful={'name': 'Frank'}) | ||||
|         self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) | ||||
|  | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|         def pull_all(): | ||||
|             Site.objects(id=s.id).update_one( | ||||
|                 pull_all__collaborators__helpful__name=['Ross']) | ||||
|  | ||||
|         self.assertRaises(InvalidQueryError, pull_all) | ||||
|  | ||||
|     def test_pull_from_nested_mapfield(self): | ||||
|  | ||||
|         class Collaborator(EmbeddedDocument): | ||||
| @@ -2039,10 +1980,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|             pull__collaborators__unhelpful={'user': 'Frank'}) | ||||
|         self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) | ||||
|  | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|         def pull_all(): | ||||
|             Site.objects(id=s.id).update_one( | ||||
|                 pull_all__collaborators__helpful__user=['Ross']) | ||||
|  | ||||
|         self.assertRaises(InvalidQueryError, pull_all) | ||||
|  | ||||
|     def test_update_one_pop_generic_reference(self): | ||||
|  | ||||
|         class BlogTag(Document): | ||||
| @@ -2667,7 +2610,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         BlogPost(hits=2, tags=['music', 'actors']).save() | ||||
|  | ||||
|         def test_assertions(f): | ||||
|             f = {key: int(val) for key, val in f.items()} | ||||
|             f = dict((key, int(val)) for key, val in f.items()) | ||||
|             self.assertEqual( | ||||
|                 set(['music', 'film', 'actors', 'watch']), set(f.keys())) | ||||
|             self.assertEqual(f['music'], 3) | ||||
| @@ -2682,7 +2625,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # Ensure query is taken into account | ||||
|         def test_assertions(f): | ||||
|             f = {key: int(val) for key, val in f.items()} | ||||
|             f = dict((key, int(val)) for key, val in f.items()) | ||||
|             self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys())) | ||||
|             self.assertEqual(f['music'], 2) | ||||
|             self.assertEqual(f['actors'], 1) | ||||
| @@ -2746,7 +2689,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         doc.save() | ||||
|  | ||||
|         def test_assertions(f): | ||||
|             f = {key: int(val) for key, val in f.items()} | ||||
|             f = dict((key, int(val)) for key, val in f.items()) | ||||
|             self.assertEqual( | ||||
|                 set(['62-3331-1656', '62-3332-1656']), set(f.keys())) | ||||
|             self.assertEqual(f['62-3331-1656'], 2) | ||||
| @@ -2760,7 +2703,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # Ensure query is taken into account | ||||
|         def test_assertions(f): | ||||
|             f = {key: int(val) for key, val in f.items()} | ||||
|             f = dict((key, int(val)) for key, val in f.items()) | ||||
|             self.assertEqual(set(['62-3331-1656']), set(f.keys())) | ||||
|             self.assertEqual(f['62-3331-1656'], 2) | ||||
|  | ||||
| @@ -2867,10 +2810,10 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Test.drop_collection() | ||||
|  | ||||
|         for i in range(50): | ||||
|         for i in xrange(50): | ||||
|             Test(val=1).save() | ||||
|  | ||||
|         for i in range(20): | ||||
|         for i in xrange(20): | ||||
|             Test(val=2).save() | ||||
|  | ||||
|         freqs = Test.objects.item_frequencies( | ||||
| @@ -3108,7 +3051,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertEqual(Foo.objects.distinct("bar"), [bar]) | ||||
|  | ||||
|     @needs_mongodb_v26 | ||||
|     @skip_older_mongodb | ||||
|     def test_text_indexes(self): | ||||
|         class News(Document): | ||||
|             title = StringField() | ||||
| @@ -3195,7 +3138,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             'brasil').order_by('$text_score').first() | ||||
|         self.assertEqual(item.get_text_score(), max_text_score) | ||||
|  | ||||
|     @needs_mongodb_v26 | ||||
|     @skip_older_mongodb | ||||
|     def test_distinct_handles_references_to_alias(self): | ||||
|         register_connection('testdb', 'mongoenginetest2') | ||||
|  | ||||
| @@ -3660,7 +3603,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Post.drop_collection() | ||||
|  | ||||
|         for i in range(10): | ||||
|         for i in xrange(10): | ||||
|             Post(title="Post %s" % i).save() | ||||
|  | ||||
|         self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True)) | ||||
| @@ -3675,7 +3618,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             pass | ||||
|  | ||||
|         MyDoc.drop_collection() | ||||
|         for i in range(0, 10): | ||||
|         for i in xrange(0, 10): | ||||
|             MyDoc().save() | ||||
|  | ||||
|         self.assertEqual(MyDoc.objects.count(), 10) | ||||
| @@ -3731,7 +3674,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         Number.drop_collection() | ||||
|  | ||||
|         for i in range(1, 101): | ||||
|         for i in xrange(1, 101): | ||||
|             t = Number(n=i) | ||||
|             t.save() | ||||
|  | ||||
| @@ -3878,9 +3821,11 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertTrue(a in results) | ||||
|         self.assertTrue(c in results) | ||||
|  | ||||
|         with self.assertRaises(TypeError): | ||||
|         def invalid_where(): | ||||
|             list(IntPair.objects.where(fielda__gte=3)) | ||||
|  | ||||
|         self.assertRaises(TypeError, invalid_where) | ||||
|  | ||||
|     def test_scalar(self): | ||||
|  | ||||
|         class Organization(Document): | ||||
| @@ -4136,7 +4081,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         # Test larger slice __repr__ | ||||
|         self.Person.objects.delete() | ||||
|         for i in range(55): | ||||
|         for i in xrange(55): | ||||
|             self.Person(name='A%s' % i, age=i).save() | ||||
|  | ||||
|         self.assertEqual(self.Person.objects.scalar('name').count(), 55) | ||||
| @@ -4144,7 +4089,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) | ||||
|         self.assertEqual( | ||||
|             "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) | ||||
|         if six.PY3: | ||||
|         if PY3: | ||||
|             self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( | ||||
|                 'age').scalar('name')[1:3]) | ||||
|             self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( | ||||
| @@ -4162,7 +4107,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         pks = self.Person.objects.order_by('age').scalar('pk')[1:3] | ||||
|         names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() | ||||
|         if six.PY3: | ||||
|         if PY3: | ||||
|             expected = "['A1', 'A2']" | ||||
|         else: | ||||
|             expected = "[u'A1', u'A2']" | ||||
| @@ -4518,7 +4463,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             name = StringField() | ||||
|  | ||||
|         Person.drop_collection() | ||||
|         for i in range(100): | ||||
|         for i in xrange(100): | ||||
|             Person(name="No: %s" % i).save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -4549,7 +4494,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             name = StringField() | ||||
|  | ||||
|         Person.drop_collection() | ||||
|         for i in range(100): | ||||
|         for i in xrange(100): | ||||
|             Person(name="No: %s" % i).save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -4593,7 +4538,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             fields = DictField() | ||||
|  | ||||
|         Noddy.drop_collection() | ||||
|         for i in range(100): | ||||
|         for i in xrange(100): | ||||
|             noddy = Noddy() | ||||
|             for j in range(20): | ||||
|                 noddy.fields["key" + str(j)] = "value " + str(j) | ||||
| @@ -4605,9 +4550,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEqual(counter, 100) | ||||
|  | ||||
|         self.assertEqual(len(list(docs)), 100) | ||||
|  | ||||
|         with self.assertRaises(TypeError): | ||||
|             len(docs) | ||||
|         self.assertRaises(TypeError, lambda: len(docs)) | ||||
|  | ||||
|         with query_counter() as q: | ||||
|             self.assertEqual(q, 0) | ||||
| @@ -4796,7 +4739,7 @@ class QuerySetTest(unittest.TestCase): | ||||
|             name = StringField() | ||||
|  | ||||
|         Person.drop_collection() | ||||
|         for i in range(100): | ||||
|         for i in xrange(100): | ||||
|             Person(name="No: %s" % i).save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -4870,7 +4813,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|             self.assertTrue(Person.objects._has_data(), | ||||
|                             'Cursor has data and returned False') | ||||
|  | ||||
|     @needs_mongodb_v26 | ||||
|     def test_queryset_aggregation_framework(self): | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
| @@ -4905,22 +4847,26 @@ class QuerySetTest(unittest.TestCase): | ||||
|             {'_id': p1.pk, 'name': "ISABELLA LUANNA"} | ||||
|         ]) | ||||
|  | ||||
|         data = Person.objects(age__gte=17, age__lte=40).order_by('-age').aggregate({ | ||||
|             '$group': { | ||||
|                 '_id': None, | ||||
|                 'total': {'$sum': 1}, | ||||
|                 'avg': {'$avg': '$age'} | ||||
|             } | ||||
|         }) | ||||
|         data = Person.objects( | ||||
|             age__gte=17, age__lte=40).order_by('-age').aggregate( | ||||
|                 {'$group': { | ||||
|                     '_id': None, | ||||
|                     'total': {'$sum': 1}, | ||||
|                     'avg': {'$avg': '$age'} | ||||
|                 } | ||||
|                 } | ||||
|  | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual(list(data), [ | ||||
|             {'_id': None, 'avg': 29, 'total': 2} | ||||
|         ]) | ||||
|  | ||||
|     def test_delete_count(self): | ||||
|         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] | ||||
|         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)] | ||||
|         self.assertEqual(self.Person.objects().delete(), 3)  # test ordinary QuerySey delete count | ||||
|  | ||||
|         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] | ||||
|         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)] | ||||
|  | ||||
|         self.assertEqual(self.Person.objects().skip(1).delete(), 2)  # test Document delete with existing documents | ||||
|  | ||||
| @@ -4929,14 +4875,12 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|     def test_max_time_ms(self): | ||||
|         # 778: max_time_ms can get only int or None as input | ||||
|         self.assertRaises(TypeError, | ||||
|                           self.Person.objects(name="name").max_time_ms, | ||||
|                           'not a number') | ||||
|         self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number") | ||||
|  | ||||
|     def test_subclass_field_query(self): | ||||
|         class Animal(Document): | ||||
|             is_mamal = BooleanField() | ||||
|             meta = {'allow_inheritance': True} | ||||
|             meta = dict(allow_inheritance=True) | ||||
|  | ||||
|         class Cat(Animal): | ||||
|             whiskers_length = FloatField() | ||||
| @@ -4952,13 +4896,11 @@ class QuerySetTest(unittest.TestCase): | ||||
|         self.assertEquals(Animal.objects(folded_ears=True).count(), 1) | ||||
|         self.assertEquals(Animal.objects(whiskers_length=5.1).count(), 1) | ||||
|  | ||||
|     def test_loop_over_invalid_id_does_not_crash(self): | ||||
|     def test_loop_via_invalid_id_does_not_crash(self): | ||||
|         class Person(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Person.drop_collection() | ||||
|  | ||||
|         Person._get_collection().insert({'name': 'a', 'id': ''}) | ||||
|         Person.objects.delete() | ||||
|         Person._get_collection().update({"name": "a"}, {"$set": {"_id": ""}}, upsert=True) | ||||
|         for p in Person.objects(): | ||||
|             self.assertEqual(p.name, 'a') | ||||
|  | ||||
| @@ -4976,85 +4918,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||
|  | ||||
|     def test_len_during_iteration(self): | ||||
|         """Tests that calling len on a queyset during iteration doesn't | ||||
|         stop paging. | ||||
|         """ | ||||
|         class Data(Document): | ||||
|             pass | ||||
|  | ||||
|         for i in range(300): | ||||
|             Data().save() | ||||
|  | ||||
|         records = Data.objects.limit(250) | ||||
|  | ||||
|         # This should pull all 250 docs from mongo and populate the result | ||||
|         # cache | ||||
|         len(records) | ||||
|  | ||||
|         # Assert that iterating over documents in the qs touches every | ||||
|         # document even if we call len(qs) midway through the iteration. | ||||
|         for i, r in enumerate(records): | ||||
|             if i == 58: | ||||
|                 len(records) | ||||
|         self.assertEqual(i, 249) | ||||
|  | ||||
|         # Assert the same behavior is true even if we didn't pre-populate the | ||||
|         # result cache. | ||||
|         records = Data.objects.limit(250) | ||||
|         for i, r in enumerate(records): | ||||
|             if i == 58: | ||||
|                 len(records) | ||||
|         self.assertEqual(i, 249) | ||||
|  | ||||
|     def test_iteration_within_iteration(self): | ||||
|         """You should be able to reliably iterate over all the documents | ||||
|         in a given queryset even if there are multiple iterations of it | ||||
|         happening at the same time. | ||||
|         """ | ||||
|         class Data(Document): | ||||
|             pass | ||||
|  | ||||
|         for i in range(300): | ||||
|             Data().save() | ||||
|  | ||||
|         qs = Data.objects.limit(250) | ||||
|         for i, doc in enumerate(qs): | ||||
|             for j, doc2 in enumerate(qs): | ||||
|                 pass | ||||
|  | ||||
|         self.assertEqual(i, 249) | ||||
|         self.assertEqual(j, 249) | ||||
|  | ||||
|     def test_in_operator_on_non_iterable(self): | ||||
|         """Ensure that using the `__in` operator on a non-iterable raises an | ||||
|         error. | ||||
|         """ | ||||
|         class User(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             content = StringField() | ||||
|             authors = ListField(ReferenceField(User)) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         author = User.objects.create(name='Test User') | ||||
|         post = BlogPost.objects.create(content='Had a good coffee today...', | ||||
|                                        authors=[author]) | ||||
|  | ||||
|         # Make sure using `__in` with a list works | ||||
|         blog_posts = BlogPost.objects(authors__in=[author]) | ||||
|         self.assertEqual(list(blog_posts), [post]) | ||||
|  | ||||
|         # Using `__in` with a non-iterable should raise a TypeError | ||||
|         self.assertRaises(TypeError, BlogPost.objects(authors__in=author.pk).count) | ||||
|  | ||||
|         # Using `__in` with a `Document` (which is seemingly iterable but not | ||||
|         # in a way we'd expect) should raise a TypeError, too | ||||
|         self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -238,8 +238,7 @@ class TransformTest(unittest.TestCase): | ||||
|         box = [(35.0, -125.0), (40.0, -100.0)] | ||||
|         # I *meant* to execute location__within_box=box | ||||
|         events = Event.objects(location__within=box) | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             events.count() | ||||
|         self.assertRaises(InvalidQueryError, lambda: events.count()) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   | ||||
| @@ -185,7 +185,7 @@ class QTest(unittest.TestCase): | ||||
|             x = IntField() | ||||
|  | ||||
|         TestDoc.drop_collection() | ||||
|         for i in range(1, 101): | ||||
|         for i in xrange(1, 101): | ||||
|             t = TestDoc(x=i) | ||||
|             t.save() | ||||
|  | ||||
| @@ -268,13 +268,14 @@ class QTest(unittest.TestCase): | ||||
|         self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) | ||||
|  | ||||
|         # Test invalid query objs | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|         def wrong_query_objs(): | ||||
|             self.Person.objects('user1') | ||||
|  | ||||
|         # filter should fail, too | ||||
|         with self.assertRaises(InvalidQueryError): | ||||
|             self.Person.objects.filter('user1') | ||||
|         def wrong_query_objs_filter(): | ||||
|             self.Person.objects('user1') | ||||
|  | ||||
|         self.assertRaises(InvalidQueryError, wrong_query_objs) | ||||
|         self.assertRaises(InvalidQueryError, wrong_query_objs_filter) | ||||
|  | ||||
|     def test_q_regex(self): | ||||
|         """Ensure that Q objects can be queried using regexes. | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| import sys | ||||
| import datetime | ||||
| from pymongo.errors import OperationFailure | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| try: | ||||
|     import unittest2 as unittest | ||||
| except ImportError: | ||||
| @@ -16,8 +19,7 @@ from mongoengine import ( | ||||
| ) | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
| import mongoengine.connection | ||||
| from mongoengine.connection import (MongoEngineConnectionError, get_db, | ||||
|                                     get_connection) | ||||
| from mongoengine.connection import get_db, get_connection, ConnectionError | ||||
|  | ||||
|  | ||||
| def get_tz_awareness(connection): | ||||
| @@ -35,7 +37,8 @@ class ConnectionTest(unittest.TestCase): | ||||
|         mongoengine.connection._dbs = {} | ||||
|  | ||||
|     def test_connect(self): | ||||
|         """Ensure that the connect() method works properly.""" | ||||
|         """Ensure that the connect() method works properly. | ||||
|         """ | ||||
|         connect('mongoenginetest') | ||||
|  | ||||
|         conn = get_connection() | ||||
| @@ -145,7 +148,8 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertEqual(expected_connection, actual_connection) | ||||
|  | ||||
|     def test_connect_uri(self): | ||||
|         """Ensure that the connect() method works properly with URIs.""" | ||||
|         """Ensure that the connect() method works properly with uri's | ||||
|         """ | ||||
|         c = connect(db='mongoenginetest', alias='admin') | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
| @@ -155,10 +159,7 @@ class ConnectionTest(unittest.TestCase): | ||||
|         c.mongoenginetest.add_user("username", "password") | ||||
|  | ||||
|         if not IS_PYMONGO_3: | ||||
|             self.assertRaises( | ||||
|                 MongoEngineConnectionError, connect, 'testdb_uri_bad', | ||||
|                 host='mongodb://test:password@localhost' | ||||
|             ) | ||||
|             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|  | ||||
|         connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') | ||||
|  | ||||
| @@ -173,9 +174,19 @@ class ConnectionTest(unittest.TestCase): | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|     def test_connect_uri_without_db(self): | ||||
|         """Ensure connect() method works properly if the URI doesn't | ||||
|         include a database name. | ||||
|         """Ensure connect() method works properly with uri's without database_name | ||||
|         """ | ||||
|         c = connect(db='mongoenginetest', alias='admin') | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|         c.admin.add_user("admin", "password") | ||||
|         c.admin.authenticate("admin", "password") | ||||
|         c.mongoenginetest.add_user("username", "password") | ||||
|  | ||||
|         if not IS_PYMONGO_3: | ||||
|             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|  | ||||
|         connect("mongoenginetest", host='mongodb://localhost/') | ||||
|  | ||||
|         conn = get_connection() | ||||
| @@ -185,35 +196,13 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'mongoenginetest') | ||||
|  | ||||
|     def test_connect_uri_default_db(self): | ||||
|         """Ensure connect() defaults to the right database name if | ||||
|         the URI and the database_name don't explicitly specify it. | ||||
|         """ | ||||
|         connect(host='mongodb://localhost/') | ||||
|  | ||||
|         conn = get_connection() | ||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||
|  | ||||
|         db = get_db() | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'test') | ||||
|  | ||||
|     def test_uri_without_credentials_doesnt_override_conn_settings(self): | ||||
|         """Ensure connect() uses the username & password params if the URI | ||||
|         doesn't explicitly specify them. | ||||
|         """ | ||||
|         c = connect(host='mongodb://localhost/mongoenginetest', | ||||
|                     username='user', | ||||
|                     password='pass') | ||||
|  | ||||
|         # OperationFailure means that mongoengine attempted authentication | ||||
|         # w/ the provided username/password and failed - that's the desired | ||||
|         # behavior. If the MongoDB URI would override the credentials | ||||
|         self.assertRaises(OperationFailure, get_db) | ||||
|         c.admin.system.users.remove({}) | ||||
|         c.mongoenginetest.system.users.remove({}) | ||||
|  | ||||
|     def test_connect_uri_with_authsource(self): | ||||
|         """Ensure that the connect() method works well with `authSource` | ||||
|         option in the URI. | ||||
|         """Ensure that the connect() method works well with | ||||
|         the option `authSource` in URI. | ||||
|         This feature was introduced in MongoDB 2.4 and removed in 2.6 | ||||
|         """ | ||||
|         # Create users | ||||
|         c = connect('mongoenginetest') | ||||
| @@ -222,38 +211,36 @@ class ConnectionTest(unittest.TestCase): | ||||
|  | ||||
|         # Authentication fails without "authSource" | ||||
|         if IS_PYMONGO_3: | ||||
|             test_conn = connect( | ||||
|                 'mongoenginetest', alias='test1', | ||||
|                 host='mongodb://username2:password@localhost/mongoenginetest' | ||||
|             ) | ||||
|             test_conn = connect('mongoenginetest', alias='test1', | ||||
|                                 host='mongodb://username2:password@localhost/mongoenginetest') | ||||
|             self.assertRaises(OperationFailure, test_conn.server_info) | ||||
|         else: | ||||
|             self.assertRaises( | ||||
|                 MongoEngineConnectionError, | ||||
|                 connect, 'mongoenginetest', alias='test1', | ||||
|                 ConnectionError, connect, 'mongoenginetest', alias='test1', | ||||
|                 host='mongodb://username2:password@localhost/mongoenginetest' | ||||
|             ) | ||||
|             self.assertRaises(MongoEngineConnectionError, get_db, 'test1') | ||||
|             self.assertRaises(ConnectionError, get_db, 'test1') | ||||
|  | ||||
|         # Authentication succeeds with "authSource" | ||||
|         authd_conn = connect( | ||||
|         connect( | ||||
|             'mongoenginetest', alias='test2', | ||||
|             host=('mongodb://username2:password@localhost/' | ||||
|                   'mongoenginetest?authSource=admin') | ||||
|         ) | ||||
|         # This will fail starting from MongoDB 2.6+ | ||||
|         db = get_db('test2') | ||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|         self.assertEqual(db.name, 'mongoenginetest') | ||||
|  | ||||
|         # Clear all users | ||||
|         authd_conn.admin.system.users.remove({}) | ||||
|         c.admin.system.users.remove({}) | ||||
|  | ||||
|     def test_register_connection(self): | ||||
|         """Ensure that connections with different aliases may be registered. | ||||
|         """ | ||||
|         register_connection('testdb', 'mongoenginetest2') | ||||
|  | ||||
|         self.assertRaises(MongoEngineConnectionError, get_connection) | ||||
|         self.assertRaises(ConnectionError, get_connection) | ||||
|         conn = get_connection('testdb') | ||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||
|  | ||||
| @@ -270,7 +257,8 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||
|  | ||||
|     def test_connection_kwargs(self): | ||||
|         """Ensure that connection kwargs get passed to pymongo.""" | ||||
|         """Ensure that connection kwargs get passed to pymongo. | ||||
|         """ | ||||
|         connect('mongoenginetest', alias='t1', tz_aware=True) | ||||
|         conn = get_connection('t1') | ||||
|  | ||||
| @@ -280,77 +268,6 @@ class ConnectionTest(unittest.TestCase): | ||||
|         conn = get_connection('t2') | ||||
|         self.assertFalse(get_tz_awareness(conn)) | ||||
|  | ||||
|     def test_connection_pool_via_kwarg(self): | ||||
|         """Ensure we can specify a max connection pool size using | ||||
|         a connection kwarg. | ||||
|         """ | ||||
|         # Use "max_pool_size" or "maxpoolsize" depending on PyMongo version | ||||
|         # (former was changed to the latter as described in | ||||
|         # https://jira.mongodb.org/browse/PYTHON-854). | ||||
|         # TODO remove once PyMongo < 3.0 support is dropped | ||||
|         if pymongo.version_tuple[0] >= 3: | ||||
|             pool_size_kwargs = {'maxpoolsize': 100} | ||||
|         else: | ||||
|             pool_size_kwargs = {'max_pool_size': 100} | ||||
|  | ||||
|         conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs) | ||||
|         self.assertEqual(conn.max_pool_size, 100) | ||||
|  | ||||
|     def test_connection_pool_via_uri(self): | ||||
|         """Ensure we can specify a max connection pool size using | ||||
|         an option in a connection URI. | ||||
|         """ | ||||
|         if pymongo.version_tuple[0] == 2 and pymongo.version_tuple[1] < 9: | ||||
|             raise SkipTest('maxpoolsize as a URI option is only supported in PyMongo v2.9+') | ||||
|  | ||||
|         conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri') | ||||
|         self.assertEqual(conn.max_pool_size, 100) | ||||
|  | ||||
|     def test_write_concern(self): | ||||
|         """Ensure write concern can be specified in connect() via | ||||
|         a kwarg or as part of the connection URI. | ||||
|         """ | ||||
|         conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true') | ||||
|         conn2 = connect('testing', alias='conn2', w=1, j=True) | ||||
|         if IS_PYMONGO_3: | ||||
|             self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) | ||||
|             self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) | ||||
|         else: | ||||
|             self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True}) | ||||
|             self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True}) | ||||
|  | ||||
|     def test_connect_with_replicaset_via_uri(self): | ||||
|         """Ensure connect() works when specifying a replicaSet via the | ||||
|         MongoDB URI. | ||||
|         """ | ||||
|         if IS_PYMONGO_3: | ||||
|             c = connect(host='mongodb://localhost/test?replicaSet=local-rs') | ||||
|             db = get_db() | ||||
|             self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|             self.assertEqual(db.name, 'test') | ||||
|         else: | ||||
|             # PyMongo < v3.x raises an exception: | ||||
|             # "localhost:27017 is not a member of replica set local-rs" | ||||
|             with self.assertRaises(MongoEngineConnectionError): | ||||
|                 c = connect(host='mongodb://localhost/test?replicaSet=local-rs') | ||||
|  | ||||
|     def test_connect_with_replicaset_via_kwargs(self): | ||||
|         """Ensure connect() works when specifying a replicaSet via the | ||||
|         connection kwargs | ||||
|         """ | ||||
|         if IS_PYMONGO_3: | ||||
|             c = connect(replicaset='local-rs') | ||||
|             self.assertEqual(c._MongoClient__options.replica_set_name, | ||||
|                              'local-rs') | ||||
|             db = get_db() | ||||
|             self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||
|             self.assertEqual(db.name, 'test') | ||||
|         else: | ||||
|             # PyMongo < v3.x raises an exception: | ||||
|             # "localhost:27017 is not a member of replica set local-rs" | ||||
|             with self.assertRaises(MongoEngineConnectionError): | ||||
|                 c = connect(replicaset='local-rs') | ||||
|  | ||||
|     def test_datetime(self): | ||||
|         connect('mongoenginetest', tz_aware=True) | ||||
|         d = datetime.datetime(2010, 5, 5, tzinfo=utc) | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
| @@ -77,7 +79,7 @@ class ContextManagersTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             User(name='user %s' % i).save() | ||||
|  | ||||
|         user = User.objects.first() | ||||
| @@ -115,7 +117,7 @@ class ContextManagersTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             User(name='user %s' % i).save() | ||||
|  | ||||
|         user = User.objects.first() | ||||
| @@ -193,7 +195,7 @@ class ContextManagersTest(unittest.TestCase): | ||||
|         with query_counter() as q: | ||||
|             self.assertEqual(0, q) | ||||
|  | ||||
|             for i in range(1, 51): | ||||
|             for i in xrange(1, 51): | ||||
|                 db.test.find({}).count() | ||||
|  | ||||
|             self.assertEqual(50, q) | ||||
|   | ||||
| @@ -23,8 +23,7 @@ class TestStrictDict(unittest.TestCase): | ||||
|         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') | ||||
|  | ||||
|     def test_init_fails_on_nonexisting_attrs(self): | ||||
|         with self.assertRaises(AttributeError): | ||||
|             self.dtype(a=1, b=2, d=3) | ||||
|         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||
|  | ||||
|     def test_eq(self): | ||||
|         d = self.dtype(a=1, b=1, c=1) | ||||
| @@ -47,12 +46,14 @@ class TestStrictDict(unittest.TestCase): | ||||
|         d = self.dtype() | ||||
|         d.a = 1 | ||||
|         self.assertEqual(d.a, 1) | ||||
|         self.assertRaises(AttributeError, getattr, d, 'b') | ||||
|         self.assertRaises(AttributeError, lambda: d.b) | ||||
|  | ||||
|     def test_setattr_raises_on_nonexisting_attr(self): | ||||
|         d = self.dtype() | ||||
|         with self.assertRaises(AttributeError): | ||||
|  | ||||
|         def _f(): | ||||
|             d.x = 1 | ||||
|         self.assertRaises(AttributeError, _f) | ||||
|  | ||||
|     def test_setattr_getattr_special(self): | ||||
|         d = self.strict_dict_class(["items"]) | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from bson import DBRef, ObjectId | ||||
| @@ -30,7 +32,7 @@ class FieldTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             user = User(name='user %s' % i) | ||||
|             user.save() | ||||
|  | ||||
| @@ -88,7 +90,7 @@ class FieldTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             user = User(name='user %s' % i) | ||||
|             user.save() | ||||
|  | ||||
| @@ -160,7 +162,7 @@ class FieldTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 26): | ||||
|         for i in xrange(1, 26): | ||||
|             user = User(name='user %s' % i) | ||||
|             user.save() | ||||
|  | ||||
| @@ -438,7 +440,7 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
| @@ -529,7 +531,7 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
| @@ -612,15 +614,15 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             user = User(name='user %s' % i) | ||||
|             user.save() | ||||
|             members.append(user) | ||||
|  | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -685,7 +687,7 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
| @@ -697,9 +699,9 @@ class FieldTest(unittest.TestCase): | ||||
|  | ||||
|             members += [a, b, c] | ||||
|  | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -781,16 +783,16 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
|             members += [a] | ||||
|  | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -864,7 +866,7 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i) | ||||
|             a.save() | ||||
|  | ||||
| @@ -876,9 +878,9 @@ class FieldTest(unittest.TestCase): | ||||
|  | ||||
|             members += [a, b, c] | ||||
|  | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|         group = Group(members={str(u.id): u for u in members}) | ||||
|         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||
|         group.save() | ||||
|  | ||||
|         with query_counter() as q: | ||||
| @@ -1101,7 +1103,7 @@ class FieldTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             User(name='user %s' % i).save() | ||||
|  | ||||
|         Group(name="Test", members=User.objects).save() | ||||
| @@ -1130,7 +1132,7 @@ class FieldTest(unittest.TestCase): | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             User(name='user %s' % i).save() | ||||
|  | ||||
|         Group(name="Test", members=User.objects).save() | ||||
| @@ -1167,7 +1169,7 @@ class FieldTest(unittest.TestCase): | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         members = [] | ||||
|         for i in range(1, 51): | ||||
|         for i in xrange(1, 51): | ||||
|             a = UserA(name='User A %s' % i).save() | ||||
|             b = UserB(name='User B %s' % i).save() | ||||
|             c = UserC(name='User C %s' % i).save() | ||||
|   | ||||
| @@ -1,3 +1,6 @@ | ||||
| import sys | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from pymongo import ReadPreference | ||||
| @@ -15,7 +18,7 @@ else: | ||||
|  | ||||
| import mongoengine | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import MongoEngineConnectionError | ||||
| from mongoengine.connection import ConnectionError | ||||
|  | ||||
|  | ||||
| class ConnectionTest(unittest.TestCase): | ||||
| @@ -38,7 +41,7 @@ class ConnectionTest(unittest.TestCase): | ||||
|             conn = connect(db='mongoenginetest', | ||||
|                            host="mongodb://localhost/mongoenginetest?replicaSet=rs", | ||||
|                            read_preference=READ_PREF) | ||||
|         except MongoEngineConnectionError as e: | ||||
|         except ConnectionError, e: | ||||
|             return | ||||
|  | ||||
|         if not isinstance(conn, CONN_CLASS): | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import sys | ||||
| sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
|   | ||||
| @@ -1,78 +0,0 @@ | ||||
| import unittest | ||||
|  | ||||
| from nose.plugins.skip import SkipTest | ||||
|  | ||||
| from mongoengine import connect | ||||
| from mongoengine.connection import get_db, get_connection | ||||
| from mongoengine.python_support import IS_PYMONGO_3 | ||||
|  | ||||
|  | ||||
| MONGO_TEST_DB = 'mongoenginetest' | ||||
|  | ||||
|  | ||||
| class MongoDBTestCase(unittest.TestCase): | ||||
|     """Base class for tests that need a mongodb connection | ||||
|     db is being dropped automatically | ||||
|     """ | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         cls._connection = connect(db=MONGO_TEST_DB) | ||||
|         cls._connection.drop_database(MONGO_TEST_DB) | ||||
|         cls.db = get_db() | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         cls._connection.drop_database(MONGO_TEST_DB) | ||||
|  | ||||
|  | ||||
| def get_mongodb_version(): | ||||
|     """Return the version tuple of the MongoDB server that the default | ||||
|     connection is connected to. | ||||
|     """ | ||||
|     return tuple(get_connection().server_info()['versionArray']) | ||||
|  | ||||
| def _decorated_with_ver_requirement(func, ver_tuple): | ||||
|     """Return a given function decorated with the version requirement | ||||
|     for a particular MongoDB version tuple. | ||||
|     """ | ||||
|     def _inner(*args, **kwargs): | ||||
|         mongodb_ver = get_mongodb_version() | ||||
|         if mongodb_ver >= ver_tuple: | ||||
|             return func(*args, **kwargs) | ||||
|  | ||||
|         raise SkipTest('Needs MongoDB v{}+'.format( | ||||
|             '.'.join([str(v) for v in ver_tuple]) | ||||
|         )) | ||||
|  | ||||
|     _inner.__name__ = func.__name__ | ||||
|     _inner.__doc__ = func.__doc__ | ||||
|  | ||||
|     return _inner | ||||
|  | ||||
| def needs_mongodb_v26(func): | ||||
|     """Raise a SkipTest exception if we're working with MongoDB version | ||||
|     lower than v2.6. | ||||
|     """ | ||||
|     return _decorated_with_ver_requirement(func, (2, 6)) | ||||
|  | ||||
| def needs_mongodb_v3(func): | ||||
|     """Raise a SkipTest exception if we're working with MongoDB version | ||||
|     lower than v3.0. | ||||
|     """ | ||||
|     return _decorated_with_ver_requirement(func, (3, 0)) | ||||
|  | ||||
| def skip_pymongo3(f): | ||||
|     """Raise a SkipTest exception if we're running a test against | ||||
|     PyMongo v3.x. | ||||
|     """ | ||||
|     def _inner(*args, **kwargs): | ||||
|         if IS_PYMONGO_3: | ||||
|             raise SkipTest("Useless with PyMongo 3+") | ||||
|         return f(*args, **kwargs) | ||||
|  | ||||
|     _inner.__name__ = f.__name__ | ||||
|     _inner.__doc__ = f.__doc__ | ||||
|  | ||||
|     return _inner | ||||
|  | ||||
							
								
								
									
										13
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								tox.ini
									
									
									
									
									
								
							| @@ -1,5 +1,5 @@ | ||||
| [tox] | ||||
| envlist = {py27,py35,pypy,pypy3}-{mg27,mg28,mg30} | ||||
| envlist = {py26,py27,py33,py34,py35,pypy,pypy3}-{mg27,mg28},flake8 | ||||
|  | ||||
| [testenv] | ||||
| commands = | ||||
| @@ -7,7 +7,16 @@ commands = | ||||
| deps = | ||||
|     nose | ||||
|     mg27: PyMongo<2.8 | ||||
|     mg28: PyMongo>=2.8,<2.9 | ||||
|     mg28: PyMongo>=2.8,<3.0 | ||||
|     mg30: PyMongo>=3.0 | ||||
|     mgdev: https://github.com/mongodb/mongo-python-driver/tarball/master | ||||
| setenv = | ||||
|     PYTHON_EGG_CACHE = {envdir}/python-eggs | ||||
| passenv = windir | ||||
|  | ||||
| [testenv:flake8] | ||||
| deps = | ||||
|     flake8 | ||||
|     flake8-import-order | ||||
| commands = | ||||
|    flake8 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user