Compare commits
	
		
			1 Commits
		
	
	
		
			cant-save-
			...
			no-conflic
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | df12211c25 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -14,6 +14,4 @@ env/ | |||||||
| .project | .project | ||||||
| .pydevproject | .pydevproject | ||||||
| tests/test_bugfix.py | tests/test_bugfix.py | ||||||
| htmlcov/ | htmlcov/ | ||||||
| venv |  | ||||||
| venv3 |  | ||||||
| @@ -1,22 +0,0 @@ | |||||||
| pylint: |  | ||||||
|     disable: |  | ||||||
|         # We use this a lot (e.g. via document._meta) |  | ||||||
|         - protected-access |  | ||||||
|  |  | ||||||
|     options: |  | ||||||
|         additional-builtins: |  | ||||||
|             # add xrange and long as valid built-ins. In Python 3, xrange is |  | ||||||
|             # translated into range and long is translated into int via 2to3 (see |  | ||||||
|             # "use_2to3" in setup.py). This should be removed when we drop Python |  | ||||||
|             # 2 support (which probably won't happen any time soon). |  | ||||||
|             - xrange |  | ||||||
|             - long |  | ||||||
|  |  | ||||||
| pyflakes: |  | ||||||
|     disable: |  | ||||||
|         # undefined variables are already covered by pylint (and exclude |  | ||||||
|         # xrange & long) |  | ||||||
|         - F821 |  | ||||||
|  |  | ||||||
| ignore-paths: |  | ||||||
|     - benchmark.py |  | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| language: python | language: python | ||||||
|  |  | ||||||
| python: | python: | ||||||
|  | - '2.6' | ||||||
| - '2.7' | - '2.7' | ||||||
| - '3.3' | - '3.3' | ||||||
| - '3.4' | - '3.4' | ||||||
| @@ -42,11 +43,7 @@ before_script: | |||||||
| script: | script: | ||||||
| - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage | - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage | ||||||
|  |  | ||||||
| # For now only submit coveralls for Python v2.7. Python v3.x currently shows | after_script: coveralls --verbose | ||||||
| # 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible |  | ||||||
| # code in a separate dir and runs tests on that. |  | ||||||
| after_script: |  | ||||||
| - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then coveralls --verbose; fi |  | ||||||
|  |  | ||||||
| notifications: | notifications: | ||||||
|   irc: irc.freenode.org#mongoengine |   irc: irc.freenode.org#mongoengine | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -242,4 +242,3 @@ that much better: | |||||||
|  * xiaost7 (https://github.com/xiaost7) |  * xiaost7 (https://github.com/xiaost7) | ||||||
|  * Victor Varvaryuk |  * Victor Varvaryuk | ||||||
|  * Stanislav Kaledin (https://github.com/sallyruthstruik) |  * Stanislav Kaledin (https://github.com/sallyruthstruik) | ||||||
|  * Dmitry Yantsen (https://github.com/mrTable) |  | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ post to the `user group <http://groups.google.com/group/mongoengine-users>` | |||||||
| Supported Interpreters | Supported Interpreters | ||||||
| ---------------------- | ---------------------- | ||||||
|  |  | ||||||
| MongoEngine supports CPython 2.7 and newer. Language | MongoEngine supports CPython 2.6 and newer. Language | ||||||
| features not supported by all interpreters can not be used. | features not supported by all interpreters can not be used. | ||||||
| Please also ensure that your code is properly converted by | Please also ensure that your code is properly converted by | ||||||
| `2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support. | `2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support. | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ MongoEngine | |||||||
| :Info: MongoEngine is an ORM-like layer on top of PyMongo. | :Info: MongoEngine is an ORM-like layer on top of PyMongo. | ||||||
| :Repository: https://github.com/MongoEngine/mongoengine | :Repository: https://github.com/MongoEngine/mongoengine | ||||||
| :Author: Harry Marr (http://github.com/hmarr) | :Author: Harry Marr (http://github.com/hmarr) | ||||||
| :Maintainer: Stefan Wójcik (http://github.com/wojcikstefan) | :Maintainer: Ross Lawley (http://github.com/rozza) | ||||||
|  |  | ||||||
| .. image:: https://travis-ci.org/MongoEngine/mongoengine.svg?branch=master | .. image:: https://travis-ci.org/MongoEngine/mongoengine.svg?branch=master | ||||||
|   :target: https://travis-ci.org/MongoEngine/mongoengine |   :target: https://travis-ci.org/MongoEngine/mongoengine | ||||||
|   | |||||||
							
								
								
									
										152
									
								
								benchmark.py
									
									
									
									
									
								
							
							
						
						
									
										152
									
								
								benchmark.py
									
									
									
									
									
								
							| @@ -1,41 +1,118 @@ | |||||||
| #!/usr/bin/env python | #!/usr/bin/env python | ||||||
|  |  | ||||||
| """ |  | ||||||
| Simple benchmark comparing PyMongo and MongoEngine. |  | ||||||
|  |  | ||||||
| Sample run on a mid 2015 MacBook Pro (commit b282511): |  | ||||||
|  |  | ||||||
| Benchmarking... |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - Pymongo |  | ||||||
| 2.58979988098 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - Pymongo write_concern={"w": 0} |  | ||||||
| 1.26657605171 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - MongoEngine |  | ||||||
| 8.4351580143 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries without continual assign - MongoEngine |  | ||||||
| 7.20191693306 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True |  | ||||||
| 6.31104588509 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True |  | ||||||
| 6.07083487511 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False |  | ||||||
| 5.97704291344 |  | ||||||
| ---------------------------------------------------------------------------------------------------- |  | ||||||
| Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False |  | ||||||
| 5.9111430645 |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import timeit | import timeit | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def cprofile_main(): | ||||||
|  |     from pymongo import Connection | ||||||
|  |     connection = Connection() | ||||||
|  |     connection.drop_database('timeit_test') | ||||||
|  |     connection.disconnect() | ||||||
|  |  | ||||||
|  |     from mongoengine import Document, DictField, connect | ||||||
|  |     connect("timeit_test") | ||||||
|  |  | ||||||
|  |     class Noddy(Document): | ||||||
|  |         fields = DictField() | ||||||
|  |  | ||||||
|  |     for i in range(1): | ||||||
|  |         noddy = Noddy() | ||||||
|  |         for j in range(20): | ||||||
|  |             noddy.fields["key" + str(j)] = "value " + str(j) | ||||||
|  |         noddy.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|  |     """ | ||||||
|  |     0.4 Performance Figures ... | ||||||
|  |  | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - Pymongo | ||||||
|  |     3.86744189262 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine | ||||||
|  |     6.23374891281 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||||
|  |     5.33027005196 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||||
|  |     pass - No Cascade | ||||||
|  |  | ||||||
|  |     0.5.X | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - Pymongo | ||||||
|  |     3.89597702026 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine | ||||||
|  |     21.7735359669 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||||
|  |     19.8670389652 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||||
|  |     pass - No Cascade | ||||||
|  |  | ||||||
|  |     0.6.X | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - Pymongo | ||||||
|  |     3.81559205055 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine | ||||||
|  |     10.0446798801 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||||
|  |     9.51354718208 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||||
|  |     9.02567505836 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, force=True | ||||||
|  |     8.44933390617 | ||||||
|  |  | ||||||
|  |     0.7.X | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - Pymongo | ||||||
|  |     3.78801012039 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine | ||||||
|  |     9.73050498962 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False | ||||||
|  |     8.33456707001 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False | ||||||
|  |     8.37778115273 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, force=True | ||||||
|  |     8.36906409264 | ||||||
|  |     0.8.X | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - Pymongo | ||||||
|  |     3.69964408875 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - Pymongo write_concern={"w": 0} | ||||||
|  |     3.5526599884 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine | ||||||
|  |     7.00959801674 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries without continual assign - MongoEngine | ||||||
|  |     5.60943293571 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade=True | ||||||
|  |     6.715102911 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True | ||||||
|  |     5.50644683838 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False | ||||||
|  |     4.69851183891 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False | ||||||
|  |     4.68946313858 | ||||||
|  |     ---------------------------------------------------------------------------------------------------- | ||||||
|  |     """ | ||||||
|     print("Benchmarking...") |     print("Benchmarking...") | ||||||
|  |  | ||||||
|     setup = """ |     setup = """ | ||||||
| @@ -54,7 +131,7 @@ noddy = db.noddy | |||||||
| for i in range(10000): | for i in range(10000): | ||||||
|     example = {'fields': {}} |     example = {'fields': {}} | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         example['fields']['key' + str(j)] = 'value ' + str(j) |         example['fields']["key"+str(j)] = "value "+str(j) | ||||||
|  |  | ||||||
|     noddy.save(example) |     noddy.save(example) | ||||||
|  |  | ||||||
| @@ -69,10 +146,9 @@ myNoddys = noddy.find() | |||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| from pymongo import MongoClient | from pymongo import MongoClient | ||||||
| from pymongo.write_concern import WriteConcern |  | ||||||
| connection = MongoClient() | connection = MongoClient() | ||||||
|  |  | ||||||
| db = connection.get_database('timeit_test', write_concern=WriteConcern(w=0)) | db = connection.timeit_test | ||||||
| noddy = db.noddy | noddy = db.noddy | ||||||
|  |  | ||||||
| for i in range(10000): | for i in range(10000): | ||||||
| @@ -80,7 +156,7 @@ for i in range(10000): | |||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         example['fields']["key"+str(j)] = "value "+str(j) |         example['fields']["key"+str(j)] = "value "+str(j) | ||||||
|  |  | ||||||
|     noddy.save(example) |     noddy.save(example, write_concern={"w": 0}) | ||||||
|  |  | ||||||
| myNoddys = noddy.find() | myNoddys = noddy.find() | ||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| @@ -95,10 +171,10 @@ myNoddys = noddy.find() | |||||||
| from pymongo import MongoClient | from pymongo import MongoClient | ||||||
| connection = MongoClient() | connection = MongoClient() | ||||||
| connection.drop_database('timeit_test') | connection.drop_database('timeit_test') | ||||||
| connection.close() | connection.disconnect() | ||||||
|  |  | ||||||
| from mongoengine import Document, DictField, connect | from mongoengine import Document, DictField, connect | ||||||
| connect('timeit_test') | connect("timeit_test") | ||||||
|  |  | ||||||
| class Noddy(Document): | class Noddy(Document): | ||||||
|     fields = DictField() |     fields = DictField() | ||||||
|   | |||||||
| @@ -2,34 +2,11 @@ | |||||||
| Changelog | Changelog | ||||||
| ========= | ========= | ||||||
|  |  | ||||||
| Development |  | ||||||
| =========== |  | ||||||
| - (Fill this out as you fix issues and develop you features). |  | ||||||
| - Fixed connecting to a replica set with PyMongo 2.x #1436 |  | ||||||
| - Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237 |  | ||||||
|  |  | ||||||
| Changes in 0.11.0 |  | ||||||
| ================= |  | ||||||
| - BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428 |  | ||||||
| - BREAKING CHANGE: Dropped Python 2.6 support. #1428 |  | ||||||
| - BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428 |  | ||||||
| - Fixed absent rounding for DecimalField when `force_string` is set. #1103 |  | ||||||
|  |  | ||||||
| Changes in 0.10.8 | Changes in 0.10.8 | ||||||
| ================= | ================= | ||||||
| - Added support for QuerySet.batch_size (#1426) |  | ||||||
| - Fixed query set iteration within iteration #1427 |  | ||||||
| - Fixed an issue where specifying a MongoDB URI host would override more information than it should #1421 |  | ||||||
| - Added ability to filter the generic reference field by ObjectId and DBRef #1425 |  | ||||||
| - Fixed delete cascade for models with a custom primary key field #1247 |  | ||||||
| - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | ||||||
| - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | ||||||
| - Fixed QuerySet#sum/average for fields w/ explicit db_field #1417 | - Fixed BaseQuerySet#sum/average for fields w/ explicit db_field #1417 | ||||||
| - Fixed filtering by embedded_doc=None #1422 |  | ||||||
| - Added support for cursor.comment #1420 |  | ||||||
| - Fixed doc.get_<field>_display #1419 |  | ||||||
| - Fixed __repr__ method of the StrictDict #1424 |  | ||||||
| - Added a deprecation warning for Python 2.6 |  | ||||||
|  |  | ||||||
| Changes in 0.10.7 | Changes in 0.10.7 | ||||||
| ================= | ================= | ||||||
|   | |||||||
| @@ -479,8 +479,6 @@ operators. To use a :class:`~mongoengine.queryset.Q` object, pass it in as the | |||||||
| first positional argument to :attr:`Document.objects` when you filter it by | first positional argument to :attr:`Document.objects` when you filter it by | ||||||
| calling it with keyword arguments:: | calling it with keyword arguments:: | ||||||
|  |  | ||||||
|     from mongoengine.queryset.visitor import Q |  | ||||||
|  |  | ||||||
|     # Get published posts |     # Get published posts | ||||||
|     Post.objects(Q(published=True) | Q(publish_date__lte=datetime.now())) |     Post.objects(Q(published=True) | Q(publish_date__lte=datetime.now())) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,39 +2,6 @@ | |||||||
| Upgrading | Upgrading | ||||||
| ######### | ######### | ||||||
|  |  | ||||||
| 0.11.0 |  | ||||||
| ****** |  | ||||||
| This release includes a major rehaul of MongoEngine's code quality and |  | ||||||
| introduces a few breaking changes. It also touches many different parts of |  | ||||||
| the package and although all the changes have been tested and scrutinized, |  | ||||||
| you're encouraged to thorougly test the upgrade. |  | ||||||
|  |  | ||||||
| First breaking change involves renaming `ConnectionError` to `MongoEngineConnectionError`. |  | ||||||
| If you import or catch this exception, you'll need to rename it in your code. |  | ||||||
|  |  | ||||||
| Second breaking change drops Python v2.6 support. If you run MongoEngine on |  | ||||||
| that Python version, you'll need to upgrade it first. |  | ||||||
|  |  | ||||||
| Third breaking change drops an old backward compatibility measure where |  | ||||||
| `from mongoengine.base import ErrorClass` would work on top of |  | ||||||
| `from mongoengine.errors import ErrorClass` (where `ErrorClass` is e.g. |  | ||||||
| `ValidationError`). If you import any exceptions from `mongoengine.base`, |  | ||||||
| change it to `mongoengine.errors`. |  | ||||||
|  |  | ||||||
| 0.10.8 |  | ||||||
| ****** |  | ||||||
| This version fixed an issue where specifying a MongoDB URI host would override |  | ||||||
| more information than it should. These changes are minor, but they still |  | ||||||
| subtly modify the connection logic and thus you're encouraged to test your |  | ||||||
| MongoDB connection before shipping v0.10.8 in production. |  | ||||||
|  |  | ||||||
| 0.10.7 |  | ||||||
| ****** |  | ||||||
|  |  | ||||||
| `QuerySet.aggregate_sum` and `QuerySet.aggregate_average` are dropped. Use |  | ||||||
| `QuerySet.sum` and `QuerySet.average` instead which use the aggreation framework |  | ||||||
| by default from now on. |  | ||||||
|  |  | ||||||
| 0.9.0 | 0.9.0 | ||||||
| ***** | ***** | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,35 +1,25 @@ | |||||||
| # Import submodules so that we can expose their __all__ | import connection | ||||||
| from mongoengine import connection | from connection import * | ||||||
| from mongoengine import document | import document | ||||||
| from mongoengine import errors | from document import * | ||||||
| from mongoengine import fields | import errors | ||||||
| from mongoengine import queryset | from errors import * | ||||||
| from mongoengine import signals | import fields | ||||||
|  | from fields import * | ||||||
|  | import queryset | ||||||
|  | from queryset import * | ||||||
|  | import signals | ||||||
|  | from signals import * | ||||||
|  |  | ||||||
| # Import everything from each submodule so that it can be accessed via | __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + | ||||||
| # mongoengine, e.g. instead of `from mongoengine.connection import connect`, |            list(queryset.__all__) + signals.__all__ + list(errors.__all__)) | ||||||
| # users can simply use `from mongoengine import connect`, or even |  | ||||||
| # `from mongoengine import *` and then `connect('testdb')`. |  | ||||||
| from mongoengine.connection import * |  | ||||||
| from mongoengine.document import * |  | ||||||
| from mongoengine.errors import * |  | ||||||
| from mongoengine.fields import * |  | ||||||
| from mongoengine.queryset import * |  | ||||||
| from mongoengine.signals import * |  | ||||||
|  |  | ||||||
|  | VERSION = (0, 10, 7) | ||||||
| __all__ = (list(document.__all__) + list(fields.__all__) + |  | ||||||
|            list(connection.__all__) + list(queryset.__all__) + |  | ||||||
|            list(signals.__all__) + list(errors.__all__)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| VERSION = (0, 11, 0) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_version(): | def get_version(): | ||||||
|     """Return the VERSION as a string, e.g. for VERSION == (0, 10, 7), |     if isinstance(VERSION[-1], basestring): | ||||||
|     return '0.10.7'. |         return '.'.join(map(str, VERSION[:-1])) + VERSION[-1] | ||||||
|     """ |  | ||||||
|     return '.'.join(map(str, VERSION)) |     return '.'.join(map(str, VERSION)) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,28 +1,8 @@ | |||||||
| # Base module is split into several files for convenience. Files inside of |  | ||||||
| # this module should import from a specific submodule (e.g. |  | ||||||
| # `from mongoengine.base.document import BaseDocument`), but all of the |  | ||||||
| # other modules should import directly from the top-level module (e.g. |  | ||||||
| # `from mongoengine.base import BaseDocument`). This approach is cleaner and |  | ||||||
| # also helps with cyclical import errors. |  | ||||||
| from mongoengine.base.common import * | from mongoengine.base.common import * | ||||||
| from mongoengine.base.datastructures import * | from mongoengine.base.datastructures import * | ||||||
| from mongoengine.base.document import * | from mongoengine.base.document import * | ||||||
| from mongoengine.base.fields import * | from mongoengine.base.fields import * | ||||||
| from mongoengine.base.metaclasses import * | from mongoengine.base.metaclasses import * | ||||||
|  |  | ||||||
| __all__ = ( | # Help with backwards compatibility | ||||||
|     # common | from mongoengine.errors import * | ||||||
|     'UPDATE_OPERATORS', '_document_registry', 'get_document', |  | ||||||
|  |  | ||||||
|     # datastructures |  | ||||||
|     'BaseDict', 'BaseList', 'EmbeddedDocumentList', |  | ||||||
|  |  | ||||||
|     # document |  | ||||||
|     'BaseDocument', |  | ||||||
|  |  | ||||||
|     # fields |  | ||||||
|     'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField', |  | ||||||
|  |  | ||||||
|     # metaclasses |  | ||||||
|     'DocumentMetaclass', 'TopLevelDocumentMetaclass' |  | ||||||
| ) |  | ||||||
|   | |||||||
| @@ -1,18 +1,13 @@ | |||||||
| from mongoengine.errors import NotRegistered | from mongoengine.errors import NotRegistered | ||||||
|  |  | ||||||
| __all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry') | __all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') | ||||||
|  |  | ||||||
|  |  | ||||||
| UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', |  | ||||||
|                         'push_all', 'pull', 'pull_all', 'add_to_set', |  | ||||||
|                         'set_on_insert', 'min', 'max']) |  | ||||||
|  |  | ||||||
|  | ALLOW_INHERITANCE = False | ||||||
|  |  | ||||||
| _document_registry = {} | _document_registry = {} | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_document(name): | def get_document(name): | ||||||
|     """Get a document class by name.""" |  | ||||||
|     doc = _document_registry.get(name, None) |     doc = _document_registry.get(name, None) | ||||||
|     if not doc: |     if not doc: | ||||||
|         # Possible old style name |         # Possible old style name | ||||||
|   | |||||||
| @@ -1,16 +1,14 @@ | |||||||
| import itertools | import itertools | ||||||
| import weakref | import weakref | ||||||
|  |  | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | from mongoengine.errors import DoesNotExist, MultipleObjectsReturned | ||||||
|  |  | ||||||
| __all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList') | __all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList") | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseDict(dict): | class BaseDict(dict): | ||||||
|     """A special dict so we can watch any changes.""" |     """A special dict so we can watch any changes""" | ||||||
|  |  | ||||||
|     _dereferenced = False |     _dereferenced = False | ||||||
|     _instance = None |     _instance = None | ||||||
| @@ -95,7 +93,8 @@ class BaseDict(dict): | |||||||
|  |  | ||||||
|  |  | ||||||
| class BaseList(list): | class BaseList(list): | ||||||
|     """A special list so we can watch any changes.""" |     """A special list so we can watch any changes | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     _dereferenced = False |     _dereferenced = False | ||||||
|     _instance = None |     _instance = None | ||||||
| @@ -138,7 +137,10 @@ class BaseList(list): | |||||||
|         return super(BaseList, self).__setitem__(key, value) |         return super(BaseList, self).__setitem__(key, value) | ||||||
|  |  | ||||||
|     def __delitem__(self, key, *args, **kwargs): |     def __delitem__(self, key, *args, **kwargs): | ||||||
|         self._mark_as_changed() |         if isinstance(key, slice): | ||||||
|  |             self._mark_as_changed() | ||||||
|  |         else: | ||||||
|  |             self._mark_as_changed(key) | ||||||
|         return super(BaseList, self).__delitem__(key) |         return super(BaseList, self).__delitem__(key) | ||||||
|  |  | ||||||
|     def __setslice__(self, *args, **kwargs): |     def __setslice__(self, *args, **kwargs): | ||||||
| @@ -207,22 +209,17 @@ class BaseList(list): | |||||||
| class EmbeddedDocumentList(BaseList): | class EmbeddedDocumentList(BaseList): | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def __match_all(cls, embedded_doc, kwargs): |     def __match_all(cls, i, kwargs): | ||||||
|         """Return True if a given embedded doc matches all the filter |         items = kwargs.items() | ||||||
|         kwargs. If it doesn't return False. |         return all([ | ||||||
|         """ |             getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items | ||||||
|         for key, expected_value in kwargs.items(): |         ]) | ||||||
|             doc_val = getattr(embedded_doc, key) |  | ||||||
|             if doc_val != expected_value and six.text_type(doc_val) != expected_value: |  | ||||||
|                 return False |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def __only_matches(cls, embedded_docs, kwargs): |     def __only_matches(cls, obj, kwargs): | ||||||
|         """Return embedded docs that match the filter kwargs.""" |  | ||||||
|         if not kwargs: |         if not kwargs: | ||||||
|             return embedded_docs |             return obj | ||||||
|         return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)] |         return filter(lambda i: cls.__match_all(i, kwargs), obj) | ||||||
|  |  | ||||||
|     def __init__(self, list_items, instance, name): |     def __init__(self, list_items, instance, name): | ||||||
|         super(EmbeddedDocumentList, self).__init__(list_items, instance, name) |         super(EmbeddedDocumentList, self).__init__(list_items, instance, name) | ||||||
| @@ -288,18 +285,18 @@ class EmbeddedDocumentList(BaseList): | |||||||
|         values = self.__only_matches(self, kwargs) |         values = self.__only_matches(self, kwargs) | ||||||
|         if len(values) == 0: |         if len(values) == 0: | ||||||
|             raise DoesNotExist( |             raise DoesNotExist( | ||||||
|                 '%s matching query does not exist.' % self._name |                 "%s matching query does not exist." % self._name | ||||||
|             ) |             ) | ||||||
|         elif len(values) > 1: |         elif len(values) > 1: | ||||||
|             raise MultipleObjectsReturned( |             raise MultipleObjectsReturned( | ||||||
|                 '%d items returned, instead of 1' % len(values) |                 "%d items returned, instead of 1" % len(values) | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         return values[0] |         return values[0] | ||||||
|  |  | ||||||
|     def first(self): |     def first(self): | ||||||
|         """Return the first embedded document in the list, or ``None`` |         """ | ||||||
|         if empty. |         Returns the first embedded document in the list, or ``None`` if empty. | ||||||
|         """ |         """ | ||||||
|         if len(self) > 0: |         if len(self) > 0: | ||||||
|             return self[0] |             return self[0] | ||||||
| @@ -441,7 +438,7 @@ class StrictDict(object): | |||||||
|                 __slots__ = allowed_keys_tuple |                 __slots__ = allowed_keys_tuple | ||||||
|  |  | ||||||
|                 def __repr__(self): |                 def __repr__(self): | ||||||
|                     return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) |                     return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys()) | ||||||
|  |  | ||||||
|             cls._classes[allowed_keys] = SpecificStrictDict |             cls._classes[allowed_keys] = SpecificStrictDict | ||||||
|         return cls._classes[allowed_keys] |         return cls._classes[allowed_keys] | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import copy | import copy | ||||||
| import numbers | import numbers | ||||||
|  | import operator | ||||||
| from collections import Hashable | from collections import Hashable | ||||||
| from functools import partial | from functools import partial | ||||||
|  |  | ||||||
| @@ -7,27 +8,30 @@ from bson import ObjectId, json_util | |||||||
| from bson.dbref import DBRef | from bson.dbref import DBRef | ||||||
| from bson.son import SON | from bson.son import SON | ||||||
| import pymongo | import pymongo | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
| from mongoengine.base.common import get_document | from mongoengine.base.common import ALLOW_INHERITANCE, get_document | ||||||
| from mongoengine.base.datastructures import (BaseDict, BaseList, | from mongoengine.base.datastructures import ( | ||||||
|                                              EmbeddedDocumentList, |     BaseDict, | ||||||
|                                              SemiStrictDict, StrictDict) |     BaseList, | ||||||
|  |     EmbeddedDocumentList, | ||||||
|  |     SemiStrictDict, | ||||||
|  |     StrictDict | ||||||
|  | ) | ||||||
| from mongoengine.base.fields import ComplexBaseField | from mongoengine.base.fields import ComplexBaseField | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, | from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, | ||||||
|                                 LookUpError, OperationError, ValidationError) |                                 LookUpError, ValidationError) | ||||||
|  | from mongoengine.python_support import PY3, txt_type | ||||||
|  |  | ||||||
| __all__ = ('BaseDocument',) | __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | ||||||
|  |  | ||||||
| NON_FIELD_ERRORS = '__all__' | NON_FIELD_ERRORS = '__all__' | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseDocument(object): | class BaseDocument(object): | ||||||
|     __slots__ = ('_changed_fields', '_initialised', '_created', '_data', |     __slots__ = ('_changed_fields', '_initialised', '_created', '_data', | ||||||
|                  '_dynamic_fields', '_auto_id_field', '_db_field_map', |                  '_dynamic_fields', '_auto_id_field', '_db_field_map', '__weakref__') | ||||||
|                  '__weakref__') |  | ||||||
|  |  | ||||||
|     _dynamic = False |     _dynamic = False | ||||||
|     _dynamic_lock = True |     _dynamic_lock = True | ||||||
| @@ -53,15 +57,15 @@ class BaseDocument(object): | |||||||
|                 name = next(field) |                 name = next(field) | ||||||
|                 if name in values: |                 if name in values: | ||||||
|                     raise TypeError( |                     raise TypeError( | ||||||
|                         'Multiple values for keyword argument "%s"' % name) |                         "Multiple values for keyword argument '" + name + "'") | ||||||
|                 values[name] = value |                 values[name] = value | ||||||
|  |  | ||||||
|         __auto_convert = values.pop('__auto_convert', True) |         __auto_convert = values.pop("__auto_convert", True) | ||||||
|  |  | ||||||
|         # 399: set default values only to fields loaded from DB |         # 399: set default values only to fields loaded from DB | ||||||
|         __only_fields = set(values.pop('__only_fields', values)) |         __only_fields = set(values.pop("__only_fields", values)) | ||||||
|  |  | ||||||
|         _created = values.pop('_created', True) |         _created = values.pop("_created", True) | ||||||
|  |  | ||||||
|         signals.pre_init.send(self.__class__, document=self, values=values) |         signals.pre_init.send(self.__class__, document=self, values=values) | ||||||
|  |  | ||||||
| @@ -72,7 +76,7 @@ class BaseDocument(object): | |||||||
|                 self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) |                 self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) | ||||||
|             if _undefined_fields: |             if _undefined_fields: | ||||||
|                 msg = ( |                 msg = ( | ||||||
|                     'The fields "{0}" do not exist on the document "{1}"' |                     "The fields '{0}' do not exist on the document '{1}'" | ||||||
|                 ).format(_undefined_fields, self._class_name) |                 ).format(_undefined_fields, self._class_name) | ||||||
|                 raise FieldDoesNotExist(msg) |                 raise FieldDoesNotExist(msg) | ||||||
|  |  | ||||||
| @@ -91,7 +95,7 @@ class BaseDocument(object): | |||||||
|             value = getattr(self, key, None) |             value = getattr(self, key, None) | ||||||
|             setattr(self, key, value) |             setattr(self, key, value) | ||||||
|  |  | ||||||
|         if '_cls' not in values: |         if "_cls" not in values: | ||||||
|             self._cls = self._class_name |             self._cls = self._class_name | ||||||
|  |  | ||||||
|         # Set passed values after initialisation |         # Set passed values after initialisation | ||||||
| @@ -117,7 +121,7 @@ class BaseDocument(object): | |||||||
|                 else: |                 else: | ||||||
|                     self._data[key] = value |                     self._data[key] = value | ||||||
|  |  | ||||||
|         # Set any get_<field>_display methods |         # Set any get_fieldname_display methods | ||||||
|         self.__set_field_display() |         self.__set_field_display() | ||||||
|  |  | ||||||
|         if self._dynamic: |         if self._dynamic: | ||||||
| @@ -146,7 +150,7 @@ class BaseDocument(object): | |||||||
|         if self._dynamic and not self._dynamic_lock: |         if self._dynamic and not self._dynamic_lock: | ||||||
|  |  | ||||||
|             if not hasattr(self, name) and not name.startswith('_'): |             if not hasattr(self, name) and not name.startswith('_'): | ||||||
|                 DynamicField = _import_class('DynamicField') |                 DynamicField = _import_class("DynamicField") | ||||||
|                 field = DynamicField(db_field=name) |                 field = DynamicField(db_field=name) | ||||||
|                 field.name = name |                 field.name = name | ||||||
|                 self._dynamic_fields[name] = field |                 self._dynamic_fields[name] = field | ||||||
| @@ -165,13 +169,11 @@ class BaseDocument(object): | |||||||
|         except AttributeError: |         except AttributeError: | ||||||
|             self__created = True |             self__created = True | ||||||
|  |  | ||||||
|         if ( |         if (self._is_document and not self__created and | ||||||
|             self._is_document and |                 name in self._meta.get('shard_key', tuple()) and | ||||||
|             not self__created and |                 self._data.get(name) != value): | ||||||
|             name in self._meta.get('shard_key', tuple()) and |             OperationError = _import_class('OperationError') | ||||||
|             self._data.get(name) != value |             msg = "Shard Keys are immutable. Tried to update %s" % name | ||||||
|         ): |  | ||||||
|             msg = 'Shard Keys are immutable. Tried to update %s' % name |  | ||||||
|             raise OperationError(msg) |             raise OperationError(msg) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
| @@ -195,8 +197,8 @@ class BaseDocument(object): | |||||||
|         return data |         return data | ||||||
|  |  | ||||||
|     def __setstate__(self, data): |     def __setstate__(self, data): | ||||||
|         if isinstance(data['_data'], SON): |         if isinstance(data["_data"], SON): | ||||||
|             data['_data'] = self.__class__._from_son(data['_data'])._data |             data["_data"] = self.__class__._from_son(data["_data"])._data | ||||||
|         for k in ('_changed_fields', '_initialised', '_created', '_data', |         for k in ('_changed_fields', '_initialised', '_created', '_data', | ||||||
|                   '_dynamic_fields'): |                   '_dynamic_fields'): | ||||||
|             if k in data: |             if k in data: | ||||||
| @@ -210,7 +212,7 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|         dynamic_fields = data.get('_dynamic_fields') or SON() |         dynamic_fields = data.get('_dynamic_fields') or SON() | ||||||
|         for k in dynamic_fields.keys(): |         for k in dynamic_fields.keys(): | ||||||
|             setattr(self, k, data['_data'].get(k)) |             setattr(self, k, data["_data"].get(k)) | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         return iter(self._fields_ordered) |         return iter(self._fields_ordered) | ||||||
| @@ -252,13 +254,12 @@ class BaseDocument(object): | |||||||
|         return repr_type('<%s: %s>' % (self.__class__.__name__, u)) |         return repr_type('<%s: %s>' % (self.__class__.__name__, u)) | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         # TODO this could be simpler? |  | ||||||
|         if hasattr(self, '__unicode__'): |         if hasattr(self, '__unicode__'): | ||||||
|             if six.PY3: |             if PY3: | ||||||
|                 return self.__unicode__() |                 return self.__unicode__() | ||||||
|             else: |             else: | ||||||
|                 return six.text_type(self).encode('utf-8') |                 return unicode(self).encode('utf-8') | ||||||
|         return six.text_type('%s object' % self.__class__.__name__) |         return txt_type('%s object' % self.__class__.__name__) | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: |         if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: | ||||||
| @@ -307,7 +308,7 @@ class BaseDocument(object): | |||||||
|             fields = [] |             fields = [] | ||||||
|  |  | ||||||
|         data = SON() |         data = SON() | ||||||
|         data['_id'] = None |         data["_id"] = None | ||||||
|         data['_cls'] = self._class_name |         data['_cls'] = self._class_name | ||||||
|  |  | ||||||
|         # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] |         # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] | ||||||
| @@ -350,8 +351,18 @@ class BaseDocument(object): | |||||||
|                 else: |                 else: | ||||||
|                     data[field.name] = value |                     data[field.name] = value | ||||||
|  |  | ||||||
|  |         # If "_id" has not been set, then try and set it | ||||||
|  |         Document = _import_class("Document") | ||||||
|  |         if isinstance(self, Document): | ||||||
|  |             if data["_id"] is None: | ||||||
|  |                 data["_id"] = self._data.get("id", None) | ||||||
|  |  | ||||||
|  |         if data['_id'] is None: | ||||||
|  |             data.pop('_id') | ||||||
|  |  | ||||||
|         # Only add _cls if allow_inheritance is True |         # Only add _cls if allow_inheritance is True | ||||||
|         if not self._meta.get('allow_inheritance'): |         if (not hasattr(self, '_meta') or | ||||||
|  |                 not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): | ||||||
|             data.pop('_cls') |             data.pop('_cls') | ||||||
|  |  | ||||||
|         return data |         return data | ||||||
| @@ -365,16 +376,16 @@ class BaseDocument(object): | |||||||
|         if clean: |         if clean: | ||||||
|             try: |             try: | ||||||
|                 self.clean() |                 self.clean() | ||||||
|             except ValidationError as error: |             except ValidationError, error: | ||||||
|                 errors[NON_FIELD_ERRORS] = error |                 errors[NON_FIELD_ERRORS] = error | ||||||
|  |  | ||||||
|         # Get a list of tuples of field names and their current values |         # Get a list of tuples of field names and their current values | ||||||
|         fields = [(self._fields.get(name, self._dynamic_fields.get(name)), |         fields = [(self._fields.get(name, self._dynamic_fields.get(name)), | ||||||
|                    self._data.get(name)) for name in self._fields_ordered] |                    self._data.get(name)) for name in self._fields_ordered] | ||||||
|  |  | ||||||
|         EmbeddedDocumentField = _import_class('EmbeddedDocumentField') |         EmbeddedDocumentField = _import_class("EmbeddedDocumentField") | ||||||
|         GenericEmbeddedDocumentField = _import_class( |         GenericEmbeddedDocumentField = _import_class( | ||||||
|             'GenericEmbeddedDocumentField') |             "GenericEmbeddedDocumentField") | ||||||
|  |  | ||||||
|         for field, value in fields: |         for field, value in fields: | ||||||
|             if value is not None: |             if value is not None: | ||||||
| @@ -384,21 +395,21 @@ class BaseDocument(object): | |||||||
|                         field._validate(value, clean=clean) |                         field._validate(value, clean=clean) | ||||||
|                     else: |                     else: | ||||||
|                         field._validate(value) |                         field._validate(value) | ||||||
|                 except ValidationError as error: |                 except ValidationError, error: | ||||||
|                     errors[field.name] = error.errors or error |                     errors[field.name] = error.errors or error | ||||||
|                 except (ValueError, AttributeError, AssertionError) as error: |                 except (ValueError, AttributeError, AssertionError), error: | ||||||
|                     errors[field.name] = error |                     errors[field.name] = error | ||||||
|             elif field.required and not getattr(field, '_auto_gen', False): |             elif field.required and not getattr(field, '_auto_gen', False): | ||||||
|                 errors[field.name] = ValidationError('Field is required', |                 errors[field.name] = ValidationError('Field is required', | ||||||
|                                                      field_name=field.name) |                                                      field_name=field.name) | ||||||
|  |  | ||||||
|         if errors: |         if errors: | ||||||
|             pk = 'None' |             pk = "None" | ||||||
|             if hasattr(self, 'pk'): |             if hasattr(self, 'pk'): | ||||||
|                 pk = self.pk |                 pk = self.pk | ||||||
|             elif self._instance and hasattr(self._instance, 'pk'): |             elif self._instance and hasattr(self._instance, 'pk'): | ||||||
|                 pk = self._instance.pk |                 pk = self._instance.pk | ||||||
|             message = 'ValidationError (%s:%s) ' % (self._class_name, pk) |             message = "ValidationError (%s:%s) " % (self._class_name, pk) | ||||||
|             raise ValidationError(message, errors=errors) |             raise ValidationError(message, errors=errors) | ||||||
|  |  | ||||||
|     def to_json(self, *args, **kwargs): |     def to_json(self, *args, **kwargs): | ||||||
| @@ -415,26 +426,33 @@ class BaseDocument(object): | |||||||
|         return cls._from_son(json_util.loads(json_data), created=created) |         return cls._from_son(json_util.loads(json_data), created=created) | ||||||
|  |  | ||||||
|     def __expand_dynamic_values(self, name, value): |     def __expand_dynamic_values(self, name, value): | ||||||
|         """Expand any dynamic values to their correct types / values.""" |         """expand any dynamic values to their correct types / values""" | ||||||
|         if not isinstance(value, (dict, list, tuple)): |         if not isinstance(value, (dict, list, tuple)): | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         # If the value is a dict with '_cls' in it, turn it into a document |         EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') | ||||||
|         is_dict = isinstance(value, dict) |  | ||||||
|         if is_dict and '_cls' in value: |         is_list = False | ||||||
|  |         if not hasattr(value, 'items'): | ||||||
|  |             is_list = True | ||||||
|  |             value = dict([(k, v) for k, v in enumerate(value)]) | ||||||
|  |  | ||||||
|  |         if not is_list and '_cls' in value: | ||||||
|             cls = get_document(value['_cls']) |             cls = get_document(value['_cls']) | ||||||
|             return cls(**value) |             return cls(**value) | ||||||
|  |  | ||||||
|         if is_dict: |         data = {} | ||||||
|             value = { |         for k, v in value.items(): | ||||||
|                 k: self.__expand_dynamic_values(k, v) |             key = name if is_list else k | ||||||
|                 for k, v in value.items() |             data[k] = self.__expand_dynamic_values(key, v) | ||||||
|             } |  | ||||||
|  |         if is_list:  # Convert back to a list | ||||||
|  |             data_items = sorted(data.items(), key=operator.itemgetter(0)) | ||||||
|  |             value = [v for k, v in data_items] | ||||||
|         else: |         else: | ||||||
|             value = [self.__expand_dynamic_values(name, v) for v in value] |             value = data | ||||||
|  |  | ||||||
|         # Convert lists / values so we can watch for any changes on them |         # Convert lists / values so we can watch for any changes on them | ||||||
|         EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') |  | ||||||
|         if (isinstance(value, (list, tuple)) and |         if (isinstance(value, (list, tuple)) and | ||||||
|                 not isinstance(value, BaseList)): |                 not isinstance(value, BaseList)): | ||||||
|             if issubclass(type(self), EmbeddedDocumentListField): |             if issubclass(type(self), EmbeddedDocumentListField): | ||||||
| @@ -447,7 +465,8 @@ class BaseDocument(object): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def _mark_as_changed(self, key): |     def _mark_as_changed(self, key): | ||||||
|         """Mark a key as explicitly changed by the user.""" |         """Marks a key as explicitly changed by the user | ||||||
|  |         """ | ||||||
|         if not key: |         if not key: | ||||||
|             return |             return | ||||||
|  |  | ||||||
| @@ -477,11 +496,10 @@ class BaseDocument(object): | |||||||
|                         remove(field) |                         remove(field) | ||||||
|  |  | ||||||
|     def _clear_changed_fields(self): |     def _clear_changed_fields(self): | ||||||
|         """Using _get_changed_fields iterate and remove any fields that |         """Using get_changed_fields iterate and remove any fields that are | ||||||
|         are marked as changed. |         marked as changed""" | ||||||
|         """ |  | ||||||
|         for changed in self._get_changed_fields(): |         for changed in self._get_changed_fields(): | ||||||
|             parts = changed.split('.') |             parts = changed.split(".") | ||||||
|             data = self |             data = self | ||||||
|             for part in parts: |             for part in parts: | ||||||
|                 if isinstance(data, list): |                 if isinstance(data, list): | ||||||
| @@ -493,13 +511,10 @@ class BaseDocument(object): | |||||||
|                     data = data.get(part, None) |                     data = data.get(part, None) | ||||||
|                 else: |                 else: | ||||||
|                     data = getattr(data, part, None) |                     data = getattr(data, part, None) | ||||||
|  |                 if hasattr(data, "_changed_fields"): | ||||||
|                 if hasattr(data, '_changed_fields'): |                     if hasattr(data, "_is_document") and data._is_document: | ||||||
|                     if getattr(data, '_is_document', False): |  | ||||||
|                         continue |                         continue | ||||||
|  |  | ||||||
|                     data._changed_fields = [] |                     data._changed_fields = [] | ||||||
|  |  | ||||||
|         self._changed_fields = [] |         self._changed_fields = [] | ||||||
|  |  | ||||||
|     def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): |     def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): | ||||||
| @@ -511,27 +526,26 @@ class BaseDocument(object): | |||||||
|             iterator = data.iteritems() |             iterator = data.iteritems() | ||||||
|  |  | ||||||
|         for index, value in iterator: |         for index, value in iterator: | ||||||
|             list_key = '%s%s.' % (key, index) |             list_key = "%s%s." % (key, index) | ||||||
|             # don't check anything lower if this key is already marked |             # don't check anything lower if this key is already marked | ||||||
|             # as changed. |             # as changed. | ||||||
|             if list_key[:-1] in changed_fields: |             if list_key[:-1] in changed_fields: | ||||||
|                 continue |                 continue | ||||||
|             if hasattr(value, '_get_changed_fields'): |             if hasattr(value, '_get_changed_fields'): | ||||||
|                 changed = value._get_changed_fields(inspected) |                 changed = value._get_changed_fields(inspected) | ||||||
|                 changed_fields += ['%s%s' % (list_key, k) |                 changed_fields += ["%s%s" % (list_key, k) | ||||||
|                                    for k in changed if k] |                                    for k in changed if k] | ||||||
|             elif isinstance(value, (list, tuple, dict)): |             elif isinstance(value, (list, tuple, dict)): | ||||||
|                 self._nestable_types_changed_fields( |                 self._nestable_types_changed_fields( | ||||||
|                     changed_fields, list_key, value, inspected) |                     changed_fields, list_key, value, inspected) | ||||||
|  |  | ||||||
|     def _get_changed_fields(self, inspected=None): |     def _get_changed_fields(self, inspected=None): | ||||||
|         """Return a list of all fields that have explicitly been changed. |         """Returns a list of all fields that have explicitly been changed. | ||||||
|         """ |         """ | ||||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') |         EmbeddedDocument = _import_class("EmbeddedDocument") | ||||||
|         DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument') |         DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") | ||||||
|         ReferenceField = _import_class('ReferenceField') |         ReferenceField = _import_class("ReferenceField") | ||||||
|         SortedListField = _import_class('SortedListField') |         SortedListField = _import_class("SortedListField") | ||||||
|  |  | ||||||
|         changed_fields = [] |         changed_fields = [] | ||||||
|         changed_fields += getattr(self, '_changed_fields', []) |         changed_fields += getattr(self, '_changed_fields', []) | ||||||
|  |  | ||||||
| @@ -558,7 +572,7 @@ class BaseDocument(object): | |||||||
|             ): |             ): | ||||||
|                 # Find all embedded fields that have been changed |                 # Find all embedded fields that have been changed | ||||||
|                 changed = data._get_changed_fields(inspected) |                 changed = data._get_changed_fields(inspected) | ||||||
|                 changed_fields += ['%s%s' % (key, k) for k in changed if k] |                 changed_fields += ["%s%s" % (key, k) for k in changed if k] | ||||||
|             elif (isinstance(data, (list, tuple, dict)) and |             elif (isinstance(data, (list, tuple, dict)) and | ||||||
|                     db_field_name not in changed_fields): |                     db_field_name not in changed_fields): | ||||||
|                 if (hasattr(field, 'field') and |                 if (hasattr(field, 'field') and | ||||||
| @@ -662,28 +676,21 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _get_collection_name(cls): |     def _get_collection_name(cls): | ||||||
|         """Return the collection name for this class. None for abstract |         """Returns the collection name for this class. None for abstract class | ||||||
|         class. |  | ||||||
|         """ |         """ | ||||||
|         return cls._meta.get('collection', None) |         return cls._meta.get('collection', None) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): |     def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): | ||||||
|         """Create an instance of a Document (subclass) from a PyMongo |         """Create an instance of a Document (subclass) from a PyMongo SON. | ||||||
|         SON. |  | ||||||
|         """ |         """ | ||||||
|         if not only_fields: |         if not only_fields: | ||||||
|             only_fields = [] |             only_fields = [] | ||||||
|  |  | ||||||
|         if son and not isinstance(son, dict): |         # get the class name from the document, falling back to the given | ||||||
|             raise ValueError("The source SON object needs to be of type 'dict'") |  | ||||||
|  |  | ||||||
|         # Get the class name from the document, falling back to the given |  | ||||||
|         # class if unavailable |         # class if unavailable | ||||||
|         class_name = son.get('_cls', cls._class_name) |         class_name = son.get('_cls', cls._class_name) | ||||||
|  |         data = dict(("%s" % key, value) for key, value in son.iteritems()) | ||||||
|         # Convert SON to a dict, making sure each key is a string |  | ||||||
|         data = {str(key): value for key, value in son.iteritems()} |  | ||||||
|  |  | ||||||
|         # Return correct subclass for document type |         # Return correct subclass for document type | ||||||
|         if class_name != cls._class_name: |         if class_name != cls._class_name: | ||||||
| @@ -705,20 +712,19 @@ class BaseDocument(object): | |||||||
|                                         else field.to_python(value)) |                                         else field.to_python(value)) | ||||||
|                     if field_name != field.db_field: |                     if field_name != field.db_field: | ||||||
|                         del data[field.db_field] |                         del data[field.db_field] | ||||||
|                 except (AttributeError, ValueError) as e: |                 except (AttributeError, ValueError), e: | ||||||
|                     errors_dict[field_name] = e |                     errors_dict[field_name] = e | ||||||
|  |  | ||||||
|         if errors_dict: |         if errors_dict: | ||||||
|             errors = '\n'.join(['%s - %s' % (k, v) |             errors = "\n".join(["%s - %s" % (k, v) | ||||||
|                                 for k, v in errors_dict.items()]) |                                 for k, v in errors_dict.items()]) | ||||||
|             msg = ('Invalid data to create a `%s` instance.\n%s' |             msg = ("Invalid data to create a `%s` instance.\n%s" | ||||||
|                    % (cls._class_name, errors)) |                    % (cls._class_name, errors)) | ||||||
|             raise InvalidDocumentError(msg) |             raise InvalidDocumentError(msg) | ||||||
|  |  | ||||||
|         # In STRICT documents, remove any keys that aren't in cls._fields |  | ||||||
|         if cls.STRICT: |         if cls.STRICT: | ||||||
|             data = {k: v for k, v in data.iteritems() if k in cls._fields} |             data = dict((k, v) | ||||||
|  |                         for k, v in data.iteritems() if k in cls._fields) | ||||||
|         obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) |         obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) | ||||||
|         obj._changed_fields = changed_fields |         obj._changed_fields = changed_fields | ||||||
|         if not _auto_dereference: |         if not _auto_dereference: | ||||||
| @@ -728,43 +734,37 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _build_index_specs(cls, meta_indexes): |     def _build_index_specs(cls, meta_indexes): | ||||||
|         """Generate and merge the full index specs.""" |         """Generate and merge the full index specs | ||||||
|  |         """ | ||||||
|  |  | ||||||
|         geo_indices = cls._geo_indices() |         geo_indices = cls._geo_indices() | ||||||
|         unique_indices = cls._unique_with_indexes() |         unique_indices = cls._unique_with_indexes() | ||||||
|         index_specs = [cls._build_index_spec(spec) for spec in meta_indexes] |         index_specs = [cls._build_index_spec(spec) | ||||||
|  |                        for spec in meta_indexes] | ||||||
|  |  | ||||||
|         def merge_index_specs(index_specs, indices): |         def merge_index_specs(index_specs, indices): | ||||||
|             """Helper method for merging index specs.""" |  | ||||||
|             if not indices: |             if not indices: | ||||||
|                 return index_specs |                 return index_specs | ||||||
|  |  | ||||||
|             # Create a map of index fields to index spec. We're converting |             spec_fields = [v['fields'] | ||||||
|             # the fields from a list to a tuple so that it's hashable. |                            for k, v in enumerate(index_specs)] | ||||||
|             spec_fields = { |             # Merge unique_indexes with existing specs | ||||||
|                 tuple(index['fields']): index for index in index_specs |             for k, v in enumerate(indices): | ||||||
|             } |                 if v['fields'] in spec_fields: | ||||||
|  |                     index_specs[spec_fields.index(v['fields'])].update(v) | ||||||
|             # For each new index, if there's an existing index with the same |  | ||||||
|             # fields list, update the existing spec with all data from the |  | ||||||
|             # new spec. |  | ||||||
|             for new_index in indices: |  | ||||||
|                 candidate = spec_fields.get(tuple(new_index['fields'])) |  | ||||||
|                 if candidate is None: |  | ||||||
|                     index_specs.append(new_index) |  | ||||||
|                 else: |                 else: | ||||||
|                     candidate.update(new_index) |                     index_specs.append(v) | ||||||
|  |  | ||||||
|             return index_specs |             return index_specs | ||||||
|  |  | ||||||
|         # Merge geo indexes and unique_with indexes into the meta index specs. |  | ||||||
|         index_specs = merge_index_specs(index_specs, geo_indices) |         index_specs = merge_index_specs(index_specs, geo_indices) | ||||||
|         index_specs = merge_index_specs(index_specs, unique_indices) |         index_specs = merge_index_specs(index_specs, unique_indices) | ||||||
|         return index_specs |         return index_specs | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _build_index_spec(cls, spec): |     def _build_index_spec(cls, spec): | ||||||
|         """Build a PyMongo index spec from a MongoEngine index spec.""" |         """Build a PyMongo index spec from a MongoEngine index spec. | ||||||
|         if isinstance(spec, six.string_types): |         """ | ||||||
|  |         if isinstance(spec, basestring): | ||||||
|             spec = {'fields': [spec]} |             spec = {'fields': [spec]} | ||||||
|         elif isinstance(spec, (list, tuple)): |         elif isinstance(spec, (list, tuple)): | ||||||
|             spec = {'fields': list(spec)} |             spec = {'fields': list(spec)} | ||||||
| @@ -775,7 +775,8 @@ class BaseDocument(object): | |||||||
|         direction = None |         direction = None | ||||||
|  |  | ||||||
|         # Check to see if we need to include _cls |         # Check to see if we need to include _cls | ||||||
|         allow_inheritance = cls._meta.get('allow_inheritance') |         allow_inheritance = cls._meta.get('allow_inheritance', | ||||||
|  |                                           ALLOW_INHERITANCE) | ||||||
|         include_cls = ( |         include_cls = ( | ||||||
|             allow_inheritance and |             allow_inheritance and | ||||||
|             not spec.get('sparse', False) and |             not spec.get('sparse', False) and | ||||||
| @@ -785,7 +786,7 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|         # 733: don't include cls if index_cls is False unless there is an explicit cls with the index |         # 733: don't include cls if index_cls is False unless there is an explicit cls with the index | ||||||
|         include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) |         include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) | ||||||
|         if 'cls' in spec: |         if "cls" in spec: | ||||||
|             spec.pop('cls') |             spec.pop('cls') | ||||||
|         for key in spec['fields']: |         for key in spec['fields']: | ||||||
|             # If inherited spec continue |             # If inherited spec continue | ||||||
| @@ -800,19 +801,19 @@ class BaseDocument(object): | |||||||
|             # GEOHAYSTACK from ) |             # GEOHAYSTACK from ) | ||||||
|             # GEO2D from * |             # GEO2D from * | ||||||
|             direction = pymongo.ASCENDING |             direction = pymongo.ASCENDING | ||||||
|             if key.startswith('-'): |             if key.startswith("-"): | ||||||
|                 direction = pymongo.DESCENDING |                 direction = pymongo.DESCENDING | ||||||
|             elif key.startswith('$'): |             elif key.startswith("$"): | ||||||
|                 direction = pymongo.TEXT |                 direction = pymongo.TEXT | ||||||
|             elif key.startswith('#'): |             elif key.startswith("#"): | ||||||
|                 direction = pymongo.HASHED |                 direction = pymongo.HASHED | ||||||
|             elif key.startswith('('): |             elif key.startswith("("): | ||||||
|                 direction = pymongo.GEOSPHERE |                 direction = pymongo.GEOSPHERE | ||||||
|             elif key.startswith(')'): |             elif key.startswith(")"): | ||||||
|                 direction = pymongo.GEOHAYSTACK |                 direction = pymongo.GEOHAYSTACK | ||||||
|             elif key.startswith('*'): |             elif key.startswith("*"): | ||||||
|                 direction = pymongo.GEO2D |                 direction = pymongo.GEO2D | ||||||
|             if key.startswith(('+', '-', '*', '$', '#', '(', ')')): |             if key.startswith(("+", "-", "*", "$", "#", "(", ")")): | ||||||
|                 key = key[1:] |                 key = key[1:] | ||||||
|  |  | ||||||
|             # Use real field name, do it manually because we need field |             # Use real field name, do it manually because we need field | ||||||
| @@ -825,7 +826,7 @@ class BaseDocument(object): | |||||||
|                 parts = [] |                 parts = [] | ||||||
|                 for field in fields: |                 for field in fields: | ||||||
|                     try: |                     try: | ||||||
|                         if field != '_id': |                         if field != "_id": | ||||||
|                             field = field.db_field |                             field = field.db_field | ||||||
|                     except AttributeError: |                     except AttributeError: | ||||||
|                         pass |                         pass | ||||||
| @@ -844,53 +845,49 @@ class BaseDocument(object): | |||||||
|         return spec |         return spec | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _unique_with_indexes(cls, namespace=''): |     def _unique_with_indexes(cls, namespace=""): | ||||||
|         """Find unique indexes in the document schema and return them.""" |         """ | ||||||
|  |         Find and set unique indexes | ||||||
|  |         """ | ||||||
|         unique_indexes = [] |         unique_indexes = [] | ||||||
|         for field_name, field in cls._fields.items(): |         for field_name, field in cls._fields.items(): | ||||||
|             sparse = field.sparse |             sparse = field.sparse | ||||||
|  |  | ||||||
|             # Generate a list of indexes needed by uniqueness constraints |             # Generate a list of indexes needed by uniqueness constraints | ||||||
|             if field.unique: |             if field.unique: | ||||||
|                 unique_fields = [field.db_field] |                 unique_fields = [field.db_field] | ||||||
|  |  | ||||||
|                 # Add any unique_with fields to the back of the index spec |                 # Add any unique_with fields to the back of the index spec | ||||||
|                 if field.unique_with: |                 if field.unique_with: | ||||||
|                     if isinstance(field.unique_with, six.string_types): |                     if isinstance(field.unique_with, basestring): | ||||||
|                         field.unique_with = [field.unique_with] |                         field.unique_with = [field.unique_with] | ||||||
|  |  | ||||||
|                     # Convert unique_with field names to real field names |                     # Convert unique_with field names to real field names | ||||||
|                     unique_with = [] |                     unique_with = [] | ||||||
|                     for other_name in field.unique_with: |                     for other_name in field.unique_with: | ||||||
|                         parts = other_name.split('.') |                         parts = other_name.split('.') | ||||||
|  |  | ||||||
|                         # Lookup real name |                         # Lookup real name | ||||||
|                         parts = cls._lookup_field(parts) |                         parts = cls._lookup_field(parts) | ||||||
|                         name_parts = [part.db_field for part in parts] |                         name_parts = [part.db_field for part in parts] | ||||||
|                         unique_with.append('.'.join(name_parts)) |                         unique_with.append('.'.join(name_parts)) | ||||||
|  |  | ||||||
|                         # Unique field should be required |                         # Unique field should be required | ||||||
|                         parts[-1].required = True |                         parts[-1].required = True | ||||||
|                         sparse = (not sparse and |                         sparse = (not sparse and | ||||||
|                                   parts[-1].name not in cls.__dict__) |                                   parts[-1].name not in cls.__dict__) | ||||||
|  |  | ||||||
|                     unique_fields += unique_with |                     unique_fields += unique_with | ||||||
|  |  | ||||||
|                 # Add the new index to the list |                 # Add the new index to the list | ||||||
|                 fields = [ |                 fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) | ||||||
|                     ('%s%s' % (namespace, f), pymongo.ASCENDING) |                           for f in unique_fields] | ||||||
|                     for f in unique_fields |  | ||||||
|                 ] |  | ||||||
|                 index = {'fields': fields, 'unique': True, 'sparse': sparse} |                 index = {'fields': fields, 'unique': True, 'sparse': sparse} | ||||||
|                 unique_indexes.append(index) |                 unique_indexes.append(index) | ||||||
|  |  | ||||||
|             if field.__class__.__name__ == 'ListField': |             if field.__class__.__name__ == "ListField": | ||||||
|                 field = field.field |                 field = field.field | ||||||
|  |  | ||||||
|             # Grab any embedded document field unique indexes |             # Grab any embedded document field unique indexes | ||||||
|             if (field.__class__.__name__ == 'EmbeddedDocumentField' and |             if (field.__class__.__name__ == "EmbeddedDocumentField" and | ||||||
|                     field.document_type != cls): |                     field.document_type != cls): | ||||||
|                 field_namespace = '%s.' % field_name |                 field_namespace = "%s." % field_name | ||||||
|                 doc_cls = field.document_type |                 doc_cls = field.document_type | ||||||
|                 unique_indexes += doc_cls._unique_with_indexes(field_namespace) |                 unique_indexes += doc_cls._unique_with_indexes(field_namespace) | ||||||
|  |  | ||||||
| @@ -902,9 +899,8 @@ class BaseDocument(object): | |||||||
|         geo_indices = [] |         geo_indices = [] | ||||||
|         inspected.append(cls) |         inspected.append(cls) | ||||||
|  |  | ||||||
|         geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField', |         geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", | ||||||
|                                 'PointField', 'LineStringField', |                                 "PointField", "LineStringField", "PolygonField"] | ||||||
|                                 'PolygonField') |  | ||||||
|  |  | ||||||
|         geo_field_types = tuple([_import_class(field) |         geo_field_types = tuple([_import_class(field) | ||||||
|                                  for field in geo_field_type_names]) |                                  for field in geo_field_type_names]) | ||||||
| @@ -912,68 +908,32 @@ class BaseDocument(object): | |||||||
|         for field in cls._fields.values(): |         for field in cls._fields.values(): | ||||||
|             if not isinstance(field, geo_field_types): |             if not isinstance(field, geo_field_types): | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             if hasattr(field, 'document_type'): |             if hasattr(field, 'document_type'): | ||||||
|                 field_cls = field.document_type |                 field_cls = field.document_type | ||||||
|                 if field_cls in inspected: |                 if field_cls in inspected: | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|                 if hasattr(field_cls, '_geo_indices'): |                 if hasattr(field_cls, '_geo_indices'): | ||||||
|                     geo_indices += field_cls._geo_indices( |                     geo_indices += field_cls._geo_indices( | ||||||
|                         inspected, parent_field=field.db_field) |                         inspected, parent_field=field.db_field) | ||||||
|             elif field._geo_index: |             elif field._geo_index: | ||||||
|                 field_name = field.db_field |                 field_name = field.db_field | ||||||
|                 if parent_field: |                 if parent_field: | ||||||
|                     field_name = '%s.%s' % (parent_field, field_name) |                     field_name = "%s.%s" % (parent_field, field_name) | ||||||
|                 geo_indices.append({ |                 geo_indices.append({'fields': | ||||||
|                     'fields': [(field_name, field._geo_index)] |                                     [(field_name, field._geo_index)]}) | ||||||
|                 }) |  | ||||||
|  |  | ||||||
|         return geo_indices |         return geo_indices | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _lookup_field(cls, parts): |     def _lookup_field(cls, parts): | ||||||
|         """Given the path to a given field, return a list containing |         """Lookup a field based on its attribute and return a list containing | ||||||
|         the Field object associated with that field and all of its parent |         the field's parents and the field. | ||||||
|         Field objects. |  | ||||||
|  |  | ||||||
|         Args: |  | ||||||
|             parts (str, list, or tuple) - path to the field. Should be a |  | ||||||
|             string for simple fields existing on this document or a list |  | ||||||
|             of strings for a field that exists deeper in embedded documents. |  | ||||||
|  |  | ||||||
|         Returns: |  | ||||||
|             A list of Field instances for fields that were found or |  | ||||||
|             strings for sub-fields that weren't. |  | ||||||
|  |  | ||||||
|         Example: |  | ||||||
|             >>> user._lookup_field('name') |  | ||||||
|             [<mongoengine.fields.StringField at 0x1119bff50>] |  | ||||||
|  |  | ||||||
|             >>> user._lookup_field('roles') |  | ||||||
|             [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>] |  | ||||||
|  |  | ||||||
|             >>> user._lookup_field(['roles', 'role']) |  | ||||||
|             [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>, |  | ||||||
|              <mongoengine.fields.StringField at 0x1119ec050>] |  | ||||||
|  |  | ||||||
|             >>> user._lookup_field('doesnt_exist') |  | ||||||
|             raises LookUpError |  | ||||||
|  |  | ||||||
|             >>> user._lookup_field(['roles', 'doesnt_exist']) |  | ||||||
|             [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>, |  | ||||||
|              'doesnt_exist'] |  | ||||||
|  |  | ||||||
|         """ |         """ | ||||||
|         # TODO this method is WAY too complicated. Simplify it. |  | ||||||
|         # TODO don't think returning a string for embedded non-existent fields is desired |  | ||||||
|  |  | ||||||
|         ListField = _import_class('ListField') |         ListField = _import_class("ListField") | ||||||
|         DynamicField = _import_class('DynamicField') |         DynamicField = _import_class('DynamicField') | ||||||
|  |  | ||||||
|         if not isinstance(parts, (list, tuple)): |         if not isinstance(parts, (list, tuple)): | ||||||
|             parts = [parts] |             parts = [parts] | ||||||
|  |  | ||||||
|         fields = [] |         fields = [] | ||||||
|         field = None |         field = None | ||||||
|  |  | ||||||
| @@ -983,17 +943,16 @@ class BaseDocument(object): | |||||||
|                 fields.append(field_name) |                 fields.append(field_name) | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             # Look up first field from the document |  | ||||||
|             if field is None: |             if field is None: | ||||||
|  |                 # Look up first field from the document | ||||||
|                 if field_name == 'pk': |                 if field_name == 'pk': | ||||||
|                     # Deal with "primary key" alias |                     # Deal with "primary key" alias | ||||||
|                     field_name = cls._meta['id_field'] |                     field_name = cls._meta['id_field'] | ||||||
|  |  | ||||||
|                 if field_name in cls._fields: |                 if field_name in cls._fields: | ||||||
|                     field = cls._fields[field_name] |                     field = cls._fields[field_name] | ||||||
|                 elif cls._dynamic: |                 elif cls._dynamic: | ||||||
|                     field = DynamicField(db_field=field_name) |                     field = DynamicField(db_field=field_name) | ||||||
|                 elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False): |                 elif cls._meta.get("allow_inheritance", False) or cls._meta.get("abstract", False): | ||||||
|                     # 744: in case the field is defined in a subclass |                     # 744: in case the field is defined in a subclass | ||||||
|                     for subcls in cls.__subclasses__(): |                     for subcls in cls.__subclasses__(): | ||||||
|                         try: |                         try: | ||||||
| @@ -1006,55 +965,35 @@ class BaseDocument(object): | |||||||
|                     else: |                     else: | ||||||
|                         raise LookUpError('Cannot resolve field "%s"' % field_name) |                         raise LookUpError('Cannot resolve field "%s"' % field_name) | ||||||
|                 else: |                 else: | ||||||
|                     raise LookUpError('Cannot resolve field "%s"' % field_name) |                     raise LookUpError('Cannot resolve field "%s"' | ||||||
|  |                                       % field_name) | ||||||
|             else: |             else: | ||||||
|                 ReferenceField = _import_class('ReferenceField') |                 ReferenceField = _import_class('ReferenceField') | ||||||
|                 GenericReferenceField = _import_class('GenericReferenceField') |                 GenericReferenceField = _import_class('GenericReferenceField') | ||||||
|  |  | ||||||
|                 # If previous field was a reference, throw an error (we |  | ||||||
|                 # cannot look up fields that are on references). |  | ||||||
|                 if isinstance(field, (ReferenceField, GenericReferenceField)): |                 if isinstance(field, (ReferenceField, GenericReferenceField)): | ||||||
|                     raise LookUpError('Cannot perform join in mongoDB: %s' % |                     raise LookUpError('Cannot perform join in mongoDB: %s' % | ||||||
|                                       '__'.join(parts)) |                                       '__'.join(parts)) | ||||||
|  |  | ||||||
|                 # If the parent field has a "field" attribute which has a |  | ||||||
|                 # lookup_member method, call it to find the field |  | ||||||
|                 # corresponding to this iteration. |  | ||||||
|                 if hasattr(getattr(field, 'field', None), 'lookup_member'): |                 if hasattr(getattr(field, 'field', None), 'lookup_member'): | ||||||
|                     new_field = field.field.lookup_member(field_name) |                     new_field = field.field.lookup_member(field_name) | ||||||
|  |  | ||||||
|                 # If the parent field is a DynamicField or if it's part of |  | ||||||
|                 # a DynamicDocument, mark current field as a DynamicField |  | ||||||
|                 # with db_name equal to the field name. |  | ||||||
|                 elif cls._dynamic and (isinstance(field, DynamicField) or |                 elif cls._dynamic and (isinstance(field, DynamicField) or | ||||||
|                                        getattr(getattr(field, 'document_type', None), '_dynamic', None)): |                                        getattr(getattr(field, 'document_type', None), '_dynamic', None)): | ||||||
|                     new_field = DynamicField(db_field=field_name) |                     new_field = DynamicField(db_field=field_name) | ||||||
|  |  | ||||||
|                 # Else, try to use the parent field's lookup_member method |  | ||||||
|                 # to find the subfield. |  | ||||||
|                 elif hasattr(field, 'lookup_member'): |  | ||||||
|                     new_field = field.lookup_member(field_name) |  | ||||||
|  |  | ||||||
|                 # Raise a LookUpError if all the other conditions failed. |  | ||||||
|                 else: |                 else: | ||||||
|                     raise LookUpError( |                     # Look up subfield on the previous field or raise | ||||||
|                         'Cannot resolve subfield or operator {} ' |                     try: | ||||||
|                         'on the field {}'.format(field_name, field.name) |                         new_field = field.lookup_member(field_name) | ||||||
|                     ) |                     except AttributeError: | ||||||
|  |                         raise LookUpError('Cannot resolve subfield or operator {} ' | ||||||
|                 # If current field still wasn't found and the parent field |                                           'on the field {}'.format( | ||||||
|                 # is a ComplexBaseField, add the name current field name and |                                               field_name, field.name)) | ||||||
|                 # move on. |  | ||||||
|                 if not new_field and isinstance(field, ComplexBaseField): |                 if not new_field and isinstance(field, ComplexBaseField): | ||||||
|                     fields.append(field_name) |                     fields.append(field_name) | ||||||
|                     continue |                     continue | ||||||
|                 elif not new_field: |                 elif not new_field: | ||||||
|                     raise LookUpError('Cannot resolve field "%s"' % field_name) |                     raise LookUpError('Cannot resolve field "%s"' | ||||||
|  |                                       % field_name) | ||||||
|                 field = new_field  # update field to the new field type |                 field = new_field  # update field to the new field type | ||||||
|  |  | ||||||
|             fields.append(field) |             fields.append(field) | ||||||
|  |  | ||||||
|         return fields |         return fields | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -1066,18 +1005,19 @@ class BaseDocument(object): | |||||||
|         return '.'.join(parts) |         return '.'.join(parts) | ||||||
|  |  | ||||||
|     def __set_field_display(self): |     def __set_field_display(self): | ||||||
|         """For each field that specifies choices, create a |         """Dynamically set the display value for a field with choices""" | ||||||
|         get_<field>_display method. |         for attr_name, field in self._fields.items(): | ||||||
|         """ |             if field.choices: | ||||||
|         fields_with_choices = [(n, f) for n, f in self._fields.items() |                 if self._dynamic: | ||||||
|                                if f.choices] |                     obj = self | ||||||
|         for attr_name, field in fields_with_choices: |                 else: | ||||||
|             setattr(self, |                     obj = type(self) | ||||||
|                     'get_%s_display' % attr_name, |                 setattr(obj, | ||||||
|                     partial(self.__get_field_display, field=field)) |                         'get_%s_display' % attr_name, | ||||||
|  |                         partial(self.__get_field_display, field=field)) | ||||||
|  |  | ||||||
|     def __get_field_display(self, field): |     def __get_field_display(self, field): | ||||||
|         """Return the display value for a choice field""" |         """Returns the display value for a choice field""" | ||||||
|         value = getattr(self, field.name) |         value = getattr(self, field.name) | ||||||
|         if field.choices and isinstance(field.choices[0], (list, tuple)): |         if field.choices and isinstance(field.choices[0], (list, tuple)): | ||||||
|             return dict(field.choices).get(value, value) |             return dict(field.choices).get(value, value) | ||||||
|   | |||||||
| @@ -4,17 +4,21 @@ import weakref | |||||||
|  |  | ||||||
| from bson import DBRef, ObjectId, SON | from bson import DBRef, ObjectId, SON | ||||||
| import pymongo | import pymongo | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine.base.common import UPDATE_OPERATORS | from mongoengine.base.common import ALLOW_INHERITANCE | ||||||
| from mongoengine.base.datastructures import (BaseDict, BaseList, | from mongoengine.base.datastructures import ( | ||||||
|                                              EmbeddedDocumentList) |     BaseDict, BaseList, EmbeddedDocumentList | ||||||
|  | ) | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import ValidationError | from mongoengine.errors import ValidationError | ||||||
|  |  | ||||||
|  | __all__ = ("BaseField", "ComplexBaseField", | ||||||
|  |            "ObjectIdField", "GeoJsonBaseField") | ||||||
|  |  | ||||||
| __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', |  | ||||||
|            'GeoJsonBaseField') | UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', | ||||||
|  |                         'push_all', 'pull', 'pull_all', 'add_to_set', | ||||||
|  |                         'set_on_insert', 'min', 'max']) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseField(object): | class BaseField(object): | ||||||
| @@ -23,6 +27,7 @@ class BaseField(object): | |||||||
|  |  | ||||||
|     .. versionchanged:: 0.5 - added verbose and help text |     .. versionchanged:: 0.5 - added verbose and help text | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     name = None |     name = None | ||||||
|     _geo_index = False |     _geo_index = False | ||||||
|     _auto_gen = False  # Call `generate` to generate a value |     _auto_gen = False  # Call `generate` to generate a value | ||||||
| @@ -68,7 +73,7 @@ class BaseField(object): | |||||||
|         self.db_field = (db_field or name) if not primary_key else '_id' |         self.db_field = (db_field or name) if not primary_key else '_id' | ||||||
|  |  | ||||||
|         if name: |         if name: | ||||||
|             msg = 'Field\'s "name" attribute deprecated in favour of "db_field"' |             msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" | ||||||
|             warnings.warn(msg, DeprecationWarning) |             warnings.warn(msg, DeprecationWarning) | ||||||
|         self.required = required or primary_key |         self.required = required or primary_key | ||||||
|         self.default = default |         self.default = default | ||||||
| @@ -84,7 +89,7 @@ class BaseField(object): | |||||||
|         # Detect and report conflicts between metadata and base properties. |         # Detect and report conflicts between metadata and base properties. | ||||||
|         conflicts = set(dir(self)) & set(kwargs) |         conflicts = set(dir(self)) & set(kwargs) | ||||||
|         if conflicts: |         if conflicts: | ||||||
|             raise TypeError('%s already has attribute(s): %s' % ( |             raise TypeError("%s already has attribute(s): %s" % ( | ||||||
|                 self.__class__.__name__, ', '.join(conflicts))) |                 self.__class__.__name__, ', '.join(conflicts))) | ||||||
|  |  | ||||||
|         # Assign metadata to the instance |         # Assign metadata to the instance | ||||||
| @@ -142,21 +147,25 @@ class BaseField(object): | |||||||
|                     v._instance = weakref.proxy(instance) |                     v._instance = weakref.proxy(instance) | ||||||
|         instance._data[self.name] = value |         instance._data[self.name] = value | ||||||
|  |  | ||||||
|     def error(self, message='', errors=None, field_name=None): |     def error(self, message="", errors=None, field_name=None): | ||||||
|         """Raise a ValidationError.""" |         """Raises a ValidationError. | ||||||
|  |         """ | ||||||
|         field_name = field_name if field_name else self.name |         field_name = field_name if field_name else self.name | ||||||
|         raise ValidationError(message, errors=errors, field_name=field_name) |         raise ValidationError(message, errors=errors, field_name=field_name) | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         """Convert a MongoDB-compatible type to a Python type.""" |         """Convert a MongoDB-compatible type to a Python type. | ||||||
|  |         """ | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         """Convert a Python type to a MongoDB-compatible type.""" |         """Convert a Python type to a MongoDB-compatible type. | ||||||
|  |         """ | ||||||
|         return self.to_python(value) |         return self.to_python(value) | ||||||
|  |  | ||||||
|     def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): |     def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): | ||||||
|         """Helper method to call to_mongo with proper inputs.""" |         """A helper method to call to_mongo with proper inputs | ||||||
|  |         """ | ||||||
|         f_inputs = self.to_mongo.__code__.co_varnames |         f_inputs = self.to_mongo.__code__.co_varnames | ||||||
|         ex_vars = {} |         ex_vars = {} | ||||||
|         if 'fields' in f_inputs: |         if 'fields' in f_inputs: | ||||||
| @@ -168,13 +177,15 @@ class BaseField(object): | |||||||
|         return self.to_mongo(value, **ex_vars) |         return self.to_mongo(value, **ex_vars) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         """Prepare a value that is being used in a query for PyMongo.""" |         """Prepare a value that is being used in a query for PyMongo. | ||||||
|  |         """ | ||||||
|         if op in UPDATE_OPERATORS: |         if op in UPDATE_OPERATORS: | ||||||
|             self.validate(value) |             self.validate(value) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def validate(self, value, clean=True): |     def validate(self, value, clean=True): | ||||||
|         """Perform validation on a value.""" |         """Perform validation on a value. | ||||||
|  |         """ | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     def _validate_choices(self, value): |     def _validate_choices(self, value): | ||||||
| @@ -189,13 +200,11 @@ class BaseField(object): | |||||||
|         if isinstance(value, (Document, EmbeddedDocument)): |         if isinstance(value, (Document, EmbeddedDocument)): | ||||||
|             if not any(isinstance(value, c) for c in choice_list): |             if not any(isinstance(value, c) for c in choice_list): | ||||||
|                 self.error( |                 self.error( | ||||||
|                     'Value must be an instance of %s' % ( |                     'Value must be instance of %s' % unicode(choice_list) | ||||||
|                         six.text_type(choice_list) |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|         # Choices which are types other than Documents |         # Choices which are types other than Documents | ||||||
|         elif value not in choice_list: |         elif value not in choice_list: | ||||||
|             self.error('Value must be one of %s' % six.text_type(choice_list)) |             self.error('Value must be one of %s' % unicode(choice_list)) | ||||||
|  |  | ||||||
|     def _validate(self, value, **kwargs): |     def _validate(self, value, **kwargs): | ||||||
|         # Check the Choices Constraint |         # Check the Choices Constraint | ||||||
| @@ -238,7 +247,8 @@ class ComplexBaseField(BaseField): | |||||||
|     field = None |     field = None | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |     def __get__(self, instance, owner): | ||||||
|         """Descriptor to automatically dereference references.""" |         """Descriptor to automatically dereference references. | ||||||
|  |         """ | ||||||
|         if instance is None: |         if instance is None: | ||||||
|             # Document class being used rather than a document object |             # Document class being used rather than a document object | ||||||
|             return self |             return self | ||||||
| @@ -250,7 +260,7 @@ class ComplexBaseField(BaseField): | |||||||
|                        (self.field is None or isinstance(self.field, |                        (self.field is None or isinstance(self.field, | ||||||
|                                                          (GenericReferenceField, ReferenceField)))) |                                                          (GenericReferenceField, ReferenceField)))) | ||||||
|  |  | ||||||
|         _dereference = _import_class('DeReference')() |         _dereference = _import_class("DeReference")() | ||||||
|  |  | ||||||
|         self._auto_dereference = instance._fields[self.name]._auto_dereference |         self._auto_dereference = instance._fields[self.name]._auto_dereference | ||||||
|         if instance._initialised and dereference and instance._data.get(self.name): |         if instance._initialised and dereference and instance._data.get(self.name): | ||||||
| @@ -285,8 +295,9 @@ class ComplexBaseField(BaseField): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         """Convert a MongoDB-compatible type to a Python type.""" |         """Convert a MongoDB-compatible type to a Python type. | ||||||
|         if isinstance(value, six.string_types): |         """ | ||||||
|  |         if isinstance(value, basestring): | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         if hasattr(value, 'to_python'): |         if hasattr(value, 'to_python'): | ||||||
| @@ -296,14 +307,14 @@ class ComplexBaseField(BaseField): | |||||||
|         if not hasattr(value, 'items'): |         if not hasattr(value, 'items'): | ||||||
|             try: |             try: | ||||||
|                 is_list = True |                 is_list = True | ||||||
|                 value = {k: v for k, v in enumerate(value)} |                 value = dict([(k, v) for k, v in enumerate(value)]) | ||||||
|             except TypeError:  # Not iterable return the value |             except TypeError:  # Not iterable return the value | ||||||
|                 return value |                 return value | ||||||
|  |  | ||||||
|         if self.field: |         if self.field: | ||||||
|             self.field._auto_dereference = self._auto_dereference |             self.field._auto_dereference = self._auto_dereference | ||||||
|             value_dict = {key: self.field.to_python(item) |             value_dict = dict([(key, self.field.to_python(item)) | ||||||
|                           for key, item in value.items()} |                                for key, item in value.items()]) | ||||||
|         else: |         else: | ||||||
|             Document = _import_class('Document') |             Document = _import_class('Document') | ||||||
|             value_dict = {} |             value_dict = {} | ||||||
| @@ -326,12 +337,13 @@ class ComplexBaseField(BaseField): | |||||||
|         return value_dict |         return value_dict | ||||||
|  |  | ||||||
|     def to_mongo(self, value, use_db_field=True, fields=None): |     def to_mongo(self, value, use_db_field=True, fields=None): | ||||||
|         """Convert a Python type to a MongoDB-compatible type.""" |         """Convert a Python type to a MongoDB-compatible type. | ||||||
|         Document = _import_class('Document') |         """ | ||||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') |         Document = _import_class("Document") | ||||||
|         GenericReferenceField = _import_class('GenericReferenceField') |         EmbeddedDocument = _import_class("EmbeddedDocument") | ||||||
|  |         GenericReferenceField = _import_class("GenericReferenceField") | ||||||
|  |  | ||||||
|         if isinstance(value, six.string_types): |         if isinstance(value, basestring): | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         if hasattr(value, 'to_mongo'): |         if hasattr(value, 'to_mongo'): | ||||||
| @@ -348,15 +360,13 @@ class ComplexBaseField(BaseField): | |||||||
|         if not hasattr(value, 'items'): |         if not hasattr(value, 'items'): | ||||||
|             try: |             try: | ||||||
|                 is_list = True |                 is_list = True | ||||||
|                 value = {k: v for k, v in enumerate(value)} |                 value = dict([(k, v) for k, v in enumerate(value)]) | ||||||
|             except TypeError:  # Not iterable return the value |             except TypeError:  # Not iterable return the value | ||||||
|                 return value |                 return value | ||||||
|  |  | ||||||
|         if self.field: |         if self.field: | ||||||
|             value_dict = { |             value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields)) | ||||||
|                 key: self.field._to_mongo_safe_call(item, use_db_field, fields) |                                for key, item in value.iteritems()]) | ||||||
|                 for key, item in value.iteritems() |  | ||||||
|             } |  | ||||||
|         else: |         else: | ||||||
|             value_dict = {} |             value_dict = {} | ||||||
|             for k, v in value.iteritems(): |             for k, v in value.iteritems(): | ||||||
| @@ -370,7 +380,9 @@ class ComplexBaseField(BaseField): | |||||||
|                     # any _cls data so make it a generic reference allows |                     # any _cls data so make it a generic reference allows | ||||||
|                     # us to dereference |                     # us to dereference | ||||||
|                     meta = getattr(v, '_meta', {}) |                     meta = getattr(v, '_meta', {}) | ||||||
|                     allow_inheritance = meta.get('allow_inheritance') |                     allow_inheritance = ( | ||||||
|  |                         meta.get('allow_inheritance', ALLOW_INHERITANCE) | ||||||
|  |                         is True) | ||||||
|                     if not allow_inheritance and not self.field: |                     if not allow_inheritance and not self.field: | ||||||
|                         value_dict[k] = GenericReferenceField().to_mongo(v) |                         value_dict[k] = GenericReferenceField().to_mongo(v) | ||||||
|                     else: |                     else: | ||||||
| @@ -392,7 +404,8 @@ class ComplexBaseField(BaseField): | |||||||
|         return value_dict |         return value_dict | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """If field is provided ensure the value is valid.""" |         """If field is provided ensure the value is valid. | ||||||
|  |         """ | ||||||
|         errors = {} |         errors = {} | ||||||
|         if self.field: |         if self.field: | ||||||
|             if hasattr(value, 'iteritems') or hasattr(value, 'items'): |             if hasattr(value, 'iteritems') or hasattr(value, 'items'): | ||||||
| @@ -402,9 +415,9 @@ class ComplexBaseField(BaseField): | |||||||
|             for k, v in sequence: |             for k, v in sequence: | ||||||
|                 try: |                 try: | ||||||
|                     self.field._validate(v) |                     self.field._validate(v) | ||||||
|                 except ValidationError as error: |                 except ValidationError, error: | ||||||
|                     errors[k] = error.errors or error |                     errors[k] = error.errors or error | ||||||
|                 except (ValueError, AssertionError) as error: |                 except (ValueError, AssertionError), error: | ||||||
|                     errors[k] = error |                     errors[k] = error | ||||||
|  |  | ||||||
|             if errors: |             if errors: | ||||||
| @@ -430,7 +443,8 @@ class ComplexBaseField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class ObjectIdField(BaseField): | class ObjectIdField(BaseField): | ||||||
|     """A field wrapper around MongoDB's ObjectIds.""" |     """A field wrapper around MongoDB's ObjectIds. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         try: |         try: | ||||||
| @@ -443,10 +457,10 @@ class ObjectIdField(BaseField): | |||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         if not isinstance(value, ObjectId): |         if not isinstance(value, ObjectId): | ||||||
|             try: |             try: | ||||||
|                 return ObjectId(six.text_type(value)) |                 return ObjectId(unicode(value)) | ||||||
|             except Exception as e: |             except Exception, e: | ||||||
|                 # e.message attribute has been deprecated since Python 2.6 |                 # e.message attribute has been deprecated since Python 2.6 | ||||||
|                 self.error(six.text_type(e)) |                 self.error(unicode(e)) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
| @@ -454,7 +468,7 @@ class ObjectIdField(BaseField): | |||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         try: |         try: | ||||||
|             ObjectId(six.text_type(value)) |             ObjectId(unicode(value)) | ||||||
|         except Exception: |         except Exception: | ||||||
|             self.error('Invalid Object ID') |             self.error('Invalid Object ID') | ||||||
|  |  | ||||||
| @@ -466,20 +480,21 @@ class GeoJsonBaseField(BaseField): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     _geo_index = pymongo.GEOSPHERE |     _geo_index = pymongo.GEOSPHERE | ||||||
|     _type = 'GeoBase' |     _type = "GeoBase" | ||||||
|  |  | ||||||
|     def __init__(self, auto_index=True, *args, **kwargs): |     def __init__(self, auto_index=True, *args, **kwargs): | ||||||
|         """ |         """ | ||||||
|         :param bool auto_index: Automatically create a '2dsphere' index.\ |         :param bool auto_index: Automatically create a "2dsphere" index.\ | ||||||
|             Defaults to `True`. |             Defaults to `True`. | ||||||
|         """ |         """ | ||||||
|         self._name = '%sField' % self._type |         self._name = "%sField" % self._type | ||||||
|         if not auto_index: |         if not auto_index: | ||||||
|             self._geo_index = False |             self._geo_index = False | ||||||
|         super(GeoJsonBaseField, self).__init__(*args, **kwargs) |         super(GeoJsonBaseField, self).__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """Validate the GeoJson object based on its type.""" |         """Validate the GeoJson object based on its type | ||||||
|  |         """ | ||||||
|         if isinstance(value, dict): |         if isinstance(value, dict): | ||||||
|             if set(value.keys()) == set(['type', 'coordinates']): |             if set(value.keys()) == set(['type', 'coordinates']): | ||||||
|                 if value['type'] != self._type: |                 if value['type'] != self._type: | ||||||
| @@ -494,7 +509,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|             self.error('%s can only accept lists of [x, y]' % self._name) |             self.error('%s can only accept lists of [x, y]' % self._name) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         validate = getattr(self, '_validate_%s' % self._type.lower()) |         validate = getattr(self, "_validate_%s" % self._type.lower()) | ||||||
|         error = validate(value) |         error = validate(value) | ||||||
|         if error: |         if error: | ||||||
|             self.error(error) |             self.error(error) | ||||||
| @@ -507,7 +522,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         try: |         try: | ||||||
|             value[0][0][0] |             value[0][0][0] | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             return 'Invalid Polygon must contain at least one valid linestring' |             return "Invalid Polygon must contain at least one valid linestring" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
|         for val in value: |         for val in value: | ||||||
| @@ -518,12 +533,12 @@ class GeoJsonBaseField(BaseField): | |||||||
|                 errors.append(error) |                 errors.append(error) | ||||||
|         if errors: |         if errors: | ||||||
|             if top_level: |             if top_level: | ||||||
|                 return 'Invalid Polygon:\n%s' % ', '.join(errors) |                 return "Invalid Polygon:\n%s" % ", ".join(errors) | ||||||
|             else: |             else: | ||||||
|                 return '%s' % ', '.join(errors) |                 return "%s" % ", ".join(errors) | ||||||
|  |  | ||||||
|     def _validate_linestring(self, value, top_level=True): |     def _validate_linestring(self, value, top_level=True): | ||||||
|         """Validate a linestring.""" |         """Validates a linestring""" | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
|             return 'LineStrings must contain list of coordinate pairs' |             return 'LineStrings must contain list of coordinate pairs' | ||||||
|  |  | ||||||
| @@ -531,7 +546,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         try: |         try: | ||||||
|             value[0][0] |             value[0][0] | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             return 'Invalid LineString must contain at least one valid point' |             return "Invalid LineString must contain at least one valid point" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
|         for val in value: |         for val in value: | ||||||
| @@ -540,19 +555,19 @@ class GeoJsonBaseField(BaseField): | |||||||
|                 errors.append(error) |                 errors.append(error) | ||||||
|         if errors: |         if errors: | ||||||
|             if top_level: |             if top_level: | ||||||
|                 return 'Invalid LineString:\n%s' % ', '.join(errors) |                 return "Invalid LineString:\n%s" % ", ".join(errors) | ||||||
|             else: |             else: | ||||||
|                 return '%s' % ', '.join(errors) |                 return "%s" % ", ".join(errors) | ||||||
|  |  | ||||||
|     def _validate_point(self, value): |     def _validate_point(self, value): | ||||||
|         """Validate each set of coords""" |         """Validate each set of coords""" | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
|             return 'Points must be a list of coordinate pairs' |             return 'Points must be a list of coordinate pairs' | ||||||
|         elif not len(value) == 2: |         elif not len(value) == 2: | ||||||
|             return 'Value (%s) must be a two-dimensional point' % repr(value) |             return "Value (%s) must be a two-dimensional point" % repr(value) | ||||||
|         elif (not isinstance(value[0], (float, int)) or |         elif (not isinstance(value[0], (float, int)) or | ||||||
|               not isinstance(value[1], (float, int))): |               not isinstance(value[1], (float, int))): | ||||||
|             return 'Both values (%s) in point must be float or int' % repr(value) |             return "Both values (%s) in point must be float or int" % repr(value) | ||||||
|  |  | ||||||
|     def _validate_multipoint(self, value): |     def _validate_multipoint(self, value): | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
| @@ -562,7 +577,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         try: |         try: | ||||||
|             value[0][0] |             value[0][0] | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             return 'Invalid MultiPoint must contain at least one valid point' |             return "Invalid MultiPoint must contain at least one valid point" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
|         for point in value: |         for point in value: | ||||||
| @@ -571,7 +586,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|                 errors.append(error) |                 errors.append(error) | ||||||
|  |  | ||||||
|         if errors: |         if errors: | ||||||
|             return '%s' % ', '.join(errors) |             return "%s" % ", ".join(errors) | ||||||
|  |  | ||||||
|     def _validate_multilinestring(self, value, top_level=True): |     def _validate_multilinestring(self, value, top_level=True): | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
| @@ -581,7 +596,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         try: |         try: | ||||||
|             value[0][0][0] |             value[0][0][0] | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             return 'Invalid MultiLineString must contain at least one valid linestring' |             return "Invalid MultiLineString must contain at least one valid linestring" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
|         for linestring in value: |         for linestring in value: | ||||||
| @@ -591,9 +606,9 @@ class GeoJsonBaseField(BaseField): | |||||||
|  |  | ||||||
|         if errors: |         if errors: | ||||||
|             if top_level: |             if top_level: | ||||||
|                 return 'Invalid MultiLineString:\n%s' % ', '.join(errors) |                 return "Invalid MultiLineString:\n%s" % ", ".join(errors) | ||||||
|             else: |             else: | ||||||
|                 return '%s' % ', '.join(errors) |                 return "%s" % ", ".join(errors) | ||||||
|  |  | ||||||
|     def _validate_multipolygon(self, value): |     def _validate_multipolygon(self, value): | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
| @@ -603,7 +618,7 @@ class GeoJsonBaseField(BaseField): | |||||||
|         try: |         try: | ||||||
|             value[0][0][0][0] |             value[0][0][0][0] | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             return 'Invalid MultiPolygon must contain at least one valid Polygon' |             return "Invalid MultiPolygon must contain at least one valid Polygon" | ||||||
|  |  | ||||||
|         errors = [] |         errors = [] | ||||||
|         for polygon in value: |         for polygon in value: | ||||||
| @@ -612,9 +627,9 @@ class GeoJsonBaseField(BaseField): | |||||||
|                 errors.append(error) |                 errors.append(error) | ||||||
|  |  | ||||||
|         if errors: |         if errors: | ||||||
|             return 'Invalid MultiPolygon:\n%s' % ', '.join(errors) |             return "Invalid MultiPolygon:\n%s" % ", ".join(errors) | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         if isinstance(value, dict): |         if isinstance(value, dict): | ||||||
|             return value |             return value | ||||||
|         return SON([('type', self._type), ('coordinates', value)]) |         return SON([("type", self._type), ("coordinates", value)]) | ||||||
|   | |||||||
| @@ -1,11 +1,10 @@ | |||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| import six | from mongoengine.base.common import ALLOW_INHERITANCE, _document_registry | ||||||
|  |  | ||||||
| from mongoengine.base.common import _document_registry |  | ||||||
| from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import InvalidDocumentError | from mongoengine.errors import InvalidDocumentError | ||||||
|  | from mongoengine.python_support import PY3 | ||||||
| from mongoengine.queryset import (DO_NOTHING, DoesNotExist, | from mongoengine.queryset import (DO_NOTHING, DoesNotExist, | ||||||
|                                   MultipleObjectsReturned, |                                   MultipleObjectsReturned, | ||||||
|                                   QuerySetManager) |                                   QuerySetManager) | ||||||
| @@ -46,8 +45,7 @@ class DocumentMetaclass(type): | |||||||
|             attrs['_meta'] = meta |             attrs['_meta'] = meta | ||||||
|             attrs['_meta']['abstract'] = False  # 789: EmbeddedDocument shouldn't inherit abstract |             attrs['_meta']['abstract'] = False  # 789: EmbeddedDocument shouldn't inherit abstract | ||||||
|  |  | ||||||
|         # If allow_inheritance is True, add a "_cls" string field to the attrs |         if attrs['_meta'].get('allow_inheritance', ALLOW_INHERITANCE): | ||||||
|         if attrs['_meta'].get('allow_inheritance'): |  | ||||||
|             StringField = _import_class('StringField') |             StringField = _import_class('StringField') | ||||||
|             attrs['_cls'] = StringField() |             attrs['_cls'] = StringField() | ||||||
|  |  | ||||||
| @@ -89,17 +87,16 @@ class DocumentMetaclass(type): | |||||||
|         # Ensure no duplicate db_fields |         # Ensure no duplicate db_fields | ||||||
|         duplicate_db_fields = [k for k, v in field_names.items() if v > 1] |         duplicate_db_fields = [k for k, v in field_names.items() if v > 1] | ||||||
|         if duplicate_db_fields: |         if duplicate_db_fields: | ||||||
|             msg = ('Multiple db_fields defined for: %s ' % |             msg = ("Multiple db_fields defined for: %s " % | ||||||
|                    ', '.join(duplicate_db_fields)) |                    ", ".join(duplicate_db_fields)) | ||||||
|             raise InvalidDocumentError(msg) |             raise InvalidDocumentError(msg) | ||||||
|  |  | ||||||
|         # Set _fields and db_field maps |         # Set _fields and db_field maps | ||||||
|         attrs['_fields'] = doc_fields |         attrs['_fields'] = doc_fields | ||||||
|         attrs['_db_field_map'] = {k: getattr(v, 'db_field', k) |         attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) | ||||||
|                                   for k, v in doc_fields.items()} |                                        for k, v in doc_fields.iteritems()]) | ||||||
|         attrs['_reverse_db_field_map'] = { |         attrs['_reverse_db_field_map'] = dict( | ||||||
|             v: k for k, v in attrs['_db_field_map'].items() |             (v, k) for k, v in attrs['_db_field_map'].iteritems()) | ||||||
|         } |  | ||||||
|  |  | ||||||
|         attrs['_fields_ordered'] = tuple(i[1] for i in sorted( |         attrs['_fields_ordered'] = tuple(i[1] for i in sorted( | ||||||
|                                          (v.creation_counter, v.name) |                                          (v.creation_counter, v.name) | ||||||
| @@ -119,8 +116,10 @@ class DocumentMetaclass(type): | |||||||
|             if hasattr(base, '_meta'): |             if hasattr(base, '_meta'): | ||||||
|                 # Warn if allow_inheritance isn't set and prevent |                 # Warn if allow_inheritance isn't set and prevent | ||||||
|                 # inheritance of classes where inheritance is set to False |                 # inheritance of classes where inheritance is set to False | ||||||
|                 allow_inheritance = base._meta.get('allow_inheritance') |                 allow_inheritance = base._meta.get('allow_inheritance', | ||||||
|                 if not allow_inheritance and not base._meta.get('abstract'): |                                                    ALLOW_INHERITANCE) | ||||||
|  |                 if (allow_inheritance is not True and | ||||||
|  |                         not base._meta.get('abstract')): | ||||||
|                     raise ValueError('Document %s may not be subclassed' % |                     raise ValueError('Document %s may not be subclassed' % | ||||||
|                                      base.__name__) |                                      base.__name__) | ||||||
|  |  | ||||||
| @@ -162,7 +161,7 @@ class DocumentMetaclass(type): | |||||||
|         # module continues to use im_func and im_self, so the code below |         # module continues to use im_func and im_self, so the code below | ||||||
|         # copies __func__ into im_func and __self__ into im_self for |         # copies __func__ into im_func and __self__ into im_self for | ||||||
|         # classmethod objects in Document derived classes. |         # classmethod objects in Document derived classes. | ||||||
|         if six.PY3: |         if PY3: | ||||||
|             for val in new_class.__dict__.values(): |             for val in new_class.__dict__.values(): | ||||||
|                 if isinstance(val, classmethod): |                 if isinstance(val, classmethod): | ||||||
|                     f = val.__get__(new_class) |                     f = val.__get__(new_class) | ||||||
| @@ -180,11 +179,11 @@ class DocumentMetaclass(type): | |||||||
|             if isinstance(f, CachedReferenceField): |             if isinstance(f, CachedReferenceField): | ||||||
|  |  | ||||||
|                 if issubclass(new_class, EmbeddedDocument): |                 if issubclass(new_class, EmbeddedDocument): | ||||||
|                     raise InvalidDocumentError('CachedReferenceFields is not ' |                     raise InvalidDocumentError( | ||||||
|                                                'allowed in EmbeddedDocuments') |                         "CachedReferenceFields is not allowed in EmbeddedDocuments") | ||||||
|                 if not f.document_type: |                 if not f.document_type: | ||||||
|                     raise InvalidDocumentError( |                     raise InvalidDocumentError( | ||||||
|                         'Document is not available to sync') |                         "Document is not available to sync") | ||||||
|  |  | ||||||
|                 if f.auto_sync: |                 if f.auto_sync: | ||||||
|                     f.start_listener() |                     f.start_listener() | ||||||
| @@ -196,8 +195,8 @@ class DocumentMetaclass(type): | |||||||
|                                       'reverse_delete_rule', |                                       'reverse_delete_rule', | ||||||
|                                       DO_NOTHING) |                                       DO_NOTHING) | ||||||
|                 if isinstance(f, DictField) and delete_rule != DO_NOTHING: |                 if isinstance(f, DictField) and delete_rule != DO_NOTHING: | ||||||
|                     msg = ('Reverse delete rules are not supported ' |                     msg = ("Reverse delete rules are not supported " | ||||||
|                            'for %s (field: %s)' % |                            "for %s (field: %s)" % | ||||||
|                            (field.__class__.__name__, field.name)) |                            (field.__class__.__name__, field.name)) | ||||||
|                     raise InvalidDocumentError(msg) |                     raise InvalidDocumentError(msg) | ||||||
|  |  | ||||||
| @@ -205,16 +204,16 @@ class DocumentMetaclass(type): | |||||||
|  |  | ||||||
|             if delete_rule != DO_NOTHING: |             if delete_rule != DO_NOTHING: | ||||||
|                 if issubclass(new_class, EmbeddedDocument): |                 if issubclass(new_class, EmbeddedDocument): | ||||||
|                     msg = ('Reverse delete rules are not supported for ' |                     msg = ("Reverse delete rules are not supported for " | ||||||
|                            'EmbeddedDocuments (field: %s)' % field.name) |                            "EmbeddedDocuments (field: %s)" % field.name) | ||||||
|                     raise InvalidDocumentError(msg) |                     raise InvalidDocumentError(msg) | ||||||
|                 f.document_type.register_delete_rule(new_class, |                 f.document_type.register_delete_rule(new_class, | ||||||
|                                                      field.name, delete_rule) |                                                      field.name, delete_rule) | ||||||
|  |  | ||||||
|             if (field.name and hasattr(Document, field.name) and |             if (field.name and hasattr(Document, field.name) and | ||||||
|                     EmbeddedDocument not in new_class.mro()): |                     EmbeddedDocument not in new_class.mro()): | ||||||
|                 msg = ('%s is a document method and not a valid ' |                 msg = ("%s is a document method and not a valid " | ||||||
|                        'field name' % field.name) |                        "field name" % field.name) | ||||||
|                 raise InvalidDocumentError(msg) |                 raise InvalidDocumentError(msg) | ||||||
|  |  | ||||||
|         return new_class |         return new_class | ||||||
| @@ -272,11 +271,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|                 'index_drop_dups': False, |                 'index_drop_dups': False, | ||||||
|                 'index_opts': None, |                 'index_opts': None, | ||||||
|                 'delete_rules': None, |                 'delete_rules': None, | ||||||
|  |  | ||||||
|                 # allow_inheritance can be True, False, and None. True means |  | ||||||
|                 # "allow inheritance", False means "don't allow inheritance", |  | ||||||
|                 # None means "do whatever your parent does, or don't allow |  | ||||||
|                 # inheritance if you're a top-level class". |  | ||||||
|                 'allow_inheritance': None, |                 'allow_inheritance': None, | ||||||
|             } |             } | ||||||
|             attrs['_is_base_cls'] = True |             attrs['_is_base_cls'] = True | ||||||
| @@ -309,7 +303,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|         # If parent wasn't an abstract class |         # If parent wasn't an abstract class | ||||||
|         if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and |         if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and | ||||||
|                 not parent_doc_cls._meta.get('abstract', True)): |                 not parent_doc_cls._meta.get('abstract', True)): | ||||||
|             msg = 'Trying to set a collection on a subclass (%s)' % name |             msg = "Trying to set a collection on a subclass (%s)" % name | ||||||
|             warnings.warn(msg, SyntaxWarning) |             warnings.warn(msg, SyntaxWarning) | ||||||
|             del attrs['_meta']['collection'] |             del attrs['_meta']['collection'] | ||||||
|  |  | ||||||
| @@ -317,7 +311,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|         if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): |         if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): | ||||||
|             if (parent_doc_cls and |             if (parent_doc_cls and | ||||||
|                     not parent_doc_cls._meta.get('abstract', False)): |                     not parent_doc_cls._meta.get('abstract', False)): | ||||||
|                 msg = 'Abstract document cannot have non-abstract base' |                 msg = "Abstract document cannot have non-abstract base" | ||||||
|                 raise ValueError(msg) |                 raise ValueError(msg) | ||||||
|             return super_new(cls, name, bases, attrs) |             return super_new(cls, name, bases, attrs) | ||||||
|  |  | ||||||
| @@ -340,16 +334,12 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|  |  | ||||||
|         meta.merge(attrs.get('_meta', {}))  # Top level meta |         meta.merge(attrs.get('_meta', {}))  # Top level meta | ||||||
|  |  | ||||||
|         # Only simple classes (i.e. direct subclasses of Document) may set |         # Only simple classes (direct subclasses of Document) | ||||||
|         # allow_inheritance to False. If the base Document allows inheritance, |         # may set allow_inheritance to False | ||||||
|         # none of its subclasses can override allow_inheritance to False. |  | ||||||
|         simple_class = all([b._meta.get('abstract') |         simple_class = all([b._meta.get('abstract') | ||||||
|                             for b in flattened_bases if hasattr(b, '_meta')]) |                             for b in flattened_bases if hasattr(b, '_meta')]) | ||||||
|         if ( |         if (not simple_class and meta['allow_inheritance'] is False and | ||||||
|             not simple_class and |                 not meta['abstract']): | ||||||
|             meta['allow_inheritance'] is False and |  | ||||||
|             not meta['abstract'] |  | ||||||
|         ): |  | ||||||
|             raise ValueError('Only direct subclasses of Document may set ' |             raise ValueError('Only direct subclasses of Document may set ' | ||||||
|                              '"allow_inheritance" to False') |                              '"allow_inheritance" to False') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -34,10 +34,7 @@ def _import_class(cls_name): | |||||||
|     queryset_classes = ('OperationError',) |     queryset_classes = ('OperationError',) | ||||||
|     deref_classes = ('DeReference',) |     deref_classes = ('DeReference',) | ||||||
|  |  | ||||||
|     if cls_name == 'BaseDocument': |     if cls_name in doc_classes: | ||||||
|         from mongoengine.base import document as module |  | ||||||
|         import_classes = ['BaseDocument'] |  | ||||||
|     elif cls_name in doc_classes: |  | ||||||
|         from mongoengine import document as module |         from mongoengine import document as module | ||||||
|         import_classes = doc_classes |         import_classes = doc_classes | ||||||
|     elif cls_name in field_classes: |     elif cls_name in field_classes: | ||||||
|   | |||||||
| @@ -1,9 +1,7 @@ | |||||||
| from pymongo import MongoClient, ReadPreference, uri_parser | from pymongo import MongoClient, ReadPreference, uri_parser | ||||||
| import six | from mongoengine.python_support import (IS_PYMONGO_3, str_types) | ||||||
|  |  | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||||
|  |  | ||||||
| __all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', |  | ||||||
|            'DEFAULT_CONNECTION_NAME'] |            'DEFAULT_CONNECTION_NAME'] | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -16,10 +14,7 @@ else: | |||||||
|     READ_PREFERENCE = False |     READ_PREFERENCE = False | ||||||
|  |  | ||||||
|  |  | ||||||
| class MongoEngineConnectionError(Exception): | class ConnectionError(Exception): | ||||||
|     """Error raised when the database connection can't be established or |  | ||||||
|     when a connection with a requested alias can't be retrieved. |  | ||||||
|     """ |  | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -55,6 +50,8 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|  |  | ||||||
|     .. versionchanged:: 0.10.6 - added mongomock support |     .. versionchanged:: 0.10.6 - added mongomock support | ||||||
|     """ |     """ | ||||||
|  |     global _connection_settings | ||||||
|  |  | ||||||
|     conn_settings = { |     conn_settings = { | ||||||
|         'name': name or 'test', |         'name': name or 'test', | ||||||
|         'host': host or 'localhost', |         'host': host or 'localhost', | ||||||
| @@ -66,10 +63,10 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|         'authentication_mechanism': authentication_mechanism |         'authentication_mechanism': authentication_mechanism | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     # Handle uri style connections | ||||||
|     conn_host = conn_settings['host'] |     conn_host = conn_settings['host'] | ||||||
|  |     # host can be a list or a string, so if string, force to a list | ||||||
|     # Host can be a list or a string, so if string, force to a list. |     if isinstance(conn_host, str_types): | ||||||
|     if isinstance(conn_host, six.string_types): |  | ||||||
|         conn_host = [conn_host] |         conn_host = [conn_host] | ||||||
|  |  | ||||||
|     resolved_hosts = [] |     resolved_hosts = [] | ||||||
| @@ -96,7 +93,7 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|  |  | ||||||
|             uri_options = uri_dict['options'] |             uri_options = uri_dict['options'] | ||||||
|             if 'replicaset' in uri_options: |             if 'replicaset' in uri_options: | ||||||
|                 conn_settings['replicaSet'] = uri_options['replicaset'] |                 conn_settings['replicaSet'] = True | ||||||
|             if 'authsource' in uri_options: |             if 'authsource' in uri_options: | ||||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] |                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||||
|             if 'authmechanism' in uri_options: |             if 'authmechanism' in uri_options: | ||||||
| @@ -114,7 +111,9 @@ def register_connection(alias, name=None, host=None, port=None, | |||||||
|  |  | ||||||
|  |  | ||||||
| def disconnect(alias=DEFAULT_CONNECTION_NAME): | def disconnect(alias=DEFAULT_CONNECTION_NAME): | ||||||
|     """Close the connection with a given alias.""" |     global _connections | ||||||
|  |     global _dbs | ||||||
|  |  | ||||||
|     if alias in _connections: |     if alias in _connections: | ||||||
|         get_connection(alias=alias).close() |         get_connection(alias=alias).close() | ||||||
|         del _connections[alias] |         del _connections[alias] | ||||||
| @@ -123,99 +122,71 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME): | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||||
|     """Return a connection with a given alias.""" |     global _connections | ||||||
|  |  | ||||||
|     # Connect to the database if not already connected |     # Connect to the database if not already connected | ||||||
|     if reconnect: |     if reconnect: | ||||||
|         disconnect(alias) |         disconnect(alias) | ||||||
|  |  | ||||||
|     # If the requested alias already exists in the _connections list, return |     if alias not in _connections: | ||||||
|     # it immediately. |         if alias not in _connection_settings: | ||||||
|     if alias in _connections: |  | ||||||
|         return _connections[alias] |  | ||||||
|  |  | ||||||
|     # Validate that the requested alias exists in the _connection_settings. |  | ||||||
|     # Raise MongoEngineConnectionError if it doesn't. |  | ||||||
|     if alias not in _connection_settings: |  | ||||||
|         if alias == DEFAULT_CONNECTION_NAME: |  | ||||||
|             msg = 'You have not defined a default connection' |  | ||||||
|         else: |  | ||||||
|             msg = 'Connection with alias "%s" has not been defined' % alias |             msg = 'Connection with alias "%s" has not been defined' % alias | ||||||
|         raise MongoEngineConnectionError(msg) |             if alias == DEFAULT_CONNECTION_NAME: | ||||||
|  |                 msg = 'You have not defined a default connection' | ||||||
|  |             raise ConnectionError(msg) | ||||||
|  |         conn_settings = _connection_settings[alias].copy() | ||||||
|  |  | ||||||
|     def _clean_settings(settings_dict): |         conn_settings.pop('name', None) | ||||||
|         irrelevant_fields = set([ |         conn_settings.pop('username', None) | ||||||
|             'name', 'username', 'password', 'authentication_source', |         conn_settings.pop('password', None) | ||||||
|             'authentication_mechanism' |         conn_settings.pop('authentication_source', None) | ||||||
|         ]) |         conn_settings.pop('authentication_mechanism', None) | ||||||
|         return { |  | ||||||
|             k: v for k, v in settings_dict.items() |  | ||||||
|             if k not in irrelevant_fields |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|     # Retrieve a copy of the connection settings associated with the requested |         is_mock = conn_settings.pop('is_mock', None) | ||||||
|     # alias and remove the database name and authentication info (we don't |         if is_mock: | ||||||
|     # care about them at this point). |             # Use MongoClient from mongomock | ||||||
|     conn_settings = _clean_settings(_connection_settings[alias].copy()) |             try: | ||||||
|  |                 import mongomock | ||||||
|     # Determine if we should use PyMongo's or mongomock's MongoClient. |             except ImportError: | ||||||
|     is_mock = conn_settings.pop('is_mock', False) |                 raise RuntimeError('You need mongomock installed ' | ||||||
|     if is_mock: |                                    'to mock MongoEngine.') | ||||||
|         try: |             connection_class = mongomock.MongoClient | ||||||
|             import mongomock |         else: | ||||||
|         except ImportError: |             # Use MongoClient from pymongo | ||||||
|             raise RuntimeError('You need mongomock installed to mock ' |             connection_class = MongoClient | ||||||
|                                'MongoEngine.') |  | ||||||
|         connection_class = mongomock.MongoClient |  | ||||||
|     else: |  | ||||||
|         connection_class = MongoClient |  | ||||||
|  |  | ||||||
|         # For replica set connections with PyMongo 2.x, use |  | ||||||
|         # MongoReplicaSetClient. |  | ||||||
|         # TODO remove this once we stop supporting PyMongo 2.x. |  | ||||||
|         if 'replicaSet' in conn_settings and not IS_PYMONGO_3: |  | ||||||
|             connection_class = MongoReplicaSetClient |  | ||||||
|             conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) |  | ||||||
|  |  | ||||||
|             # hosts_or_uri has to be a string, so if 'host' was provided |  | ||||||
|             # as a list, join its parts and separate them by ',' |  | ||||||
|             if isinstance(conn_settings['hosts_or_uri'], list): |  | ||||||
|                 conn_settings['hosts_or_uri'] = ','.join( |  | ||||||
|                     conn_settings['hosts_or_uri']) |  | ||||||
|  |  | ||||||
|  |         if 'replicaSet' in conn_settings: | ||||||
|             # Discard port since it can't be used on MongoReplicaSetClient |             # Discard port since it can't be used on MongoReplicaSetClient | ||||||
|             conn_settings.pop('port', None) |             conn_settings.pop('port', None) | ||||||
|  |             # Discard replicaSet if not base string | ||||||
|  |             if not isinstance(conn_settings['replicaSet'], basestring): | ||||||
|  |                 conn_settings.pop('replicaSet', None) | ||||||
|  |             if not IS_PYMONGO_3: | ||||||
|  |                 connection_class = MongoReplicaSetClient | ||||||
|  |                 conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) | ||||||
|  |  | ||||||
|     # Iterate over all of the connection settings and if a connection with |  | ||||||
|     # the same parameters is already established, use it instead of creating |  | ||||||
|     # a new one. |  | ||||||
|     existing_connection = None |  | ||||||
|     connection_settings_iterator = ( |  | ||||||
|         (db_alias, settings.copy()) |  | ||||||
|         for db_alias, settings in _connection_settings.items() |  | ||||||
|     ) |  | ||||||
|     for db_alias, connection_settings in connection_settings_iterator: |  | ||||||
|         connection_settings = _clean_settings(connection_settings) |  | ||||||
|         if conn_settings == connection_settings and _connections.get(db_alias): |  | ||||||
|             existing_connection = _connections[db_alias] |  | ||||||
|             break |  | ||||||
|  |  | ||||||
|     # If an existing connection was found, assign it to the new alias |  | ||||||
|     if existing_connection: |  | ||||||
|         _connections[alias] = existing_connection |  | ||||||
|     else: |  | ||||||
|         # Otherwise, create the new connection for this alias. Raise |  | ||||||
|         # MongoEngineConnectionError if it can't be established. |  | ||||||
|         try: |         try: | ||||||
|             _connections[alias] = connection_class(**conn_settings) |             connection = None | ||||||
|         except Exception as e: |             # check for shared connections | ||||||
|             raise MongoEngineConnectionError( |             connection_settings_iterator = ( | ||||||
|                 'Cannot connect to database %s :\n%s' % (alias, e)) |                 (db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems()) | ||||||
|  |             for db_alias, connection_settings in connection_settings_iterator: | ||||||
|  |                 connection_settings.pop('name', None) | ||||||
|  |                 connection_settings.pop('username', None) | ||||||
|  |                 connection_settings.pop('password', None) | ||||||
|  |                 connection_settings.pop('authentication_source', None) | ||||||
|  |                 connection_settings.pop('authentication_mechanism', None) | ||||||
|  |                 if conn_settings == connection_settings and _connections.get(db_alias, None): | ||||||
|  |                     connection = _connections[db_alias] | ||||||
|  |                     break | ||||||
|  |  | ||||||
|  |             _connections[alias] = connection if connection else connection_class(**conn_settings) | ||||||
|  |         except Exception, e: | ||||||
|  |             raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) | ||||||
|     return _connections[alias] |     return _connections[alias] | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||||
|  |     global _dbs | ||||||
|     if reconnect: |     if reconnect: | ||||||
|         disconnect(alias) |         disconnect(alias) | ||||||
|  |  | ||||||
| @@ -246,6 +217,7 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): | |||||||
|  |  | ||||||
|     .. versionchanged:: 0.6 - added multiple database support. |     .. versionchanged:: 0.6 - added multiple database support. | ||||||
|     """ |     """ | ||||||
|  |     global _connections | ||||||
|     if alias not in _connections: |     if alias not in _connections: | ||||||
|         register_connection(alias, db, **kwargs) |         register_connection(alias, db, **kwargs) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,12 +2,12 @@ from mongoengine.common import _import_class | |||||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ('switch_db', 'switch_collection', 'no_dereference', | __all__ = ("switch_db", "switch_collection", "no_dereference", | ||||||
|            'no_sub_classes', 'query_counter') |            "no_sub_classes", "query_counter") | ||||||
|  |  | ||||||
|  |  | ||||||
| class switch_db(object): | class switch_db(object): | ||||||
|     """switch_db alias context manager. |     """ switch_db alias context manager. | ||||||
|  |  | ||||||
|     Example :: |     Example :: | ||||||
|  |  | ||||||
| @@ -18,14 +18,15 @@ class switch_db(object): | |||||||
|         class Group(Document): |         class Group(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         Group(name='test').save()  # Saves in the default db |         Group(name="test").save()  # Saves in the default db | ||||||
|  |  | ||||||
|         with switch_db(Group, 'testdb-1') as Group: |         with switch_db(Group, 'testdb-1') as Group: | ||||||
|             Group(name='hello testdb!').save()  # Saves in testdb-1 |             Group(name="hello testdb!").save()  # Saves in testdb-1 | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, cls, db_alias): |     def __init__(self, cls, db_alias): | ||||||
|         """Construct the switch_db context manager |         """ Construct the switch_db context manager | ||||||
|  |  | ||||||
|         :param cls: the class to change the registered db |         :param cls: the class to change the registered db | ||||||
|         :param db_alias: the name of the specific database to use |         :param db_alias: the name of the specific database to use | ||||||
| @@ -33,36 +34,37 @@ class switch_db(object): | |||||||
|         self.cls = cls |         self.cls = cls | ||||||
|         self.collection = cls._get_collection() |         self.collection = cls._get_collection() | ||||||
|         self.db_alias = db_alias |         self.db_alias = db_alias | ||||||
|         self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME) |         self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         """Change the db_alias and clear the cached collection.""" |         """ change the db_alias and clear the cached collection """ | ||||||
|         self.cls._meta['db_alias'] = self.db_alias |         self.cls._meta["db_alias"] = self.db_alias | ||||||
|         self.cls._collection = None |         self.cls._collection = None | ||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|     def __exit__(self, t, value, traceback): |     def __exit__(self, t, value, traceback): | ||||||
|         """Reset the db_alias and collection.""" |         """ Reset the db_alias and collection """ | ||||||
|         self.cls._meta['db_alias'] = self.ori_db_alias |         self.cls._meta["db_alias"] = self.ori_db_alias | ||||||
|         self.cls._collection = self.collection |         self.cls._collection = self.collection | ||||||
|  |  | ||||||
|  |  | ||||||
| class switch_collection(object): | class switch_collection(object): | ||||||
|     """switch_collection alias context manager. |     """ switch_collection alias context manager. | ||||||
|  |  | ||||||
|     Example :: |     Example :: | ||||||
|  |  | ||||||
|         class Group(Document): |         class Group(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         Group(name='test').save()  # Saves in the default db |         Group(name="test").save()  # Saves in the default db | ||||||
|  |  | ||||||
|         with switch_collection(Group, 'group1') as Group: |         with switch_collection(Group, 'group1') as Group: | ||||||
|             Group(name='hello testdb!').save()  # Saves in group1 collection |             Group(name="hello testdb!").save()  # Saves in group1 collection | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, cls, collection_name): |     def __init__(self, cls, collection_name): | ||||||
|         """Construct the switch_collection context manager. |         """ Construct the switch_collection context manager | ||||||
|  |  | ||||||
|         :param cls: the class to change the registered db |         :param cls: the class to change the registered db | ||||||
|         :param collection_name: the name of the collection to use |         :param collection_name: the name of the collection to use | ||||||
| @@ -73,7 +75,7 @@ class switch_collection(object): | |||||||
|         self.collection_name = collection_name |         self.collection_name = collection_name | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         """Change the _get_collection_name and clear the cached collection.""" |         """ change the _get_collection_name and clear the cached collection """ | ||||||
|  |  | ||||||
|         @classmethod |         @classmethod | ||||||
|         def _get_collection_name(cls): |         def _get_collection_name(cls): | ||||||
| @@ -84,23 +86,24 @@ class switch_collection(object): | |||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|     def __exit__(self, t, value, traceback): |     def __exit__(self, t, value, traceback): | ||||||
|         """Reset the collection.""" |         """ Reset the collection """ | ||||||
|         self.cls._collection = self.ori_collection |         self.cls._collection = self.ori_collection | ||||||
|         self.cls._get_collection_name = self.ori_get_collection_name |         self.cls._get_collection_name = self.ori_get_collection_name | ||||||
|  |  | ||||||
|  |  | ||||||
| class no_dereference(object): | class no_dereference(object): | ||||||
|     """no_dereference context manager. |     """ no_dereference context manager. | ||||||
|  |  | ||||||
|     Turns off all dereferencing in Documents for the duration of the context |     Turns off all dereferencing in Documents for the duration of the context | ||||||
|     manager:: |     manager:: | ||||||
|  |  | ||||||
|         with no_dereference(Group) as Group: |         with no_dereference(Group) as Group: | ||||||
|             Group.objects.find() |             Group.objects.find() | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, cls): |     def __init__(self, cls): | ||||||
|         """Construct the no_dereference context manager. |         """ Construct the no_dereference context manager. | ||||||
|  |  | ||||||
|         :param cls: the class to turn dereferencing off on |         :param cls: the class to turn dereferencing off on | ||||||
|         """ |         """ | ||||||
| @@ -116,102 +119,103 @@ class no_dereference(object): | |||||||
|                                                ComplexBaseField))] |                                                ComplexBaseField))] | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         """Change the objects default and _auto_dereference values.""" |         """ change the objects default and _auto_dereference values""" | ||||||
|         for field in self.deref_fields: |         for field in self.deref_fields: | ||||||
|             self.cls._fields[field]._auto_dereference = False |             self.cls._fields[field]._auto_dereference = False | ||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|     def __exit__(self, t, value, traceback): |     def __exit__(self, t, value, traceback): | ||||||
|         """Reset the default and _auto_dereference values.""" |         """ Reset the default and _auto_dereference values""" | ||||||
|         for field in self.deref_fields: |         for field in self.deref_fields: | ||||||
|             self.cls._fields[field]._auto_dereference = True |             self.cls._fields[field]._auto_dereference = True | ||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|  |  | ||||||
| class no_sub_classes(object): | class no_sub_classes(object): | ||||||
|     """no_sub_classes context manager. |     """ no_sub_classes context manager. | ||||||
|  |  | ||||||
|     Only returns instances of this class and no sub (inherited) classes:: |     Only returns instances of this class and no sub (inherited) classes:: | ||||||
|  |  | ||||||
|         with no_sub_classes(Group) as Group: |         with no_sub_classes(Group) as Group: | ||||||
|             Group.objects.find() |             Group.objects.find() | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, cls): |     def __init__(self, cls): | ||||||
|         """Construct the no_sub_classes context manager. |         """ Construct the no_sub_classes context manager. | ||||||
|  |  | ||||||
|         :param cls: the class to turn querying sub classes on |         :param cls: the class to turn querying sub classes on | ||||||
|         """ |         """ | ||||||
|         self.cls = cls |         self.cls = cls | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         """Change the objects default and _auto_dereference values.""" |         """ change the objects default and _auto_dereference values""" | ||||||
|         self.cls._all_subclasses = self.cls._subclasses |         self.cls._all_subclasses = self.cls._subclasses | ||||||
|         self.cls._subclasses = (self.cls,) |         self.cls._subclasses = (self.cls,) | ||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|     def __exit__(self, t, value, traceback): |     def __exit__(self, t, value, traceback): | ||||||
|         """Reset the default and _auto_dereference values.""" |         """ Reset the default and _auto_dereference values""" | ||||||
|         self.cls._subclasses = self.cls._all_subclasses |         self.cls._subclasses = self.cls._all_subclasses | ||||||
|         delattr(self.cls, '_all_subclasses') |         delattr(self.cls, '_all_subclasses') | ||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|  |  | ||||||
| class query_counter(object): | class query_counter(object): | ||||||
|     """Query_counter context manager to get the number of queries.""" |     """ Query_counter context manager to get the number of queries. """ | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         """Construct the query_counter.""" |         """ Construct the query_counter. """ | ||||||
|         self.counter = 0 |         self.counter = 0 | ||||||
|         self.db = get_db() |         self.db = get_db() | ||||||
|  |  | ||||||
|     def __enter__(self): |     def __enter__(self): | ||||||
|         """On every with block we need to drop the profile collection.""" |         """ On every with block we need to drop the profile collection. """ | ||||||
|         self.db.set_profiling_level(0) |         self.db.set_profiling_level(0) | ||||||
|         self.db.system.profile.drop() |         self.db.system.profile.drop() | ||||||
|         self.db.set_profiling_level(2) |         self.db.set_profiling_level(2) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def __exit__(self, t, value, traceback): |     def __exit__(self, t, value, traceback): | ||||||
|         """Reset the profiling level.""" |         """ Reset the profiling level. """ | ||||||
|         self.db.set_profiling_level(0) |         self.db.set_profiling_level(0) | ||||||
|  |  | ||||||
|     def __eq__(self, value): |     def __eq__(self, value): | ||||||
|         """== Compare querycounter.""" |         """ == Compare querycounter. """ | ||||||
|         counter = self._get_count() |         counter = self._get_count() | ||||||
|         return value == counter |         return value == counter | ||||||
|  |  | ||||||
|     def __ne__(self, value): |     def __ne__(self, value): | ||||||
|         """!= Compare querycounter.""" |         """ != Compare querycounter. """ | ||||||
|         return not self.__eq__(value) |         return not self.__eq__(value) | ||||||
|  |  | ||||||
|     def __lt__(self, value): |     def __lt__(self, value): | ||||||
|         """< Compare querycounter.""" |         """ < Compare querycounter. """ | ||||||
|         return self._get_count() < value |         return self._get_count() < value | ||||||
|  |  | ||||||
|     def __le__(self, value): |     def __le__(self, value): | ||||||
|         """<= Compare querycounter.""" |         """ <= Compare querycounter. """ | ||||||
|         return self._get_count() <= value |         return self._get_count() <= value | ||||||
|  |  | ||||||
|     def __gt__(self, value): |     def __gt__(self, value): | ||||||
|         """> Compare querycounter.""" |         """ > Compare querycounter. """ | ||||||
|         return self._get_count() > value |         return self._get_count() > value | ||||||
|  |  | ||||||
|     def __ge__(self, value): |     def __ge__(self, value): | ||||||
|         """>= Compare querycounter.""" |         """ >= Compare querycounter. """ | ||||||
|         return self._get_count() >= value |         return self._get_count() >= value | ||||||
|  |  | ||||||
|     def __int__(self): |     def __int__(self): | ||||||
|         """int representation.""" |         """ int representation. """ | ||||||
|         return self._get_count() |         return self._get_count() | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         """repr query_counter as the number of queries.""" |         """ repr query_counter as the number of queries. """ | ||||||
|         return u"%s" % self._get_count() |         return u"%s" % self._get_count() | ||||||
|  |  | ||||||
|     def _get_count(self): |     def _get_count(self): | ||||||
|         """Get the number of queries.""" |         """ Get the number of queries. """ | ||||||
|         ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} |         ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} | ||||||
|         count = self.db.system.profile.find(ignore_query).count() - self.counter |         count = self.db.system.profile.find(ignore_query).count() - self.counter | ||||||
|         self.counter += 1 |         self.counter += 1 | ||||||
|         return count |         return count | ||||||
|   | |||||||
| @@ -1,12 +1,14 @@ | |||||||
| from bson import DBRef, SON | from bson import DBRef, SON | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, | from .base import ( | ||||||
|                               TopLevelDocumentMetaclass, get_document) |     BaseDict, BaseList, EmbeddedDocumentList, | ||||||
| from mongoengine.connection import get_db |     TopLevelDocumentMetaclass, get_document | ||||||
| from mongoengine.document import Document, EmbeddedDocument | ) | ||||||
| from mongoengine.fields import DictField, ListField, MapField, ReferenceField | from .connection import get_db | ||||||
| from mongoengine.queryset import QuerySet | from .document import Document, EmbeddedDocument | ||||||
|  | from .fields import DictField, ListField, MapField, ReferenceField | ||||||
|  | from .python_support import txt_type | ||||||
|  | from .queryset import QuerySet | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeReference(object): | class DeReference(object): | ||||||
| @@ -23,7 +25,7 @@ class DeReference(object): | |||||||
|             :class:`~mongoengine.base.ComplexBaseField` |             :class:`~mongoengine.base.ComplexBaseField` | ||||||
|         :param get: A boolean determining if being called by __get__ |         :param get: A boolean determining if being called by __get__ | ||||||
|         """ |         """ | ||||||
|         if items is None or isinstance(items, six.string_types): |         if items is None or isinstance(items, basestring): | ||||||
|             return items |             return items | ||||||
|  |  | ||||||
|         # cheapest way to convert a queryset to a list |         # cheapest way to convert a queryset to a list | ||||||
| @@ -66,11 +68,11 @@ class DeReference(object): | |||||||
|  |  | ||||||
|                         items = _get_items(items) |                         items = _get_items(items) | ||||||
|                     else: |                     else: | ||||||
|                         items = { |                         items = dict([ | ||||||
|                             k: (v if isinstance(v, (DBRef, Document)) |                             (k, field.to_python(v)) | ||||||
|                                 else field.to_python(v)) |                             if not isinstance(v, (DBRef, Document)) else (k, v) | ||||||
|                             for k, v in items.iteritems() |                             for k, v in items.iteritems()] | ||||||
|                         } |                         ) | ||||||
|  |  | ||||||
|         self.reference_map = self._find_references(items) |         self.reference_map = self._find_references(items) | ||||||
|         self.object_map = self._fetch_objects(doc_type=doc_type) |         self.object_map = self._fetch_objects(doc_type=doc_type) | ||||||
| @@ -88,14 +90,14 @@ class DeReference(object): | |||||||
|             return reference_map |             return reference_map | ||||||
|  |  | ||||||
|         # Determine the iterator to use |         # Determine the iterator to use | ||||||
|         if isinstance(items, dict): |         if not hasattr(items, 'items'): | ||||||
|             iterator = items.values() |             iterator = enumerate(items) | ||||||
|         else: |         else: | ||||||
|             iterator = items |             iterator = items.iteritems() | ||||||
|  |  | ||||||
|         # Recursively find dbreferences |         # Recursively find dbreferences | ||||||
|         depth += 1 |         depth += 1 | ||||||
|         for item in iterator: |         for k, item in iterator: | ||||||
|             if isinstance(item, (Document, EmbeddedDocument)): |             if isinstance(item, (Document, EmbeddedDocument)): | ||||||
|                 for field_name, field in item._fields.iteritems(): |                 for field_name, field in item._fields.iteritems(): | ||||||
|                     v = item._data.get(field_name, None) |                     v = item._data.get(field_name, None) | ||||||
| @@ -149,7 +151,7 @@ class DeReference(object): | |||||||
|                     references = get_db()[collection].find({'_id': {'$in': refs}}) |                     references = get_db()[collection].find({'_id': {'$in': refs}}) | ||||||
|                     for ref in references: |                     for ref in references: | ||||||
|                         if '_cls' in ref: |                         if '_cls' in ref: | ||||||
|                             doc = get_document(ref['_cls'])._from_son(ref) |                             doc = get_document(ref["_cls"])._from_son(ref) | ||||||
|                         elif doc_type is None: |                         elif doc_type is None: | ||||||
|                             doc = get_document( |                             doc = get_document( | ||||||
|                                 ''.join(x.capitalize() |                                 ''.join(x.capitalize() | ||||||
| @@ -216,7 +218,7 @@ class DeReference(object): | |||||||
|             if k in self.object_map and not is_list: |             if k in self.object_map and not is_list: | ||||||
|                 data[k] = self.object_map[k] |                 data[k] = self.object_map[k] | ||||||
|             elif isinstance(v, (Document, EmbeddedDocument)): |             elif isinstance(v, (Document, EmbeddedDocument)): | ||||||
|                 for field_name in v._fields: |                 for field_name, field in v._fields.iteritems(): | ||||||
|                     v = data[k]._data.get(field_name, None) |                     v = data[k]._data.get(field_name, None) | ||||||
|                     if isinstance(v, DBRef): |                     if isinstance(v, DBRef): | ||||||
|                         data[k]._data[field_name] = self.object_map.get( |                         data[k]._data[field_name] = self.object_map.get( | ||||||
| @@ -225,7 +227,7 @@ class DeReference(object): | |||||||
|                         data[k]._data[field_name] = self.object_map.get( |                         data[k]._data[field_name] = self.object_map.get( | ||||||
|                             (v['_ref'].collection, v['_ref'].id), v) |                             (v['_ref'].collection, v['_ref'].id), v) | ||||||
|                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: |                     elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||||
|                         item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name) |                         item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) | ||||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) |                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) | ||||||
|             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: |             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||||
|                 item_name = '%s.%s' % (name, k) if name else name |                 item_name = '%s.%s' % (name, k) if name else name | ||||||
|   | |||||||
| @@ -4,12 +4,18 @@ import warnings | |||||||
| from bson.dbref import DBRef | from bson.dbref import DBRef | ||||||
| import pymongo | import pymongo | ||||||
| from pymongo.read_preferences import ReadPreference | from pymongo.read_preferences import ReadPreference | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
| from mongoengine.base import (BaseDict, BaseDocument, BaseList, | from mongoengine.base import ( | ||||||
|                               DocumentMetaclass, EmbeddedDocumentList, |     ALLOW_INHERITANCE, | ||||||
|                               TopLevelDocumentMetaclass, get_document) |     BaseDict, | ||||||
|  |     BaseDocument, | ||||||
|  |     BaseList, | ||||||
|  |     DocumentMetaclass, | ||||||
|  |     EmbeddedDocumentList, | ||||||
|  |     TopLevelDocumentMetaclass, | ||||||
|  |     get_document | ||||||
|  | ) | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||||
| from mongoengine.context_managers import switch_collection, switch_db | from mongoengine.context_managers import switch_collection, switch_db | ||||||
| @@ -25,10 +31,12 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | |||||||
|  |  | ||||||
|  |  | ||||||
| def includes_cls(fields): | def includes_cls(fields): | ||||||
|     """Helper function used for ensuring and comparing indexes.""" |     """ Helper function used for ensuring and comparing indexes | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     first_field = None |     first_field = None | ||||||
|     if len(fields): |     if len(fields): | ||||||
|         if isinstance(fields[0], six.string_types): |         if isinstance(fields[0], basestring): | ||||||
|             first_field = fields[0] |             first_field = fields[0] | ||||||
|         elif isinstance(fields[0], (list, tuple)) and len(fields[0]): |         elif isinstance(fields[0], (list, tuple)) and len(fields[0]): | ||||||
|             first_field = fields[0][0] |             first_field = fields[0][0] | ||||||
| @@ -49,8 +57,9 @@ class EmbeddedDocument(BaseDocument): | |||||||
|     to create a specialised version of the embedded document that will be |     to create a specialised version of the embedded document that will be | ||||||
|     stored in the same collection. To facilitate this behaviour a `_cls` |     stored in the same collection. To facilitate this behaviour a `_cls` | ||||||
|     field is added to documents (hidden though the MongoEngine interface). |     field is added to documents (hidden though the MongoEngine interface). | ||||||
|     To enable this behaviour set :attr:`allow_inheritance` to ``True`` in the |     To disable this behaviour and remove the dependence on the presence of | ||||||
|     :attr:`meta` dictionary. |     `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` | ||||||
|  |     dictionary. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     __slots__ = ('_instance', ) |     __slots__ = ('_instance', ) | ||||||
| @@ -73,15 +82,6 @@ class EmbeddedDocument(BaseDocument): | |||||||
|     def __ne__(self, other): |     def __ne__(self, other): | ||||||
|         return not self.__eq__(other) |         return not self.__eq__(other) | ||||||
|  |  | ||||||
|     def to_mongo(self, *args, **kwargs): |  | ||||||
|         data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs) |  | ||||||
|  |  | ||||||
|         # remove _id from the SON if it's in it and it's None |  | ||||||
|         if '_id' in data and data['_id'] is None: |  | ||||||
|             del data['_id'] |  | ||||||
|  |  | ||||||
|         return data |  | ||||||
|  |  | ||||||
|     def save(self, *args, **kwargs): |     def save(self, *args, **kwargs): | ||||||
|         self._instance.save(*args, **kwargs) |         self._instance.save(*args, **kwargs) | ||||||
|  |  | ||||||
| @@ -106,8 +106,9 @@ class Document(BaseDocument): | |||||||
|     create a specialised version of the document that will be stored in the |     create a specialised version of the document that will be stored in the | ||||||
|     same collection. To facilitate this behaviour a `_cls` |     same collection. To facilitate this behaviour a `_cls` | ||||||
|     field is added to documents (hidden though the MongoEngine interface). |     field is added to documents (hidden though the MongoEngine interface). | ||||||
|     To enable this behaviourset :attr:`allow_inheritance` to ``True`` in the |     To disable this behaviour and remove the dependence on the presence of | ||||||
|     :attr:`meta` dictionary. |     `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` | ||||||
|  |     dictionary. | ||||||
|  |  | ||||||
|     A :class:`~mongoengine.Document` may use a **Capped Collection** by |     A :class:`~mongoengine.Document` may use a **Capped Collection** by | ||||||
|     specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` |     specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` | ||||||
| @@ -148,22 +149,26 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|     __slots__ = ('__objects',) |     __slots__ = ('__objects',) | ||||||
|  |  | ||||||
|     @property |     def pk(): | ||||||
|     def pk(self): |         """Primary key alias | ||||||
|         """Get the primary key.""" |         """ | ||||||
|         if 'id_field' not in self._meta: |  | ||||||
|             return None |  | ||||||
|         return getattr(self, self._meta['id_field']) |  | ||||||
|  |  | ||||||
|     @pk.setter |         def fget(self): | ||||||
|     def pk(self, value): |             if 'id_field' not in self._meta: | ||||||
|         """Set the primary key.""" |                 return None | ||||||
|         return setattr(self, self._meta['id_field'], value) |             return getattr(self, self._meta['id_field']) | ||||||
|  |  | ||||||
|  |         def fset(self, value): | ||||||
|  |             return setattr(self, self._meta['id_field'], value) | ||||||
|  |  | ||||||
|  |         return property(fget, fset) | ||||||
|  |  | ||||||
|  |     pk = pk() | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _get_db(cls): |     def _get_db(cls): | ||||||
|         """Some Model using other db_alias""" |         """Some Model using other db_alias""" | ||||||
|         return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) |         return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _get_collection(cls): |     def _get_collection(cls): | ||||||
| @@ -206,20 +211,7 @@ class Document(BaseDocument): | |||||||
|                 cls.ensure_indexes() |                 cls.ensure_indexes() | ||||||
|         return cls._collection |         return cls._collection | ||||||
|  |  | ||||||
|     def to_mongo(self, *args, **kwargs): |     def modify(self, query={}, **update): | ||||||
|         data = super(Document, self).to_mongo(*args, **kwargs) |  | ||||||
|  |  | ||||||
|         # If '_id' is None, try and set it from self._data. If that |  | ||||||
|         # doesn't exist either, remote '_id' from the SON completely. |  | ||||||
|         if data['_id'] is None: |  | ||||||
|             if self._data.get('id') is None: |  | ||||||
|                 del data['_id'] |  | ||||||
|             else: |  | ||||||
|                 data['_id'] = self._data['id'] |  | ||||||
|  |  | ||||||
|         return data |  | ||||||
|  |  | ||||||
|     def modify(self, query=None, **update): |  | ||||||
|         """Perform an atomic update of the document in the database and reload |         """Perform an atomic update of the document in the database and reload | ||||||
|         the document object using updated version. |         the document object using updated version. | ||||||
|  |  | ||||||
| @@ -233,19 +225,17 @@ class Document(BaseDocument): | |||||||
|             database matches the query |             database matches the query | ||||||
|         :param update: Django-style update keyword arguments |         :param update: Django-style update keyword arguments | ||||||
|         """ |         """ | ||||||
|         if query is None: |  | ||||||
|             query = {} |  | ||||||
|  |  | ||||||
|         if self.pk is None: |         if self.pk is None: | ||||||
|             raise InvalidDocumentError('The document does not have a primary key.') |             raise InvalidDocumentError("The document does not have a primary key.") | ||||||
|  |  | ||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta["id_field"] | ||||||
|         query = query.copy() if isinstance(query, dict) else query.to_query(self) |         query = query.copy() if isinstance(query, dict) else query.to_query(self) | ||||||
|  |  | ||||||
|         if id_field not in query: |         if id_field not in query: | ||||||
|             query[id_field] = self.pk |             query[id_field] = self.pk | ||||||
|         elif query[id_field] != self.pk: |         elif query[id_field] != self.pk: | ||||||
|             raise InvalidQueryError('Invalid document modify query: it must modify only this document.') |             raise InvalidQueryError("Invalid document modify query: it must modify only this document.") | ||||||
|  |  | ||||||
|         updated = self._qs(**query).modify(new=True, **update) |         updated = self._qs(**query).modify(new=True, **update) | ||||||
|         if updated is None: |         if updated is None: | ||||||
| @@ -313,9 +303,6 @@ class Document(BaseDocument): | |||||||
|         .. versionchanged:: 0.10.7 |         .. versionchanged:: 0.10.7 | ||||||
|             Add signal_kwargs argument |             Add signal_kwargs argument | ||||||
|         """ |         """ | ||||||
|         if self._meta.get('abstract'): |  | ||||||
|             raise InvalidDocumentError('Cannot save an abstract document.') |  | ||||||
|  |  | ||||||
|         signal_kwargs = signal_kwargs or {} |         signal_kwargs = signal_kwargs or {} | ||||||
|         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) |         signals.pre_save.send(self.__class__, document=self, **signal_kwargs) | ||||||
|  |  | ||||||
| @@ -323,7 +310,7 @@ class Document(BaseDocument): | |||||||
|             self.validate(clean=clean) |             self.validate(clean=clean) | ||||||
|  |  | ||||||
|         if write_concern is None: |         if write_concern is None: | ||||||
|             write_concern = {'w': 1} |             write_concern = {"w": 1} | ||||||
|  |  | ||||||
|         doc = self.to_mongo() |         doc = self.to_mongo() | ||||||
|  |  | ||||||
| @@ -360,7 +347,7 @@ class Document(BaseDocument): | |||||||
|                 else: |                 else: | ||||||
|                     select_dict = {} |                     select_dict = {} | ||||||
|                 select_dict['_id'] = object_id |                 select_dict['_id'] = object_id | ||||||
|                 shard_key = self._meta.get('shard_key', tuple()) |                 shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||||
|                 for k in shard_key: |                 for k in shard_key: | ||||||
|                     path = self._lookup_field(k.split('.')) |                     path = self._lookup_field(k.split('.')) | ||||||
|                     actual_key = [p.db_field for p in path] |                     actual_key = [p.db_field for p in path] | ||||||
| @@ -371,7 +358,7 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|                 def is_new_object(last_error): |                 def is_new_object(last_error): | ||||||
|                     if last_error is not None: |                     if last_error is not None: | ||||||
|                         updated = last_error.get('updatedExisting') |                         updated = last_error.get("updatedExisting") | ||||||
|                         if updated is not None: |                         if updated is not None: | ||||||
|                             return not updated |                             return not updated | ||||||
|                     return created |                     return created | ||||||
| @@ -379,14 +366,14 @@ class Document(BaseDocument): | |||||||
|                 update_query = {} |                 update_query = {} | ||||||
|  |  | ||||||
|                 if updates: |                 if updates: | ||||||
|                     update_query['$set'] = updates |                     update_query["$set"] = updates | ||||||
|                 if removals: |                 if removals: | ||||||
|                     update_query['$unset'] = removals |                     update_query["$unset"] = removals | ||||||
|                 if updates or removals: |                 if updates or removals: | ||||||
|                     upsert = save_condition is None |                     upsert = save_condition is None | ||||||
|                     last_error = collection.update(select_dict, update_query, |                     last_error = collection.update(select_dict, update_query, | ||||||
|                                                    upsert=upsert, **write_concern) |                                                    upsert=upsert, **write_concern) | ||||||
|                     if not upsert and last_error['n'] == 0: |                     if not upsert and last_error["n"] == 0: | ||||||
|                         raise SaveConditionError('Race condition preventing' |                         raise SaveConditionError('Race condition preventing' | ||||||
|                                                  ' document update detected') |                                                  ' document update detected') | ||||||
|                     created = is_new_object(last_error) |                     created = is_new_object(last_error) | ||||||
| @@ -397,27 +384,26 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|             if cascade: |             if cascade: | ||||||
|                 kwargs = { |                 kwargs = { | ||||||
|                     'force_insert': force_insert, |                     "force_insert": force_insert, | ||||||
|                     'validate': validate, |                     "validate": validate, | ||||||
|                     'write_concern': write_concern, |                     "write_concern": write_concern, | ||||||
|                     'cascade': cascade |                     "cascade": cascade | ||||||
|                 } |                 } | ||||||
|                 if cascade_kwargs:  # Allow granular control over cascades |                 if cascade_kwargs:  # Allow granular control over cascades | ||||||
|                     kwargs.update(cascade_kwargs) |                     kwargs.update(cascade_kwargs) | ||||||
|                 kwargs['_refs'] = _refs |                 kwargs['_refs'] = _refs | ||||||
|                 self.cascade_save(**kwargs) |                 self.cascade_save(**kwargs) | ||||||
|         except pymongo.errors.DuplicateKeyError as err: |         except pymongo.errors.DuplicateKeyError, err: | ||||||
|             message = u'Tried to save duplicate unique keys (%s)' |             message = u'Tried to save duplicate unique keys (%s)' | ||||||
|             raise NotUniqueError(message % six.text_type(err)) |             raise NotUniqueError(message % unicode(err)) | ||||||
|         except pymongo.errors.OperationFailure as err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             message = 'Could not save document (%s)' |             message = 'Could not save document (%s)' | ||||||
|             if re.match('^E1100[01] duplicate key', six.text_type(err)): |             if re.match('^E1100[01] duplicate key', unicode(err)): | ||||||
|                 # E11000 - duplicate key error index |                 # E11000 - duplicate key error index | ||||||
|                 # E11001 - duplicate key on update |                 # E11001 - duplicate key on update | ||||||
|                 message = u'Tried to save duplicate unique keys (%s)' |                 message = u'Tried to save duplicate unique keys (%s)' | ||||||
|                 raise NotUniqueError(message % six.text_type(err)) |                 raise NotUniqueError(message % unicode(err)) | ||||||
|             raise OperationError(message % six.text_type(err)) |             raise OperationError(message % unicode(err)) | ||||||
|  |  | ||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         if created or id_field not in self._meta.get('shard_key', []): |         if created or id_field not in self._meta.get('shard_key', []): | ||||||
|             self[id_field] = self._fields[id_field].to_python(object_id) |             self[id_field] = self._fields[id_field].to_python(object_id) | ||||||
| @@ -428,11 +414,10 @@ class Document(BaseDocument): | |||||||
|         self._created = False |         self._created = False | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def cascade_save(self, **kwargs): |     def cascade_save(self, *args, **kwargs): | ||||||
|         """Recursively save any references and generic references on the |         """Recursively saves any references / | ||||||
|         document. |            generic references on the document""" | ||||||
|         """ |         _refs = kwargs.get('_refs', []) or [] | ||||||
|         _refs = kwargs.get('_refs') or [] |  | ||||||
|  |  | ||||||
|         ReferenceField = _import_class('ReferenceField') |         ReferenceField = _import_class('ReferenceField') | ||||||
|         GenericReferenceField = _import_class('GenericReferenceField') |         GenericReferenceField = _import_class('GenericReferenceField') | ||||||
| @@ -458,17 +443,16 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def _qs(self): |     def _qs(self): | ||||||
|         """Return the queryset to use for updating / reloading / deletions.""" |         """ | ||||||
|  |         Returns the queryset to use for updating / reloading / deletions | ||||||
|  |         """ | ||||||
|         if not hasattr(self, '__objects'): |         if not hasattr(self, '__objects'): | ||||||
|             self.__objects = QuerySet(self, self._get_collection()) |             self.__objects = QuerySet(self, self._get_collection()) | ||||||
|         return self.__objects |         return self.__objects | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def _object_key(self): |     def _object_key(self): | ||||||
|         """Get the query dict that can be used to fetch this object from |         """Dict to identify object in collection | ||||||
|         the database. Most of the time it's a simple PK lookup, but in |  | ||||||
|         case of a sharded collection with a compound shard key, it can |  | ||||||
|         contain a more complex query. |  | ||||||
|         """ |         """ | ||||||
|         select_dict = {'pk': self.pk} |         select_dict = {'pk': self.pk} | ||||||
|         shard_key = self.__class__._meta.get('shard_key', tuple()) |         shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||||
| @@ -491,8 +475,8 @@ class Document(BaseDocument): | |||||||
|         if self.pk is None: |         if self.pk is None: | ||||||
|             if kwargs.get('upsert', False): |             if kwargs.get('upsert', False): | ||||||
|                 query = self.to_mongo() |                 query = self.to_mongo() | ||||||
|                 if '_cls' in query: |                 if "_cls" in query: | ||||||
|                     del query['_cls'] |                     del query["_cls"] | ||||||
|                 return self._qs.filter(**query).update_one(**kwargs) |                 return self._qs.filter(**query).update_one(**kwargs) | ||||||
|             else: |             else: | ||||||
|                 raise OperationError( |                 raise OperationError( | ||||||
| @@ -529,7 +513,7 @@ class Document(BaseDocument): | |||||||
|         try: |         try: | ||||||
|             self._qs.filter( |             self._qs.filter( | ||||||
|                 **self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) |                 **self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) | ||||||
|         except pymongo.errors.OperationFailure as err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             message = u'Could not delete document (%s)' % err.message |             message = u'Could not delete document (%s)' % err.message | ||||||
|             raise OperationError(message) |             raise OperationError(message) | ||||||
|         signals.post_delete.send(self.__class__, document=self, **signal_kwargs) |         signals.post_delete.send(self.__class__, document=self, **signal_kwargs) | ||||||
| @@ -617,12 +601,11 @@ class Document(BaseDocument): | |||||||
|         if fields and isinstance(fields[0], int): |         if fields and isinstance(fields[0], int): | ||||||
|             max_depth = fields[0] |             max_depth = fields[0] | ||||||
|             fields = fields[1:] |             fields = fields[1:] | ||||||
|         elif 'max_depth' in kwargs: |         elif "max_depth" in kwargs: | ||||||
|             max_depth = kwargs['max_depth'] |             max_depth = kwargs["max_depth"] | ||||||
|  |  | ||||||
|         if self.pk is None: |         if self.pk is None: | ||||||
|             raise self.DoesNotExist('Document does not exist') |             raise self.DoesNotExist("Document does not exist") | ||||||
|  |  | ||||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( |         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||||
|             **self._object_key).only(*fields).limit( |             **self._object_key).only(*fields).limit( | ||||||
|             1).select_related(max_depth=max_depth) |             1).select_related(max_depth=max_depth) | ||||||
| @@ -630,7 +613,7 @@ class Document(BaseDocument): | |||||||
|         if obj: |         if obj: | ||||||
|             obj = obj[0] |             obj = obj[0] | ||||||
|         else: |         else: | ||||||
|             raise self.DoesNotExist('Document does not exist') |             raise self.DoesNotExist("Document does not exist") | ||||||
|  |  | ||||||
|         for field in obj._data: |         for field in obj._data: | ||||||
|             if not fields or field in fields: |             if not fields or field in fields: | ||||||
| @@ -673,7 +656,7 @@ class Document(BaseDocument): | |||||||
|         """Returns an instance of :class:`~bson.dbref.DBRef` useful in |         """Returns an instance of :class:`~bson.dbref.DBRef` useful in | ||||||
|         `__raw__` queries.""" |         `__raw__` queries.""" | ||||||
|         if self.pk is None: |         if self.pk is None: | ||||||
|             msg = 'Only saved documents can have a valid dbref' |             msg = "Only saved documents can have a valid dbref" | ||||||
|             raise OperationError(msg) |             raise OperationError(msg) | ||||||
|         return DBRef(self.__class__._get_collection_name(), self.pk) |         return DBRef(self.__class__._get_collection_name(), self.pk) | ||||||
|  |  | ||||||
| @@ -728,7 +711,7 @@ class Document(BaseDocument): | |||||||
|         fields = index_spec.pop('fields') |         fields = index_spec.pop('fields') | ||||||
|         drop_dups = kwargs.get('drop_dups', False) |         drop_dups = kwargs.get('drop_dups', False) | ||||||
|         if IS_PYMONGO_3 and drop_dups: |         if IS_PYMONGO_3 and drop_dups: | ||||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' |             msg = "drop_dups is deprecated and is removed when using PyMongo 3+." | ||||||
|             warnings.warn(msg, DeprecationWarning) |             warnings.warn(msg, DeprecationWarning) | ||||||
|         elif not IS_PYMONGO_3: |         elif not IS_PYMONGO_3: | ||||||
|             index_spec['drop_dups'] = drop_dups |             index_spec['drop_dups'] = drop_dups | ||||||
| @@ -754,7 +737,7 @@ class Document(BaseDocument): | |||||||
|             will be removed if PyMongo3+ is used |             will be removed if PyMongo3+ is used | ||||||
|         """ |         """ | ||||||
|         if IS_PYMONGO_3 and drop_dups: |         if IS_PYMONGO_3 and drop_dups: | ||||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' |             msg = "drop_dups is deprecated and is removed when using PyMongo 3+." | ||||||
|             warnings.warn(msg, DeprecationWarning) |             warnings.warn(msg, DeprecationWarning) | ||||||
|         elif not IS_PYMONGO_3: |         elif not IS_PYMONGO_3: | ||||||
|             kwargs.update({'drop_dups': drop_dups}) |             kwargs.update({'drop_dups': drop_dups}) | ||||||
| @@ -774,7 +757,7 @@ class Document(BaseDocument): | |||||||
|         index_opts = cls._meta.get('index_opts') or {} |         index_opts = cls._meta.get('index_opts') or {} | ||||||
|         index_cls = cls._meta.get('index_cls', True) |         index_cls = cls._meta.get('index_cls', True) | ||||||
|         if IS_PYMONGO_3 and drop_dups: |         if IS_PYMONGO_3 and drop_dups: | ||||||
|             msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' |             msg = "drop_dups is deprecated and is removed when using PyMongo 3+." | ||||||
|             warnings.warn(msg, DeprecationWarning) |             warnings.warn(msg, DeprecationWarning) | ||||||
|  |  | ||||||
|         collection = cls._get_collection() |         collection = cls._get_collection() | ||||||
| @@ -812,7 +795,8 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|         # If _cls is being used (for polymorphism), it needs an index, |         # If _cls is being used (for polymorphism), it needs an index, | ||||||
|         # only if another index doesn't begin with _cls |         # only if another index doesn't begin with _cls | ||||||
|         if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'): |         if (index_cls and not cls_indexed and | ||||||
|  |                 cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): | ||||||
|  |  | ||||||
|             # we shouldn't pass 'cls' to the collection.ensureIndex options |             # we shouldn't pass 'cls' to the collection.ensureIndex options | ||||||
|             # because of https://jira.mongodb.org/browse/SERVER-769 |             # because of https://jira.mongodb.org/browse/SERVER-769 | ||||||
| @@ -831,6 +815,7 @@ class Document(BaseDocument): | |||||||
|         """ Lists all of the indexes that should be created for given |         """ Lists all of the indexes that should be created for given | ||||||
|         collection. It includes all the indexes from super- and sub-classes. |         collection. It includes all the indexes from super- and sub-classes. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         if cls._meta.get('abstract'): |         if cls._meta.get('abstract'): | ||||||
|             return [] |             return [] | ||||||
|  |  | ||||||
| @@ -881,15 +866,16 @@ class Document(BaseDocument): | |||||||
|         # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed |         # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed | ||||||
|         if [(u'_id', 1)] not in indexes: |         if [(u'_id', 1)] not in indexes: | ||||||
|             indexes.append([(u'_id', 1)]) |             indexes.append([(u'_id', 1)]) | ||||||
|         if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'): |         if (cls._meta.get('index_cls', True) and | ||||||
|  |                 cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): | ||||||
|             indexes.append([(u'_cls', 1)]) |             indexes.append([(u'_cls', 1)]) | ||||||
|  |  | ||||||
|         return indexes |         return indexes | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def compare_indexes(cls): |     def compare_indexes(cls): | ||||||
|         """ Compares the indexes defined in MongoEngine with the ones |         """ Compares the indexes defined in MongoEngine with the ones existing | ||||||
|         existing in the database. Returns any missing/extra indexes. |         in the database. Returns any missing/extra indexes. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         required = cls.list_indexes() |         required = cls.list_indexes() | ||||||
| @@ -933,9 +919,8 @@ class DynamicDocument(Document): | |||||||
|     _dynamic = True |     _dynamic = True | ||||||
|  |  | ||||||
|     def __delattr__(self, *args, **kwargs): |     def __delattr__(self, *args, **kwargs): | ||||||
|         """Delete the attribute by setting to None and allowing _delta |         """Deletes the attribute by setting to None and allowing _delta to unset | ||||||
|         to unset it. |         it""" | ||||||
|         """ |  | ||||||
|         field_name = args[0] |         field_name = args[0] | ||||||
|         if field_name in self._dynamic_fields: |         if field_name in self._dynamic_fields: | ||||||
|             setattr(self, field_name, None) |             setattr(self, field_name, None) | ||||||
| @@ -957,9 +942,8 @@ class DynamicEmbeddedDocument(EmbeddedDocument): | |||||||
|     _dynamic = True |     _dynamic = True | ||||||
|  |  | ||||||
|     def __delattr__(self, *args, **kwargs): |     def __delattr__(self, *args, **kwargs): | ||||||
|         """Delete the attribute by setting to None and allowing _delta |         """Deletes the attribute by setting to None and allowing _delta to unset | ||||||
|         to unset it. |         it""" | ||||||
|         """ |  | ||||||
|         field_name = args[0] |         field_name = args[0] | ||||||
|         if field_name in self._fields: |         if field_name in self._fields: | ||||||
|             default = self._fields[field_name].default |             default = self._fields[field_name].default | ||||||
| @@ -1001,10 +985,10 @@ class MapReduceDocument(object): | |||||||
|             try: |             try: | ||||||
|                 self.key = id_field_type(self.key) |                 self.key = id_field_type(self.key) | ||||||
|             except Exception: |             except Exception: | ||||||
|                 raise Exception('Could not cast key as %s' % |                 raise Exception("Could not cast key as %s" % | ||||||
|                                 id_field_type.__name__) |                                 id_field_type.__name__) | ||||||
|  |  | ||||||
|         if not hasattr(self, '_key_object'): |         if not hasattr(self, "_key_object"): | ||||||
|             self._key_object = self._document.objects.with_id(self.key) |             self._key_object = self._document.objects.with_id(self.key) | ||||||
|             return self._key_object |             return self._key_object | ||||||
|         return self._key_object |         return self._key_object | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  |  | ||||||
| import six | from mongoengine.python_support import txt_type | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', | __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', | ||||||
|            'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', |            'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', | ||||||
| @@ -70,13 +71,13 @@ class ValidationError(AssertionError): | |||||||
|     field_name = None |     field_name = None | ||||||
|     _message = None |     _message = None | ||||||
|  |  | ||||||
|     def __init__(self, message='', **kwargs): |     def __init__(self, message="", **kwargs): | ||||||
|         self.errors = kwargs.get('errors', {}) |         self.errors = kwargs.get('errors', {}) | ||||||
|         self.field_name = kwargs.get('field_name') |         self.field_name = kwargs.get('field_name') | ||||||
|         self.message = message |         self.message = message | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return six.text_type(self.message) |         return txt_type(self.message) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return '%s(%s,)' % (self.__class__.__name__, self.message) |         return '%s(%s,)' % (self.__class__.__name__, self.message) | ||||||
| @@ -110,20 +111,17 @@ class ValidationError(AssertionError): | |||||||
|             errors_dict = {} |             errors_dict = {} | ||||||
|             if not source: |             if not source: | ||||||
|                 return errors_dict |                 return errors_dict | ||||||
|  |  | ||||||
|             if isinstance(source, dict): |             if isinstance(source, dict): | ||||||
|                 for field_name, error in source.iteritems(): |                 for field_name, error in source.iteritems(): | ||||||
|                     errors_dict[field_name] = build_dict(error) |                     errors_dict[field_name] = build_dict(error) | ||||||
|             elif isinstance(source, ValidationError) and source.errors: |             elif isinstance(source, ValidationError) and source.errors: | ||||||
|                 return build_dict(source.errors) |                 return build_dict(source.errors) | ||||||
|             else: |             else: | ||||||
|                 return six.text_type(source) |                 return unicode(source) | ||||||
|  |  | ||||||
|             return errors_dict |             return errors_dict | ||||||
|  |  | ||||||
|         if not self.errors: |         if not self.errors: | ||||||
|             return {} |             return {} | ||||||
|  |  | ||||||
|         return build_dict(self.errors) |         return build_dict(self.errors) | ||||||
|  |  | ||||||
|     def _format_errors(self): |     def _format_errors(self): | ||||||
| @@ -136,10 +134,10 @@ class ValidationError(AssertionError): | |||||||
|                 value = ' '.join( |                 value = ' '.join( | ||||||
|                     [generate_key(v, k) for k, v in value.iteritems()]) |                     [generate_key(v, k) for k, v in value.iteritems()]) | ||||||
|  |  | ||||||
|             results = '%s.%s' % (prefix, value) if prefix else value |             results = "%s.%s" % (prefix, value) if prefix else value | ||||||
|             return results |             return results | ||||||
|  |  | ||||||
|         error_dict = defaultdict(list) |         error_dict = defaultdict(list) | ||||||
|         for k, v in self.to_dict().iteritems(): |         for k, v in self.to_dict().iteritems(): | ||||||
|             error_dict[generate_key(v)].append(k) |             error_dict[generate_key(v)].append(k) | ||||||
|         return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()]) |         return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ import decimal | |||||||
| import itertools | import itertools | ||||||
| import re | import re | ||||||
| import time | import time | ||||||
|  | import urllib2 | ||||||
| import uuid | import uuid | ||||||
| import warnings | import warnings | ||||||
| from operator import itemgetter | from operator import itemgetter | ||||||
| @@ -24,13 +25,13 @@ try: | |||||||
| except ImportError: | except ImportError: | ||||||
|     Int64 = long |     Int64 = long | ||||||
|  |  | ||||||
| from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, | from .base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField, | ||||||
|                               GeoJsonBaseField, ObjectIdField, get_document) |                    ObjectIdField, get_document) | ||||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | from .connection import DEFAULT_CONNECTION_NAME, get_db | ||||||
| from mongoengine.document import Document, EmbeddedDocument | from .document import Document, EmbeddedDocument | ||||||
| from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError | from .errors import DoesNotExist, ValidationError | ||||||
| from mongoengine.python_support import StringIO | from .python_support import PY3, StringIO, bin_type, str_types, txt_type | ||||||
| from mongoengine.queryset import DO_NOTHING, QuerySet | from .queryset import DO_NOTHING, QuerySet | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     from PIL import Image, ImageOps |     from PIL import Image, ImageOps | ||||||
| @@ -38,7 +39,7 @@ except ImportError: | |||||||
|     Image = None |     Image = None | ||||||
|     ImageOps = None |     ImageOps = None | ||||||
|  |  | ||||||
| __all__ = ( | __all__ = [ | ||||||
|     'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', |     'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', | ||||||
|     'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', |     'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', | ||||||
|     'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', |     'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', | ||||||
| @@ -49,14 +50,14 @@ __all__ = ( | |||||||
|     'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', |     'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', | ||||||
|     'GeoPointField', 'PointField', 'LineStringField', 'PolygonField', |     'GeoPointField', 'PointField', 'LineStringField', 'PolygonField', | ||||||
|     'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField', |     'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField', | ||||||
|     'MultiPolygonField', 'GeoJsonBaseField' |     'MultiPolygonField', 'GeoJsonBaseField'] | ||||||
| ) |  | ||||||
|  |  | ||||||
| RECURSIVE_REFERENCE_CONSTANT = 'self' | RECURSIVE_REFERENCE_CONSTANT = 'self' | ||||||
|  |  | ||||||
|  |  | ||||||
| class StringField(BaseField): | class StringField(BaseField): | ||||||
|     """A unicode string field.""" |     """A unicode string field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): |     def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): | ||||||
|         self.regex = re.compile(regex) if regex else None |         self.regex = re.compile(regex) if regex else None | ||||||
| @@ -65,7 +66,7 @@ class StringField(BaseField): | |||||||
|         super(StringField, self).__init__(**kwargs) |         super(StringField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         if isinstance(value, six.text_type): |         if isinstance(value, unicode): | ||||||
|             return value |             return value | ||||||
|         try: |         try: | ||||||
|             value = value.decode('utf-8') |             value = value.decode('utf-8') | ||||||
| @@ -74,7 +75,7 @@ class StringField(BaseField): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if not isinstance(value, six.string_types): |         if not isinstance(value, basestring): | ||||||
|             self.error('StringField only accepts string values') |             self.error('StringField only accepts string values') | ||||||
|  |  | ||||||
|         if self.max_length is not None and len(value) > self.max_length: |         if self.max_length is not None and len(value) > self.max_length: | ||||||
| @@ -90,7 +91,7 @@ class StringField(BaseField): | |||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if not isinstance(op, six.string_types): |         if not isinstance(op, basestring): | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): |         if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): | ||||||
| @@ -147,6 +148,17 @@ class URLField(StringField): | |||||||
|             self.error('Invalid URL: {}'.format(value)) |             self.error('Invalid URL: {}'.format(value)) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|  |         if self.verify_exists: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "The URLField verify_exists argument has intractable security " | ||||||
|  |                 "and performance issues. Accordingly, it has been deprecated.", | ||||||
|  |                 DeprecationWarning) | ||||||
|  |             try: | ||||||
|  |                 request = urllib2.Request(value) | ||||||
|  |                 urllib2.urlopen(request) | ||||||
|  |             except Exception, e: | ||||||
|  |                 self.error('This URL appears to be a broken link: %s' % e) | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmailField(StringField): | class EmailField(StringField): | ||||||
|     """A field that validates input as an email address. |     """A field that validates input as an email address. | ||||||
| @@ -170,7 +182,8 @@ class EmailField(StringField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class IntField(BaseField): | class IntField(BaseField): | ||||||
|     """32-bit integer field.""" |     """An 32-bit integer field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, min_value=None, max_value=None, **kwargs): |     def __init__(self, min_value=None, max_value=None, **kwargs): | ||||||
|         self.min_value, self.max_value = min_value, max_value |         self.min_value, self.max_value = min_value, max_value | ||||||
| @@ -203,7 +216,8 @@ class IntField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class LongField(BaseField): | class LongField(BaseField): | ||||||
|     """64-bit integer field.""" |     """An 64-bit integer field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, min_value=None, max_value=None, **kwargs): |     def __init__(self, min_value=None, max_value=None, **kwargs): | ||||||
|         self.min_value, self.max_value = min_value, max_value |         self.min_value, self.max_value = min_value, max_value | ||||||
| @@ -239,7 +253,8 @@ class LongField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class FloatField(BaseField): | class FloatField(BaseField): | ||||||
|     """Floating point number field.""" |     """An floating point number field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, min_value=None, max_value=None, **kwargs): |     def __init__(self, min_value=None, max_value=None, **kwargs): | ||||||
|         self.min_value, self.max_value = min_value, max_value |         self.min_value, self.max_value = min_value, max_value | ||||||
| @@ -276,7 +291,7 @@ class FloatField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class DecimalField(BaseField): | class DecimalField(BaseField): | ||||||
|     """Fixed-point decimal number field. |     """A fixed-point decimal number field. | ||||||
|  |  | ||||||
|     .. versionchanged:: 0.8 |     .. versionchanged:: 0.8 | ||||||
|     .. versionadded:: 0.3 |     .. versionadded:: 0.3 | ||||||
| @@ -317,25 +332,25 @@ class DecimalField(BaseField): | |||||||
|  |  | ||||||
|         # Convert to string for python 2.6 before casting to Decimal |         # Convert to string for python 2.6 before casting to Decimal | ||||||
|         try: |         try: | ||||||
|             value = decimal.Decimal('%s' % value) |             value = decimal.Decimal("%s" % value) | ||||||
|         except decimal.InvalidOperation: |         except decimal.InvalidOperation: | ||||||
|             return value |             return value | ||||||
|         return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding) |         return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         if value is None: |         if value is None: | ||||||
|             return value |             return value | ||||||
|         if self.force_string: |         if self.force_string: | ||||||
|             return six.text_type(self.to_python(value)) |             return unicode(value) | ||||||
|         return float(self.to_python(value)) |         return float(self.to_python(value)) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if not isinstance(value, decimal.Decimal): |         if not isinstance(value, decimal.Decimal): | ||||||
|             if not isinstance(value, six.string_types): |             if not isinstance(value, basestring): | ||||||
|                 value = six.text_type(value) |                 value = unicode(value) | ||||||
|             try: |             try: | ||||||
|                 value = decimal.Decimal(value) |                 value = decimal.Decimal(value) | ||||||
|             except Exception as exc: |             except Exception, exc: | ||||||
|                 self.error('Could not convert value to decimal: %s' % exc) |                 self.error('Could not convert value to decimal: %s' % exc) | ||||||
|  |  | ||||||
|         if self.min_value is not None and value < self.min_value: |         if self.min_value is not None and value < self.min_value: | ||||||
| @@ -349,7 +364,7 @@ class DecimalField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class BooleanField(BaseField): | class BooleanField(BaseField): | ||||||
|     """Boolean field type. |     """A boolean field type. | ||||||
|  |  | ||||||
|     .. versionadded:: 0.1.2 |     .. versionadded:: 0.1.2 | ||||||
|     """ |     """ | ||||||
| @@ -367,7 +382,7 @@ class BooleanField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class DateTimeField(BaseField): | class DateTimeField(BaseField): | ||||||
|     """Datetime field. |     """A datetime field. | ||||||
|  |  | ||||||
|     Uses the python-dateutil library if available alternatively use time.strptime |     Uses the python-dateutil library if available alternatively use time.strptime | ||||||
|     to parse the dates.  Note: python-dateutil's parser is fully featured and when |     to parse the dates.  Note: python-dateutil's parser is fully featured and when | ||||||
| @@ -395,7 +410,7 @@ class DateTimeField(BaseField): | |||||||
|         if callable(value): |         if callable(value): | ||||||
|             return value() |             return value() | ||||||
|  |  | ||||||
|         if not isinstance(value, six.string_types): |         if not isinstance(value, basestring): | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         # Attempt to parse a datetime: |         # Attempt to parse a datetime: | ||||||
| @@ -522,19 +537,16 @@ class EmbeddedDocumentField(BaseField): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, document_type, **kwargs): |     def __init__(self, document_type, **kwargs): | ||||||
|         if ( |         if not isinstance(document_type, basestring): | ||||||
|             not isinstance(document_type, six.string_types) and |             if not issubclass(document_type, EmbeddedDocument): | ||||||
|             not issubclass(document_type, EmbeddedDocument) |                 self.error('Invalid embedded document class provided to an ' | ||||||
|         ): |                            'EmbeddedDocumentField') | ||||||
|             self.error('Invalid embedded document class provided to an ' |  | ||||||
|                        'EmbeddedDocumentField') |  | ||||||
|  |  | ||||||
|         self.document_type_obj = document_type |         self.document_type_obj = document_type | ||||||
|         super(EmbeddedDocumentField, self).__init__(**kwargs) |         super(EmbeddedDocumentField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def document_type(self): |     def document_type(self): | ||||||
|         if isinstance(self.document_type_obj, six.string_types): |         if isinstance(self.document_type_obj, basestring): | ||||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: |             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||||
|                 self.document_type_obj = self.owner_document |                 self.document_type_obj = self.owner_document | ||||||
|             else: |             else: | ||||||
| @@ -565,12 +577,8 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return self.document_type._fields.get(member_name) |         return self.document_type._fields.get(member_name) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if value is not None and not isinstance(value, self.document_type): |         if not isinstance(value, self.document_type): | ||||||
|             try: |             value = self.document_type._from_son(value) | ||||||
|                 value = self.document_type._from_son(value) |  | ||||||
|             except ValueError: |  | ||||||
|                 raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" % |  | ||||||
|                                         (self.document_type._class_name,)) |  | ||||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) |         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
|  |  | ||||||
| @@ -623,7 +631,7 @@ class DynamicField(BaseField): | |||||||
|         """Convert a Python type to a MongoDB compatible type. |         """Convert a Python type to a MongoDB compatible type. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         if isinstance(value, six.string_types): |         if isinstance(value, basestring): | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         if hasattr(value, 'to_mongo'): |         if hasattr(value, 'to_mongo'): | ||||||
| @@ -631,7 +639,7 @@ class DynamicField(BaseField): | |||||||
|             val = value.to_mongo(use_db_field, fields) |             val = value.to_mongo(use_db_field, fields) | ||||||
|             # If we its a document thats not inherited add _cls |             # If we its a document thats not inherited add _cls | ||||||
|             if isinstance(value, Document): |             if isinstance(value, Document): | ||||||
|                 val = {'_ref': value.to_dbref(), '_cls': cls.__name__} |                 val = {"_ref": value.to_dbref(), "_cls": cls.__name__} | ||||||
|             if isinstance(value, EmbeddedDocument): |             if isinstance(value, EmbeddedDocument): | ||||||
|                 val['_cls'] = cls.__name__ |                 val['_cls'] = cls.__name__ | ||||||
|             return val |             return val | ||||||
| @@ -642,7 +650,7 @@ class DynamicField(BaseField): | |||||||
|         is_list = False |         is_list = False | ||||||
|         if not hasattr(value, 'items'): |         if not hasattr(value, 'items'): | ||||||
|             is_list = True |             is_list = True | ||||||
|             value = {k: v for k, v in enumerate(value)} |             value = dict([(k, v) for k, v in enumerate(value)]) | ||||||
|  |  | ||||||
|         data = {} |         data = {} | ||||||
|         for k, v in value.iteritems(): |         for k, v in value.iteritems(): | ||||||
| @@ -666,12 +674,12 @@ class DynamicField(BaseField): | |||||||
|         return member_name |         return member_name | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if isinstance(value, six.string_types): |         if isinstance(value, basestring): | ||||||
|             return StringField().prepare_query_value(op, value) |             return StringField().prepare_query_value(op, value) | ||||||
|         return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) |         return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) | ||||||
|  |  | ||||||
|     def validate(self, value, clean=True): |     def validate(self, value, clean=True): | ||||||
|         if hasattr(value, 'validate'): |         if hasattr(value, "validate"): | ||||||
|             value.validate(clean=clean) |             value.validate(clean=clean) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -691,27 +699,21 @@ class ListField(ComplexBaseField): | |||||||
|         super(ListField, self).__init__(**kwargs) |         super(ListField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """Make sure that a list of valid fields is being used.""" |         """Make sure that a list of valid fields is being used. | ||||||
|  |         """ | ||||||
|         if (not isinstance(value, (list, tuple, QuerySet)) or |         if (not isinstance(value, (list, tuple, QuerySet)) or | ||||||
|                 isinstance(value, six.string_types)): |                 isinstance(value, basestring)): | ||||||
|             self.error('Only lists and tuples may be used in a list field') |             self.error('Only lists and tuples may be used in a list field') | ||||||
|         super(ListField, self).validate(value) |         super(ListField, self).validate(value) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if self.field: |         if self.field: | ||||||
|  |             if op in ('set', 'unset', None) and ( | ||||||
|             # If the value is iterable and it's not a string nor a |                     not isinstance(value, basestring) and | ||||||
|             # BaseDocument, call prepare_query_value for each of its items. |                     not isinstance(value, BaseDocument) and | ||||||
|             if ( |                     hasattr(value, '__iter__')): | ||||||
|                 op in ('set', 'unset', None) and |  | ||||||
|                 hasattr(value, '__iter__') and |  | ||||||
|                 not isinstance(value, six.string_types) and |  | ||||||
|                 not isinstance(value, BaseDocument) |  | ||||||
|             ): |  | ||||||
|                 return [self.field.prepare_query_value(op, v) for v in value] |                 return [self.field.prepare_query_value(op, v) for v in value] | ||||||
|  |  | ||||||
|             return self.field.prepare_query_value(op, value) |             return self.field.prepare_query_value(op, value) | ||||||
|  |  | ||||||
|         return super(ListField, self).prepare_query_value(op, value) |         return super(ListField, self).prepare_query_value(op, value) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -724,6 +726,7 @@ class EmbeddedDocumentListField(ListField): | |||||||
|         :class:`~mongoengine.EmbeddedDocument`. |         :class:`~mongoengine.EmbeddedDocument`. | ||||||
|  |  | ||||||
|     .. versionadded:: 0.9 |     .. versionadded:: 0.9 | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, document_type, **kwargs): |     def __init__(self, document_type, **kwargs): | ||||||
| @@ -772,17 +775,17 @@ class SortedListField(ListField): | |||||||
|  |  | ||||||
|  |  | ||||||
| def key_not_string(d): | def key_not_string(d): | ||||||
|     """Helper function to recursively determine if any key in a |     """ Helper function to recursively determine if any key in a dictionary is | ||||||
|     dictionary is not a string. |     not a string. | ||||||
|     """ |     """ | ||||||
|     for k, v in d.items(): |     for k, v in d.items(): | ||||||
|         if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)): |         if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): | ||||||
|             return True |             return True | ||||||
|  |  | ||||||
|  |  | ||||||
| def key_has_dot_or_dollar(d): | def key_has_dot_or_dollar(d): | ||||||
|     """Helper function to recursively determine if any key in a |     """ Helper function to recursively determine if any key in a dictionary | ||||||
|     dictionary contains a dot or a dollar sign. |     contains a dot or a dollar sign. | ||||||
|     """ |     """ | ||||||
|     for k, v in d.items(): |     for k, v in d.items(): | ||||||
|         if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): |         if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): | ||||||
| @@ -810,13 +813,14 @@ class DictField(ComplexBaseField): | |||||||
|         super(DictField, self).__init__(*args, **kwargs) |         super(DictField, self).__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """Make sure that a list of valid fields is being used.""" |         """Make sure that a list of valid fields is being used. | ||||||
|  |         """ | ||||||
|         if not isinstance(value, dict): |         if not isinstance(value, dict): | ||||||
|             self.error('Only dictionaries may be used in a DictField') |             self.error('Only dictionaries may be used in a DictField') | ||||||
|  |  | ||||||
|         if key_not_string(value): |         if key_not_string(value): | ||||||
|             msg = ('Invalid dictionary key - documents must ' |             msg = ("Invalid dictionary key - documents must " | ||||||
|                    'have only string keys') |                    "have only string keys") | ||||||
|             self.error(msg) |             self.error(msg) | ||||||
|         if key_has_dot_or_dollar(value): |         if key_has_dot_or_dollar(value): | ||||||
|             self.error('Invalid dictionary key name - keys may not contain "."' |             self.error('Invalid dictionary key name - keys may not contain "."' | ||||||
| @@ -831,15 +835,14 @@ class DictField(ComplexBaseField): | |||||||
|                            'istartswith', 'endswith', 'iendswith', |                            'istartswith', 'endswith', 'iendswith', | ||||||
|                            'exact', 'iexact'] |                            'exact', 'iexact'] | ||||||
|  |  | ||||||
|         if op in match_operators and isinstance(value, six.string_types): |         if op in match_operators and isinstance(value, basestring): | ||||||
|             return StringField().prepare_query_value(op, value) |             return StringField().prepare_query_value(op, value) | ||||||
|  |  | ||||||
|         if hasattr(self.field, 'field'): |         if hasattr(self.field, 'field'): | ||||||
|             if op in ('set', 'unset') and isinstance(value, dict): |             if op in ('set', 'unset') and isinstance(value, dict): | ||||||
|                 return { |                 return dict( | ||||||
|                     k: self.field.prepare_query_value(op, v) |                     (k, self.field.prepare_query_value(op, v)) | ||||||
|                     for k, v in value.items() |                     for k, v in value.items()) | ||||||
|                 } |  | ||||||
|             return self.field.prepare_query_value(op, value) |             return self.field.prepare_query_value(op, value) | ||||||
|  |  | ||||||
|         return super(DictField, self).prepare_query_value(op, value) |         return super(DictField, self).prepare_query_value(op, value) | ||||||
| @@ -908,12 +911,10 @@ class ReferenceField(BaseField): | |||||||
|             A reference to an abstract document type is always stored as a |             A reference to an abstract document type is always stored as a | ||||||
|             :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. |             :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. | ||||||
|         """ |         """ | ||||||
|         if ( |         if not isinstance(document_type, basestring): | ||||||
|             not isinstance(document_type, six.string_types) and |             if not issubclass(document_type, (Document, basestring)): | ||||||
|             not issubclass(document_type, Document) |                 self.error('Argument to ReferenceField constructor must be a ' | ||||||
|         ): |                            'document class or a string') | ||||||
|             self.error('Argument to ReferenceField constructor must be a ' |  | ||||||
|                        'document class or a string') |  | ||||||
|  |  | ||||||
|         self.dbref = dbref |         self.dbref = dbref | ||||||
|         self.document_type_obj = document_type |         self.document_type_obj = document_type | ||||||
| @@ -922,7 +923,7 @@ class ReferenceField(BaseField): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def document_type(self): |     def document_type(self): | ||||||
|         if isinstance(self.document_type_obj, six.string_types): |         if isinstance(self.document_type_obj, basestring): | ||||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: |             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||||
|                 self.document_type_obj = self.owner_document |                 self.document_type_obj = self.owner_document | ||||||
|             else: |             else: | ||||||
| @@ -930,7 +931,8 @@ class ReferenceField(BaseField): | |||||||
|         return self.document_type_obj |         return self.document_type_obj | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |     def __get__(self, instance, owner): | ||||||
|         """Descriptor to allow lazy dereferencing.""" |         """Descriptor to allow lazy dereferencing. | ||||||
|  |         """ | ||||||
|         if instance is None: |         if instance is None: | ||||||
|             # Document class being used rather than a document object |             # Document class being used rather than a document object | ||||||
|             return self |             return self | ||||||
| @@ -987,7 +989,8 @@ class ReferenceField(BaseField): | |||||||
|         return id_ |         return id_ | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         """Convert a MongoDB-compatible type to a Python type.""" |         """Convert a MongoDB-compatible type to a Python type. | ||||||
|  |         """ | ||||||
|         if (not self.dbref and |         if (not self.dbref and | ||||||
|                 not isinstance(value, (DBRef, Document, EmbeddedDocument))): |                 not isinstance(value, (DBRef, Document, EmbeddedDocument))): | ||||||
|             collection = self.document_type._get_collection_name() |             collection = self.document_type._get_collection_name() | ||||||
| @@ -1003,7 +1006,7 @@ class ReferenceField(BaseField): | |||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|  |  | ||||||
|         if not isinstance(value, (self.document_type, DBRef)): |         if not isinstance(value, (self.document_type, DBRef)): | ||||||
|             self.error('A ReferenceField only accepts DBRef or documents') |             self.error("A ReferenceField only accepts DBRef or documents") | ||||||
|  |  | ||||||
|         if isinstance(value, Document) and value.id is None: |         if isinstance(value, Document) and value.id is None: | ||||||
|             self.error('You can only reference documents once they have been ' |             self.error('You can only reference documents once they have been ' | ||||||
| @@ -1027,19 +1030,14 @@ class CachedReferenceField(BaseField): | |||||||
|     .. versionadded:: 0.9 |     .. versionadded:: 0.9 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, document_type, fields=None, auto_sync=True, **kwargs): |     def __init__(self, document_type, fields=[], auto_sync=True, **kwargs): | ||||||
|         """Initialises the Cached Reference Field. |         """Initialises the Cached Reference Field. | ||||||
|  |  | ||||||
|         :param fields:  A list of fields to be cached in document |         :param fields:  A list of fields to be cached in document | ||||||
|         :param auto_sync: if True documents are auto updated. |         :param auto_sync: if True documents are auto updated. | ||||||
|         """ |         """ | ||||||
|         if fields is None: |         if not isinstance(document_type, basestring) and \ | ||||||
|             fields = [] |                 not issubclass(document_type, (Document, basestring)): | ||||||
|  |  | ||||||
|         if ( |  | ||||||
|             not isinstance(document_type, six.string_types) and |  | ||||||
|             not issubclass(document_type, Document) |  | ||||||
|         ): |  | ||||||
|             self.error('Argument to CachedReferenceField constructor must be a' |             self.error('Argument to CachedReferenceField constructor must be a' | ||||||
|                        ' document class or a string') |                        ' document class or a string') | ||||||
|  |  | ||||||
| @@ -1055,20 +1053,18 @@ class CachedReferenceField(BaseField): | |||||||
|                                   sender=self.document_type) |                                   sender=self.document_type) | ||||||
|  |  | ||||||
|     def on_document_pre_save(self, sender, document, created, **kwargs): |     def on_document_pre_save(self, sender, document, created, **kwargs): | ||||||
|         if created: |         if not created: | ||||||
|             return None |             update_kwargs = dict( | ||||||
|  |                 ('set__%s__%s' % (self.name, k), v) | ||||||
|  |                 for k, v in document._delta()[0].items() | ||||||
|  |                 if k in self.fields) | ||||||
|  |  | ||||||
|         update_kwargs = { |             if update_kwargs: | ||||||
|             'set__%s__%s' % (self.name, key): val |                 filter_kwargs = {} | ||||||
|             for key, val in document._delta()[0].items() |                 filter_kwargs[self.name] = document | ||||||
|             if key in self.fields |  | ||||||
|         } |  | ||||||
|         if update_kwargs: |  | ||||||
|             filter_kwargs = {} |  | ||||||
|             filter_kwargs[self.name] = document |  | ||||||
|  |  | ||||||
|             self.owner_document.objects( |                 self.owner_document.objects( | ||||||
|                 **filter_kwargs).update(**update_kwargs) |                     **filter_kwargs).update(**update_kwargs) | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         if isinstance(value, dict): |         if isinstance(value, dict): | ||||||
| @@ -1081,7 +1077,7 @@ class CachedReferenceField(BaseField): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def document_type(self): |     def document_type(self): | ||||||
|         if isinstance(self.document_type_obj, six.string_types): |         if isinstance(self.document_type_obj, basestring): | ||||||
|             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: |             if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: | ||||||
|                 self.document_type_obj = self.owner_document |                 self.document_type_obj = self.owner_document | ||||||
|             else: |             else: | ||||||
| @@ -1121,7 +1117,7 @@ class CachedReferenceField(BaseField): | |||||||
|             # TODO: should raise here or will fail next statement |             # TODO: should raise here or will fail next statement | ||||||
|  |  | ||||||
|         value = SON(( |         value = SON(( | ||||||
|             ('_id', id_field.to_mongo(id_)), |             ("_id", id_field.to_mongo(id_)), | ||||||
|         )) |         )) | ||||||
|  |  | ||||||
|         if fields: |         if fields: | ||||||
| @@ -1147,7 +1143,7 @@ class CachedReferenceField(BaseField): | |||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|  |  | ||||||
|         if not isinstance(value, self.document_type): |         if not isinstance(value, self.document_type): | ||||||
|             self.error('A CachedReferenceField only accepts documents') |             self.error("A CachedReferenceField only accepts documents") | ||||||
|  |  | ||||||
|         if isinstance(value, Document) and value.id is None: |         if isinstance(value, Document) and value.id is None: | ||||||
|             self.error('You can only reference documents once they have been ' |             self.error('You can only reference documents once they have been ' | ||||||
| @@ -1195,13 +1191,13 @@ class GenericReferenceField(BaseField): | |||||||
|         # Keep the choices as a list of allowed Document class names |         # Keep the choices as a list of allowed Document class names | ||||||
|         if choices: |         if choices: | ||||||
|             for choice in choices: |             for choice in choices: | ||||||
|                 if isinstance(choice, six.string_types): |                 if isinstance(choice, basestring): | ||||||
|                     self.choices.append(choice) |                     self.choices.append(choice) | ||||||
|                 elif isinstance(choice, type) and issubclass(choice, Document): |                 elif isinstance(choice, type) and issubclass(choice, Document): | ||||||
|                     self.choices.append(choice._class_name) |                     self.choices.append(choice._class_name) | ||||||
|                 else: |                 else: | ||||||
|                     self.error('Invalid choices provided: must be a list of' |                     self.error('Invalid choices provided: must be a list of' | ||||||
|                                'Document subclasses and/or six.string_typess') |                                'Document subclasses and/or basestrings') | ||||||
|  |  | ||||||
|     def _validate_choices(self, value): |     def _validate_choices(self, value): | ||||||
|         if isinstance(value, dict): |         if isinstance(value, dict): | ||||||
| @@ -1253,7 +1249,7 @@ class GenericReferenceField(BaseField): | |||||||
|         if document is None: |         if document is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         if isinstance(document, (dict, SON, ObjectId, DBRef)): |         if isinstance(document, (dict, SON)): | ||||||
|             return document |             return document | ||||||
|  |  | ||||||
|         id_field_name = document.__class__._meta['id_field'] |         id_field_name = document.__class__._meta['id_field'] | ||||||
| @@ -1284,7 +1280,8 @@ class GenericReferenceField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class BinaryField(BaseField): | class BinaryField(BaseField): | ||||||
|     """A binary data field.""" |     """A binary data field. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, max_bytes=None, **kwargs): |     def __init__(self, max_bytes=None, **kwargs): | ||||||
|         self.max_bytes = max_bytes |         self.max_bytes = max_bytes | ||||||
| @@ -1292,18 +1289,18 @@ class BinaryField(BaseField): | |||||||
|  |  | ||||||
|     def __set__(self, instance, value): |     def __set__(self, instance, value): | ||||||
|         """Handle bytearrays in python 3.1""" |         """Handle bytearrays in python 3.1""" | ||||||
|         if six.PY3 and isinstance(value, bytearray): |         if PY3 and isinstance(value, bytearray): | ||||||
|             value = six.binary_type(value) |             value = bin_type(value) | ||||||
|         return super(BinaryField, self).__set__(instance, value) |         return super(BinaryField, self).__set__(instance, value) | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         return Binary(value) |         return Binary(value) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if not isinstance(value, (six.binary_type, six.text_type, Binary)): |         if not isinstance(value, (bin_type, txt_type, Binary)): | ||||||
|             self.error('BinaryField only accepts instances of ' |             self.error("BinaryField only accepts instances of " | ||||||
|                        '(%s, %s, Binary)' % ( |                        "(%s, %s, Binary)" % ( | ||||||
|                            six.binary_type.__name__, six.text_type.__name__)) |                            bin_type.__name__, txt_type.__name__)) | ||||||
|  |  | ||||||
|         if self.max_bytes is not None and len(value) > self.max_bytes: |         if self.max_bytes is not None and len(value) > self.max_bytes: | ||||||
|             self.error('Binary value is too long') |             self.error('Binary value is too long') | ||||||
| @@ -1387,13 +1384,11 @@ class GridFSProxy(object): | |||||||
|                 get_db(self.db_alias), self.collection_name) |                 get_db(self.db_alias), self.collection_name) | ||||||
|         return self._fs |         return self._fs | ||||||
|  |  | ||||||
|     def get(self, grid_id=None): |     def get(self, id=None): | ||||||
|         if grid_id: |         if id: | ||||||
|             self.grid_id = grid_id |             self.grid_id = id | ||||||
|  |  | ||||||
|         if self.grid_id is None: |         if self.grid_id is None: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             if self.gridout is None: |             if self.gridout is None: | ||||||
|                 self.gridout = self.fs.get(self.grid_id) |                 self.gridout = self.fs.get(self.grid_id) | ||||||
| @@ -1437,7 +1432,7 @@ class GridFSProxy(object): | |||||||
|             try: |             try: | ||||||
|                 return gridout.read(size) |                 return gridout.read(size) | ||||||
|             except Exception: |             except Exception: | ||||||
|                 return '' |                 return "" | ||||||
|  |  | ||||||
|     def delete(self): |     def delete(self): | ||||||
|         # Delete file from GridFS, FileField still remains |         # Delete file from GridFS, FileField still remains | ||||||
| @@ -1469,8 +1464,9 @@ class FileField(BaseField): | |||||||
|     """ |     """ | ||||||
|     proxy_class = GridFSProxy |     proxy_class = GridFSProxy | ||||||
|  |  | ||||||
|     def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs', |     def __init__(self, | ||||||
|                  **kwargs): |                  db_alias=DEFAULT_CONNECTION_NAME, | ||||||
|  |                  collection_name="fs", **kwargs): | ||||||
|         super(FileField, self).__init__(**kwargs) |         super(FileField, self).__init__(**kwargs) | ||||||
|         self.collection_name = collection_name |         self.collection_name = collection_name | ||||||
|         self.db_alias = db_alias |         self.db_alias = db_alias | ||||||
| @@ -1492,10 +1488,8 @@ class FileField(BaseField): | |||||||
|  |  | ||||||
|     def __set__(self, instance, value): |     def __set__(self, instance, value): | ||||||
|         key = self.name |         key = self.name | ||||||
|         if ( |         if ((hasattr(value, 'read') and not | ||||||
|             (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or |                 isinstance(value, GridFSProxy)) or isinstance(value, str_types)): | ||||||
|             isinstance(value, (six.binary_type, six.string_types)) |  | ||||||
|         ): |  | ||||||
|             # using "FileField() = file/string" notation |             # using "FileField() = file/string" notation | ||||||
|             grid_file = instance._data.get(self.name) |             grid_file = instance._data.get(self.name) | ||||||
|             # If a file already exists, delete it |             # If a file already exists, delete it | ||||||
| @@ -1564,7 +1558,7 @@ class ImageGridFsProxy(GridFSProxy): | |||||||
|         try: |         try: | ||||||
|             img = Image.open(file_obj) |             img = Image.open(file_obj) | ||||||
|             img_format = img.format |             img_format = img.format | ||||||
|         except Exception as e: |         except Exception, e: | ||||||
|             raise ValidationError('Invalid image: %s' % e) |             raise ValidationError('Invalid image: %s' % e) | ||||||
|  |  | ||||||
|         # Progressive JPEG |         # Progressive JPEG | ||||||
| @@ -1673,10 +1667,10 @@ class ImageGridFsProxy(GridFSProxy): | |||||||
|             return self.fs.get(out.thumbnail_id) |             return self.fs.get(out.thumbnail_id) | ||||||
|  |  | ||||||
|     def write(self, *args, **kwargs): |     def write(self, *args, **kwargs): | ||||||
|         raise RuntimeError('Please use "put" method instead') |         raise RuntimeError("Please use \"put\" method instead") | ||||||
|  |  | ||||||
|     def writelines(self, *args, **kwargs): |     def writelines(self, *args, **kwargs): | ||||||
|         raise RuntimeError('Please use "put" method instead') |         raise RuntimeError("Please use \"put\" method instead") | ||||||
|  |  | ||||||
|  |  | ||||||
| class ImproperlyConfigured(Exception): | class ImproperlyConfigured(Exception): | ||||||
| @@ -1701,17 +1695,14 @@ class ImageField(FileField): | |||||||
|     def __init__(self, size=None, thumbnail_size=None, |     def __init__(self, size=None, thumbnail_size=None, | ||||||
|                  collection_name='images', **kwargs): |                  collection_name='images', **kwargs): | ||||||
|         if not Image: |         if not Image: | ||||||
|             raise ImproperlyConfigured('PIL library was not found') |             raise ImproperlyConfigured("PIL library was not found") | ||||||
|  |  | ||||||
|         params_size = ('width', 'height', 'force') |         params_size = ('width', 'height', 'force') | ||||||
|         extra_args = { |         extra_args = dict(size=size, thumbnail_size=thumbnail_size) | ||||||
|             'size': size, |  | ||||||
|             'thumbnail_size': thumbnail_size |  | ||||||
|         } |  | ||||||
|         for att_name, att in extra_args.items(): |         for att_name, att in extra_args.items(): | ||||||
|             value = None |             value = None | ||||||
|             if isinstance(att, (tuple, list)): |             if isinstance(att, (tuple, list)): | ||||||
|                 if six.PY3: |                 if PY3: | ||||||
|                     value = dict(itertools.zip_longest(params_size, att, |                     value = dict(itertools.zip_longest(params_size, att, | ||||||
|                                                        fillvalue=None)) |                                                        fillvalue=None)) | ||||||
|                 else: |                 else: | ||||||
| @@ -1772,10 +1763,10 @@ class SequenceField(BaseField): | |||||||
|         Generate and Increment the counter |         Generate and Increment the counter | ||||||
|         """ |         """ | ||||||
|         sequence_name = self.get_sequence_name() |         sequence_name = self.get_sequence_name() | ||||||
|         sequence_id = '%s.%s' % (sequence_name, self.name) |         sequence_id = "%s.%s" % (sequence_name, self.name) | ||||||
|         collection = get_db(alias=self.db_alias)[self.collection_name] |         collection = get_db(alias=self.db_alias)[self.collection_name] | ||||||
|         counter = collection.find_and_modify(query={'_id': sequence_id}, |         counter = collection.find_and_modify(query={"_id": sequence_id}, | ||||||
|                                              update={'$inc': {'next': 1}}, |                                              update={"$inc": {"next": 1}}, | ||||||
|                                              new=True, |                                              new=True, | ||||||
|                                              upsert=True) |                                              upsert=True) | ||||||
|         return self.value_decorator(counter['next']) |         return self.value_decorator(counter['next']) | ||||||
| @@ -1798,9 +1789,9 @@ class SequenceField(BaseField): | |||||||
|         as it is only fixed on set. |         as it is only fixed on set. | ||||||
|         """ |         """ | ||||||
|         sequence_name = self.get_sequence_name() |         sequence_name = self.get_sequence_name() | ||||||
|         sequence_id = '%s.%s' % (sequence_name, self.name) |         sequence_id = "%s.%s" % (sequence_name, self.name) | ||||||
|         collection = get_db(alias=self.db_alias)[self.collection_name] |         collection = get_db(alias=self.db_alias)[self.collection_name] | ||||||
|         data = collection.find_one({'_id': sequence_id}) |         data = collection.find_one({"_id": sequence_id}) | ||||||
|  |  | ||||||
|         if data: |         if data: | ||||||
|             return self.value_decorator(data['next'] + 1) |             return self.value_decorator(data['next'] + 1) | ||||||
| @@ -1870,8 +1861,8 @@ class UUIDField(BaseField): | |||||||
|         if not self._binary: |         if not self._binary: | ||||||
|             original_value = value |             original_value = value | ||||||
|             try: |             try: | ||||||
|                 if not isinstance(value, six.string_types): |                 if not isinstance(value, basestring): | ||||||
|                     value = six.text_type(value) |                     value = unicode(value) | ||||||
|                 return uuid.UUID(value) |                 return uuid.UUID(value) | ||||||
|             except Exception: |             except Exception: | ||||||
|                 return original_value |                 return original_value | ||||||
| @@ -1879,8 +1870,8 @@ class UUIDField(BaseField): | |||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         if not self._binary: |         if not self._binary: | ||||||
|             return six.text_type(value) |             return unicode(value) | ||||||
|         elif isinstance(value, six.string_types): |         elif isinstance(value, basestring): | ||||||
|             return uuid.UUID(value) |             return uuid.UUID(value) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
| @@ -1891,11 +1882,11 @@ class UUIDField(BaseField): | |||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if not isinstance(value, uuid.UUID): |         if not isinstance(value, uuid.UUID): | ||||||
|             if not isinstance(value, six.string_types): |             if not isinstance(value, basestring): | ||||||
|                 value = str(value) |                 value = str(value) | ||||||
|             try: |             try: | ||||||
|                 uuid.UUID(value) |                 uuid.UUID(value) | ||||||
|             except Exception as exc: |             except Exception, exc: | ||||||
|                 self.error('Could not convert to UUID: %s' % exc) |                 self.error('Could not convert to UUID: %s' % exc) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -1913,18 +1904,19 @@ class GeoPointField(BaseField): | |||||||
|     _geo_index = pymongo.GEO2D |     _geo_index = pymongo.GEO2D | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         """Make sure that a geo-value is of type (x, y)""" |         """Make sure that a geo-value is of type (x, y) | ||||||
|  |         """ | ||||||
|         if not isinstance(value, (list, tuple)): |         if not isinstance(value, (list, tuple)): | ||||||
|             self.error('GeoPointField can only accept tuples or lists ' |             self.error('GeoPointField can only accept tuples or lists ' | ||||||
|                        'of (x, y)') |                        'of (x, y)') | ||||||
|  |  | ||||||
|         if not len(value) == 2: |         if not len(value) == 2: | ||||||
|             self.error('Value (%s) must be a two-dimensional point' % |             self.error("Value (%s) must be a two-dimensional point" % | ||||||
|                        repr(value)) |                        repr(value)) | ||||||
|         elif (not isinstance(value[0], (float, int)) or |         elif (not isinstance(value[0], (float, int)) or | ||||||
|               not isinstance(value[1], (float, int))): |               not isinstance(value[1], (float, int))): | ||||||
|             self.error( |             self.error( | ||||||
|                 'Both values (%s) in point must be float or int' % repr(value)) |                 "Both values (%s) in point must be float or int" % repr(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PointField(GeoJsonBaseField): | class PointField(GeoJsonBaseField): | ||||||
| @@ -1934,8 +1926,8 @@ class PointField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. code-block:: js |     .. code-block:: js | ||||||
|  |  | ||||||
|         {'type' : 'Point' , |         { "type" : "Point" , | ||||||
|          'coordinates' : [x, y]} |           "coordinates" : [x, y]} | ||||||
|  |  | ||||||
|     You can either pass a dict with the full information or a list |     You can either pass a dict with the full information or a list | ||||||
|     to set the value. |     to set the value. | ||||||
| @@ -1944,7 +1936,7 @@ class PointField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. versionadded:: 0.8 |     .. versionadded:: 0.8 | ||||||
|     """ |     """ | ||||||
|     _type = 'Point' |     _type = "Point" | ||||||
|  |  | ||||||
|  |  | ||||||
| class LineStringField(GeoJsonBaseField): | class LineStringField(GeoJsonBaseField): | ||||||
| @@ -1954,8 +1946,8 @@ class LineStringField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. code-block:: js |     .. code-block:: js | ||||||
|  |  | ||||||
|         {'type' : 'LineString' , |         { "type" : "LineString" , | ||||||
|          'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]} |           "coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]} | ||||||
|  |  | ||||||
|     You can either pass a dict with the full information or a list of points. |     You can either pass a dict with the full information or a list of points. | ||||||
|  |  | ||||||
| @@ -1963,7 +1955,7 @@ class LineStringField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. versionadded:: 0.8 |     .. versionadded:: 0.8 | ||||||
|     """ |     """ | ||||||
|     _type = 'LineString' |     _type = "LineString" | ||||||
|  |  | ||||||
|  |  | ||||||
| class PolygonField(GeoJsonBaseField): | class PolygonField(GeoJsonBaseField): | ||||||
| @@ -1973,9 +1965,9 @@ class PolygonField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. code-block:: js |     .. code-block:: js | ||||||
|  |  | ||||||
|         {'type' : 'Polygon' , |         { "type" : "Polygon" , | ||||||
|          'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]], |           "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], | ||||||
|                           [[x1, y1], [x1, y1] ... [xn, yn]]} |                            [[x1, y1], [x1, y1] ... [xn, yn]]} | ||||||
|  |  | ||||||
|     You can either pass a dict with the full information or a list |     You can either pass a dict with the full information or a list | ||||||
|     of LineStrings. The first LineString being the outside and the rest being |     of LineStrings. The first LineString being the outside and the rest being | ||||||
| @@ -1985,7 +1977,7 @@ class PolygonField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. versionadded:: 0.8 |     .. versionadded:: 0.8 | ||||||
|     """ |     """ | ||||||
|     _type = 'Polygon' |     _type = "Polygon" | ||||||
|  |  | ||||||
|  |  | ||||||
| class MultiPointField(GeoJsonBaseField): | class MultiPointField(GeoJsonBaseField): | ||||||
| @@ -1995,8 +1987,8 @@ class MultiPointField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. code-block:: js |     .. code-block:: js | ||||||
|  |  | ||||||
|         {'type' : 'MultiPoint' , |         { "type" : "MultiPoint" , | ||||||
|          'coordinates' : [[x1, y1], [x2, y2]]} |           "coordinates" : [[x1, y1], [x2, y2]]} | ||||||
|  |  | ||||||
|     You can either pass a dict with the full information or a list |     You can either pass a dict with the full information or a list | ||||||
|     to set the value. |     to set the value. | ||||||
| @@ -2005,7 +1997,7 @@ class MultiPointField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. versionadded:: 0.9 |     .. versionadded:: 0.9 | ||||||
|     """ |     """ | ||||||
|     _type = 'MultiPoint' |     _type = "MultiPoint" | ||||||
|  |  | ||||||
|  |  | ||||||
| class MultiLineStringField(GeoJsonBaseField): | class MultiLineStringField(GeoJsonBaseField): | ||||||
| @@ -2015,9 +2007,9 @@ class MultiLineStringField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. code-block:: js |     .. code-block:: js | ||||||
|  |  | ||||||
|         {'type' : 'MultiLineString' , |         { "type" : "MultiLineString" , | ||||||
|          'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]], |           "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], | ||||||
|                           [[x1, y1], [x1, y1] ... [xn, yn]]]} |                            [[x1, y1], [x1, y1] ... [xn, yn]]]} | ||||||
|  |  | ||||||
|     You can either pass a dict with the full information or a list of points. |     You can either pass a dict with the full information or a list of points. | ||||||
|  |  | ||||||
| @@ -2025,7 +2017,7 @@ class MultiLineStringField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. versionadded:: 0.9 |     .. versionadded:: 0.9 | ||||||
|     """ |     """ | ||||||
|     _type = 'MultiLineString' |     _type = "MultiLineString" | ||||||
|  |  | ||||||
|  |  | ||||||
| class MultiPolygonField(GeoJsonBaseField): | class MultiPolygonField(GeoJsonBaseField): | ||||||
| @@ -2035,14 +2027,14 @@ class MultiPolygonField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. code-block:: js |     .. code-block:: js | ||||||
|  |  | ||||||
|         {'type' : 'MultiPolygon' , |         { "type" : "MultiPolygon" , | ||||||
|          'coordinates' : [[ |           "coordinates" : [[ | ||||||
|                [[x1, y1], [x1, y1] ... [xn, yn]], |                 [[x1, y1], [x1, y1] ... [xn, yn]], | ||||||
|                [[x1, y1], [x1, y1] ... [xn, yn]] |                 [[x1, y1], [x1, y1] ... [xn, yn]] | ||||||
|            ], [ |             ], [ | ||||||
|                [[x1, y1], [x1, y1] ... [xn, yn]], |                 [[x1, y1], [x1, y1] ... [xn, yn]], | ||||||
|                [[x1, y1], [x1, y1] ... [xn, yn]] |                 [[x1, y1], [x1, y1] ... [xn, yn]] | ||||||
|            ] |             ] | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     You can either pass a dict with the full information or a list |     You can either pass a dict with the full information or a list | ||||||
| @@ -2052,4 +2044,4 @@ class MultiPolygonField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|     .. versionadded:: 0.9 |     .. versionadded:: 0.9 | ||||||
|     """ |     """ | ||||||
|     _type = 'MultiPolygon' |     _type = "MultiPolygon" | ||||||
|   | |||||||
| @@ -1,9 +1,7 @@ | |||||||
| """ | """Helper functions and types to aid with Python 2.5 - 3 support.""" | ||||||
| Helper functions, constants, and types to aid with Python v2.7 - v3.x and |  | ||||||
| PyMongo v2.7 - v3.x support. | import sys | ||||||
| """ |  | ||||||
| import pymongo | import pymongo | ||||||
| import six |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if pymongo.version_tuple[0] < 3: | if pymongo.version_tuple[0] < 3: | ||||||
| @@ -11,15 +9,29 @@ if pymongo.version_tuple[0] < 3: | |||||||
| else: | else: | ||||||
|     IS_PYMONGO_3 = True |     IS_PYMONGO_3 = True | ||||||
|  |  | ||||||
|  | PY3 = sys.version_info[0] == 3 | ||||||
|  |  | ||||||
| # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. | if PY3: | ||||||
| StringIO = six.BytesIO |     import codecs | ||||||
|  |     from io import BytesIO as StringIO | ||||||
|  |  | ||||||
| # Additionally for Py2, try to use the faster cStringIO, if available |     # return s converted to binary.  b('test') should be equivalent to b'test' | ||||||
| if not six.PY3: |     def b(s): | ||||||
|  |         return codecs.latin_1_encode(s)[0] | ||||||
|  |  | ||||||
|  |     bin_type = bytes | ||||||
|  |     txt_type = str | ||||||
|  | else: | ||||||
|     try: |     try: | ||||||
|         import cStringIO |         from cStringIO import StringIO | ||||||
|     except ImportError: |     except ImportError: | ||||||
|         pass |         from StringIO import StringIO | ||||||
|     else: |  | ||||||
|         StringIO = cStringIO.StringIO |     # Conversion to binary only necessary in Python 3 | ||||||
|  |     def b(s): | ||||||
|  |         return s | ||||||
|  |  | ||||||
|  |     bin_type = str | ||||||
|  |     txt_type = unicode | ||||||
|  |  | ||||||
|  | str_types = (bin_type, txt_type) | ||||||
|   | |||||||
| @@ -1,17 +1,11 @@ | |||||||
| from mongoengine.errors import * | from mongoengine.errors import (DoesNotExist, InvalidQueryError, | ||||||
|  |                                 MultipleObjectsReturned, NotUniqueError, | ||||||
|  |                                 OperationError) | ||||||
| from mongoengine.queryset.field_list import * | from mongoengine.queryset.field_list import * | ||||||
| from mongoengine.queryset.manager import * | from mongoengine.queryset.manager import * | ||||||
| from mongoengine.queryset.queryset import * | from mongoengine.queryset.queryset import * | ||||||
| from mongoengine.queryset.transform import * | from mongoengine.queryset.transform import * | ||||||
| from mongoengine.queryset.visitor import * | from mongoengine.queryset.visitor import * | ||||||
|  |  | ||||||
| # Expose just the public subset of all imported objects and constants. | __all__ = (field_list.__all__ + manager.__all__ + queryset.__all__ + | ||||||
| __all__ = ( |            transform.__all__ + visitor.__all__) | ||||||
|     'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager', |  | ||||||
|     'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL', |  | ||||||
|  |  | ||||||
|     # Errors that might be related to a queryset, mostly here for backward |  | ||||||
|     # compatibility |  | ||||||
|     'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned', |  | ||||||
|     'NotUniqueError', 'OperationError', |  | ||||||
| ) |  | ||||||
|   | |||||||
| @@ -12,10 +12,9 @@ from bson.code import Code | |||||||
| import pymongo | import pymongo | ||||||
| import pymongo.errors | import pymongo.errors | ||||||
| from pymongo.common import validate_read_preference | from pymongo.common import validate_read_preference | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
| from mongoengine.base import get_document | from mongoengine.base.common import get_document | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.connection import get_db | from mongoengine.connection import get_db | ||||||
| from mongoengine.context_managers import switch_db | from mongoengine.context_managers import switch_db | ||||||
| @@ -74,16 +73,15 @@ class BaseQuerySet(object): | |||||||
|         # subclasses of the class being used |         # subclasses of the class being used | ||||||
|         if document._meta.get('allow_inheritance') is True: |         if document._meta.get('allow_inheritance') is True: | ||||||
|             if len(self._document._subclasses) == 1: |             if len(self._document._subclasses) == 1: | ||||||
|                 self._initial_query = {'_cls': self._document._subclasses[0]} |                 self._initial_query = {"_cls": self._document._subclasses[0]} | ||||||
|             else: |             else: | ||||||
|                 self._initial_query = { |                 self._initial_query = { | ||||||
|                     '_cls': {'$in': self._document._subclasses}} |                     "_cls": {"$in": self._document._subclasses}} | ||||||
|             self._loaded_fields = QueryFieldList(always_include=['_cls']) |             self._loaded_fields = QueryFieldList(always_include=['_cls']) | ||||||
|         self._cursor_obj = None |         self._cursor_obj = None | ||||||
|         self._limit = None |         self._limit = None | ||||||
|         self._skip = None |         self._skip = None | ||||||
|         self._hint = -1  # Using -1 as None is a valid value for hint |         self._hint = -1  # Using -1 as None is a valid value for hint | ||||||
|         self._batch_size = None |  | ||||||
|         self.only_fields = [] |         self.only_fields = [] | ||||||
|         self._max_time_ms = None |         self._max_time_ms = None | ||||||
|  |  | ||||||
| @@ -106,8 +104,8 @@ class BaseQuerySet(object): | |||||||
|         if q_obj: |         if q_obj: | ||||||
|             # make sure proper query object is passed |             # make sure proper query object is passed | ||||||
|             if not isinstance(q_obj, QNode): |             if not isinstance(q_obj, QNode): | ||||||
|                 msg = ('Not a query object: %s. ' |                 msg = ("Not a query object: %s. " | ||||||
|                        'Did you intend to use key=value?' % q_obj) |                        "Did you intend to use key=value?" % q_obj) | ||||||
|                 raise InvalidQueryError(msg) |                 raise InvalidQueryError(msg) | ||||||
|             query &= q_obj |             query &= q_obj | ||||||
|  |  | ||||||
| @@ -134,10 +132,10 @@ class BaseQuerySet(object): | |||||||
|         obj_dict = self.__dict__.copy() |         obj_dict = self.__dict__.copy() | ||||||
|  |  | ||||||
|         # don't picke collection, instead pickle collection params |         # don't picke collection, instead pickle collection params | ||||||
|         obj_dict.pop('_collection_obj') |         obj_dict.pop("_collection_obj") | ||||||
|  |  | ||||||
|         # don't pickle cursor |         # don't pickle cursor | ||||||
|         obj_dict['_cursor_obj'] = None |         obj_dict["_cursor_obj"] = None | ||||||
|  |  | ||||||
|         return obj_dict |         return obj_dict | ||||||
|  |  | ||||||
| @@ -148,7 +146,7 @@ class BaseQuerySet(object): | |||||||
|         See https://github.com/MongoEngine/mongoengine/issues/442 |         See https://github.com/MongoEngine/mongoengine/issues/442 | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         obj_dict['_collection_obj'] = obj_dict['_document']._get_collection() |         obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() | ||||||
|  |  | ||||||
|         # update attributes |         # update attributes | ||||||
|         self.__dict__.update(obj_dict) |         self.__dict__.update(obj_dict) | ||||||
| @@ -167,7 +165,7 @@ class BaseQuerySet(object): | |||||||
|                 queryset._skip, queryset._limit = key.start, key.stop |                 queryset._skip, queryset._limit = key.start, key.stop | ||||||
|                 if key.start and key.stop: |                 if key.start and key.stop: | ||||||
|                     queryset._limit = key.stop - key.start |                     queryset._limit = key.stop - key.start | ||||||
|             except IndexError as err: |             except IndexError, err: | ||||||
|                 # PyMongo raises an error if key.start == key.stop, catch it, |                 # PyMongo raises an error if key.start == key.stop, catch it, | ||||||
|                 # bin it, kill it. |                 # bin it, kill it. | ||||||
|                 start = key.start or 0 |                 start = key.start or 0 | ||||||
| @@ -200,16 +198,19 @@ class BaseQuerySet(object): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def _has_data(self): |     def _has_data(self): | ||||||
|         """Return True if cursor has any data.""" |         """ Retrieves whether cursor has any data. """ | ||||||
|  |  | ||||||
|         queryset = self.order_by() |         queryset = self.order_by() | ||||||
|         return False if queryset.first() is None else True |         return False if queryset.first() is None else True | ||||||
|  |  | ||||||
|     def __nonzero__(self): |     def __nonzero__(self): | ||||||
|         """Avoid to open all records in an if stmt in Py2.""" |         """ Avoid to open all records in an if stmt in Py2. """ | ||||||
|  |  | ||||||
|         return self._has_data() |         return self._has_data() | ||||||
|  |  | ||||||
|     def __bool__(self): |     def __bool__(self): | ||||||
|         """Avoid to open all records in an if stmt in Py3.""" |         """ Avoid to open all records in an if stmt in Py3. """ | ||||||
|  |  | ||||||
|         return self._has_data() |         return self._has_data() | ||||||
|  |  | ||||||
|     # Core functions |     # Core functions | ||||||
| @@ -237,7 +238,7 @@ class BaseQuerySet(object): | |||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         if queryset._search_text: |         if queryset._search_text: | ||||||
|             raise OperationError( |             raise OperationError( | ||||||
|                 'It is not possible to use search_text two times.') |                 "It is not possible to use search_text two times.") | ||||||
|  |  | ||||||
|         query_kwargs = SON({'$search': text}) |         query_kwargs = SON({'$search': text}) | ||||||
|         if language: |         if language: | ||||||
| @@ -266,7 +267,7 @@ class BaseQuerySet(object): | |||||||
|         try: |         try: | ||||||
|             result = queryset.next() |             result = queryset.next() | ||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             msg = ('%s matching query does not exist.' |             msg = ("%s matching query does not exist." | ||||||
|                    % queryset._document._class_name) |                    % queryset._document._class_name) | ||||||
|             raise queryset._document.DoesNotExist(msg) |             raise queryset._document.DoesNotExist(msg) | ||||||
|         try: |         try: | ||||||
| @@ -274,8 +275,6 @@ class BaseQuerySet(object): | |||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             return result |             return result | ||||||
|  |  | ||||||
|         # If we were able to retrieve the 2nd doc, rewind the cursor and |  | ||||||
|         # raise the MultipleObjectsReturned exception. |  | ||||||
|         queryset.rewind() |         queryset.rewind() | ||||||
|         message = u'%d items returned, instead of 1' % queryset.count() |         message = u'%d items returned, instead of 1' % queryset.count() | ||||||
|         raise queryset._document.MultipleObjectsReturned(message) |         raise queryset._document.MultipleObjectsReturned(message) | ||||||
| @@ -288,7 +287,8 @@ class BaseQuerySet(object): | |||||||
|         return self._document(**kwargs).save() |         return self._document(**kwargs).save() | ||||||
|  |  | ||||||
|     def first(self): |     def first(self): | ||||||
|         """Retrieve the first object matching the query.""" |         """Retrieve the first object matching the query. | ||||||
|  |         """ | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         try: |         try: | ||||||
|             result = queryset[0] |             result = queryset[0] | ||||||
| @@ -337,7 +337,7 @@ class BaseQuerySet(object): | |||||||
|                        % str(self._document)) |                        % str(self._document)) | ||||||
|                 raise OperationError(msg) |                 raise OperationError(msg) | ||||||
|             if doc.pk and not doc._created: |             if doc.pk and not doc._created: | ||||||
|                 msg = 'Some documents have ObjectIds use doc.update() instead' |                 msg = "Some documents have ObjectIds use doc.update() instead" | ||||||
|                 raise OperationError(msg) |                 raise OperationError(msg) | ||||||
|  |  | ||||||
|         signal_kwargs = signal_kwargs or {} |         signal_kwargs = signal_kwargs or {} | ||||||
| @@ -347,17 +347,17 @@ class BaseQuerySet(object): | |||||||
|         raw = [doc.to_mongo() for doc in docs] |         raw = [doc.to_mongo() for doc in docs] | ||||||
|         try: |         try: | ||||||
|             ids = self._collection.insert(raw, **write_concern) |             ids = self._collection.insert(raw, **write_concern) | ||||||
|         except pymongo.errors.DuplicateKeyError as err: |         except pymongo.errors.DuplicateKeyError, err: | ||||||
|             message = 'Could not save document (%s)' |             message = 'Could not save document (%s)' | ||||||
|             raise NotUniqueError(message % six.text_type(err)) |             raise NotUniqueError(message % unicode(err)) | ||||||
|         except pymongo.errors.OperationFailure as err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             message = 'Could not save document (%s)' |             message = 'Could not save document (%s)' | ||||||
|             if re.match('^E1100[01] duplicate key', six.text_type(err)): |             if re.match('^E1100[01] duplicate key', unicode(err)): | ||||||
|                 # E11000 - duplicate key error index |                 # E11000 - duplicate key error index | ||||||
|                 # E11001 - duplicate key on update |                 # E11001 - duplicate key on update | ||||||
|                 message = u'Tried to save duplicate unique keys (%s)' |                 message = u'Tried to save duplicate unique keys (%s)' | ||||||
|                 raise NotUniqueError(message % six.text_type(err)) |                 raise NotUniqueError(message % unicode(err)) | ||||||
|             raise OperationError(message % six.text_type(err)) |             raise OperationError(message % unicode(err)) | ||||||
|  |  | ||||||
|         if not load_bulk: |         if not load_bulk: | ||||||
|             signals.post_bulk_insert.send( |             signals.post_bulk_insert.send( | ||||||
| @@ -383,8 +383,7 @@ class BaseQuerySet(object): | |||||||
|             return 0 |             return 0 | ||||||
|         return self._cursor.count(with_limit_and_skip=with_limit_and_skip) |         return self._cursor.count(with_limit_and_skip=with_limit_and_skip) | ||||||
|  |  | ||||||
|     def delete(self, write_concern=None, _from_doc_delete=False, |     def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): | ||||||
|                cascade_refs=None): |  | ||||||
|         """Delete the documents matched by the query. |         """Delete the documents matched by the query. | ||||||
|  |  | ||||||
|         :param write_concern: Extra keyword arguments are passed down which |         :param write_concern: Extra keyword arguments are passed down which | ||||||
| @@ -407,9 +406,8 @@ class BaseQuerySet(object): | |||||||
|         # Handle deletes where skips or limits have been applied or |         # Handle deletes where skips or limits have been applied or | ||||||
|         # there is an untriggered delete signal |         # there is an untriggered delete signal | ||||||
|         has_delete_signal = signals.signals_available and ( |         has_delete_signal = signals.signals_available and ( | ||||||
|             signals.pre_delete.has_receivers_for(doc) or |             signals.pre_delete.has_receivers_for(self._document) or | ||||||
|             signals.post_delete.has_receivers_for(doc) |             signals.post_delete.has_receivers_for(self._document)) | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         call_document_delete = (queryset._skip or queryset._limit or |         call_document_delete = (queryset._skip or queryset._limit or | ||||||
|                                 has_delete_signal) and not _from_doc_delete |                                 has_delete_signal) and not _from_doc_delete | ||||||
| @@ -422,44 +420,37 @@ class BaseQuerySet(object): | |||||||
|             return cnt |             return cnt | ||||||
|  |  | ||||||
|         delete_rules = doc._meta.get('delete_rules') or {} |         delete_rules = doc._meta.get('delete_rules') or {} | ||||||
|         delete_rules = list(delete_rules.items()) |  | ||||||
|  |  | ||||||
|         # Check for DENY rules before actually deleting/nullifying any other |         # Check for DENY rules before actually deleting/nullifying any other | ||||||
|         # references |         # references | ||||||
|         for rule_entry, rule in delete_rules: |         for rule_entry in delete_rules: | ||||||
|             document_cls, field_name = rule_entry |             document_cls, field_name = rule_entry | ||||||
|             if document_cls._meta.get('abstract'): |             if document_cls._meta.get('abstract'): | ||||||
|                 continue |                 continue | ||||||
|  |             rule = doc._meta['delete_rules'][rule_entry] | ||||||
|  |             if rule == DENY and document_cls.objects( | ||||||
|  |                     **{field_name + '__in': self}).count() > 0: | ||||||
|  |                 msg = ("Could not delete document (%s.%s refers to it)" | ||||||
|  |                        % (document_cls.__name__, field_name)) | ||||||
|  |                 raise OperationError(msg) | ||||||
|  |  | ||||||
|             if rule == DENY: |         for rule_entry in delete_rules: | ||||||
|                 refs = document_cls.objects(**{field_name + '__in': self}) |  | ||||||
|                 if refs.limit(1).count() > 0: |  | ||||||
|                     raise OperationError( |  | ||||||
|                         'Could not delete document (%s.%s refers to it)' |  | ||||||
|                         % (document_cls.__name__, field_name) |  | ||||||
|                     ) |  | ||||||
|  |  | ||||||
|         # Check all the other rules |  | ||||||
|         for rule_entry, rule in delete_rules: |  | ||||||
|             document_cls, field_name = rule_entry |             document_cls, field_name = rule_entry | ||||||
|             if document_cls._meta.get('abstract'): |             if document_cls._meta.get('abstract'): | ||||||
|                 continue |                 continue | ||||||
|  |             rule = doc._meta['delete_rules'][rule_entry] | ||||||
|             if rule == CASCADE: |             if rule == CASCADE: | ||||||
|                 cascade_refs = set() if cascade_refs is None else cascade_refs |                 cascade_refs = set() if cascade_refs is None else cascade_refs | ||||||
|                 # Handle recursive reference |                 # Handle recursive reference | ||||||
|                 if doc._collection == document_cls._collection: |                 if doc._collection == document_cls._collection: | ||||||
|                     for ref in queryset: |                     for ref in queryset: | ||||||
|                         cascade_refs.add(ref.id) |                         cascade_refs.add(ref.id) | ||||||
|                 refs = document_cls.objects(**{field_name + '__in': self, |                 ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) | ||||||
|                                                'pk__nin': cascade_refs}) |                 ref_q_count = ref_q.count() | ||||||
|                 if refs.count() > 0: |                 if ref_q_count > 0: | ||||||
|                     refs.delete(write_concern=write_concern, |                     ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs) | ||||||
|                                 cascade_refs=cascade_refs) |  | ||||||
|             elif rule == NULLIFY: |             elif rule == NULLIFY: | ||||||
|                 document_cls.objects(**{field_name + '__in': self}).update( |                 document_cls.objects(**{field_name + '__in': self}).update( | ||||||
|                     write_concern=write_concern, |                     write_concern=write_concern, **{'unset__%s' % field_name: 1}) | ||||||
|                     **{'unset__%s' % field_name: 1}) |  | ||||||
|             elif rule == PULL: |             elif rule == PULL: | ||||||
|                 document_cls.objects(**{field_name + '__in': self}).update( |                 document_cls.objects(**{field_name + '__in': self}).update( | ||||||
|                     write_concern=write_concern, |                     write_concern=write_concern, | ||||||
| @@ -467,7 +458,7 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         result = queryset._collection.remove(queryset._query, **write_concern) |         result = queryset._collection.remove(queryset._query, **write_concern) | ||||||
|         if result: |         if result: | ||||||
|             return result.get('n') |             return result.get("n") | ||||||
|  |  | ||||||
|     def update(self, upsert=False, multi=True, write_concern=None, |     def update(self, upsert=False, multi=True, write_concern=None, | ||||||
|                full_result=False, **update): |                full_result=False, **update): | ||||||
| @@ -488,7 +479,7 @@ class BaseQuerySet(object): | |||||||
|         .. versionadded:: 0.2 |         .. versionadded:: 0.2 | ||||||
|         """ |         """ | ||||||
|         if not update and not upsert: |         if not update and not upsert: | ||||||
|             raise OperationError('No update parameters, would remove data') |             raise OperationError("No update parameters, would remove data") | ||||||
|  |  | ||||||
|         if write_concern is None: |         if write_concern is None: | ||||||
|             write_concern = {} |             write_concern = {} | ||||||
| @@ -501,9 +492,9 @@ class BaseQuerySet(object): | |||||||
|         # then ensure we add _cls to the update operation |         # then ensure we add _cls to the update operation | ||||||
|         if upsert and '_cls' in query: |         if upsert and '_cls' in query: | ||||||
|             if '$set' in update: |             if '$set' in update: | ||||||
|                 update['$set']['_cls'] = queryset._document._class_name |                 update["$set"]["_cls"] = queryset._document._class_name | ||||||
|             else: |             else: | ||||||
|                 update['$set'] = {'_cls': queryset._document._class_name} |                 update["$set"] = {"_cls": queryset._document._class_name} | ||||||
|         try: |         try: | ||||||
|             result = queryset._collection.update(query, update, multi=multi, |             result = queryset._collection.update(query, update, multi=multi, | ||||||
|                                                  upsert=upsert, **write_concern) |                                                  upsert=upsert, **write_concern) | ||||||
| @@ -511,13 +502,13 @@ class BaseQuerySet(object): | |||||||
|                 return result |                 return result | ||||||
|             elif result: |             elif result: | ||||||
|                 return result['n'] |                 return result['n'] | ||||||
|         except pymongo.errors.DuplicateKeyError as err: |         except pymongo.errors.DuplicateKeyError, err: | ||||||
|             raise NotUniqueError(u'Update failed (%s)' % six.text_type(err)) |             raise NotUniqueError(u'Update failed (%s)' % unicode(err)) | ||||||
|         except pymongo.errors.OperationFailure as err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             if six.text_type(err) == u'multi not coded yet': |             if unicode(err) == u'multi not coded yet': | ||||||
|                 message = u'update() method requires MongoDB 1.1.3+' |                 message = u'update() method requires MongoDB 1.1.3+' | ||||||
|                 raise OperationError(message) |                 raise OperationError(message) | ||||||
|             raise OperationError(u'Update failed (%s)' % six.text_type(err)) |             raise OperationError(u'Update failed (%s)' % unicode(err)) | ||||||
|  |  | ||||||
|     def upsert_one(self, write_concern=None, **update): |     def upsert_one(self, write_concern=None, **update): | ||||||
|         """Overwrite or add the first document matched by the query. |         """Overwrite or add the first document matched by the query. | ||||||
| @@ -588,11 +579,11 @@ class BaseQuerySet(object): | |||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         if remove and new: |         if remove and new: | ||||||
|             raise OperationError('Conflicting parameters: remove and new') |             raise OperationError("Conflicting parameters: remove and new") | ||||||
|  |  | ||||||
|         if not update and not upsert and not remove: |         if not update and not upsert and not remove: | ||||||
|             raise OperationError( |             raise OperationError( | ||||||
|                 'No update parameters, must either update or remove') |                 "No update parameters, must either update or remove") | ||||||
|  |  | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         query = queryset._query |         query = queryset._query | ||||||
| @@ -603,7 +594,7 @@ class BaseQuerySet(object): | |||||||
|         try: |         try: | ||||||
|             if IS_PYMONGO_3: |             if IS_PYMONGO_3: | ||||||
|                 if full_response: |                 if full_response: | ||||||
|                     msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' |                     msg = "With PyMongo 3+, it is not possible anymore to get the full response." | ||||||
|                     warnings.warn(msg, DeprecationWarning) |                     warnings.warn(msg, DeprecationWarning) | ||||||
|                 if remove: |                 if remove: | ||||||
|                     result = queryset._collection.find_one_and_delete( |                     result = queryset._collection.find_one_and_delete( | ||||||
| @@ -621,14 +612,14 @@ class BaseQuerySet(object): | |||||||
|                 result = queryset._collection.find_and_modify( |                 result = queryset._collection.find_and_modify( | ||||||
|                     query, update, upsert=upsert, sort=sort, remove=remove, new=new, |                     query, update, upsert=upsert, sort=sort, remove=remove, new=new, | ||||||
|                     full_response=full_response, **self._cursor_args) |                     full_response=full_response, **self._cursor_args) | ||||||
|         except pymongo.errors.DuplicateKeyError as err: |         except pymongo.errors.DuplicateKeyError, err: | ||||||
|             raise NotUniqueError(u'Update failed (%s)' % err) |             raise NotUniqueError(u"Update failed (%s)" % err) | ||||||
|         except pymongo.errors.OperationFailure as err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             raise OperationError(u'Update failed (%s)' % err) |             raise OperationError(u"Update failed (%s)" % err) | ||||||
|  |  | ||||||
|         if full_response: |         if full_response: | ||||||
|             if result['value'] is not None: |             if result["value"] is not None: | ||||||
|                 result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields) |                 result["value"] = self._document._from_son(result["value"], only_fields=self.only_fields) | ||||||
|         else: |         else: | ||||||
|             if result is not None: |             if result is not None: | ||||||
|                 result = self._document._from_son(result, only_fields=self.only_fields) |                 result = self._document._from_son(result, only_fields=self.only_fields) | ||||||
| @@ -646,7 +637,7 @@ class BaseQuerySet(object): | |||||||
|         """ |         """ | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         if not queryset._query_obj.empty: |         if not queryset._query_obj.empty: | ||||||
|             msg = 'Cannot use a filter whilst using `with_id`' |             msg = "Cannot use a filter whilst using `with_id`" | ||||||
|             raise InvalidQueryError(msg) |             raise InvalidQueryError(msg) | ||||||
|         return queryset.filter(pk=object_id).first() |         return queryset.filter(pk=object_id).first() | ||||||
|  |  | ||||||
| @@ -690,7 +681,7 @@ class BaseQuerySet(object): | |||||||
|         Only return instances of this document and not any inherited documents |         Only return instances of this document and not any inherited documents | ||||||
|         """ |         """ | ||||||
|         if self._document._meta.get('allow_inheritance') is True: |         if self._document._meta.get('allow_inheritance') is True: | ||||||
|             self._initial_query = {'_cls': self._document._class_name} |             self._initial_query = {"_cls": self._document._class_name} | ||||||
|  |  | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
| @@ -790,19 +781,6 @@ class BaseQuerySet(object): | |||||||
|         queryset._hint = index |         queryset._hint = index | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def batch_size(self, size): |  | ||||||
|         """Limit the number of documents returned in a single batch (each |  | ||||||
|         batch requires a round trip to the server). |  | ||||||
|  |  | ||||||
|         See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size |  | ||||||
|         for details. |  | ||||||
|  |  | ||||||
|         :param size: desired size of each batch. |  | ||||||
|         """ |  | ||||||
|         queryset = self.clone() |  | ||||||
|         queryset._batch_size = size |  | ||||||
|         return queryset |  | ||||||
|  |  | ||||||
|     def distinct(self, field): |     def distinct(self, field): | ||||||
|         """Return a list of distinct values for a given field. |         """Return a list of distinct values for a given field. | ||||||
|  |  | ||||||
| @@ -816,56 +794,49 @@ class BaseQuerySet(object): | |||||||
|         .. versionchanged:: 0.6 - Improved db_field refrence handling |         .. versionchanged:: 0.6 - Improved db_field refrence handling | ||||||
|         """ |         """ | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             field = self._fields_to_dbfields([field]).pop() |             field = self._fields_to_dbfields([field]).pop() | ||||||
|         except LookUpError: |         finally: | ||||||
|             pass |             distinct = self._dereference(queryset._cursor.distinct(field), 1, | ||||||
|  |                                          name=field, instance=self._document) | ||||||
|  |  | ||||||
|         distinct = self._dereference(queryset._cursor.distinct(field), 1, |             doc_field = self._document._fields.get(field.split('.', 1)[0]) | ||||||
|                                      name=field, instance=self._document) |             instance = False | ||||||
|  |             # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) | ||||||
|         doc_field = self._document._fields.get(field.split('.', 1)[0]) |             EmbeddedDocumentField = _import_class('EmbeddedDocumentField') | ||||||
|         instance = None |             ListField = _import_class('ListField') | ||||||
|  |             GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') | ||||||
|         # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) |             if isinstance(doc_field, ListField): | ||||||
|         EmbeddedDocumentField = _import_class('EmbeddedDocumentField') |                 doc_field = getattr(doc_field, "field", doc_field) | ||||||
|         ListField = _import_class('ListField') |             if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): | ||||||
|         GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') |                 instance = getattr(doc_field, "document_type", False) | ||||||
|         if isinstance(doc_field, ListField): |             # handle distinct on subdocuments | ||||||
|             doc_field = getattr(doc_field, 'field', doc_field) |             if '.' in field: | ||||||
|         if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): |                 for field_part in field.split('.')[1:]: | ||||||
|             instance = getattr(doc_field, 'document_type', None) |                     # if looping on embedded document, get the document type instance | ||||||
|  |                     if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): | ||||||
|         # handle distinct on subdocuments |                         doc_field = instance | ||||||
|         if '.' in field: |                     # now get the subdocument | ||||||
|             for field_part in field.split('.')[1:]: |                     doc_field = getattr(doc_field, field_part, doc_field) | ||||||
|                 # if looping on embedded document, get the document type instance |                     # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) | ||||||
|                 if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): |                     if isinstance(doc_field, ListField): | ||||||
|                     doc_field = instance |                         doc_field = getattr(doc_field, "field", doc_field) | ||||||
|                 # now get the subdocument |                     if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): | ||||||
|                 doc_field = getattr(doc_field, field_part, doc_field) |                         instance = getattr(doc_field, "document_type", False) | ||||||
|                 # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) |             if instance and isinstance(doc_field, (EmbeddedDocumentField, | ||||||
|                 if isinstance(doc_field, ListField): |                                                    GenericEmbeddedDocumentField)): | ||||||
|                     doc_field = getattr(doc_field, 'field', doc_field) |                 distinct = [instance(**doc) for doc in distinct] | ||||||
|                 if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): |             return distinct | ||||||
|                     instance = getattr(doc_field, 'document_type', None) |  | ||||||
|  |  | ||||||
|         if instance and isinstance(doc_field, (EmbeddedDocumentField, |  | ||||||
|                                                GenericEmbeddedDocumentField)): |  | ||||||
|             distinct = [instance(**doc) for doc in distinct] |  | ||||||
|  |  | ||||||
|         return distinct |  | ||||||
|  |  | ||||||
|     def only(self, *fields): |     def only(self, *fields): | ||||||
|         """Load only a subset of this document's fields. :: |         """Load only a subset of this document's fields. :: | ||||||
|  |  | ||||||
|             post = BlogPost.objects(...).only('title', 'author.name') |             post = BlogPost.objects(...).only("title", "author.name") | ||||||
|  |  | ||||||
|         .. note :: `only()` is chainable and will perform a union :: |         .. note :: `only()` is chainable and will perform a union :: | ||||||
|             So with the following it will fetch both: `title` and `author.name`:: |             So with the following it will fetch both: `title` and `author.name`:: | ||||||
|  |  | ||||||
|                 post = BlogPost.objects.only('title').only('author.name') |                 post = BlogPost.objects.only("title").only("author.name") | ||||||
|  |  | ||||||
|         :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any |         :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any | ||||||
|         field filters. |         field filters. | ||||||
| @@ -875,19 +846,19 @@ class BaseQuerySet(object): | |||||||
|         .. versionadded:: 0.3 |         .. versionadded:: 0.3 | ||||||
|         .. versionchanged:: 0.5 - Added subfield support |         .. versionchanged:: 0.5 - Added subfield support | ||||||
|         """ |         """ | ||||||
|         fields = {f: QueryFieldList.ONLY for f in fields} |         fields = dict([(f, QueryFieldList.ONLY) for f in fields]) | ||||||
|         self.only_fields = fields.keys() |         self.only_fields = fields.keys() | ||||||
|         return self.fields(True, **fields) |         return self.fields(True, **fields) | ||||||
|  |  | ||||||
|     def exclude(self, *fields): |     def exclude(self, *fields): | ||||||
|         """Opposite to .only(), exclude some document's fields. :: |         """Opposite to .only(), exclude some document's fields. :: | ||||||
|  |  | ||||||
|             post = BlogPost.objects(...).exclude('comments') |             post = BlogPost.objects(...).exclude("comments") | ||||||
|  |  | ||||||
|         .. note :: `exclude()` is chainable and will perform a union :: |         .. note :: `exclude()` is chainable and will perform a union :: | ||||||
|             So with the following it will exclude both: `title` and `author.name`:: |             So with the following it will exclude both: `title` and `author.name`:: | ||||||
|  |  | ||||||
|                 post = BlogPost.objects.exclude('title').exclude('author.name') |                 post = BlogPost.objects.exclude("title").exclude("author.name") | ||||||
|  |  | ||||||
|         :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any |         :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any | ||||||
|         field filters. |         field filters. | ||||||
| @@ -896,34 +867,28 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         .. versionadded:: 0.5 |         .. versionadded:: 0.5 | ||||||
|         """ |         """ | ||||||
|         fields = {f: QueryFieldList.EXCLUDE for f in fields} |         fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) | ||||||
|         return self.fields(**fields) |         return self.fields(**fields) | ||||||
|  |  | ||||||
|     def fields(self, _only_called=False, **kwargs): |     def fields(self, _only_called=False, **kwargs): | ||||||
|         """Manipulate how you load this document's fields. Used by `.only()` |         """Manipulate how you load this document's fields.  Used by `.only()` | ||||||
|         and `.exclude()` to manipulate which fields to retrieve. If called |         and `.exclude()` to manipulate which fields to retrieve.  Fields also | ||||||
|         directly, use a set of kwargs similar to the MongoDB projection |         allows for a greater level of control for example: | ||||||
|         document. For example: |  | ||||||
|  |  | ||||||
|         Include only a subset of fields: |         Retrieving a Subrange of Array Elements: | ||||||
|  |  | ||||||
|             posts = BlogPost.objects(...).fields(author=1, title=1) |         You can use the $slice operator to retrieve a subrange of elements in | ||||||
|  |         an array. For example to get the first 5 comments:: | ||||||
|  |  | ||||||
|         Exclude a specific field: |             post = BlogPost.objects(...).fields(slice__comments=5) | ||||||
|  |  | ||||||
|             posts = BlogPost.objects(...).fields(comments=0) |         :param kwargs: A dictionary identifying what to include | ||||||
|  |  | ||||||
|         To retrieve a subrange of array elements: |  | ||||||
|  |  | ||||||
|             posts = BlogPost.objects(...).fields(slice__comments=5) |  | ||||||
|  |  | ||||||
|         :param kwargs: A set keywors arguments identifying what to include. |  | ||||||
|  |  | ||||||
|         .. versionadded:: 0.5 |         .. versionadded:: 0.5 | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         # Check for an operator and transform to mongo-style if there is |         # Check for an operator and transform to mongo-style if there is | ||||||
|         operators = ['slice'] |         operators = ["slice"] | ||||||
|         cleaned_fields = [] |         cleaned_fields = [] | ||||||
|         for key, value in kwargs.items(): |         for key, value in kwargs.items(): | ||||||
|             parts = key.split('__') |             parts = key.split('__') | ||||||
| @@ -947,7 +912,7 @@ class BaseQuerySet(object): | |||||||
|         """Include all fields. Reset all previously calls of .only() or |         """Include all fields. Reset all previously calls of .only() or | ||||||
|         .exclude(). :: |         .exclude(). :: | ||||||
|  |  | ||||||
|             post = BlogPost.objects.exclude('comments').all_fields() |             post = BlogPost.objects.exclude("comments").all_fields() | ||||||
|  |  | ||||||
|         .. versionadded:: 0.5 |         .. versionadded:: 0.5 | ||||||
|         """ |         """ | ||||||
| @@ -968,14 +933,6 @@ class BaseQuerySet(object): | |||||||
|         queryset._ordering = queryset._get_order_by(keys) |         queryset._ordering = queryset._get_order_by(keys) | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def comment(self, text): |  | ||||||
|         """Add a comment to the query. |  | ||||||
|  |  | ||||||
|         See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment |  | ||||||
|         for details. |  | ||||||
|         """ |  | ||||||
|         return self._chainable_method('comment', text) |  | ||||||
|  |  | ||||||
|     def explain(self, format=False): |     def explain(self, format=False): | ||||||
|         """Return an explain plan record for the |         """Return an explain plan record for the | ||||||
|         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. |         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. | ||||||
| @@ -983,15 +940,8 @@ class BaseQuerySet(object): | |||||||
|         :param format: format the plan before returning it |         :param format: format the plan before returning it | ||||||
|         """ |         """ | ||||||
|         plan = self._cursor.explain() |         plan = self._cursor.explain() | ||||||
|  |  | ||||||
|         # TODO remove this option completely - it's useless. If somebody |  | ||||||
|         # wants to pretty-print the output, they easily can. |  | ||||||
|         if format: |         if format: | ||||||
|             msg = ('"format" param of BaseQuerySet.explain has been ' |  | ||||||
|                    'deprecated and will be removed in future versions.') |  | ||||||
|             warnings.warn(msg, DeprecationWarning) |  | ||||||
|             plan = pprint.pformat(plan) |             plan = pprint.pformat(plan) | ||||||
|  |  | ||||||
|         return plan |         return plan | ||||||
|  |  | ||||||
|     # DEPRECATED. Has no more impact on PyMongo 3+ |     # DEPRECATED. Has no more impact on PyMongo 3+ | ||||||
| @@ -1004,7 +954,7 @@ class BaseQuerySet(object): | |||||||
|         .. deprecated:: Ignored with PyMongo 3+ |         .. deprecated:: Ignored with PyMongo 3+ | ||||||
|         """ |         """ | ||||||
|         if IS_PYMONGO_3: |         if IS_PYMONGO_3: | ||||||
|             msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' |             msg = "snapshot is deprecated as it has no impact when using PyMongo 3+." | ||||||
|             warnings.warn(msg, DeprecationWarning) |             warnings.warn(msg, DeprecationWarning) | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         queryset._snapshot = enabled |         queryset._snapshot = enabled | ||||||
| @@ -1030,7 +980,7 @@ class BaseQuerySet(object): | |||||||
|         .. deprecated:: Ignored with PyMongo 3+ |         .. deprecated:: Ignored with PyMongo 3+ | ||||||
|         """ |         """ | ||||||
|         if IS_PYMONGO_3: |         if IS_PYMONGO_3: | ||||||
|             msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' |             msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+." | ||||||
|             warnings.warn(msg, DeprecationWarning) |             warnings.warn(msg, DeprecationWarning) | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         queryset._slave_okay = enabled |         queryset._slave_okay = enabled | ||||||
| @@ -1092,7 +1042,7 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         :param ms: the number of milliseconds before killing the query on the server |         :param ms: the number of milliseconds before killing the query on the server | ||||||
|         """ |         """ | ||||||
|         return self._chainable_method('max_time_ms', ms) |         return self._chainable_method("max_time_ms", ms) | ||||||
|  |  | ||||||
|     # JSON Helpers |     # JSON Helpers | ||||||
|  |  | ||||||
| @@ -1175,19 +1125,19 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         MapReduceDocument = _import_class('MapReduceDocument') |         MapReduceDocument = _import_class('MapReduceDocument') | ||||||
|  |  | ||||||
|         if not hasattr(self._collection, 'map_reduce'): |         if not hasattr(self._collection, "map_reduce"): | ||||||
|             raise NotImplementedError('Requires MongoDB >= 1.7.1') |             raise NotImplementedError("Requires MongoDB >= 1.7.1") | ||||||
|  |  | ||||||
|         map_f_scope = {} |         map_f_scope = {} | ||||||
|         if isinstance(map_f, Code): |         if isinstance(map_f, Code): | ||||||
|             map_f_scope = map_f.scope |             map_f_scope = map_f.scope | ||||||
|             map_f = six.text_type(map_f) |             map_f = unicode(map_f) | ||||||
|         map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) |         map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) | ||||||
|  |  | ||||||
|         reduce_f_scope = {} |         reduce_f_scope = {} | ||||||
|         if isinstance(reduce_f, Code): |         if isinstance(reduce_f, Code): | ||||||
|             reduce_f_scope = reduce_f.scope |             reduce_f_scope = reduce_f.scope | ||||||
|             reduce_f = six.text_type(reduce_f) |             reduce_f = unicode(reduce_f) | ||||||
|         reduce_f_code = queryset._sub_js_fields(reduce_f) |         reduce_f_code = queryset._sub_js_fields(reduce_f) | ||||||
|         reduce_f = Code(reduce_f_code, reduce_f_scope) |         reduce_f = Code(reduce_f_code, reduce_f_scope) | ||||||
|  |  | ||||||
| @@ -1197,7 +1147,7 @@ class BaseQuerySet(object): | |||||||
|             finalize_f_scope = {} |             finalize_f_scope = {} | ||||||
|             if isinstance(finalize_f, Code): |             if isinstance(finalize_f, Code): | ||||||
|                 finalize_f_scope = finalize_f.scope |                 finalize_f_scope = finalize_f.scope | ||||||
|                 finalize_f = six.text_type(finalize_f) |                 finalize_f = unicode(finalize_f) | ||||||
|             finalize_f_code = queryset._sub_js_fields(finalize_f) |             finalize_f_code = queryset._sub_js_fields(finalize_f) | ||||||
|             finalize_f = Code(finalize_f_code, finalize_f_scope) |             finalize_f = Code(finalize_f_code, finalize_f_scope) | ||||||
|             mr_args['finalize'] = finalize_f |             mr_args['finalize'] = finalize_f | ||||||
| @@ -1213,7 +1163,7 @@ class BaseQuerySet(object): | |||||||
|         else: |         else: | ||||||
|             map_reduce_function = 'map_reduce' |             map_reduce_function = 'map_reduce' | ||||||
|  |  | ||||||
|             if isinstance(output, six.string_types): |             if isinstance(output, basestring): | ||||||
|                 mr_args['out'] = output |                 mr_args['out'] = output | ||||||
|  |  | ||||||
|             elif isinstance(output, dict): |             elif isinstance(output, dict): | ||||||
| @@ -1226,7 +1176,7 @@ class BaseQuerySet(object): | |||||||
|                         break |                         break | ||||||
|  |  | ||||||
|                 else: |                 else: | ||||||
|                     raise OperationError('actionData not specified for output') |                     raise OperationError("actionData not specified for output") | ||||||
|  |  | ||||||
|                 db_alias = output.get('db_alias') |                 db_alias = output.get('db_alias') | ||||||
|                 remaing_args = ['db', 'sharded', 'nonAtomic'] |                 remaing_args = ['db', 'sharded', 'nonAtomic'] | ||||||
| @@ -1456,7 +1406,7 @@ class BaseQuerySet(object): | |||||||
|             # snapshot is not handled at all by PyMongo 3+ |             # snapshot is not handled at all by PyMongo 3+ | ||||||
|             # TODO: evaluate similar possibilities using modifiers |             # TODO: evaluate similar possibilities using modifiers | ||||||
|             if self._snapshot: |             if self._snapshot: | ||||||
|                 msg = 'The snapshot option is not anymore available with PyMongo 3+' |                 msg = "The snapshot option is not anymore available with PyMongo 3+" | ||||||
|                 warnings.warn(msg, DeprecationWarning) |                 warnings.warn(msg, DeprecationWarning) | ||||||
|             cursor_args = { |             cursor_args = { | ||||||
|                 'no_cursor_timeout': not self._timeout |                 'no_cursor_timeout': not self._timeout | ||||||
| @@ -1468,7 +1418,7 @@ class BaseQuerySet(object): | |||||||
|             if fields_name not in cursor_args: |             if fields_name not in cursor_args: | ||||||
|                 cursor_args[fields_name] = {} |                 cursor_args[fields_name] = {} | ||||||
|  |  | ||||||
|             cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'} |             cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"} | ||||||
|  |  | ||||||
|         return cursor_args |         return cursor_args | ||||||
|  |  | ||||||
| @@ -1509,9 +1459,6 @@ class BaseQuerySet(object): | |||||||
|             if self._hint != -1: |             if self._hint != -1: | ||||||
|                 self._cursor_obj.hint(self._hint) |                 self._cursor_obj.hint(self._hint) | ||||||
|  |  | ||||||
|             if self._batch_size is not None: |  | ||||||
|                 self._cursor_obj.batch_size(self._batch_size) |  | ||||||
|  |  | ||||||
|         return self._cursor_obj |         return self._cursor_obj | ||||||
|  |  | ||||||
|     def __deepcopy__(self, memo): |     def __deepcopy__(self, memo): | ||||||
| @@ -1523,8 +1470,8 @@ class BaseQuerySet(object): | |||||||
|         if self._mongo_query is None: |         if self._mongo_query is None: | ||||||
|             self._mongo_query = self._query_obj.to_query(self._document) |             self._mongo_query = self._query_obj.to_query(self._document) | ||||||
|             if self._class_check and self._initial_query: |             if self._class_check and self._initial_query: | ||||||
|                 if '_cls' in self._mongo_query: |                 if "_cls" in self._mongo_query: | ||||||
|                     self._mongo_query = {'$and': [self._initial_query, self._mongo_query]} |                     self._mongo_query = {"$and": [self._initial_query, self._mongo_query]} | ||||||
|                 else: |                 else: | ||||||
|                     self._mongo_query.update(self._initial_query) |                     self._mongo_query.update(self._initial_query) | ||||||
|         return self._mongo_query |         return self._mongo_query | ||||||
| @@ -1536,7 +1483,8 @@ class BaseQuerySet(object): | |||||||
|         return self.__dereference |         return self.__dereference | ||||||
|  |  | ||||||
|     def no_dereference(self): |     def no_dereference(self): | ||||||
|         """Turn off any dereferencing for the results of this queryset.""" |         """Turn off any dereferencing for the results of this queryset. | ||||||
|  |         """ | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         queryset._auto_dereference = False |         queryset._auto_dereference = False | ||||||
|         return queryset |         return queryset | ||||||
| @@ -1565,7 +1513,7 @@ class BaseQuerySet(object): | |||||||
|                     emit(null, 1); |                     emit(null, 1); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         """ % {'field': field} |         """ % dict(field=field) | ||||||
|         reduce_func = """ |         reduce_func = """ | ||||||
|             function(key, values) { |             function(key, values) { | ||||||
|                 var total = 0; |                 var total = 0; | ||||||
| @@ -1587,8 +1535,8 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         if normalize: |         if normalize: | ||||||
|             count = sum(frequencies.values()) |             count = sum(frequencies.values()) | ||||||
|             frequencies = {k: float(v) / count |             frequencies = dict([(k, float(v) / count) | ||||||
|                            for k, v in frequencies.items()} |                                 for k, v in frequencies.items()]) | ||||||
|  |  | ||||||
|         return frequencies |         return frequencies | ||||||
|  |  | ||||||
| @@ -1640,10 +1588,10 @@ class BaseQuerySet(object): | |||||||
|             } |             } | ||||||
|         """ |         """ | ||||||
|         total, data, types = self.exec_js(freq_func, field) |         total, data, types = self.exec_js(freq_func, field) | ||||||
|         values = {types.get(k): int(v) for k, v in data.iteritems()} |         values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) | ||||||
|  |  | ||||||
|         if normalize: |         if normalize: | ||||||
|             values = {k: float(v) / total for k, v in values.items()} |             values = dict([(k, float(v) / total) for k, v in values.items()]) | ||||||
|  |  | ||||||
|         frequencies = {} |         frequencies = {} | ||||||
|         for k, v in values.iteritems(): |         for k, v in values.iteritems(): | ||||||
| @@ -1665,14 +1613,14 @@ class BaseQuerySet(object): | |||||||
|                           for x in document._subclasses][1:] |                           for x in document._subclasses][1:] | ||||||
|         for field in fields: |         for field in fields: | ||||||
|             try: |             try: | ||||||
|                 field = '.'.join(f.db_field for f in |                 field = ".".join(f.db_field for f in | ||||||
|                                  document._lookup_field(field.split('.'))) |                                  document._lookup_field(field.split('.'))) | ||||||
|                 ret.append(field) |                 ret.append(field) | ||||||
|             except LookUpError as err: |             except LookUpError, err: | ||||||
|                 found = False |                 found = False | ||||||
|                 for subdoc in subclasses: |                 for subdoc in subclasses: | ||||||
|                     try: |                     try: | ||||||
|                         subfield = '.'.join(f.db_field for f in |                         subfield = ".".join(f.db_field for f in | ||||||
|                                             subdoc._lookup_field(field.split('.'))) |                                             subdoc._lookup_field(field.split('.'))) | ||||||
|                         ret.append(subfield) |                         ret.append(subfield) | ||||||
|                         found = True |                         found = True | ||||||
| @@ -1685,14 +1633,15 @@ class BaseQuerySet(object): | |||||||
|         return ret |         return ret | ||||||
|  |  | ||||||
|     def _get_order_by(self, keys): |     def _get_order_by(self, keys): | ||||||
|         """Creates a list of order by fields""" |         """Creates a list of order by fields | ||||||
|  |         """ | ||||||
|         key_list = [] |         key_list = [] | ||||||
|         for key in keys: |         for key in keys: | ||||||
|             if not key: |             if not key: | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             if key == '$text_score': |             if key == '$text_score': | ||||||
|                 key_list.append(('_text_score', {'$meta': 'textScore'})) |                 key_list.append(('_text_score', {'$meta': "textScore"})) | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             direction = pymongo.ASCENDING |             direction = pymongo.ASCENDING | ||||||
| @@ -1764,7 +1713,7 @@ class BaseQuerySet(object): | |||||||
|                     # If we need to coerce types, we need to determine the |                     # If we need to coerce types, we need to determine the | ||||||
|                     # type of this field and use the corresponding |                     # type of this field and use the corresponding | ||||||
|                     # .to_python(...) |                     # .to_python(...) | ||||||
|                     EmbeddedDocumentField = _import_class('EmbeddedDocumentField') |                     from mongoengine.fields import EmbeddedDocumentField | ||||||
|  |  | ||||||
|                     obj = self._document |                     obj = self._document | ||||||
|                     for chunk in path.split('.'): |                     for chunk in path.split('.'): | ||||||
| @@ -1798,7 +1747,7 @@ class BaseQuerySet(object): | |||||||
|             field_name = match.group(1).split('.') |             field_name = match.group(1).split('.') | ||||||
|             fields = self._document._lookup_field(field_name) |             fields = self._document._lookup_field(field_name) | ||||||
|             # Substitute the correct name for the field into the javascript |             # Substitute the correct name for the field into the javascript | ||||||
|             return '.'.join([f.db_field for f in fields]) |             return ".".join([f.db_field for f in fields]) | ||||||
|  |  | ||||||
|         code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) |         code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) | ||||||
|         code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, |         code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, | ||||||
| @@ -1809,21 +1758,21 @@ class BaseQuerySet(object): | |||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         method = getattr(queryset._cursor, method_name) |         method = getattr(queryset._cursor, method_name) | ||||||
|         method(val) |         method(val) | ||||||
|         setattr(queryset, '_' + method_name, val) |         setattr(queryset, "_" + method_name, val) | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     # Deprecated |     # Deprecated | ||||||
|     def ensure_index(self, **kwargs): |     def ensure_index(self, **kwargs): | ||||||
|         """Deprecated use :func:`Document.ensure_index`""" |         """Deprecated use :func:`Document.ensure_index`""" | ||||||
|         msg = ('Doc.objects()._ensure_index() is deprecated. ' |         msg = ("Doc.objects()._ensure_index() is deprecated. " | ||||||
|                'Use Doc.ensure_index() instead.') |                "Use Doc.ensure_index() instead.") | ||||||
|         warnings.warn(msg, DeprecationWarning) |         warnings.warn(msg, DeprecationWarning) | ||||||
|         self._document.__class__.ensure_index(**kwargs) |         self._document.__class__.ensure_index(**kwargs) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def _ensure_indexes(self): |     def _ensure_indexes(self): | ||||||
|         """Deprecated use :func:`~Document.ensure_indexes`""" |         """Deprecated use :func:`~Document.ensure_indexes`""" | ||||||
|         msg = ('Doc.objects()._ensure_indexes() is deprecated. ' |         msg = ("Doc.objects()._ensure_indexes() is deprecated. " | ||||||
|                'Use Doc.ensure_indexes() instead.') |                "Use Doc.ensure_indexes() instead.") | ||||||
|         warnings.warn(msg, DeprecationWarning) |         warnings.warn(msg, DeprecationWarning) | ||||||
|         self._document.__class__.ensure_indexes() |         self._document.__class__.ensure_indexes() | ||||||
|   | |||||||
| @@ -67,7 +67,7 @@ class QueryFieldList(object): | |||||||
|         return bool(self.fields) |         return bool(self.fields) | ||||||
|  |  | ||||||
|     def as_dict(self): |     def as_dict(self): | ||||||
|         field_list = {field: self.value for field in self.fields} |         field_list = dict((field, self.value) for field in self.fields) | ||||||
|         if self.slice: |         if self.slice: | ||||||
|             field_list.update(self.slice) |             field_list.update(self.slice) | ||||||
|         if self._id is not None: |         if self._id is not None: | ||||||
|   | |||||||
| @@ -27,10 +27,9 @@ class QuerySet(BaseQuerySet): | |||||||
|         in batches of ``ITER_CHUNK_SIZE``. |         in batches of ``ITER_CHUNK_SIZE``. | ||||||
|  |  | ||||||
|         If ``self._has_more`` the cursor hasn't been exhausted so cache then |         If ``self._has_more`` the cursor hasn't been exhausted so cache then | ||||||
|         batch. Otherwise iterate the result_cache. |         batch.  Otherwise iterate the result_cache. | ||||||
|         """ |         """ | ||||||
|         self._iter = True |         self._iter = True | ||||||
|  |  | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|             return self._iter_results() |             return self._iter_results() | ||||||
|  |  | ||||||
| @@ -43,56 +42,40 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._len is not None: |         if self._len is not None: | ||||||
|             return self._len |             return self._len | ||||||
|  |  | ||||||
|         # Populate the result cache with *all* of the docs in the cursor |  | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|  |             # populate the cache | ||||||
|             list(self._iter_results()) |             list(self._iter_results()) | ||||||
|  |  | ||||||
|         # Cache the length of the complete result cache and return it |  | ||||||
|         self._len = len(self._result_cache) |         self._len = len(self._result_cache) | ||||||
|         return self._len |         return self._len | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         """Provide a string representation of the QuerySet""" |         """Provides the string representation of the QuerySet | ||||||
|  |         """ | ||||||
|         if self._iter: |         if self._iter: | ||||||
|             return '.. queryset mid-iteration ..' |             return '.. queryset mid-iteration ..' | ||||||
|  |  | ||||||
|         self._populate_cache() |         self._populate_cache() | ||||||
|         data = self._result_cache[:REPR_OUTPUT_SIZE + 1] |         data = self._result_cache[:REPR_OUTPUT_SIZE + 1] | ||||||
|         if len(data) > REPR_OUTPUT_SIZE: |         if len(data) > REPR_OUTPUT_SIZE: | ||||||
|             data[-1] = '...(remaining elements truncated)...' |             data[-1] = "...(remaining elements truncated)..." | ||||||
|         return repr(data) |         return repr(data) | ||||||
|  |  | ||||||
|     def _iter_results(self): |     def _iter_results(self): | ||||||
|         """A generator for iterating over the result cache. |         """A generator for iterating over the result cache. | ||||||
|  |  | ||||||
|         Also populates the cache if there are more possible results to |         Also populates the cache if there are more possible results to yield. | ||||||
|         yield. Raises StopIteration when there are no more results. |         Raises StopIteration when there are no more results""" | ||||||
|         """ |  | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|  |  | ||||||
|         pos = 0 |         pos = 0 | ||||||
|         while True: |         while True: | ||||||
|  |             upper = len(self._result_cache) | ||||||
|             # For all positions lower than the length of the current result |             while pos < upper: | ||||||
|             # cache, serve the docs straight from the cache w/o hitting the |  | ||||||
|             # database. |  | ||||||
|             # XXX it's VERY important to compute the len within the `while` |  | ||||||
|             # condition because the result cache might expand mid-iteration |  | ||||||
|             # (e.g. if we call len(qs) inside a loop that iterates over the |  | ||||||
|             # queryset). Fortunately len(list) is O(1) in Python, so this |  | ||||||
|             # doesn't cause performance issues. |  | ||||||
|             while pos < len(self._result_cache): |  | ||||||
|                 yield self._result_cache[pos] |                 yield self._result_cache[pos] | ||||||
|                 pos += 1 |                 pos += 1 | ||||||
|  |  | ||||||
|             # Raise StopIteration if we already established there were no more |  | ||||||
|             # docs in the db cursor. |  | ||||||
|             if not self._has_more: |             if not self._has_more: | ||||||
|                 raise StopIteration |                 raise StopIteration | ||||||
|  |  | ||||||
|             # Otherwise, populate more of the cache and repeat. |  | ||||||
|             if len(self._result_cache) <= pos: |             if len(self._result_cache) <= pos: | ||||||
|                 self._populate_cache() |                 self._populate_cache() | ||||||
|  |  | ||||||
| @@ -103,22 +86,12 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|  |         if self._has_more: | ||||||
|         # Skip populating the cache if we already established there are no |             try: | ||||||
|         # more docs to pull from the database. |                 for i in xrange(ITER_CHUNK_SIZE): | ||||||
|         if not self._has_more: |                     self._result_cache.append(self.next()) | ||||||
|             return |             except StopIteration: | ||||||
|  |                 self._has_more = False | ||||||
|         # Pull in ITER_CHUNK_SIZE docs from the database and store them in |  | ||||||
|         # the result cache. |  | ||||||
|         try: |  | ||||||
|             for _ in xrange(ITER_CHUNK_SIZE): |  | ||||||
|                 self._result_cache.append(self.next()) |  | ||||||
|         except StopIteration: |  | ||||||
|             # Getting this exception means there are no more docs in the |  | ||||||
|             # db cursor. Set _has_more to False so that we can use that |  | ||||||
|             # information in other places. |  | ||||||
|             self._has_more = False |  | ||||||
|  |  | ||||||
|     def count(self, with_limit_and_skip=False): |     def count(self, with_limit_and_skip=False): | ||||||
|         """Count the selected elements in the query. |         """Count the selected elements in the query. | ||||||
| @@ -141,7 +114,7 @@ class QuerySet(BaseQuerySet): | |||||||
|         .. versionadded:: 0.8.3 Convert to non caching queryset |         .. versionadded:: 0.8.3 Convert to non caching queryset | ||||||
|         """ |         """ | ||||||
|         if self._result_cache is not None: |         if self._result_cache is not None: | ||||||
|             raise OperationError('QuerySet already cached') |             raise OperationError("QuerySet already cached") | ||||||
|         return self.clone_into(QuerySetNoCache(self._document, self._collection)) |         return self.clone_into(QuerySetNoCache(self._document, self._collection)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -164,14 +137,13 @@ class QuerySetNoCache(BaseQuerySet): | |||||||
|             return '.. queryset mid-iteration ..' |             return '.. queryset mid-iteration ..' | ||||||
|  |  | ||||||
|         data = [] |         data = [] | ||||||
|         for _ in xrange(REPR_OUTPUT_SIZE + 1): |         for i in xrange(REPR_OUTPUT_SIZE + 1): | ||||||
|             try: |             try: | ||||||
|                 data.append(self.next()) |                 data.append(self.next()) | ||||||
|             except StopIteration: |             except StopIteration: | ||||||
|                 break |                 break | ||||||
|  |  | ||||||
|         if len(data) > REPR_OUTPUT_SIZE: |         if len(data) > REPR_OUTPUT_SIZE: | ||||||
|             data[-1] = '...(remaining elements truncated)...' |             data[-1] = "...(remaining elements truncated)..." | ||||||
|  |  | ||||||
|         self.rewind() |         self.rewind() | ||||||
|         return repr(data) |         return repr(data) | ||||||
|   | |||||||
| @@ -1,11 +1,9 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  |  | ||||||
| from bson import ObjectId, SON | from bson import SON | ||||||
| from bson.dbref import DBRef |  | ||||||
| import pymongo | import pymongo | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine.base import UPDATE_OPERATORS | from mongoengine.base.fields import UPDATE_OPERATORS | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.connection import get_connection | from mongoengine.connection import get_connection | ||||||
| from mongoengine.errors import InvalidQueryError | from mongoengine.errors import InvalidQueryError | ||||||
| @@ -28,13 +26,13 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + | |||||||
|                    STRING_OPERATORS + CUSTOM_OPERATORS) |                    STRING_OPERATORS + CUSTOM_OPERATORS) | ||||||
|  |  | ||||||
|  |  | ||||||
| # TODO make this less complex |  | ||||||
| def query(_doc_cls=None, **kwargs): | def query(_doc_cls=None, **kwargs): | ||||||
|     """Transform a query from Django-style format to Mongo format.""" |     """Transform a query from Django-style format to Mongo format. | ||||||
|  |     """ | ||||||
|     mongo_query = {} |     mongo_query = {} | ||||||
|     merge_query = defaultdict(list) |     merge_query = defaultdict(list) | ||||||
|     for key, value in sorted(kwargs.items()): |     for key, value in sorted(kwargs.items()): | ||||||
|         if key == '__raw__': |         if key == "__raw__": | ||||||
|             mongo_query.update(value) |             mongo_query.update(value) | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
| @@ -47,7 +45,7 @@ def query(_doc_cls=None, **kwargs): | |||||||
|             op = parts.pop() |             op = parts.pop() | ||||||
|  |  | ||||||
|         # Allow to escape operator-like field name by __ |         # Allow to escape operator-like field name by __ | ||||||
|         if len(parts) > 1 and parts[-1] == '': |         if len(parts) > 1 and parts[-1] == "": | ||||||
|             parts.pop() |             parts.pop() | ||||||
|  |  | ||||||
|         negate = False |         negate = False | ||||||
| @@ -59,17 +57,16 @@ def query(_doc_cls=None, **kwargs): | |||||||
|             # Switch field names to proper names [set in Field(name='foo')] |             # Switch field names to proper names [set in Field(name='foo')] | ||||||
|             try: |             try: | ||||||
|                 fields = _doc_cls._lookup_field(parts) |                 fields = _doc_cls._lookup_field(parts) | ||||||
|             except Exception as e: |             except Exception, e: | ||||||
|                 raise InvalidQueryError(e) |                 raise InvalidQueryError(e) | ||||||
|             parts = [] |             parts = [] | ||||||
|  |  | ||||||
|             CachedReferenceField = _import_class('CachedReferenceField') |             CachedReferenceField = _import_class('CachedReferenceField') | ||||||
|             GenericReferenceField = _import_class('GenericReferenceField') |  | ||||||
|  |  | ||||||
|             cleaned_fields = [] |             cleaned_fields = [] | ||||||
|             for field in fields: |             for field in fields: | ||||||
|                 append_field = True |                 append_field = True | ||||||
|                 if isinstance(field, six.string_types): |                 if isinstance(field, basestring): | ||||||
|                     parts.append(field) |                     parts.append(field) | ||||||
|                     append_field = False |                     append_field = False | ||||||
|                 # is last and CachedReferenceField |                 # is last and CachedReferenceField | ||||||
| @@ -87,9 +84,9 @@ def query(_doc_cls=None, **kwargs): | |||||||
|             singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] |             singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] | ||||||
|             singular_ops += STRING_OPERATORS |             singular_ops += STRING_OPERATORS | ||||||
|             if op in singular_ops: |             if op in singular_ops: | ||||||
|                 if isinstance(field, six.string_types): |                 if isinstance(field, basestring): | ||||||
|                     if (op in STRING_OPERATORS and |                     if (op in STRING_OPERATORS and | ||||||
|                             isinstance(value, six.string_types)): |                             isinstance(value, basestring)): | ||||||
|                         StringField = _import_class('StringField') |                         StringField = _import_class('StringField') | ||||||
|                         value = StringField.prepare_query_value(op, value) |                         value = StringField.prepare_query_value(op, value) | ||||||
|                     else: |                     else: | ||||||
| @@ -101,31 +98,8 @@ def query(_doc_cls=None, **kwargs): | |||||||
|                         value = value['_id'] |                         value = value['_id'] | ||||||
|  |  | ||||||
|             elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): |             elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): | ||||||
|                 # Raise an error if the in/nin/all/near param is not iterable. We need a |                 # 'in', 'nin' and 'all' require a list of values | ||||||
|                 # special check for BaseDocument, because - although it's iterable - using |                 value = [field.prepare_query_value(op, v) for v in value] | ||||||
|                 # it as such in the context of this method is most definitely a mistake. |  | ||||||
|                 BaseDocument = _import_class('BaseDocument') |  | ||||||
|                 if isinstance(value, BaseDocument): |  | ||||||
|                     raise TypeError("When using the `in`, `nin`, `all`, or " |  | ||||||
|                                     "`near`-operators you can\'t use a " |  | ||||||
|                                     "`Document`, you must wrap your object " |  | ||||||
|                                     "in a list (object -> [object]).") |  | ||||||
|                 elif not hasattr(value, '__iter__'): |  | ||||||
|                     raise TypeError("The `in`, `nin`, `all`, or " |  | ||||||
|                                     "`near`-operators must be applied to an " |  | ||||||
|                                     "iterable (e.g. a list).") |  | ||||||
|                 else: |  | ||||||
|                     value = [field.prepare_query_value(op, v) for v in value] |  | ||||||
|  |  | ||||||
|             # If we're querying a GenericReferenceField, we need to alter the |  | ||||||
|             # key depending on the value: |  | ||||||
|             # * If the value is a DBRef, the key should be "field_name._ref". |  | ||||||
|             # * If the value is an ObjectId, the key should be "field_name._ref.$id". |  | ||||||
|             if isinstance(field, GenericReferenceField): |  | ||||||
|                 if isinstance(value, DBRef): |  | ||||||
|                     parts[-1] += '._ref' |  | ||||||
|                 elif isinstance(value, ObjectId): |  | ||||||
|                     parts[-1] += '._ref.$id' |  | ||||||
|  |  | ||||||
|         # if op and op not in COMPARISON_OPERATORS: |         # if op and op not in COMPARISON_OPERATORS: | ||||||
|         if op: |         if op: | ||||||
| @@ -142,10 +116,10 @@ def query(_doc_cls=None, **kwargs): | |||||||
|                     value = query(field.field.document_type, **value) |                     value = query(field.field.document_type, **value) | ||||||
|                 else: |                 else: | ||||||
|                     value = field.prepare_query_value(op, value) |                     value = field.prepare_query_value(op, value) | ||||||
|                 value = {'$elemMatch': value} |                 value = {"$elemMatch": value} | ||||||
|             elif op in CUSTOM_OPERATORS: |             elif op in CUSTOM_OPERATORS: | ||||||
|                 NotImplementedError('Custom method "%s" has not ' |                 NotImplementedError("Custom method '%s' has not " | ||||||
|                                     'been implemented' % op) |                                     "been implemented" % op) | ||||||
|             elif op not in STRING_OPERATORS: |             elif op not in STRING_OPERATORS: | ||||||
|                 value = {'$' + op: value} |                 value = {'$' + op: value} | ||||||
|  |  | ||||||
| @@ -154,13 +128,11 @@ def query(_doc_cls=None, **kwargs): | |||||||
|  |  | ||||||
|         for i, part in indices: |         for i, part in indices: | ||||||
|             parts.insert(i, part) |             parts.insert(i, part) | ||||||
|  |  | ||||||
|         key = '.'.join(parts) |         key = '.'.join(parts) | ||||||
|  |  | ||||||
|         if op is None or key not in mongo_query: |         if op is None or key not in mongo_query: | ||||||
|             mongo_query[key] = value |             mongo_query[key] = value | ||||||
|         elif key in mongo_query: |         elif key in mongo_query: | ||||||
|             if isinstance(mongo_query[key], dict): |             if key in mongo_query and isinstance(mongo_query[key], dict): | ||||||
|                 mongo_query[key].update(value) |                 mongo_query[key].update(value) | ||||||
|                 # $max/minDistance needs to come last - convert to SON |                 # $max/minDistance needs to come last - convert to SON | ||||||
|                 value_dict = mongo_query[key] |                 value_dict = mongo_query[key] | ||||||
| @@ -210,16 +182,15 @@ def query(_doc_cls=None, **kwargs): | |||||||
|  |  | ||||||
|  |  | ||||||
| def update(_doc_cls=None, **update): | def update(_doc_cls=None, **update): | ||||||
|     """Transform an update spec from Django-style format to Mongo |     """Transform an update spec from Django-style format to Mongo format. | ||||||
|     format. |  | ||||||
|     """ |     """ | ||||||
|     mongo_update = {} |     mongo_update = {} | ||||||
|     for key, value in update.items(): |     for key, value in update.items(): | ||||||
|         if key == '__raw__': |         if key == "__raw__": | ||||||
|             mongo_update.update(value) |             mongo_update.update(value) | ||||||
|             continue |             continue | ||||||
|         parts = key.split('__') |         parts = key.split('__') | ||||||
|         # if there is no operator, default to 'set' |         # if there is no operator, default to "set" | ||||||
|         if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: |         if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: | ||||||
|             parts.insert(0, 'set') |             parts.insert(0, 'set') | ||||||
|         # Check for an operator and transform to mongo-style if there is |         # Check for an operator and transform to mongo-style if there is | ||||||
| @@ -238,21 +209,21 @@ def update(_doc_cls=None, **update): | |||||||
|             elif op == 'add_to_set': |             elif op == 'add_to_set': | ||||||
|                 op = 'addToSet' |                 op = 'addToSet' | ||||||
|             elif op == 'set_on_insert': |             elif op == 'set_on_insert': | ||||||
|                 op = 'setOnInsert' |                 op = "setOnInsert" | ||||||
|  |  | ||||||
|         match = None |         match = None | ||||||
|         if parts[-1] in COMPARISON_OPERATORS: |         if parts[-1] in COMPARISON_OPERATORS: | ||||||
|             match = parts.pop() |             match = parts.pop() | ||||||
|  |  | ||||||
|         # Allow to escape operator-like field name by __ |         # Allow to escape operator-like field name by __ | ||||||
|         if len(parts) > 1 and parts[-1] == '': |         if len(parts) > 1 and parts[-1] == "": | ||||||
|             parts.pop() |             parts.pop() | ||||||
|  |  | ||||||
|         if _doc_cls: |         if _doc_cls: | ||||||
|             # Switch field names to proper names [set in Field(name='foo')] |             # Switch field names to proper names [set in Field(name='foo')] | ||||||
|             try: |             try: | ||||||
|                 fields = _doc_cls._lookup_field(parts) |                 fields = _doc_cls._lookup_field(parts) | ||||||
|             except Exception as e: |             except Exception, e: | ||||||
|                 raise InvalidQueryError(e) |                 raise InvalidQueryError(e) | ||||||
|             parts = [] |             parts = [] | ||||||
|  |  | ||||||
| @@ -260,7 +231,7 @@ def update(_doc_cls=None, **update): | |||||||
|             appended_sub_field = False |             appended_sub_field = False | ||||||
|             for field in fields: |             for field in fields: | ||||||
|                 append_field = True |                 append_field = True | ||||||
|                 if isinstance(field, six.string_types): |                 if isinstance(field, basestring): | ||||||
|                     # Convert the S operator to $ |                     # Convert the S operator to $ | ||||||
|                     if field == 'S': |                     if field == 'S': | ||||||
|                         field = '$' |                         field = '$' | ||||||
| @@ -281,7 +252,7 @@ def update(_doc_cls=None, **update): | |||||||
|             else: |             else: | ||||||
|                 field = cleaned_fields[-1] |                 field = cleaned_fields[-1] | ||||||
|  |  | ||||||
|             GeoJsonBaseField = _import_class('GeoJsonBaseField') |             GeoJsonBaseField = _import_class("GeoJsonBaseField") | ||||||
|             if isinstance(field, GeoJsonBaseField): |             if isinstance(field, GeoJsonBaseField): | ||||||
|                 value = field.to_mongo(value) |                 value = field.to_mongo(value) | ||||||
|  |  | ||||||
| @@ -295,7 +266,7 @@ def update(_doc_cls=None, **update): | |||||||
|                     value = [field.prepare_query_value(op, v) for v in value] |                     value = [field.prepare_query_value(op, v) for v in value] | ||||||
|                 elif field.required or value is not None: |                 elif field.required or value is not None: | ||||||
|                     value = field.prepare_query_value(op, value) |                     value = field.prepare_query_value(op, value) | ||||||
|             elif op == 'unset': |             elif op == "unset": | ||||||
|                 value = 1 |                 value = 1 | ||||||
|  |  | ||||||
|         if match: |         if match: | ||||||
| @@ -305,16 +276,16 @@ def update(_doc_cls=None, **update): | |||||||
|         key = '.'.join(parts) |         key = '.'.join(parts) | ||||||
|  |  | ||||||
|         if not op: |         if not op: | ||||||
|             raise InvalidQueryError('Updates must supply an operation ' |             raise InvalidQueryError("Updates must supply an operation " | ||||||
|                                     'eg: set__FIELD=value') |                                     "eg: set__FIELD=value") | ||||||
|  |  | ||||||
|         if 'pull' in op and '.' in key: |         if 'pull' in op and '.' in key: | ||||||
|             # Dot operators don't work on pull operations |             # Dot operators don't work on pull operations | ||||||
|             # unless they point to a list field |             # unless they point to a list field | ||||||
|             # Otherwise it uses nested dict syntax |             # Otherwise it uses nested dict syntax | ||||||
|             if op == 'pullAll': |             if op == 'pullAll': | ||||||
|                 raise InvalidQueryError('pullAll operations only support ' |                 raise InvalidQueryError("pullAll operations only support " | ||||||
|                                         'a single field depth') |                                         "a single field depth") | ||||||
|  |  | ||||||
|             # Look for the last list field and use dot notation until there |             # Look for the last list field and use dot notation until there | ||||||
|             field_classes = [c.__class__ for c in cleaned_fields] |             field_classes = [c.__class__ for c in cleaned_fields] | ||||||
| @@ -325,7 +296,7 @@ def update(_doc_cls=None, **update): | |||||||
|                 # Then process as normal |                 # Then process as normal | ||||||
|                 last_listField = len( |                 last_listField = len( | ||||||
|                     cleaned_fields) - field_classes.index(ListField) |                     cleaned_fields) - field_classes.index(ListField) | ||||||
|                 key = '.'.join(parts[:last_listField]) |                 key = ".".join(parts[:last_listField]) | ||||||
|                 parts = parts[last_listField:] |                 parts = parts[last_listField:] | ||||||
|                 parts.insert(0, key) |                 parts.insert(0, key) | ||||||
|  |  | ||||||
| @@ -333,7 +304,7 @@ def update(_doc_cls=None, **update): | |||||||
|             for key in parts: |             for key in parts: | ||||||
|                 value = {key: value} |                 value = {key: value} | ||||||
|         elif op == 'addToSet' and isinstance(value, list): |         elif op == 'addToSet' and isinstance(value, list): | ||||||
|             value = {key: {'$each': value}} |             value = {key: {"$each": value}} | ||||||
|         else: |         else: | ||||||
|             value = {key: value} |             value = {key: value} | ||||||
|         key = '$' + op |         key = '$' + op | ||||||
| @@ -347,82 +318,78 @@ def update(_doc_cls=None, **update): | |||||||
|  |  | ||||||
|  |  | ||||||
| def _geo_operator(field, op, value): | def _geo_operator(field, op, value): | ||||||
|     """Helper to return the query for a given geo query.""" |     """Helper to return the query for a given geo query""" | ||||||
|     if op == 'max_distance': |     if op == "max_distance": | ||||||
|         value = {'$maxDistance': value} |         value = {'$maxDistance': value} | ||||||
|     elif op == 'min_distance': |     elif op == "min_distance": | ||||||
|         value = {'$minDistance': value} |         value = {'$minDistance': value} | ||||||
|     elif field._geo_index == pymongo.GEO2D: |     elif field._geo_index == pymongo.GEO2D: | ||||||
|         if op == 'within_distance': |         if op == "within_distance": | ||||||
|             value = {'$within': {'$center': value}} |             value = {'$within': {'$center': value}} | ||||||
|         elif op == 'within_spherical_distance': |         elif op == "within_spherical_distance": | ||||||
|             value = {'$within': {'$centerSphere': value}} |             value = {'$within': {'$centerSphere': value}} | ||||||
|         elif op == 'within_polygon': |         elif op == "within_polygon": | ||||||
|             value = {'$within': {'$polygon': value}} |             value = {'$within': {'$polygon': value}} | ||||||
|         elif op == 'near': |         elif op == "near": | ||||||
|             value = {'$near': value} |             value = {'$near': value} | ||||||
|         elif op == 'near_sphere': |         elif op == "near_sphere": | ||||||
|             value = {'$nearSphere': value} |             value = {'$nearSphere': value} | ||||||
|         elif op == 'within_box': |         elif op == 'within_box': | ||||||
|             value = {'$within': {'$box': value}} |             value = {'$within': {'$box': value}} | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError('Geo method "%s" has not been ' |             raise NotImplementedError("Geo method '%s' has not " | ||||||
|                                       'implemented for a GeoPointField' % op) |                                       "been implemented for a GeoPointField" % op) | ||||||
|     else: |     else: | ||||||
|         if op == 'geo_within': |         if op == "geo_within": | ||||||
|             value = {'$geoWithin': _infer_geometry(value)} |             value = {"$geoWithin": _infer_geometry(value)} | ||||||
|         elif op == 'geo_within_box': |         elif op == "geo_within_box": | ||||||
|             value = {'$geoWithin': {'$box': value}} |             value = {"$geoWithin": {"$box": value}} | ||||||
|         elif op == 'geo_within_polygon': |         elif op == "geo_within_polygon": | ||||||
|             value = {'$geoWithin': {'$polygon': value}} |             value = {"$geoWithin": {"$polygon": value}} | ||||||
|         elif op == 'geo_within_center': |         elif op == "geo_within_center": | ||||||
|             value = {'$geoWithin': {'$center': value}} |             value = {"$geoWithin": {"$center": value}} | ||||||
|         elif op == 'geo_within_sphere': |         elif op == "geo_within_sphere": | ||||||
|             value = {'$geoWithin': {'$centerSphere': value}} |             value = {"$geoWithin": {"$centerSphere": value}} | ||||||
|         elif op == 'geo_intersects': |         elif op == "geo_intersects": | ||||||
|             value = {'$geoIntersects': _infer_geometry(value)} |             value = {"$geoIntersects": _infer_geometry(value)} | ||||||
|         elif op == 'near': |         elif op == "near": | ||||||
|             value = {'$near': _infer_geometry(value)} |             value = {'$near': _infer_geometry(value)} | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError( |             raise NotImplementedError("Geo method '%s' has not " | ||||||
|                 'Geo method "%s" has not been implemented for a %s ' |                                       "been implemented for a %s " % (op, field._name)) | ||||||
|                 % (op, field._name) |  | ||||||
|             ) |  | ||||||
|     return value |     return value | ||||||
|  |  | ||||||
|  |  | ||||||
| def _infer_geometry(value): | def _infer_geometry(value): | ||||||
|     """Helper method that tries to infer the $geometry shape for a |     """Helper method that tries to infer the $geometry shape for a given value""" | ||||||
|     given value. |  | ||||||
|     """ |  | ||||||
|     if isinstance(value, dict): |     if isinstance(value, dict): | ||||||
|         if '$geometry' in value: |         if "$geometry" in value: | ||||||
|             return value |             return value | ||||||
|         elif 'coordinates' in value and 'type' in value: |         elif 'coordinates' in value and 'type' in value: | ||||||
|             return {'$geometry': value} |             return {"$geometry": value} | ||||||
|         raise InvalidQueryError('Invalid $geometry dictionary should have ' |         raise InvalidQueryError("Invalid $geometry dictionary should have " | ||||||
|                                 'type and coordinates keys') |                                 "type and coordinates keys") | ||||||
|     elif isinstance(value, (list, set)): |     elif isinstance(value, (list, set)): | ||||||
|         # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? |         # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? | ||||||
|         # TODO: should both TypeError and IndexError be alike interpreted? |         # TODO: should both TypeError and IndexError be alike interpreted? | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             value[0][0][0] |             value[0][0][0] | ||||||
|             return {'$geometry': {'type': 'Polygon', 'coordinates': value}} |             return {"$geometry": {"type": "Polygon", "coordinates": value}} | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             value[0][0] |             value[0][0] | ||||||
|             return {'$geometry': {'type': 'LineString', 'coordinates': value}} |             return {"$geometry": {"type": "LineString", "coordinates": value}} | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             value[0] |             value[0] | ||||||
|             return {'$geometry': {'type': 'Point', 'coordinates': value}} |             return {"$geometry": {"type": "Point", "coordinates": value}} | ||||||
|         except (TypeError, IndexError): |         except (TypeError, IndexError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|     raise InvalidQueryError('Invalid $geometry data. Can be either a ' |     raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " | ||||||
|                             'dictionary or (nested) lists of coordinate(s)') |                             "or (nested) lists of coordinate(s)") | ||||||
|   | |||||||
| @@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor): | |||||||
|         self.document = document |         self.document = document | ||||||
|  |  | ||||||
|     def visit_combination(self, combination): |     def visit_combination(self, combination): | ||||||
|         operator = '$and' |         operator = "$and" | ||||||
|         if combination.operation == combination.OR: |         if combination.operation == combination.OR: | ||||||
|             operator = '$or' |             operator = "$or" | ||||||
|         return {operator: combination.children} |         return {operator: combination.children} | ||||||
|  |  | ||||||
|     def visit_query(self, query): |     def visit_query(self, query): | ||||||
| @@ -79,7 +79,8 @@ class QueryCompilerVisitor(QNodeVisitor): | |||||||
|  |  | ||||||
|  |  | ||||||
| class QNode(object): | class QNode(object): | ||||||
|     """Base class for nodes in query trees.""" |     """Base class for nodes in query trees. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     AND = 0 |     AND = 0 | ||||||
|     OR = 1 |     OR = 1 | ||||||
| @@ -93,8 +94,7 @@ class QNode(object): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def _combine(self, other, operation): |     def _combine(self, other, operation): | ||||||
|         """Combine this node with another node into a QCombination |         """Combine this node with another node into a QCombination object. | ||||||
|         object. |  | ||||||
|         """ |         """ | ||||||
|         if getattr(other, 'empty', True): |         if getattr(other, 'empty', True): | ||||||
|             return self |             return self | ||||||
| @@ -116,8 +116,8 @@ class QNode(object): | |||||||
|  |  | ||||||
|  |  | ||||||
| class QCombination(QNode): | class QCombination(QNode): | ||||||
|     """Represents the combination of several conditions by a given |     """Represents the combination of several conditions by a given logical | ||||||
|     logical operator. |     operator. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, operation, children): |     def __init__(self, operation, children): | ||||||
|   | |||||||
| @@ -1,5 +1,7 @@ | |||||||
| __all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', | # -*- coding: utf-8 -*- | ||||||
|            'post_save', 'pre_delete', 'post_delete') |  | ||||||
|  | __all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', | ||||||
|  |            'post_save', 'pre_delete', 'post_delete'] | ||||||
|  |  | ||||||
| signals_available = False | signals_available = False | ||||||
| try: | try: | ||||||
| @@ -32,7 +34,6 @@ except ImportError: | |||||||
|             temporarily_connected_to = _fail |             temporarily_connected_to = _fail | ||||||
|         del _fail |         del _fail | ||||||
|  |  | ||||||
|  |  | ||||||
| # the namespace for code signals.  If you are not mongoengine code, do | # the namespace for code signals.  If you are not mongoengine code, do | ||||||
| # not put signals in here.  Create your own namespace instead. | # not put signals in here.  Create your own namespace instead. | ||||||
| _signals = Namespace() | _signals = Namespace() | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								setup.cfg
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								setup.cfg
									
									
									
									
									
								
							| @@ -1,11 +1,13 @@ | |||||||
| [nosetests] | [nosetests] | ||||||
| verbosity=2 | verbosity = 2 | ||||||
| detailed-errors=1 | detailed-errors = 1 | ||||||
| tests=tests | cover-erase = 1 | ||||||
| cover-package=mongoengine | cover-branches = 1 | ||||||
|  | cover-package = mongoengine | ||||||
|  | tests = tests | ||||||
|  |  | ||||||
| [flake8] | [flake8] | ||||||
| ignore=E501,F401,F403,F405,I201 | ignore=E501,F401,F403,F405,I201 | ||||||
| exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests | exclude=build,dist,docs,venv,.tox,.eggs,tests | ||||||
| max-complexity=47 | max-complexity=42 | ||||||
| application-import-names=mongoengine,tests | application-import-names=mongoengine,tests | ||||||
|   | |||||||
							
								
								
									
										25
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								setup.py
									
									
									
									
									
								
							| @@ -21,9 +21,8 @@ except Exception: | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_version(version_tuple): | def get_version(version_tuple): | ||||||
|     """Return the version tuple as a string, e.g. for (0, 10, 7), |     if not isinstance(version_tuple[-1], int): | ||||||
|     return '0.10.7'. |         return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] | ||||||
|     """ |  | ||||||
|     return '.'.join(map(str, version_tuple)) |     return '.'.join(map(str, version_tuple)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -42,29 +41,31 @@ CLASSIFIERS = [ | |||||||
|     'Operating System :: OS Independent', |     'Operating System :: OS Independent', | ||||||
|     'Programming Language :: Python', |     'Programming Language :: Python', | ||||||
|     "Programming Language :: Python :: 2", |     "Programming Language :: Python :: 2", | ||||||
|  |     "Programming Language :: Python :: 2.6", | ||||||
|     "Programming Language :: Python :: 2.7", |     "Programming Language :: Python :: 2.7", | ||||||
|     "Programming Language :: Python :: 3", |     "Programming Language :: Python :: 3", | ||||||
|  |     "Programming Language :: Python :: 3.2", | ||||||
|     "Programming Language :: Python :: 3.3", |     "Programming Language :: Python :: 3.3", | ||||||
|     "Programming Language :: Python :: 3.4", |     "Programming Language :: Python :: 3.4", | ||||||
|     "Programming Language :: Python :: 3.5", |  | ||||||
|     "Programming Language :: Python :: Implementation :: CPython", |     "Programming Language :: Python :: Implementation :: CPython", | ||||||
|     "Programming Language :: Python :: Implementation :: PyPy", |     "Programming Language :: Python :: Implementation :: PyPy", | ||||||
|     'Topic :: Database', |     'Topic :: Database', | ||||||
|     'Topic :: Software Development :: Libraries :: Python Modules', |     'Topic :: Software Development :: Libraries :: Python Modules', | ||||||
| ] | ] | ||||||
|  |  | ||||||
| extra_opts = { | extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | ||||||
|     'packages': find_packages(exclude=['tests', 'tests.*']), |  | ||||||
|     'tests_require': ['nose', 'coverage==4.2', 'blinker', 'Pillow>=2.0.0'] |  | ||||||
| } |  | ||||||
| if sys.version_info[0] == 3: | if sys.version_info[0] == 3: | ||||||
|     extra_opts['use_2to3'] = True |     extra_opts['use_2to3'] = True | ||||||
|     if 'test' in sys.argv or 'nosetests' in sys.argv: |     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] | ||||||
|  |     if "test" in sys.argv or "nosetests" in sys.argv: | ||||||
|         extra_opts['packages'] = find_packages() |         extra_opts['packages'] = find_packages() | ||||||
|         extra_opts['package_data'] = { |         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} | ||||||
|             'tests': ['fields/mongoengine.png', 'fields/mongodb_leaf.png']} |  | ||||||
| else: | else: | ||||||
|     extra_opts['tests_require'] += ['python-dateutil'] |     # coverage 4 does not support Python 3.2 anymore | ||||||
|  |     extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] | ||||||
|  |  | ||||||
|  |     if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||||
|  |         extra_opts['tests_require'].append('unittest2') | ||||||
|  |  | ||||||
| setup( | setup( | ||||||
|     name='mongoengine', |     name='mongoengine', | ||||||
|   | |||||||
| @@ -2,3 +2,4 @@ from all_warnings import AllWarnings | |||||||
| from document import * | from document import * | ||||||
| from queryset import * | from queryset import * | ||||||
| from fields import * | from fields import * | ||||||
|  | from migration import * | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ This test has been put into a module.  This is because it tests warnings that | |||||||
| only get triggered on first hit.  This way we can ensure its imported into the | only get triggered on first hit.  This way we can ensure its imported into the | ||||||
| top level and called first by the test suite. | top level and called first by the test suite. | ||||||
| """ | """ | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,3 +1,5 @@ | |||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from class_methods import * | from class_methods import * | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from bson import SON | from bson import SON | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| import unittest | import unittest | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import get_db | from mongoengine.connection import get_db | ||||||
| @@ -141,9 +143,11 @@ class DynamicTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_three_level_complex_data_lookups(self): |     def test_three_level_complex_data_lookups(self): | ||||||
|         """Ensure you can query three level document dynamic fields""" |         """Ensure you can query three level document dynamic fields""" | ||||||
|         p = self.Person.objects.create( |         p = self.Person() | ||||||
|             misc={'hello': {'hello2': 'world'}} |         p.misc = {'hello': {'hello2': 'world'}} | ||||||
|         ) |         p.save() | ||||||
|  |         # from pprint import pprint as pp; import pdb; pdb.set_trace(); | ||||||
|  |         print self.Person.objects(misc__hello__hello2='world') | ||||||
|         self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count()) |         self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count()) | ||||||
|  |  | ||||||
|     def test_complex_embedded_document_validation(self): |     def test_complex_embedded_document_validation(self): | ||||||
|   | |||||||
| @@ -2,8 +2,10 @@ | |||||||
| import unittest | import unittest | ||||||
| import sys | import sys | ||||||
|  |  | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import pymongo | import pymongo | ||||||
|  | from random import randint | ||||||
|  |  | ||||||
| from nose.plugins.skip import SkipTest | from nose.plugins.skip import SkipTest | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| @@ -15,9 +17,11 @@ __all__ = ("IndexesTest", ) | |||||||
|  |  | ||||||
|  |  | ||||||
| class IndexesTest(unittest.TestCase): | class IndexesTest(unittest.TestCase): | ||||||
|  |     _MAX_RAND = 10 ** 10 | ||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         self.connection = connect(db='mongoenginetest') |         self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND)) | ||||||
|  |         self.connection = connect(db=self.db_name) | ||||||
|         self.db = get_db() |         self.db = get_db() | ||||||
|  |  | ||||||
|         class Person(Document): |         class Person(Document): | ||||||
| @@ -556,8 +560,8 @@ class IndexesTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(0, 10): |         for i in xrange(0, 10): | ||||||
|             tags = [("tag %i" % n) for n in range(0, i % 2)] |             tags = [("tag %i" % n) for n in xrange(0, i % 2)] | ||||||
|             BlogPost(tags=tags).save() |             BlogPost(tags=tags).save() | ||||||
|  |  | ||||||
|         self.assertEqual(BlogPost.objects.count(), 10) |         self.assertEqual(BlogPost.objects.count(), 10) | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| @@ -251,17 +253,19 @@ class InheritanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(classes, [Human]) |         self.assertEqual(classes, [Human]) | ||||||
|  |  | ||||||
|     def test_allow_inheritance(self): |     def test_allow_inheritance(self): | ||||||
|         """Ensure that inheritance is disabled by default on simple |         """Ensure that inheritance may be disabled on simple classes and that | ||||||
|         classes and that _cls will not be used. |         _cls and _subclasses will not be used. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         class Animal(Document): |         class Animal(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         # can't inherit because Animal didn't explicitly allow inheritance |         def create_dog_class(): | ||||||
|         with self.assertRaises(ValueError): |  | ||||||
|             class Dog(Animal): |             class Dog(Animal): | ||||||
|                 pass |                 pass | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValueError, create_dog_class) | ||||||
|  |  | ||||||
|         # Check that _cls etc aren't present on simple documents |         # Check that _cls etc aren't present on simple documents | ||||||
|         dog = Animal(name='dog').save() |         dog = Animal(name='dog').save() | ||||||
|         self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) |         self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) | ||||||
| @@ -271,15 +275,17 @@ class InheritanceTest(unittest.TestCase): | |||||||
|         self.assertFalse('_cls' in obj) |         self.assertFalse('_cls' in obj) | ||||||
|  |  | ||||||
|     def test_cant_turn_off_inheritance_on_subclass(self): |     def test_cant_turn_off_inheritance_on_subclass(self): | ||||||
|         """Ensure if inheritance is on in a subclass you cant turn it off. |         """Ensure if inheritance is on in a subclass you cant turn it off | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         class Animal(Document): |         class Animal(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|             meta = {'allow_inheritance': True} |             meta = {'allow_inheritance': True} | ||||||
|  |  | ||||||
|         with self.assertRaises(ValueError): |         def create_mammal_class(): | ||||||
|             class Mammal(Animal): |             class Mammal(Animal): | ||||||
|                 meta = {'allow_inheritance': False} |                 meta = {'allow_inheritance': False} | ||||||
|  |         self.assertRaises(ValueError, create_mammal_class) | ||||||
|  |  | ||||||
|     def test_allow_inheritance_abstract_document(self): |     def test_allow_inheritance_abstract_document(self): | ||||||
|         """Ensure that abstract documents can set inheritance rules and that |         """Ensure that abstract documents can set inheritance rules and that | ||||||
| @@ -292,9 +298,10 @@ class InheritanceTest(unittest.TestCase): | |||||||
|         class Animal(FinalDocument): |         class Animal(FinalDocument): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         with self.assertRaises(ValueError): |         def create_mammal_class(): | ||||||
|             class Mammal(Animal): |             class Mammal(Animal): | ||||||
|                 pass |                 pass | ||||||
|  |         self.assertRaises(ValueError, create_mammal_class) | ||||||
|  |  | ||||||
|         # Check that _cls isn't present in simple documents |         # Check that _cls isn't present in simple documents | ||||||
|         doc = Animal(name='dog') |         doc = Animal(name='dog') | ||||||
| @@ -353,26 +360,29 @@ class InheritanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(berlin.pk, berlin.auto_id_0) |         self.assertEqual(berlin.pk, berlin.auto_id_0) | ||||||
|  |  | ||||||
|     def test_abstract_document_creation_does_not_fail(self): |     def test_abstract_document_creation_does_not_fail(self): | ||||||
|  |  | ||||||
|         class City(Document): |         class City(Document): | ||||||
|             continent = StringField() |             continent = StringField() | ||||||
|             meta = {'abstract': True, |             meta = {'abstract': True, | ||||||
|                     'allow_inheritance': False} |                     'allow_inheritance': False} | ||||||
|  |  | ||||||
|         bkk = City(continent='asia') |         bkk = City(continent='asia') | ||||||
|         self.assertEqual(None, bkk.pk) |         self.assertEqual(None, bkk.pk) | ||||||
|         # TODO: expected error? Shouldn't we create a new error type? |         # TODO: expected error? Shouldn't we create a new error type? | ||||||
|         with self.assertRaises(KeyError): |         self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1)) | ||||||
|             setattr(bkk, 'pk', 1) |  | ||||||
|  |  | ||||||
|     def test_allow_inheritance_embedded_document(self): |     def test_allow_inheritance_embedded_document(self): | ||||||
|         """Ensure embedded documents respect inheritance.""" |         """Ensure embedded documents respect inheritance | ||||||
|  |         """ | ||||||
|  |  | ||||||
|         class Comment(EmbeddedDocument): |         class Comment(EmbeddedDocument): | ||||||
|             content = StringField() |             content = StringField() | ||||||
|  |  | ||||||
|         with self.assertRaises(ValueError): |         def create_special_comment(): | ||||||
|             class SpecialComment(Comment): |             class SpecialComment(Comment): | ||||||
|                 pass |                 pass | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValueError, create_special_comment) | ||||||
|  |  | ||||||
|         doc = Comment(content='test') |         doc = Comment(content='test') | ||||||
|         self.assertFalse('_cls' in doc.to_mongo()) |         self.assertFalse('_cls' in doc.to_mongo()) | ||||||
|  |  | ||||||
| @@ -444,11 +454,11 @@ class InheritanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(Guppy._get_collection_name(), 'fish') |         self.assertEqual(Guppy._get_collection_name(), 'fish') | ||||||
|         self.assertEqual(Human._get_collection_name(), 'human') |         self.assertEqual(Human._get_collection_name(), 'human') | ||||||
|  |  | ||||||
|         # ensure that a subclass of a non-abstract class can't be abstract |         def create_bad_abstract(): | ||||||
|         with self.assertRaises(ValueError): |  | ||||||
|             class EvilHuman(Human): |             class EvilHuman(Human): | ||||||
|                 evil = BooleanField(default=True) |                 evil = BooleanField(default=True) | ||||||
|                 meta = {'abstract': True} |                 meta = {'abstract': True} | ||||||
|  |         self.assertRaises(ValueError, create_bad_abstract) | ||||||
|  |  | ||||||
|     def test_abstract_embedded_documents(self): |     def test_abstract_embedded_documents(self): | ||||||
|         # 789: EmbeddedDocument shouldn't inherit abstract |         # 789: EmbeddedDocument shouldn't inherit abstract | ||||||
|   | |||||||
| @@ -1,4 +1,7 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import bson | import bson | ||||||
| import os | import os | ||||||
| import pickle | import pickle | ||||||
| @@ -13,12 +16,12 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, | |||||||
|                             PickleDynamicEmbedded, PickleDynamicTest) |                             PickleDynamicEmbedded, PickleDynamicTest) | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.base import get_document, _document_registry |  | ||||||
| from mongoengine.connection import get_db |  | ||||||
| from mongoengine.errors import (NotRegistered, InvalidDocumentError, | from mongoengine.errors import (NotRegistered, InvalidDocumentError, | ||||||
|                                 InvalidQueryError, NotUniqueError, |                                 InvalidQueryError, NotUniqueError, | ||||||
|                                 FieldDoesNotExist, SaveConditionError) |                                 FieldDoesNotExist, SaveConditionError) | ||||||
| from mongoengine.queryset import NULLIFY, Q | from mongoengine.queryset import NULLIFY, Q | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.base import get_document | ||||||
| from mongoengine.context_managers import switch_db, query_counter | from mongoengine.context_managers import switch_db, query_counter | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
|  |  | ||||||
| @@ -99,18 +102,21 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(options['size'], 4096) |         self.assertEqual(options['size'], 4096) | ||||||
|  |  | ||||||
|         # Check that the document cannot be redefined with different options |         # Check that the document cannot be redefined with different options | ||||||
|         class Log(Document): |         def recreate_log_document(): | ||||||
|             date = DateTimeField(default=datetime.now) |             class Log(Document): | ||||||
|             meta = { |                 date = DateTimeField(default=datetime.now) | ||||||
|                 'max_documents': 11, |                 meta = { | ||||||
|             } |                     'max_documents': 11, | ||||||
|  |                 } | ||||||
|         # Accessing Document.objects creates the collection |             # Create the collection by accessing Document.objects | ||||||
|         with self.assertRaises(InvalidCollectionError): |  | ||||||
|             Log.objects |             Log.objects | ||||||
|  |         self.assertRaises(InvalidCollectionError, recreate_log_document) | ||||||
|  |  | ||||||
|  |         Log.drop_collection() | ||||||
|  |  | ||||||
|     def test_capped_collection_default(self): |     def test_capped_collection_default(self): | ||||||
|         """Ensure that capped collections defaults work properly.""" |         """Ensure that capped collections defaults work properly. | ||||||
|  |         """ | ||||||
|         class Log(Document): |         class Log(Document): | ||||||
|             date = DateTimeField(default=datetime.now) |             date = DateTimeField(default=datetime.now) | ||||||
|             meta = { |             meta = { | ||||||
| @@ -128,14 +134,16 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(options['size'], 10 * 2**20) |         self.assertEqual(options['size'], 10 * 2**20) | ||||||
|  |  | ||||||
|         # Check that the document with default value can be recreated |         # Check that the document with default value can be recreated | ||||||
|         class Log(Document): |         def recreate_log_document(): | ||||||
|             date = DateTimeField(default=datetime.now) |             class Log(Document): | ||||||
|             meta = { |                 date = DateTimeField(default=datetime.now) | ||||||
|                 'max_documents': 10, |                 meta = { | ||||||
|             } |                     'max_documents': 10, | ||||||
|  |                 } | ||||||
|         # Create the collection by accessing Document.objects |             # Create the collection by accessing Document.objects | ||||||
|         Log.objects |             Log.objects | ||||||
|  |         recreate_log_document() | ||||||
|  |         Log.drop_collection() | ||||||
|  |  | ||||||
|     def test_capped_collection_no_max_size_problems(self): |     def test_capped_collection_no_max_size_problems(self): | ||||||
|         """Ensure that capped collections with odd max_size work properly. |         """Ensure that capped collections with odd max_size work properly. | ||||||
| @@ -158,14 +166,16 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertTrue(options['size'] >= 10000) |         self.assertTrue(options['size'] >= 10000) | ||||||
|  |  | ||||||
|         # Check that the document with odd max_size value can be recreated |         # Check that the document with odd max_size value can be recreated | ||||||
|         class Log(Document): |         def recreate_log_document(): | ||||||
|             date = DateTimeField(default=datetime.now) |             class Log(Document): | ||||||
|             meta = { |                 date = DateTimeField(default=datetime.now) | ||||||
|                 'max_size': 10000, |                 meta = { | ||||||
|             } |                     'max_size': 10000, | ||||||
|  |                 } | ||||||
|         # Create the collection by accessing Document.objects |             # Create the collection by accessing Document.objects | ||||||
|         Log.objects |             Log.objects | ||||||
|  |         recreate_log_document() | ||||||
|  |         Log.drop_collection() | ||||||
|  |  | ||||||
|     def test_repr(self): |     def test_repr(self): | ||||||
|         """Ensure that unicode representation works |         """Ensure that unicode representation works | ||||||
| @@ -276,7 +286,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         list_stats = [] |         list_stats = [] | ||||||
|  |  | ||||||
|         for i in range(10): |         for i in xrange(10): | ||||||
|             s = Stats() |             s = Stats() | ||||||
|             s.save() |             s.save() | ||||||
|             list_stats.append(s) |             list_stats.append(s) | ||||||
| @@ -346,14 +356,14 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(User._fields['username'].db_field, '_id') |         self.assertEqual(User._fields['username'].db_field, '_id') | ||||||
|         self.assertEqual(User._meta['id_field'], 'username') |         self.assertEqual(User._meta['id_field'], 'username') | ||||||
|  |  | ||||||
|         # test no primary key field |         def create_invalid_user(): | ||||||
|         self.assertRaises(ValidationError, User(name='test').save) |             User(name='test').save()  # no primary key field | ||||||
|  |         self.assertRaises(ValidationError, create_invalid_user) | ||||||
|  |  | ||||||
|         # define a subclass with a different primary key field than the |         def define_invalid_user(): | ||||||
|         # parent |  | ||||||
|         with self.assertRaises(ValueError): |  | ||||||
|             class EmailUser(User): |             class EmailUser(User): | ||||||
|                 email = StringField(primary_key=True) |                 email = StringField(primary_key=True) | ||||||
|  |         self.assertRaises(ValueError, define_invalid_user) | ||||||
|  |  | ||||||
|         class EmailUser(User): |         class EmailUser(User): | ||||||
|             email = StringField() |             email = StringField() | ||||||
| @@ -401,10 +411,12 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # Mimic Place and NicePlace definitions being in a different file |         # Mimic Place and NicePlace definitions being in a different file | ||||||
|         # and the NicePlace model not being imported in at query time. |         # and the NicePlace model not being imported in at query time. | ||||||
|  |         from mongoengine.base import _document_registry | ||||||
|         del(_document_registry['Place.NicePlace']) |         del(_document_registry['Place.NicePlace']) | ||||||
|  |  | ||||||
|         with self.assertRaises(NotRegistered): |         def query_without_importing_nice_place(): | ||||||
|             list(Place.objects.all()) |             print Place.objects.all() | ||||||
|  |         self.assertRaises(NotRegistered, query_without_importing_nice_place) | ||||||
|  |  | ||||||
|     def test_document_registry_regressions(self): |     def test_document_registry_regressions(self): | ||||||
|  |  | ||||||
| @@ -435,15 +447,6 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         person.to_dbref() |         person.to_dbref() | ||||||
|  |  | ||||||
|     def test_save_abstract_document(self): |  | ||||||
|         """Saving an abstract document should fail.""" |  | ||||||
|         class Doc(Document): |  | ||||||
|             name = StringField() |  | ||||||
|             meta = {'abstract': True} |  | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidDocumentError): |  | ||||||
|             Doc(name='aaa').save() |  | ||||||
|  |  | ||||||
|     def test_reload(self): |     def test_reload(self): | ||||||
|         """Ensure that attributes may be reloaded. |         """Ensure that attributes may be reloaded. | ||||||
|         """ |         """ | ||||||
| @@ -742,7 +745,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             t.save() |             t.save() | ||||||
|         except ValidationError as e: |         except ValidationError, e: | ||||||
|             expect_msg = "Draft entries may not have a publication date." |             expect_msg = "Draft entries may not have a publication date." | ||||||
|             self.assertTrue(expect_msg in e.message) |             self.assertTrue(expect_msg in e.message) | ||||||
|             self.assertEqual(e.to_dict(), {'__all__': expect_msg}) |             self.assertEqual(e.to_dict(), {'__all__': expect_msg}) | ||||||
| @@ -781,7 +784,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|         t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) |         t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) | ||||||
|         try: |         try: | ||||||
|             t.save() |             t.save() | ||||||
|         except ValidationError as e: |         except ValidationError, e: | ||||||
|             expect_msg = "Value of z != x + y" |             expect_msg = "Value of z != x + y" | ||||||
|             self.assertTrue(expect_msg in e.message) |             self.assertTrue(expect_msg in e.message) | ||||||
|             self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) |             self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) | ||||||
| @@ -795,10 +798,8 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_modify_empty(self): |     def test_modify_empty(self): | ||||||
|         doc = self.Person(name="bob", age=10).save() |         doc = self.Person(name="bob", age=10).save() | ||||||
|  |         self.assertRaises( | ||||||
|         with self.assertRaises(InvalidDocumentError): |             InvalidDocumentError, lambda: self.Person().modify(set__age=10)) | ||||||
|             self.Person().modify(set__age=10) |  | ||||||
|  |  | ||||||
|         self.assertDbEqual([dict(doc.to_mongo())]) |         self.assertDbEqual([dict(doc.to_mongo())]) | ||||||
|  |  | ||||||
|     def test_modify_invalid_query(self): |     def test_modify_invalid_query(self): | ||||||
| @@ -806,8 +807,9 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc2 = self.Person(name="jim", age=20).save() |         doc2 = self.Person(name="jim", age=20).save() | ||||||
|         docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] |         docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidQueryError): |         self.assertRaises( | ||||||
|             doc1.modify({'id': doc2.id}, set__value=20) |             InvalidQueryError, | ||||||
|  |             lambda: doc1.modify(dict(id=doc2.id), set__value=20)) | ||||||
|  |  | ||||||
|         self.assertDbEqual(docs) |         self.assertDbEqual(docs) | ||||||
|  |  | ||||||
| @@ -816,7 +818,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc2 = self.Person(name="jim", age=20).save() |         doc2 = self.Person(name="jim", age=20).save() | ||||||
|         docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] |         docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] | ||||||
|  |  | ||||||
|         assert not doc1.modify({'name': doc2.name}, set__age=100) |         assert not doc1.modify(dict(name=doc2.name), set__age=100) | ||||||
|  |  | ||||||
|         self.assertDbEqual(docs) |         self.assertDbEqual(docs) | ||||||
|  |  | ||||||
| @@ -825,7 +827,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc2 = self.Person(id=ObjectId(), name="jim", age=20) |         doc2 = self.Person(id=ObjectId(), name="jim", age=20) | ||||||
|         docs = [dict(doc1.to_mongo())] |         docs = [dict(doc1.to_mongo())] | ||||||
|  |  | ||||||
|         assert not doc2.modify({'name': doc2.name}, set__age=100) |         assert not doc2.modify(dict(name=doc2.name), set__age=100) | ||||||
|  |  | ||||||
|         self.assertDbEqual(docs) |         self.assertDbEqual(docs) | ||||||
|  |  | ||||||
| @@ -1291,11 +1293,12 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_document_update(self): |     def test_document_update(self): | ||||||
|  |  | ||||||
|         # try updating a non-saved document |         def update_not_saved_raises(): | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             person = self.Person(name='dcrosta') |             person = self.Person(name='dcrosta') | ||||||
|             person.update(set__name='Dan Crosta') |             person.update(set__name='Dan Crosta') | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, update_not_saved_raises) | ||||||
|  |  | ||||||
|         author = self.Person(name='dcrosta') |         author = self.Person(name='dcrosta') | ||||||
|         author.save() |         author.save() | ||||||
|  |  | ||||||
| @@ -1305,17 +1308,19 @@ class InstanceTest(unittest.TestCase): | |||||||
|         p1 = self.Person.objects.first() |         p1 = self.Person.objects.first() | ||||||
|         self.assertEqual(p1.name, author.name) |         self.assertEqual(p1.name, author.name) | ||||||
|  |  | ||||||
|         # try sending an empty update |         def update_no_value_raises(): | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             person = self.Person.objects.first() |             person = self.Person.objects.first() | ||||||
|             person.update() |             person.update() | ||||||
|  |  | ||||||
|         # update that doesn't explicitly specify an operator should default |         self.assertRaises(OperationError, update_no_value_raises) | ||||||
|         # to 'set__' |  | ||||||
|         person = self.Person.objects.first() |         def update_no_op_should_default_to_set(): | ||||||
|         person.update(name="Dan") |             person = self.Person.objects.first() | ||||||
|         person.reload() |             person.update(name="Dan") | ||||||
|         self.assertEqual("Dan", person.name) |             person.reload() | ||||||
|  |             return person.name | ||||||
|  |  | ||||||
|  |         self.assertEqual("Dan", update_no_op_should_default_to_set()) | ||||||
|  |  | ||||||
|     def test_update_unique_field(self): |     def test_update_unique_field(self): | ||||||
|         class Doc(Document): |         class Doc(Document): | ||||||
| @@ -1324,8 +1329,8 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc1 = Doc(name="first").save() |         doc1 = Doc(name="first").save() | ||||||
|         doc2 = Doc(name="second").save() |         doc2 = Doc(name="second").save() | ||||||
|  |  | ||||||
|         with self.assertRaises(NotUniqueError): |         self.assertRaises(NotUniqueError, lambda: | ||||||
|             doc2.update(set__name=doc1.name) |                           doc2.update(set__name=doc1.name)) | ||||||
|  |  | ||||||
|     def test_embedded_update(self): |     def test_embedded_update(self): | ||||||
|         """ |         """ | ||||||
| @@ -1843,13 +1848,15 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_duplicate_db_fields_raise_invalid_document_error(self): |     def test_duplicate_db_fields_raise_invalid_document_error(self): | ||||||
|         """Ensure a InvalidDocumentError is thrown if duplicate fields |         """Ensure a InvalidDocumentError is thrown if duplicate fields | ||||||
|         declare the same db_field. |         declare the same db_field""" | ||||||
|         """ |  | ||||||
|         with self.assertRaises(InvalidDocumentError): |         def throw_invalid_document_error(): | ||||||
|             class Foo(Document): |             class Foo(Document): | ||||||
|                 name = StringField() |                 name = StringField() | ||||||
|                 name2 = StringField(db_field='name') |                 name2 = StringField(db_field='name') | ||||||
|  |  | ||||||
|  |         self.assertRaises(InvalidDocumentError, throw_invalid_document_error) | ||||||
|  |  | ||||||
|     def test_invalid_son(self): |     def test_invalid_son(self): | ||||||
|         """Raise an error if loading invalid data""" |         """Raise an error if loading invalid data""" | ||||||
|         class Occurrence(EmbeddedDocument): |         class Occurrence(EmbeddedDocument): | ||||||
| @@ -1861,17 +1868,11 @@ class InstanceTest(unittest.TestCase): | |||||||
|             forms = ListField(StringField(), default=list) |             forms = ListField(StringField(), default=list) | ||||||
|             occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) |             occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidDocumentError): |         def raise_invalid_document(): | ||||||
|             Word._from_son({ |             Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', | ||||||
|                 'stem': [1, 2, 3], |                             'occurs': {"hello": None}}) | ||||||
|                 'forms': 1, |  | ||||||
|                 'count': 'one', |  | ||||||
|                 'occurs': {"hello": None} |  | ||||||
|             }) |  | ||||||
|  |  | ||||||
|         # Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438 |         self.assertRaises(InvalidDocumentError, raise_invalid_document) | ||||||
|         with self.assertRaises(ValueError): |  | ||||||
|             Word._from_son('this is not a valid SON dict') |  | ||||||
|  |  | ||||||
|     def test_reverse_delete_rule_cascade_and_nullify(self): |     def test_reverse_delete_rule_cascade_and_nullify(self): | ||||||
|         """Ensure that a referenced document is also deleted upon deletion. |         """Ensure that a referenced document is also deleted upon deletion. | ||||||
| @@ -2102,7 +2103,8 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(Bar.objects.get().foo, None) |         self.assertEqual(Bar.objects.get().foo, None) | ||||||
|  |  | ||||||
|     def test_invalid_reverse_delete_rule_raise_errors(self): |     def test_invalid_reverse_delete_rule_raise_errors(self): | ||||||
|         with self.assertRaises(InvalidDocumentError): |  | ||||||
|  |         def throw_invalid_document_error(): | ||||||
|             class Blog(Document): |             class Blog(Document): | ||||||
|                 content = StringField() |                 content = StringField() | ||||||
|                 authors = MapField(ReferenceField( |                 authors = MapField(ReferenceField( | ||||||
| @@ -2112,15 +2114,21 @@ class InstanceTest(unittest.TestCase): | |||||||
|                         self.Person, |                         self.Person, | ||||||
|                         reverse_delete_rule=NULLIFY)) |                         reverse_delete_rule=NULLIFY)) | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidDocumentError): |         self.assertRaises(InvalidDocumentError, throw_invalid_document_error) | ||||||
|  |  | ||||||
|  |         def throw_invalid_document_error_embedded(): | ||||||
|             class Parents(EmbeddedDocument): |             class Parents(EmbeddedDocument): | ||||||
|                 father = ReferenceField('Person', reverse_delete_rule=DENY) |                 father = ReferenceField('Person', reverse_delete_rule=DENY) | ||||||
|                 mother = ReferenceField('Person', reverse_delete_rule=DENY) |                 mother = ReferenceField('Person', reverse_delete_rule=DENY) | ||||||
|  |  | ||||||
|  |         self.assertRaises( | ||||||
|  |             InvalidDocumentError, throw_invalid_document_error_embedded) | ||||||
|  |  | ||||||
|     def test_reverse_delete_rule_cascade_recurs(self): |     def test_reverse_delete_rule_cascade_recurs(self): | ||||||
|         """Ensure that a chain of documents is also deleted upon cascaded |         """Ensure that a chain of documents is also deleted upon cascaded | ||||||
|         deletion. |         deletion. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         class BlogPost(Document): |         class BlogPost(Document): | ||||||
|             content = StringField() |             content = StringField() | ||||||
|             author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) |             author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) | ||||||
| @@ -2336,14 +2344,15 @@ class InstanceTest(unittest.TestCase): | |||||||
|         pickle_doc.save() |         pickle_doc.save() | ||||||
|         pickle_doc.delete() |         pickle_doc.delete() | ||||||
|  |  | ||||||
|     def test_override_method_with_field(self): |     def test_throw_invalid_document_error(self): | ||||||
|         """Test creating a field with a field name that would override |  | ||||||
|         the "validate" method. |         # test handles people trying to upsert | ||||||
|         """ |         def throw_invalid_document_error(): | ||||||
|         with self.assertRaises(InvalidDocumentError): |  | ||||||
|             class Blog(Document): |             class Blog(Document): | ||||||
|                 validate = DictField() |                 validate = DictField() | ||||||
|  |  | ||||||
|  |         self.assertRaises(InvalidDocumentError, throw_invalid_document_error) | ||||||
|  |  | ||||||
|     def test_mutating_documents(self): |     def test_mutating_documents(self): | ||||||
|  |  | ||||||
|         class B(EmbeddedDocument): |         class B(EmbeddedDocument): | ||||||
| @@ -2806,10 +2815,11 @@ class InstanceTest(unittest.TestCase): | |||||||
|         log.log = "Saving" |         log.log = "Saving" | ||||||
|         log.save() |         log.save() | ||||||
|  |  | ||||||
|         # try to change the shard key |         def change_shard_key(): | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             log.machine = "127.0.0.1" |             log.machine = "127.0.0.1" | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, change_shard_key) | ||||||
|  |  | ||||||
|     def test_shard_key_in_embedded_document(self): |     def test_shard_key_in_embedded_document(self): | ||||||
|         class Foo(EmbeddedDocument): |         class Foo(EmbeddedDocument): | ||||||
|             foo = StringField() |             foo = StringField() | ||||||
| @@ -2830,11 +2840,12 @@ class InstanceTest(unittest.TestCase): | |||||||
|         bar_doc.bar = 'baz' |         bar_doc.bar = 'baz' | ||||||
|         bar_doc.save() |         bar_doc.save() | ||||||
|  |  | ||||||
|         # try to change the shard key |         def change_shard_key(): | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             bar_doc.foo.foo = 'something' |             bar_doc.foo.foo = 'something' | ||||||
|             bar_doc.save() |             bar_doc.save() | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, change_shard_key) | ||||||
|  |  | ||||||
|     def test_shard_key_primary(self): |     def test_shard_key_primary(self): | ||||||
|         class LogEntry(Document): |         class LogEntry(Document): | ||||||
|             machine = StringField(primary_key=True) |             machine = StringField(primary_key=True) | ||||||
| @@ -2855,10 +2866,11 @@ class InstanceTest(unittest.TestCase): | |||||||
|         log.log = "Saving" |         log.log = "Saving" | ||||||
|         log.save() |         log.save() | ||||||
|  |  | ||||||
|         # try to change the shard key |         def change_shard_key(): | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             log.machine = "127.0.0.1" |             log.machine = "127.0.0.1" | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, change_shard_key) | ||||||
|  |  | ||||||
|     def test_kwargs_simple(self): |     def test_kwargs_simple(self): | ||||||
|  |  | ||||||
|         class Embedded(EmbeddedDocument): |         class Embedded(EmbeddedDocument): | ||||||
| @@ -2943,9 +2955,11 @@ class InstanceTest(unittest.TestCase): | |||||||
|     def test_bad_mixed_creation(self): |     def test_bad_mixed_creation(self): | ||||||
|         """Ensure that document gives correct error when duplicating arguments |         """Ensure that document gives correct error when duplicating arguments | ||||||
|         """ |         """ | ||||||
|         with self.assertRaises(TypeError): |         def construct_bad_instance(): | ||||||
|             return self.Person("Test User", 42, name="Bad User") |             return self.Person("Test User", 42, name="Bad User") | ||||||
|  |  | ||||||
|  |         self.assertRaises(TypeError, construct_bad_instance) | ||||||
|  |  | ||||||
|     def test_data_contains_id_field(self): |     def test_data_contains_id_field(self): | ||||||
|         """Ensure that asking for _data returns 'id' |         """Ensure that asking for _data returns 'id' | ||||||
|         """ |         """ | ||||||
| @@ -3104,17 +3118,17 @@ class InstanceTest(unittest.TestCase): | |||||||
|         p4 = Person.objects()[0] |         p4 = Person.objects()[0] | ||||||
|         p4.save() |         p4.save() | ||||||
|         self.assertEquals(p4.height, 189) |         self.assertEquals(p4.height, 189) | ||||||
|  |          | ||||||
|         # However the default will not be fixed in DB |         # However the default will not be fixed in DB | ||||||
|         self.assertEquals(Person.objects(height=189).count(), 0) |         self.assertEquals(Person.objects(height=189).count(), 0) | ||||||
|  |          | ||||||
|         # alter DB for the new default |         # alter DB for the new default | ||||||
|         coll = Person._get_collection() |         coll = Person._get_collection() | ||||||
|         for person in Person.objects.as_pymongo(): |         for person in Person.objects.as_pymongo(): | ||||||
|             if 'height' not in person: |             if 'height' not in person: | ||||||
|                 person['height'] = 189 |                 person['height'] = 189 | ||||||
|                 coll.save(person) |                 coll.save(person) | ||||||
|  |                  | ||||||
|         self.assertEquals(Person.objects(height=189).count(), 1) |         self.assertEquals(Person.objects(height=189).count(), 1) | ||||||
|  |  | ||||||
|     def test_from_son(self): |     def test_from_son(self): | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,4 +1,7 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  |  | ||||||
| @@ -57,7 +60,7 @@ class ValidatorErrorTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             User().validate() |             User().validate() | ||||||
|         except ValidationError as e: |         except ValidationError, e: | ||||||
|             self.assertTrue("User:None" in e.message) |             self.assertTrue("User:None" in e.message) | ||||||
|             self.assertEqual(e.to_dict(), { |             self.assertEqual(e.to_dict(), { | ||||||
|                 'username': 'Field is required', |                 'username': 'Field is required', | ||||||
| @@ -67,7 +70,7 @@ class ValidatorErrorTest(unittest.TestCase): | |||||||
|         user.name = None |         user.name = None | ||||||
|         try: |         try: | ||||||
|             user.save() |             user.save() | ||||||
|         except ValidationError as e: |         except ValidationError, e: | ||||||
|             self.assertTrue("User:RossC0" in e.message) |             self.assertTrue("User:RossC0" in e.message) | ||||||
|             self.assertEqual(e.to_dict(), { |             self.assertEqual(e.to_dict(), { | ||||||
|                 'name': 'Field is required'}) |                 'name': 'Field is required'}) | ||||||
| @@ -115,7 +118,7 @@ class ValidatorErrorTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             Doc(id="bad").validate() |             Doc(id="bad").validate() | ||||||
|         except ValidationError as e: |         except ValidationError, e: | ||||||
|             self.assertTrue("SubDoc:None" in e.message) |             self.assertTrue("SubDoc:None" in e.message) | ||||||
|             self.assertEqual(e.to_dict(), { |             self.assertEqual(e.to_dict(), { | ||||||
|                 "e": {'val': 'OK could not be converted to int'}}) |                 "e": {'val': 'OK could not be converted to int'}}) | ||||||
| @@ -133,7 +136,7 @@ class ValidatorErrorTest(unittest.TestCase): | |||||||
|         doc.e.val = "OK" |         doc.e.val = "OK" | ||||||
|         try: |         try: | ||||||
|             doc.save() |             doc.save() | ||||||
|         except ValidationError as e: |         except ValidationError, e: | ||||||
|             self.assertTrue("Doc:test" in e.message) |             self.assertTrue("Doc:test" in e.message) | ||||||
|             self.assertEqual(e.to_dict(), { |             self.assertEqual(e.to_dict(), { | ||||||
|                 "e": {'val': 'OK could not be converted to int'}}) |                 "e": {'val': 'OK could not be converted to int'}}) | ||||||
| @@ -153,14 +156,14 @@ class ValidatorErrorTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         s = SubDoc() |         s = SubDoc() | ||||||
|  |  | ||||||
|         self.assertRaises(ValidationError, s.validate) |         self.assertRaises(ValidationError, lambda: s.validate()) | ||||||
|  |  | ||||||
|         d1.e = s |         d1.e = s | ||||||
|         d2.e = s |         d2.e = s | ||||||
|  |  | ||||||
|         del d1 |         del d1 | ||||||
|  |  | ||||||
|         self.assertRaises(ValidationError, d2.validate) |         self.assertRaises(ValidationError, lambda: d2.validate()) | ||||||
|  |  | ||||||
|     def test_parent_reference_in_child_document(self): |     def test_parent_reference_in_child_document(self): | ||||||
|         """ |         """ | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,16 +1,18 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import copy | import copy | ||||||
| import os | import os | ||||||
| import unittest | import unittest | ||||||
| import tempfile | import tempfile | ||||||
|  |  | ||||||
| import gridfs | import gridfs | ||||||
| import six |  | ||||||
|  |  | ||||||
| from nose.plugins.skip import SkipTest | from nose.plugins.skip import SkipTest | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import get_db | from mongoengine.connection import get_db | ||||||
| from mongoengine.python_support import StringIO | from mongoengine.python_support import b, StringIO | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     from PIL import Image |     from PIL import Image | ||||||
| @@ -47,7 +49,7 @@ class FileTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         PutFile.drop_collection() |         PutFile.drop_collection() | ||||||
|  |  | ||||||
|         text = six.b('Hello, World!') |         text = b('Hello, World!') | ||||||
|         content_type = 'text/plain' |         content_type = 'text/plain' | ||||||
|  |  | ||||||
|         putfile = PutFile() |         putfile = PutFile() | ||||||
| @@ -86,8 +88,8 @@ class FileTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         StreamFile.drop_collection() |         StreamFile.drop_collection() | ||||||
|  |  | ||||||
|         text = six.b('Hello, World!') |         text = b('Hello, World!') | ||||||
|         more_text = six.b('Foo Bar') |         more_text = b('Foo Bar') | ||||||
|         content_type = 'text/plain' |         content_type = 'text/plain' | ||||||
|  |  | ||||||
|         streamfile = StreamFile() |         streamfile = StreamFile() | ||||||
| @@ -121,8 +123,8 @@ class FileTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         StreamFile.drop_collection() |         StreamFile.drop_collection() | ||||||
|  |  | ||||||
|         text = six.b('Hello, World!') |         text = b('Hello, World!') | ||||||
|         more_text = six.b('Foo Bar') |         more_text = b('Foo Bar') | ||||||
|         content_type = 'text/plain' |         content_type = 'text/plain' | ||||||
|  |  | ||||||
|         streamfile = StreamFile() |         streamfile = StreamFile() | ||||||
| @@ -153,8 +155,8 @@ class FileTest(unittest.TestCase): | |||||||
|         class SetFile(Document): |         class SetFile(Document): | ||||||
|             the_file = FileField() |             the_file = FileField() | ||||||
|  |  | ||||||
|         text = six.b('Hello, World!') |         text = b('Hello, World!') | ||||||
|         more_text = six.b('Foo Bar') |         more_text = b('Foo Bar') | ||||||
|  |  | ||||||
|         SetFile.drop_collection() |         SetFile.drop_collection() | ||||||
|  |  | ||||||
| @@ -183,7 +185,7 @@ class FileTest(unittest.TestCase): | |||||||
|         GridDocument.drop_collection() |         GridDocument.drop_collection() | ||||||
|  |  | ||||||
|         with tempfile.TemporaryFile() as f: |         with tempfile.TemporaryFile() as f: | ||||||
|             f.write(six.b("Hello World!")) |             f.write(b("Hello World!")) | ||||||
|             f.flush() |             f.flush() | ||||||
|  |  | ||||||
|             # Test without default |             # Test without default | ||||||
| @@ -200,7 +202,7 @@ class FileTest(unittest.TestCase): | |||||||
|             self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) |             self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) | ||||||
|  |  | ||||||
|             # Test with default |             # Test with default | ||||||
|             doc_d = GridDocument(the_file=six.b('')) |             doc_d = GridDocument(the_file=b('')) | ||||||
|             doc_d.save() |             doc_d.save() | ||||||
|  |  | ||||||
|             doc_e = GridDocument.objects.with_id(doc_d.id) |             doc_e = GridDocument.objects.with_id(doc_d.id) | ||||||
| @@ -226,7 +228,7 @@ class FileTest(unittest.TestCase): | |||||||
|         # First instance |         # First instance | ||||||
|         test_file = TestFile() |         test_file = TestFile() | ||||||
|         test_file.name = "Hello, World!" |         test_file.name = "Hello, World!" | ||||||
|         test_file.the_file.put(six.b('Hello, World!')) |         test_file.the_file.put(b('Hello, World!')) | ||||||
|         test_file.save() |         test_file.save() | ||||||
|  |  | ||||||
|         # Second instance |         # Second instance | ||||||
| @@ -280,7 +282,7 @@ class FileTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         test_file = TestFile() |         test_file = TestFile() | ||||||
|         self.assertFalse(bool(test_file.the_file)) |         self.assertFalse(bool(test_file.the_file)) | ||||||
|         test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain') |         test_file.the_file.put(b('Hello, World!'), content_type='text/plain') | ||||||
|         test_file.save() |         test_file.save() | ||||||
|         self.assertTrue(bool(test_file.the_file)) |         self.assertTrue(bool(test_file.the_file)) | ||||||
|  |  | ||||||
| @@ -295,66 +297,66 @@ class FileTest(unittest.TestCase): | |||||||
|         test_file = TestFile() |         test_file = TestFile() | ||||||
|         self.assertFalse(test_file.the_file in [{"test": 1}]) |         self.assertFalse(test_file.the_file in [{"test": 1}]) | ||||||
|  |  | ||||||
|     def test_file_disk_space(self): |     def test_file_disk_space(self):  | ||||||
|         """ Test disk space usage when we delete/replace a file """ |         """ Test disk space usage when we delete/replace a file """  | ||||||
|         class TestFile(Document): |         class TestFile(Document): | ||||||
|             the_file = FileField() |             the_file = FileField() | ||||||
|  |              | ||||||
|         text = six.b('Hello, World!') |         text = b('Hello, World!') | ||||||
|         content_type = 'text/plain' |         content_type = 'text/plain' | ||||||
|  |  | ||||||
|         testfile = TestFile() |         testfile = TestFile() | ||||||
|         testfile.the_file.put(text, content_type=content_type, filename="hello") |         testfile.the_file.put(text, content_type=content_type, filename="hello") | ||||||
|         testfile.save() |         testfile.save() | ||||||
|  |          | ||||||
|         # Now check fs.files and fs.chunks |         # Now check fs.files and fs.chunks  | ||||||
|         db = TestFile._get_db() |         db = TestFile._get_db() | ||||||
|  |          | ||||||
|         files = db.fs.files.find() |         files = db.fs.files.find() | ||||||
|         chunks = db.fs.chunks.find() |         chunks = db.fs.chunks.find() | ||||||
|         self.assertEquals(len(list(files)), 1) |         self.assertEquals(len(list(files)), 1) | ||||||
|         self.assertEquals(len(list(chunks)), 1) |         self.assertEquals(len(list(chunks)), 1) | ||||||
|  |  | ||||||
|         # Deleting the docoument should delete the files |         # Deleting the docoument should delete the files  | ||||||
|         testfile.delete() |         testfile.delete() | ||||||
|  |          | ||||||
|         files = db.fs.files.find() |         files = db.fs.files.find() | ||||||
|         chunks = db.fs.chunks.find() |         chunks = db.fs.chunks.find() | ||||||
|         self.assertEquals(len(list(files)), 0) |         self.assertEquals(len(list(files)), 0) | ||||||
|         self.assertEquals(len(list(chunks)), 0) |         self.assertEquals(len(list(chunks)), 0) | ||||||
|  |          | ||||||
|         # Test case where we don't store a file in the first place |         # Test case where we don't store a file in the first place  | ||||||
|         testfile = TestFile() |         testfile = TestFile() | ||||||
|         testfile.save() |         testfile.save() | ||||||
|  |          | ||||||
|         files = db.fs.files.find() |         files = db.fs.files.find() | ||||||
|         chunks = db.fs.chunks.find() |         chunks = db.fs.chunks.find() | ||||||
|         self.assertEquals(len(list(files)), 0) |         self.assertEquals(len(list(files)), 0) | ||||||
|         self.assertEquals(len(list(chunks)), 0) |         self.assertEquals(len(list(chunks)), 0) | ||||||
|  |          | ||||||
|         testfile.delete() |         testfile.delete() | ||||||
|  |          | ||||||
|         files = db.fs.files.find() |         files = db.fs.files.find() | ||||||
|         chunks = db.fs.chunks.find() |         chunks = db.fs.chunks.find() | ||||||
|         self.assertEquals(len(list(files)), 0) |         self.assertEquals(len(list(files)), 0) | ||||||
|         self.assertEquals(len(list(chunks)), 0) |         self.assertEquals(len(list(chunks)), 0) | ||||||
|  |          | ||||||
|         # Test case where we overwrite the file |         # Test case where we overwrite the file  | ||||||
|         testfile = TestFile() |         testfile = TestFile() | ||||||
|         testfile.the_file.put(text, content_type=content_type, filename="hello") |         testfile.the_file.put(text, content_type=content_type, filename="hello") | ||||||
|         testfile.save() |         testfile.save() | ||||||
|  |          | ||||||
|         text = six.b('Bonjour, World!') |         text = b('Bonjour, World!') | ||||||
|         testfile.the_file.replace(text, content_type=content_type, filename="hello") |         testfile.the_file.replace(text, content_type=content_type, filename="hello") | ||||||
|         testfile.save() |         testfile.save() | ||||||
|  |          | ||||||
|         files = db.fs.files.find() |         files = db.fs.files.find() | ||||||
|         chunks = db.fs.chunks.find() |         chunks = db.fs.chunks.find() | ||||||
|         self.assertEquals(len(list(files)), 1) |         self.assertEquals(len(list(files)), 1) | ||||||
|         self.assertEquals(len(list(chunks)), 1) |         self.assertEquals(len(list(chunks)), 1) | ||||||
|  |          | ||||||
|         testfile.delete() |         testfile.delete() | ||||||
|  |          | ||||||
|         files = db.fs.files.find() |         files = db.fs.files.find() | ||||||
|         chunks = db.fs.chunks.find() |         chunks = db.fs.chunks.find() | ||||||
|         self.assertEquals(len(list(files)), 0) |         self.assertEquals(len(list(files)), 0) | ||||||
| @@ -370,14 +372,14 @@ class FileTest(unittest.TestCase): | |||||||
|         TestImage.drop_collection() |         TestImage.drop_collection() | ||||||
|  |  | ||||||
|         with tempfile.TemporaryFile() as f: |         with tempfile.TemporaryFile() as f: | ||||||
|             f.write(six.b("Hello World!")) |             f.write(b("Hello World!")) | ||||||
|             f.flush() |             f.flush() | ||||||
|  |  | ||||||
|             t = TestImage() |             t = TestImage() | ||||||
|             try: |             try: | ||||||
|                 t.image.put(f) |                 t.image.put(f) | ||||||
|                 self.fail("Should have raised an invalidation error") |                 self.fail("Should have raised an invalidation error") | ||||||
|             except ValidationError as e: |             except ValidationError, e: | ||||||
|                 self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) |                 self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) | ||||||
|  |  | ||||||
|         t = TestImage() |         t = TestImage() | ||||||
| @@ -494,7 +496,7 @@ class FileTest(unittest.TestCase): | |||||||
|         # First instance |         # First instance | ||||||
|         test_file = TestFile() |         test_file = TestFile() | ||||||
|         test_file.name = "Hello, World!" |         test_file.name = "Hello, World!" | ||||||
|         test_file.the_file.put(six.b('Hello, World!'), |         test_file.the_file.put(b('Hello, World!'), | ||||||
|                           name="hello.txt") |                           name="hello.txt") | ||||||
|         test_file.save() |         test_file.save() | ||||||
|  |  | ||||||
| @@ -502,15 +504,16 @@ class FileTest(unittest.TestCase): | |||||||
|         self.assertEqual(data.get('name'), 'hello.txt') |         self.assertEqual(data.get('name'), 'hello.txt') | ||||||
|  |  | ||||||
|         test_file = TestFile.objects.first() |         test_file = TestFile.objects.first() | ||||||
|         self.assertEqual(test_file.the_file.read(), six.b('Hello, World!')) |         self.assertEqual(test_file.the_file.read(), | ||||||
|  |                           b('Hello, World!')) | ||||||
|  |  | ||||||
|         test_file = TestFile.objects.first() |         test_file = TestFile.objects.first() | ||||||
|         test_file.the_file = six.b('HELLO, WORLD!') |         test_file.the_file = b('HELLO, WORLD!') | ||||||
|         test_file.save() |         test_file.save() | ||||||
|  |  | ||||||
|         test_file = TestFile.objects.first() |         test_file = TestFile.objects.first() | ||||||
|         self.assertEqual(test_file.the_file.read(), |         self.assertEqual(test_file.the_file.read(), | ||||||
|                          six.b('HELLO, WORLD!')) |                           b('HELLO, WORLD!')) | ||||||
|  |  | ||||||
|     def test_copyable(self): |     def test_copyable(self): | ||||||
|         class PutFile(Document): |         class PutFile(Document): | ||||||
| @@ -518,7 +521,7 @@ class FileTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         PutFile.drop_collection() |         PutFile.drop_collection() | ||||||
|  |  | ||||||
|         text = six.b('Hello, World!') |         text = b('Hello, World!') | ||||||
|         content_type = 'text/plain' |         content_type = 'text/plain' | ||||||
|  |  | ||||||
|         putfile = PutFile() |         putfile = PutFile() | ||||||
|   | |||||||
| @@ -1,4 +1,7 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|   | |||||||
							
								
								
									
										11
									
								
								tests/migration/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								tests/migration/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | |||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from convert_to_new_inheritance_model import * | ||||||
|  | from decimalfield_as_float import * | ||||||
|  | from referencefield_dbref_to_object_id import * | ||||||
|  | from turn_off_inheritance import * | ||||||
|  | from uuidfield_to_binary import * | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     unittest.main() | ||||||
							
								
								
									
										51
									
								
								tests/migration/convert_to_new_inheritance_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								tests/migration/convert_to_new_inheritance_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from mongoengine import Document, connect | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.fields import StringField | ||||||
|  |  | ||||||
|  | __all__ = ('ConvertToNewInheritanceModel', ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConvertToNewInheritanceModel(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         self.db = get_db() | ||||||
|  |  | ||||||
|  |     def tearDown(self): | ||||||
|  |         for collection in self.db.collection_names(): | ||||||
|  |             if 'system.' in collection: | ||||||
|  |                 continue | ||||||
|  |             self.db.drop_collection(collection) | ||||||
|  |  | ||||||
|  |     def test_how_to_convert_to_the_new_inheritance_model(self): | ||||||
|  |         """Demonstrates migrating from 0.7 to 0.8 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # 1. Declaration of the class | ||||||
|  |         class Animal(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = { | ||||||
|  |                 'allow_inheritance': True, | ||||||
|  |                 'indexes': ['name'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         # 2. Remove _types | ||||||
|  |         collection = Animal._get_collection() | ||||||
|  |         collection.update({}, {"$unset": {"_types": 1}}, multi=True) | ||||||
|  |  | ||||||
|  |         # 3. Confirm extra data is removed | ||||||
|  |         count = collection.find({'_types': {"$exists": True}}).count() | ||||||
|  |         self.assertEqual(0, count) | ||||||
|  |  | ||||||
|  |         # 4. Remove indexes | ||||||
|  |         info = collection.index_information() | ||||||
|  |         indexes_to_drop = [key for key, value in info.iteritems() | ||||||
|  |                            if '_types' in dict(value['key'])] | ||||||
|  |         for index in indexes_to_drop: | ||||||
|  |             collection.drop_index(index) | ||||||
|  |  | ||||||
|  |         # 5. Recreate indexes | ||||||
|  |         Animal.ensure_indexes() | ||||||
							
								
								
									
										50
									
								
								tests/migration/decimalfield_as_float.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tests/migration/decimalfield_as_float.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | |||||||
|  |  # -*- coding: utf-8 -*- | ||||||
|  | import unittest | ||||||
|  | import decimal | ||||||
|  | from decimal import Decimal | ||||||
|  |  | ||||||
|  | from mongoengine import Document, connect | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.fields import StringField, DecimalField, ListField | ||||||
|  |  | ||||||
|  | __all__ = ('ConvertDecimalField', ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConvertDecimalField(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         self.db = get_db() | ||||||
|  |  | ||||||
|  |     def test_how_to_convert_decimal_fields(self): | ||||||
|  |         """Demonstrates migrating from 0.7 to 0.8 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # 1. Old definition - using dbrefs | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             money = DecimalField(force_string=True) | ||||||
|  |             monies = ListField(DecimalField(force_string=True)) | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |         Person(name="Wilson Jr", money=Decimal("2.50"), | ||||||
|  |                monies=[Decimal("2.10"), Decimal("5.00")]).save() | ||||||
|  |  | ||||||
|  |         # 2. Start the migration by changing the schema | ||||||
|  |         # Change DecimalField - add precision and rounding settings | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) | ||||||
|  |             monies = ListField(DecimalField(precision=2, | ||||||
|  |                                             rounding=decimal.ROUND_HALF_UP)) | ||||||
|  |  | ||||||
|  |         # 3. Loop all the objects and mark parent as changed | ||||||
|  |         for p in Person.objects: | ||||||
|  |             p._mark_as_changed('money') | ||||||
|  |             p._mark_as_changed('monies') | ||||||
|  |             p.save() | ||||||
|  |  | ||||||
|  |         # 4. Confirmation of the fix! | ||||||
|  |         wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] | ||||||
|  |         self.assertTrue(isinstance(wilson['money'], float)) | ||||||
|  |         self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) | ||||||
							
								
								
									
										52
									
								
								tests/migration/referencefield_dbref_to_object_id.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								tests/migration/referencefield_dbref_to_object_id.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from mongoengine import Document, connect | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.fields import StringField, ReferenceField, ListField | ||||||
|  |  | ||||||
|  | __all__ = ('ConvertToObjectIdsModel', ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConvertToObjectIdsModel(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         self.db = get_db() | ||||||
|  |  | ||||||
|  |     def test_how_to_convert_to_object_id_reference_fields(self): | ||||||
|  |         """Demonstrates migrating from 0.7 to 0.8 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # 1. Old definition - using dbrefs | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             parent = ReferenceField('self', dbref=True) | ||||||
|  |             friends = ListField(ReferenceField('self', dbref=True)) | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |  | ||||||
|  |         p1 = Person(name="Wilson", parent=None).save() | ||||||
|  |         f1 = Person(name="John", parent=None).save() | ||||||
|  |         f2 = Person(name="Paul", parent=None).save() | ||||||
|  |         f3 = Person(name="George", parent=None).save() | ||||||
|  |         f4 = Person(name="Ringo", parent=None).save() | ||||||
|  |         Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save() | ||||||
|  |  | ||||||
|  |         # 2. Start the migration by changing the schema | ||||||
|  |         # Change ReferenceField as now dbref defaults to False | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             parent = ReferenceField('self') | ||||||
|  |             friends = ListField(ReferenceField('self')) | ||||||
|  |  | ||||||
|  |         # 3. Loop all the objects and mark parent as changed | ||||||
|  |         for p in Person.objects: | ||||||
|  |             p._mark_as_changed('parent') | ||||||
|  |             p._mark_as_changed('friends') | ||||||
|  |             p.save() | ||||||
|  |  | ||||||
|  |         # 4. Confirmation of the fix! | ||||||
|  |         wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] | ||||||
|  |         self.assertEqual(p1.id, wilson['parent']) | ||||||
|  |         self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends']) | ||||||
							
								
								
									
										62
									
								
								tests/migration/turn_off_inheritance.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								tests/migration/turn_off_inheritance.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from mongoengine import Document, connect | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.fields import StringField | ||||||
|  |  | ||||||
|  | __all__ = ('TurnOffInheritanceTest', ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TurnOffInheritanceTest(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         self.db = get_db() | ||||||
|  |  | ||||||
|  |     def tearDown(self): | ||||||
|  |         for collection in self.db.collection_names(): | ||||||
|  |             if 'system.' in collection: | ||||||
|  |                 continue | ||||||
|  |             self.db.drop_collection(collection) | ||||||
|  |  | ||||||
|  |     def test_how_to_turn_off_inheritance(self): | ||||||
|  |         """Demonstrates migrating from allow_inheritance = True to False. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # 1. Old declaration of the class | ||||||
|  |  | ||||||
|  |         class Animal(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = { | ||||||
|  |                 'allow_inheritance': True, | ||||||
|  |                 'indexes': ['name'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         # 2. Turn off inheritance | ||||||
|  |         class Animal(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = { | ||||||
|  |                 'allow_inheritance': False, | ||||||
|  |                 'indexes': ['name'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         # 3. Remove _types and _cls | ||||||
|  |         collection = Animal._get_collection() | ||||||
|  |         collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) | ||||||
|  |  | ||||||
|  |         # 3. Confirm extra data is removed | ||||||
|  |         count = collection.find({"$or": [{'_types': {"$exists": True}}, | ||||||
|  |                                          {'_cls': {"$exists": True}}]}).count() | ||||||
|  |         assert count == 0 | ||||||
|  |  | ||||||
|  |         # 4. Remove indexes | ||||||
|  |         info = collection.index_information() | ||||||
|  |         indexes_to_drop = [key for key, value in info.iteritems() | ||||||
|  |                            if '_types' in dict(value['key']) | ||||||
|  |                               or '_cls' in dict(value['key'])] | ||||||
|  |         for index in indexes_to_drop: | ||||||
|  |             collection.drop_index(index) | ||||||
|  |  | ||||||
|  |         # 5. Recreate indexes | ||||||
|  |         Animal.ensure_indexes() | ||||||
							
								
								
									
										48
									
								
								tests/migration/uuidfield_to_binary.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tests/migration/uuidfield_to_binary.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | import unittest | ||||||
|  | import uuid | ||||||
|  |  | ||||||
|  | from mongoengine import Document, connect | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.fields import StringField, UUIDField, ListField | ||||||
|  |  | ||||||
|  | __all__ = ('ConvertToBinaryUUID', ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConvertToBinaryUUID(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         self.db = get_db() | ||||||
|  |  | ||||||
|  |     def test_how_to_convert_to_binary_uuid_fields(self): | ||||||
|  |         """Demonstrates migrating from 0.7 to 0.8 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # 1. Old definition - using dbrefs | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             uuid = UUIDField(binary=False) | ||||||
|  |             uuids = ListField(UUIDField(binary=False)) | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |         Person(name="Wilson Jr", uuid=uuid.uuid4(), | ||||||
|  |                uuids=[uuid.uuid4(), uuid.uuid4()]).save() | ||||||
|  |  | ||||||
|  |         # 2. Start the migration by changing the schema | ||||||
|  |         # Change UUIDFIeld as now binary defaults to True | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             uuid = UUIDField() | ||||||
|  |             uuids = ListField(UUIDField()) | ||||||
|  |  | ||||||
|  |         # 3. Loop all the objects and mark parent as changed | ||||||
|  |         for p in Person.objects: | ||||||
|  |             p._mark_as_changed('uuid') | ||||||
|  |             p._mark_as_changed('uuids') | ||||||
|  |             p.save() | ||||||
|  |  | ||||||
|  |         # 4. Confirmation of the fix! | ||||||
|  |         wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] | ||||||
|  |         self.assertTrue(isinstance(wilson['uuid'], uuid.UUID)) | ||||||
|  |         self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']])) | ||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| @@ -92,7 +95,7 @@ class OnlyExcludeAllTest(unittest.TestCase): | |||||||
|         exclude = ['d', 'e'] |         exclude = ['d', 'e'] | ||||||
|         only = ['b', 'c'] |         only = ['b', 'c'] | ||||||
|  |  | ||||||
|         qs = MyDoc.objects.fields(**{i: 1 for i in include}) |         qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) | ||||||
|         self.assertEqual(qs._loaded_fields.as_dict(), |         self.assertEqual(qs._loaded_fields.as_dict(), | ||||||
|                          {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) |                          {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) | ||||||
|         qs = qs.only(*only) |         qs = qs.only(*only) | ||||||
| @@ -100,14 +103,14 @@ class OnlyExcludeAllTest(unittest.TestCase): | |||||||
|         qs = qs.exclude(*exclude) |         qs = qs.exclude(*exclude) | ||||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) |         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) | ||||||
|  |  | ||||||
|         qs = MyDoc.objects.fields(**{i: 1 for i in include}) |         qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) | ||||||
|         qs = qs.exclude(*exclude) |         qs = qs.exclude(*exclude) | ||||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) |         self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) | ||||||
|         qs = qs.only(*only) |         qs = qs.only(*only) | ||||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) |         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) | ||||||
|  |  | ||||||
|         qs = MyDoc.objects.exclude(*exclude) |         qs = MyDoc.objects.exclude(*exclude) | ||||||
|         qs = qs.fields(**{i: 1 for i in include}) |         qs = qs.fields(**dict(((i, 1) for i in include))) | ||||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) |         self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) | ||||||
|         qs = qs.only(*only) |         qs = qs.only(*only) | ||||||
|         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) |         self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) | ||||||
| @@ -126,7 +129,7 @@ class OnlyExcludeAllTest(unittest.TestCase): | |||||||
|         exclude = ['d', 'e'] |         exclude = ['d', 'e'] | ||||||
|         only = ['b', 'c'] |         only = ['b', 'c'] | ||||||
|  |  | ||||||
|         qs = MyDoc.objects.fields(**{i: 1 for i in include}) |         qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) | ||||||
|         qs = qs.exclude(*exclude) |         qs = qs.exclude(*exclude) | ||||||
|         qs = qs.only(*only) |         qs = qs.only(*only) | ||||||
|         qs = qs.fields(slice__b=5) |         qs = qs.fields(slice__b=5) | ||||||
|   | |||||||
| @@ -1,5 +1,9 @@ | |||||||
| from datetime import datetime, timedelta | import sys | ||||||
|  |  | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|  | from datetime import datetime, timedelta | ||||||
|  |  | ||||||
| from pymongo.errors import OperationFailure | from pymongo.errors import OperationFailure | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import connect, Document, IntField | from mongoengine import connect, Document, IntField | ||||||
| @@ -96,4 +99,4 @@ class FindAndModifyTest(unittest.TestCase): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
| @@ -9,13 +9,13 @@ from nose.plugins.skip import SkipTest | |||||||
| import pymongo | import pymongo | ||||||
| from pymongo.errors import ConfigurationError | from pymongo.errors import ConfigurationError | ||||||
| from pymongo.read_preferences import ReadPreference | from pymongo.read_preferences import ReadPreference | ||||||
| import six |  | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import get_connection, get_db | from mongoengine.connection import get_connection, get_db | ||||||
| from mongoengine.context_managers import query_counter, switch_db | from mongoengine.context_managers import query_counter, switch_db | ||||||
| from mongoengine.errors import InvalidQueryError | from mongoengine.errors import InvalidQueryError | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import IS_PYMONGO_3, PY3 | ||||||
| from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, | from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, | ||||||
|                                   QuerySet, QuerySetManager, queryset_manager) |                                   QuerySet, QuerySetManager, queryset_manager) | ||||||
|  |  | ||||||
| @@ -25,10 +25,7 @@ __all__ = ("QuerySetTest",) | |||||||
| class db_ops_tracker(query_counter): | class db_ops_tracker(query_counter): | ||||||
|  |  | ||||||
|     def get_ops(self): |     def get_ops(self): | ||||||
|         ignore_query = { |         ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} | ||||||
|             'ns': {'$ne': '%s.system.indexes' % self.db.name}, |  | ||||||
|             'command.count': {'$ne': 'system.profile'} |  | ||||||
|         } |  | ||||||
|         return list(self.db.system.profile.find(ignore_query)) |         return list(self.db.system.profile.find(ignore_query)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -97,12 +94,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             author = ReferenceField(self.Person) |             author = ReferenceField(self.Person) | ||||||
|             author2 = GenericReferenceField() |             author2 = GenericReferenceField() | ||||||
|  |  | ||||||
|         # test addressing a field from a reference |         def test_reference(): | ||||||
|         with self.assertRaises(InvalidQueryError): |  | ||||||
|             list(BlogPost.objects(author__name="test")) |             list(BlogPost.objects(author__name="test")) | ||||||
|  |  | ||||||
|         # should fail for a generic reference as well |         self.assertRaises(InvalidQueryError, test_reference) | ||||||
|         with self.assertRaises(InvalidQueryError): |  | ||||||
|  |         def test_generic_reference(): | ||||||
|             list(BlogPost.objects(author2__name="test")) |             list(BlogPost.objects(author2__name="test")) | ||||||
|  |  | ||||||
|     def test_find(self): |     def test_find(self): | ||||||
| @@ -177,7 +174,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # Test larger slice __repr__ |         # Test larger slice __repr__ | ||||||
|         self.Person.objects.delete() |         self.Person.objects.delete() | ||||||
|         for i in range(55): |         for i in xrange(55): | ||||||
|             self.Person(name='A%s' % i, age=i).save() |             self.Person(name='A%s' % i, age=i).save() | ||||||
|  |  | ||||||
|         self.assertEqual(self.Person.objects.count(), 55) |         self.assertEqual(self.Person.objects.count(), 55) | ||||||
| @@ -221,15 +218,14 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         person = self.Person.objects[1] |         person = self.Person.objects[1] | ||||||
|         self.assertEqual(person.name, "User B") |         self.assertEqual(person.name, "User B") | ||||||
|  |  | ||||||
|         with self.assertRaises(IndexError): |         self.assertRaises(IndexError, self.Person.objects.__getitem__, 2) | ||||||
|             self.Person.objects[2] |  | ||||||
|  |  | ||||||
|         # Find a document using just the object id |         # Find a document using just the object id | ||||||
|         person = self.Person.objects.with_id(person1.id) |         person = self.Person.objects.with_id(person1.id) | ||||||
|         self.assertEqual(person.name, "User A") |         self.assertEqual(person.name, "User A") | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidQueryError): |         self.assertRaises( | ||||||
|             self.Person.objects(name="User A").with_id(person1.id) |             InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id) | ||||||
|  |  | ||||||
|     def test_find_only_one(self): |     def test_find_only_one(self): | ||||||
|         """Ensure that a query using ``get`` returns at most one result. |         """Ensure that a query using ``get`` returns at most one result. | ||||||
| @@ -341,37 +337,9 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         query = query.filter(boolfield=True) |         query = query.filter(boolfield=True) | ||||||
|         self.assertEqual(query.count(), 1) |         self.assertEqual(query.count(), 1) | ||||||
|  |  | ||||||
|     def test_batch_size(self): |  | ||||||
|         """Ensure that batch_size works.""" |  | ||||||
|         class A(Document): |  | ||||||
|             s = StringField() |  | ||||||
|  |  | ||||||
|         A.drop_collection() |  | ||||||
|  |  | ||||||
|         for i in range(100): |  | ||||||
|             A.objects.create(s=str(i)) |  | ||||||
|  |  | ||||||
|         # test iterating over the result set |  | ||||||
|         cnt = 0 |  | ||||||
|         for a in A.objects.batch_size(10): |  | ||||||
|             cnt += 1 |  | ||||||
|         self.assertEqual(cnt, 100) |  | ||||||
|  |  | ||||||
|         # test chaining |  | ||||||
|         qs = A.objects.all() |  | ||||||
|         qs = qs.limit(10).batch_size(20).skip(91) |  | ||||||
|         cnt = 0 |  | ||||||
|         for a in qs: |  | ||||||
|             cnt += 1 |  | ||||||
|         self.assertEqual(cnt, 9) |  | ||||||
|  |  | ||||||
|         # test invalid batch size |  | ||||||
|         qs = A.objects.batch_size(-1) |  | ||||||
|         with self.assertRaises(ValueError): |  | ||||||
|             list(qs) |  | ||||||
|  |  | ||||||
|     def test_update_write_concern(self): |     def test_update_write_concern(self): | ||||||
|         """Test that passing write_concern works""" |         """Test that passing write_concern works""" | ||||||
|  |  | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
|         write_concern = {"fsync": True} |         write_concern = {"fsync": True} | ||||||
| @@ -397,14 +365,18 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         """Test to ensure that update is passed a value to update to""" |         """Test to ensure that update is passed a value to update to""" | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
|         author = self.Person.objects.create(name='Test User') |         author = self.Person(name='Test User') | ||||||
|  |         author.save() | ||||||
|  |  | ||||||
|         with self.assertRaises(OperationError): |         def update_raises(): | ||||||
|             self.Person.objects(pk=author.pk).update({}) |             self.Person.objects(pk=author.pk).update({}) | ||||||
|  |  | ||||||
|         with self.assertRaises(OperationError): |         def update_one_raises(): | ||||||
|             self.Person.objects(pk=author.pk).update_one({}) |             self.Person.objects(pk=author.pk).update_one({}) | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, update_raises) | ||||||
|  |         self.assertRaises(OperationError, update_one_raises) | ||||||
|  |  | ||||||
|     def test_update_array_position(self): |     def test_update_array_position(self): | ||||||
|         """Ensure that updating by array position works. |         """Ensure that updating by array position works. | ||||||
|  |  | ||||||
| @@ -432,8 +404,8 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         Blog.objects.create(posts=[post2, post1]) |         Blog.objects.create(posts=[post2, post1]) | ||||||
|  |  | ||||||
|         # Update all of the first comments of second posts of all blogs |         # Update all of the first comments of second posts of all blogs | ||||||
|         Blog.objects().update(set__posts__1__comments__0__name='testc') |         Blog.objects().update(set__posts__1__comments__0__name="testc") | ||||||
|         testc_blogs = Blog.objects(posts__1__comments__0__name='testc') |         testc_blogs = Blog.objects(posts__1__comments__0__name="testc") | ||||||
|         self.assertEqual(testc_blogs.count(), 2) |         self.assertEqual(testc_blogs.count(), 2) | ||||||
|  |  | ||||||
|         Blog.drop_collection() |         Blog.drop_collection() | ||||||
| @@ -442,13 +414,14 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # Update only the first blog returned by the query |         # Update only the first blog returned by the query | ||||||
|         Blog.objects().update_one( |         Blog.objects().update_one( | ||||||
|             set__posts__1__comments__1__name='testc') |             set__posts__1__comments__1__name="testc") | ||||||
|         testc_blogs = Blog.objects(posts__1__comments__1__name='testc') |         testc_blogs = Blog.objects(posts__1__comments__1__name="testc") | ||||||
|         self.assertEqual(testc_blogs.count(), 1) |         self.assertEqual(testc_blogs.count(), 1) | ||||||
|  |  | ||||||
|         # Check that using this indexing syntax on a non-list fails |         # Check that using this indexing syntax on a non-list fails | ||||||
|         with self.assertRaises(InvalidQueryError): |         def non_list_indexing(): | ||||||
|             Blog.objects().update(set__posts__1__comments__0__name__1='asdf') |             Blog.objects().update(set__posts__1__comments__0__name__1="asdf") | ||||||
|  |         self.assertRaises(InvalidQueryError, non_list_indexing) | ||||||
|  |  | ||||||
|         Blog.drop_collection() |         Blog.drop_collection() | ||||||
|  |  | ||||||
| @@ -516,12 +489,15 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) |         self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) | ||||||
|  |  | ||||||
|         # Nested updates arent supported yet.. |         # Nested updates arent supported yet.. | ||||||
|         with self.assertRaises(OperationError): |         def update_nested(): | ||||||
|             Simple.drop_collection() |             Simple.drop_collection() | ||||||
|             Simple(x=[{'test': [1, 2, 3, 4]}]).save() |             Simple(x=[{'test': [1, 2, 3, 4]}]).save() | ||||||
|             Simple.objects(x__test=2).update(set__x__S__test__S=3) |             Simple.objects(x__test=2).update(set__x__S__test__S=3) | ||||||
|             self.assertEqual(simple.x, [1, 2, 3, 4]) |             self.assertEqual(simple.x, [1, 2, 3, 4]) | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, update_nested) | ||||||
|  |         Simple.drop_collection() | ||||||
|  |  | ||||||
|     def test_update_using_positional_operator_embedded_document(self): |     def test_update_using_positional_operator_embedded_document(self): | ||||||
|         """Ensure that the embedded documents can be updated using the positional |         """Ensure that the embedded documents can be updated using the positional | ||||||
|         operator.""" |         operator.""" | ||||||
| @@ -614,11 +590,11 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             members = DictField() |             members = DictField() | ||||||
|  |  | ||||||
|         club = Club() |         club = Club() | ||||||
|         club.members['John'] = {'gender': 'M', 'age': 13} |         club.members['John'] = dict(gender="M", age=13) | ||||||
|         club.save() |         club.save() | ||||||
|  |  | ||||||
|         Club.objects().update( |         Club.objects().update( | ||||||
|             set__members={"John": {'gender': 'F', 'age': 14}}) |             set__members={"John": dict(gender="F", age=14)}) | ||||||
|  |  | ||||||
|         club = Club.objects().first() |         club = Club.objects().first() | ||||||
|         self.assertEqual(club.members['John']['gender'], "F") |         self.assertEqual(club.members['John']['gender'], "F") | ||||||
| @@ -799,7 +775,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             post2 = Post(comments=[comment2, comment2]) |             post2 = Post(comments=[comment2, comment2]) | ||||||
|  |  | ||||||
|             blogs = [] |             blogs = [] | ||||||
|             for i in range(1, 100): |             for i in xrange(1, 100): | ||||||
|                 blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) |                 blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) | ||||||
|  |  | ||||||
|             Blog.objects.insert(blogs, load_bulk=False) |             Blog.objects.insert(blogs, load_bulk=False) | ||||||
| @@ -836,31 +812,30 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(Blog.objects.count(), 2) |         self.assertEqual(Blog.objects.count(), 2) | ||||||
|  |  | ||||||
|         # test inserting an existing document (shouldn't be allowed) |         # test handles people trying to upsert | ||||||
|         with self.assertRaises(OperationError): |         def throw_operation_error(): | ||||||
|             blog = Blog.objects.first() |  | ||||||
|             Blog.objects.insert(blog) |  | ||||||
|  |  | ||||||
|         # test inserting a query set |  | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             blogs = Blog.objects |             blogs = Blog.objects | ||||||
|             Blog.objects.insert(blogs) |             Blog.objects.insert(blogs) | ||||||
|  |  | ||||||
|         # insert a new doc |         self.assertRaises(OperationError, throw_operation_error) | ||||||
|  |  | ||||||
|  |         # Test can insert new doc | ||||||
|         new_post = Blog(title="code123", id=ObjectId()) |         new_post = Blog(title="code123", id=ObjectId()) | ||||||
|         Blog.objects.insert(new_post) |         Blog.objects.insert(new_post) | ||||||
|  |  | ||||||
|         class Author(Document): |         # test handles other classes being inserted | ||||||
|             pass |         def throw_operation_error_wrong_doc(): | ||||||
|  |             class Author(Document): | ||||||
|         # try inserting a different document class |                 pass | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|             Blog.objects.insert(Author()) |             Blog.objects.insert(Author()) | ||||||
|  |  | ||||||
|         # try inserting a non-document |         self.assertRaises(OperationError, throw_operation_error_wrong_doc) | ||||||
|         with self.assertRaises(OperationError): |  | ||||||
|  |         def throw_operation_error_not_a_document(): | ||||||
|             Blog.objects.insert("HELLO WORLD") |             Blog.objects.insert("HELLO WORLD") | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, throw_operation_error_not_a_document) | ||||||
|  |  | ||||||
|         Blog.drop_collection() |         Blog.drop_collection() | ||||||
|  |  | ||||||
|         blog1 = Blog(title="code", posts=[post1, post2]) |         blog1 = Blog(title="code", posts=[post1, post2]) | ||||||
| @@ -880,13 +855,14 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         blog3 = Blog(title="baz", posts=[post1, post2]) |         blog3 = Blog(title="baz", posts=[post1, post2]) | ||||||
|         Blog.objects.insert([blog1, blog2]) |         Blog.objects.insert([blog1, blog2]) | ||||||
|  |  | ||||||
|         with self.assertRaises(NotUniqueError): |         def throw_operation_error_not_unique(): | ||||||
|             Blog.objects.insert([blog2, blog3]) |             Blog.objects.insert([blog2, blog3]) | ||||||
|  |  | ||||||
|  |         self.assertRaises(NotUniqueError, throw_operation_error_not_unique) | ||||||
|         self.assertEqual(Blog.objects.count(), 2) |         self.assertEqual(Blog.objects.count(), 2) | ||||||
|  |  | ||||||
|         Blog.objects.insert([blog2, blog3], |         Blog.objects.insert([blog2, blog3], write_concern={"w": 0, | ||||||
|                             write_concern={"w": 0, 'continue_on_error': True}) |                                                            'continue_on_error': True}) | ||||||
|         self.assertEqual(Blog.objects.count(), 3) |         self.assertEqual(Blog.objects.count(), 3) | ||||||
|  |  | ||||||
|     def test_get_changed_fields_query_count(self): |     def test_get_changed_fields_query_count(self): | ||||||
| @@ -1019,7 +995,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Doc.drop_collection() |         Doc.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1000): |         for i in xrange(1000): | ||||||
|             Doc(number=i).save() |             Doc(number=i).save() | ||||||
|  |  | ||||||
|         docs = Doc.objects.order_by('number') |         docs = Doc.objects.order_by('number') | ||||||
| @@ -1173,7 +1149,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         qs = list(qs) |         qs = list(qs) | ||||||
|         expected = list(expected) |         expected = list(expected) | ||||||
|         self.assertEqual(len(qs), len(expected)) |         self.assertEqual(len(qs), len(expected)) | ||||||
|         for i in range(len(qs)): |         for i in xrange(len(qs)): | ||||||
|             self.assertEqual(qs[i], expected[i]) |             self.assertEqual(qs[i], expected[i]) | ||||||
|  |  | ||||||
|     def test_ordering(self): |     def test_ordering(self): | ||||||
| @@ -1213,8 +1189,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertSequence(qs, expected) |         self.assertSequence(qs, expected) | ||||||
|  |  | ||||||
|     def test_clear_ordering(self): |     def test_clear_ordering(self): | ||||||
|         """Ensure that the default ordering can be cleared by calling |         """ Ensure that the default ordering can be cleared by calling order_by(). | ||||||
|         order_by() w/o any arguments. |  | ||||||
|         """ |         """ | ||||||
|         class BlogPost(Document): |         class BlogPost(Document): | ||||||
|             title = StringField() |             title = StringField() | ||||||
| @@ -1230,13 +1205,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             BlogPost.objects.filter(title='whatever').first() |             BlogPost.objects.filter(title='whatever').first() | ||||||
|             self.assertEqual(len(q.get_ops()), 1) |             self.assertEqual(len(q.get_ops()), 1) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 q.get_ops()[0]['query']['$orderby'], |                 q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) | ||||||
|                 {'published_date': -1} |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         with db_ops_tracker() as q: |         with db_ops_tracker() as q: | ||||||
|             BlogPost.objects.filter(title='whatever').order_by().first() |             BlogPost.objects.filter(title='whatever').order_by().first() | ||||||
|             self.assertEqual(len(q.get_ops()), 1) |             self.assertEqual(len(q.get_ops()), 1) | ||||||
|  |             print q.get_ops()[0]['query'] | ||||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|     def test_no_ordering_for_get(self): |     def test_no_ordering_for_get(self): | ||||||
| @@ -1265,8 +1239,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|     def test_find_embedded(self): |     def test_find_embedded(self): | ||||||
|         """Ensure that an embedded document is properly returned from |         """Ensure that an embedded document is properly returned from a query. | ||||||
|         different manners of querying. |  | ||||||
|         """ |         """ | ||||||
|         class User(EmbeddedDocument): |         class User(EmbeddedDocument): | ||||||
|             name = StringField() |             name = StringField() | ||||||
| @@ -1277,45 +1250,16 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         user = User(name='Test User') |         post = BlogPost(content='Had a good coffee today...') | ||||||
|         BlogPost.objects.create( |         post.author = User(name='Test User') | ||||||
|             author=user, |         post.save() | ||||||
|             content='Had a good coffee today...' |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         result = BlogPost.objects.first() |         result = BlogPost.objects.first() | ||||||
|         self.assertTrue(isinstance(result.author, User)) |         self.assertTrue(isinstance(result.author, User)) | ||||||
|         self.assertEqual(result.author.name, 'Test User') |         self.assertEqual(result.author.name, 'Test User') | ||||||
|  |  | ||||||
|         result = BlogPost.objects.get(author__name=user.name) |  | ||||||
|         self.assertTrue(isinstance(result.author, User)) |  | ||||||
|         self.assertEqual(result.author.name, 'Test User') |  | ||||||
|  |  | ||||||
|         result = BlogPost.objects.get(author={'name': user.name}) |  | ||||||
|         self.assertTrue(isinstance(result.author, User)) |  | ||||||
|         self.assertEqual(result.author.name, 'Test User') |  | ||||||
|  |  | ||||||
|         # Fails, since the string is not a type that is able to represent the |  | ||||||
|         # author's document structure (should be dict) |  | ||||||
|         with self.assertRaises(InvalidQueryError): |  | ||||||
|             BlogPost.objects.get(author=user.name) |  | ||||||
|  |  | ||||||
|     def test_find_empty_embedded(self): |  | ||||||
|         """Ensure that you can save and find an empty embedded document.""" |  | ||||||
|         class User(EmbeddedDocument): |  | ||||||
|             name = StringField() |  | ||||||
|  |  | ||||||
|         class BlogPost(Document): |  | ||||||
|             content = StringField() |  | ||||||
|             author = EmbeddedDocumentField(User) |  | ||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         BlogPost.objects.create(content='Anonymous post...') |  | ||||||
|  |  | ||||||
|         result = BlogPost.objects.get(author=None) |  | ||||||
|         self.assertEqual(result.author, None) |  | ||||||
|  |  | ||||||
|     def test_find_dict_item(self): |     def test_find_dict_item(self): | ||||||
|         """Ensure that DictField items may be found. |         """Ensure that DictField items may be found. | ||||||
|         """ |         """ | ||||||
| @@ -1723,7 +1667,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Log.drop_collection() |         Log.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(10): |         for i in xrange(10): | ||||||
|             Log().save() |             Log().save() | ||||||
|  |  | ||||||
|         Log.objects()[3:5].delete() |         Log.objects()[3:5].delete() | ||||||
| @@ -1923,10 +1867,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') |         Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') | ||||||
|         self.assertEqual(Site.objects.first().collaborators, []) |         self.assertEqual(Site.objects.first().collaborators, []) | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidQueryError): |         def pull_all(): | ||||||
|             Site.objects(id=s.id).update_one( |             Site.objects(id=s.id).update_one( | ||||||
|                 pull_all__collaborators__user=['Ross']) |                 pull_all__collaborators__user=['Ross']) | ||||||
|  |  | ||||||
|  |         self.assertRaises(InvalidQueryError, pull_all) | ||||||
|  |  | ||||||
|     def test_pull_from_nested_embedded(self): |     def test_pull_from_nested_embedded(self): | ||||||
|  |  | ||||||
|         class User(EmbeddedDocument): |         class User(EmbeddedDocument): | ||||||
| @@ -1957,10 +1903,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             pull__collaborators__unhelpful={'name': 'Frank'}) |             pull__collaborators__unhelpful={'name': 'Frank'}) | ||||||
|         self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) |         self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidQueryError): |         def pull_all(): | ||||||
|             Site.objects(id=s.id).update_one( |             Site.objects(id=s.id).update_one( | ||||||
|                 pull_all__collaborators__helpful__name=['Ross']) |                 pull_all__collaborators__helpful__name=['Ross']) | ||||||
|  |  | ||||||
|  |         self.assertRaises(InvalidQueryError, pull_all) | ||||||
|  |  | ||||||
|     def test_pull_from_nested_mapfield(self): |     def test_pull_from_nested_mapfield(self): | ||||||
|  |  | ||||||
|         class Collaborator(EmbeddedDocument): |         class Collaborator(EmbeddedDocument): | ||||||
| @@ -1989,10 +1937,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             pull__collaborators__unhelpful={'user': 'Frank'}) |             pull__collaborators__unhelpful={'user': 'Frank'}) | ||||||
|         self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) |         self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) | ||||||
|  |  | ||||||
|         with self.assertRaises(InvalidQueryError): |         def pull_all(): | ||||||
|             Site.objects(id=s.id).update_one( |             Site.objects(id=s.id).update_one( | ||||||
|                 pull_all__collaborators__helpful__user=['Ross']) |                 pull_all__collaborators__helpful__user=['Ross']) | ||||||
|  |  | ||||||
|  |         self.assertRaises(InvalidQueryError, pull_all) | ||||||
|  |  | ||||||
|     def test_update_one_pop_generic_reference(self): |     def test_update_one_pop_generic_reference(self): | ||||||
|  |  | ||||||
|         class BlogTag(Document): |         class BlogTag(Document): | ||||||
| @@ -2249,21 +2199,6 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             a.author.name for a in Author.objects.order_by('-author__age')] |             a.author.name for a in Author.objects.order_by('-author__age')] | ||||||
|         self.assertEqual(names, ['User A', 'User B', 'User C']) |         self.assertEqual(names, ['User A', 'User B', 'User C']) | ||||||
|  |  | ||||||
|     def test_comment(self): |  | ||||||
|         """Make sure adding a comment to the query works.""" |  | ||||||
|         class User(Document): |  | ||||||
|             age = IntField() |  | ||||||
|  |  | ||||||
|         with db_ops_tracker() as q: |  | ||||||
|             adult = (User.objects.filter(age__gte=18) |  | ||||||
|                 .comment('looking for an adult') |  | ||||||
|                 .first()) |  | ||||||
|             ops = q.get_ops() |  | ||||||
|             self.assertEqual(len(ops), 1) |  | ||||||
|             op = ops[0] |  | ||||||
|             self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}}) |  | ||||||
|             self.assertEqual(op['query']['$comment'], 'looking for an adult') |  | ||||||
|  |  | ||||||
|     def test_map_reduce(self): |     def test_map_reduce(self): | ||||||
|         """Ensure map/reduce is both mapping and reducing. |         """Ensure map/reduce is both mapping and reducing. | ||||||
|         """ |         """ | ||||||
| @@ -2617,7 +2552,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         BlogPost(hits=2, tags=['music', 'actors']).save() |         BlogPost(hits=2, tags=['music', 'actors']).save() | ||||||
|  |  | ||||||
|         def test_assertions(f): |         def test_assertions(f): | ||||||
|             f = {key: int(val) for key, val in f.items()} |             f = dict((key, int(val)) for key, val in f.items()) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 set(['music', 'film', 'actors', 'watch']), set(f.keys())) |                 set(['music', 'film', 'actors', 'watch']), set(f.keys())) | ||||||
|             self.assertEqual(f['music'], 3) |             self.assertEqual(f['music'], 3) | ||||||
| @@ -2632,7 +2567,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # Ensure query is taken into account |         # Ensure query is taken into account | ||||||
|         def test_assertions(f): |         def test_assertions(f): | ||||||
|             f = {key: int(val) for key, val in f.items()} |             f = dict((key, int(val)) for key, val in f.items()) | ||||||
|             self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys())) |             self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys())) | ||||||
|             self.assertEqual(f['music'], 2) |             self.assertEqual(f['music'], 2) | ||||||
|             self.assertEqual(f['actors'], 1) |             self.assertEqual(f['actors'], 1) | ||||||
| @@ -2696,7 +2631,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         doc.save() |         doc.save() | ||||||
|  |  | ||||||
|         def test_assertions(f): |         def test_assertions(f): | ||||||
|             f = {key: int(val) for key, val in f.items()} |             f = dict((key, int(val)) for key, val in f.items()) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 set(['62-3331-1656', '62-3332-1656']), set(f.keys())) |                 set(['62-3331-1656', '62-3332-1656']), set(f.keys())) | ||||||
|             self.assertEqual(f['62-3331-1656'], 2) |             self.assertEqual(f['62-3331-1656'], 2) | ||||||
| @@ -2710,7 +2645,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # Ensure query is taken into account |         # Ensure query is taken into account | ||||||
|         def test_assertions(f): |         def test_assertions(f): | ||||||
|             f = {key: int(val) for key, val in f.items()} |             f = dict((key, int(val)) for key, val in f.items()) | ||||||
|             self.assertEqual(set(['62-3331-1656']), set(f.keys())) |             self.assertEqual(set(['62-3331-1656']), set(f.keys())) | ||||||
|             self.assertEqual(f['62-3331-1656'], 2) |             self.assertEqual(f['62-3331-1656'], 2) | ||||||
|  |  | ||||||
| @@ -2817,10 +2752,10 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Test.drop_collection() |         Test.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(50): |         for i in xrange(50): | ||||||
|             Test(val=1).save() |             Test(val=1).save() | ||||||
|  |  | ||||||
|         for i in range(20): |         for i in xrange(20): | ||||||
|             Test(val=2).save() |             Test(val=2).save() | ||||||
|  |  | ||||||
|         freqs = Test.objects.item_frequencies( |         freqs = Test.objects.item_frequencies( | ||||||
| @@ -3610,7 +3545,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Post.drop_collection() |         Post.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(10): |         for i in xrange(10): | ||||||
|             Post(title="Post %s" % i).save() |             Post(title="Post %s" % i).save() | ||||||
|  |  | ||||||
|         self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True)) |         self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True)) | ||||||
| @@ -3625,7 +3560,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             pass |             pass | ||||||
|  |  | ||||||
|         MyDoc.drop_collection() |         MyDoc.drop_collection() | ||||||
|         for i in range(0, 10): |         for i in xrange(0, 10): | ||||||
|             MyDoc().save() |             MyDoc().save() | ||||||
|  |  | ||||||
|         self.assertEqual(MyDoc.objects.count(), 10) |         self.assertEqual(MyDoc.objects.count(), 10) | ||||||
| @@ -3681,7 +3616,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Number.drop_collection() |         Number.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 101): |         for i in xrange(1, 101): | ||||||
|             t = Number(n=i) |             t = Number(n=i) | ||||||
|             t.save() |             t.save() | ||||||
|  |  | ||||||
| @@ -3828,9 +3763,11 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertTrue(a in results) |         self.assertTrue(a in results) | ||||||
|         self.assertTrue(c in results) |         self.assertTrue(c in results) | ||||||
|  |  | ||||||
|         with self.assertRaises(TypeError): |         def invalid_where(): | ||||||
|             list(IntPair.objects.where(fielda__gte=3)) |             list(IntPair.objects.where(fielda__gte=3)) | ||||||
|  |  | ||||||
|  |         self.assertRaises(TypeError, invalid_where) | ||||||
|  |  | ||||||
|     def test_scalar(self): |     def test_scalar(self): | ||||||
|  |  | ||||||
|         class Organization(Document): |         class Organization(Document): | ||||||
| @@ -4086,7 +4023,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         # Test larger slice __repr__ |         # Test larger slice __repr__ | ||||||
|         self.Person.objects.delete() |         self.Person.objects.delete() | ||||||
|         for i in range(55): |         for i in xrange(55): | ||||||
|             self.Person(name='A%s' % i, age=i).save() |             self.Person(name='A%s' % i, age=i).save() | ||||||
|  |  | ||||||
|         self.assertEqual(self.Person.objects.scalar('name').count(), 55) |         self.assertEqual(self.Person.objects.scalar('name').count(), 55) | ||||||
| @@ -4094,7 +4031,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) |             "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) |             "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) | ||||||
|         if six.PY3: |         if PY3: | ||||||
|             self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( |             self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( | ||||||
|                 'age').scalar('name')[1:3]) |                 'age').scalar('name')[1:3]) | ||||||
|             self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( |             self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( | ||||||
| @@ -4112,7 +4049,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         pks = self.Person.objects.order_by('age').scalar('pk')[1:3] |         pks = self.Person.objects.order_by('age').scalar('pk')[1:3] | ||||||
|         names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() |         names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() | ||||||
|         if six.PY3: |         if PY3: | ||||||
|             expected = "['A1', 'A2']" |             expected = "['A1', 'A2']" | ||||||
|         else: |         else: | ||||||
|             expected = "[u'A1', u'A2']" |             expected = "[u'A1', u'A2']" | ||||||
| @@ -4468,7 +4405,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         Person.drop_collection() |         Person.drop_collection() | ||||||
|         for i in range(100): |         for i in xrange(100): | ||||||
|             Person(name="No: %s" % i).save() |             Person(name="No: %s" % i).save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -4499,7 +4436,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         Person.drop_collection() |         Person.drop_collection() | ||||||
|         for i in range(100): |         for i in xrange(100): | ||||||
|             Person(name="No: %s" % i).save() |             Person(name="No: %s" % i).save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -4543,7 +4480,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             fields = DictField() |             fields = DictField() | ||||||
|  |  | ||||||
|         Noddy.drop_collection() |         Noddy.drop_collection() | ||||||
|         for i in range(100): |         for i in xrange(100): | ||||||
|             noddy = Noddy() |             noddy = Noddy() | ||||||
|             for j in range(20): |             for j in range(20): | ||||||
|                 noddy.fields["key" + str(j)] = "value " + str(j) |                 noddy.fields["key" + str(j)] = "value " + str(j) | ||||||
| @@ -4555,9 +4492,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual(counter, 100) |         self.assertEqual(counter, 100) | ||||||
|  |  | ||||||
|         self.assertEqual(len(list(docs)), 100) |         self.assertEqual(len(list(docs)), 100) | ||||||
|  |         self.assertRaises(TypeError, lambda: len(docs)) | ||||||
|         with self.assertRaises(TypeError): |  | ||||||
|             len(docs) |  | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
|             self.assertEqual(q, 0) |             self.assertEqual(q, 0) | ||||||
| @@ -4746,7 +4681,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             name = StringField() |             name = StringField() | ||||||
|  |  | ||||||
|         Person.drop_collection() |         Person.drop_collection() | ||||||
|         for i in range(100): |         for i in xrange(100): | ||||||
|             Person(name="No: %s" % i).save() |             Person(name="No: %s" % i).save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -4870,10 +4805,10 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|     def test_delete_count(self): |     def test_delete_count(self): | ||||||
|         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] |         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)] | ||||||
|         self.assertEqual(self.Person.objects().delete(), 3)  # test ordinary QuerySey delete count |         self.assertEqual(self.Person.objects().delete(), 3)  # test ordinary QuerySey delete count | ||||||
|  |  | ||||||
|         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] |         [self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)] | ||||||
|  |  | ||||||
|         self.assertEqual(self.Person.objects().skip(1).delete(), 2)  # test Document delete with existing documents |         self.assertEqual(self.Person.objects().skip(1).delete(), 2)  # test Document delete with existing documents | ||||||
|  |  | ||||||
| @@ -4882,14 +4817,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_max_time_ms(self): |     def test_max_time_ms(self): | ||||||
|         # 778: max_time_ms can get only int or None as input |         # 778: max_time_ms can get only int or None as input | ||||||
|         self.assertRaises(TypeError, |         self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number") | ||||||
|                           self.Person.objects(name="name").max_time_ms, |  | ||||||
|                           'not a number') |  | ||||||
|  |  | ||||||
|     def test_subclass_field_query(self): |     def test_subclass_field_query(self): | ||||||
|         class Animal(Document): |         class Animal(Document): | ||||||
|             is_mamal = BooleanField() |             is_mamal = BooleanField() | ||||||
|             meta = {'allow_inheritance': True} |             meta = dict(allow_inheritance=True) | ||||||
|  |  | ||||||
|         class Cat(Animal): |         class Cat(Animal): | ||||||
|             whiskers_length = FloatField() |             whiskers_length = FloatField() | ||||||
| @@ -4927,85 +4860,6 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) |         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||||
|  |  | ||||||
|     def test_len_during_iteration(self): |  | ||||||
|         """Tests that calling len on a queyset during iteration doesn't |  | ||||||
|         stop paging. |  | ||||||
|         """ |  | ||||||
|         class Data(Document): |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         for i in range(300): |  | ||||||
|             Data().save() |  | ||||||
|  |  | ||||||
|         records = Data.objects.limit(250) |  | ||||||
|  |  | ||||||
|         # This should pull all 250 docs from mongo and populate the result |  | ||||||
|         # cache |  | ||||||
|         len(records) |  | ||||||
|  |  | ||||||
|         # Assert that iterating over documents in the qs touches every |  | ||||||
|         # document even if we call len(qs) midway through the iteration. |  | ||||||
|         for i, r in enumerate(records): |  | ||||||
|             if i == 58: |  | ||||||
|                 len(records) |  | ||||||
|         self.assertEqual(i, 249) |  | ||||||
|  |  | ||||||
|         # Assert the same behavior is true even if we didn't pre-populate the |  | ||||||
|         # result cache. |  | ||||||
|         records = Data.objects.limit(250) |  | ||||||
|         for i, r in enumerate(records): |  | ||||||
|             if i == 58: |  | ||||||
|                 len(records) |  | ||||||
|         self.assertEqual(i, 249) |  | ||||||
|  |  | ||||||
|     def test_iteration_within_iteration(self): |  | ||||||
|         """You should be able to reliably iterate over all the documents |  | ||||||
|         in a given queryset even if there are multiple iterations of it |  | ||||||
|         happening at the same time. |  | ||||||
|         """ |  | ||||||
|         class Data(Document): |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         for i in range(300): |  | ||||||
|             Data().save() |  | ||||||
|  |  | ||||||
|         qs = Data.objects.limit(250) |  | ||||||
|         for i, doc in enumerate(qs): |  | ||||||
|             for j, doc2 in enumerate(qs): |  | ||||||
|                 pass |  | ||||||
|  |  | ||||||
|         self.assertEqual(i, 249) |  | ||||||
|         self.assertEqual(j, 249) |  | ||||||
|  |  | ||||||
|     def test_in_operator_on_non_iterable(self): |  | ||||||
|         """Ensure that using the `__in` operator on a non-iterable raises an |  | ||||||
|         error. |  | ||||||
|         """ |  | ||||||
|         class User(Document): |  | ||||||
|             name = StringField() |  | ||||||
|  |  | ||||||
|         class BlogPost(Document): |  | ||||||
|             content = StringField() |  | ||||||
|             authors = ListField(ReferenceField(User)) |  | ||||||
|  |  | ||||||
|         User.drop_collection() |  | ||||||
|         BlogPost.drop_collection() |  | ||||||
|  |  | ||||||
|         author = User.objects.create(name='Test User') |  | ||||||
|         post = BlogPost.objects.create(content='Had a good coffee today...', |  | ||||||
|                                        authors=[author]) |  | ||||||
|  |  | ||||||
|         # Make sure using `__in` with a list works |  | ||||||
|         blog_posts = BlogPost.objects(authors__in=[author]) |  | ||||||
|         self.assertEqual(list(blog_posts), [post]) |  | ||||||
|  |  | ||||||
|         # Using `__in` with a non-iterable should raise a TypeError |  | ||||||
|         self.assertRaises(TypeError, BlogPost.objects(authors__in=author.pk).count) |  | ||||||
|  |  | ||||||
|         # Using `__in` with a `Document` (which is seemingly iterable but not |  | ||||||
|         # in a way we'd expect) should raise a TypeError, too |  | ||||||
|         self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -238,8 +238,7 @@ class TransformTest(unittest.TestCase): | |||||||
|         box = [(35.0, -125.0), (40.0, -100.0)] |         box = [(35.0, -125.0), (40.0, -100.0)] | ||||||
|         # I *meant* to execute location__within_box=box |         # I *meant* to execute location__within_box=box | ||||||
|         events = Event.objects(location__within=box) |         events = Event.objects(location__within=box) | ||||||
|         with self.assertRaises(InvalidQueryError): |         self.assertRaises(InvalidQueryError, lambda: events.count()) | ||||||
|             events.count() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -185,7 +185,7 @@ class QTest(unittest.TestCase): | |||||||
|             x = IntField() |             x = IntField() | ||||||
|  |  | ||||||
|         TestDoc.drop_collection() |         TestDoc.drop_collection() | ||||||
|         for i in range(1, 101): |         for i in xrange(1, 101): | ||||||
|             t = TestDoc(x=i) |             t = TestDoc(x=i) | ||||||
|             t.save() |             t.save() | ||||||
|  |  | ||||||
| @@ -268,13 +268,14 @@ class QTest(unittest.TestCase): | |||||||
|         self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) |         self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) | ||||||
|  |  | ||||||
|         # Test invalid query objs |         # Test invalid query objs | ||||||
|         with self.assertRaises(InvalidQueryError): |         def wrong_query_objs(): | ||||||
|             self.Person.objects('user1') |             self.Person.objects('user1') | ||||||
|  |  | ||||||
|         # filter should fail, too |         def wrong_query_objs_filter(): | ||||||
|         with self.assertRaises(InvalidQueryError): |             self.Person.objects('user1') | ||||||
|             self.Person.objects.filter('user1') |  | ||||||
|  |  | ||||||
|  |         self.assertRaises(InvalidQueryError, wrong_query_objs) | ||||||
|  |         self.assertRaises(InvalidQueryError, wrong_query_objs_filter) | ||||||
|  |  | ||||||
|     def test_q_regex(self): |     def test_q_regex(self): | ||||||
|         """Ensure that Q objects can be queried using regexes. |         """Ensure that Q objects can be queried using regexes. | ||||||
|   | |||||||
| @@ -1,6 +1,9 @@ | |||||||
|  | import sys | ||||||
| import datetime | import datetime | ||||||
| from pymongo.errors import OperationFailure | from pymongo.errors import OperationFailure | ||||||
|  |  | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     import unittest2 as unittest |     import unittest2 as unittest | ||||||
| except ImportError: | except ImportError: | ||||||
| @@ -16,8 +19,7 @@ from mongoengine import ( | |||||||
| ) | ) | ||||||
| from mongoengine.python_support import IS_PYMONGO_3 | from mongoengine.python_support import IS_PYMONGO_3 | ||||||
| import mongoengine.connection | import mongoengine.connection | ||||||
| from mongoengine.connection import (MongoEngineConnectionError, get_db, | from mongoengine.connection import get_db, get_connection, ConnectionError | ||||||
|                                     get_connection) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_tz_awareness(connection): | def get_tz_awareness(connection): | ||||||
| @@ -157,10 +159,7 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         c.mongoenginetest.add_user("username", "password") |         c.mongoenginetest.add_user("username", "password") | ||||||
|  |  | ||||||
|         if not IS_PYMONGO_3: |         if not IS_PYMONGO_3: | ||||||
|             self.assertRaises( |             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||||
|                 MongoEngineConnectionError, connect, 'testdb_uri_bad', |  | ||||||
|                 host='mongodb://test:password@localhost' |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') |         connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') | ||||||
|  |  | ||||||
| @@ -200,19 +199,6 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         self.assertTrue(isinstance(db, pymongo.database.Database)) |         self.assertTrue(isinstance(db, pymongo.database.Database)) | ||||||
|         self.assertEqual(db.name, 'test') |         self.assertEqual(db.name, 'test') | ||||||
|  |  | ||||||
|     def test_connect_uri_with_replicaset(self): |  | ||||||
|         """Ensure connect() works when specifying a replicaSet.""" |  | ||||||
|         if IS_PYMONGO_3: |  | ||||||
|             c = connect(host='mongodb://localhost/test?replicaSet=local-rs') |  | ||||||
|             db = get_db() |  | ||||||
|             self.assertTrue(isinstance(db, pymongo.database.Database)) |  | ||||||
|             self.assertEqual(db.name, 'test') |  | ||||||
|         else: |  | ||||||
|             # PyMongo < v3.x raises an exception: |  | ||||||
|             # "localhost:27017 is not a member of replica set local-rs" |  | ||||||
|             with self.assertRaises(MongoEngineConnectionError): |  | ||||||
|                 c = connect(host='mongodb://localhost/test?replicaSet=local-rs') |  | ||||||
|  |  | ||||||
|     def test_uri_without_credentials_doesnt_override_conn_settings(self): |     def test_uri_without_credentials_doesnt_override_conn_settings(self): | ||||||
|         """Ensure connect() uses the username & password params if the URI |         """Ensure connect() uses the username & password params if the URI | ||||||
|         doesn't explicitly specify them. |         doesn't explicitly specify them. | ||||||
| @@ -243,11 +229,10 @@ class ConnectionTest(unittest.TestCase): | |||||||
|             self.assertRaises(OperationFailure, test_conn.server_info) |             self.assertRaises(OperationFailure, test_conn.server_info) | ||||||
|         else: |         else: | ||||||
|             self.assertRaises( |             self.assertRaises( | ||||||
|                 MongoEngineConnectionError, connect, 'mongoenginetest', |                 ConnectionError, connect, 'mongoenginetest', alias='test1', | ||||||
|                 alias='test1', |  | ||||||
|                 host='mongodb://username2:password@localhost/mongoenginetest' |                 host='mongodb://username2:password@localhost/mongoenginetest' | ||||||
|             ) |             ) | ||||||
|             self.assertRaises(MongoEngineConnectionError, get_db, 'test1') |             self.assertRaises(ConnectionError, get_db, 'test1') | ||||||
|  |  | ||||||
|         # Authentication succeeds with "authSource" |         # Authentication succeeds with "authSource" | ||||||
|         connect( |         connect( | ||||||
| @@ -268,7 +253,7 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         """ |         """ | ||||||
|         register_connection('testdb', 'mongoenginetest2') |         register_connection('testdb', 'mongoenginetest2') | ||||||
|  |  | ||||||
|         self.assertRaises(MongoEngineConnectionError, get_connection) |         self.assertRaises(ConnectionError, get_connection) | ||||||
|         conn = get_connection('testdb') |         conn = get_connection('testdb') | ||||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) |         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,3 +1,5 @@ | |||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| @@ -77,7 +79,7 @@ class ContextManagersTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             User(name='user %s' % i).save() |             User(name='user %s' % i).save() | ||||||
|  |  | ||||||
|         user = User.objects.first() |         user = User.objects.first() | ||||||
| @@ -115,7 +117,7 @@ class ContextManagersTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             User(name='user %s' % i).save() |             User(name='user %s' % i).save() | ||||||
|  |  | ||||||
|         user = User.objects.first() |         user = User.objects.first() | ||||||
| @@ -193,7 +195,7 @@ class ContextManagersTest(unittest.TestCase): | |||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
|             self.assertEqual(0, q) |             self.assertEqual(0, q) | ||||||
|  |  | ||||||
|             for i in range(1, 51): |             for i in xrange(1, 51): | ||||||
|                 db.test.find({}).count() |                 db.test.find({}).count() | ||||||
|  |  | ||||||
|             self.assertEqual(50, q) |             self.assertEqual(50, q) | ||||||
|   | |||||||
| @@ -1,6 +1,5 @@ | |||||||
| import unittest | import unittest | ||||||
|  | from mongoengine.base.datastructures import StrictDict, SemiStrictDict  | ||||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestStrictDict(unittest.TestCase): | class TestStrictDict(unittest.TestCase): | ||||||
| @@ -14,18 +13,9 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         d = self.dtype(a=1, b=1, c=1) |         d = self.dtype(a=1, b=1, c=1) | ||||||
|         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) |         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||||
|  |  | ||||||
|     def test_repr(self): |  | ||||||
|         d = self.dtype(a=1, b=2, c=3) |  | ||||||
|         self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') |  | ||||||
|  |  | ||||||
|         # make sure quotes are escaped properly |  | ||||||
|         d = self.dtype(a='"', b="'", c="") |  | ||||||
|         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') |  | ||||||
|  |  | ||||||
|     def test_init_fails_on_nonexisting_attrs(self): |     def test_init_fails_on_nonexisting_attrs(self): | ||||||
|         with self.assertRaises(AttributeError): |         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||||
|             self.dtype(a=1, b=2, d=3) |          | ||||||
|  |  | ||||||
|     def test_eq(self): |     def test_eq(self): | ||||||
|         d = self.dtype(a=1, b=1, c=1) |         d = self.dtype(a=1, b=1, c=1) | ||||||
|         dd = self.dtype(a=1, b=1, c=1) |         dd = self.dtype(a=1, b=1, c=1) | ||||||
| @@ -34,7 +24,7 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) |         g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) | ||||||
|         h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) |         h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) | ||||||
|         i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) |         i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) | ||||||
|  |          | ||||||
|         self.assertEqual(d, dd) |         self.assertEqual(d, dd) | ||||||
|         self.assertNotEqual(d, e) |         self.assertNotEqual(d, e) | ||||||
|         self.assertNotEqual(d, f) |         self.assertNotEqual(d, f) | ||||||
| @@ -47,18 +37,20 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         d = self.dtype() |         d = self.dtype() | ||||||
|         d.a = 1 |         d.a = 1 | ||||||
|         self.assertEqual(d.a, 1) |         self.assertEqual(d.a, 1) | ||||||
|         self.assertRaises(AttributeError, getattr, d, 'b') |         self.assertRaises(AttributeError, lambda: d.b) | ||||||
|  |      | ||||||
|     def test_setattr_raises_on_nonexisting_attr(self): |     def test_setattr_raises_on_nonexisting_attr(self): | ||||||
|         d = self.dtype() |         d = self.dtype() | ||||||
|         with self.assertRaises(AttributeError): |  | ||||||
|             d.x = 1 |  | ||||||
|  |  | ||||||
|  |         def _f(): | ||||||
|  |             d.x = 1 | ||||||
|  |         self.assertRaises(AttributeError, _f) | ||||||
|  |      | ||||||
|     def test_setattr_getattr_special(self): |     def test_setattr_getattr_special(self): | ||||||
|         d = self.strict_dict_class(["items"]) |         d = self.strict_dict_class(["items"]) | ||||||
|         d.items = 1 |         d.items = 1 | ||||||
|         self.assertEqual(d.items, 1) |         self.assertEqual(d.items, 1) | ||||||
|  |      | ||||||
|     def test_get(self): |     def test_get(self): | ||||||
|         d = self.dtype(a=1) |         d = self.dtype(a=1) | ||||||
|         self.assertEqual(d.get('a'), 1) |         self.assertEqual(d.get('a'), 1) | ||||||
| @@ -96,7 +88,7 @@ class TestSemiSrictDict(TestStrictDict): | |||||||
|     def test_init_succeeds_with_nonexisting_attrs(self): |     def test_init_succeeds_with_nonexisting_attrs(self): | ||||||
|         d = self.dtype(a=1, b=1, c=1, x=2) |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|         self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) |         self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) | ||||||
|  |     | ||||||
|     def test_iter_with_nonexisting_attrs(self): |     def test_iter_with_nonexisting_attrs(self): | ||||||
|         d = self.dtype(a=1, b=1, c=1, x=2) |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|         self.assertEqual(list(d), ['a', 'b', 'c', 'x']) |         self.assertEqual(list(d), ['a', 'b', 'c', 'x']) | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from bson import DBRef, ObjectId | from bson import DBRef, ObjectId | ||||||
| @@ -30,7 +32,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             user = User(name='user %s' % i) |             user = User(name='user %s' % i) | ||||||
|             user.save() |             user.save() | ||||||
|  |  | ||||||
| @@ -88,7 +90,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             user = User(name='user %s' % i) |             user = User(name='user %s' % i) | ||||||
|             user.save() |             user.save() | ||||||
|  |  | ||||||
| @@ -160,7 +162,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 26): |         for i in xrange(1, 26): | ||||||
|             user = User(name='user %s' % i) |             user = User(name='user %s' % i) | ||||||
|             user.save() |             user.save() | ||||||
|  |  | ||||||
| @@ -438,7 +440,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             a = UserA(name='User A %s' % i) |             a = UserA(name='User A %s' % i) | ||||||
|             a.save() |             a.save() | ||||||
|  |  | ||||||
| @@ -529,7 +531,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             a = UserA(name='User A %s' % i) |             a = UserA(name='User A %s' % i) | ||||||
|             a.save() |             a.save() | ||||||
|  |  | ||||||
| @@ -612,15 +614,15 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             user = User(name='user %s' % i) |             user = User(name='user %s' % i) | ||||||
|             user.save() |             user.save() | ||||||
|             members.append(user) |             members.append(user) | ||||||
|  |  | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -685,7 +687,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             a = UserA(name='User A %s' % i) |             a = UserA(name='User A %s' % i) | ||||||
|             a.save() |             a.save() | ||||||
|  |  | ||||||
| @@ -697,9 +699,9 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|             members += [a, b, c] |             members += [a, b, c] | ||||||
|  |  | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -781,16 +783,16 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             a = UserA(name='User A %s' % i) |             a = UserA(name='User A %s' % i) | ||||||
|             a.save() |             a.save() | ||||||
|  |  | ||||||
|             members += [a] |             members += [a] | ||||||
|  |  | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -864,7 +866,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             a = UserA(name='User A %s' % i) |             a = UserA(name='User A %s' % i) | ||||||
|             a.save() |             a.save() | ||||||
|  |  | ||||||
| @@ -876,9 +878,9 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|             members += [a, b, c] |             members += [a, b, c] | ||||||
|  |  | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|         group = Group(members={str(u.id): u for u in members}) |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|         group.save() |         group.save() | ||||||
|  |  | ||||||
|         with query_counter() as q: |         with query_counter() as q: | ||||||
| @@ -1101,7 +1103,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             User(name='user %s' % i).save() |             User(name='user %s' % i).save() | ||||||
|  |  | ||||||
|         Group(name="Test", members=User.objects).save() |         Group(name="Test", members=User.objects).save() | ||||||
| @@ -1130,7 +1132,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         User.drop_collection() |         User.drop_collection() | ||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             User(name='user %s' % i).save() |             User(name='user %s' % i).save() | ||||||
|  |  | ||||||
|         Group(name="Test", members=User.objects).save() |         Group(name="Test", members=User.objects).save() | ||||||
| @@ -1167,7 +1169,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         Group.drop_collection() |         Group.drop_collection() | ||||||
|  |  | ||||||
|         members = [] |         members = [] | ||||||
|         for i in range(1, 51): |         for i in xrange(1, 51): | ||||||
|             a = UserA(name='User A %s' % i).save() |             a = UserA(name='User A %s' % i).save() | ||||||
|             b = UserB(name='User B %s' % i).save() |             b = UserB(name='User B %s' % i).save() | ||||||
|             c = UserC(name='User C %s' % i).save() |             c = UserC(name='User C %s' % i).save() | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | import sys | ||||||
|  |  | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from pymongo import ReadPreference | from pymongo import ReadPreference | ||||||
| @@ -15,7 +18,7 @@ else: | |||||||
|  |  | ||||||
| import mongoengine | import mongoengine | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import MongoEngineConnectionError | from mongoengine.connection import ConnectionError | ||||||
|  |  | ||||||
|  |  | ||||||
| class ConnectionTest(unittest.TestCase): | class ConnectionTest(unittest.TestCase): | ||||||
| @@ -38,7 +41,7 @@ class ConnectionTest(unittest.TestCase): | |||||||
|             conn = connect(db='mongoenginetest', |             conn = connect(db='mongoenginetest', | ||||||
|                            host="mongodb://localhost/mongoenginetest?replicaSet=rs", |                            host="mongodb://localhost/mongoenginetest?replicaSet=rs", | ||||||
|                            read_preference=READ_PREF) |                            read_preference=READ_PREF) | ||||||
|         except MongoEngineConnectionError as e: |         except ConnectionError, e: | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         if not isinstance(conn, CONN_CLASS): |         if not isinstance(conn, CONN_CLASS): | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user