Compare commits
	
		
			181 Commits
		
	
	
		
			v0.10.2
			...
			fix-iterat
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 894678da39 | ||
|  | 0a66a4b8a9 | ||
|  | 15714ef855 | ||
|  | eb743beaa3 | ||
|  | 0007535a46 | ||
|  | 8391af026c | ||
|  | 800f656dcf | ||
|  | 088c5f49d9 | ||
|  | d8d98b6143 | ||
|  | 02fb3b9315 | ||
|  | 4f87db784e | ||
|  | 7e6287b925 | ||
|  | 999cdfd997 | ||
|  | 8d6cb087c6 | ||
|  | 2b7417c728 | ||
|  | 3c455cf1c1 | ||
|  | 5135185e31 | ||
|  | b461f26e5d | ||
|  | faef5b8570 | ||
|  | 0a20e04c10 | ||
|  | d19bb2308d | ||
|  | d8dd07d9ef | ||
|  | 36c56243cd | ||
|  | 23d06b79a6 | ||
|  | e4c4e923ee | ||
|  | 936d2f1f47 | ||
|  | 07018b5060 | ||
|  | ac90d6ae5c | ||
|  | 2141f2c4c5 | ||
|  | 81870777a9 | ||
|  | 845092dcad | ||
|  | dd473d1e1e | ||
|  | d2869bf4ed | ||
|  | 891a3f4b29 | ||
|  | 6767b50d75 | ||
|  | d9e4b562a9 | ||
|  | fb3243f1bc | ||
|  | 5fe1497c92 | ||
|  | 5446592d44 | ||
|  | 40ed9a53c9 | ||
|  | f7ac8cea90 | ||
|  | 4ef5d1f0cd | ||
|  | 6992615c98 | ||
|  | 43dabb2825 | ||
|  | 05e40e5681 | ||
|  | 2c4536e137 | ||
|  | 3dc81058a0 | ||
|  | bd84667a2b | ||
|  | e5b6a12977 | ||
|  | ca415d5d62 | ||
|  | 99b4fe7278 | ||
|  | 327e164869 | ||
|  | 25bc571f30 | ||
|  | 38c7e8a1d2 | ||
|  | ca282e28e0 | ||
|  | 5ef59c06df | ||
|  | 8f55d385d6 | ||
|  | cd2fc25c19 | ||
|  | 709983eea6 | ||
|  | 40e99b1b80 | ||
|  | 488684d960 | ||
|  | f35034b989 | ||
|  | 9d6f9b1f26 | ||
|  | 6148a608fb | ||
|  | 3fa9e70383 | ||
|  | 16fea6f009 | ||
|  | df9ed835ca | ||
|  | e394c8f0f2 | ||
|  | 21974f7288 | ||
|  | 5ef0170d77 | ||
|  | c21dcf14de | ||
|  | a8d20d4e1e | ||
|  | 8b307485b0 | ||
|  | 4544afe422 | ||
|  | 9d7eba5f70 | ||
|  | be0aee95f2 | ||
|  | 3469ed7ab9 | ||
|  | 1f223aa7e6 | ||
|  | 0a431ead5e | ||
|  | f750796444 | ||
|  | c82bcd882a | ||
|  | 7d0ec33b54 | ||
|  | 43d48b3feb | ||
|  | 2e406d2687 | ||
|  | 3f30808104 | ||
|  | ab10217c86 | ||
|  | 00430491ca | ||
|  | 109202329f | ||
|  | 3b1509f307 | ||
|  | 7ad7b08bed | ||
|  | 4650e5e8fb | ||
|  | af59d4929e | ||
|  | e34100bab4 | ||
|  | d9b3a9fb60 | ||
|  | 39eec59c90 | ||
|  | d651d0d472 | ||
|  | 87a2358a65 | ||
|  | cef4e313e1 | ||
|  | 7cc1a4eba0 | ||
|  | c6cc0133b3 | ||
|  | 7748e68440 | ||
|  | 6c2230a076 | ||
|  | 66b233eaea | ||
|  | fed58f3920 | ||
|  | 815b2be7f7 | ||
|  | f420c9fb7c | ||
|  | 01bdf10b94 | ||
|  | ddedc1ee92 | ||
|  | 9e9703183f | ||
|  | adce9e6220 | ||
|  | c499133bbe | ||
|  | 8f505c2dcc | ||
|  | b320064418 | ||
|  | a643933d16 | ||
|  | 2659ec5887 | ||
|  | 9f8327926d | ||
|  | 7a568dc118 | ||
|  | c946b06be5 | ||
|  | c65fd0e477 | ||
|  | 8f8217e928 | ||
|  | 6c9e1799c7 | ||
|  | decd70eb23 | ||
|  | a20d40618f | ||
|  | b4af8ec751 | ||
|  | feb5eed8a5 | ||
|  | f4fa39c70e | ||
|  | 7b7165f5d8 | ||
|  | 13897db6d3 | ||
|  | c4afdb7198 | ||
|  | 0284975f3f | ||
|  | 269e3d1303 | ||
|  | 8c81f7ece9 | ||
|  | f6e0593774 | ||
|  | 3d80e549cb | ||
|  | acc7448dc5 | ||
|  | 35d3d3de72 | ||
|  | 0372e07eb0 | ||
|  | 00221e3410 | ||
|  | 9c264611cf | ||
|  | 31d7f70e27 | ||
|  | 04e8b83d45 | ||
|  | e87bf71f20 | ||
|  | 2dd70c8d62 | ||
|  | a3886702a3 | ||
|  | 713af133a0 | ||
|  | 057ffffbf2 | ||
|  | a81d6d124b | ||
|  | 23f07fde5e | ||
|  | b42b760393 | ||
|  | bf6f4c48c0 | ||
|  | 6133f04841 | ||
|  | 3c18f79ea4 | ||
|  | 2af8342fea | ||
|  | fc3db7942d | ||
|  | 164e2b2678 | ||
|  | b7b28390df | ||
|  | a6e996d921 | ||
|  | 07e666345d | ||
|  | 007f10d29d | ||
|  | f9284d20ca | ||
|  | 9050869781 | ||
|  | 54975de0f3 | ||
|  | a7aead5138 | ||
|  | 6868f66f24 | ||
|  | 04497aec36 | ||
|  | aa9d596930 | ||
|  | f96e68cd11 | ||
|  | 013227323d | ||
|  | 0a1ba7c434 | ||
|  | cceef33fef | ||
|  | ed8174fe36 | ||
|  | 3c8906494f | ||
|  | 6e745e9882 | ||
|  | fb4e9c3772 | ||
|  | 8e7c5af16c | ||
|  | c1645ab7a7 | ||
|  | 2ae2bfdde9 | ||
|  | 3fe93968a6 | ||
|  | eb8176971c | ||
|  | 5bbfca45fa | ||
|  | 11024deaae | 
							
								
								
									
										21
									
								
								.travis.yml
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								.travis.yml
									
									
									
									
									
								
							| @@ -1,41 +1,58 @@ | |||||||
| language: python | language: python | ||||||
|  |  | ||||||
| python: | python: | ||||||
| - '2.6' | - '2.6' | ||||||
| - '2.7' | - '2.7' | ||||||
| - '3.2' |  | ||||||
| - '3.3' | - '3.3' | ||||||
| - '3.4' | - '3.4' | ||||||
| - '3.5' | - '3.5' | ||||||
| - pypy | - pypy | ||||||
| - pypy3 | - pypy3 | ||||||
|  |  | ||||||
| env: | env: | ||||||
| - PYMONGO=2.7 | - PYMONGO=2.7 | ||||||
| - PYMONGO=2.8 | - PYMONGO=2.8 | ||||||
| - PYMONGO=3.0 | - PYMONGO=3.0 | ||||||
| - PYMONGO=dev | - PYMONGO=dev | ||||||
|  |  | ||||||
| matrix: | matrix: | ||||||
|   fast_finish: true |   fast_finish: true | ||||||
|  |  | ||||||
| before_install: | before_install: | ||||||
| - travis_retry sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 | - travis_retry sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 | ||||||
| - echo 'deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen' | | - echo 'deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen' | | ||||||
|   sudo tee /etc/apt/sources.list.d/mongodb.list |   sudo tee /etc/apt/sources.list.d/mongodb.list | ||||||
| - travis_retry sudo apt-get update | - travis_retry sudo apt-get update | ||||||
| - travis_retry sudo apt-get install mongodb-org-server | - travis_retry sudo apt-get install mongodb-org-server | ||||||
|  |  | ||||||
| install: | install: | ||||||
| - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev | - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev | ||||||
|   libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev |   libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev | ||||||
|   python-tk |   python-tk | ||||||
| - travis_retry pip install tox>=1.9 coveralls | - travis_retry pip install --upgrade pip | ||||||
|  | - travis_retry pip install coveralls | ||||||
|  | - travis_retry pip install flake8 | ||||||
|  | - travis_retry pip install tox>=1.9 | ||||||
|  | - travis_retry pip install "virtualenv<14.0.0"  # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) | ||||||
| - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test | - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test | ||||||
|  |  | ||||||
|  | # Run flake8 for py27 | ||||||
|  | before_script: | ||||||
|  | - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then tox -e flake8; fi | ||||||
|  |  | ||||||
| script: | script: | ||||||
| - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage | - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage | ||||||
|  |  | ||||||
| after_script: coveralls --verbose | after_script: coveralls --verbose | ||||||
|  |  | ||||||
| notifications: | notifications: | ||||||
|   irc: irc.freenode.org#mongoengine |   irc: irc.freenode.org#mongoengine | ||||||
|  |  | ||||||
| branches: | branches: | ||||||
|   only: |   only: | ||||||
|   - master |   - master | ||||||
|   - /^v.*$/ |   - /^v.*$/ | ||||||
|  |  | ||||||
| deploy: | deploy: | ||||||
|   provider: pypi |   provider: pypi | ||||||
|   user: the_drow |   user: the_drow | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -228,5 +228,17 @@ that much better: | |||||||
|  * Vicki Donchenko (https://github.com/kivistein) |  * Vicki Donchenko (https://github.com/kivistein) | ||||||
|  * Emile Caron (https://github.com/emilecaron) |  * Emile Caron (https://github.com/emilecaron) | ||||||
|  * Amit Lichtenberg (https://github.com/amitlicht) |  * Amit Lichtenberg (https://github.com/amitlicht) | ||||||
|  |  * Gang Li (https://github.com/iici-gli) | ||||||
|  * Lars Butler (https://github.com/larsbutler) |  * Lars Butler (https://github.com/larsbutler) | ||||||
|  * George Macon (https://github.com/gmacon) |  * George Macon (https://github.com/gmacon) | ||||||
|  |  * Ashley Whetter (https://github.com/AWhetter) | ||||||
|  |  * Paul-Armand Verhaegen (https://github.com/paularmand) | ||||||
|  |  * Steven Rossiter (https://github.com/BeardedSteve) | ||||||
|  |  * Luo Peng (https://github.com/RussellLuo) | ||||||
|  |  * Bryan Bennett (https://github.com/bbenne10) | ||||||
|  |  * Gilb's Gilb's (https://github.com/gilbsgilbs) | ||||||
|  |  * Joshua Nedrud (https://github.com/Neurostack) | ||||||
|  |  * Shu Shen (https://github.com/shushen) | ||||||
|  |  * xiaost7 (https://github.com/xiaost7) | ||||||
|  |  * Victor Varvaryuk | ||||||
|  |  * Stanislav Kaledin (https://github.com/sallyruthstruik) | ||||||
|   | |||||||
							
								
								
									
										28
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								README.rst
									
									
									
									
									
								
							| @@ -6,23 +6,23 @@ MongoEngine | |||||||
| :Author: Harry Marr (http://github.com/hmarr) | :Author: Harry Marr (http://github.com/hmarr) | ||||||
| :Maintainer: Ross Lawley (http://github.com/rozza) | :Maintainer: Ross Lawley (http://github.com/rozza) | ||||||
|  |  | ||||||
| .. image:: https://secure.travis-ci.org/MongoEngine/mongoengine.png?branch=master | .. image:: https://travis-ci.org/MongoEngine/mongoengine.svg?branch=master | ||||||
|   :target: http://travis-ci.org/MongoEngine/mongoengine |   :target: https://travis-ci.org/MongoEngine/mongoengine | ||||||
|  |  | ||||||
| .. image:: https://coveralls.io/repos/MongoEngine/mongoengine/badge.png?branch=master | .. image:: https://coveralls.io/repos/github/MongoEngine/mongoengine/badge.svg?branch=master | ||||||
|   :target: https://coveralls.io/r/MongoEngine/mongoengine?branch=master |   :target: https://coveralls.io/github/MongoEngine/mongoengine?branch=master | ||||||
|  |  | ||||||
| .. image:: https://landscape.io/github/MongoEngine/mongoengine/master/landscape.png | .. image:: https://landscape.io/github/MongoEngine/mongoengine/master/landscape.svg?style=flat | ||||||
|   :target: https://landscape.io/github/MongoEngine/mongoengine/master |   :target: https://landscape.io/github/MongoEngine/mongoengine/master | ||||||
|   :alt: Code Health |   :alt: Code Health | ||||||
|  |  | ||||||
| About | About | ||||||
| ===== | ===== | ||||||
| MongoEngine is a Python Object-Document Mapper for working with MongoDB. | MongoEngine is a Python Object-Document Mapper for working with MongoDB. | ||||||
| Documentation available at http://mongoengine-odm.rtfd.org - there is currently | Documentation available at https://mongoengine-odm.readthedocs.io - there is currently | ||||||
| a `tutorial <http://readthedocs.org/docs/mongoengine-odm/en/latest/tutorial.html>`_, a `user guide | a `tutorial <https://mongoengine-odm.readthedocs.io/tutorial.html>`_, a `user guide | ||||||
| <https://mongoengine-odm.readthedocs.org/en/latest/guide/index.html>`_ and an `API reference | <https://mongoengine-odm.readthedocs.io/guide/index.html>`_ and an `API reference | ||||||
| <http://readthedocs.org/docs/mongoengine-odm/en/latest/apireference.html>`_. | <https://mongoengine-odm.readthedocs.io/apireference.html>`_. | ||||||
|  |  | ||||||
| Installation | Installation | ||||||
| ============ | ============ | ||||||
| @@ -48,12 +48,18 @@ Optional Dependencies | |||||||
|  |  | ||||||
| Examples | Examples | ||||||
| ======== | ======== | ||||||
| Some simple examples of what MongoEngine code looks like:: | Some simple examples of what MongoEngine code looks like: | ||||||
|  |  | ||||||
|  | .. code :: python | ||||||
|  |  | ||||||
|  |     from mongoengine import * | ||||||
|  |     connect('mydb') | ||||||
|  |  | ||||||
|     class BlogPost(Document): |     class BlogPost(Document): | ||||||
|         title = StringField(required=True, max_length=200) |         title = StringField(required=True, max_length=200) | ||||||
|         posted = DateTimeField(default=datetime.datetime.now) |         posted = DateTimeField(default=datetime.datetime.now) | ||||||
|         tags = ListField(StringField(max_length=50)) |         tags = ListField(StringField(max_length=50)) | ||||||
|  |         meta = {'allow_inheritance': True} | ||||||
|  |  | ||||||
|     class TextPost(BlogPost): |     class TextPost(BlogPost): | ||||||
|         content = StringField(required=True) |         content = StringField(required=True) | ||||||
| @@ -97,7 +103,7 @@ Some simple examples of what MongoEngine code looks like:: | |||||||
| Tests | Tests | ||||||
| ===== | ===== | ||||||
| To run the test suite, ensure you are running a local instance of MongoDB on | To run the test suite, ensure you are running a local instance of MongoDB on | ||||||
| the standard port, and run: ``python setup.py nosetests``. | the standard port and have ``nose`` installed. Then, run: ``python setup.py nosetests``. | ||||||
|  |  | ||||||
| To run the test suite on every supported Python version and every supported PyMongo version, | To run the test suite on every supported Python version and every supported PyMongo version, | ||||||
| you can use ``tox``. | you can use ``tox``. | ||||||
|   | |||||||
| @@ -2,19 +2,71 @@ | |||||||
| Changelog | Changelog | ||||||
| ========= | ========= | ||||||
|  |  | ||||||
|  | Changes in 0.10.8 | ||||||
|  | ================= | ||||||
|  | - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | ||||||
|  | - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | ||||||
|  | - Fixed BaseQuerySet#sum/average for fields w/ explicit db_field #1417 | ||||||
|  |  | ||||||
|  | Changes in 0.10.7 | ||||||
|  | ================= | ||||||
|  | - Dropped Python 3.2 support #1390 | ||||||
|  | - Fixed the bug where dynamic doc has index inside a dict field #1278 | ||||||
|  | - Fixed: ListField minus index assignment does not work #1128 | ||||||
|  | - Fixed cascade delete mixing among collections #1224 | ||||||
|  | - Add `signal_kwargs` argument to `Document.save`, `Document.delete` and `BaseQuerySet.insert` to be passed to signals calls #1206 | ||||||
|  | - Raise `OperationError` when trying to do a `drop_collection` on document with no collection set. | ||||||
|  | - count on ListField of EmbeddedDocumentField fails. #1187 | ||||||
|  | - Fixed long fields stored as int32 in Python 3. #1253 | ||||||
|  | - MapField now handles unicodes keys correctly. #1267 | ||||||
|  | - ListField now handles negative indicies correctly. #1270 | ||||||
|  | - Fixed AttributeError when initializing EmbeddedDocument with positional args. #681 | ||||||
|  | - Fixed no_cursor_timeout error with pymongo 3.0+ #1304 | ||||||
|  | - Replaced map-reduce based QuerySet.sum/average with aggregation-based implementations #1336 | ||||||
|  | - Fixed support for `__` to escape field names that match operators names in `update` #1351 | ||||||
|  | - Fixed BaseDocument#_mark_as_changed #1369 | ||||||
|  | - Added support for pickling QuerySet instances. #1397 | ||||||
|  | - Fixed connecting to a list of hosts #1389 | ||||||
|  | - Fixed a bug where accessing broken references wouldn't raise a DoesNotExist error #1334 | ||||||
|  | - Fixed not being able to specify use_db_field=False on ListField(EmbeddedDocumentField) instances #1218 | ||||||
|  | - Improvements to the dictionary fields docs #1383 | ||||||
|  |  | ||||||
|  | Changes in 0.10.6 | ||||||
|  | ================= | ||||||
|  | - Add support for mocking MongoEngine based on mongomock. #1151 | ||||||
|  | - Fixed not being able to run tests on Windows. #1153 | ||||||
|  | - Allow creation of sparse compound indexes. #1114 | ||||||
|  | - count on ListField of EmbeddedDocumentField fails. #1187 | ||||||
|  |  | ||||||
|  | Changes in 0.10.5 | ||||||
|  | ================= | ||||||
|  | - Fix for reloading of strict with special fields. #1156 | ||||||
|  |  | ||||||
|  | Changes in 0.10.4 | ||||||
|  | ================= | ||||||
|  | - SaveConditionError is now importable from the top level package. #1165 | ||||||
|  | - upsert_one method added. #1157 | ||||||
|  |  | ||||||
|  | Changes in 0.10.3 | ||||||
|  | ================= | ||||||
|  | - Fix `read_preference` (it had chaining issues with PyMongo 2.x and it didn't work at all with PyMongo 3.x) #1042 | ||||||
|  |  | ||||||
| Changes in 0.10.2 | Changes in 0.10.2 | ||||||
| ================= | ================= | ||||||
| - Allow shard key to point to a field in an embedded document. #551 | - Allow shard key to point to a field in an embedded document. #551 | ||||||
| - Allow arbirary metadata in fields. #1129 | - Allow arbirary metadata in fields. #1129 | ||||||
|  | - ReferenceFields now support abstract document types. #837 | ||||||
|  |  | ||||||
| Changes in 0.10.1 | Changes in 0.10.1 | ||||||
| ======================= | ================= | ||||||
| - Fix infinite recursion with CASCADE delete rules under specific conditions. #1046 | - Fix infinite recursion with CASCADE delete rules under specific conditions. #1046 | ||||||
| - Fix CachedReferenceField bug when loading cached docs as DBRef but failing to save them. #1047 | - Fix CachedReferenceField bug when loading cached docs as DBRef but failing to save them. #1047 | ||||||
| - Fix ignored chained options #842 | - Fix ignored chained options #842 | ||||||
| - Document save's save_condition error raises `SaveConditionError` exception #1070 | - Document save's save_condition error raises `SaveConditionError` exception #1070 | ||||||
| - Fix Document.reload for DynamicDocument. #1050 | - Fix Document.reload for DynamicDocument. #1050 | ||||||
| - StrictDict & SemiStrictDict are shadowed at init time. #1105 | - StrictDict & SemiStrictDict are shadowed at init time. #1105 | ||||||
|  | - Fix ListField minus index assignment does not work. #1119 | ||||||
|  | - Remove code that marks field as changed when the field has default but not existed in database #1126 | ||||||
| - Remove test dependencies (nose and rednose) from install dependencies list. #1079 | - Remove test dependencies (nose and rednose) from install dependencies list. #1079 | ||||||
| - Recursively build query when using elemMatch operator. #1130 | - Recursively build query when using elemMatch operator. #1130 | ||||||
| - Fix instance back references for lists of embedded documents. #1131 | - Fix instance back references for lists of embedded documents. #1131 | ||||||
|   | |||||||
| @@ -17,6 +17,10 @@ class Post(Document): | |||||||
|     tags = ListField(StringField(max_length=30)) |     tags = ListField(StringField(max_length=30)) | ||||||
|     comments = ListField(EmbeddedDocumentField(Comment)) |     comments = ListField(EmbeddedDocumentField(Comment)) | ||||||
|  |  | ||||||
|  |     # bugfix | ||||||
|  |     meta = {'allow_inheritance': True} | ||||||
|  |  | ||||||
|  |  | ||||||
| class TextPost(Post): | class TextPost(Post): | ||||||
|     content = StringField() |     content = StringField() | ||||||
|  |  | ||||||
| @@ -45,7 +49,8 @@ print 'ALL POSTS' | |||||||
| print | print | ||||||
| for post in Post.objects: | for post in Post.objects: | ||||||
|     print post.title |     print post.title | ||||||
|     print '=' * post.title.count() |     #print '=' * post.title.count() | ||||||
|  |     print "=" * 20 | ||||||
|  |  | ||||||
|     if isinstance(post, TextPost): |     if isinstance(post, TextPost): | ||||||
|         print post.content |         print post.content | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ documents are serialized based on their field order. | |||||||
|  |  | ||||||
| Dynamic document schemas | Dynamic document schemas | ||||||
| ======================== | ======================== | ||||||
| One of the benefits of MongoDb is dynamic schemas for a collection, whilst data | One of the benefits of MongoDB is dynamic schemas for a collection, whilst data | ||||||
| should be planned and organised (after all explicit is better than implicit!) | should be planned and organised (after all explicit is better than implicit!) | ||||||
| there are scenarios where having dynamic / expando style documents is desirable. | there are scenarios where having dynamic / expando style documents is desirable. | ||||||
|  |  | ||||||
| @@ -75,6 +75,7 @@ are as follows: | |||||||
| * :class:`~mongoengine.fields.DynamicField` | * :class:`~mongoengine.fields.DynamicField` | ||||||
| * :class:`~mongoengine.fields.EmailField` | * :class:`~mongoengine.fields.EmailField` | ||||||
| * :class:`~mongoengine.fields.EmbeddedDocumentField` | * :class:`~mongoengine.fields.EmbeddedDocumentField` | ||||||
|  | * :class:`~mongoengine.fields.EmbeddedDocumentListField` | ||||||
| * :class:`~mongoengine.fields.FileField` | * :class:`~mongoengine.fields.FileField` | ||||||
| * :class:`~mongoengine.fields.FloatField` | * :class:`~mongoengine.fields.FloatField` | ||||||
| * :class:`~mongoengine.fields.GenericEmbeddedDocumentField` | * :class:`~mongoengine.fields.GenericEmbeddedDocumentField` | ||||||
| @@ -213,9 +214,9 @@ document class as the first argument:: | |||||||
|  |  | ||||||
| Dictionary Fields | Dictionary Fields | ||||||
| ----------------- | ----------------- | ||||||
| Often, an embedded document may be used instead of a dictionary -- generally | Often, an embedded document may be used instead of a dictionary – generally  | ||||||
| this is recommended as dictionaries don't support validation or custom field | embedded documents are recommended as dictionaries don’t support validation  | ||||||
| types. However, sometimes you will not know the structure of what you want to | or custom field types. However, sometimes you will not know the structure of what you want to | ||||||
| store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate:: | store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate:: | ||||||
|  |  | ||||||
|     class SurveyResponse(Document): |     class SurveyResponse(Document): | ||||||
|   | |||||||
| @@ -13,3 +13,4 @@ User Guide | |||||||
|    gridfs |    gridfs | ||||||
|    signals |    signals | ||||||
|    text-indexes |    text-indexes | ||||||
|  |    mongomock | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								docs/guide/mongomock.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								docs/guide/mongomock.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | ============================== | ||||||
|  | Use mongomock for testing | ||||||
|  | ============================== | ||||||
|  |  | ||||||
|  | `mongomock <https://github.com/vmalloc/mongomock/>`_ is a package to do just  | ||||||
|  | what the name implies, mocking a mongo database. | ||||||
|  |  | ||||||
|  | To use with mongoengine, simply specify mongomock when connecting with  | ||||||
|  | mongoengine: | ||||||
|  |  | ||||||
|  | .. code-block:: python | ||||||
|  |  | ||||||
|  |     connect('mongoenginetest', host='mongomock://localhost') | ||||||
|  |     conn = get_connection() | ||||||
|  |  | ||||||
|  | or with an alias: | ||||||
|  |  | ||||||
|  | .. code-block:: python | ||||||
|  |  | ||||||
|  |     connect('mongoenginetest', host='mongomock://localhost', alias='testdb') | ||||||
|  |     conn = get_connection('testdb') | ||||||
| @@ -237,7 +237,7 @@ is preferred for achieving this:: | |||||||
|     # All except for the first 5 people |     # All except for the first 5 people | ||||||
|     users = User.objects[5:] |     users = User.objects[5:] | ||||||
|  |  | ||||||
|     # 5 users, starting from the 10th user found |     # 5 users, starting from the 11th user found | ||||||
|     users = User.objects[10:15] |     users = User.objects[10:15] | ||||||
|  |  | ||||||
| You may also index the query to retrieve a single result. If an item at that | You may also index the query to retrieve a single result. If an item at that | ||||||
|   | |||||||
| @@ -1,20 +1,20 @@ | |||||||
| import document |  | ||||||
| from document import * |  | ||||||
| import fields |  | ||||||
| from fields import * |  | ||||||
| import connection | import connection | ||||||
| from connection import * | from connection import * | ||||||
|  | import document | ||||||
|  | from document import * | ||||||
|  | import errors | ||||||
|  | from errors import * | ||||||
|  | import fields | ||||||
|  | from fields import * | ||||||
| import queryset | import queryset | ||||||
| from queryset import * | from queryset import * | ||||||
| import signals | import signals | ||||||
| from signals import * | from signals import * | ||||||
| from errors import * |  | ||||||
| import errors |  | ||||||
|  |  | ||||||
| __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + | __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + | ||||||
|            list(queryset.__all__) + signals.__all__ + list(errors.__all__)) |            list(queryset.__all__) + signals.__all__ + list(errors.__all__)) | ||||||
|  |  | ||||||
| VERSION = (0, 10, 1) | VERSION = (0, 10, 7) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_version(): | def get_version(): | ||||||
| @@ -22,4 +22,5 @@ def get_version(): | |||||||
|         return '.'.join(map(str, VERSION[:-1])) + VERSION[-1] |         return '.'.join(map(str, VERSION[:-1])) + VERSION[-1] | ||||||
|     return '.'.join(map(str, VERSION)) |     return '.'.join(map(str, VERSION)) | ||||||
|  |  | ||||||
|  |  | ||||||
| __version__ = get_version() | __version__ = get_version() | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| import weakref |  | ||||||
| import itertools | import itertools | ||||||
|  | import weakref | ||||||
|  |  | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | ||||||
| @@ -199,7 +199,9 @@ class BaseList(list): | |||||||
|     def _mark_as_changed(self, key=None): |     def _mark_as_changed(self, key=None): | ||||||
|         if hasattr(self._instance, '_mark_as_changed'): |         if hasattr(self._instance, '_mark_as_changed'): | ||||||
|             if key: |             if key: | ||||||
|                 self._instance._mark_as_changed('%s.%s' % (self._name, key)) |                 self._instance._mark_as_changed( | ||||||
|  |                     '%s.%s' % (self._name, key % len(self)) | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 self._instance._mark_as_changed(self._name) |                 self._instance._mark_as_changed(self._name) | ||||||
|  |  | ||||||
| @@ -210,7 +212,7 @@ class EmbeddedDocumentList(BaseList): | |||||||
|     def __match_all(cls, i, kwargs): |     def __match_all(cls, i, kwargs): | ||||||
|         items = kwargs.items() |         items = kwargs.items() | ||||||
|         return all([ |         return all([ | ||||||
|             getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items |             getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -436,7 +438,7 @@ class StrictDict(object): | |||||||
|                 __slots__ = allowed_keys_tuple |                 __slots__ = allowed_keys_tuple | ||||||
|  |  | ||||||
|                 def __repr__(self): |                 def __repr__(self): | ||||||
|                     return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys()) |                     return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) | ||||||
|  |  | ||||||
|             cls._classes[allowed_keys] = SpecificStrictDict |             cls._classes[allowed_keys] = SpecificStrictDict | ||||||
|         return cls._classes[allowed_keys] |         return cls._classes[allowed_keys] | ||||||
|   | |||||||
| @@ -1,28 +1,28 @@ | |||||||
| import copy | import copy | ||||||
| import operator |  | ||||||
| import numbers | import numbers | ||||||
|  | import operator | ||||||
| from collections import Hashable | from collections import Hashable | ||||||
| from functools import partial | from functools import partial | ||||||
|  |  | ||||||
| import pymongo | from bson import ObjectId, json_util | ||||||
| from bson import json_util, ObjectId |  | ||||||
| from bson.dbref import DBRef | from bson.dbref import DBRef | ||||||
| from bson.son import SON | from bson.son import SON | ||||||
|  | import pymongo | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
| from mongoengine.common import _import_class | from mongoengine.base.common import ALLOW_INHERITANCE, get_document | ||||||
| from mongoengine.errors import (ValidationError, InvalidDocumentError, |  | ||||||
|                                 LookUpError, FieldDoesNotExist) |  | ||||||
| from mongoengine.python_support import PY3, txt_type |  | ||||||
| from mongoengine.base.common import get_document, ALLOW_INHERITANCE |  | ||||||
| from mongoengine.base.datastructures import ( | from mongoengine.base.datastructures import ( | ||||||
|     BaseDict, |     BaseDict, | ||||||
|     BaseList, |     BaseList, | ||||||
|     EmbeddedDocumentList, |     EmbeddedDocumentList, | ||||||
|     StrictDict, |     SemiStrictDict, | ||||||
|     SemiStrictDict |     StrictDict | ||||||
| ) | ) | ||||||
| from mongoengine.base.fields import ComplexBaseField | from mongoengine.base.fields import ComplexBaseField | ||||||
|  | from mongoengine.common import _import_class | ||||||
|  | from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, | ||||||
|  |                                 LookUpError, ValidationError) | ||||||
|  | from mongoengine.python_support import PY3, txt_type | ||||||
|  |  | ||||||
| __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | ||||||
|  |  | ||||||
| @@ -51,7 +51,7 @@ class BaseDocument(object): | |||||||
|             # We only want named arguments. |             # We only want named arguments. | ||||||
|             field = iter(self._fields_ordered) |             field = iter(self._fields_ordered) | ||||||
|             # If its an automatic id field then skip to the first defined field |             # If its an automatic id field then skip to the first defined field | ||||||
|             if self._auto_id_field: |             if getattr(self, '_auto_id_field', False): | ||||||
|                 next(field) |                 next(field) | ||||||
|             for value in args: |             for value in args: | ||||||
|                 name = next(field) |                 name = next(field) | ||||||
| @@ -72,11 +72,12 @@ class BaseDocument(object): | |||||||
|         # Check if there are undefined fields supplied to the constructor, |         # Check if there are undefined fields supplied to the constructor, | ||||||
|         # if so raise an Exception. |         # if so raise an Exception. | ||||||
|         if not self._dynamic and (self._meta.get('strict', True) or _created): |         if not self._dynamic and (self._meta.get('strict', True) or _created): | ||||||
|             for var in values.keys(): |             _undefined_fields = set(values.keys()) - set( | ||||||
|                 if var not in self._fields.keys() + ['id', 'pk', '_cls', '_text_score']: |                 self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) | ||||||
|  |             if _undefined_fields: | ||||||
|                 msg = ( |                 msg = ( | ||||||
|                         "The field '{0}' does not exist on the document '{1}'" |                     "The fields '{0}' do not exist on the document '{1}'" | ||||||
|                     ).format(var, self._class_name) |                 ).format(_undefined_fields, self._class_name) | ||||||
|                 raise FieldDoesNotExist(msg) |                 raise FieldDoesNotExist(msg) | ||||||
|  |  | ||||||
|         if self.STRICT and not self._dynamic: |         if self.STRICT and not self._dynamic: | ||||||
| @@ -120,7 +121,7 @@ class BaseDocument(object): | |||||||
|                 else: |                 else: | ||||||
|                     self._data[key] = value |                     self._data[key] = value | ||||||
|  |  | ||||||
|         # Set any get_fieldname_display methods |         # Set any get_<field>_display methods | ||||||
|         self.__set_field_display() |         self.__set_field_display() | ||||||
|  |  | ||||||
|         if self._dynamic: |         if self._dynamic: | ||||||
| @@ -309,7 +310,7 @@ class BaseDocument(object): | |||||||
|         data = SON() |         data = SON() | ||||||
|         data["_id"] = None |         data["_id"] = None | ||||||
|         data['_cls'] = self._class_name |         data['_cls'] = self._class_name | ||||||
|         EmbeddedDocumentField = _import_class("EmbeddedDocumentField") |  | ||||||
|         # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] |         # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] | ||||||
|         root_fields = set([f.split('.')[0] for f in fields]) |         root_fields = set([f.split('.')[0] for f in fields]) | ||||||
|  |  | ||||||
| @@ -324,21 +325,20 @@ class BaseDocument(object): | |||||||
|                 field = self._dynamic_fields.get(field_name) |                 field = self._dynamic_fields.get(field_name) | ||||||
|  |  | ||||||
|             if value is not None: |             if value is not None: | ||||||
|  |                 f_inputs = field.to_mongo.__code__.co_varnames | ||||||
|                 if isinstance(field, EmbeddedDocumentField): |                 ex_vars = {} | ||||||
|                     if fields: |                 if fields and 'fields' in f_inputs: | ||||||
|                     key = '%s.' % field_name |                     key = '%s.' % field_name | ||||||
|                     embedded_fields = [ |                     embedded_fields = [ | ||||||
|                         i.replace(key, '') for i in fields |                         i.replace(key, '') for i in fields | ||||||
|                         if i.startswith(key)] |                         if i.startswith(key)] | ||||||
|  |  | ||||||
|                     else: |                     ex_vars['fields'] = embedded_fields | ||||||
|                         embedded_fields = [] |  | ||||||
|  |  | ||||||
|                     value = field.to_mongo(value, use_db_field=use_db_field, |                 if 'use_db_field' in f_inputs: | ||||||
|                                            fields=embedded_fields) |                     ex_vars['use_db_field'] = use_db_field | ||||||
|                 else: |  | ||||||
|                     value = field.to_mongo(value) |                 value = field.to_mongo(value, **ex_vars) | ||||||
|  |  | ||||||
|             # Handle self generating fields |             # Handle self generating fields | ||||||
|             if value is None and field._auto_gen: |             if value is None and field._auto_gen: | ||||||
| @@ -491,7 +491,7 @@ class BaseDocument(object): | |||||||
|                 # remove lower level changed fields |                 # remove lower level changed fields | ||||||
|                 level = '.'.join(levels[:idx]) + '.' |                 level = '.'.join(levels[:idx]) + '.' | ||||||
|                 remove = self._changed_fields.remove |                 remove = self._changed_fields.remove | ||||||
|                 for field in self._changed_fields: |                 for field in self._changed_fields[:]: | ||||||
|                     if field.startswith(level): |                     if field.startswith(level): | ||||||
|                         remove(field) |                         remove(field) | ||||||
|  |  | ||||||
| @@ -566,8 +566,10 @@ class BaseDocument(object): | |||||||
|                     continue |                     continue | ||||||
|             if isinstance(field, ReferenceField): |             if isinstance(field, ReferenceField): | ||||||
|                 continue |                 continue | ||||||
|             elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) |             elif ( | ||||||
|                   and db_field_name not in changed_fields): |                 isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and | ||||||
|  |                 db_field_name not in changed_fields | ||||||
|  |             ): | ||||||
|                 # Find all embedded fields that have been changed |                 # Find all embedded fields that have been changed | ||||||
|                 changed = data._get_changed_fields(inspected) |                 changed = data._get_changed_fields(inspected) | ||||||
|                 changed_fields += ["%s%s" % (key, k) for k in changed if k] |                 changed_fields += ["%s%s" % (key, k) for k in changed if k] | ||||||
| @@ -606,7 +608,9 @@ class BaseDocument(object): | |||||||
|                 for p in parts: |                 for p in parts: | ||||||
|                     if isinstance(d, (ObjectId, DBRef)): |                     if isinstance(d, (ObjectId, DBRef)): | ||||||
|                         break |                         break | ||||||
|                     elif isinstance(d, list) and p.isdigit(): |                     elif isinstance(d, list) and p.lstrip('-').isdigit(): | ||||||
|  |                         if p[0] == '-': | ||||||
|  |                             p = str(len(d) + int(p)) | ||||||
|                         try: |                         try: | ||||||
|                             d = d[int(p)] |                             d = d[int(p)] | ||||||
|                         except IndexError: |                         except IndexError: | ||||||
| @@ -640,7 +644,9 @@ class BaseDocument(object): | |||||||
|                 parts = path.split('.') |                 parts = path.split('.') | ||||||
|                 db_field_name = parts.pop() |                 db_field_name = parts.pop() | ||||||
|                 for p in parts: |                 for p in parts: | ||||||
|                     if isinstance(d, list) and p.isdigit(): |                     if isinstance(d, list) and p.lstrip('-').isdigit(): | ||||||
|  |                         if p[0] == '-': | ||||||
|  |                             p = str(len(d) + int(p)) | ||||||
|                         d = d[int(p)] |                         d = d[int(p)] | ||||||
|                     elif (hasattr(d, '__getattribute__') and |                     elif (hasattr(d, '__getattribute__') and | ||||||
|                           not isinstance(d, dict)): |                           not isinstance(d, dict)): | ||||||
| @@ -708,14 +714,6 @@ class BaseDocument(object): | |||||||
|                         del data[field.db_field] |                         del data[field.db_field] | ||||||
|                 except (AttributeError, ValueError), e: |                 except (AttributeError, ValueError), e: | ||||||
|                     errors_dict[field_name] = e |                     errors_dict[field_name] = e | ||||||
|             elif field.default: |  | ||||||
|                 default = field.default |  | ||||||
|                 if callable(default): |  | ||||||
|                     default = default() |  | ||||||
|                 if isinstance(default, BaseDocument): |  | ||||||
|                     changed_fields.append(field_name) |  | ||||||
|                 elif not only_fields or field_name in only_fields: |  | ||||||
|                     changed_fields.append(field_name) |  | ||||||
|  |  | ||||||
|         if errors_dict: |         if errors_dict: | ||||||
|             errors = "\n".join(["%s - %s" % (k, v) |             errors = "\n".join(["%s - %s" % (k, v) | ||||||
| @@ -779,8 +777,12 @@ class BaseDocument(object): | |||||||
|         # Check to see if we need to include _cls |         # Check to see if we need to include _cls | ||||||
|         allow_inheritance = cls._meta.get('allow_inheritance', |         allow_inheritance = cls._meta.get('allow_inheritance', | ||||||
|                                           ALLOW_INHERITANCE) |                                           ALLOW_INHERITANCE) | ||||||
|         include_cls = (allow_inheritance and not spec.get('sparse', False) and |         include_cls = ( | ||||||
|                        spec.get('cls',  True) and '_cls' not in spec['fields']) |             allow_inheritance and | ||||||
|  |             not spec.get('sparse', False) and | ||||||
|  |             spec.get('cls', True) and | ||||||
|  |             '_cls' not in spec['fields'] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # 733: don't include cls if index_cls is False unless there is an explicit cls with the index |         # 733: don't include cls if index_cls is False unless there is an explicit cls with the index | ||||||
|         include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) |         include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) | ||||||
| @@ -839,10 +841,6 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|         if index_list: |         if index_list: | ||||||
|             spec['fields'] = index_list |             spec['fields'] = index_list | ||||||
|         if spec.get('sparse', False) and len(spec['fields']) > 1: |  | ||||||
|             raise ValueError( |  | ||||||
|                 'Sparse indexes can only have one field in them. ' |  | ||||||
|                 'See https://jira.mongodb.org/browse/SERVER-2193') |  | ||||||
|  |  | ||||||
|         return spec |         return spec | ||||||
|  |  | ||||||
| @@ -978,7 +976,7 @@ class BaseDocument(object): | |||||||
|                 if hasattr(getattr(field, 'field', None), 'lookup_member'): |                 if hasattr(getattr(field, 'field', None), 'lookup_member'): | ||||||
|                     new_field = field.field.lookup_member(field_name) |                     new_field = field.field.lookup_member(field_name) | ||||||
|                 elif cls._dynamic and (isinstance(field, DynamicField) or |                 elif cls._dynamic and (isinstance(field, DynamicField) or | ||||||
|                                        getattr(getattr(field, 'document_type'), '_dynamic')): |                                        getattr(getattr(field, 'document_type', None), '_dynamic', None)): | ||||||
|                     new_field = DynamicField(db_field=field_name) |                     new_field = DynamicField(db_field=field_name) | ||||||
|                 else: |                 else: | ||||||
|                     # Look up subfield on the previous field or raise |                     # Look up subfield on the previous field or raise | ||||||
| @@ -1007,19 +1005,18 @@ class BaseDocument(object): | |||||||
|         return '.'.join(parts) |         return '.'.join(parts) | ||||||
|  |  | ||||||
|     def __set_field_display(self): |     def __set_field_display(self): | ||||||
|         """Dynamically set the display value for a field with choices""" |         """For each field that specifies choices, create a | ||||||
|         for attr_name, field in self._fields.items(): |         get_<field>_display method. | ||||||
|             if field.choices: |         """ | ||||||
|                 if self._dynamic: |         fields_with_choices = [(n, f) for n, f in self._fields.items() | ||||||
|                     obj = self |                                if f.choices] | ||||||
|                 else: |         for attr_name, field in fields_with_choices: | ||||||
|                     obj = type(self) |             setattr(self, | ||||||
|                 setattr(obj, |  | ||||||
|                     'get_%s_display' % attr_name, |                     'get_%s_display' % attr_name, | ||||||
|                     partial(self.__get_field_display, field=field)) |                     partial(self.__get_field_display, field=field)) | ||||||
|  |  | ||||||
|     def __get_field_display(self, field): |     def __get_field_display(self, field): | ||||||
|         """Returns the display value for a choice field""" |         """Return the display value for a choice field""" | ||||||
|         value = getattr(self, field.name) |         value = getattr(self, field.name) | ||||||
|         if field.choices and isinstance(field.choices[0], (list, tuple)): |         if field.choices and isinstance(field.choices[0], (list, tuple)): | ||||||
|             return dict(field.choices).get(value, value) |             return dict(field.choices).get(value, value) | ||||||
|   | |||||||
| @@ -5,12 +5,12 @@ import weakref | |||||||
| from bson import DBRef, ObjectId, SON | from bson import DBRef, ObjectId, SON | ||||||
| import pymongo | import pymongo | ||||||
|  |  | ||||||
| from mongoengine.common import _import_class |  | ||||||
| from mongoengine.errors import ValidationError |  | ||||||
| from mongoengine.base.common import ALLOW_INHERITANCE | from mongoengine.base.common import ALLOW_INHERITANCE | ||||||
| from mongoengine.base.datastructures import ( | from mongoengine.base.datastructures import ( | ||||||
|     BaseDict, BaseList, EmbeddedDocumentList |     BaseDict, BaseList, EmbeddedDocumentList | ||||||
| ) | ) | ||||||
|  | from mongoengine.common import _import_class | ||||||
|  | from mongoengine.errors import ValidationError | ||||||
|  |  | ||||||
| __all__ = ("BaseField", "ComplexBaseField", | __all__ = ("BaseField", "ComplexBaseField", | ||||||
|            "ObjectIdField", "GeoJsonBaseField") |            "ObjectIdField", "GeoJsonBaseField") | ||||||
| @@ -133,7 +133,7 @@ class BaseField(object): | |||||||
|                 if (self.name not in instance._data or |                 if (self.name not in instance._data or | ||||||
|                         instance._data[self.name] != value): |                         instance._data[self.name] != value): | ||||||
|                     instance._mark_as_changed(self.name) |                     instance._mark_as_changed(self.name) | ||||||
|             except: |             except Exception: | ||||||
|                 # Values cant be compared eg: naive and tz datetimes |                 # Values cant be compared eg: naive and tz datetimes | ||||||
|                 # So mark it as changed |                 # So mark it as changed | ||||||
|                 instance._mark_as_changed(self.name) |                 instance._mark_as_changed(self.name) | ||||||
| @@ -163,6 +163,19 @@ class BaseField(object): | |||||||
|         """ |         """ | ||||||
|         return self.to_python(value) |         return self.to_python(value) | ||||||
|  |  | ||||||
|  |     def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): | ||||||
|  |         """A helper method to call to_mongo with proper inputs | ||||||
|  |         """ | ||||||
|  |         f_inputs = self.to_mongo.__code__.co_varnames | ||||||
|  |         ex_vars = {} | ||||||
|  |         if 'fields' in f_inputs: | ||||||
|  |             ex_vars['fields'] = fields | ||||||
|  |  | ||||||
|  |         if 'use_db_field' in f_inputs: | ||||||
|  |             ex_vars['use_db_field'] = use_db_field | ||||||
|  |  | ||||||
|  |         return self.to_mongo(value, **ex_vars) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         """Prepare a value that is being used in a query for PyMongo. |         """Prepare a value that is being used in a query for PyMongo. | ||||||
|         """ |         """ | ||||||
| @@ -193,7 +206,6 @@ class BaseField(object): | |||||||
|         elif value not in choice_list: |         elif value not in choice_list: | ||||||
|             self.error('Value must be one of %s' % unicode(choice_list)) |             self.error('Value must be one of %s' % unicode(choice_list)) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def _validate(self, value, **kwargs): |     def _validate(self, value, **kwargs): | ||||||
|         # Check the Choices Constraint |         # Check the Choices Constraint | ||||||
|         if self.choices: |         if self.choices: | ||||||
| @@ -285,8 +297,6 @@ class ComplexBaseField(BaseField): | |||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         """Convert a MongoDB-compatible type to a Python type. |         """Convert a MongoDB-compatible type to a Python type. | ||||||
|         """ |         """ | ||||||
|         Document = _import_class('Document') |  | ||||||
|  |  | ||||||
|         if isinstance(value, basestring): |         if isinstance(value, basestring): | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
| @@ -306,6 +316,7 @@ class ComplexBaseField(BaseField): | |||||||
|             value_dict = dict([(key, self.field.to_python(item)) |             value_dict = dict([(key, self.field.to_python(item)) | ||||||
|                                for key, item in value.items()]) |                                for key, item in value.items()]) | ||||||
|         else: |         else: | ||||||
|  |             Document = _import_class('Document') | ||||||
|             value_dict = {} |             value_dict = {} | ||||||
|             for k, v in value.items(): |             for k, v in value.items(): | ||||||
|                 if isinstance(v, Document): |                 if isinstance(v, Document): | ||||||
| @@ -325,7 +336,7 @@ class ComplexBaseField(BaseField): | |||||||
|                                          key=operator.itemgetter(0))] |                                          key=operator.itemgetter(0))] | ||||||
|         return value_dict |         return value_dict | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value, use_db_field=True, fields=None): | ||||||
|         """Convert a Python type to a MongoDB-compatible type. |         """Convert a Python type to a MongoDB-compatible type. | ||||||
|         """ |         """ | ||||||
|         Document = _import_class("Document") |         Document = _import_class("Document") | ||||||
| @@ -339,7 +350,7 @@ class ComplexBaseField(BaseField): | |||||||
|             if isinstance(value, Document): |             if isinstance(value, Document): | ||||||
|                 return GenericReferenceField().to_mongo(value) |                 return GenericReferenceField().to_mongo(value) | ||||||
|             cls = value.__class__ |             cls = value.__class__ | ||||||
|             val = value.to_mongo() |             val = value.to_mongo(use_db_field, fields) | ||||||
|             # If it's a document that is not inherited add _cls |             # If it's a document that is not inherited add _cls | ||||||
|             if isinstance(value, EmbeddedDocument): |             if isinstance(value, EmbeddedDocument): | ||||||
|                 val['_cls'] = cls.__name__ |                 val['_cls'] = cls.__name__ | ||||||
| @@ -354,7 +365,7 @@ class ComplexBaseField(BaseField): | |||||||
|                 return value |                 return value | ||||||
|  |  | ||||||
|         if self.field: |         if self.field: | ||||||
|             value_dict = dict([(key, self.field.to_mongo(item)) |             value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields)) | ||||||
|                                for key, item in value.iteritems()]) |                                for key, item in value.iteritems()]) | ||||||
|         else: |         else: | ||||||
|             value_dict = {} |             value_dict = {} | ||||||
| @@ -379,13 +390,13 @@ class ComplexBaseField(BaseField): | |||||||
|                         value_dict[k] = DBRef(collection, v.pk) |                         value_dict[k] = DBRef(collection, v.pk) | ||||||
|                 elif hasattr(v, 'to_mongo'): |                 elif hasattr(v, 'to_mongo'): | ||||||
|                     cls = v.__class__ |                     cls = v.__class__ | ||||||
|                     val = v.to_mongo() |                     val = v.to_mongo(use_db_field, fields) | ||||||
|                     # If it's a document that is not inherited add _cls |                     # If it's a document that is not inherited add _cls | ||||||
|                     if isinstance(v, (Document, EmbeddedDocument)): |                     if isinstance(v, (Document, EmbeddedDocument)): | ||||||
|                         val['_cls'] = cls.__name__ |                         val['_cls'] = cls.__name__ | ||||||
|                     value_dict[k] = val |                     value_dict[k] = val | ||||||
|                 else: |                 else: | ||||||
|                     value_dict[k] = self.to_mongo(v) |                     value_dict[k] = self.to_mongo(v, use_db_field, fields) | ||||||
|  |  | ||||||
|         if is_list:  # Convert back to a list |         if is_list:  # Convert back to a list | ||||||
|             return [v for _, v in sorted(value_dict.items(), |             return [v for _, v in sorted(value_dict.items(), | ||||||
| @@ -439,7 +450,7 @@ class ObjectIdField(BaseField): | |||||||
|         try: |         try: | ||||||
|             if not isinstance(value, ObjectId): |             if not isinstance(value, ObjectId): | ||||||
|                 value = ObjectId(value) |                 value = ObjectId(value) | ||||||
|         except: |         except Exception: | ||||||
|             pass |             pass | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
| @@ -458,7 +469,7 @@ class ObjectIdField(BaseField): | |||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         try: |         try: | ||||||
|             ObjectId(unicode(value)) |             ObjectId(unicode(value)) | ||||||
|         except: |         except Exception: | ||||||
|             self.error('Invalid Object ID') |             self.error('Invalid Object ID') | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -510,7 +521,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         # Quick and dirty validator |         # Quick and dirty validator | ||||||
|         try: |         try: | ||||||
|             value[0][0][0] |             value[0][0][0] | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             return "Invalid Polygon must contain at least one valid linestring" |             return "Invalid Polygon must contain at least one valid linestring" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
| @@ -534,7 +545,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         # Quick and dirty validator |         # Quick and dirty validator | ||||||
|         try: |         try: | ||||||
|             value[0][0] |             value[0][0] | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             return "Invalid LineString must contain at least one valid point" |             return "Invalid LineString must contain at least one valid point" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
| @@ -565,7 +576,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         # Quick and dirty validator |         # Quick and dirty validator | ||||||
|         try: |         try: | ||||||
|             value[0][0] |             value[0][0] | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             return "Invalid MultiPoint must contain at least one valid point" |             return "Invalid MultiPoint must contain at least one valid point" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
| @@ -584,7 +595,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         # Quick and dirty validator |         # Quick and dirty validator | ||||||
|         try: |         try: | ||||||
|             value[0][0][0] |             value[0][0][0] | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             return "Invalid MultiLineString must contain at least one valid linestring" |             return "Invalid MultiLineString must contain at least one valid linestring" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
| @@ -606,7 +617,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         # Quick and dirty validator |         # Quick and dirty validator | ||||||
|         try: |         try: | ||||||
|             value[0][0][0][0] |             value[0][0][0][0] | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             return "Invalid MultiPolygon must contain at least one valid Polygon" |             return "Invalid MultiPolygon must contain at least one valid Polygon" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
|  | from mongoengine.base.common import ALLOW_INHERITANCE, _document_registry | ||||||
|  | from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import InvalidDocumentError | from mongoengine.errors import InvalidDocumentError | ||||||
| from mongoengine.python_support import PY3 | from mongoengine.python_support import PY3 | ||||||
| @@ -7,16 +9,14 @@ from mongoengine.queryset import (DO_NOTHING, DoesNotExist, | |||||||
|                                   MultipleObjectsReturned, |                                   MultipleObjectsReturned, | ||||||
|                                   QuerySetManager) |                                   QuerySetManager) | ||||||
|  |  | ||||||
| from mongoengine.base.common import _document_registry, ALLOW_INHERITANCE |  | ||||||
| from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField |  | ||||||
|  |  | ||||||
| __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') | __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') | ||||||
|  |  | ||||||
|  |  | ||||||
| class DocumentMetaclass(type): | class DocumentMetaclass(type): | ||||||
|     """Metaclass for all documents. |     """Metaclass for all documents.""" | ||||||
|     """ |  | ||||||
|  |  | ||||||
|  |     # TODO lower complexity of this method | ||||||
|     def __new__(cls, name, bases, attrs): |     def __new__(cls, name, bases, attrs): | ||||||
|         flattened_bases = cls._get_bases(bases) |         flattened_bases = cls._get_bases(bases) | ||||||
|         super_new = super(DocumentMetaclass, cls).__new__ |         super_new = super(DocumentMetaclass, cls).__new__ | ||||||
| @@ -162,7 +162,7 @@ class DocumentMetaclass(type): | |||||||
|         # copies __func__ into im_func and __self__ into im_self for |         # copies __func__ into im_func and __self__ into im_self for | ||||||
|         # classmethod objects in Document derived classes. |         # classmethod objects in Document derived classes. | ||||||
|         if PY3: |         if PY3: | ||||||
|             for key, val in new_class.__dict__.items(): |             for val in new_class.__dict__.values(): | ||||||
|                 if isinstance(val, classmethod): |                 if isinstance(val, classmethod): | ||||||
|                     f = val.__get__(new_class) |                     f = val.__get__(new_class) | ||||||
|                     if hasattr(f, '__func__') and not hasattr(f, 'im_func'): |                     if hasattr(f, '__func__') and not hasattr(f, 'im_func'): | ||||||
|   | |||||||
| @@ -1,11 +1,12 @@ | |||||||
| from pymongo import MongoClient, ReadPreference, uri_parser | from pymongo import MongoClient, ReadPreference, uri_parser | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import (IS_PYMONGO_3, str_types) | ||||||
|  |  | ||||||
| __all__ = ['ConnectionError', 'connect', 'register_connection', | __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||||
|            'DEFAULT_CONNECTION_NAME'] |            'DEFAULT_CONNECTION_NAME'] | ||||||
|  |  | ||||||
|  |  | ||||||
| DEFAULT_CONNECTION_NAME = 'default' | DEFAULT_CONNECTION_NAME = 'default' | ||||||
|  |  | ||||||
| if IS_PYMONGO_3: | if IS_PYMONGO_3: | ||||||
|     READ_PREFERENCE = ReadPreference.PRIMARY |     READ_PREFERENCE = ReadPreference.PRIMARY | ||||||
| else: | else: | ||||||
| @@ -25,6 +26,7 @@ _dbs = {} | |||||||
| def register_connection(alias, name=None, host=None, port=None, | def register_connection(alias, name=None, host=None, port=None, | ||||||
|                         read_preference=READ_PREFERENCE, |                         read_preference=READ_PREFERENCE, | ||||||
|                         username=None, password=None, authentication_source=None, |                         username=None, password=None, authentication_source=None, | ||||||
|  |                         authentication_mechanism=None, | ||||||
|                         **kwargs): |                         **kwargs): | ||||||
|     """Add a connection. |     """Add a connection. | ||||||
|  |  | ||||||
| @@ -38,8 +40,14 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|     :param username: username to authenticate with |     :param username: username to authenticate with | ||||||
|     :param password: password to authenticate with |     :param password: password to authenticate with | ||||||
|     :param authentication_source: database to authenticate against |     :param authentication_source: database to authenticate against | ||||||
|  |     :param authentication_mechanism: database authentication mechanisms. | ||||||
|  |         By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, | ||||||
|  |         MONGODB-CR (MongoDB Challenge Response protocol) for older servers. | ||||||
|  |     :param is_mock: explicitly use mongomock for this connection | ||||||
|  |         (can also be done by using `mongomock://` as db host prefix) | ||||||
|     :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver |     :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver | ||||||
|  |  | ||||||
|  |     .. versionchanged:: 0.10.6 - added mongomock support | ||||||
|     """ |     """ | ||||||
|     global _connection_settings |     global _connection_settings | ||||||
|  |  | ||||||
| @@ -50,12 +58,26 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|         'read_preference': read_preference, |         'read_preference': read_preference, | ||||||
|         'username': username, |         'username': username, | ||||||
|         'password': password, |         'password': password, | ||||||
|         'authentication_source': authentication_source |         'authentication_source': authentication_source, | ||||||
|  |         'authentication_mechanism': authentication_mechanism | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     # Handle uri style connections |     # Handle uri style connections | ||||||
|     if "://" in conn_settings['host']: |     conn_host = conn_settings['host'] | ||||||
|         uri_dict = uri_parser.parse_uri(conn_settings['host']) |     # host can be a list or a string, so if string, force to a list | ||||||
|  |     if isinstance(conn_host, str_types): | ||||||
|  |         conn_host = [conn_host] | ||||||
|  |  | ||||||
|  |     resolved_hosts = [] | ||||||
|  |     for entity in conn_host: | ||||||
|  |         # Handle uri style connections | ||||||
|  |         if entity.startswith('mongomock://'): | ||||||
|  |             conn_settings['is_mock'] = True | ||||||
|  |             # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` | ||||||
|  |             resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) | ||||||
|  |         elif '://' in entity: | ||||||
|  |             uri_dict = uri_parser.parse_uri(entity) | ||||||
|  |             resolved_hosts.append(entity) | ||||||
|             conn_settings.update({ |             conn_settings.update({ | ||||||
|                 'name': uri_dict.get('database') or name, |                 'name': uri_dict.get('database') or name, | ||||||
|                 'username': uri_dict.get('username'), |                 'username': uri_dict.get('username'), | ||||||
| @@ -67,6 +89,11 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|                 conn_settings['replicaSet'] = True |                 conn_settings['replicaSet'] = True | ||||||
|             if 'authsource' in uri_options: |             if 'authsource' in uri_options: | ||||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] |                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||||
|  |             if 'authmechanism' in uri_options: | ||||||
|  |                 conn_settings['authentication_mechanism'] = uri_options['authmechanism'] | ||||||
|  |         else: | ||||||
|  |             resolved_hosts.append(entity) | ||||||
|  |     conn_settings['host'] = resolved_hosts | ||||||
|  |  | ||||||
|     # Deprecated parameters that should not be passed on |     # Deprecated parameters that should not be passed on | ||||||
|     kwargs.pop('slaves', None) |     kwargs.pop('slaves', None) | ||||||
| @@ -105,8 +132,21 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|         conn_settings.pop('username', None) |         conn_settings.pop('username', None) | ||||||
|         conn_settings.pop('password', None) |         conn_settings.pop('password', None) | ||||||
|         conn_settings.pop('authentication_source', None) |         conn_settings.pop('authentication_source', None) | ||||||
|  |         conn_settings.pop('authentication_mechanism', None) | ||||||
|  |  | ||||||
|  |         is_mock = conn_settings.pop('is_mock', None) | ||||||
|  |         if is_mock: | ||||||
|  |             # Use MongoClient from mongomock | ||||||
|  |             try: | ||||||
|  |                 import mongomock | ||||||
|  |             except ImportError: | ||||||
|  |                 raise RuntimeError('You need mongomock installed ' | ||||||
|  |                                    'to mock MongoEngine.') | ||||||
|  |             connection_class = mongomock.MongoClient | ||||||
|  |         else: | ||||||
|  |             # Use MongoClient from pymongo | ||||||
|             connection_class = MongoClient |             connection_class = MongoClient | ||||||
|  |  | ||||||
|         if 'replicaSet' in conn_settings: |         if 'replicaSet' in conn_settings: | ||||||
|             # Discard port since it can't be used on MongoReplicaSetClient |             # Discard port since it can't be used on MongoReplicaSetClient | ||||||
|             conn_settings.pop('port', None) |             conn_settings.pop('port', None) | ||||||
| @@ -126,6 +166,8 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|                 connection_settings.pop('name', None) |                 connection_settings.pop('name', None) | ||||||
|                 connection_settings.pop('username', None) |                 connection_settings.pop('username', None) | ||||||
|                 connection_settings.pop('password', None) |                 connection_settings.pop('password', None) | ||||||
|  |                 connection_settings.pop('authentication_source', None) | ||||||
|  |                 connection_settings.pop('authentication_mechanism', None) | ||||||
|                 if conn_settings == connection_settings and _connections.get(db_alias, None): |                 if conn_settings == connection_settings and _connections.get(db_alias, None): | ||||||
|                     connection = _connections[db_alias] |                     connection = _connections[db_alias] | ||||||
|                     break |                     break | ||||||
| @@ -145,11 +187,13 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|         conn = get_connection(alias) |         conn = get_connection(alias) | ||||||
|         conn_settings = _connection_settings[alias] |         conn_settings = _connection_settings[alias] | ||||||
|         db = conn[conn_settings['name']] |         db = conn[conn_settings['name']] | ||||||
|  |         auth_kwargs = {'source': conn_settings['authentication_source']} | ||||||
|  |         if conn_settings['authentication_mechanism'] is not None: | ||||||
|  |             auth_kwargs['mechanism'] = conn_settings['authentication_mechanism'] | ||||||
|         # Authenticate if necessary |         # Authenticate if necessary | ||||||
|         if conn_settings['username'] and conn_settings['password']: |         if conn_settings['username'] and (conn_settings['password'] or | ||||||
|             db.authenticate(conn_settings['username'], |                                           conn_settings['authentication_mechanism'] == 'MONGODB-X509'): | ||||||
|                             conn_settings['password'], |             db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs) | ||||||
|                             source=conn_settings['authentication_source']) |  | ||||||
|         _dbs[alias] = db |         _dbs[alias] = db | ||||||
|     return _dbs[alias] |     return _dbs[alias] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,13 +1,14 @@ | |||||||
| from bson import DBRef, SON | from bson import DBRef, SON | ||||||
|  |  | ||||||
| from base import ( | from .base import ( | ||||||
|     BaseDict, BaseList, EmbeddedDocumentList, |     BaseDict, BaseList, EmbeddedDocumentList, | ||||||
|     TopLevelDocumentMetaclass, get_document |     TopLevelDocumentMetaclass, get_document | ||||||
| ) | ) | ||||||
| from fields import (ReferenceField, ListField, DictField, MapField) | from .connection import get_db | ||||||
| from connection import get_db | from .document import Document, EmbeddedDocument | ||||||
| from queryset import QuerySet | from .fields import DictField, ListField, MapField, ReferenceField | ||||||
| from document import Document, EmbeddedDocument | from .python_support import txt_type | ||||||
|  | from .queryset import QuerySet | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeReference(object): | class DeReference(object): | ||||||
| @@ -226,7 +227,7 @@ class DeReference(object): | |||||||
|                         data[k]._data[field_name] = self.object_map.get( |                         data[k]._data[field_name] = self.object_map.get( | ||||||
|                             (v['_ref'].collection, v['_ref'].id), v) |                             (v['_ref'].collection, v['_ref'].id), v) | ||||||
|                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: |                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||||
|                         item_name = "{0}.{1}.{2}".format(name, k, field_name) |                         item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) | ||||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) |                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) | ||||||
|             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: |             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||||
|                 item_name = '%s.%s' % (name, k) if name else name |                 item_name = '%s.%s' % (name, k) if name else name | ||||||
|   | |||||||
| @@ -1,28 +1,29 @@ | |||||||
| import warnings |  | ||||||
| import pymongo |  | ||||||
| import re | import re | ||||||
|  | import warnings | ||||||
|  |  | ||||||
| from pymongo.read_preferences import ReadPreference |  | ||||||
| from bson.dbref import DBRef | from bson.dbref import DBRef | ||||||
|  | import pymongo | ||||||
|  | from pymongo.read_preferences import ReadPreference | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
| from mongoengine.common import _import_class |  | ||||||
| from mongoengine.base import ( | from mongoengine.base import ( | ||||||
|     DocumentMetaclass, |  | ||||||
|     TopLevelDocumentMetaclass, |  | ||||||
|     BaseDocument, |  | ||||||
|     BaseDict, |  | ||||||
|     BaseList, |  | ||||||
|     EmbeddedDocumentList, |  | ||||||
|     ALLOW_INHERITANCE, |     ALLOW_INHERITANCE, | ||||||
|  |     BaseDict, | ||||||
|  |     BaseDocument, | ||||||
|  |     BaseList, | ||||||
|  |     DocumentMetaclass, | ||||||
|  |     EmbeddedDocumentList, | ||||||
|  |     TopLevelDocumentMetaclass, | ||||||
|     get_document |     get_document | ||||||
| ) | ) | ||||||
| from mongoengine.errors import (InvalidQueryError, InvalidDocumentError, | from mongoengine.common import _import_class | ||||||
|  | from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||||
|  | from mongoengine.context_managers import switch_collection, switch_db | ||||||
|  | from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, | ||||||
|                                 SaveConditionError) |                                 SaveConditionError) | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import IS_PYMONGO_3 | ||||||
| from mongoengine.queryset import (OperationError, NotUniqueError, | from mongoengine.queryset import (NotUniqueError, OperationError, | ||||||
|                                   QuerySet, transform) |                                   QuerySet, transform) | ||||||
| from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME |  | ||||||
| from mongoengine.context_managers import switch_db, switch_collection |  | ||||||
|  |  | ||||||
| __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | ||||||
|            'DynamicEmbeddedDocument', 'OperationError', |            'DynamicEmbeddedDocument', 'OperationError', | ||||||
| @@ -217,7 +218,7 @@ class Document(BaseDocument): | |||||||
|         Returns True if the document has been updated or False if the document |         Returns True if the document has been updated or False if the document | ||||||
|         in the database doesn't match the query. |         in the database doesn't match the query. | ||||||
|  |  | ||||||
|         .. note:: All unsaved changes that has been made to the document are |         .. note:: All unsaved changes that have been made to the document are | ||||||
|             rejected if the method returns True. |             rejected if the method returns True. | ||||||
|  |  | ||||||
|         :param query: the update will be performed only if the document in the |         :param query: the update will be performed only if the document in the | ||||||
| @@ -250,7 +251,7 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|     def save(self, force_insert=False, validate=True, clean=True, |     def save(self, force_insert=False, validate=True, clean=True, | ||||||
|              write_concern=None, cascade=None, cascade_kwargs=None, |              write_concern=None, cascade=None, cascade_kwargs=None, | ||||||
|              _refs=None, save_condition=None, **kwargs): |              _refs=None, save_condition=None, signal_kwargs=None, **kwargs): | ||||||
|         """Save the :class:`~mongoengine.Document` to the database. If the |         """Save the :class:`~mongoengine.Document` to the database. If the | ||||||
|         document already exists, it will be updated, otherwise it will be |         document already exists, it will be updated, otherwise it will be | ||||||
|         created. |         created. | ||||||
| @@ -276,6 +277,8 @@ class Document(BaseDocument): | |||||||
|         :param save_condition: only perform save if matching record in db |         :param save_condition: only perform save if matching record in db | ||||||
|             satisfies condition(s) (e.g. version number). |             satisfies condition(s) (e.g. version number). | ||||||
|             Raises :class:`OperationError` if the conditions are not satisfied |             Raises :class:`OperationError` if the conditions are not satisfied | ||||||
|  |         :parm signal_kwargs: (optional) kwargs dictionary to be passed to | ||||||
|  |             the signal calls. | ||||||
|  |  | ||||||
|         .. versionchanged:: 0.5 |         .. versionchanged:: 0.5 | ||||||
|             In existing documents it only saves changed fields using |             In existing documents it only saves changed fields using | ||||||
| @@ -297,8 +300,11 @@ class Document(BaseDocument): | |||||||
|             :class:`OperationError` exception raised if save_condition fails. |             :class:`OperationError` exception raised if save_condition fails. | ||||||
|         .. versionchanged:: 0.10.1 |         .. versionchanged:: 0.10.1 | ||||||
|             :class: save_condition failure now raises a `SaveConditionError` |             :class: save_condition failure now raises a `SaveConditionError` | ||||||
|  |         .. versionchanged:: 0.10.7 | ||||||
|  |             Add signal_kwargs argument | ||||||
|         """ |         """ | ||||||
|         signals.pre_save.send(self.__class__, document=self) |         signal_kwargs = signal_kwargs or {} | ||||||
|  |         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) | ||||||
|  |  | ||||||
|         if validate: |         if validate: | ||||||
|             self.validate(clean=clean) |             self.validate(clean=clean) | ||||||
| @@ -311,7 +317,7 @@ class Document(BaseDocument): | |||||||
|         created = ('_id' not in doc or self._created or force_insert) |         created = ('_id' not in doc or self._created or force_insert) | ||||||
|  |  | ||||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, |         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||||
|                                               created=created) |                                               created=created, **signal_kwargs) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             collection = self._get_collection() |             collection = self._get_collection() | ||||||
| @@ -327,8 +333,10 @@ class Document(BaseDocument): | |||||||
|                     # Correct behaviour in 2.X and in 3.0.1+ versions |                     # Correct behaviour in 2.X and in 3.0.1+ versions | ||||||
|                     if not object_id and pymongo.version_tuple == (3, 0): |                     if not object_id and pymongo.version_tuple == (3, 0): | ||||||
|                         pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) |                         pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) | ||||||
|                         object_id = self._qs.filter(pk=pk_as_mongo_obj).first() and \ |                         object_id = ( | ||||||
|  |                             self._qs.filter(pk=pk_as_mongo_obj).first() and | ||||||
|                             self._qs.filter(pk=pk_as_mongo_obj).first().pk |                             self._qs.filter(pk=pk_as_mongo_obj).first().pk | ||||||
|  |                         )  # TODO doesn't this make 2 queries? | ||||||
|             else: |             else: | ||||||
|                 object_id = doc['_id'] |                 object_id = doc['_id'] | ||||||
|                 updates, removals = self._delta() |                 updates, removals = self._delta() | ||||||
| @@ -400,14 +408,15 @@ class Document(BaseDocument): | |||||||
|         if created or id_field not in self._meta.get('shard_key', []): |         if created or id_field not in self._meta.get('shard_key', []): | ||||||
|             self[id_field] = self._fields[id_field].to_python(object_id) |             self[id_field] = self._fields[id_field].to_python(object_id) | ||||||
|  |  | ||||||
|         signals.post_save.send(self.__class__, document=self, created=created) |         signals.post_save.send(self.__class__, document=self, | ||||||
|  |                                created=created, **signal_kwargs) | ||||||
|         self._clear_changed_fields() |         self._clear_changed_fields() | ||||||
|         self._created = False |         self._created = False | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def cascade_save(self, *args, **kwargs): |     def cascade_save(self, *args, **kwargs): | ||||||
|         """Recursively saves any references / |         """Recursively saves any references / | ||||||
|            generic references on an objects""" |            generic references on the document""" | ||||||
|         _refs = kwargs.get('_refs', []) or [] |         _refs = kwargs.get('_refs', []) or [] | ||||||
|  |  | ||||||
|         ReferenceField = _import_class('ReferenceField') |         ReferenceField = _import_class('ReferenceField') | ||||||
| @@ -463,7 +472,7 @@ class Document(BaseDocument): | |||||||
|         Raises :class:`OperationError` if called on an object that has not yet |         Raises :class:`OperationError` if called on an object that has not yet | ||||||
|         been saved. |         been saved. | ||||||
|         """ |         """ | ||||||
|         if not self.pk: |         if self.pk is None: | ||||||
|             if kwargs.get('upsert', False): |             if kwargs.get('upsert', False): | ||||||
|                 query = self.to_mongo() |                 query = self.to_mongo() | ||||||
|                 if "_cls" in query: |                 if "_cls" in query: | ||||||
| @@ -476,18 +485,24 @@ class Document(BaseDocument): | |||||||
|         # Need to add shard key to query, or you get an error |         # Need to add shard key to query, or you get an error | ||||||
|         return self._qs.filter(**self._object_key).update_one(**kwargs) |         return self._qs.filter(**self._object_key).update_one(**kwargs) | ||||||
|  |  | ||||||
|     def delete(self, **write_concern): |     def delete(self, signal_kwargs=None, **write_concern): | ||||||
|         """Delete the :class:`~mongoengine.Document` from the database. This |         """Delete the :class:`~mongoengine.Document` from the database. This | ||||||
|         will only take effect if the document has been previously saved. |         will only take effect if the document has been previously saved. | ||||||
|  |  | ||||||
|  |         :parm signal_kwargs: (optional) kwargs dictionary to be passed to | ||||||
|  |             the signal calls. | ||||||
|         :param write_concern: Extra keyword arguments are passed down which |         :param write_concern: Extra keyword arguments are passed down which | ||||||
|             will be used as options for the resultant |             will be used as options for the resultant | ||||||
|             ``getLastError`` command.  For example, |             ``getLastError`` command.  For example, | ||||||
|             ``save(..., write_concern={w: 2, fsync: True}, ...)`` will |             ``save(..., write_concern={w: 2, fsync: True}, ...)`` will | ||||||
|             wait until at least two servers have recorded the write and |             wait until at least two servers have recorded the write and | ||||||
|             will force an fsync on the primary server. |             will force an fsync on the primary server. | ||||||
|  |  | ||||||
|  |         .. versionchanged:: 0.10.7 | ||||||
|  |             Add signal_kwargs argument | ||||||
|         """ |         """ | ||||||
|         signals.pre_delete.send(self.__class__, document=self) |         signal_kwargs = signal_kwargs or {} | ||||||
|  |         signals.pre_delete.send(self.__class__, document=self, **signal_kwargs) | ||||||
|  |  | ||||||
|         # Delete FileFields separately |         # Delete FileFields separately | ||||||
|         FileField = _import_class('FileField') |         FileField = _import_class('FileField') | ||||||
| @@ -501,7 +516,7 @@ class Document(BaseDocument): | |||||||
|         except pymongo.errors.OperationFailure, err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             message = u'Could not delete document (%s)' % err.message |             message = u'Could not delete document (%s)' % err.message | ||||||
|             raise OperationError(message) |             raise OperationError(message) | ||||||
|         signals.post_delete.send(self.__class__, document=self) |         signals.post_delete.send(self.__class__, document=self, **signal_kwargs) | ||||||
|  |  | ||||||
|     def switch_db(self, db_alias, keep_created=True): |     def switch_db(self, db_alias, keep_created=True): | ||||||
|         """ |         """ | ||||||
| @@ -589,7 +604,7 @@ class Document(BaseDocument): | |||||||
|         elif "max_depth" in kwargs: |         elif "max_depth" in kwargs: | ||||||
|             max_depth = kwargs["max_depth"] |             max_depth = kwargs["max_depth"] | ||||||
|  |  | ||||||
|         if not self.pk: |         if self.pk is None: | ||||||
|             raise self.DoesNotExist("Document does not exist") |             raise self.DoesNotExist("Document does not exist") | ||||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( |         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||||
|             **self._object_key).only(*fields).limit( |             **self._object_key).only(*fields).limit( | ||||||
| @@ -604,6 +619,11 @@ class Document(BaseDocument): | |||||||
|             if not fields or field in fields: |             if not fields or field in fields: | ||||||
|                 try: |                 try: | ||||||
|                     setattr(self, field, self._reload(field, obj[field])) |                     setattr(self, field, self._reload(field, obj[field])) | ||||||
|  |                 except (KeyError, AttributeError): | ||||||
|  |                     try: | ||||||
|  |                         # If field is a special field, e.g. items is stored as _reserved_items, | ||||||
|  |                         # an KeyError is thrown. So try to retrieve the field from _data | ||||||
|  |                         setattr(self, field, self._reload(field, obj._data.get(field))) | ||||||
|                     except KeyError: |                     except KeyError: | ||||||
|                         # If field is removed from the database while the object |                         # If field is removed from the database while the object | ||||||
|                         # is in memory, a reload would cause a KeyError |                         # is in memory, a reload would cause a KeyError | ||||||
| @@ -635,7 +655,7 @@ class Document(BaseDocument): | |||||||
|     def to_dbref(self): |     def to_dbref(self): | ||||||
|         """Returns an instance of :class:`~bson.dbref.DBRef` useful in |         """Returns an instance of :class:`~bson.dbref.DBRef` useful in | ||||||
|         `__raw__` queries.""" |         `__raw__` queries.""" | ||||||
|         if not self.pk: |         if self.pk is None: | ||||||
|             msg = "Only saved documents can have a valid dbref" |             msg = "Only saved documents can have a valid dbref" | ||||||
|             raise OperationError(msg) |             raise OperationError(msg) | ||||||
|         return DBRef(self.__class__._get_collection_name(), self.pk) |         return DBRef(self.__class__._get_collection_name(), self.pk) | ||||||
| @@ -662,10 +682,20 @@ class Document(BaseDocument): | |||||||
|     def drop_collection(cls): |     def drop_collection(cls): | ||||||
|         """Drops the entire collection associated with this |         """Drops the entire collection associated with this | ||||||
|         :class:`~mongoengine.Document` type from the database. |         :class:`~mongoengine.Document` type from the database. | ||||||
|  |  | ||||||
|  |         Raises :class:`OperationError` if the document has no collection set | ||||||
|  |         (i.g. if it is `abstract`) | ||||||
|  |  | ||||||
|  |         .. versionchanged:: 0.10.7 | ||||||
|  |             :class:`OperationError` exception raised if no collection available | ||||||
|         """ |         """ | ||||||
|  |         col_name = cls._get_collection_name() | ||||||
|  |         if not col_name: | ||||||
|  |             raise OperationError('Document %s has no collection defined ' | ||||||
|  |                                  '(is it abstract ?)' % cls) | ||||||
|         cls._collection = None |         cls._collection = None | ||||||
|         db = cls._get_db() |         db = cls._get_db() | ||||||
|         db.drop_collection(cls._get_collection_name()) |         db.drop_collection(col_name) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def create_index(cls, keys, background=False, **kwargs): |     def create_index(cls, keys, background=False, **kwargs): | ||||||
| @@ -954,7 +984,7 @@ class MapReduceDocument(object): | |||||||
|         if not isinstance(self.key, id_field_type): |         if not isinstance(self.key, id_field_type): | ||||||
|             try: |             try: | ||||||
|                 self.key = id_field_type(self.key) |                 self.key = id_field_type(self.key) | ||||||
|             except: |             except Exception: | ||||||
|                 raise Exception("Could not cast key as %s" % |                 raise Exception("Could not cast key as %s" % | ||||||
|                                 id_field_type.__name__) |                                 id_field_type.__name__) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ from mongoengine.python_support import txt_type | |||||||
| __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', | __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', | ||||||
|            'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', |            'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', | ||||||
|            'OperationError', 'NotUniqueError', 'FieldDoesNotExist', |            'OperationError', 'NotUniqueError', 'FieldDoesNotExist', | ||||||
|            'ValidationError') |            'ValidationError', 'SaveConditionError') | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotRegistered(Exception): | class NotRegistered(Exception): | ||||||
|   | |||||||
| @@ -8,6 +8,11 @@ import uuid | |||||||
| import warnings | import warnings | ||||||
| from operator import itemgetter | from operator import itemgetter | ||||||
|  |  | ||||||
|  | from bson import Binary, DBRef, ObjectId, SON | ||||||
|  | import gridfs | ||||||
|  | import pymongo | ||||||
|  | import six | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     import dateutil |     import dateutil | ||||||
| except ImportError: | except ImportError: | ||||||
| @@ -15,18 +20,18 @@ except ImportError: | |||||||
| else: | else: | ||||||
|     import dateutil.parser |     import dateutil.parser | ||||||
|  |  | ||||||
| import pymongo | try: | ||||||
| import gridfs |     from bson.int64 import Int64 | ||||||
| from bson import Binary, DBRef, SON, ObjectId | except ImportError: | ||||||
|  |     Int64 = long | ||||||
|  |  | ||||||
| from mongoengine.errors import ValidationError | from .base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField, | ||||||
| from mongoengine.python_support import (PY3, bin_type, txt_type, |                    ObjectIdField, get_document) | ||||||
|                                         str_types, StringIO) | from .connection import DEFAULT_CONNECTION_NAME, get_db | ||||||
| from base import (BaseField, ComplexBaseField, ObjectIdField, GeoJsonBaseField, | from .document import Document, EmbeddedDocument | ||||||
|                   get_document, BaseDocument) | from .errors import DoesNotExist, ValidationError | ||||||
| from queryset import DO_NOTHING, QuerySet | from .python_support import PY3, StringIO, bin_type, str_types, txt_type | ||||||
| from document import Document, EmbeddedDocument | from .queryset import DO_NOTHING, QuerySet | ||||||
| from connection import get_db, DEFAULT_CONNECTION_NAME |  | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     from PIL import Image, ImageOps |     from PIL import Image, ImageOps | ||||||
| @@ -65,7 +70,7 @@ class StringField(BaseField): | |||||||
|             return value |             return value | ||||||
|         try: |         try: | ||||||
|             value = value.decode('utf-8') |             value = value.decode('utf-8') | ||||||
|         except: |         except Exception: | ||||||
|             pass |             pass | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
| @@ -156,7 +161,7 @@ class URLField(StringField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class EmailField(StringField): | class EmailField(StringField): | ||||||
|     """A field that validates input as an E-Mail-Address. |     """A field that validates input as an email address. | ||||||
|  |  | ||||||
|     .. versionadded:: 0.4 |     .. versionadded:: 0.4 | ||||||
|     """ |     """ | ||||||
| @@ -172,7 +177,7 @@ class EmailField(StringField): | |||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if not EmailField.EMAIL_REGEX.match(value): |         if not EmailField.EMAIL_REGEX.match(value): | ||||||
|             self.error('Invalid Mail-address: %s' % value) |             self.error('Invalid email address: %s' % value) | ||||||
|         super(EmailField, self).validate(value) |         super(EmailField, self).validate(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -194,7 +199,7 @@ class IntField(BaseField): | |||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         try: |         try: | ||||||
|             value = int(value) |             value = int(value) | ||||||
|         except: |         except Exception: | ||||||
|             self.error('%s could not be converted to int' % value) |             self.error('%s could not be converted to int' % value) | ||||||
|  |  | ||||||
|         if self.min_value is not None and value < self.min_value: |         if self.min_value is not None and value < self.min_value: | ||||||
| @@ -225,10 +230,13 @@ class LongField(BaseField): | |||||||
|             pass |             pass | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|  |     def to_mongo(self, value): | ||||||
|  |         return Int64(value) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         try: |         try: | ||||||
|             value = long(value) |             value = long(value) | ||||||
|         except: |         except Exception: | ||||||
|             self.error('%s could not be converted to long' % value) |             self.error('%s could not be converted to long' % value) | ||||||
|  |  | ||||||
|         if self.min_value is not None and value < self.min_value: |         if self.min_value is not None and value < self.min_value: | ||||||
| @@ -260,10 +268,14 @@ class FloatField(BaseField): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if isinstance(value, int): |         if isinstance(value, six.integer_types): | ||||||
|  |             try: | ||||||
|                 value = float(value) |                 value = float(value) | ||||||
|  |             except OverflowError: | ||||||
|  |                 self.error('The value is too large to be converted to float') | ||||||
|  |  | ||||||
|         if not isinstance(value, float): |         if not isinstance(value, float): | ||||||
|             self.error('FloatField only accepts float values') |             self.error('FloatField only accepts float and integer values') | ||||||
|  |  | ||||||
|         if self.min_value is not None and value < self.min_value: |         if self.min_value is not None and value < self.min_value: | ||||||
|             self.error('Float value is too small') |             self.error('Float value is too small') | ||||||
| @@ -325,7 +337,7 @@ class DecimalField(BaseField): | |||||||
|             return value |             return value | ||||||
|         return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) |         return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) | ||||||
|  |  | ||||||
|     def to_mongo(self, value, use_db_field=True): |     def to_mongo(self, value): | ||||||
|         if value is None: |         if value is None: | ||||||
|             return value |             return value | ||||||
|         if self.force_string: |         if self.force_string: | ||||||
| @@ -508,7 +520,7 @@ class ComplexDateTimeField(StringField): | |||||||
|         original_value = value |         original_value = value | ||||||
|         try: |         try: | ||||||
|             return self._convert_from_string(value) |             return self._convert_from_string(value) | ||||||
|         except: |         except Exception: | ||||||
|             return original_value |             return original_value | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
| @@ -546,11 +558,10 @@ class EmbeddedDocumentField(BaseField): | |||||||
|             return self.document_type._from_son(value, _auto_dereference=self._auto_dereference) |             return self.document_type._from_son(value, _auto_dereference=self._auto_dereference) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def to_mongo(self, value, use_db_field=True, fields=[]): |     def to_mongo(self, value, use_db_field=True, fields=None): | ||||||
|         if not isinstance(value, self.document_type): |         if not isinstance(value, self.document_type): | ||||||
|             return value |             return value | ||||||
|         return self.document_type.to_mongo(value, use_db_field, |         return self.document_type.to_mongo(value, use_db_field, fields) | ||||||
|                                            fields=fields) |  | ||||||
|  |  | ||||||
|     def validate(self, value, clean=True): |     def validate(self, value, clean=True): | ||||||
|         """Make sure that the document instance is an instance of the |         """Make sure that the document instance is an instance of the | ||||||
| @@ -566,7 +577,7 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return self.document_type._fields.get(member_name) |         return self.document_type._fields.get(member_name) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if not isinstance(value, self.document_type): |         if value is not None and not isinstance(value, self.document_type): | ||||||
|             value = self.document_type._from_son(value) |             value = self.document_type._from_son(value) | ||||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) |         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
| @@ -600,11 +611,11 @@ class GenericEmbeddedDocumentField(BaseField): | |||||||
|  |  | ||||||
|         value.validate(clean=clean) |         value.validate(clean=clean) | ||||||
|  |  | ||||||
|     def to_mongo(self, document, use_db_field=True): |     def to_mongo(self, document, use_db_field=True, fields=None): | ||||||
|         if document is None: |         if document is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         data = document.to_mongo(use_db_field) |         data = document.to_mongo(use_db_field, fields) | ||||||
|         if '_cls' not in data: |         if '_cls' not in data: | ||||||
|             data['_cls'] = document._class_name |             data['_cls'] = document._class_name | ||||||
|         return data |         return data | ||||||
| @@ -616,7 +627,7 @@ class DynamicField(BaseField): | |||||||
|  |  | ||||||
|     Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" |     Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value, use_db_field=True, fields=None): | ||||||
|         """Convert a Python type to a MongoDB compatible type. |         """Convert a Python type to a MongoDB compatible type. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
| @@ -625,7 +636,7 @@ class DynamicField(BaseField): | |||||||
|  |  | ||||||
|         if hasattr(value, 'to_mongo'): |         if hasattr(value, 'to_mongo'): | ||||||
|             cls = value.__class__ |             cls = value.__class__ | ||||||
|             val = value.to_mongo() |             val = value.to_mongo(use_db_field, fields) | ||||||
|             # If we its a document thats not inherited add _cls |             # If we its a document thats not inherited add _cls | ||||||
|             if isinstance(value, Document): |             if isinstance(value, Document): | ||||||
|                 val = {"_ref": value.to_dbref(), "_cls": cls.__name__} |                 val = {"_ref": value.to_dbref(), "_cls": cls.__name__} | ||||||
| @@ -643,7 +654,7 @@ class DynamicField(BaseField): | |||||||
|  |  | ||||||
|         data = {} |         data = {} | ||||||
|         for k, v in value.iteritems(): |         for k, v in value.iteritems(): | ||||||
|             data[k] = self.to_mongo(v) |             data[k] = self.to_mongo(v, use_db_field, fields) | ||||||
|  |  | ||||||
|         value = data |         value = data | ||||||
|         if is_list:  # Convert back to a list |         if is_list:  # Convert back to a list | ||||||
| @@ -697,7 +708,7 @@ class ListField(ComplexBaseField): | |||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if self.field: |         if self.field: | ||||||
|             if op in ('set', 'unset') and ( |             if op in ('set', 'unset', None) and ( | ||||||
|                     not isinstance(value, basestring) and |                     not isinstance(value, basestring) and | ||||||
|                     not isinstance(value, BaseDocument) and |                     not isinstance(value, BaseDocument) and | ||||||
|                     hasattr(value, '__iter__')): |                     hasattr(value, '__iter__')): | ||||||
| @@ -755,8 +766,8 @@ class SortedListField(ListField): | |||||||
|             self._order_reverse = kwargs.pop('reverse') |             self._order_reverse = kwargs.pop('reverse') | ||||||
|         super(SortedListField, self).__init__(field, **kwargs) |         super(SortedListField, self).__init__(field, **kwargs) | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value, use_db_field=True, fields=None): | ||||||
|         value = super(SortedListField, self).to_mongo(value) |         value = super(SortedListField, self).to_mongo(value, use_db_field, fields) | ||||||
|         if self._ordering is not None: |         if self._ordering is not None: | ||||||
|             return sorted(value, key=itemgetter(self._ordering), |             return sorted(value, key=itemgetter(self._ordering), | ||||||
|                           reverse=self._order_reverse) |                           reverse=self._order_reverse) | ||||||
| @@ -878,7 +889,7 @@ class ReferenceField(BaseField): | |||||||
|             content = StringField() |             content = StringField() | ||||||
|             foo = ReferenceField('Foo') |             foo = ReferenceField('Foo') | ||||||
|  |  | ||||||
|         Bar.register_delete_rule(Foo, 'bar', NULLIFY) |         Foo.register_delete_rule(Bar, 'foo', NULLIFY) | ||||||
|  |  | ||||||
|     .. note :: |     .. note :: | ||||||
|         `reverse_delete_rule` does not trigger pre / post delete signals to be |         `reverse_delete_rule` does not trigger pre / post delete signals to be | ||||||
| @@ -895,6 +906,10 @@ class ReferenceField(BaseField): | |||||||
|           or as the :class:`~pymongo.objectid.ObjectId`.id . |           or as the :class:`~pymongo.objectid.ObjectId`.id . | ||||||
|         :param reverse_delete_rule: Determines what to do when the referring |         :param reverse_delete_rule: Determines what to do when the referring | ||||||
|           object is deleted |           object is deleted | ||||||
|  |  | ||||||
|  |         .. note :: | ||||||
|  |             A reference to an abstract document type is always stored as a | ||||||
|  |             :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. | ||||||
|         """ |         """ | ||||||
|         if not isinstance(document_type, basestring): |         if not isinstance(document_type, basestring): | ||||||
|             if not issubclass(document_type, (Document, basestring)): |             if not issubclass(document_type, (Document, basestring)): | ||||||
| @@ -927,9 +942,16 @@ class ReferenceField(BaseField): | |||||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference |         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||||
|         # Dereference DBRefs |         # Dereference DBRefs | ||||||
|         if self._auto_dereference and isinstance(value, DBRef): |         if self._auto_dereference and isinstance(value, DBRef): | ||||||
|             value = self.document_type._get_db().dereference(value) |             if hasattr(value, 'cls'): | ||||||
|             if value is not None: |                 # Dereference using the class type specified in the reference | ||||||
|                 instance._data[self.name] = self.document_type._from_son(value) |                 cls = get_document(value.cls) | ||||||
|  |             else: | ||||||
|  |                 cls = self.document_type | ||||||
|  |             dereferenced = cls._get_db().dereference(value) | ||||||
|  |             if dereferenced is None: | ||||||
|  |                 raise DoesNotExist('Trying to dereference unknown document %s' % value) | ||||||
|  |             else: | ||||||
|  |                 instance._data[self.name] = cls._from_son(dereferenced) | ||||||
|  |  | ||||||
|         return super(ReferenceField, self).__get__(instance, owner) |         return super(ReferenceField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
| @@ -939,21 +961,29 @@ class ReferenceField(BaseField): | |||||||
|                 return document.id |                 return document.id | ||||||
|             return document |             return document | ||||||
|  |  | ||||||
|         id_field_name = self.document_type._meta['id_field'] |  | ||||||
|         id_field = self.document_type._fields[id_field_name] |  | ||||||
|  |  | ||||||
|         if isinstance(document, Document): |         if isinstance(document, Document): | ||||||
|             # We need the id from the saved object to create the DBRef |             # We need the id from the saved object to create the DBRef | ||||||
|             id_ = document.pk |             id_ = document.pk | ||||||
|             if id_ is None: |             if id_ is None: | ||||||
|                 self.error('You can only reference documents once they have' |                 self.error('You can only reference documents once they have' | ||||||
|                            ' been saved to the database') |                            ' been saved to the database') | ||||||
|  |  | ||||||
|  |             # Use the attributes from the document instance, so that they | ||||||
|  |             # override the attributes of this field's document type | ||||||
|  |             cls = document | ||||||
|         else: |         else: | ||||||
|             id_ = document |             id_ = document | ||||||
|  |             cls = self.document_type | ||||||
|  |  | ||||||
|  |         id_field_name = cls._meta['id_field'] | ||||||
|  |         id_field = cls._fields[id_field_name] | ||||||
|  |  | ||||||
|         id_ = id_field.to_mongo(id_) |         id_ = id_field.to_mongo(id_) | ||||||
|         if self.dbref: |         if self.document_type._meta.get('abstract'): | ||||||
|             collection = self.document_type._get_collection_name() |             collection = cls._get_collection_name() | ||||||
|  |             return DBRef(collection, id_, cls=cls._class_name) | ||||||
|  |         elif self.dbref: | ||||||
|  |             collection = cls._get_collection_name() | ||||||
|             return DBRef(collection, id_) |             return DBRef(collection, id_) | ||||||
|  |  | ||||||
|         return id_ |         return id_ | ||||||
| @@ -982,6 +1012,13 @@ class ReferenceField(BaseField): | |||||||
|             self.error('You can only reference documents once they have been ' |             self.error('You can only reference documents once they have been ' | ||||||
|                        'saved to the database') |                        'saved to the database') | ||||||
|  |  | ||||||
|  |         if self.document_type._meta.get('abstract') and \ | ||||||
|  |                 not isinstance(value, self.document_type): | ||||||
|  |             self.error( | ||||||
|  |                 '%s is not an instance of abstract reference type %s' % ( | ||||||
|  |                     self.document_type._class_name) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     def lookup_member(self, member_name): |     def lookup_member(self, member_name): | ||||||
|         return self.document_type._fields.get(member_name) |         return self.document_type._fields.get(member_name) | ||||||
|  |  | ||||||
| @@ -1057,13 +1094,15 @@ class CachedReferenceField(BaseField): | |||||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference |         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||||
|         # Dereference DBRefs |         # Dereference DBRefs | ||||||
|         if self._auto_dereference and isinstance(value, DBRef): |         if self._auto_dereference and isinstance(value, DBRef): | ||||||
|             value = self.document_type._get_db().dereference(value) |             dereferenced = self.document_type._get_db().dereference(value) | ||||||
|             if value is not None: |             if dereferenced is None: | ||||||
|                 instance._data[self.name] = self.document_type._from_son(value) |                 raise DoesNotExist('Trying to dereference unknown document %s' % value) | ||||||
|  |             else: | ||||||
|  |                 instance._data[self.name] = self.document_type._from_son(dereferenced) | ||||||
|  |  | ||||||
|         return super(CachedReferenceField, self).__get__(instance, owner) |         return super(CachedReferenceField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
|     def to_mongo(self, document): |     def to_mongo(self, document, use_db_field=True, fields=None): | ||||||
|         id_field_name = self.document_type._meta['id_field'] |         id_field_name = self.document_type._meta['id_field'] | ||||||
|         id_field = self.document_type._fields[id_field_name] |         id_field = self.document_type._fields[id_field_name] | ||||||
|  |  | ||||||
| @@ -1081,7 +1120,12 @@ class CachedReferenceField(BaseField): | |||||||
|             ("_id", id_field.to_mongo(id_)), |             ("_id", id_field.to_mongo(id_)), | ||||||
|         )) |         )) | ||||||
|  |  | ||||||
|         value.update(dict(document.to_mongo(fields=self.fields))) |         if fields: | ||||||
|  |             new_fields = [f for f in self.fields if f in fields] | ||||||
|  |         else: | ||||||
|  |             new_fields = self.fields | ||||||
|  |  | ||||||
|  |         value.update(dict(document.to_mongo(use_db_field, fields=new_fields))) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
| @@ -1172,7 +1216,11 @@ class GenericReferenceField(BaseField): | |||||||
|  |  | ||||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference |         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||||
|         if self._auto_dereference and isinstance(value, (dict, SON)): |         if self._auto_dereference and isinstance(value, (dict, SON)): | ||||||
|             instance._data[self.name] = self.dereference(value) |             dereferenced = self.dereference(value) | ||||||
|  |             if dereferenced is None: | ||||||
|  |                 raise DoesNotExist('Trying to dereference unknown document %s' % value) | ||||||
|  |             else: | ||||||
|  |                 instance._data[self.name] = dereferenced | ||||||
|  |  | ||||||
|         return super(GenericReferenceField, self).__get__(instance, owner) |         return super(GenericReferenceField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
| @@ -1197,7 +1245,7 @@ class GenericReferenceField(BaseField): | |||||||
|             doc = doc_cls._from_son(doc) |             doc = doc_cls._from_son(doc) | ||||||
|         return doc |         return doc | ||||||
|  |  | ||||||
|     def to_mongo(self, document, use_db_field=True): |     def to_mongo(self, document): | ||||||
|         if document is None: |         if document is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
| @@ -1345,7 +1393,7 @@ class GridFSProxy(object): | |||||||
|             if self.gridout is None: |             if self.gridout is None: | ||||||
|                 self.gridout = self.fs.get(self.grid_id) |                 self.gridout = self.fs.get(self.grid_id) | ||||||
|             return self.gridout |             return self.gridout | ||||||
|         except: |         except Exception: | ||||||
|             # File has been deleted |             # File has been deleted | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
| @@ -1383,7 +1431,7 @@ class GridFSProxy(object): | |||||||
|         else: |         else: | ||||||
|             try: |             try: | ||||||
|                 return gridout.read(size) |                 return gridout.read(size) | ||||||
|             except: |             except Exception: | ||||||
|                 return "" |                 return "" | ||||||
|  |  | ||||||
|     def delete(self): |     def delete(self): | ||||||
| @@ -1448,7 +1496,7 @@ class FileField(BaseField): | |||||||
|             if grid_file: |             if grid_file: | ||||||
|                 try: |                 try: | ||||||
|                     grid_file.delete() |                     grid_file.delete() | ||||||
|                 except: |                 except Exception: | ||||||
|                     pass |                     pass | ||||||
|  |  | ||||||
|             # Create a new proxy object as we don't already have one |             # Create a new proxy object as we don't already have one | ||||||
| @@ -1816,7 +1864,7 @@ class UUIDField(BaseField): | |||||||
|                 if not isinstance(value, basestring): |                 if not isinstance(value, basestring): | ||||||
|                     value = unicode(value) |                     value = unicode(value) | ||||||
|                 return uuid.UUID(value) |                 return uuid.UUID(value) | ||||||
|             except: |             except Exception: | ||||||
|                 return original_value |                 return original_value | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| from mongoengine.errors import (DoesNotExist, MultipleObjectsReturned, | from mongoengine.errors import (DoesNotExist, InvalidQueryError, | ||||||
|                                 InvalidQueryError, OperationError, |                                 MultipleObjectsReturned, NotUniqueError, | ||||||
|                                 NotUniqueError) |                                 OperationError) | ||||||
| from mongoengine.queryset.field_list import * | from mongoengine.queryset.field_list import * | ||||||
| from mongoengine.queryset.manager import * | from mongoengine.queryset.manager import * | ||||||
| from mongoengine.queryset.queryset import * | from mongoengine.queryset.queryset import * | ||||||
|   | |||||||
| @@ -7,20 +7,19 @@ import pprint | |||||||
| import re | import re | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| from bson import SON | from bson import SON, json_util | ||||||
| from bson.code import Code | from bson.code import Code | ||||||
| from bson import json_util |  | ||||||
| import pymongo | import pymongo | ||||||
| import pymongo.errors | import pymongo.errors | ||||||
| from pymongo.common import validate_read_preference | from pymongo.common import validate_read_preference | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
|  | from mongoengine.base.common import get_document | ||||||
|  | from mongoengine.common import _import_class | ||||||
| from mongoengine.connection import get_db | from mongoengine.connection import get_db | ||||||
| from mongoengine.context_managers import switch_db | from mongoengine.context_managers import switch_db | ||||||
| from mongoengine.common import _import_class | from mongoengine.errors import (InvalidQueryError, LookUpError, | ||||||
| from mongoengine.base.common import get_document |                                 NotUniqueError, OperationError) | ||||||
| from mongoengine.errors import (OperationError, NotUniqueError, |  | ||||||
|                                 InvalidQueryError, LookUpError) |  | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import IS_PYMONGO_3 | ||||||
| from mongoengine.queryset import transform | from mongoengine.queryset import transform | ||||||
| from mongoengine.queryset.field_list import QueryFieldList | from mongoengine.queryset.field_list import QueryFieldList | ||||||
| @@ -123,9 +122,40 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def __getitem__(self, key): |     def __getstate__(self): | ||||||
|         """Support skip and limit using getitem and slicing syntax. |  | ||||||
|         """ |         """ | ||||||
|  |         Need for pickling queryset | ||||||
|  |  | ||||||
|  |         See https://github.com/MongoEngine/mongoengine/issues/442 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         obj_dict = self.__dict__.copy() | ||||||
|  |  | ||||||
|  |         # don't picke collection, instead pickle collection params | ||||||
|  |         obj_dict.pop("_collection_obj") | ||||||
|  |  | ||||||
|  |         # don't pickle cursor | ||||||
|  |         obj_dict["_cursor_obj"] = None | ||||||
|  |  | ||||||
|  |         return obj_dict | ||||||
|  |  | ||||||
|  |     def __setstate__(self, obj_dict): | ||||||
|  |         """ | ||||||
|  |         Need for pickling queryset | ||||||
|  |  | ||||||
|  |         See https://github.com/MongoEngine/mongoengine/issues/442 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() | ||||||
|  |  | ||||||
|  |         # update attributes | ||||||
|  |         self.__dict__.update(obj_dict) | ||||||
|  |  | ||||||
|  |         # forse load cursor | ||||||
|  |         # self._cursor | ||||||
|  |  | ||||||
|  |     def __getitem__(self, key): | ||||||
|  |         """Support skip and limit using getitem and slicing syntax.""" | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|  |  | ||||||
|         # Slice provided |         # Slice provided | ||||||
| @@ -245,6 +275,8 @@ class BaseQuerySet(object): | |||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             return result |             return result | ||||||
|  |  | ||||||
|  |         # If we were able to retrieve the 2nd doc, rewind the cursor and | ||||||
|  |         # raise the MultipleObjectsReturned exception. | ||||||
|         queryset.rewind() |         queryset.rewind() | ||||||
|         message = u'%d items returned, instead of 1' % queryset.count() |         message = u'%d items returned, instead of 1' % queryset.count() | ||||||
|         raise queryset._document.MultipleObjectsReturned(message) |         raise queryset._document.MultipleObjectsReturned(message) | ||||||
| @@ -266,7 +298,8 @@ class BaseQuerySet(object): | |||||||
|             result = None |             result = None | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|     def insert(self, doc_or_docs, load_bulk=True, write_concern=None): |     def insert(self, doc_or_docs, load_bulk=True, | ||||||
|  |                write_concern=None, signal_kwargs=None): | ||||||
|         """bulk insert documents |         """bulk insert documents | ||||||
|  |  | ||||||
|         :param doc_or_docs: a document or list of documents to be inserted |         :param doc_or_docs: a document or list of documents to be inserted | ||||||
| @@ -279,11 +312,15 @@ class BaseQuerySet(object): | |||||||
|                 ``insert(..., {w: 2, fsync: True})`` will wait until at least |                 ``insert(..., {w: 2, fsync: True})`` will wait until at least | ||||||
|                 two servers have recorded the write and will force an fsync on |                 two servers have recorded the write and will force an fsync on | ||||||
|                 each server being written to. |                 each server being written to. | ||||||
|  |         :parm signal_kwargs: (optional) kwargs dictionary to be passed to | ||||||
|  |             the signal calls. | ||||||
|  |  | ||||||
|         By default returns document instances, set ``load_bulk`` to False to |         By default returns document instances, set ``load_bulk`` to False to | ||||||
|         return just ``ObjectIds`` |         return just ``ObjectIds`` | ||||||
|  |  | ||||||
|         .. versionadded:: 0.5 |         .. versionadded:: 0.5 | ||||||
|  |         .. versionchanged:: 0.10.7 | ||||||
|  |             Add signal_kwargs argument | ||||||
|         """ |         """ | ||||||
|         Document = _import_class('Document') |         Document = _import_class('Document') | ||||||
|  |  | ||||||
| @@ -296,7 +333,6 @@ class BaseQuerySet(object): | |||||||
|             return_one = True |             return_one = True | ||||||
|             docs = [docs] |             docs = [docs] | ||||||
|  |  | ||||||
|         raw = [] |  | ||||||
|         for doc in docs: |         for doc in docs: | ||||||
|             if not isinstance(doc, self._document): |             if not isinstance(doc, self._document): | ||||||
|                 msg = ("Some documents inserted aren't instances of %s" |                 msg = ("Some documents inserted aren't instances of %s" | ||||||
| @@ -305,9 +341,12 @@ class BaseQuerySet(object): | |||||||
|             if doc.pk and not doc._created: |             if doc.pk and not doc._created: | ||||||
|                 msg = "Some documents have ObjectIds use doc.update() instead" |                 msg = "Some documents have ObjectIds use doc.update() instead" | ||||||
|                 raise OperationError(msg) |                 raise OperationError(msg) | ||||||
|             raw.append(doc.to_mongo()) |  | ||||||
|  |  | ||||||
|         signals.pre_bulk_insert.send(self._document, documents=docs) |         signal_kwargs = signal_kwargs or {} | ||||||
|  |         signals.pre_bulk_insert.send(self._document, | ||||||
|  |                                      documents=docs, **signal_kwargs) | ||||||
|  |  | ||||||
|  |         raw = [doc.to_mongo() for doc in docs] | ||||||
|         try: |         try: | ||||||
|             ids = self._collection.insert(raw, **write_concern) |             ids = self._collection.insert(raw, **write_concern) | ||||||
|         except pymongo.errors.DuplicateKeyError, err: |         except pymongo.errors.DuplicateKeyError, err: | ||||||
| @@ -324,7 +363,7 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         if not load_bulk: |         if not load_bulk: | ||||||
|             signals.post_bulk_insert.send( |             signals.post_bulk_insert.send( | ||||||
|                 self._document, documents=docs, loaded=False) |                 self._document, documents=docs, loaded=False, **signal_kwargs) | ||||||
|             return return_one and ids[0] or ids |             return return_one and ids[0] or ids | ||||||
|  |  | ||||||
|         documents = self.in_bulk(ids) |         documents = self.in_bulk(ids) | ||||||
| @@ -332,7 +371,7 @@ class BaseQuerySet(object): | |||||||
|         for obj_id in ids: |         for obj_id in ids: | ||||||
|             results.append(documents.get(obj_id)) |             results.append(documents.get(obj_id)) | ||||||
|         signals.post_bulk_insert.send( |         signals.post_bulk_insert.send( | ||||||
|             self._document, documents=results, loaded=True) |             self._document, documents=results, loaded=True, **signal_kwargs) | ||||||
|         return return_one and results[0] or results |         return return_one and results[0] or results | ||||||
|  |  | ||||||
|     def count(self, with_limit_and_skip=False): |     def count(self, with_limit_and_skip=False): | ||||||
| @@ -403,6 +442,8 @@ class BaseQuerySet(object): | |||||||
|             rule = doc._meta['delete_rules'][rule_entry] |             rule = doc._meta['delete_rules'][rule_entry] | ||||||
|             if rule == CASCADE: |             if rule == CASCADE: | ||||||
|                 cascade_refs = set() if cascade_refs is None else cascade_refs |                 cascade_refs = set() if cascade_refs is None else cascade_refs | ||||||
|  |                 # Handle recursive reference | ||||||
|  |                 if doc._collection == document_cls._collection: | ||||||
|                     for ref in queryset: |                     for ref in queryset: | ||||||
|                         cascade_refs.add(ref.id) |                         cascade_refs.add(ref.id) | ||||||
|                 ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) |                 ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) | ||||||
| @@ -425,7 +466,7 @@ class BaseQuerySet(object): | |||||||
|                full_result=False, **update): |                full_result=False, **update): | ||||||
|         """Perform an atomic update on the fields matched by the query. |         """Perform an atomic update on the fields matched by the query. | ||||||
|  |  | ||||||
|         :param upsert: Any existing document with that "_id" is overwritten. |         :param upsert: insert if document doesn't exist (default ``False``) | ||||||
|         :param multi: Update multiple documents. |         :param multi: Update multiple documents. | ||||||
|         :param write_concern: Extra keyword arguments are passed down which |         :param write_concern: Extra keyword arguments are passed down which | ||||||
|             will be used as options for the resultant |             will be used as options for the resultant | ||||||
| @@ -471,11 +512,37 @@ class BaseQuerySet(object): | |||||||
|                 raise OperationError(message) |                 raise OperationError(message) | ||||||
|             raise OperationError(u'Update failed (%s)' % unicode(err)) |             raise OperationError(u'Update failed (%s)' % unicode(err)) | ||||||
|  |  | ||||||
|  |     def upsert_one(self, write_concern=None, **update): | ||||||
|  |         """Overwrite or add the first document matched by the query. | ||||||
|  |  | ||||||
|  |         :param write_concern: Extra keyword arguments are passed down which | ||||||
|  |             will be used as options for the resultant | ||||||
|  |             ``getLastError`` command.  For example, | ||||||
|  |             ``save(..., write_concern={w: 2, fsync: True}, ...)`` will | ||||||
|  |             wait until at least two servers have recorded the write and | ||||||
|  |             will force an fsync on the primary server. | ||||||
|  |         :param update: Django-style update keyword arguments | ||||||
|  |  | ||||||
|  |         :returns the new or overwritten document | ||||||
|  |  | ||||||
|  |         .. versionadded:: 0.10.2 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         atomic_update = self.update(multi=False, upsert=True, | ||||||
|  |                                     write_concern=write_concern, | ||||||
|  |                                     full_result=True, **update) | ||||||
|  |  | ||||||
|  |         if atomic_update['updatedExisting']: | ||||||
|  |             document = self.get() | ||||||
|  |         else: | ||||||
|  |             document = self._document.objects.with_id(atomic_update['upserted']) | ||||||
|  |         return document | ||||||
|  |  | ||||||
|     def update_one(self, upsert=False, write_concern=None, **update): |     def update_one(self, upsert=False, write_concern=None, **update): | ||||||
|         """Perform an atomic update on the fields of the first document |         """Perform an atomic update on the fields of the first document | ||||||
|         matched by the query. |         matched by the query. | ||||||
|  |  | ||||||
|         :param upsert: Any existing document with that "_id" is overwritten. |         :param upsert: insert if document doesn't exist (default ``False``) | ||||||
|         :param write_concern: Extra keyword arguments are passed down which |         :param write_concern: Extra keyword arguments are passed down which | ||||||
|             will be used as options for the resultant |             will be used as options for the resultant | ||||||
|             ``getLastError`` command.  For example, |             ``getLastError`` command.  For example, | ||||||
| @@ -868,6 +935,14 @@ class BaseQuerySet(object): | |||||||
|         queryset._ordering = queryset._get_order_by(keys) |         queryset._ordering = queryset._get_order_by(keys) | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|  |     def comment(self, text): | ||||||
|  |         """Add a comment to the query. | ||||||
|  |  | ||||||
|  |         See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment | ||||||
|  |         for details. | ||||||
|  |         """ | ||||||
|  |         return self._chainable_method("comment", text) | ||||||
|  |  | ||||||
|     def explain(self, format=False): |     def explain(self, format=False): | ||||||
|         """Return an explain plan record for the |         """Return an explain plan record for the | ||||||
|         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. |         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. | ||||||
| @@ -930,6 +1005,7 @@ class BaseQuerySet(object): | |||||||
|         validate_read_preference('read_preference', read_preference) |         validate_read_preference('read_preference', read_preference) | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         queryset._read_preference = read_preference |         queryset._read_preference = read_preference | ||||||
|  |         queryset._cursor_obj = None  # we need to re-create the cursor object whenever we apply read_preference | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def scalar(self, *fields): |     def scalar(self, *fields): | ||||||
| @@ -1202,66 +1278,29 @@ class BaseQuerySet(object): | |||||||
|     def sum(self, field): |     def sum(self, field): | ||||||
|         """Sum over the values of the specified field. |         """Sum over the values of the specified field. | ||||||
|  |  | ||||||
|         :param field: the field to sum over; use dot-notation to refer to |         :param field: the field to sum over; use dot notation to refer to | ||||||
|             embedded document fields |             embedded document fields | ||||||
|  |  | ||||||
|         .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work |  | ||||||
|             with sharding. |  | ||||||
|         """ |         """ | ||||||
|         map_func = """ |         db_field = self._fields_to_dbfields([field]).pop() | ||||||
|             function() { |         pipeline = [ | ||||||
|                 var path = '{{~%(field)s}}'.split('.'), |  | ||||||
|                 field = this; |  | ||||||
|  |  | ||||||
|                 for (p in path) { |  | ||||||
|                     if (typeof field != 'undefined') |  | ||||||
|                        field = field[path[p]]; |  | ||||||
|                     else |  | ||||||
|                        break; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 if (field && field.constructor == Array) { |  | ||||||
|                     field.forEach(function(item) { |  | ||||||
|                         emit(1, item||0); |  | ||||||
|                     }); |  | ||||||
|                 } else if (typeof field != 'undefined') { |  | ||||||
|                     emit(1, field||0); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         """ % dict(field=field) |  | ||||||
|  |  | ||||||
|         reduce_func = Code(""" |  | ||||||
|             function(key, values) { |  | ||||||
|                 var sum = 0; |  | ||||||
|                 for (var i in values) { |  | ||||||
|                     sum += values[i]; |  | ||||||
|                 } |  | ||||||
|                 return sum; |  | ||||||
|             } |  | ||||||
|         """) |  | ||||||
|  |  | ||||||
|         for result in self.map_reduce(map_func, reduce_func, output='inline'): |  | ||||||
|             return result.value |  | ||||||
|         else: |  | ||||||
|             return 0 |  | ||||||
|  |  | ||||||
|     def aggregate_sum(self, field): |  | ||||||
|         """Sum over the values of the specified field. |  | ||||||
|  |  | ||||||
|         :param field: the field to sum over; use dot-notation to refer to |  | ||||||
|             embedded document fields |  | ||||||
|  |  | ||||||
|         This method is more performant than the regular `sum`, because it uses |  | ||||||
|         the aggregation framework instead of map-reduce. |  | ||||||
|         """ |  | ||||||
|         result = self._document._get_collection().aggregate([ |  | ||||||
|             {'$match': self._query}, |             {'$match': self._query}, | ||||||
|             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}} |             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}} | ||||||
|         ]) |         ] | ||||||
|  |  | ||||||
|  |         # if we're performing a sum over a list field, we sum up all the | ||||||
|  |         # elements in the list, hence we need to $unwind the arrays first | ||||||
|  |         ListField = _import_class('ListField') | ||||||
|  |         field_parts = field.split('.') | ||||||
|  |         field_instances = self._document._lookup_field(field_parts) | ||||||
|  |         if isinstance(field_instances[-1], ListField): | ||||||
|  |             pipeline.insert(1, {'$unwind': '$' + field}) | ||||||
|  |  | ||||||
|  |         result = self._document._get_collection().aggregate(pipeline) | ||||||
|         if IS_PYMONGO_3: |         if IS_PYMONGO_3: | ||||||
|             result = list(result) |             result = tuple(result) | ||||||
|         else: |         else: | ||||||
|             result = result.get('result') |             result = result.get('result') | ||||||
|  |  | ||||||
|         if result: |         if result: | ||||||
|             return result[0]['total'] |             return result[0]['total'] | ||||||
|         return 0 |         return 0 | ||||||
| @@ -1269,73 +1308,27 @@ class BaseQuerySet(object): | |||||||
|     def average(self, field): |     def average(self, field): | ||||||
|         """Average over the values of the specified field. |         """Average over the values of the specified field. | ||||||
|  |  | ||||||
|         :param field: the field to average over; use dot-notation to refer to |         :param field: the field to average over; use dot notation to refer to | ||||||
|             embedded document fields |             embedded document fields | ||||||
|  |  | ||||||
|         .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work |  | ||||||
|             with sharding. |  | ||||||
|         """ |         """ | ||||||
|         map_func = """ |         db_field = self._fields_to_dbfields([field]).pop() | ||||||
|             function() { |         pipeline = [ | ||||||
|                 var path = '{{~%(field)s}}'.split('.'), |  | ||||||
|                 field = this; |  | ||||||
|  |  | ||||||
|                 for (p in path) { |  | ||||||
|                     if (typeof field != 'undefined') |  | ||||||
|                        field = field[path[p]]; |  | ||||||
|                     else |  | ||||||
|                        break; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 if (field && field.constructor == Array) { |  | ||||||
|                     field.forEach(function(item) { |  | ||||||
|                         emit(1, {t: item||0, c: 1}); |  | ||||||
|                     }); |  | ||||||
|                 } else if (typeof field != 'undefined') { |  | ||||||
|                     emit(1, {t: field||0, c: 1}); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         """ % dict(field=field) |  | ||||||
|  |  | ||||||
|         reduce_func = Code(""" |  | ||||||
|             function(key, values) { |  | ||||||
|                 var out = {t: 0, c: 0}; |  | ||||||
|                 for (var i in values) { |  | ||||||
|                     var value = values[i]; |  | ||||||
|                     out.t += value.t; |  | ||||||
|                     out.c += value.c; |  | ||||||
|                 } |  | ||||||
|                 return out; |  | ||||||
|             } |  | ||||||
|         """) |  | ||||||
|  |  | ||||||
|         finalize_func = Code(""" |  | ||||||
|             function(key, value) { |  | ||||||
|                 return value.t / value.c; |  | ||||||
|             } |  | ||||||
|         """) |  | ||||||
|  |  | ||||||
|         for result in self.map_reduce(map_func, reduce_func, |  | ||||||
|                                       finalize_f=finalize_func, output='inline'): |  | ||||||
|             return result.value |  | ||||||
|         else: |  | ||||||
|             return 0 |  | ||||||
|  |  | ||||||
|     def aggregate_average(self, field): |  | ||||||
|         """Average over the values of the specified field. |  | ||||||
|  |  | ||||||
|         :param field: the field to average over; use dot-notation to refer to |  | ||||||
|             embedded document fields |  | ||||||
|  |  | ||||||
|         This method is more performant than the regular `average`, because it |  | ||||||
|         uses the aggregation framework instead of map-reduce. |  | ||||||
|         """ |  | ||||||
|         result = self._document._get_collection().aggregate([ |  | ||||||
|             {'$match': self._query}, |             {'$match': self._query}, | ||||||
|             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}} |             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}} | ||||||
|         ]) |         ] | ||||||
|  |  | ||||||
|  |         # if we're performing an average over a list field, we average out | ||||||
|  |         # all the elements in the list, hence we need to $unwind the arrays | ||||||
|  |         # first | ||||||
|  |         ListField = _import_class('ListField') | ||||||
|  |         field_parts = field.split('.') | ||||||
|  |         field_instances = self._document._lookup_field(field_parts) | ||||||
|  |         if isinstance(field_instances[-1], ListField): | ||||||
|  |             pipeline.insert(1, {'$unwind': '$' + field}) | ||||||
|  |  | ||||||
|  |         result = self._document._get_collection().aggregate(pipeline) | ||||||
|         if IS_PYMONGO_3: |         if IS_PYMONGO_3: | ||||||
|             result = list(result) |             result = tuple(result) | ||||||
|         else: |         else: | ||||||
|             result = result.get('result') |             result = result.get('result') | ||||||
|         if result: |         if result: | ||||||
| @@ -1352,7 +1345,7 @@ class BaseQuerySet(object): | |||||||
|             Can only do direct simple mappings and cannot map across |             Can only do direct simple mappings and cannot map across | ||||||
|             :class:`~mongoengine.fields.ReferenceField` or |             :class:`~mongoengine.fields.ReferenceField` or | ||||||
|             :class:`~mongoengine.fields.GenericReferenceField` for more complex |             :class:`~mongoengine.fields.GenericReferenceField` for more complex | ||||||
|             counting a manual map reduce call would is required. |             counting a manual map reduce call is required. | ||||||
|  |  | ||||||
|         If the field is a :class:`~mongoengine.fields.ListField`, the items within |         If the field is a :class:`~mongoengine.fields.ListField`, the items within | ||||||
|         each list will be counted individually. |         each list will be counted individually. | ||||||
| @@ -1426,7 +1419,7 @@ class BaseQuerySet(object): | |||||||
|                 msg = "The snapshot option is not anymore available with PyMongo 3+" |                 msg = "The snapshot option is not anymore available with PyMongo 3+" | ||||||
|                 warnings.warn(msg, DeprecationWarning) |                 warnings.warn(msg, DeprecationWarning) | ||||||
|             cursor_args = { |             cursor_args = { | ||||||
|                 'no_cursor_timeout': self._timeout |                 'no_cursor_timeout': not self._timeout | ||||||
|             } |             } | ||||||
|         if self._loaded_fields: |         if self._loaded_fields: | ||||||
|             cursor_args[fields_name] = self._loaded_fields.as_dict() |             cursor_args[fields_name] = self._loaded_fields.as_dict() | ||||||
| @@ -1443,6 +1436,14 @@ class BaseQuerySet(object): | |||||||
|     def _cursor(self): |     def _cursor(self): | ||||||
|         if self._cursor_obj is None: |         if self._cursor_obj is None: | ||||||
|  |  | ||||||
|  |             # In PyMongo 3+, we define the read preference on a collection | ||||||
|  |             # level, not a cursor level. Thus, we need to get a cloned | ||||||
|  |             # collection object using `with_options` first. | ||||||
|  |             if IS_PYMONGO_3 and self._read_preference is not None: | ||||||
|  |                 self._cursor_obj = self._collection\ | ||||||
|  |                     .with_options(read_preference=self._read_preference)\ | ||||||
|  |                     .find(self._query, **self._cursor_args) | ||||||
|  |             else: | ||||||
|                 self._cursor_obj = self._collection.find(self._query, |                 self._cursor_obj = self._collection.find(self._query, | ||||||
|                                                          **self._cursor_args) |                                                          **self._cursor_args) | ||||||
|             # Apply where clauses to cursor |             # Apply where clauses to cursor | ||||||
| @@ -1661,7 +1662,7 @@ class BaseQuerySet(object): | |||||||
|             key = key.replace('__', '.') |             key = key.replace('__', '.') | ||||||
|             try: |             try: | ||||||
|                 key = self._document._translate_field_name(key) |                 key = self._document._translate_field_name(key) | ||||||
|             except: |             except Exception: | ||||||
|                 pass |                 pass | ||||||
|             key_list.append((key, direction)) |             key_list.append((key, direction)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ class QuerySetManager(object): | |||||||
|         Document.objects is accessed. |         Document.objects is accessed. | ||||||
|         """ |         """ | ||||||
|         if instance is not None: |         if instance is not None: | ||||||
|             # Document class being used rather than a document object |             # Document object being used rather than a document class | ||||||
|             return self |             return self | ||||||
|  |  | ||||||
|         # owner is the document that contains the QuerySetManager |         # owner is the document that contains the QuerySetManager | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| from mongoengine.errors import OperationError | from mongoengine.errors import OperationError | ||||||
| from mongoengine.queryset.base import (BaseQuerySet, DO_NOTHING, NULLIFY, | from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING, | ||||||
|                                        CASCADE, DENY, PULL) |                                        NULLIFY, PULL) | ||||||
|  |  | ||||||
| __all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', | __all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', | ||||||
|            'DENY', 'PULL') |            'DENY', 'PULL') | ||||||
| @@ -30,6 +30,7 @@ class QuerySet(BaseQuerySet): | |||||||
|         batch. Otherwise iterate the result_cache. |         batch. Otherwise iterate the result_cache. | ||||||
|         """ |         """ | ||||||
|         self._iter = True |         self._iter = True | ||||||
|  |  | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|             return self._iter_results() |             return self._iter_results() | ||||||
|  |  | ||||||
| @@ -38,14 +39,16 @@ class QuerySet(BaseQuerySet): | |||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         """Since __len__ is called quite frequently (for example, as part of |         """Since __len__ is called quite frequently (for example, as part of | ||||||
|         list(qs) we populate the result cache and cache the length. |         list(qs)), we populate the result cache and cache the length. | ||||||
|         """ |         """ | ||||||
|         if self._len is not None: |         if self._len is not None: | ||||||
|             return self._len |             return self._len | ||||||
|  |  | ||||||
|  |         # Populate the result cache with *all* of the docs in the cursor | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|             # populate the cache |  | ||||||
|             list(self._iter_results()) |             list(self._iter_results()) | ||||||
|  |  | ||||||
|  |         # Cache the length of the complete result cache and return it | ||||||
|         self._len = len(self._result_cache) |         self._len = len(self._result_cache) | ||||||
|         return self._len |         return self._len | ||||||
|  |  | ||||||
| @@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet): | |||||||
|     def _iter_results(self): |     def _iter_results(self): | ||||||
|         """A generator for iterating over the result cache. |         """A generator for iterating over the result cache. | ||||||
|  |  | ||||||
|         Also populates the cache if there are more possible results to yield. |         Also populates the cache if there are more possible results to | ||||||
|         Raises StopIteration when there are no more results""" |         yield. Raises StopIteration when there are no more results. | ||||||
|  |         """ | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|  |  | ||||||
|         pos = 0 |         pos = 0 | ||||||
|         while True: |         while True: | ||||||
|             upper = len(self._result_cache) |  | ||||||
|             while pos < upper: |             # For all positions lower than the length of the current result | ||||||
|  |             # cache, serve the docs straight from the cache w/o hitting the | ||||||
|  |             # database. | ||||||
|  |             # XXX it's VERY important to compute the len within the `while` | ||||||
|  |             # condition because the result cache might expand mid-iteration | ||||||
|  |             # (e.g. if we call len(qs) inside a loop that iterates over the | ||||||
|  |             # queryset). Fortunately len(list) is O(1) in Python, so this | ||||||
|  |             # doesn't cause performance issues. | ||||||
|  |             while pos < len(self._result_cache): | ||||||
|                 yield self._result_cache[pos] |                 yield self._result_cache[pos] | ||||||
|                 pos += 1 |                 pos += 1 | ||||||
|  |  | ||||||
|  |             # Raise StopIteration if we already established there were no more | ||||||
|  |             # docs in the db cursor. | ||||||
|             if not self._has_more: |             if not self._has_more: | ||||||
|                 raise StopIteration |                 raise StopIteration | ||||||
|  |  | ||||||
|  |             # Otherwise, populate more of the cache and repeat. | ||||||
|             if len(self._result_cache) <= pos: |             if len(self._result_cache) <= pos: | ||||||
|                 self._populate_cache() |                 self._populate_cache() | ||||||
|  |  | ||||||
| @@ -86,11 +104,21 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|         if self._has_more: |  | ||||||
|  |         # Skip populating the cache if we already established there are no | ||||||
|  |         # more docs to pull from the database. | ||||||
|  |         if not self._has_more: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         # Pull in ITER_CHUNK_SIZE docs from the database and store them in | ||||||
|  |         # the result cache. | ||||||
|         try: |         try: | ||||||
|             for i in xrange(ITER_CHUNK_SIZE): |             for i in xrange(ITER_CHUNK_SIZE): | ||||||
|                 self._result_cache.append(self.next()) |                 self._result_cache.append(self.next()) | ||||||
|         except StopIteration: |         except StopIteration: | ||||||
|  |             # Getting this exception means there are no more docs in the | ||||||
|  |             # db cursor. Set _has_more to False so that we can use that | ||||||
|  |             # information in other places. | ||||||
|             self._has_more = False |             self._has_more = False | ||||||
|  |  | ||||||
|     def count(self, with_limit_and_skip=False): |     def count(self, with_limit_and_skip=False): | ||||||
|   | |||||||
| @@ -1,11 +1,11 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  |  | ||||||
| import pymongo |  | ||||||
| from bson import SON | from bson import SON | ||||||
|  | import pymongo | ||||||
|  |  | ||||||
| from mongoengine.base.fields import UPDATE_OPERATORS | from mongoengine.base.fields import UPDATE_OPERATORS | ||||||
| from mongoengine.connection import get_connection |  | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
|  | from mongoengine.connection import get_connection | ||||||
| from mongoengine.errors import InvalidQueryError | from mongoengine.errors import InvalidQueryError | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import IS_PYMONGO_3 | ||||||
|  |  | ||||||
| @@ -44,7 +44,7 @@ def query(_doc_cls=None, **kwargs): | |||||||
|         if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: |         if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: | ||||||
|             op = parts.pop() |             op = parts.pop() | ||||||
|  |  | ||||||
|         # Allw to escape operator-like field name by __ |         # Allow to escape operator-like field name by __ | ||||||
|         if len(parts) > 1 and parts[-1] == "": |         if len(parts) > 1 and parts[-1] == "": | ||||||
|             parts.pop() |             parts.pop() | ||||||
|  |  | ||||||
| @@ -108,8 +108,11 @@ def query(_doc_cls=None, **kwargs): | |||||||
|             elif op in ('match', 'elemMatch'): |             elif op in ('match', 'elemMatch'): | ||||||
|                 ListField = _import_class('ListField') |                 ListField = _import_class('ListField') | ||||||
|                 EmbeddedDocumentField = _import_class('EmbeddedDocumentField') |                 EmbeddedDocumentField = _import_class('EmbeddedDocumentField') | ||||||
|                 if (isinstance(value, dict) and isinstance(field, ListField) and |                 if ( | ||||||
|                     isinstance(field.field, EmbeddedDocumentField)): |                     isinstance(value, dict) and | ||||||
|  |                     isinstance(field, ListField) and | ||||||
|  |                     isinstance(field.field, EmbeddedDocumentField) | ||||||
|  |                 ): | ||||||
|                     value = query(field.field.document_type, **value) |                     value = query(field.field.document_type, **value) | ||||||
|                 else: |                 else: | ||||||
|                     value = field.prepare_query_value(op, value) |                     value = field.prepare_query_value(op, value) | ||||||
| @@ -212,6 +215,10 @@ def update(_doc_cls=None, **update): | |||||||
|         if parts[-1] in COMPARISON_OPERATORS: |         if parts[-1] in COMPARISON_OPERATORS: | ||||||
|             match = parts.pop() |             match = parts.pop() | ||||||
|  |  | ||||||
|  |         # Allow to escape operator-like field name by __ | ||||||
|  |         if len(parts) > 1 and parts[-1] == "": | ||||||
|  |             parts.pop() | ||||||
|  |  | ||||||
|         if _doc_cls: |         if _doc_cls: | ||||||
|             # Switch field names to proper names [set in Field(name='foo')] |             # Switch field names to proper names [set in Field(name='foo')] | ||||||
|             try: |             try: | ||||||
| @@ -364,20 +371,24 @@ def _infer_geometry(value): | |||||||
|                                 "type and coordinates keys") |                                 "type and coordinates keys") | ||||||
|     elif isinstance(value, (list, set)): |     elif isinstance(value, (list, set)): | ||||||
|         # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? |         # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? | ||||||
|  |         # TODO: should both TypeError and IndexError be alike interpreted? | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             value[0][0][0] |             value[0][0][0] | ||||||
|             return {"$geometry": {"type": "Polygon", "coordinates": value}} |             return {"$geometry": {"type": "Polygon", "coordinates": value}} | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             value[0][0] |             value[0][0] | ||||||
|             return {"$geometry": {"type": "LineString", "coordinates": value}} |             return {"$geometry": {"type": "LineString", "coordinates": value}} | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             value[0] |             value[0] | ||||||
|             return {"$geometry": {"type": "Point", "coordinates": value}} |             return {"$geometry": {"type": "Point", "coordinates": value}} | ||||||
|         except: |         except (TypeError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|     raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " |     raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ except ImportError: | |||||||
|                                'because the blinker library is ' |                                'because the blinker library is ' | ||||||
|                                'not installed.') |                                'not installed.') | ||||||
|  |  | ||||||
|         send = lambda *a, **kw: None |         send = lambda *a, **kw: None  # noqa | ||||||
|         connect = disconnect = has_receivers_for = receivers_for = \ |         connect = disconnect = has_receivers_for = receivers_for = \ | ||||||
|             temporarily_connected_to = _fail |             temporarily_connected_to = _fail | ||||||
|         del _fail |         del _fail | ||||||
|   | |||||||
| @@ -1,2 +1,5 @@ | |||||||
| pymongo>=2.7.1 |  | ||||||
| nose | nose | ||||||
|  | pymongo>=2.7.1 | ||||||
|  | six==1.10.0 | ||||||
|  | flake8 | ||||||
|  | flake8-import-order | ||||||
|   | |||||||
| @@ -1,8 +1,13 @@ | |||||||
| [nosetests] | [nosetests] | ||||||
| rednose = 1 |  | ||||||
| verbosity = 2 | verbosity = 2 | ||||||
| detailed-errors = 1 | detailed-errors = 1 | ||||||
| cover-erase = 1 | cover-erase = 1 | ||||||
| cover-branches = 1 | cover-branches = 1 | ||||||
| cover-package = mongoengine | cover-package = mongoengine | ||||||
| tests = tests | tests = tests | ||||||
|  |  | ||||||
|  | [flake8] | ||||||
|  | ignore=E501,F401,F403,F405,I201 | ||||||
|  | exclude=build,dist,docs,venv,.tox,.eggs,tests | ||||||
|  | max-complexity=42 | ||||||
|  | application-import-names=mongoengine,tests | ||||||
|   | |||||||
							
								
								
									
										25
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								setup.py
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | |||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
| from setuptools import setup, find_packages | from setuptools import find_packages, setup | ||||||
|  |  | ||||||
| # Hack to silence atexit traceback in newer python versions | # Hack to silence atexit traceback in newer python versions | ||||||
| try: | try: | ||||||
| @@ -8,13 +8,16 @@ try: | |||||||
| except ImportError: | except ImportError: | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
| DESCRIPTION = 'MongoEngine is a Python Object-Document ' + \ | DESCRIPTION = ( | ||||||
|  |     'MongoEngine is a Python Object-Document ' | ||||||
|     'Mapper for working with MongoDB.' |     'Mapper for working with MongoDB.' | ||||||
| LONG_DESCRIPTION = None | ) | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     LONG_DESCRIPTION = open('README.rst').read() |     with open('README.rst') as fin: | ||||||
| except: |         LONG_DESCRIPTION = fin.read() | ||||||
|     pass | except Exception: | ||||||
|  |     LONG_DESCRIPTION = None | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_version(version_tuple): | def get_version(version_tuple): | ||||||
| @@ -22,6 +25,7 @@ def get_version(version_tuple): | |||||||
|         return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] |         return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] | ||||||
|     return '.'.join(map(str, version_tuple)) |     return '.'.join(map(str, version_tuple)) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Dirty hack to get version number from monogengine/__init__.py - we can't | # Dirty hack to get version number from monogengine/__init__.py - we can't | ||||||
| # import it as it depends on PyMongo and PyMongo isn't installed until this | # import it as it depends on PyMongo and PyMongo isn't installed until this | ||||||
| # file is read | # file is read | ||||||
| @@ -52,18 +56,19 @@ CLASSIFIERS = [ | |||||||
| extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | ||||||
| if sys.version_info[0] == 3: | if sys.version_info[0] == 3: | ||||||
|     extra_opts['use_2to3'] = True |     extra_opts['use_2to3'] = True | ||||||
|     extra_opts['tests_require'] = ['nose', 'rednose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] |     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] | ||||||
|     if "test" in sys.argv or "nosetests" in sys.argv: |     if "test" in sys.argv or "nosetests" in sys.argv: | ||||||
|         extra_opts['packages'] = find_packages() |         extra_opts['packages'] = find_packages() | ||||||
|         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} |         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} | ||||||
| else: | else: | ||||||
|     # coverage 4 does not support Python 3.2 anymore |     # coverage 4 does not support Python 3.2 anymore | ||||||
|     extra_opts['tests_require'] = ['nose', 'rednose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] |     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] | ||||||
|  |  | ||||||
|     if sys.version_info[0] == 2 and sys.version_info[1] == 6: |     if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||||
|         extra_opts['tests_require'].append('unittest2') |         extra_opts['tests_require'].append('unittest2') | ||||||
|  |  | ||||||
| setup(name='mongoengine', | setup( | ||||||
|  |     name='mongoengine', | ||||||
|     version=VERSION, |     version=VERSION, | ||||||
|     author='Harry Marr', |     author='Harry Marr', | ||||||
|     author_email='harry.marr@{nospam}gmail.com', |     author_email='harry.marr@{nospam}gmail.com', | ||||||
| @@ -77,7 +82,7 @@ setup(name='mongoengine', | |||||||
|     long_description=LONG_DESCRIPTION, |     long_description=LONG_DESCRIPTION, | ||||||
|     platforms=['any'], |     platforms=['any'], | ||||||
|     classifiers=CLASSIFIERS, |     classifiers=CLASSIFIERS, | ||||||
|       install_requires=['pymongo>=2.7.1'], |     install_requires=['pymongo>=2.7.1', 'six'], | ||||||
|     test_suite='nose.collector', |     test_suite='nose.collector', | ||||||
|     **extra_opts |     **extra_opts | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -2,7 +2,6 @@ | |||||||
| import unittest | import unittest | ||||||
| import sys | import sys | ||||||
|  |  | ||||||
| sys.path[0:0] = [""] |  | ||||||
|  |  | ||||||
| import pymongo | import pymongo | ||||||
|  |  | ||||||
| @@ -32,10 +31,7 @@ class IndexesTest(unittest.TestCase): | |||||||
|         self.Person = Person |         self.Person = Person | ||||||
|  |  | ||||||
|     def tearDown(self): |     def tearDown(self): | ||||||
|         for collection in self.db.collection_names(): |         self.connection.drop_database(self.db) | ||||||
|             if 'system.' in collection: |  | ||||||
|                 continue |  | ||||||
|             self.db.drop_collection(collection) |  | ||||||
|  |  | ||||||
|     def test_indexes_document(self): |     def test_indexes_document(self): | ||||||
|         """Ensure that indexes are used when meta[indexes] is specified for |         """Ensure that indexes are used when meta[indexes] is specified for | ||||||
| @@ -822,33 +818,34 @@ class IndexesTest(unittest.TestCase): | |||||||
|             name = StringField(required=True) |             name = StringField(required=True) | ||||||
|             term = StringField(required=True) |             term = StringField(required=True) | ||||||
|  |  | ||||||
|         class Report(Document): |         class ReportEmbedded(Document): | ||||||
|             key = EmbeddedDocumentField(CompoundKey, primary_key=True) |             key = EmbeddedDocumentField(CompoundKey, primary_key=True) | ||||||
|             text = StringField() |             text = StringField() | ||||||
|  |  | ||||||
|         Report.drop_collection() |  | ||||||
|  |  | ||||||
|         my_key = CompoundKey(name="n", term="ok") |         my_key = CompoundKey(name="n", term="ok") | ||||||
|         report = Report(text="OK", key=my_key).save() |         report = ReportEmbedded(text="OK", key=my_key).save() | ||||||
|  |  | ||||||
|         self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, |         self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, | ||||||
|                          report.to_mongo()) |                          report.to_mongo()) | ||||||
|         self.assertEqual(report, Report.objects.get(pk=my_key)) |         self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key)) | ||||||
|  |  | ||||||
|     def test_compound_key_dictfield(self): |     def test_compound_key_dictfield(self): | ||||||
|  |  | ||||||
|         class Report(Document): |         class ReportDictField(Document): | ||||||
|             key = DictField(primary_key=True) |             key = DictField(primary_key=True) | ||||||
|             text = StringField() |             text = StringField() | ||||||
|  |  | ||||||
|         Report.drop_collection() |  | ||||||
|  |  | ||||||
|         my_key = {"name": "n", "term": "ok"} |         my_key = {"name": "n", "term": "ok"} | ||||||
|         report = Report(text="OK", key=my_key).save() |         report = ReportDictField(text="OK", key=my_key).save() | ||||||
|  |  | ||||||
|         self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, |         self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, | ||||||
|                          report.to_mongo()) |                          report.to_mongo()) | ||||||
|         self.assertEqual(report, Report.objects.get(pk=my_key)) |  | ||||||
|  |         # We can't directly call ReportDictField.objects.get(pk=my_key), | ||||||
|  |         # because dicts are unordered, and if the order in MongoDB is | ||||||
|  |         # different than the one in `my_key`, this test will fail. | ||||||
|  |         self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key['name'])) | ||||||
|  |         self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key['term'])) | ||||||
|  |  | ||||||
|     def test_string_indexes(self): |     def test_string_indexes(self): | ||||||
|  |  | ||||||
| @@ -863,6 +860,20 @@ class IndexesTest(unittest.TestCase): | |||||||
|         self.assertTrue([('provider_ids.foo', 1)] in info) |         self.assertTrue([('provider_ids.foo', 1)] in info) | ||||||
|         self.assertTrue([('provider_ids.bar', 1)] in info) |         self.assertTrue([('provider_ids.bar', 1)] in info) | ||||||
|  |  | ||||||
|  |     def test_sparse_compound_indexes(self): | ||||||
|  |  | ||||||
|  |         class MyDoc(Document): | ||||||
|  |             provider_ids = DictField() | ||||||
|  |             meta = { | ||||||
|  |                 "indexes": [{'fields': ("provider_ids.foo", "provider_ids.bar"), | ||||||
|  |                              'sparse': True}], | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         info = MyDoc.objects._collection.index_information() | ||||||
|  |         self.assertEqual([('provider_ids.foo', 1), ('provider_ids.bar', 1)], | ||||||
|  |                          info['provider_ids.foo_1_provider_ids.bar_1']['key']) | ||||||
|  |         self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse']) | ||||||
|  |  | ||||||
|     def test_text_indexes(self): |     def test_text_indexes(self): | ||||||
|  |  | ||||||
|         class Book(Document): |         class Book(Document): | ||||||
| @@ -895,10 +906,18 @@ class IndexesTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Issue #812 |         Issue #812 | ||||||
|         """ |         """ | ||||||
|  |         # Use a new connection and database since dropping the database could | ||||||
|  |         # cause concurrent tests to fail. | ||||||
|  |         connection = connect(db='tempdatabase', | ||||||
|  |                              alias='test_indexes_after_database_drop') | ||||||
|  |  | ||||||
|         class BlogPost(Document): |         class BlogPost(Document): | ||||||
|             title = StringField() |             title = StringField() | ||||||
|             slug = StringField(unique=True) |             slug = StringField(unique=True) | ||||||
|  |  | ||||||
|  |             meta = {'db_alias': 'test_indexes_after_database_drop'} | ||||||
|  |  | ||||||
|  |         try: | ||||||
|             BlogPost.drop_collection() |             BlogPost.drop_collection() | ||||||
|  |  | ||||||
|             # Create Post #1 |             # Create Post #1 | ||||||
| @@ -906,7 +925,7 @@ class IndexesTest(unittest.TestCase): | |||||||
|             post1.save() |             post1.save() | ||||||
|  |  | ||||||
|             # Drop the Database |             # Drop the Database | ||||||
|         self.connection.drop_database(BlogPost._get_db().name) |             connection.drop_database('tempdatabase') | ||||||
|  |  | ||||||
|             # Re-create Post #1 |             # Re-create Post #1 | ||||||
|             post1 = BlogPost(title='test1', slug='test') |             post1 = BlogPost(title='test1', slug='test') | ||||||
| @@ -915,6 +934,10 @@ class IndexesTest(unittest.TestCase): | |||||||
|             # Create Post #2 |             # Create Post #2 | ||||||
|             post2 = BlogPost(title='test2', slug='test') |             post2 = BlogPost(title='test2', slug='test') | ||||||
|             self.assertRaises(NotUniqueError, post2.save) |             self.assertRaises(NotUniqueError, post2.save) | ||||||
|  |         finally: | ||||||
|  |             # Drop the temporary database at the end | ||||||
|  |             connection.drop_database('tempdatabase') | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_index_dont_send_cls_option(self): |     def test_index_dont_send_cls_option(self): | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -411,7 +411,7 @@ class InheritanceTest(unittest.TestCase): | |||||||
|         try: |         try: | ||||||
|             class MyDocument(DateCreatedDocument, DateUpdatedDocument): |             class MyDocument(DateCreatedDocument, DateUpdatedDocument): | ||||||
|                 pass |                 pass | ||||||
|         except: |         except Exception: | ||||||
|             self.assertTrue(False, "Couldn't create MyDocument class") |             self.assertTrue(False, "Couldn't create MyDocument class") | ||||||
|  |  | ||||||
|     def test_abstract_documents(self): |     def test_abstract_documents(self): | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ from datetime import datetime | |||||||
| from bson import DBRef, ObjectId | from bson import DBRef, ObjectId | ||||||
| from tests import fixtures | from tests import fixtures | ||||||
| from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, | from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, | ||||||
|                             PickleDyanmicEmbedded, PickleDynamicTest) |                             PickleDynamicEmbedded, PickleDynamicTest) | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.errors import (NotRegistered, InvalidDocumentError, | from mongoengine.errors import (NotRegistered, InvalidDocumentError, | ||||||
| @@ -571,6 +571,28 @@ class InstanceTest(unittest.TestCase): | |||||||
|         except Exception: |         except Exception: | ||||||
|             self.assertFalse("Threw wrong exception") |             self.assertFalse("Threw wrong exception") | ||||||
|  |  | ||||||
|  |     def test_reload_of_non_strict_with_special_field_name(self): | ||||||
|  |         """Ensures reloading works for documents with meta strict == False | ||||||
|  |         """ | ||||||
|  |         class Post(Document): | ||||||
|  |             meta = { | ||||||
|  |                 'strict': False | ||||||
|  |             } | ||||||
|  |             title = StringField() | ||||||
|  |             items = ListField() | ||||||
|  |  | ||||||
|  |         Post.drop_collection() | ||||||
|  |  | ||||||
|  |         Post._get_collection().insert({ | ||||||
|  |             "title": "Items eclipse", | ||||||
|  |             "items": ["more lorem", "even more ipsum"] | ||||||
|  |         }) | ||||||
|  |  | ||||||
|  |         post = Post.objects.first() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.title, "Items eclipse") | ||||||
|  |         self.assertEqual(post.items, ["more lorem", "even more ipsum"]) | ||||||
|  |  | ||||||
|     def test_dictionary_access(self): |     def test_dictionary_access(self): | ||||||
|         """Ensure that dictionary-style field access works properly. |         """Ensure that dictionary-style field access works properly. | ||||||
|         """ |         """ | ||||||
| @@ -657,6 +679,19 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc = Doc.objects.get() |         doc = Doc.objects.get() | ||||||
|         self.assertHasInstance(doc.embedded_field[0], doc) |         self.assertHasInstance(doc.embedded_field[0], doc) | ||||||
|  |  | ||||||
|  |     def test_embedded_document_complex_instance_no_use_db_field(self): | ||||||
|  |         """Ensure that use_db_field is propagated to list of Emb Docs | ||||||
|  |         """ | ||||||
|  |         class Embedded(EmbeddedDocument): | ||||||
|  |             string = StringField(db_field='s') | ||||||
|  |  | ||||||
|  |         class Doc(Document): | ||||||
|  |             embedded_field = ListField(EmbeddedDocumentField(Embedded)) | ||||||
|  |  | ||||||
|  |         d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo( | ||||||
|  |             use_db_field=False).to_dict() | ||||||
|  |         self.assertEqual(d['embedded_field'], [{'string': 'Hi'}]) | ||||||
|  |  | ||||||
|     def test_instance_is_set_on_setattr(self): |     def test_instance_is_set_on_setattr(self): | ||||||
|  |  | ||||||
|         class Email(EmbeddedDocument): |         class Email(EmbeddedDocument): | ||||||
| @@ -1871,6 +1906,62 @@ class InstanceTest(unittest.TestCase): | |||||||
|         author.delete() |         author.delete() | ||||||
|         self.assertEqual(BlogPost.objects.count(), 0) |         self.assertEqual(BlogPost.objects.count(), 0) | ||||||
|  |  | ||||||
|  |     def test_reverse_delete_rule_with_custom_id_field(self): | ||||||
|  |         """Ensure that a referenced document with custom primary key | ||||||
|  |         is also deleted upon deletion. | ||||||
|  |         """ | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField(primary_key=True) | ||||||
|  |  | ||||||
|  |         class Book(Document): | ||||||
|  |             author = ReferenceField(User, reverse_delete_rule=CASCADE) | ||||||
|  |             reviewer = ReferenceField(User, reverse_delete_rule=NULLIFY) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Book.drop_collection() | ||||||
|  |  | ||||||
|  |         user = User(name='Mike').save() | ||||||
|  |         reviewer = User(name='John').save() | ||||||
|  |         book = Book(author=user, reviewer=reviewer).save() | ||||||
|  |  | ||||||
|  |         reviewer.delete() | ||||||
|  |         self.assertEqual(Book.objects.count(), 1) | ||||||
|  |         self.assertEqual(Book.objects.get().reviewer, None) | ||||||
|  |  | ||||||
|  |         user.delete() | ||||||
|  |         self.assertEqual(Book.objects.count(), 0) | ||||||
|  |  | ||||||
|  |     def test_reverse_delete_rule_with_shared_id_among_collections(self): | ||||||
|  |         """Ensure that cascade delete rule doesn't mix id among collections. | ||||||
|  |         """ | ||||||
|  |         class User(Document): | ||||||
|  |             id = IntField(primary_key=True) | ||||||
|  |  | ||||||
|  |         class Book(Document): | ||||||
|  |             id = IntField(primary_key=True) | ||||||
|  |             author = ReferenceField(User, reverse_delete_rule=CASCADE) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Book.drop_collection() | ||||||
|  |  | ||||||
|  |         user_1 = User(id=1).save() | ||||||
|  |         user_2 = User(id=2).save() | ||||||
|  |         book_1 = Book(id=1, author=user_2).save() | ||||||
|  |         book_2 = Book(id=2, author=user_1).save() | ||||||
|  |  | ||||||
|  |         user_2.delete() | ||||||
|  |         # Deleting user_2 should also delete book_1 but not book_2 | ||||||
|  |         self.assertEqual(Book.objects.count(), 1) | ||||||
|  |         self.assertEqual(Book.objects.get(), book_2) | ||||||
|  |  | ||||||
|  |         user_3 = User(id=3).save() | ||||||
|  |         book_3 = Book(id=3, author=user_3).save() | ||||||
|  |  | ||||||
|  |         user_3.delete() | ||||||
|  |         # Deleting user_3 should also delete book_3 | ||||||
|  |         self.assertEqual(Book.objects.count(), 1) | ||||||
|  |         self.assertEqual(Book.objects.get(), book_2) | ||||||
|  |  | ||||||
|     def test_reverse_delete_rule_with_document_inheritance(self): |     def test_reverse_delete_rule_with_document_inheritance(self): | ||||||
|         """Ensure that a referenced document is also deleted upon deletion |         """Ensure that a referenced document is also deleted upon deletion | ||||||
|         of a child document. |         of a child document. | ||||||
| @@ -2226,7 +2317,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         pickle_doc = PickleDynamicTest( |         pickle_doc = PickleDynamicTest( | ||||||
|             name="test", number=1, string="One", lists=['1', '2']) |             name="test", number=1, string="One", lists=['1', '2']) | ||||||
|         pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") |         pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar") | ||||||
|         pickled_doc = pickle.dumps(pickle_doc)  # make sure pickling works even before the doc is saved |         pickled_doc = pickle.dumps(pickle_doc)  # make sure pickling works even before the doc is saved | ||||||
|  |  | ||||||
|         pickle_doc.save() |         pickle_doc.save() | ||||||
| @@ -2837,6 +2928,20 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(person.name, "Test User") |         self.assertEqual(person.name, "Test User") | ||||||
|         self.assertEqual(person.age, 42) |         self.assertEqual(person.age, 42) | ||||||
|  |  | ||||||
|  |     def test_positional_creation_embedded(self): | ||||||
|  |         """Ensure that embedded document may be created using positional arguments. | ||||||
|  |         """ | ||||||
|  |         job = self.Job("Test Job", 4) | ||||||
|  |         self.assertEqual(job.name, "Test Job") | ||||||
|  |         self.assertEqual(job.years, 4) | ||||||
|  |  | ||||||
|  |     def test_mixed_creation_embedded(self): | ||||||
|  |         """Ensure that embedded document may be created using mixed arguments. | ||||||
|  |         """ | ||||||
|  |         job = self.Job("Test Job", years=4) | ||||||
|  |         self.assertEqual(job.name, "Test Job") | ||||||
|  |         self.assertEqual(job.years, 4) | ||||||
|  |  | ||||||
|     def test_mixed_creation_dynamic(self): |     def test_mixed_creation_dynamic(self): | ||||||
|         """Ensure that document may be created using mixed arguments. |         """Ensure that document may be created using mixed arguments. | ||||||
|         """ |         """ | ||||||
| @@ -3013,6 +3118,17 @@ class InstanceTest(unittest.TestCase): | |||||||
|         p4 = Person.objects()[0] |         p4 = Person.objects()[0] | ||||||
|         p4.save() |         p4.save() | ||||||
|         self.assertEquals(p4.height, 189) |         self.assertEquals(p4.height, 189) | ||||||
|  |          | ||||||
|  |         # However the default will not be fixed in DB | ||||||
|  |         self.assertEquals(Person.objects(height=189).count(), 0) | ||||||
|  |          | ||||||
|  |         # alter DB for the new default | ||||||
|  |         coll = Person._get_collection() | ||||||
|  |         for person in Person.objects.as_pymongo(): | ||||||
|  |             if 'height' not in person: | ||||||
|  |                 person['height'] = 189 | ||||||
|  |                 coll.save(person) | ||||||
|  |                  | ||||||
|         self.assertEquals(Person.objects(height=189).count(), 1) |         self.assertEquals(Person.objects(height=189).count(), 1) | ||||||
|  |  | ||||||
|     def test_from_son(self): |     def test_from_son(self): | ||||||
| @@ -3086,5 +3202,20 @@ class InstanceTest(unittest.TestCase): | |||||||
|             self.assertEqual(b._instance, a) |             self.assertEqual(b._instance, a) | ||||||
|         self.assertEqual(idx, 2) |         self.assertEqual(idx, 2) | ||||||
|  |  | ||||||
|  |     def test_falsey_pk(self): | ||||||
|  |         """Ensure that we can create and update a document with Falsey PK. | ||||||
|  |         """ | ||||||
|  |         class Person(Document): | ||||||
|  |             age = IntField(primary_key=True) | ||||||
|  |             height = FloatField() | ||||||
|  |  | ||||||
|  |         person = Person() | ||||||
|  |         person.age = 0 | ||||||
|  |         person.height = 1.89 | ||||||
|  |         person.save() | ||||||
|  |  | ||||||
|  |         person.update(set__height=2.0) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| import sys | import sys | ||||||
|  |  | ||||||
|  | import six | ||||||
| from nose.plugins.skip import SkipTest | from nose.plugins.skip import SkipTest | ||||||
|  |  | ||||||
| sys.path[0:0] = [""] | sys.path[0:0] = [""] | ||||||
| @@ -10,6 +12,7 @@ import uuid | |||||||
| import math | import math | ||||||
| import itertools | import itertools | ||||||
| import re | import re | ||||||
|  | import six | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     import dateutil |     import dateutil | ||||||
| @@ -19,12 +22,16 @@ except ImportError: | |||||||
| from decimal import Decimal | from decimal import Decimal | ||||||
|  |  | ||||||
| from bson import Binary, DBRef, ObjectId | from bson import Binary, DBRef, ObjectId | ||||||
|  | try: | ||||||
|  |     from bson.int64 import Int64 | ||||||
|  | except ImportError: | ||||||
|  |     Int64 = long | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import get_db | from mongoengine.connection import get_db | ||||||
| from mongoengine.base import _document_registry | from mongoengine.base import _document_registry | ||||||
| from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList | from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList | ||||||
| from mongoengine.errors import NotRegistered | from mongoengine.errors import NotRegistered, DoesNotExist | ||||||
| from mongoengine.python_support import PY3, b, bin_type | from mongoengine.python_support import PY3, b, bin_type | ||||||
|  |  | ||||||
| __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") | __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") | ||||||
| @@ -399,20 +406,37 @@ class FieldTest(unittest.TestCase): | |||||||
|         class Person(Document): |         class Person(Document): | ||||||
|             height = FloatField(min_value=0.1, max_value=3.5) |             height = FloatField(min_value=0.1, max_value=3.5) | ||||||
|  |  | ||||||
|  |         class BigPerson(Document): | ||||||
|  |             height = FloatField() | ||||||
|  |  | ||||||
|         person = Person() |         person = Person() | ||||||
|         person.height = 1.89 |         person.height = 1.89 | ||||||
|         person.validate() |         person.validate() | ||||||
|  |  | ||||||
|         person.height = '2.0' |         person.height = '2.0' | ||||||
|         self.assertRaises(ValidationError, person.validate) |         self.assertRaises(ValidationError, person.validate) | ||||||
|  |  | ||||||
|         person.height = 0.01 |         person.height = 0.01 | ||||||
|         self.assertRaises(ValidationError, person.validate) |         self.assertRaises(ValidationError, person.validate) | ||||||
|  |  | ||||||
|         person.height = 4.0 |         person.height = 4.0 | ||||||
|         self.assertRaises(ValidationError, person.validate) |         self.assertRaises(ValidationError, person.validate) | ||||||
|  |  | ||||||
|         person_2 = Person(height='something invalid') |         person_2 = Person(height='something invalid') | ||||||
|         self.assertRaises(ValidationError, person_2.validate) |         self.assertRaises(ValidationError, person_2.validate) | ||||||
|  |  | ||||||
|  |         big_person = BigPerson() | ||||||
|  |  | ||||||
|  |         for value, value_type in enumerate(six.integer_types): | ||||||
|  |             big_person.height = value_type(value) | ||||||
|  |             big_person.validate() | ||||||
|  |  | ||||||
|  |         big_person.height = 2 ** 500 | ||||||
|  |         big_person.validate() | ||||||
|  |  | ||||||
|  |         big_person.height = 2 ** 100000  # Too big for a float value | ||||||
|  |         self.assertRaises(ValidationError, big_person.validate) | ||||||
|  |  | ||||||
|     def test_decimal_validation(self): |     def test_decimal_validation(self): | ||||||
|         """Ensure that invalid values cannot be assigned to decimal fields. |         """Ensure that invalid values cannot be assigned to decimal fields. | ||||||
|         """ |         """ | ||||||
| @@ -1022,6 +1046,54 @@ class FieldTest(unittest.TestCase): | |||||||
|         self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1) |         self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1) | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_list_assignment(self): | ||||||
|  |         """Ensure that list field element assignment and slicing work | ||||||
|  |         """ | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             info = ListField() | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         post = BlogPost() | ||||||
|  |         post.info = ['e1', 'e2', 3, '4', 5] | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         post.info[0] = 1 | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info[0], 1) | ||||||
|  |  | ||||||
|  |         post.info[1:3] = ['n2', 'n3'] | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info, [1, 'n2', 'n3', '4', 5]) | ||||||
|  |  | ||||||
|  |         post.info[-1] = 'n5' | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info, [1, 'n2', 'n3', '4', 'n5']) | ||||||
|  |  | ||||||
|  |         post.info[-2] = 4 | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5']) | ||||||
|  |  | ||||||
|  |         post.info[1:-1] = [2] | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info, [1, 2, 'n5']) | ||||||
|  |  | ||||||
|  |         post.info[:-1] = [1, 'n2', 'n3', 4] | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5']) | ||||||
|  |  | ||||||
|  |         post.info[-4:3] = [2, 3] | ||||||
|  |         post.save() | ||||||
|  |         post.reload() | ||||||
|  |         self.assertEqual(post.info, [1, 2, 3, 4, 'n5']) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_list_field_passed_in_value(self): |     def test_list_field_passed_in_value(self): | ||||||
|         class Foo(Document): |         class Foo(Document): | ||||||
|             bars = ListField(ReferenceField("Bar")) |             bars = ListField(ReferenceField("Bar")) | ||||||
| @@ -1136,6 +1208,19 @@ class FieldTest(unittest.TestCase): | |||||||
|         simple = simple.reload() |         simple = simple.reload() | ||||||
|         self.assertEqual(simple.widgets, [4]) |         self.assertEqual(simple.widgets, [4]) | ||||||
|  |  | ||||||
|  |     def test_list_field_with_negative_indices(self): | ||||||
|  |  | ||||||
|  |         class Simple(Document): | ||||||
|  |             widgets = ListField() | ||||||
|  |  | ||||||
|  |         simple = Simple(widgets=[1, 2, 3, 4]).save() | ||||||
|  |         simple.widgets[-1] = 5 | ||||||
|  |         self.assertEqual(['widgets.3'], simple._changed_fields) | ||||||
|  |         simple.save() | ||||||
|  |  | ||||||
|  |         simple = simple.reload() | ||||||
|  |         self.assertEqual(simple.widgets, [1, 2, 3, 5]) | ||||||
|  |  | ||||||
|     def test_list_field_complex(self): |     def test_list_field_complex(self): | ||||||
|         """Ensure that the list fields can handle the complex types.""" |         """Ensure that the list fields can handle the complex types.""" | ||||||
|  |  | ||||||
| @@ -1515,6 +1600,29 @@ class FieldTest(unittest.TestCase): | |||||||
|             actions__friends__operation='drink', |             actions__friends__operation='drink', | ||||||
|             actions__friends__object='beer').count()) |             actions__friends__object='beer').count()) | ||||||
|  |  | ||||||
|  |     def test_map_field_unicode(self): | ||||||
|  |  | ||||||
|  |         class Info(EmbeddedDocument): | ||||||
|  |             description = StringField() | ||||||
|  |             value_list = ListField(field=StringField()) | ||||||
|  |  | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             info_dict = MapField(field=EmbeddedDocumentField(Info)) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         tree = BlogPost(info_dict={ | ||||||
|  |             u"éééé": { | ||||||
|  |                 'description': u"VALUE: éééé" | ||||||
|  |             } | ||||||
|  |         }) | ||||||
|  |  | ||||||
|  |         tree.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, u"VALUE: éééé") | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|     def test_embedded_db_field(self): |     def test_embedded_db_field(self): | ||||||
|  |  | ||||||
|         class Embedded(EmbeddedDocument): |         class Embedded(EmbeddedDocument): | ||||||
| @@ -1551,6 +1659,8 @@ class FieldTest(unittest.TestCase): | |||||||
|             name = StringField() |             name = StringField() | ||||||
|             preferences = EmbeddedDocumentField(PersonPreferences) |             preferences = EmbeddedDocumentField(PersonPreferences) | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |  | ||||||
|         person = Person(name='Test User') |         person = Person(name='Test User') | ||||||
|         person.preferences = 'My Preferences' |         person.preferences = 'My Preferences' | ||||||
|         self.assertRaises(ValidationError, person.validate) |         self.assertRaises(ValidationError, person.validate) | ||||||
| @@ -1583,12 +1693,70 @@ class FieldTest(unittest.TestCase): | |||||||
|             content = StringField() |             content = StringField() | ||||||
|             author = EmbeddedDocumentField(User) |             author = EmbeddedDocumentField(User) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         post = BlogPost(content='What I did today...') |         post = BlogPost(content='What I did today...') | ||||||
|         post.author = PowerUser(name='Test User', power=47) |         post.author = PowerUser(name='Test User', power=47) | ||||||
|         post.save() |         post.save() | ||||||
|  |  | ||||||
|         self.assertEqual(47, BlogPost.objects.first().author.power) |         self.assertEqual(47, BlogPost.objects.first().author.power) | ||||||
|  |  | ||||||
|  |     def test_embedded_document_inheritance_with_list(self): | ||||||
|  |         """Ensure that nested list of subclassed embedded documents is | ||||||
|  |         handled correctly. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         class Group(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |             content = ListField(StringField()) | ||||||
|  |  | ||||||
|  |         class Basedoc(Document): | ||||||
|  |             groups = ListField(EmbeddedDocumentField(Group)) | ||||||
|  |             meta = {'abstract': True} | ||||||
|  |  | ||||||
|  |         class User(Basedoc): | ||||||
|  |             doctype = StringField(require=True, default='userdata') | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |  | ||||||
|  |         content = ['la', 'le', 'lu'] | ||||||
|  |         group = Group(name='foo', content=content) | ||||||
|  |         foobar = User(groups=[group]) | ||||||
|  |         foobar.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(content, User.objects.first().groups[0].content) | ||||||
|  |  | ||||||
|  |     def test_reference_miss(self): | ||||||
|  |         """Ensure an exception is raised when dereferencing unknow document | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         class Foo(Document): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         class Bar(Document): | ||||||
|  |             ref = ReferenceField(Foo) | ||||||
|  |             generic_ref = GenericReferenceField() | ||||||
|  |  | ||||||
|  |         Foo.drop_collection() | ||||||
|  |         Bar.drop_collection() | ||||||
|  |  | ||||||
|  |         foo = Foo().save() | ||||||
|  |         bar = Bar(ref=foo, generic_ref=foo).save() | ||||||
|  |  | ||||||
|  |         # Reference is no longer valid | ||||||
|  |         foo.delete() | ||||||
|  |         bar = Bar.objects.get() | ||||||
|  |         self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref')) | ||||||
|  |         self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref')) | ||||||
|  |  | ||||||
|  |         # When auto_dereference is disabled, there is no trouble returning DBRef | ||||||
|  |         bar = Bar.objects.get() | ||||||
|  |         expected = foo.to_dbref() | ||||||
|  |         bar._fields['ref']._auto_dereference = False | ||||||
|  |         self.assertEqual(bar.ref, expected) | ||||||
|  |         bar._fields['generic_ref']._auto_dereference = False | ||||||
|  |         self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'}) | ||||||
|  |  | ||||||
|     def test_reference_validation(self): |     def test_reference_validation(self): | ||||||
|         """Ensure that invalid docment objects cannot be assigned to reference |         """Ensure that invalid docment objects cannot be assigned to reference | ||||||
|         fields. |         fields. | ||||||
| @@ -2281,6 +2449,91 @@ class FieldTest(unittest.TestCase): | |||||||
|         Member.drop_collection() |         Member.drop_collection() | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_drop_abstract_document(self): | ||||||
|  |         """Ensure that an abstract document cannot be dropped given it | ||||||
|  |         has no underlying collection. | ||||||
|  |         """ | ||||||
|  |         class AbstractDoc(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = {"abstract": True} | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, AbstractDoc.drop_collection) | ||||||
|  |  | ||||||
|  |     def test_reference_class_with_abstract_parent(self): | ||||||
|  |         """Ensure that a class with an abstract parent can be referenced. | ||||||
|  |         """ | ||||||
|  |         class Sibling(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = {"abstract": True} | ||||||
|  |  | ||||||
|  |         class Sister(Sibling): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         class Brother(Sibling): | ||||||
|  |             sibling = ReferenceField(Sibling) | ||||||
|  |  | ||||||
|  |         Sister.drop_collection() | ||||||
|  |         Brother.drop_collection() | ||||||
|  |  | ||||||
|  |         sister = Sister(name="Alice") | ||||||
|  |         sister.save() | ||||||
|  |         brother = Brother(name="Bob", sibling=sister) | ||||||
|  |         brother.save() | ||||||
|  |  | ||||||
|  |         self.assertEquals(Brother.objects[0].sibling.name, sister.name) | ||||||
|  |  | ||||||
|  |         Sister.drop_collection() | ||||||
|  |         Brother.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_reference_abstract_class(self): | ||||||
|  |         """Ensure that an abstract class instance cannot be used in the | ||||||
|  |         reference of that abstract class. | ||||||
|  |         """ | ||||||
|  |         class Sibling(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = {"abstract": True} | ||||||
|  |  | ||||||
|  |         class Sister(Sibling): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         class Brother(Sibling): | ||||||
|  |             sibling = ReferenceField(Sibling) | ||||||
|  |  | ||||||
|  |         Sister.drop_collection() | ||||||
|  |         Brother.drop_collection() | ||||||
|  |  | ||||||
|  |         sister = Sibling(name="Alice") | ||||||
|  |         brother = Brother(name="Bob", sibling=sister) | ||||||
|  |         self.assertRaises(ValidationError, brother.save) | ||||||
|  |  | ||||||
|  |         Sister.drop_collection() | ||||||
|  |         Brother.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_abstract_reference_base_type(self): | ||||||
|  |         """Ensure that an an abstract reference fails validation when given a | ||||||
|  |         Document that does not inherit from the abstract type. | ||||||
|  |         """ | ||||||
|  |         class Sibling(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = {"abstract": True} | ||||||
|  |  | ||||||
|  |         class Brother(Sibling): | ||||||
|  |             sibling = ReferenceField(Sibling) | ||||||
|  |  | ||||||
|  |         class Mother(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         Brother.drop_collection() | ||||||
|  |         Mother.drop_collection() | ||||||
|  |  | ||||||
|  |         mother = Mother(name="Carol") | ||||||
|  |         mother.save() | ||||||
|  |         brother = Brother(name="Bob", sibling=mother) | ||||||
|  |         self.assertRaises(ValidationError, brother.save) | ||||||
|  |  | ||||||
|  |         Brother.drop_collection() | ||||||
|  |         Mother.drop_collection() | ||||||
|  |  | ||||||
|     def test_generic_reference(self): |     def test_generic_reference(self): | ||||||
|         """Ensure that a GenericReferenceField properly dereferences items. |         """Ensure that a GenericReferenceField properly dereferences items. | ||||||
|         """ |         """ | ||||||
| @@ -2748,28 +3001,32 @@ class FieldTest(unittest.TestCase): | |||||||
|                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), |                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), | ||||||
|                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) |                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) | ||||||
|             style = StringField(max_length=3, choices=( |             style = StringField(max_length=3, choices=( | ||||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') |                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') | ||||||
|  |  | ||||||
|         Shirt.drop_collection() |         Shirt.drop_collection() | ||||||
|  |  | ||||||
|         shirt = Shirt() |         shirt1 = Shirt() | ||||||
|  |         shirt2 = Shirt() | ||||||
|  |  | ||||||
|         self.assertEqual(shirt.get_size_display(), None) |         # Make sure get_<field>_display returns the default value (or None) | ||||||
|         self.assertEqual(shirt.get_style_display(), 'Small') |         self.assertEqual(shirt1.get_size_display(), None) | ||||||
|  |         self.assertEqual(shirt1.get_style_display(), 'Wide') | ||||||
|  |  | ||||||
|         shirt.size = "XXL" |         shirt1.size = 'XXL' | ||||||
|         shirt.style = "B" |         shirt1.style = 'B' | ||||||
|         self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') |         shirt2.size = 'M' | ||||||
|         self.assertEqual(shirt.get_style_display(), 'Baggy') |         shirt2.style = 'S' | ||||||
|  |         self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') | ||||||
|  |         self.assertEqual(shirt1.get_style_display(), 'Baggy') | ||||||
|  |         self.assertEqual(shirt2.get_size_display(), 'Medium') | ||||||
|  |         self.assertEqual(shirt2.get_style_display(), 'Small') | ||||||
|  |  | ||||||
|         # Set as Z - an invalid choice |         # Set as Z - an invalid choice | ||||||
|         shirt.size = "Z" |         shirt1.size = 'Z' | ||||||
|         shirt.style = "Z" |         shirt1.style = 'Z' | ||||||
|         self.assertEqual(shirt.get_size_display(), 'Z') |         self.assertEqual(shirt1.get_size_display(), 'Z') | ||||||
|         self.assertEqual(shirt.get_style_display(), 'Z') |         self.assertEqual(shirt1.get_style_display(), 'Z') | ||||||
|         self.assertRaises(ValidationError, shirt.validate) |         self.assertRaises(ValidationError, shirt1.validate) | ||||||
|  |  | ||||||
|         Shirt.drop_collection() |  | ||||||
|  |  | ||||||
|     def test_simple_choices_validation(self): |     def test_simple_choices_validation(self): | ||||||
|         """Ensure that value is in a container of allowed values. |         """Ensure that value is in a container of allowed values. | ||||||
| @@ -3472,6 +3729,19 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertRaises(FieldDoesNotExist, test) |         self.assertRaises(FieldDoesNotExist, test) | ||||||
|  |  | ||||||
|  |     def test_long_field_is_considered_as_int64(self): | ||||||
|  |         """ | ||||||
|  |         Tests that long fields are stored as long in mongo, even if long value | ||||||
|  |         is small enough to be an int. | ||||||
|  |         """ | ||||||
|  |         class TestLongFieldConsideredAsInt64(Document): | ||||||
|  |             some_long = LongField() | ||||||
|  |  | ||||||
|  |         doc = TestLongFieldConsideredAsInt64(some_long=42).save() | ||||||
|  |         db = get_db() | ||||||
|  |         self.assertTrue(isinstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64)) | ||||||
|  |         self.assertTrue(isinstance(doc.some_long, six.integer_types)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmbeddedDocumentListFieldTestCase(unittest.TestCase): | class EmbeddedDocumentListFieldTestCase(unittest.TestCase): | ||||||
|  |  | ||||||
| @@ -3859,6 +4129,17 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): | |||||||
|         # modified |         # modified | ||||||
|         self.assertEqual(number, 2) |         self.assertEqual(number, 2) | ||||||
|  |  | ||||||
|  |     def test_unicode(self): | ||||||
|  |         """ | ||||||
|  |         Tests that unicode strings handled correctly | ||||||
|  |         """ | ||||||
|  |         post = self.BlogPost(comments=[ | ||||||
|  |             self.Comments(author='user1', message=u'сообщение'), | ||||||
|  |             self.Comments(author='user2', message=u'хабарлама') | ||||||
|  |         ]).save() | ||||||
|  |         self.assertEqual(post.comments.get(message=u'сообщение').author, | ||||||
|  |                          'user1') | ||||||
|  |  | ||||||
|     def test_save(self): |     def test_save(self): | ||||||
|         """ |         """ | ||||||
|         Tests the save method of a List of Embedded Documents. |         Tests the save method of a List of Embedded Documents. | ||||||
|   | |||||||
| @@ -26,7 +26,7 @@ class NewDocumentPickleTest(Document): | |||||||
|     new_field = StringField() |     new_field = StringField() | ||||||
|  |  | ||||||
|  |  | ||||||
| class PickleDyanmicEmbedded(DynamicEmbeddedDocument): | class PickleDynamicEmbedded(DynamicEmbeddedDocument): | ||||||
|     date = DateTimeField(default=datetime.now) |     date = DateTimeField(default=datetime.now) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,8 +1,11 @@ | |||||||
|  | import unittest | ||||||
|  |  | ||||||
| from convert_to_new_inheritance_model import * | from convert_to_new_inheritance_model import * | ||||||
| from decimalfield_as_float import * | from decimalfield_as_float import * | ||||||
| from refrencefield_dbref_to_object_id import * | from referencefield_dbref_to_object_id import * | ||||||
| from turn_off_inheritance import * | from turn_off_inheritance import * | ||||||
| from uuidfield_to_binary import * | from uuidfield_to_binary import * | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
							
								
								
									
										78
									
								
								tests/queryset/pickable.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								tests/queryset/pickable.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | |||||||
|  | import pickle | ||||||
|  | import unittest | ||||||
|  | from pymongo.mongo_client import MongoClient | ||||||
|  | from mongoengine import Document, StringField, IntField | ||||||
|  | from mongoengine.connection import connect | ||||||
|  |  | ||||||
|  | __author__ = 'stas' | ||||||
|  |  | ||||||
|  | class Person(Document): | ||||||
|  |     name = StringField() | ||||||
|  |     age = IntField() | ||||||
|  |  | ||||||
|  | class TestQuerysetPickable(unittest.TestCase): | ||||||
|  |     """ | ||||||
|  |     Test for adding pickling support for QuerySet instances | ||||||
|  |     See issue https://github.com/MongoEngine/mongoengine/issues/442 | ||||||
|  |     """ | ||||||
|  |     def setUp(self): | ||||||
|  |         super(TestQuerysetPickable, self).setUp() | ||||||
|  |  | ||||||
|  |         connection = connect(db="test") #type: pymongo.mongo_client.MongoClient | ||||||
|  |  | ||||||
|  |         connection.drop_database("test") | ||||||
|  |  | ||||||
|  |         self.john = Person.objects.create( | ||||||
|  |             name="John", | ||||||
|  |             age=21 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def test_picke_simple_qs(self): | ||||||
|  |  | ||||||
|  |         qs = Person.objects.all() | ||||||
|  |  | ||||||
|  |         pickle.dumps(qs) | ||||||
|  |  | ||||||
|  |     def _get_loaded(self, qs): | ||||||
|  |         s = pickle.dumps(qs) | ||||||
|  |  | ||||||
|  |         return pickle.loads(s) | ||||||
|  |  | ||||||
|  |     def test_unpickle(self): | ||||||
|  |         qs = Person.objects.all() | ||||||
|  |  | ||||||
|  |         loadedQs = self._get_loaded(qs) | ||||||
|  |  | ||||||
|  |         self.assertEqual(qs.count(), loadedQs.count()) | ||||||
|  |  | ||||||
|  |         #can update loadedQs | ||||||
|  |         loadedQs.update(age=23) | ||||||
|  |  | ||||||
|  |         #check | ||||||
|  |         self.assertEqual(Person.objects.first().age, 23) | ||||||
|  |  | ||||||
|  |     def test_pickle_support_filtration(self): | ||||||
|  |         Person.objects.create( | ||||||
|  |             name="Alice", | ||||||
|  |             age=22 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         Person.objects.create( | ||||||
|  |             name="Bob", | ||||||
|  |             age=23 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         qs = Person.objects.filter(age__gte=22) | ||||||
|  |         self.assertEqual(qs.count(), 2) | ||||||
|  |  | ||||||
|  |         loaded = self._get_loaded(qs) | ||||||
|  |  | ||||||
|  |         self.assertEqual(loaded.count(), 2) | ||||||
|  |         self.assertEqual(loaded.filter(name="Bob").first().age, 23) | ||||||
|  |      | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -1,28 +1,23 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  |  | ||||||
| import sys | import datetime | ||||||
| sys.path[0:0] = [""] |  | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
|  |  | ||||||
|  | from bson import DBRef, ObjectId | ||||||
| from nose.plugins.skip import SkipTest | from nose.plugins.skip import SkipTest | ||||||
|  |  | ||||||
| from datetime import datetime, timedelta |  | ||||||
|  |  | ||||||
| import pymongo | import pymongo | ||||||
| from pymongo.errors import ConfigurationError | from pymongo.errors import ConfigurationError | ||||||
| from pymongo.read_preferences import ReadPreference | from pymongo.read_preferences import ReadPreference | ||||||
|  |  | ||||||
| from bson import ObjectId, DBRef |  | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import get_connection, get_db | from mongoengine.connection import get_connection, get_db | ||||||
| from mongoengine.python_support import PY3, IS_PYMONGO_3 |  | ||||||
| from mongoengine.context_managers import query_counter, switch_db | from mongoengine.context_managers import query_counter, switch_db | ||||||
| from mongoengine.queryset import (QuerySet, QuerySetManager, |  | ||||||
|                                   MultipleObjectsReturned, DoesNotExist, |  | ||||||
|                                   queryset_manager) |  | ||||||
| from mongoengine.errors import InvalidQueryError | from mongoengine.errors import InvalidQueryError | ||||||
|  | from mongoengine.python_support import IS_PYMONGO_3, PY3 | ||||||
|  | from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, | ||||||
|  |                                   QuerySet, QuerySetManager, queryset_manager) | ||||||
|  |  | ||||||
| __all__ = ("QuerySetTest",) | __all__ = ("QuerySetTest",) | ||||||
|  |  | ||||||
| @@ -184,12 +179,14 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(self.Person.objects.count(), 55) |         self.assertEqual(self.Person.objects.count(), 55) | ||||||
|         self.assertEqual("Person object", "%s" % self.Person.objects[0]) |         self.assertEqual("Person object", "%s" % self.Person.objects[0]) | ||||||
|         self.assertEqual( |         self.assertEqual("[<Person: Person object>, <Person: Person object>]", | ||||||
|             "[<Person: Person object>, <Person: Person object>]",  "%s" % self.Person.objects[1:3]) |                          "%s" % self.Person.objects[1:3]) | ||||||
|         self.assertEqual( |         self.assertEqual("[<Person: Person object>, <Person: Person object>]", | ||||||
|             "[<Person: Person object>, <Person: Person object>]",  "%s" % self.Person.objects[51:53]) |                          "%s" % self.Person.objects[51:53]) | ||||||
|  |  | ||||||
|         # Test only after limit |         # Test only after limit | ||||||
|         self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None) |         self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None) | ||||||
|  |  | ||||||
|         # Test only after skip |         # Test only after skip | ||||||
|         self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None) |         self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None) | ||||||
|  |  | ||||||
| @@ -287,6 +284,9 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         blog = Blog.objects(posts__0__comments__0__name='testa').get() |         blog = Blog.objects(posts__0__comments__0__name='testa').get() | ||||||
|         self.assertEqual(blog, blog1) |         self.assertEqual(blog, blog1) | ||||||
|  |  | ||||||
|  |         blog = Blog.objects(posts__0__comments__0__name='testb').get() | ||||||
|  |         self.assertEqual(blog, blog2) | ||||||
|  |  | ||||||
|         query = Blog.objects(posts__1__comments__1__name='testb') |         query = Blog.objects(posts__1__comments__1__name='testb') | ||||||
|         self.assertEqual(query.count(), 2) |         self.assertEqual(query.count(), 2) | ||||||
|  |  | ||||||
| @@ -339,7 +339,6 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_update_write_concern(self): |     def test_update_write_concern(self): | ||||||
|         """Test that passing write_concern works""" |         """Test that passing write_concern works""" | ||||||
|  |  | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
|         write_concern = {"fsync": True} |         write_concern = {"fsync": True} | ||||||
| @@ -680,12 +679,20 @@ class QuerySetTest(unittest.TestCase): | |||||||
|     def test_upsert_one(self): |     def test_upsert_one(self): | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
|         self.Person.objects(name="Bob", age=30).update_one(upsert=True) |         bob = self.Person.objects(name="Bob", age=30).upsert_one() | ||||||
|  |  | ||||||
|         bob = self.Person.objects.first() |  | ||||||
|         self.assertEqual("Bob", bob.name) |         self.assertEqual("Bob", bob.name) | ||||||
|         self.assertEqual(30, bob.age) |         self.assertEqual(30, bob.age) | ||||||
|  |  | ||||||
|  |         bob.name = "Bobby" | ||||||
|  |         bob.save() | ||||||
|  |  | ||||||
|  |         bobby = self.Person.objects(name="Bobby", age=30).upsert_one() | ||||||
|  |  | ||||||
|  |         self.assertEqual("Bobby", bobby.name) | ||||||
|  |         self.assertEqual(30, bobby.age) | ||||||
|  |         self.assertEqual(bob.id, bobby.id) | ||||||
|  |  | ||||||
|     def test_set_on_insert(self): |     def test_set_on_insert(self): | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
| @@ -1104,24 +1111,29 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         blog_2.save() |         blog_2.save() | ||||||
|         blog_3.save() |         blog_3.save() | ||||||
|  |  | ||||||
|         blog_post_1 = BlogPost(blog=blog_1, title="Blog Post #1", |         BlogPost.objects.create( | ||||||
|  |             blog=blog_1, | ||||||
|  |             title="Blog Post #1", | ||||||
|             is_published=True, |             is_published=True, | ||||||
|                                published_date=datetime(2010, 1, 5, 0, 0, 0)) |             published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) | ||||||
|         blog_post_2 = BlogPost(blog=blog_2, title="Blog Post #2", |         ) | ||||||
|  |         BlogPost.objects.create( | ||||||
|  |             blog=blog_2, | ||||||
|  |             title="Blog Post #2", | ||||||
|             is_published=True, |             is_published=True, | ||||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) |             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||||
|         blog_post_3 = BlogPost(blog=blog_3, title="Blog Post #3", |         ) | ||||||
|  |         BlogPost.objects.create( | ||||||
|  |             blog=blog_3, | ||||||
|  |             title="Blog Post #3", | ||||||
|             is_published=True, |             is_published=True, | ||||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) |             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||||
|  |         ) | ||||||
|         blog_post_1.save() |  | ||||||
|         blog_post_2.save() |  | ||||||
|         blog_post_3.save() |  | ||||||
|  |  | ||||||
|         # find all published blog posts before 2010-01-07 |         # find all published blog posts before 2010-01-07 | ||||||
|         published_posts = BlogPost.published() |         published_posts = BlogPost.published() | ||||||
|         published_posts = published_posts.filter( |         published_posts = published_posts.filter( | ||||||
|             published_date__lt=datetime(2010, 1, 7, 0, 0, 0)) |             published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0)) | ||||||
|         self.assertEqual(published_posts.count(), 2) |         self.assertEqual(published_posts.count(), 2) | ||||||
|  |  | ||||||
|         blog_posts = BlogPost.objects |         blog_posts = BlogPost.objects | ||||||
| @@ -1152,16 +1164,18 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         blog_post_1 = BlogPost(title="Blog Post #1", |         blog_post_1 = BlogPost.objects.create( | ||||||
|                                published_date=datetime(2010, 1, 5, 0, 0, 0)) |             title="Blog Post #1", | ||||||
|         blog_post_2 = BlogPost(title="Blog Post #2", |             published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) | ||||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) |         ) | ||||||
|         blog_post_3 = BlogPost(title="Blog Post #3", |         blog_post_2 = BlogPost.objects.create( | ||||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) |             title="Blog Post #2", | ||||||
|  |             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||||
|         blog_post_1.save() |         ) | ||||||
|         blog_post_2.save() |         blog_post_3 = BlogPost.objects.create( | ||||||
|         blog_post_3.save() |             title="Blog Post #3", | ||||||
|  |             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # get the "first" BlogPost using default ordering |         # get the "first" BlogPost using default ordering | ||||||
|         # from BlogPost.meta.ordering |         # from BlogPost.meta.ordering | ||||||
| @@ -1210,7 +1224,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             } |             } | ||||||
|  |  | ||||||
|         BlogPost.objects.create( |         BlogPost.objects.create( | ||||||
|             title='whatever', published_date=datetime.utcnow()) |             title='whatever', published_date=datetime.datetime.utcnow()) | ||||||
|  |  | ||||||
|         with db_ops_tracker() as q: |         with db_ops_tracker() as q: | ||||||
|             BlogPost.objects.get(title='whatever') |             BlogPost.objects.get(title='whatever') | ||||||
| @@ -1224,7 +1238,8 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|     def test_find_embedded(self): |     def test_find_embedded(self): | ||||||
|         """Ensure that an embedded document is properly returned from a query. |         """Ensure that an embedded document is properly returned from | ||||||
|  |         a query. | ||||||
|         """ |         """ | ||||||
|         class User(EmbeddedDocument): |         class User(EmbeddedDocument): | ||||||
|             name = StringField() |             name = StringField() | ||||||
| @@ -1235,16 +1250,31 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         post = BlogPost(content='Had a good coffee today...') |         BlogPost.objects.create( | ||||||
|         post.author = User(name='Test User') |             author=User(name='Test User'), | ||||||
|         post.save() |             content='Had a good coffee today...' | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         result = BlogPost.objects.first() |         result = BlogPost.objects.first() | ||||||
|         self.assertTrue(isinstance(result.author, User)) |         self.assertTrue(isinstance(result.author, User)) | ||||||
|         self.assertEqual(result.author.name, 'Test User') |         self.assertEqual(result.author.name, 'Test User') | ||||||
|  |  | ||||||
|  |     def test_find_empty_embedded(self): | ||||||
|  |         """Ensure that you can save and find an empty embedded document.""" | ||||||
|  |         class User(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             content = StringField() | ||||||
|  |             author = EmbeddedDocumentField(User) | ||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         BlogPost.objects.create(content='Anonymous post...') | ||||||
|  |  | ||||||
|  |         result = BlogPost.objects.get(author=None) | ||||||
|  |         self.assertEqual(result.author, None) | ||||||
|  |  | ||||||
|     def test_find_dict_item(self): |     def test_find_dict_item(self): | ||||||
|         """Ensure that DictField items may be found. |         """Ensure that DictField items may be found. | ||||||
|         """ |         """ | ||||||
| @@ -2073,18 +2103,22 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         blog_post_3 = BlogPost(title="Blog Post #3", |         blog_post_3 = BlogPost.objects.create( | ||||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) |             title="Blog Post #3", | ||||||
|         blog_post_2 = BlogPost(title="Blog Post #2", |             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||||
|                                published_date=datetime(2010, 1, 5, 0, 0, 0)) |         ) | ||||||
|         blog_post_4 = BlogPost(title="Blog Post #4", |         blog_post_2 = BlogPost.objects.create( | ||||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) |             title="Blog Post #2", | ||||||
|         blog_post_1 = BlogPost(title="Blog Post #1", published_date=None) |             published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) | ||||||
|  |         ) | ||||||
|         blog_post_3.save() |         blog_post_4 = BlogPost.objects.create( | ||||||
|         blog_post_1.save() |             title="Blog Post #4", | ||||||
|         blog_post_4.save() |             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||||
|         blog_post_2.save() |         ) | ||||||
|  |         blog_post_1 = BlogPost.objects.create( | ||||||
|  |             title="Blog Post #1", | ||||||
|  |             published_date=None | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4] |         expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4] | ||||||
|         self.assertSequence(BlogPost.objects.order_by('published_date'), |         self.assertSequence(BlogPost.objects.order_by('published_date'), | ||||||
| @@ -2103,16 +2137,18 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         blog_post_1 = BlogPost(title="A", |         blog_post_1 = BlogPost.objects.create( | ||||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) |             title="A", | ||||||
|         blog_post_2 = BlogPost(title="B", |             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||||
|                                published_date=datetime(2010, 1, 6, 0, 0, 0)) |         ) | ||||||
|         blog_post_3 = BlogPost(title="C", |         blog_post_2 = BlogPost.objects.create( | ||||||
|                                published_date=datetime(2010, 1, 7, 0, 0, 0)) |             title="B", | ||||||
|  |             published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) | ||||||
|         blog_post_2.save() |         ) | ||||||
|         blog_post_3.save() |         blog_post_3 = BlogPost.objects.create( | ||||||
|         blog_post_1.save() |             title="C", | ||||||
|  |             published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         qs = BlogPost.objects.order_by('published_date', 'title') |         qs = BlogPost.objects.order_by('published_date', 'title') | ||||||
|         expected = [blog_post_1, blog_post_2, blog_post_3] |         expected = [blog_post_1, blog_post_2, blog_post_3] | ||||||
| @@ -2178,6 +2214,21 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             a.author.name for a in Author.objects.order_by('-author__age')] |             a.author.name for a in Author.objects.order_by('-author__age')] | ||||||
|         self.assertEqual(names, ['User A', 'User B', 'User C']) |         self.assertEqual(names, ['User A', 'User B', 'User C']) | ||||||
|  |  | ||||||
|  |     def test_comment(self): | ||||||
|  |         """Make sure adding a comment to the query works.""" | ||||||
|  |         class User(Document): | ||||||
|  |             age = IntField() | ||||||
|  |  | ||||||
|  |         with db_ops_tracker() as q: | ||||||
|  |             adult = (User.objects.filter(age__gte=18) | ||||||
|  |                 .comment('looking for an adult') | ||||||
|  |                 .first()) | ||||||
|  |             ops = q.get_ops() | ||||||
|  |             self.assertEqual(len(ops), 1) | ||||||
|  |             op = ops[0] | ||||||
|  |             self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}}) | ||||||
|  |             self.assertEqual(op['query']['$comment'], 'looking for an adult') | ||||||
|  |  | ||||||
|     def test_map_reduce(self): |     def test_map_reduce(self): | ||||||
|         """Ensure map/reduce is both mapping and reducing. |         """Ensure map/reduce is both mapping and reducing. | ||||||
|         """ |         """ | ||||||
| @@ -2416,7 +2467,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Link.drop_collection() |         Link.drop_collection() | ||||||
|  |  | ||||||
|         now = datetime.utcnow() |         now = datetime.datetime.utcnow() | ||||||
|  |  | ||||||
|         # Note: Test data taken from a custom Reddit homepage on |         # Note: Test data taken from a custom Reddit homepage on | ||||||
|         # Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should |         # Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should | ||||||
| @@ -2425,27 +2476,27 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         Link(title="Google Buzz auto-followed a woman's abusive ex ...", |         Link(title="Google Buzz auto-followed a woman's abusive ex ...", | ||||||
|              up_votes=1079, |              up_votes=1079, | ||||||
|              down_votes=553, |              down_votes=553, | ||||||
|              submitted=now - timedelta(hours=4)).save() |              submitted=now - datetime.timedelta(hours=4)).save() | ||||||
|         Link(title="We did it! Barbie is a computer engineer.", |         Link(title="We did it! Barbie is a computer engineer.", | ||||||
|              up_votes=481, |              up_votes=481, | ||||||
|              down_votes=124, |              down_votes=124, | ||||||
|              submitted=now - timedelta(hours=2)).save() |              submitted=now - datetime.timedelta(hours=2)).save() | ||||||
|         Link(title="This Is A Mosquito Getting Killed By A Laser", |         Link(title="This Is A Mosquito Getting Killed By A Laser", | ||||||
|              up_votes=1446, |              up_votes=1446, | ||||||
|              down_votes=530, |              down_votes=530, | ||||||
|              submitted=now - timedelta(hours=13)).save() |              submitted=now - datetime.timedelta(hours=13)).save() | ||||||
|         Link(title="Arabic flashcards land physics student in jail.", |         Link(title="Arabic flashcards land physics student in jail.", | ||||||
|              up_votes=215, |              up_votes=215, | ||||||
|              down_votes=105, |              down_votes=105, | ||||||
|              submitted=now - timedelta(hours=6)).save() |              submitted=now - datetime.timedelta(hours=6)).save() | ||||||
|         Link(title="The Burger Lab: Presenting, the Flood Burger", |         Link(title="The Burger Lab: Presenting, the Flood Burger", | ||||||
|              up_votes=48, |              up_votes=48, | ||||||
|              down_votes=17, |              down_votes=17, | ||||||
|              submitted=now - timedelta(hours=5)).save() |              submitted=now - datetime.timedelta(hours=5)).save() | ||||||
|         Link(title="How to see polarization with the naked eye", |         Link(title="How to see polarization with the naked eye", | ||||||
|              up_votes=74, |              up_votes=74, | ||||||
|              down_votes=13, |              down_votes=13, | ||||||
|              submitted=now - timedelta(hours=10)).save() |              submitted=now - datetime.timedelta(hours=10)).save() | ||||||
|  |  | ||||||
|         map_f = """ |         map_f = """ | ||||||
|             function() { |             function() { | ||||||
| @@ -2495,7 +2546,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # provide the reddit epoch (used for ranking) as a variable available |         # provide the reddit epoch (used for ranking) as a variable available | ||||||
|         # to all phases of the map/reduce operation: map, reduce, and finalize. |         # to all phases of the map/reduce operation: map, reduce, and finalize. | ||||||
|         reddit_epoch = mktime(datetime(2005, 12, 8, 7, 46, 43).timetuple()) |         reddit_epoch = mktime(datetime.datetime(2005, 12, 8, 7, 46, 43).timetuple()) | ||||||
|         scope = {'reddit_epoch': reddit_epoch} |         scope = {'reddit_epoch': reddit_epoch} | ||||||
|  |  | ||||||
|         # run a map/reduce operation across all links. ordering is set |         # run a map/reduce operation across all links. ordering is set | ||||||
| @@ -2757,25 +2808,15 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         avg = float(sum(ages)) / (len(ages) + 1)  # take into account the 0 |         avg = float(sum(ages)) / (len(ages) + 1)  # take into account the 0 | ||||||
|         self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) |         self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) | ||||||
|         self.assertAlmostEqual( |  | ||||||
|             int(self.Person.objects.aggregate_average('age')), avg |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         self.Person(name='ageless person').save() |         self.Person(name='ageless person').save() | ||||||
|         self.assertEqual(int(self.Person.objects.average('age')), avg) |         self.assertEqual(int(self.Person.objects.average('age')), avg) | ||||||
|         self.assertEqual( |  | ||||||
|             int(self.Person.objects.aggregate_average('age')), avg |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # dot notation |         # dot notation | ||||||
|         self.Person( |         self.Person( | ||||||
|             name='person meta', person_meta=self.PersonMeta(weight=0)).save() |             name='person meta', person_meta=self.PersonMeta(weight=0)).save() | ||||||
|         self.assertAlmostEqual( |         self.assertAlmostEqual( | ||||||
|             int(self.Person.objects.average('person_meta.weight')), 0) |             int(self.Person.objects.average('person_meta.weight')), 0) | ||||||
|         self.assertAlmostEqual( |  | ||||||
|             int(self.Person.objects.aggregate_average('person_meta.weight')), |  | ||||||
|             0 |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         for i, weight in enumerate(ages): |         for i, weight in enumerate(ages): | ||||||
|             self.Person( |             self.Person( | ||||||
| @@ -2784,19 +2825,11 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertAlmostEqual( |         self.assertAlmostEqual( | ||||||
|             int(self.Person.objects.average('person_meta.weight')), avg |             int(self.Person.objects.average('person_meta.weight')), avg | ||||||
|         ) |         ) | ||||||
|         self.assertAlmostEqual( |  | ||||||
|             int(self.Person.objects.aggregate_average('person_meta.weight')), |  | ||||||
|             avg |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         self.Person(name='test meta none').save() |         self.Person(name='test meta none').save() | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             int(self.Person.objects.average('person_meta.weight')), avg |             int(self.Person.objects.average('person_meta.weight')), avg | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |  | ||||||
|             int(self.Person.objects.aggregate_average('person_meta.weight')), |  | ||||||
|             avg |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # test summing over a filtered queryset |         # test summing over a filtered queryset | ||||||
|         over_50 = [a for a in ages if a >= 50] |         over_50 = [a for a in ages if a >= 50] | ||||||
| @@ -2805,10 +2838,6 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             self.Person.objects.filter(age__gte=50).average('age'), |             self.Person.objects.filter(age__gte=50).average('age'), | ||||||
|             avg |             avg | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |  | ||||||
|             self.Person.objects.filter(age__gte=50).aggregate_average('age'), |  | ||||||
|             avg |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_sum(self): |     def test_sum(self): | ||||||
|         """Ensure that field can be summed over correctly. |         """Ensure that field can be summed over correctly. | ||||||
| @@ -2818,15 +2847,9 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             self.Person(name='test%s' % i, age=age).save() |             self.Person(name='test%s' % i, age=age).save() | ||||||
|  |  | ||||||
|         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) |         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||||
|         self.assertEqual( |  | ||||||
|             self.Person.objects.aggregate_sum('age'), sum(ages) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         self.Person(name='ageless person').save() |         self.Person(name='ageless person').save() | ||||||
|         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) |         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||||
|         self.assertEqual( |  | ||||||
|             self.Person.objects.aggregate_sum('age'), sum(ages) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         for i, age in enumerate(ages): |         for i, age in enumerate(ages): | ||||||
|             self.Person(name='test meta%s' % |             self.Person(name='test meta%s' % | ||||||
| @@ -2835,26 +2858,43 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             self.Person.objects.sum('person_meta.weight'), sum(ages) |             self.Person.objects.sum('person_meta.weight'), sum(ages) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |  | ||||||
|             self.Person.objects.aggregate_sum('person_meta.weight'), |  | ||||||
|             sum(ages) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         self.Person(name='weightless person').save() |         self.Person(name='weightless person').save() | ||||||
|         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) |         self.assertEqual(self.Person.objects.sum('age'), sum(ages)) | ||||||
|         self.assertEqual( |  | ||||||
|             self.Person.objects.aggregate_sum('age'), sum(ages) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # test summing over a filtered queryset |         # test summing over a filtered queryset | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             self.Person.objects.filter(age__gte=50).sum('age'), |             self.Person.objects.filter(age__gte=50).sum('age'), | ||||||
|             sum([a for a in ages if a >= 50]) |             sum([a for a in ages if a >= 50]) | ||||||
|         ) |         ) | ||||||
|         self.assertEqual( |  | ||||||
|             self.Person.objects.filter(age__gte=50).aggregate_sum('age'), |     def test_sum_over_db_field(self): | ||||||
|             sum([a for a in ages if a >= 50]) |         """Ensure that a field mapped to a db field with a different name | ||||||
|         ) |         can be summed over correctly. | ||||||
|  |         """ | ||||||
|  |         class UserVisit(Document): | ||||||
|  |             num_visits = IntField(db_field='visits') | ||||||
|  |  | ||||||
|  |         UserVisit.drop_collection() | ||||||
|  |  | ||||||
|  |         UserVisit.objects.create(num_visits=10) | ||||||
|  |         UserVisit.objects.create(num_visits=5) | ||||||
|  |  | ||||||
|  |         self.assertEqual(UserVisit.objects.sum('num_visits'), 15) | ||||||
|  |  | ||||||
|  |     def test_average_over_db_field(self): | ||||||
|  |         """Ensure that a field mapped to a db field with a different name | ||||||
|  |         can have its average computed correctly. | ||||||
|  |         """ | ||||||
|  |         class UserVisit(Document): | ||||||
|  |             num_visits = IntField(db_field='visits') | ||||||
|  |  | ||||||
|  |         UserVisit.drop_collection() | ||||||
|  |  | ||||||
|  |         UserVisit.objects.create(num_visits=20) | ||||||
|  |         UserVisit.objects.create(num_visits=10) | ||||||
|  |  | ||||||
|  |         self.assertEqual(UserVisit.objects.average('num_visits'), 15) | ||||||
|  |  | ||||||
|     def test_embedded_average(self): |     def test_embedded_average(self): | ||||||
|         class Pay(EmbeddedDocument): |         class Pay(EmbeddedDocument): | ||||||
| @@ -2867,21 +2907,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Doc.drop_collection() |         Doc.drop_collection() | ||||||
|  |  | ||||||
|         Doc(name=u"Wilson Junior", |         Doc(name='Wilson Junior', pay=Pay(value=150)).save() | ||||||
|             pay=Pay(value=150)).save() |         Doc(name='Isabella Luanna', pay=Pay(value=530)).save() | ||||||
|  |         Doc(name='Tayza mariana', pay=Pay(value=165)).save() | ||||||
|  |         Doc(name='Eliana Costa', pay=Pay(value=115)).save() | ||||||
|  |  | ||||||
|         Doc(name=u"Isabella Luanna", |         self.assertEqual(Doc.objects.average('pay.value'), 240) | ||||||
|             pay=Pay(value=530)).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Tayza mariana", |  | ||||||
|             pay=Pay(value=165)).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Eliana Costa", |  | ||||||
|             pay=Pay(value=115)).save() |  | ||||||
|  |  | ||||||
|         self.assertEqual( |  | ||||||
|             Doc.objects.average('pay.value'), |  | ||||||
|             240) |  | ||||||
|  |  | ||||||
|     def test_embedded_array_average(self): |     def test_embedded_array_average(self): | ||||||
|         class Pay(EmbeddedDocument): |         class Pay(EmbeddedDocument): | ||||||
| @@ -2889,26 +2920,16 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|             pay = EmbeddedDocumentField( |             pay = EmbeddedDocumentField(Pay) | ||||||
|                 Pay) |  | ||||||
|  |  | ||||||
|         Doc.drop_collection() |         Doc.drop_collection() | ||||||
|  |  | ||||||
|         Doc(name=u"Wilson Junior", |         Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() | ||||||
|             pay=Pay(values=[150, 100])).save() |         Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() | ||||||
|  |         Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() | ||||||
|  |         Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() | ||||||
|  |  | ||||||
|         Doc(name=u"Isabella Luanna", |         self.assertEqual(Doc.objects.average('pay.values'), 170) | ||||||
|             pay=Pay(values=[530, 100])).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Tayza mariana", |  | ||||||
|             pay=Pay(values=[165, 100])).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Eliana Costa", |  | ||||||
|             pay=Pay(values=[115, 100])).save() |  | ||||||
|  |  | ||||||
|         self.assertEqual( |  | ||||||
|             Doc.objects.average('pay.values'), |  | ||||||
|             170) |  | ||||||
|  |  | ||||||
|     def test_array_average(self): |     def test_array_average(self): | ||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
| @@ -2921,9 +2942,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         Doc(values=[165, 100]).save() |         Doc(values=[165, 100]).save() | ||||||
|         Doc(values=[115, 100]).save() |         Doc(values=[115, 100]).save() | ||||||
|  |  | ||||||
|         self.assertEqual( |         self.assertEqual(Doc.objects.average('values'), 170) | ||||||
|             Doc.objects.average('values'), |  | ||||||
|             170) |  | ||||||
|  |  | ||||||
|     def test_embedded_sum(self): |     def test_embedded_sum(self): | ||||||
|         class Pay(EmbeddedDocument): |         class Pay(EmbeddedDocument): | ||||||
| @@ -2931,26 +2950,16 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|             pay = EmbeddedDocumentField( |             pay = EmbeddedDocumentField(Pay) | ||||||
|                 Pay) |  | ||||||
|  |  | ||||||
|         Doc.drop_collection() |         Doc.drop_collection() | ||||||
|  |  | ||||||
|         Doc(name=u"Wilson Junior", |         Doc(name='Wilson Junior', pay=Pay(value=150)).save() | ||||||
|             pay=Pay(value=150)).save() |         Doc(name='Isabella Luanna', pay=Pay(value=530)).save() | ||||||
|  |         Doc(name='Tayza mariana', pay=Pay(value=165)).save() | ||||||
|  |         Doc(name='Eliana Costa', pay=Pay(value=115)).save() | ||||||
|  |  | ||||||
|         Doc(name=u"Isabella Luanna", |         self.assertEqual(Doc.objects.sum('pay.value'), 960) | ||||||
|             pay=Pay(value=530)).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Tayza mariana", |  | ||||||
|             pay=Pay(value=165)).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Eliana Costa", |  | ||||||
|             pay=Pay(value=115)).save() |  | ||||||
|  |  | ||||||
|         self.assertEqual( |  | ||||||
|             Doc.objects.sum('pay.value'), |  | ||||||
|             960) |  | ||||||
|  |  | ||||||
|     def test_embedded_array_sum(self): |     def test_embedded_array_sum(self): | ||||||
|         class Pay(EmbeddedDocument): |         class Pay(EmbeddedDocument): | ||||||
| @@ -2958,26 +2967,16 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|             pay = EmbeddedDocumentField( |             pay = EmbeddedDocumentField(Pay) | ||||||
|                 Pay) |  | ||||||
|  |  | ||||||
|         Doc.drop_collection() |         Doc.drop_collection() | ||||||
|  |  | ||||||
|         Doc(name=u"Wilson Junior", |         Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() | ||||||
|             pay=Pay(values=[150, 100])).save() |         Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() | ||||||
|  |         Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() | ||||||
|  |         Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() | ||||||
|  |  | ||||||
|         Doc(name=u"Isabella Luanna", |         self.assertEqual(Doc.objects.sum('pay.values'), 1360) | ||||||
|             pay=Pay(values=[530, 100])).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Tayza mariana", |  | ||||||
|             pay=Pay(values=[165, 100])).save() |  | ||||||
|  |  | ||||||
|         Doc(name=u"Eliana Costa", |  | ||||||
|             pay=Pay(values=[115, 100])).save() |  | ||||||
|  |  | ||||||
|         self.assertEqual( |  | ||||||
|             Doc.objects.sum('pay.values'), |  | ||||||
|             1360) |  | ||||||
|  |  | ||||||
|     def test_array_sum(self): |     def test_array_sum(self): | ||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
| @@ -2990,9 +2989,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         Doc(values=[165, 100]).save() |         Doc(values=[165, 100]).save() | ||||||
|         Doc(values=[115, 100]).save() |         Doc(values=[115, 100]).save() | ||||||
|  |  | ||||||
|         self.assertEqual( |         self.assertEqual(Doc.objects.sum('values'), 1360) | ||||||
|             Doc.objects.sum('values'), |  | ||||||
|             1360) |  | ||||||
|  |  | ||||||
|     def test_distinct(self): |     def test_distinct(self): | ||||||
|         """Ensure that the QuerySet.distinct method works. |         """Ensure that the QuerySet.distinct method works. | ||||||
| @@ -3169,13 +3166,11 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         mark_twain = Author(name="Mark Twain") |         mark_twain = Author(name="Mark Twain") | ||||||
|         john_tolkien = Author(name="John Ronald Reuel Tolkien") |         john_tolkien = Author(name="John Ronald Reuel Tolkien") | ||||||
|  |  | ||||||
|         book = Book(title="Tom Sawyer", authors=[mark_twain]).save() |         Book.objects.create(title="Tom Sawyer", authors=[mark_twain]) | ||||||
|         book = Book( |         Book.objects.create(title="The Lord of the Rings", authors=[john_tolkien]) | ||||||
|             title="The Lord of the Rings", authors=[john_tolkien]).save() |         Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) | ||||||
|         book = Book( |  | ||||||
|             title="The Stories", authors=[mark_twain, john_tolkien]).save() |  | ||||||
|         authors = Book.objects.distinct("authors") |  | ||||||
|  |  | ||||||
|  |         authors = Book.objects.distinct("authors") | ||||||
|         self.assertEqual(authors, [mark_twain, john_tolkien]) |         self.assertEqual(authors, [mark_twain, john_tolkien]) | ||||||
|  |  | ||||||
|     def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): |     def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): | ||||||
| @@ -3205,17 +3200,14 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         mark_twain = Author(name="Mark Twain", country=scotland) |         mark_twain = Author(name="Mark Twain", country=scotland) | ||||||
|         john_tolkien = Author(name="John Ronald Reuel Tolkien", country=tibet) |         john_tolkien = Author(name="John Ronald Reuel Tolkien", country=tibet) | ||||||
|  |  | ||||||
|         book = Book(title="Tom Sawyer", authors=[mark_twain]).save() |         Book.objects.create(title="Tom Sawyer", authors=[mark_twain]) | ||||||
|         book = Book( |         Book.objects.create(title="The Lord of the Rings", authors=[john_tolkien]) | ||||||
|             title="The Lord of the Rings", authors=[john_tolkien]).save() |         Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) | ||||||
|         book = Book( |  | ||||||
|             title="The Stories", authors=[mark_twain, john_tolkien]).save() |  | ||||||
|         country_list = Book.objects.distinct("authors.country") |  | ||||||
|  |  | ||||||
|  |         country_list = Book.objects.distinct("authors.country") | ||||||
|         self.assertEqual(country_list, [scotland, tibet]) |         self.assertEqual(country_list, [scotland, tibet]) | ||||||
|  |  | ||||||
|         continent_list = Book.objects.distinct("authors.country.continent") |         continent_list = Book.objects.distinct("authors.country.continent") | ||||||
|  |  | ||||||
|         self.assertEqual(continent_list, [europe, asia]) |         self.assertEqual(continent_list, [europe, asia]) | ||||||
|  |  | ||||||
|     def test_distinct_ListField_ReferenceField(self): |     def test_distinct_ListField_ReferenceField(self): | ||||||
| @@ -3247,7 +3239,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         class BlogPost(Document): |         class BlogPost(Document): | ||||||
|             tags = ListField(StringField()) |             tags = ListField(StringField()) | ||||||
|             deleted = BooleanField(default=False) |             deleted = BooleanField(default=False) | ||||||
|             date = DateTimeField(default=datetime.now) |             date = DateTimeField(default=datetime.datetime.now) | ||||||
|  |  | ||||||
|             @queryset_manager |             @queryset_manager | ||||||
|             def objects(cls, qryset): |             def objects(cls, qryset): | ||||||
| @@ -3604,6 +3596,15 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual(MyDoc.objects.count(), 10) |         self.assertEqual(MyDoc.objects.count(), 10) | ||||||
|         self.assertEqual(MyDoc.objects.none().count(), 0) |         self.assertEqual(MyDoc.objects.none().count(), 0) | ||||||
|  |  | ||||||
|  |     def test_count_list_embedded(self): | ||||||
|  |         class B(EmbeddedDocument): | ||||||
|  |             c = StringField() | ||||||
|  |  | ||||||
|  |         class A(Document): | ||||||
|  |             b = ListField(EmbeddedDocumentField(B)) | ||||||
|  |  | ||||||
|  |         self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0) | ||||||
|  |  | ||||||
|     def test_call_after_limits_set(self): |     def test_call_after_limits_set(self): | ||||||
|         """Ensure that re-filtering after slicing works |         """Ensure that re-filtering after slicing works | ||||||
|         """ |         """ | ||||||
| @@ -4061,8 +4062,8 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) |             "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) | ||||||
|         if PY3: |         if PY3: | ||||||
|             self.assertEqual( |             self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( | ||||||
|                 "['A1', 'A2']",  "%s" % self.Person.objects.order_by('age').scalar('name')[1:3]) |                 'age').scalar('name')[1:3]) | ||||||
|             self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( |             self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( | ||||||
|                 'age').scalar('name')[51:53]) |                 'age').scalar('name')[51:53]) | ||||||
|         else: |         else: | ||||||
| @@ -4077,12 +4078,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|                          self.Person.objects.scalar('name').with_id(person.id)) |                          self.Person.objects.scalar('name').with_id(person.id)) | ||||||
|  |  | ||||||
|         pks = self.Person.objects.order_by('age').scalar('pk')[1:3] |         pks = self.Person.objects.order_by('age').scalar('pk')[1:3] | ||||||
|  |         names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() | ||||||
|         if PY3: |         if PY3: | ||||||
|             self.assertEqual("['A1', 'A2']",  "%s" % sorted( |             expected = "['A1', 'A2']" | ||||||
|                 self.Person.objects.scalar('name').in_bulk(list(pks)).values())) |  | ||||||
|         else: |         else: | ||||||
|             self.assertEqual("[u'A1', u'A2']",  "%s" % sorted( |             expected = "[u'A1', u'A2']" | ||||||
|                 self.Person.objects.scalar('name').in_bulk(list(pks)).values())) |         self.assertEqual(expected, "%s" % sorted(names)) | ||||||
|  |  | ||||||
|     def test_elem_match(self): |     def test_elem_match(self): | ||||||
|         class Foo(EmbeddedDocument): |         class Foo(EmbeddedDocument): | ||||||
| @@ -4105,6 +4106,10 @@ class QuerySetTest(unittest.TestCase): | |||||||
|                       Foo(shape="circle", color="purple", thick=False)]) |                       Foo(shape="circle", color="purple", thick=False)]) | ||||||
|         b2.save() |         b2.save() | ||||||
|  |  | ||||||
|  |         b3 = Bar(foo=[Foo(shape="square", thick=True), | ||||||
|  |                       Foo(shape="circle", color="purple", thick=False)]) | ||||||
|  |         b3.save() | ||||||
|  |  | ||||||
|         ak = list( |         ak = list( | ||||||
|             Bar.objects(foo__match={'shape': "square", "color": "purple"})) |             Bar.objects(foo__match={'shape': "square", "color": "purple"})) | ||||||
|         self.assertEqual([b1], ak) |         self.assertEqual([b1], ak) | ||||||
| @@ -4124,6 +4129,13 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             Bar.objects(foo__match={'shape': "square", "color__exists": True})) |             Bar.objects(foo__match={'shape': "square", "color__exists": True})) | ||||||
|         self.assertEqual([b1, b2], ak) |         self.assertEqual([b1, b2], ak) | ||||||
|  |  | ||||||
|  |         ak = list( | ||||||
|  |             Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False})) | ||||||
|  |         self.assertEqual([b3], ak) | ||||||
|  |  | ||||||
|  |         ak = list( | ||||||
|  |             Bar.objects(foo__match={'shape': "square", "color__exists": False})) | ||||||
|  |         self.assertEqual([b3], ak) | ||||||
|  |  | ||||||
|     def test_upsert_includes_cls(self): |     def test_upsert_includes_cls(self): | ||||||
|         """Upserts should include _cls information for inheritable classes |         """Upserts should include _cls information for inheritable classes | ||||||
| @@ -4165,7 +4177,11 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_read_preference(self): |     def test_read_preference(self): | ||||||
|         class Bar(Document): |         class Bar(Document): | ||||||
|             pass |             txt = StringField() | ||||||
|  |  | ||||||
|  |             meta = { | ||||||
|  |                 'indexes': ['txt'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|         Bar.drop_collection() |         Bar.drop_collection() | ||||||
|         bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) |         bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) | ||||||
| @@ -4177,9 +4193,51 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             error_class = TypeError |             error_class = TypeError | ||||||
|         self.assertRaises(error_class, Bar.objects, read_preference='Primary') |         self.assertRaises(error_class, Bar.objects, read_preference='Primary') | ||||||
|  |  | ||||||
|  |         # read_preference as a kwarg | ||||||
|         bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) |         bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) | ||||||
|         self.assertEqual( |         self.assertEqual(bars._read_preference, | ||||||
|             bars._read_preference, ReadPreference.SECONDARY_PREFERRED) |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|  |         # read_preference as a query set method | ||||||
|  |         bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|  |         # read_preference after skip | ||||||
|  |         bars = Bar.objects.skip(1) \ | ||||||
|  |             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|  |         # read_preference after limit | ||||||
|  |         bars = Bar.objects.limit(1) \ | ||||||
|  |             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|  |         # read_preference after order_by | ||||||
|  |         bars = Bar.objects.order_by('txt') \ | ||||||
|  |             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|  |         # read_preference after hint | ||||||
|  |         bars = Bar.objects.hint([('txt', 1)]) \ | ||||||
|  |             .read_preference(ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |         self.assertEqual(bars._cursor._Cursor__read_preference, | ||||||
|  |                          ReadPreference.SECONDARY_PREFERRED) | ||||||
|  |  | ||||||
|     def test_json_simple(self): |     def test_json_simple(self): | ||||||
|  |  | ||||||
| @@ -4215,7 +4273,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             int_field = IntField(default=1) |             int_field = IntField(default=1) | ||||||
|             float_field = FloatField(default=1.1) |             float_field = FloatField(default=1.1) | ||||||
|             boolean_field = BooleanField(default=True) |             boolean_field = BooleanField(default=True) | ||||||
|             datetime_field = DateTimeField(default=datetime.now) |             datetime_field = DateTimeField(default=datetime.datetime.now) | ||||||
|             embedded_document_field = EmbeddedDocumentField( |             embedded_document_field = EmbeddedDocumentField( | ||||||
|                 EmbeddedDoc, default=lambda: EmbeddedDoc()) |                 EmbeddedDoc, default=lambda: EmbeddedDoc()) | ||||||
|             list_field = ListField(default=lambda: [1, 2, 3]) |             list_field = ListField(default=lambda: [1, 2, 3]) | ||||||
| @@ -4225,7 +4283,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|                 Simple, default=lambda: Simple().save()) |                 Simple, default=lambda: Simple().save()) | ||||||
|             map_field = MapField(IntField(), default=lambda: {"simple": 1}) |             map_field = MapField(IntField(), default=lambda: {"simple": 1}) | ||||||
|             decimal_field = DecimalField(default=1.0) |             decimal_field = DecimalField(default=1.0) | ||||||
|             complex_datetime_field = ComplexDateTimeField(default=datetime.now) |             complex_datetime_field = ComplexDateTimeField(default=datetime.datetime.now) | ||||||
|             url_field = URLField(default="http://mongoengine.org") |             url_field = URLField(default="http://mongoengine.org") | ||||||
|             dynamic_field = DynamicField(default=1) |             dynamic_field = DynamicField(default=1) | ||||||
|             generic_reference_field = GenericReferenceField( |             generic_reference_field = GenericReferenceField( | ||||||
| @@ -4572,8 +4630,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         B.drop_collection() |         B.drop_collection() | ||||||
|  |  | ||||||
|         a = A.objects.create(id='custom_id') |         a = A.objects.create(id='custom_id') | ||||||
|  |         B.objects.create(a=a) | ||||||
|         b = B.objects.create(a=a) |  | ||||||
|  |  | ||||||
|         self.assertEqual(B.objects.count(), 1) |         self.assertEqual(B.objects.count(), 1) | ||||||
|         self.assertEqual(B.objects.get(a=a).a, a) |         self.assertEqual(B.objects.get(a=a).a, a) | ||||||
| @@ -4833,5 +4890,56 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) |         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||||
|  |  | ||||||
|  |     def test_len_during_iteration(self): | ||||||
|  |         """Tests that calling len on a queyset during iteration doesn't | ||||||
|  |         stop paging. | ||||||
|  |         """ | ||||||
|  |         class Data(Document): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         for i in xrange(300): | ||||||
|  |             Data().save() | ||||||
|  |  | ||||||
|  |         records = Data.objects.limit(250) | ||||||
|  |  | ||||||
|  |         # This should pull all 250 docs from mongo and populate the result | ||||||
|  |         # cache | ||||||
|  |         len(records) | ||||||
|  |  | ||||||
|  |         # Assert that iterating over documents in the qs touches every | ||||||
|  |         # document even if we call len(qs) midway through the iteration. | ||||||
|  |         for i, r in enumerate(records): | ||||||
|  |             if i == 58: | ||||||
|  |                 len(records) | ||||||
|  |         self.assertEqual(i, 249) | ||||||
|  |  | ||||||
|  |         # Assert the same behavior is true even if we didn't pre-populate the | ||||||
|  |         # result cache. | ||||||
|  |         records = Data.objects.limit(250) | ||||||
|  |         for i, r in enumerate(records): | ||||||
|  |             if i == 58: | ||||||
|  |                 len(records) | ||||||
|  |         self.assertEqual(i, 249) | ||||||
|  |  | ||||||
|  |     def test_iteration_within_iteration(self): | ||||||
|  |         """You should be able to reliably iterate over all the documents | ||||||
|  |         in a given queryset even if there are multiple iterations of it | ||||||
|  |         happening at the same time. | ||||||
|  |         """ | ||||||
|  |         class Data(Document): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         for i in xrange(300): | ||||||
|  |             Data().save() | ||||||
|  |  | ||||||
|  |         qs = Data.objects.limit(250) | ||||||
|  |         for i, doc in enumerate(qs): | ||||||
|  |             for j, doc2 in enumerate(qs): | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |         self.assertEqual(i, 249) | ||||||
|  |         self.assertEqual(j, 249) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -1,11 +1,7 @@ | |||||||
| import sys |  | ||||||
| sys.path[0:0] = [""] |  | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.queryset import Q | from mongoengine.queryset import Q, transform | ||||||
| from mongoengine.queryset import transform |  | ||||||
|  |  | ||||||
| __all__ = ("TransformTest",) | __all__ = ("TransformTest",) | ||||||
|  |  | ||||||
| @@ -41,8 +37,8 @@ class TransformTest(unittest.TestCase): | |||||||
|         DicDoc.drop_collection() |         DicDoc.drop_collection() | ||||||
|         Doc.drop_collection() |         Doc.drop_collection() | ||||||
|  |  | ||||||
|  |         DicDoc().save() | ||||||
|         doc = Doc().save() |         doc = Doc().save() | ||||||
|         dic_doc = DicDoc().save() |  | ||||||
|  |  | ||||||
|         for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): |         for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): | ||||||
|             update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) |             update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) | ||||||
| @@ -55,7 +51,6 @@ class TransformTest(unittest.TestCase): | |||||||
|         update = transform.update(DicDoc, pull__dictField__test=doc) |         update = transform.update(DicDoc, pull__dictField__test=doc) | ||||||
|         self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) |         self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_query_field_name(self): |     def test_query_field_name(self): | ||||||
|         """Ensure that the correct field name is used when querying. |         """Ensure that the correct field name is used when querying. | ||||||
|         """ |         """ | ||||||
| @@ -156,16 +151,23 @@ class TransformTest(unittest.TestCase): | |||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
|             meta = {'allow_inheritance': False} |             meta = {'allow_inheritance': False} | ||||||
|  |  | ||||||
|         raw_query = Doc.objects(__raw__={'deleted': False, |         raw_query = Doc.objects(__raw__={ | ||||||
|  |             'deleted': False, | ||||||
|             'scraped': 'yes', |             'scraped': 'yes', | ||||||
|                                 '$nor': [{'views.extracted': 'no'}, |             '$nor': [ | ||||||
|                                          {'attachments.views.extracted':'no'}] |                 {'views.extracted': 'no'}, | ||||||
|  |                 {'attachments.views.extracted': 'no'} | ||||||
|  |             ] | ||||||
|         })._query |         })._query | ||||||
|  |  | ||||||
|         expected = {'deleted': False, 'scraped': 'yes', |         self.assertEqual(raw_query, { | ||||||
|                     '$nor': [{'views.extracted': 'no'}, |             'deleted': False, | ||||||
|                              {'attachments.views.extracted': 'no'}]} |             'scraped': 'yes', | ||||||
|         self.assertEqual(expected, raw_query) |             '$nor': [ | ||||||
|  |                 {'views.extracted': 'no'}, | ||||||
|  |                 {'attachments.views.extracted': 'no'} | ||||||
|  |             ] | ||||||
|  |         }) | ||||||
|  |  | ||||||
|     def test_geojson_PointField(self): |     def test_geojson_PointField(self): | ||||||
|         class Location(Document): |         class Location(Document): | ||||||
| @@ -224,6 +226,10 @@ class TransformTest(unittest.TestCase): | |||||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) |         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||||
|         self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) |         self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) | ||||||
|  |  | ||||||
|  |         Doc.objects(id=doc.id).update(set__item__type__='sword') | ||||||
|  |         self.assertEqual(1, Doc.objects(item__type__="sword").count()) | ||||||
|  |         self.assertEqual(0, Doc.objects(item__type__="axe").count()) | ||||||
|  |  | ||||||
|     def test_understandable_error_raised(self): |     def test_understandable_error_raised(self): | ||||||
|         class Event(Document): |         class Event(Document): | ||||||
|             title = StringField() |             title = StringField() | ||||||
| @@ -234,5 +240,6 @@ class TransformTest(unittest.TestCase): | |||||||
|         events = Event.objects(location__within=box) |         events = Event.objects(location__within=box) | ||||||
|         self.assertRaises(InvalidQueryError, lambda: events.count()) |         self.assertRaises(InvalidQueryError, lambda: events.count()) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -1,14 +1,12 @@ | |||||||
| import sys | import datetime | ||||||
| sys.path[0:0] = [""] | import re | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from bson import ObjectId | from bson import ObjectId | ||||||
| from datetime import datetime |  | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.queryset import Q |  | ||||||
| from mongoengine.errors import InvalidQueryError | from mongoengine.errors import InvalidQueryError | ||||||
|  | from mongoengine.queryset import Q | ||||||
|  |  | ||||||
| __all__ = ("QTest",) | __all__ = ("QTest",) | ||||||
|  |  | ||||||
| @@ -132,12 +130,12 @@ class QTest(unittest.TestCase): | |||||||
|         TestDoc(x=10).save() |         TestDoc(x=10).save() | ||||||
|         TestDoc(y=True).save() |         TestDoc(y=True).save() | ||||||
|  |  | ||||||
|         self.assertEqual(query, |         self.assertEqual(query, { | ||||||
|         {'$and': [ |             '$and': [ | ||||||
|                 {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]}, |                 {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]}, | ||||||
|                 {'$or': [{'x': {'$lt': 100}}, {'y': True}]} |                 {'$or': [{'x': {'$lt': 100}}, {'y': True}]} | ||||||
|         ]}) |             ] | ||||||
|  |         }) | ||||||
|         self.assertEqual(2, TestDoc.objects(q1 & q2).count()) |         self.assertEqual(2, TestDoc.objects(q1 & q2).count()) | ||||||
|  |  | ||||||
|     def test_or_and_or_combination(self): |     def test_or_and_or_combination(self): | ||||||
| @@ -157,15 +155,14 @@ class QTest(unittest.TestCase): | |||||||
|         q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) |         q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) | ||||||
|         query = (q1 | q2).to_query(TestDoc) |         query = (q1 | q2).to_query(TestDoc) | ||||||
|  |  | ||||||
|         self.assertEqual(query, |         self.assertEqual(query, { | ||||||
|             {'$or': [ |             '$or': [ | ||||||
|                 {'$and': [{'x': {'$gt': 0}}, |                 {'$and': [{'x': {'$gt': 0}}, | ||||||
|                           {'$or': [{'y': True}, {'y': {'$exists': False}}]}]}, |                           {'$or': [{'y': True}, {'y': {'$exists': False}}]}]}, | ||||||
|                 {'$and': [{'x': {'$lt': 100}}, |                 {'$and': [{'x': {'$lt': 100}}, | ||||||
|                           {'$or': [{'y': False}, {'y': {'$exists': False}}]}]} |                           {'$or': [{'y': False}, {'y': {'$exists': False}}]}]} | ||||||
|             ]} |             ] | ||||||
|         ) |         }) | ||||||
|  |  | ||||||
|         self.assertEqual(2, TestDoc.objects(q1 | q2).count()) |         self.assertEqual(2, TestDoc.objects(q1 | q2).count()) | ||||||
|  |  | ||||||
|     def test_multiple_occurence_in_field(self): |     def test_multiple_occurence_in_field(self): | ||||||
| @@ -215,19 +212,19 @@ class QTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         post1 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 8), published=False) |         post1 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 8), published=False) | ||||||
|         post1.save() |         post1.save() | ||||||
|  |  | ||||||
|         post2 = BlogPost(title='Test 2', publish_date=datetime(2010, 1, 15), published=True) |         post2 = BlogPost(title='Test 2', publish_date=datetime.datetime(2010, 1, 15), published=True) | ||||||
|         post2.save() |         post2.save() | ||||||
|  |  | ||||||
|         post3 = BlogPost(title='Test 3', published=True) |         post3 = BlogPost(title='Test 3', published=True) | ||||||
|         post3.save() |         post3.save() | ||||||
|  |  | ||||||
|         post4 = BlogPost(title='Test 4', publish_date=datetime(2010, 1, 8)) |         post4 = BlogPost(title='Test 4', publish_date=datetime.datetime(2010, 1, 8)) | ||||||
|         post4.save() |         post4.save() | ||||||
|  |  | ||||||
|         post5 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 15)) |         post5 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 15)) | ||||||
|         post5.save() |         post5.save() | ||||||
|  |  | ||||||
|         post6 = BlogPost(title='Test 1', published=False) |         post6 = BlogPost(title='Test 1', published=False) | ||||||
| @@ -250,7 +247,7 @@ class QTest(unittest.TestCase): | |||||||
|         self.assertTrue(all(obj.id in posts for obj in published_posts)) |         self.assertTrue(all(obj.id in posts for obj in published_posts)) | ||||||
|  |  | ||||||
|         # Check Q object combination |         # Check Q object combination | ||||||
|         date = datetime(2010, 1, 10) |         date = datetime.datetime(2010, 1, 10) | ||||||
|         q = BlogPost.objects(Q(publish_date__lte=date) | Q(published=True)) |         q = BlogPost.objects(Q(publish_date__lte=date) | Q(published=True)) | ||||||
|         posts = [post.id for post in q] |         posts = [post.id for post in q] | ||||||
|  |  | ||||||
| @@ -273,8 +270,10 @@ class QTest(unittest.TestCase): | |||||||
|         # Test invalid query objs |         # Test invalid query objs | ||||||
|         def wrong_query_objs(): |         def wrong_query_objs(): | ||||||
|             self.Person.objects('user1') |             self.Person.objects('user1') | ||||||
|  |  | ||||||
|         def wrong_query_objs_filter(): |         def wrong_query_objs_filter(): | ||||||
|             self.Person.objects('user1') |             self.Person.objects('user1') | ||||||
|  |  | ||||||
|         self.assertRaises(InvalidQueryError, wrong_query_objs) |         self.assertRaises(InvalidQueryError, wrong_query_objs) | ||||||
|         self.assertRaises(InvalidQueryError, wrong_query_objs_filter) |         self.assertRaises(InvalidQueryError, wrong_query_objs_filter) | ||||||
|  |  | ||||||
| @@ -284,7 +283,6 @@ class QTest(unittest.TestCase): | |||||||
|         person = self.Person(name='Guido van Rossum') |         person = self.Person(name='Guido van Rossum') | ||||||
|         person.save() |         person.save() | ||||||
|  |  | ||||||
|         import re |  | ||||||
|         obj = self.Person.objects(Q(name=re.compile('^Gui'))).first() |         obj = self.Person.objects(Q(name=re.compile('^Gui'))).first() | ||||||
|         self.assertEqual(obj, person) |         self.assertEqual(obj, person) | ||||||
|         obj = self.Person.objects(Q(name=re.compile('^gui'))).first() |         obj = self.Person.objects(Q(name=re.compile('^gui'))).first() | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ try: | |||||||
|     import unittest2 as unittest |     import unittest2 as unittest | ||||||
| except ImportError: | except ImportError: | ||||||
|     import unittest |     import unittest | ||||||
|  | from nose.plugins.skip import SkipTest | ||||||
|  |  | ||||||
| import pymongo | import pymongo | ||||||
| from bson.tz_util import utc | from bson.tz_util import utc | ||||||
| @@ -51,6 +52,76 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         conn = get_connection('testdb') |         conn = get_connection('testdb') | ||||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) |         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||||
|  |  | ||||||
|  |     def test_connect_in_mocking(self): | ||||||
|  |         """Ensure that the connect() method works properly in mocking. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             import mongomock | ||||||
|  |         except ImportError: | ||||||
|  |             raise SkipTest('you need mongomock installed to run this testcase') | ||||||
|  |  | ||||||
|  |         connect('mongoenginetest', host='mongomock://localhost') | ||||||
|  |         conn = get_connection() | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect('mongoenginetest2', host='mongomock://localhost', alias='testdb2') | ||||||
|  |         conn = get_connection('testdb2') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect('mongoenginetest3', host='mongodb://localhost', is_mock=True, alias='testdb3') | ||||||
|  |         conn = get_connection('testdb3') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect('mongoenginetest4', is_mock=True, alias='testdb4') | ||||||
|  |         conn = get_connection('testdb4') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host='mongodb://localhost:27017/mongoenginetest5', is_mock=True, alias='testdb5') | ||||||
|  |         conn = get_connection('testdb5') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host='mongomock://localhost:27017/mongoenginetest6', alias='testdb6') | ||||||
|  |         conn = get_connection('testdb6') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host='mongomock://localhost:27017/mongoenginetest7', is_mock=True, alias='testdb7') | ||||||
|  |         conn = get_connection('testdb7') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |     def test_connect_with_host_list(self): | ||||||
|  |         """Ensure that the connect() method works when host is a list | ||||||
|  |  | ||||||
|  |         Uses mongomock to test w/o needing multiple mongod/mongos processes | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             import mongomock | ||||||
|  |         except ImportError: | ||||||
|  |             raise SkipTest('you need mongomock installed to run this testcase') | ||||||
|  |  | ||||||
|  |         connect(host=['mongomock://localhost']) | ||||||
|  |         conn = get_connection() | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host=['mongodb://localhost'], is_mock=True,  alias='testdb2') | ||||||
|  |         conn = get_connection('testdb2') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host=['localhost'], is_mock=True,  alias='testdb3') | ||||||
|  |         conn = get_connection('testdb3') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host=['mongomock://localhost:27017', 'mongomock://localhost:27018'], alias='testdb4') | ||||||
|  |         conn = get_connection('testdb4') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True,  alias='testdb5') | ||||||
|  |         conn = get_connection('testdb5') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|  |         connect(host=['localhost:27017', 'localhost:27018'], is_mock=True,  alias='testdb6') | ||||||
|  |         conn = get_connection('testdb6') | ||||||
|  |         self.assertTrue(isinstance(conn, mongomock.MongoClient)) | ||||||
|  |  | ||||||
|     def test_disconnect(self): |     def test_disconnect(self): | ||||||
|         """Ensure that the disconnect() method works properly |         """Ensure that the disconnect() method works properly | ||||||
|         """ |         """ | ||||||
| @@ -151,7 +222,7 @@ class ConnectionTest(unittest.TestCase): | |||||||
|             self.assertRaises(ConnectionError, get_db, 'test1') |             self.assertRaises(ConnectionError, get_db, 'test1') | ||||||
|  |  | ||||||
|         # Authentication succeeds with "authSource" |         # Authentication succeeds with "authSource" | ||||||
|         test_conn2 = connect( |         connect( | ||||||
|             'mongoenginetest', alias='test2', |             'mongoenginetest', alias='test2', | ||||||
|             host=('mongodb://username2:password@localhost/' |             host=('mongodb://username2:password@localhost/' | ||||||
|                   'mongoenginetest?authSource=admin') |                   'mongoenginetest?authSource=admin') | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict | from mongoengine.base.datastructures import StrictDict, SemiStrictDict | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -13,6 +14,14 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         d = self.dtype(a=1, b=1, c=1) |         d = self.dtype(a=1, b=1, c=1) | ||||||
|         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) |         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||||
|  |  | ||||||
|  |     def test_repr(self): | ||||||
|  |         d = self.dtype(a=1, b=2, c=3) | ||||||
|  |         self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') | ||||||
|  |  | ||||||
|  |         # make sure quotes are escaped properly | ||||||
|  |         d = self.dtype(a='"', b="'", c="") | ||||||
|  |         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') | ||||||
|  |  | ||||||
|     def test_init_fails_on_nonexisting_attrs(self): |     def test_init_fails_on_nonexisting_attrs(self): | ||||||
|         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) |         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -12,9 +12,13 @@ from mongoengine.context_managers import query_counter | |||||||
|  |  | ||||||
| class FieldTest(unittest.TestCase): | class FieldTest(unittest.TestCase): | ||||||
|  |  | ||||||
|     def setUp(self): |     @classmethod | ||||||
|         connect(db='mongoenginetest') |     def setUpClass(cls): | ||||||
|         self.db = get_db() |         cls.db = connect(db='mongoenginetest') | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def tearDownClass(cls): | ||||||
|  |         cls.db.drop_database('mongoenginetest') | ||||||
|  |  | ||||||
|     def test_list_item_dereference(self): |     def test_list_item_dereference(self): | ||||||
|         """Ensure that DBRef items in ListFields are dereferenced. |         """Ensure that DBRef items in ListFields are dereferenced. | ||||||
| @@ -304,6 +308,7 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Post.drop_collection() |         Post.drop_collection() | ||||||
|  |         SimpleList.drop_collection() | ||||||
|  |  | ||||||
|         u1 = User.objects.create(name='u1') |         u1 = User.objects.create(name='u1') | ||||||
|         u2 = User.objects.create(name='u2') |         u2 = User.objects.create(name='u2') | ||||||
|   | |||||||
| @@ -25,6 +25,8 @@ class SignalTests(unittest.TestCase): | |||||||
|         connect(db='mongoenginetest') |         connect(db='mongoenginetest') | ||||||
|  |  | ||||||
|         class Author(Document): |         class Author(Document): | ||||||
|  |             # Make the id deterministic for easier testing | ||||||
|  |             id = SequenceField(primary_key=True) | ||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|             def __unicode__(self): |             def __unicode__(self): | ||||||
| @@ -33,7 +35,7 @@ class SignalTests(unittest.TestCase): | |||||||
|             @classmethod |             @classmethod | ||||||
|             def pre_init(cls, sender, document, *args, **kwargs): |             def pre_init(cls, sender, document, *args, **kwargs): | ||||||
|                 signal_output.append('pre_init signal, %s' % cls.__name__) |                 signal_output.append('pre_init signal, %s' % cls.__name__) | ||||||
|                 signal_output.append(str(kwargs['values'])) |                 signal_output.append(kwargs['values']) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def post_init(cls, sender, document, **kwargs): |             def post_init(cls, sender, document, **kwargs): | ||||||
| @@ -43,48 +45,55 @@ class SignalTests(unittest.TestCase): | |||||||
|             @classmethod |             @classmethod | ||||||
|             def pre_save(cls, sender, document, **kwargs): |             def pre_save(cls, sender, document, **kwargs): | ||||||
|                 signal_output.append('pre_save signal, %s' % document) |                 signal_output.append('pre_save signal, %s' % document) | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def pre_save_post_validation(cls, sender, document, **kwargs): |             def pre_save_post_validation(cls, sender, document, **kwargs): | ||||||
|                 signal_output.append('pre_save_post_validation signal, %s' % document) |                 signal_output.append('pre_save_post_validation signal, %s' % document) | ||||||
|                 if 'created' in kwargs: |                 if kwargs.pop('created', False): | ||||||
|                     if kwargs['created']: |  | ||||||
|                     signal_output.append('Is created') |                     signal_output.append('Is created') | ||||||
|                 else: |                 else: | ||||||
|                     signal_output.append('Is updated') |                     signal_output.append('Is updated') | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def post_save(cls, sender, document, **kwargs): |             def post_save(cls, sender, document, **kwargs): | ||||||
|                 dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() |                 dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() | ||||||
|                 signal_output.append('post_save signal, %s' % document) |                 signal_output.append('post_save signal, %s' % document) | ||||||
|                 signal_output.append('post_save dirty keys, %s' % dirty_keys) |                 signal_output.append('post_save dirty keys, %s' % dirty_keys) | ||||||
|                 if 'created' in kwargs: |                 if kwargs.pop('created', False): | ||||||
|                     if kwargs['created']: |  | ||||||
|                     signal_output.append('Is created') |                     signal_output.append('Is created') | ||||||
|                 else: |                 else: | ||||||
|                     signal_output.append('Is updated') |                     signal_output.append('Is updated') | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def pre_delete(cls, sender, document, **kwargs): |             def pre_delete(cls, sender, document, **kwargs): | ||||||
|                 signal_output.append('pre_delete signal, %s' % document) |                 signal_output.append('pre_delete signal, %s' % document) | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def post_delete(cls, sender, document, **kwargs): |             def post_delete(cls, sender, document, **kwargs): | ||||||
|                 signal_output.append('post_delete signal, %s' % document) |                 signal_output.append('post_delete signal, %s' % document) | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def pre_bulk_insert(cls, sender, documents, **kwargs): |             def pre_bulk_insert(cls, sender, documents, **kwargs): | ||||||
|                 signal_output.append('pre_bulk_insert signal, %s' % documents) |                 signal_output.append('pre_bulk_insert signal, %s' % documents) | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def post_bulk_insert(cls, sender, documents, **kwargs): |             def post_bulk_insert(cls, sender, documents, **kwargs): | ||||||
|                 signal_output.append('post_bulk_insert signal, %s' % documents) |                 signal_output.append('post_bulk_insert signal, %s' % documents) | ||||||
|                 if kwargs.get('loaded', False): |                 if kwargs.pop('loaded', False): | ||||||
|                     signal_output.append('Is loaded') |                     signal_output.append('Is loaded') | ||||||
|                 else: |                 else: | ||||||
|                     signal_output.append('Not loaded') |                     signal_output.append('Not loaded') | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|         self.Author = Author |         self.Author = Author | ||||||
|         Author.drop_collection() |         Author.drop_collection() | ||||||
|  |         Author.id.set_next_value(0) | ||||||
|  |  | ||||||
|         class Another(Document): |         class Another(Document): | ||||||
|  |  | ||||||
| @@ -96,10 +105,12 @@ class SignalTests(unittest.TestCase): | |||||||
|             @classmethod |             @classmethod | ||||||
|             def pre_delete(cls, sender, document, **kwargs): |             def pre_delete(cls, sender, document, **kwargs): | ||||||
|                 signal_output.append('pre_delete signal, %s' % document) |                 signal_output.append('pre_delete signal, %s' % document) | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def post_delete(cls, sender, document, **kwargs): |             def post_delete(cls, sender, document, **kwargs): | ||||||
|                 signal_output.append('post_delete signal, %s' % document) |                 signal_output.append('post_delete signal, %s' % document) | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|         self.Another = Another |         self.Another = Another | ||||||
|         Another.drop_collection() |         Another.drop_collection() | ||||||
| @@ -118,6 +129,41 @@ class SignalTests(unittest.TestCase): | |||||||
|         self.ExplicitId = ExplicitId |         self.ExplicitId = ExplicitId | ||||||
|         ExplicitId.drop_collection() |         ExplicitId.drop_collection() | ||||||
|  |  | ||||||
|  |         class Post(Document): | ||||||
|  |             title = StringField() | ||||||
|  |             content = StringField() | ||||||
|  |             active = BooleanField(default=False) | ||||||
|  |  | ||||||
|  |             def __unicode__(self): | ||||||
|  |                 return self.title | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def pre_bulk_insert(cls, sender, documents, **kwargs): | ||||||
|  |                 signal_output.append('pre_bulk_insert signal, %s' % | ||||||
|  |                                      [(doc, {'active': documents[n].active}) | ||||||
|  |                                       for n, doc in enumerate(documents)]) | ||||||
|  |  | ||||||
|  |                 # make changes here, this is just an example - | ||||||
|  |                 # it could be anything that needs pre-validation or looks-ups before bulk bulk inserting | ||||||
|  |                 for document in documents: | ||||||
|  |                     if not document.active: | ||||||
|  |                         document.active = True | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def post_bulk_insert(cls, sender, documents, **kwargs): | ||||||
|  |                 signal_output.append('post_bulk_insert signal, %s' % | ||||||
|  |                                      [(doc, {'active': documents[n].active}) | ||||||
|  |                                       for n, doc in enumerate(documents)]) | ||||||
|  |                 if kwargs.pop('loaded', False): | ||||||
|  |                     signal_output.append('Is loaded') | ||||||
|  |                 else: | ||||||
|  |                     signal_output.append('Not loaded') | ||||||
|  |                 signal_output.append(kwargs) | ||||||
|  |  | ||||||
|  |         self.Post = Post | ||||||
|  |         Post.drop_collection() | ||||||
|  |  | ||||||
|         # Save up the number of connected signals so that we can check at the |         # Save up the number of connected signals so that we can check at the | ||||||
|         # end that all the signals we register get properly unregistered |         # end that all the signals we register get properly unregistered | ||||||
|         self.pre_signals = ( |         self.pre_signals = ( | ||||||
| @@ -147,6 +193,9 @@ class SignalTests(unittest.TestCase): | |||||||
|  |  | ||||||
|         signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) |         signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) | ||||||
|  |  | ||||||
|  |         signals.pre_bulk_insert.connect(Post.pre_bulk_insert, sender=Post) | ||||||
|  |         signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post) | ||||||
|  |  | ||||||
|     def tearDown(self): |     def tearDown(self): | ||||||
|         signals.pre_init.disconnect(self.Author.pre_init) |         signals.pre_init.disconnect(self.Author.pre_init) | ||||||
|         signals.post_init.disconnect(self.Author.post_init) |         signals.post_init.disconnect(self.Author.post_init) | ||||||
| @@ -163,6 +212,9 @@ class SignalTests(unittest.TestCase): | |||||||
|  |  | ||||||
|         signals.post_save.disconnect(self.ExplicitId.post_save) |         signals.post_save.disconnect(self.ExplicitId.post_save) | ||||||
|  |  | ||||||
|  |         signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) | ||||||
|  |         signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) | ||||||
|  |  | ||||||
|         # Check that all our signals got disconnected properly. |         # Check that all our signals got disconnected properly. | ||||||
|         post_signals = ( |         post_signals = ( | ||||||
|             len(signals.pre_init.receivers), |             len(signals.pre_init.receivers), | ||||||
| @@ -202,63 +254,118 @@ class SignalTests(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(self.get_signal_output(create_author), [ |         self.assertEqual(self.get_signal_output(create_author), [ | ||||||
|             "pre_init signal, Author", |             "pre_init signal, Author", | ||||||
|             "{'name': 'Bill Shakespeare'}", |             {'name': 'Bill Shakespeare'}, | ||||||
|             "post_init signal, Bill Shakespeare, document._created = True", |             "post_init signal, Bill Shakespeare, document._created = True", | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|         a1 = self.Author(name='Bill Shakespeare') |         a1 = self.Author(name='Bill Shakespeare') | ||||||
|         self.assertEqual(self.get_signal_output(a1.save), [ |         self.assertEqual(self.get_signal_output(a1.save), [ | ||||||
|             "pre_save signal, Bill Shakespeare", |             "pre_save signal, Bill Shakespeare", | ||||||
|  |             {}, | ||||||
|             "pre_save_post_validation signal, Bill Shakespeare", |             "pre_save_post_validation signal, Bill Shakespeare", | ||||||
|             "Is created", |             "Is created", | ||||||
|  |             {}, | ||||||
|             "post_save signal, Bill Shakespeare", |             "post_save signal, Bill Shakespeare", | ||||||
|             "post_save dirty keys, ['name']", |             "post_save dirty keys, ['name']", | ||||||
|             "Is created" |             "Is created", | ||||||
|  |             {} | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|         a1.reload() |         a1.reload() | ||||||
|         a1.name = 'William Shakespeare' |         a1.name = 'William Shakespeare' | ||||||
|         self.assertEqual(self.get_signal_output(a1.save), [ |         self.assertEqual(self.get_signal_output(a1.save), [ | ||||||
|             "pre_save signal, William Shakespeare", |             "pre_save signal, William Shakespeare", | ||||||
|  |             {}, | ||||||
|             "pre_save_post_validation signal, William Shakespeare", |             "pre_save_post_validation signal, William Shakespeare", | ||||||
|             "Is updated", |             "Is updated", | ||||||
|  |             {}, | ||||||
|             "post_save signal, William Shakespeare", |             "post_save signal, William Shakespeare", | ||||||
|             "post_save dirty keys, ['name']", |             "post_save dirty keys, ['name']", | ||||||
|             "Is updated" |             "Is updated", | ||||||
|  |             {} | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|         self.assertEqual(self.get_signal_output(a1.delete), [ |         self.assertEqual(self.get_signal_output(a1.delete), [ | ||||||
|             'pre_delete signal, William Shakespeare', |             'pre_delete signal, William Shakespeare', | ||||||
|  |             {}, | ||||||
|             'post_delete signal, William Shakespeare', |             'post_delete signal, William Shakespeare', | ||||||
|  |             {} | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|         signal_output = self.get_signal_output(load_existing_author) |         self.assertEqual(self.get_signal_output(load_existing_author), [ | ||||||
|         # test signal_output lines separately, because of random ObjectID after object load |  | ||||||
|         self.assertEqual(signal_output[0], |  | ||||||
|             "pre_init signal, Author", |             "pre_init signal, Author", | ||||||
|         ) |             {'id': 2, 'name': 'Bill Shakespeare'}, | ||||||
|         self.assertEqual(signal_output[2], |             "post_init signal, Bill Shakespeare, document._created = False" | ||||||
|             "post_init signal, Bill Shakespeare, document._created = False", |         ]) | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |         self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ | ||||||
|         signal_output = self.get_signal_output(bulk_create_author_with_load) |             'pre_init signal, Author', | ||||||
|  |             {'name': 'Bill Shakespeare'}, | ||||||
|         # The output of this signal is not entirely deterministic. The reloaded |             'post_init signal, Bill Shakespeare, document._created = True', | ||||||
|         # object will have an object ID. Hence, we only check part of the output |             'pre_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||||
|         self.assertEqual(signal_output[3], "pre_bulk_insert signal, [<Author: Bill Shakespeare>]" |             {}, | ||||||
|         ) |             'pre_init signal, Author', | ||||||
|         self.assertEqual(signal_output[-2:], |             {'id': 3, 'name': 'Bill Shakespeare'}, | ||||||
|             ["post_bulk_insert signal, [<Author: Bill Shakespeare>]", |             'post_init signal, Bill Shakespeare, document._created = False', | ||||||
|              "Is loaded",]) |             'post_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||||
|  |             'Is loaded', | ||||||
|  |             {} | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|         self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ |         self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ | ||||||
|             "pre_init signal, Author", |             "pre_init signal, Author", | ||||||
|             "{'name': 'Bill Shakespeare'}", |             {'name': 'Bill Shakespeare'}, | ||||||
|             "post_init signal, Bill Shakespeare, document._created = True", |             "post_init signal, Bill Shakespeare, document._created = True", | ||||||
|             "pre_bulk_insert signal, [<Author: Bill Shakespeare>]", |             "pre_bulk_insert signal, [<Author: Bill Shakespeare>]", | ||||||
|  |             {}, | ||||||
|             "post_bulk_insert signal, [<Author: Bill Shakespeare>]", |             "post_bulk_insert signal, [<Author: Bill Shakespeare>]", | ||||||
|             "Not loaded", |             "Not loaded", | ||||||
|  |             {} | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |     def test_signal_kwargs(self): | ||||||
|  |         """ Make sure signal_kwargs is passed to signals calls. """ | ||||||
|  |  | ||||||
|  |         def live_and_let_die(): | ||||||
|  |             a = self.Author(name='Bill Shakespeare') | ||||||
|  |             a.save(signal_kwargs={'live': True, 'die': False}) | ||||||
|  |             a.delete(signal_kwargs={'live': False, 'die': True}) | ||||||
|  |  | ||||||
|  |         self.assertEqual(self.get_signal_output(live_and_let_die), [ | ||||||
|  |             "pre_init signal, Author", | ||||||
|  |             {'name': 'Bill Shakespeare'}, | ||||||
|  |             "post_init signal, Bill Shakespeare, document._created = True", | ||||||
|  |             "pre_save signal, Bill Shakespeare", | ||||||
|  |             {'die': False, 'live': True}, | ||||||
|  |             "pre_save_post_validation signal, Bill Shakespeare", | ||||||
|  |             "Is created", | ||||||
|  |             {'die': False, 'live': True}, | ||||||
|  |             "post_save signal, Bill Shakespeare", | ||||||
|  |             "post_save dirty keys, ['name']", | ||||||
|  |             "Is created", | ||||||
|  |             {'die': False, 'live': True}, | ||||||
|  |             'pre_delete signal, Bill Shakespeare', | ||||||
|  |             {'die': True, 'live': False}, | ||||||
|  |             'post_delete signal, Bill Shakespeare', | ||||||
|  |             {'die': True, 'live': False} | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |         def bulk_create_author(): | ||||||
|  |             a1 = self.Author(name='Bill Shakespeare') | ||||||
|  |             self.Author.objects.insert([a1], signal_kwargs={'key': True}) | ||||||
|  |  | ||||||
|  |         self.assertEqual(self.get_signal_output(bulk_create_author), [ | ||||||
|  |             'pre_init signal, Author', | ||||||
|  |             {'name': 'Bill Shakespeare'}, | ||||||
|  |             'post_init signal, Bill Shakespeare, document._created = True', | ||||||
|  |             'pre_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||||
|  |             {'key': True}, | ||||||
|  |             'pre_init signal, Author', | ||||||
|  |             {'id': 2, 'name': 'Bill Shakespeare'}, | ||||||
|  |             'post_init signal, Bill Shakespeare, document._created = False', | ||||||
|  |             'post_bulk_insert signal, [<Author: Bill Shakespeare>]', | ||||||
|  |             'Is loaded', | ||||||
|  |             {'key': True} | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|     def test_queryset_delete_signals(self): |     def test_queryset_delete_signals(self): | ||||||
| @@ -267,7 +374,9 @@ class SignalTests(unittest.TestCase): | |||||||
|         self.Another(name='Bill Shakespeare').save() |         self.Another(name='Bill Shakespeare').save() | ||||||
|         self.assertEqual(self.get_signal_output(self.Another.objects.delete), [ |         self.assertEqual(self.get_signal_output(self.Another.objects.delete), [ | ||||||
|             'pre_delete signal, Bill Shakespeare', |             'pre_delete signal, Bill Shakespeare', | ||||||
|  |             {}, | ||||||
|             'post_delete signal, Bill Shakespeare', |             'post_delete signal, Bill Shakespeare', | ||||||
|  |             {} | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|     def test_signals_with_explicit_doc_ids(self): |     def test_signals_with_explicit_doc_ids(self): | ||||||
| @@ -306,6 +415,23 @@ class SignalTests(unittest.TestCase): | |||||||
|         ei.switch_db("testdb-1", keep_created=False) |         ei.switch_db("testdb-1", keep_created=False) | ||||||
|         self.assertEqual(self.get_signal_output(ei.save), ['Is created']) |         self.assertEqual(self.get_signal_output(ei.save), ['Is created']) | ||||||
|  |  | ||||||
|  |     def test_signals_bulk_insert(self): | ||||||
|  |         def bulk_set_active_post(): | ||||||
|  |             posts = [ | ||||||
|  |                 self.Post(title='Post 1'), | ||||||
|  |                 self.Post(title='Post 2'), | ||||||
|  |                 self.Post(title='Post 3') | ||||||
|  |             ] | ||||||
|  |             self.Post.objects.insert(posts) | ||||||
|  |  | ||||||
|  |         results = self.get_signal_output(bulk_set_active_post) | ||||||
|  |         self.assertEqual(results, [ | ||||||
|  |             "pre_bulk_insert signal, [(<Post: Post 1>, {'active': False}), (<Post: Post 2>, {'active': False}), (<Post: Post 3>, {'active': False})]", | ||||||
|  |             {}, | ||||||
|  |             "post_bulk_insert signal, [(<Post: Post 1>, {'active': True}), (<Post: Post 2>, {'active': True}), (<Post: Post 3>, {'active': True})]", | ||||||
|  |             'Is loaded', | ||||||
|  |             {} | ||||||
|  |         ]) | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								tox.ini
									
									
									
									
									
								
							| @@ -1,14 +1,22 @@ | |||||||
| [tox] | [tox] | ||||||
| envlist = {py26,py27,py32,py33,py34,py35,pypy,pypy3}-{mg27,mg28} | envlist = {py26,py27,py33,py34,py35,pypy,pypy3}-{mg27,mg28},flake8 | ||||||
| #envlist = {py26,py27,py32,py33,py34,pypy,pypy3}-{mg27,mg28,mg30,mgdev} |  | ||||||
|  |  | ||||||
| [testenv] | [testenv] | ||||||
| commands = | commands = | ||||||
|     python setup.py nosetests {posargs} |     python setup.py nosetests {posargs} | ||||||
| deps = | deps = | ||||||
|     nose |     nose | ||||||
|     rednose |  | ||||||
|     mg27: PyMongo<2.8 |     mg27: PyMongo<2.8 | ||||||
|     mg28: PyMongo>=2.8,<3.0 |     mg28: PyMongo>=2.8,<3.0 | ||||||
|     mg30: PyMongo>=3.0 |     mg30: PyMongo>=3.0 | ||||||
|     mgdev: https://github.com/mongodb/mongo-python-driver/tarball/master |     mgdev: https://github.com/mongodb/mongo-python-driver/tarball/master | ||||||
|  | setenv = | ||||||
|  |     PYTHON_EGG_CACHE = {envdir}/python-eggs | ||||||
|  | passenv = windir | ||||||
|  |  | ||||||
|  | [testenv:flake8] | ||||||
|  | deps = | ||||||
|  |     flake8 | ||||||
|  |     flake8-import-order | ||||||
|  | commands = | ||||||
|  |    flake8 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user