Compare commits
97 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
aa8a991d20 | ||
|
40ba51ac43 | ||
|
d20430a778 | ||
|
f08f749cd9 | ||
|
a6c04f4f9a | ||
|
15b6c1590f | ||
|
4a8985278d | ||
|
996618a495 | ||
|
1f02d5fbbd | ||
|
c58b9f00f0 | ||
|
f131b18cbe | ||
|
118a998138 | ||
|
7ad6f036e7 | ||
|
1d29b824a8 | ||
|
3caf2dce28 | ||
|
1fc5b954f2 | ||
|
31d99c0bd2 | ||
|
0ac59c67ea | ||
|
8e8c74c621 | ||
|
f996f3df74 | ||
|
9499c97e18 | ||
|
c1c81fc07b | ||
|
072e86a2f0 | ||
|
70d6e763b0 | ||
|
15f4d4fee6 | ||
|
82e28dec43 | ||
|
b407c0e6c6 | ||
|
27ea01ee05 | ||
|
7ed5829b2c | ||
|
5bf1dd55b1 | ||
|
36aebffcc0 | ||
|
84c42ed58c | ||
|
9634e44343 | ||
|
048a045966 | ||
|
a18c8c0eb4 | ||
|
5fb0f46e3f | ||
|
962997ed16 | ||
|
daca0ebc14 | ||
|
9ae8fe7c2d | ||
|
1907133f99 | ||
|
4334955e39 | ||
|
f00c9dc4d6 | ||
|
7d0687ec73 | ||
|
da3773bfe8 | ||
|
6e1c132ee8 | ||
|
24ba35d76f | ||
|
64b63e9d52 | ||
|
7848a82a1c | ||
|
6a843cc8b2 | ||
|
ecdb0785a4 | ||
|
9a55caed75 | ||
|
2e01eb87db | ||
|
597b962ad5 | ||
|
7531f533e0 | ||
|
6b9d71554e | ||
|
bb1089e03d | ||
|
c82f0c937d | ||
|
00d2fd685a | ||
|
f28e1b8c90 | ||
|
2b17985a11 | ||
|
b392e3102e | ||
|
58b0b18ddd | ||
|
6a9ef319d0 | ||
|
cf38ef70cb | ||
|
ac64ade10f | ||
|
ee85af34d8 | ||
|
9d53ad53e5 | ||
|
9cdc3ebee6 | ||
|
14a5e05d64 | ||
|
f7b7d0f79e | ||
|
d98f36ceff | ||
|
abfabc30c9 | ||
|
c1aff7a248 | ||
|
e44f71eeb1 | ||
|
cb578c84e2 | ||
|
565e1dc0ed | ||
|
b1e28d02f7 | ||
|
d1467c2f73 | ||
|
c439150431 | ||
|
9bb3dfd639 | ||
|
4caa58b9ec | ||
|
b5213097e8 | ||
|
61081651e4 | ||
|
4ccfdf051d | ||
|
9f2a9d9cda | ||
|
827de76345 | ||
|
fdcaca42ae | ||
|
0744892244 | ||
|
b70ffc69df | ||
|
73b12cc32f | ||
|
ba6a37f315 | ||
|
6f8be8c8ac | ||
|
68497542b3 | ||
|
3d762fed10 | ||
|
48b849c031 | ||
|
9b02867293 | ||
|
99a5f2cd9d |
@@ -3,30 +3,20 @@
|
||||
sudo apt-get remove mongodb-org-server
|
||||
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10
|
||||
|
||||
if [ "$MONGODB" = "2.6" ]; then
|
||||
echo "deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen" | sudo tee /etc/apt/sources.list.d/mongodb.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install mongodb-org-server=2.6.12
|
||||
# service should be started automatically
|
||||
elif [ "$MONGODB" = "3.0" ]; then
|
||||
echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install mongodb-org-server=3.0.14
|
||||
# service should be started automatically
|
||||
elif [ "$MONGODB" = "3.2" ]; then
|
||||
sudo apt-key adv --keyserver keyserver.ubuntu.com --recv EA312927
|
||||
echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.2 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.2.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install mongodb-org-server=3.2.20
|
||||
# service should be started automatically
|
||||
elif [ "$MONGODB" = "3.4" ]; then
|
||||
if [ "$MONGODB" = "3.4" ]; then
|
||||
sudo apt-key adv --keyserver keyserver.ubuntu.com:80 --recv 0C49F3730359A14518585931BC711F9BA15703C6
|
||||
echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install mongodb-org-server=3.4.17
|
||||
# service should be started automatically
|
||||
elif [ "$MONGODB" = "3.6" ]; then
|
||||
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2930ADAE8CAF5059EE73BB4B58712A2291FA4AD5
|
||||
echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.6 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.6.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install mongodb-org-server=3.6.12
|
||||
# service should be started automatically
|
||||
else
|
||||
echo "Invalid MongoDB version, expected 2.6, 3.0, 3.2 or 3.4."
|
||||
echo "Invalid MongoDB version, expected 2.6, 3.0, 3.2, 3.4 or 3.6."
|
||||
exit 1
|
||||
fi;
|
||||
|
||||
|
47
.travis.yml
47
.travis.yml
@@ -2,11 +2,18 @@
|
||||
# PyMongo combinations. However, that would result in an overly long build
|
||||
# with a very large number of jobs, hence we only test a subset of all the
|
||||
# combinations:
|
||||
# * MongoDB v2.6 is currently the "main" version tested against Python v2.7,
|
||||
# v3.5, v3.6, PyPy, and PyMongo v3.x.
|
||||
# * MongoDB v3.0 & v3.2 are tested against Python v2.7, v3.5 & v3.6
|
||||
# and Pymongo v3.5 & v3.x
|
||||
# * MongoDB v3.4 is tested against v3.6 and Pymongo v3.x
|
||||
# * MongoDB v3.4 & the latest PyMongo v3.x is currently the "main" setup,
|
||||
# tested against Python v2.7, v3.5, v3.6, and PyPy.
|
||||
# * Besides that, we test the lowest actively supported Python/MongoDB/PyMongo
|
||||
# combination: MongoDB v3.4, PyMongo v3.4, Python v2.7.
|
||||
# * MongoDB v3.6 is tested against Python v3.6, and PyMongo v3.6, v3.7, v3.8.
|
||||
#
|
||||
# We should periodically check MongoDB Server versions supported by MongoDB
|
||||
# Inc., add newly released versions to the test matrix, and remove versions
|
||||
# which have reached their End of Life. See:
|
||||
# 1. https://www.mongodb.com/support-policy.
|
||||
# 2. https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility
|
||||
#
|
||||
# Reminder: Update README.rst if you change MongoDB versions we test.
|
||||
|
||||
language: python
|
||||
@@ -18,7 +25,7 @@ python:
|
||||
- pypy
|
||||
|
||||
env:
|
||||
- MONGODB=2.6 PYMONGO=3.x
|
||||
- MONGODB=3.4 PYMONGO=3.x
|
||||
|
||||
matrix:
|
||||
# Finish the build as soon as one job fails
|
||||
@@ -26,19 +33,13 @@ matrix:
|
||||
|
||||
include:
|
||||
- python: 2.7
|
||||
env: MONGODB=3.0 PYMONGO=3.5
|
||||
- python: 3.5
|
||||
env: MONGODB=3.2 PYMONGO=3.x
|
||||
env: MONGODB=3.4 PYMONGO=3.4.x
|
||||
- python: 3.6
|
||||
env: MONGODB=3.0 PYMONGO=3.5
|
||||
- python: 3.6
|
||||
env: MONGODB=3.2 PYMONGO=3.x
|
||||
- python: 3.6
|
||||
env: MONGODB=3.4 PYMONGO=3.x
|
||||
env: MONGODB=3.6 PYMONGO=3.x
|
||||
|
||||
before_install:
|
||||
- bash .install_mongodb_on_travis.sh
|
||||
- sleep 15 # https://docs.travis-ci.com/user/database-setup/#MongoDB-does-not-immediately-accept-connections
|
||||
- sleep 20 # https://docs.travis-ci.com/user/database-setup/#mongodb-does-not-immediately-accept-connections
|
||||
- mongo --eval 'db.version();'
|
||||
|
||||
install:
|
||||
@@ -48,8 +49,8 @@ install:
|
||||
- travis_retry pip install --upgrade pip
|
||||
- travis_retry pip install coveralls
|
||||
- travis_retry pip install flake8 flake8-import-order
|
||||
- travis_retry pip install tox>=1.9
|
||||
- travis_retry pip install "virtualenv<14.0.0" # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32)
|
||||
- travis_retry pip install "tox" # tox 3.11.0 has requirement virtualenv>=14.0.0
|
||||
- travis_retry pip install "virtualenv" # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32)
|
||||
- travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test
|
||||
|
||||
# Cache dependencies installed via pip
|
||||
@@ -84,15 +85,15 @@ deploy:
|
||||
password:
|
||||
secure: QMyatmWBnC6ZN3XLW2+fTBDU4LQcp1m/LjR2/0uamyeUzWKdlOoh/Wx5elOgLwt/8N9ppdPeG83ose1jOz69l5G0MUMjv8n/RIcMFSpCT59tGYqn3kh55b0cIZXFT9ar+5cxlif6a5rS72IHm5li7QQyxexJIII6Uxp0kpvUmek=
|
||||
|
||||
# create a source distribution and a pure python wheel for faster installs
|
||||
# Create a source distribution and a pure python wheel for faster installs.
|
||||
distributions: "sdist bdist_wheel"
|
||||
|
||||
# only deploy on tagged commits (aka GitHub releases) and only for the
|
||||
# parent repo's builds running Python 2.7 along with PyMongo v3.x (we run
|
||||
# Travis against many different Python and PyMongo versions and we don't
|
||||
# want the deploy to occur multiple times).
|
||||
# Only deploy on tagged commits (aka GitHub releases) and only for the parent
|
||||
# repo's builds running Python v2.7 along with PyMongo v3.x and MongoDB v3.4.
|
||||
# We run Travis against many different Python, PyMongo, and MongoDB versions
|
||||
# and we don't want the deploy to occur multiple times).
|
||||
on:
|
||||
tags: true
|
||||
repo: MongoEngine/mongoengine
|
||||
condition: "$PYMONGO = 3.x"
|
||||
condition: ($PYMONGO = 3.x) && ($MONGODB = 3.4)
|
||||
python: 2.7
|
||||
|
5
AUTHORS
5
AUTHORS
@@ -248,4 +248,7 @@ that much better:
|
||||
* Andy Yankovsky (https://github.com/werat)
|
||||
* Bastien Gérard (https://github.com/bagerard)
|
||||
* Trevor Hall (https://github.com/tjhall13)
|
||||
* Gleb Voropaev (https://github.com/buggyspace)
|
||||
* Gleb Voropaev (https://github.com/buggyspace)
|
||||
* Paulo Amaral (https://github.com/pauloAmaral)
|
||||
* Gaurav Dadhania (https://github.com/GVRV)
|
||||
* Yurii Andrieiev (https://github.com/yandrieiev)
|
||||
|
10
README.rst
10
README.rst
@@ -26,10 +26,10 @@ an `API reference <https://mongoengine-odm.readthedocs.io/apireference.html>`_.
|
||||
|
||||
Supported MongoDB Versions
|
||||
==========================
|
||||
MongoEngine is currently tested against MongoDB v2.6, v3.0, v3.2 and v3.4. Future
|
||||
versions should be supported as well, but aren't actively tested at the moment.
|
||||
Make sure to open an issue or submit a pull request if you experience any
|
||||
problems with MongoDB v3.4+.
|
||||
MongoEngine is currently tested against MongoDB v3.4 and v3.6. Future versions
|
||||
should be supported as well, but aren't actively tested at the moment. Make
|
||||
sure to open an issue or submit a pull request if you experience any problems
|
||||
with MongoDB version > 3.6.
|
||||
|
||||
Installation
|
||||
============
|
||||
@@ -47,7 +47,7 @@ Dependencies
|
||||
All of the dependencies can easily be installed via `pip <https://pip.pypa.io/>`_.
|
||||
At the very least, you'll need these two packages to use MongoEngine:
|
||||
|
||||
- pymongo>=2.7.1
|
||||
- pymongo>=3.5
|
||||
- six>=1.10.0
|
||||
|
||||
If you utilize a ``DateTimeField``, you might also use a more flexible date parser:
|
||||
|
@@ -13,6 +13,7 @@ Documents
|
||||
|
||||
.. autoclass:: mongoengine.Document
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. attribute:: objects
|
||||
|
||||
@@ -21,15 +22,18 @@ Documents
|
||||
|
||||
.. autoclass:: mongoengine.EmbeddedDocument
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: mongoengine.DynamicDocument
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: mongoengine.DynamicEmbeddedDocument
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: mongoengine.document.MapReduceDocument
|
||||
:members:
|
||||
:members:
|
||||
|
||||
.. autoclass:: mongoengine.ValidationError
|
||||
:members:
|
||||
|
@@ -6,6 +6,29 @@ Development
|
||||
===========
|
||||
- (Fill this out as you fix issues and develop your features).
|
||||
|
||||
Changes in 0.18.0
|
||||
=================
|
||||
- Drop support for EOL'd MongoDB v2.6, v3.0, and v3.2.
|
||||
- MongoEngine now requires PyMongo >= v3.4. Travis CI now tests against MongoDB v3.4 – v3.6 and PyMongo v3.4 – v3.6 (#2017 #2066).
|
||||
- Improve performance by avoiding a call to `to_mongo` in `Document.save()` #2049
|
||||
- Connection/disconnection improvements:
|
||||
- Expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all`
|
||||
- Fix disconnecting #566 #1599 #605 #607 #1213 #565
|
||||
- Improve documentation of `connect`/`disconnect`
|
||||
- Fix issue when using multiple connections to the same mongo with different credentials #2047
|
||||
- `connect` fails immediately when db name contains invalid characters #2031 #1718
|
||||
- Fix the default write concern of `Document.save` that was overwriting the connection write concern #568
|
||||
- Fix querying on `List(EmbeddedDocument)` subclasses fields #1961 #1492
|
||||
- Fix querying on `(Generic)EmbeddedDocument` subclasses fields #475
|
||||
- Fix `QuerySet.aggregate` so that it takes limit and skip value into account #2029
|
||||
- Generate unique indices for `SortedListField` and `EmbeddedDocumentListFields` #2020
|
||||
- BREAKING CHANGE: Changed the behavior of a custom field validator (i.e `validation` parameter of a `Field`). It is now expected to raise a `ValidationError` instead of returning True/False #2050
|
||||
- BREAKING CHANGES (associated with connect/disconnect fixes):
|
||||
- Calling `connect` 2 times with the same alias and different parameter will raise an error (should call `disconnect` first).
|
||||
- `disconnect` now clears `mongoengine.connection._connection_settings`.
|
||||
- `disconnect` now clears the cached attribute `Document._collection`.
|
||||
- BREAKING CHANGE: `EmbeddedDocument.save` & `.reload` is no longier exist #1552
|
||||
|
||||
Changes in 0.17.0
|
||||
=================
|
||||
- Fix .only() working improperly after using .count() of the same instance of QuerySet
|
||||
@@ -15,6 +38,7 @@ Changes in 0.17.0
|
||||
- Fix InvalidStringData error when using modify on a BinaryField #1127
|
||||
- DEPRECATION: `EmbeddedDocument.save` & `.reload` are marked as deprecated and will be removed in a next version of mongoengine #1552
|
||||
- Fix test suite and CI to support MongoDB 3.4 #1445
|
||||
- Fix reference fields querying the database on each access if value contains orphan DBRefs
|
||||
|
||||
=================
|
||||
Changes in 0.16.3
|
||||
|
@@ -4,9 +4,11 @@
|
||||
Connecting to MongoDB
|
||||
=====================
|
||||
|
||||
To connect to a running instance of :program:`mongod`, use the
|
||||
:func:`~mongoengine.connect` function. The first argument is the name of the
|
||||
database to connect to::
|
||||
Connections in MongoEngine are registered globally and are identified with aliases.
|
||||
If no `alias` is provided during the connection, it will use "default" as alias.
|
||||
|
||||
To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect`
|
||||
function. The first argument is the name of the database to connect to::
|
||||
|
||||
from mongoengine import connect
|
||||
connect('project1')
|
||||
@@ -42,6 +44,9 @@ the :attr:`host` to
|
||||
will establish connection to ``production`` database using
|
||||
``admin`` username and ``qwerty`` password.
|
||||
|
||||
.. note:: Calling :func:`~mongoengine.connect` without argument will establish
|
||||
a connection to the "test" database by default
|
||||
|
||||
Replica Sets
|
||||
============
|
||||
|
||||
@@ -71,28 +76,61 @@ is used.
|
||||
In the background this uses :func:`~mongoengine.register_connection` to
|
||||
store the data and you can register all aliases up front if required.
|
||||
|
||||
Individual documents can also support multiple databases by providing a
|
||||
Documents defined in different database
|
||||
---------------------------------------
|
||||
Individual documents can be attached to different databases by providing a
|
||||
`db_alias` in their meta data. This allows :class:`~pymongo.dbref.DBRef`
|
||||
objects to point across databases and collections. Below is an example schema,
|
||||
using 3 different databases to store data::
|
||||
|
||||
connect(alias='user-db-alias', db='user-db')
|
||||
connect(alias='book-db-alias', db='book-db')
|
||||
connect(alias='users-books-db-alias', db='users-books-db')
|
||||
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
meta = {'db_alias': 'user-db'}
|
||||
meta = {'db_alias': 'user-db-alias'}
|
||||
|
||||
class Book(Document):
|
||||
name = StringField()
|
||||
|
||||
meta = {'db_alias': 'book-db'}
|
||||
meta = {'db_alias': 'book-db-alias'}
|
||||
|
||||
class AuthorBooks(Document):
|
||||
author = ReferenceField(User)
|
||||
book = ReferenceField(Book)
|
||||
|
||||
meta = {'db_alias': 'users-books-db'}
|
||||
meta = {'db_alias': 'users-books-db-alias'}
|
||||
|
||||
|
||||
Disconnecting an existing connection
|
||||
------------------------------------
|
||||
The function :func:`~mongoengine.disconnect` can be used to
|
||||
disconnect a particular connection. This can be used to change a
|
||||
connection globally::
|
||||
|
||||
from mongoengine import connect, disconnect
|
||||
connect('a_db', alias='db1')
|
||||
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
meta = {'db_alias': 'db1'}
|
||||
|
||||
disconnect(alias='db1')
|
||||
|
||||
connect('another_db', alias='db1')
|
||||
|
||||
.. note:: Calling :func:`~mongoengine.disconnect` without argument
|
||||
will disconnect the "default" connection
|
||||
|
||||
.. note:: Since connections gets registered globally, it is important
|
||||
to use the `disconnect` function from MongoEngine and not the
|
||||
`disconnect()` method of an existing connection (pymongo.MongoClient)
|
||||
|
||||
.. note:: :class:`~mongoengine.Document` are caching the pymongo collection.
|
||||
using `disconnect` ensures that it gets cleaned as well
|
||||
|
||||
Context Managers
|
||||
================
|
||||
Sometimes you may want to switch the database or collection to query against.
|
||||
@@ -119,7 +157,7 @@ access to the same User document across databases::
|
||||
|
||||
Switch Collection
|
||||
-----------------
|
||||
The :class:`~mongoengine.context_managers.switch_collection` context manager
|
||||
The :func:`~mongoengine.context_managers.switch_collection` context manager
|
||||
allows you to change the collection for a given class allowing quick and easy
|
||||
access to the same Group document across collection::
|
||||
|
||||
|
@@ -176,6 +176,21 @@ arguments can be set on all fields:
|
||||
class Shirt(Document):
|
||||
size = StringField(max_length=3, choices=SIZE)
|
||||
|
||||
:attr:`validation` (Optional)
|
||||
A callable to validate the value of the field.
|
||||
The callable takes the value as parameter and should raise a ValidationError
|
||||
if validation fails
|
||||
|
||||
e.g ::
|
||||
|
||||
def _not_empty(val):
|
||||
if not val:
|
||||
raise ValidationError('value can not be empty')
|
||||
|
||||
class Person(Document):
|
||||
name = StringField(validation=_not_empty)
|
||||
|
||||
|
||||
:attr:`**kwargs` (Optional)
|
||||
You can supply additional metadata as arbitrary additional keyword
|
||||
arguments. You can not override existing attributes, however. Common
|
||||
|
@@ -19,3 +19,30 @@ or with an alias:
|
||||
|
||||
connect('mongoenginetest', host='mongomock://localhost', alias='testdb')
|
||||
conn = get_connection('testdb')
|
||||
|
||||
Example of test file:
|
||||
--------
|
||||
.. code-block:: python
|
||||
|
||||
import unittest
|
||||
from mongoengine import connect, disconnect
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
|
||||
class TestPerson(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
connect('mongoenginetest', host='mongomock://localhost')
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
disconnect()
|
||||
|
||||
def test_thing(self):
|
||||
pers = Person(name='John')
|
||||
pers.save()
|
||||
|
||||
fresh_pers = Person.objects().first()
|
||||
self.assertEqual(fresh_pers.name, 'John')
|
||||
|
@@ -64,7 +64,7 @@ Available operators are as follows:
|
||||
* ``gt`` -- greater than
|
||||
* ``gte`` -- greater than or equal to
|
||||
* ``not`` -- negate a standard check, may be used before other operators (e.g.
|
||||
``Q(age__not__mod=5)``)
|
||||
``Q(age__not__mod=(5, 0))``)
|
||||
* ``in`` -- value is in list (a list of values should be provided)
|
||||
* ``nin`` -- value is not in list (a list of values should be provided)
|
||||
* ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values
|
||||
|
@@ -23,12 +23,13 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
|
||||
list(signals.__all__) + list(errors.__all__))
|
||||
|
||||
|
||||
VERSION = (0, 17, 0)
|
||||
VERSION = (0, 18, 0)
|
||||
|
||||
|
||||
def get_version():
|
||||
"""Return the VERSION as a string, e.g. for VERSION == (0, 10, 7),
|
||||
return '0.10.7'.
|
||||
"""Return the VERSION as a string.
|
||||
|
||||
For example, if `VERSION == (0, 10, 7)`, return '0.10.7'.
|
||||
"""
|
||||
return '.'.join(map(str, VERSION))
|
||||
|
||||
|
@@ -13,7 +13,7 @@ _document_registry = {}
|
||||
|
||||
|
||||
def get_document(name):
|
||||
"""Get a document class by name."""
|
||||
"""Get a registered Document class by name."""
|
||||
doc = _document_registry.get(name, None)
|
||||
if not doc:
|
||||
# Possible old style name
|
||||
@@ -30,3 +30,12 @@ def get_document(name):
|
||||
been imported?
|
||||
""".strip() % name)
|
||||
return doc
|
||||
|
||||
|
||||
def _get_documents_by_db(connection_alias, default_connection_alias):
|
||||
"""Get all registered Documents class attached to a given database"""
|
||||
def get_doc_alias(doc_cls):
|
||||
return doc_cls._meta.get('db_alias', default_connection_alias)
|
||||
|
||||
return [doc_cls for doc_cls in _document_registry.values()
|
||||
if get_doc_alias(doc_cls) == connection_alias]
|
||||
|
@@ -293,8 +293,7 @@ class BaseDocument(object):
|
||||
"""
|
||||
Return as SON data ready for use with MongoDB.
|
||||
"""
|
||||
if not fields:
|
||||
fields = []
|
||||
fields = fields or []
|
||||
|
||||
data = SON()
|
||||
data['_id'] = None
|
||||
@@ -349,6 +348,9 @@ class BaseDocument(object):
|
||||
def validate(self, clean=True):
|
||||
"""Ensure that all fields' values are valid and that required fields
|
||||
are present.
|
||||
|
||||
Raises :class:`ValidationError` if any of the fields' values are found
|
||||
to be invalid.
|
||||
"""
|
||||
# Ensure that each field is matched to a valid value
|
||||
errors = {}
|
||||
@@ -883,7 +885,8 @@ class BaseDocument(object):
|
||||
index = {'fields': fields, 'unique': True, 'sparse': sparse}
|
||||
unique_indexes.append(index)
|
||||
|
||||
if field.__class__.__name__ == 'ListField':
|
||||
if field.__class__.__name__ in {'EmbeddedDocumentListField',
|
||||
'ListField', 'SortedListField'}:
|
||||
field = field.field
|
||||
|
||||
# Grab any embedded document field unique indexes
|
||||
|
@@ -11,8 +11,7 @@ from mongoengine.base.common import UPDATE_OPERATORS
|
||||
from mongoengine.base.datastructures import (BaseDict, BaseList,
|
||||
EmbeddedDocumentList)
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.errors import ValidationError
|
||||
|
||||
from mongoengine.errors import DeprecatedError, ValidationError
|
||||
|
||||
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
|
||||
'GeoJsonBaseField')
|
||||
@@ -53,8 +52,8 @@ class BaseField(object):
|
||||
unique with.
|
||||
:param primary_key: Mark this field as the primary key. Defaults to False.
|
||||
:param validation: (optional) A callable to validate the value of the
|
||||
field. Generally this is deprecated in favour of the
|
||||
`FIELD.validate` method
|
||||
field. The callable takes the value as parameter and should raise
|
||||
a ValidationError if validation fails
|
||||
:param choices: (optional) The valid choices
|
||||
:param null: (optional) If the field value can be null. If no and there is a default value
|
||||
then the default value is set
|
||||
@@ -226,10 +225,18 @@ class BaseField(object):
|
||||
# check validation argument
|
||||
if self.validation is not None:
|
||||
if callable(self.validation):
|
||||
if not self.validation(value):
|
||||
self.error('Value does not match custom validation method')
|
||||
try:
|
||||
# breaking change of 0.18
|
||||
# Get rid of True/False-type return for the validation method
|
||||
# in favor of having validation raising a ValidationError
|
||||
ret = self.validation(value)
|
||||
if ret is not None:
|
||||
raise DeprecatedError('validation argument for `%s` must not return anything, '
|
||||
'it should raise a ValidationError if validation fails' % self.name)
|
||||
except ValidationError as ex:
|
||||
self.error(str(ex))
|
||||
else:
|
||||
raise ValueError('validation argument for "%s" must be a '
|
||||
raise ValueError('validation argument for `"%s"` must be a '
|
||||
'callable.' % self.name)
|
||||
|
||||
self.validate(value, **kwargs)
|
||||
@@ -276,11 +283,16 @@ class ComplexBaseField(BaseField):
|
||||
|
||||
_dereference = _import_class('DeReference')()
|
||||
|
||||
if instance._initialised and dereference and instance._data.get(self.name):
|
||||
if (instance._initialised and
|
||||
dereference and
|
||||
instance._data.get(self.name) and
|
||||
not getattr(instance._data[self.name], '_dereferenced', False)):
|
||||
instance._data[self.name] = _dereference(
|
||||
instance._data.get(self.name), max_depth=1, instance=instance,
|
||||
name=self.name
|
||||
)
|
||||
if hasattr(instance._data[self.name], '_dereferenced'):
|
||||
instance._data[self.name]._dereferenced = True
|
||||
|
||||
value = super(ComplexBaseField, self).__get__(instance, owner)
|
||||
|
||||
|
@@ -184,9 +184,6 @@ class DocumentMetaclass(type):
|
||||
if issubclass(new_class, EmbeddedDocument):
|
||||
raise InvalidDocumentError('CachedReferenceFields is not '
|
||||
'allowed in EmbeddedDocuments')
|
||||
if not f.document_type:
|
||||
raise InvalidDocumentError(
|
||||
'Document is not available to sync')
|
||||
|
||||
if f.auto_sync:
|
||||
f.start_listener()
|
||||
|
@@ -31,7 +31,6 @@ def _import_class(cls_name):
|
||||
|
||||
field_classes = _field_list_cache
|
||||
|
||||
queryset_classes = ('OperationError',)
|
||||
deref_classes = ('DeReference',)
|
||||
|
||||
if cls_name == 'BaseDocument':
|
||||
@@ -43,14 +42,11 @@ def _import_class(cls_name):
|
||||
elif cls_name in field_classes:
|
||||
from mongoengine import fields as module
|
||||
import_classes = field_classes
|
||||
elif cls_name in queryset_classes:
|
||||
from mongoengine import queryset as module
|
||||
import_classes = queryset_classes
|
||||
elif cls_name in deref_classes:
|
||||
from mongoengine import dereference as module
|
||||
import_classes = deref_classes
|
||||
else:
|
||||
raise ValueError('No import set for: ' % cls_name)
|
||||
raise ValueError('No import set for: %s' % cls_name)
|
||||
|
||||
for cls in import_classes:
|
||||
_class_registry_cache[cls] = getattr(module, cls)
|
||||
|
@@ -1,19 +1,22 @@
|
||||
from pymongo import MongoClient, ReadPreference, uri_parser
|
||||
from pymongo.database import _check_name
|
||||
import six
|
||||
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
|
||||
__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection',
|
||||
'DEFAULT_CONNECTION_NAME', 'get_db']
|
||||
__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all',
|
||||
'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME',
|
||||
'get_db', 'get_connection']
|
||||
|
||||
|
||||
DEFAULT_CONNECTION_NAME = 'default'
|
||||
DEFAULT_DATABASE_NAME = 'test'
|
||||
DEFAULT_HOST = 'localhost'
|
||||
DEFAULT_PORT = 27017
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
READ_PREFERENCE = ReadPreference.PRIMARY
|
||||
else:
|
||||
from pymongo import MongoReplicaSetClient
|
||||
READ_PREFERENCE = False
|
||||
_connection_settings = {}
|
||||
_connections = {}
|
||||
_dbs = {}
|
||||
|
||||
READ_PREFERENCE = ReadPreference.PRIMARY
|
||||
|
||||
|
||||
class MongoEngineConnectionError(Exception):
|
||||
@@ -23,45 +26,48 @@ class MongoEngineConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_connection_settings = {}
|
||||
_connections = {}
|
||||
_dbs = {}
|
||||
def _check_db_name(name):
|
||||
"""Check if a database name is valid.
|
||||
This functionality is copied from pymongo Database class constructor.
|
||||
"""
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError('name must be an instance of %s' % six.string_types)
|
||||
elif name != '$external':
|
||||
_check_name(name)
|
||||
|
||||
|
||||
def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
authentication_mechanism=None,
|
||||
**kwargs):
|
||||
"""Add a connection.
|
||||
def _get_connection_settings(
|
||||
db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
authentication_mechanism=None,
|
||||
**kwargs):
|
||||
"""Get the connection settings as a dict
|
||||
|
||||
:param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
:param name: the name of the specific database to use
|
||||
:param db: the name of the database to use, for compatibility with connect
|
||||
:param host: the host name of the :program:`mongod` instance to connect to
|
||||
:param port: the port that the :program:`mongod` instance is running on
|
||||
:param read_preference: The read preference for the collection
|
||||
** Added pymongo 2.1
|
||||
:param username: username to authenticate with
|
||||
:param password: password to authenticate with
|
||||
:param authentication_source: database to authenticate against
|
||||
:param authentication_mechanism: database authentication mechanisms.
|
||||
: param db: the name of the database to use, for compatibility with connect
|
||||
: param name: the name of the specific database to use
|
||||
: param host: the host name of the: program: `mongod` instance to connect to
|
||||
: param port: the port that the: program: `mongod` instance is running on
|
||||
: param read_preference: The read preference for the collection
|
||||
: param username: username to authenticate with
|
||||
: param password: password to authenticate with
|
||||
: param authentication_source: database to authenticate against
|
||||
: param authentication_mechanism: database authentication mechanisms.
|
||||
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
|
||||
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
|
||||
:param is_mock: explicitly use mongomock for this connection
|
||||
(can also be done by using `mongomock://` as db host prefix)
|
||||
:param kwargs: ad-hoc parameters to be passed into the pymongo driver,
|
||||
: param is_mock: explicitly use mongomock for this connection
|
||||
(can also be done by using `mongomock: // ` as db host prefix)
|
||||
: param kwargs: ad-hoc parameters to be passed into the pymongo driver,
|
||||
for example maxpoolsize, tz_aware, etc. See the documentation
|
||||
for pymongo's `MongoClient` for a full list.
|
||||
|
||||
.. versionchanged:: 0.10.6 - added mongomock support
|
||||
"""
|
||||
conn_settings = {
|
||||
'name': name or db or 'test',
|
||||
'host': host or 'localhost',
|
||||
'port': port or 27017,
|
||||
'name': name or db or DEFAULT_DATABASE_NAME,
|
||||
'host': host or DEFAULT_HOST,
|
||||
'port': port or DEFAULT_PORT,
|
||||
'read_preference': read_preference,
|
||||
'username': username,
|
||||
'password': password,
|
||||
@@ -69,6 +75,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
'authentication_mechanism': authentication_mechanism
|
||||
}
|
||||
|
||||
_check_db_name(conn_settings['name'])
|
||||
conn_host = conn_settings['host']
|
||||
|
||||
# Host can be a list or a string, so if string, force to a list.
|
||||
@@ -104,16 +111,28 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
conn_settings['authentication_source'] = uri_options['authsource']
|
||||
if 'authmechanism' in uri_options:
|
||||
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
|
||||
if IS_PYMONGO_3 and 'readpreference' in uri_options:
|
||||
if 'readpreference' in uri_options:
|
||||
read_preferences = (
|
||||
ReadPreference.NEAREST,
|
||||
ReadPreference.PRIMARY,
|
||||
ReadPreference.PRIMARY_PREFERRED,
|
||||
ReadPreference.SECONDARY,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
read_pf_mode = uri_options['readpreference'].lower()
|
||||
ReadPreference.SECONDARY_PREFERRED,
|
||||
)
|
||||
|
||||
# Starting with PyMongo v3.5, the "readpreference" option is
|
||||
# returned as a string (e.g. "secondaryPreferred") and not an
|
||||
# int (e.g. 3).
|
||||
# TODO simplify the code below once we drop support for
|
||||
# PyMongo v3.4.
|
||||
read_pf_mode = uri_options['readpreference']
|
||||
if isinstance(read_pf_mode, six.string_types):
|
||||
read_pf_mode = read_pf_mode.lower()
|
||||
for preference in read_preferences:
|
||||
if preference.name.lower() == read_pf_mode:
|
||||
if (
|
||||
preference.name.lower() == read_pf_mode or
|
||||
preference.mode == read_pf_mode
|
||||
):
|
||||
conn_settings['read_preference'] = preference
|
||||
break
|
||||
else:
|
||||
@@ -125,17 +144,74 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
kwargs.pop('is_slave', None)
|
||||
|
||||
conn_settings.update(kwargs)
|
||||
return conn_settings
|
||||
|
||||
|
||||
def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
read_preference=READ_PREFERENCE,
|
||||
username=None, password=None,
|
||||
authentication_source=None,
|
||||
authentication_mechanism=None,
|
||||
**kwargs):
|
||||
"""Register the connection settings.
|
||||
|
||||
: param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
: param name: the name of the specific database to use
|
||||
: param db: the name of the database to use, for compatibility with connect
|
||||
: param host: the host name of the: program: `mongod` instance to connect to
|
||||
: param port: the port that the: program: `mongod` instance is running on
|
||||
: param read_preference: The read preference for the collection
|
||||
: param username: username to authenticate with
|
||||
: param password: password to authenticate with
|
||||
: param authentication_source: database to authenticate against
|
||||
: param authentication_mechanism: database authentication mechanisms.
|
||||
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
|
||||
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
|
||||
: param is_mock: explicitly use mongomock for this connection
|
||||
(can also be done by using `mongomock: // ` as db host prefix)
|
||||
: param kwargs: ad-hoc parameters to be passed into the pymongo driver,
|
||||
for example maxpoolsize, tz_aware, etc. See the documentation
|
||||
for pymongo's `MongoClient` for a full list.
|
||||
|
||||
.. versionchanged:: 0.10.6 - added mongomock support
|
||||
"""
|
||||
conn_settings = _get_connection_settings(
|
||||
db=db, name=name, host=host, port=port,
|
||||
read_preference=read_preference,
|
||||
username=username, password=password,
|
||||
authentication_source=authentication_source,
|
||||
authentication_mechanism=authentication_mechanism,
|
||||
**kwargs)
|
||||
_connection_settings[alias] = conn_settings
|
||||
|
||||
|
||||
def disconnect(alias=DEFAULT_CONNECTION_NAME):
|
||||
"""Close the connection with a given alias."""
|
||||
from mongoengine.base.common import _get_documents_by_db
|
||||
from mongoengine import Document
|
||||
|
||||
if alias in _connections:
|
||||
get_connection(alias=alias).close()
|
||||
del _connections[alias]
|
||||
|
||||
if alias in _dbs:
|
||||
# Detach all cached collections in Documents
|
||||
for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME):
|
||||
if issubclass(doc_cls, Document): # Skip EmbeddedDocument
|
||||
doc_cls._disconnect()
|
||||
|
||||
del _dbs[alias]
|
||||
|
||||
if alias in _connection_settings:
|
||||
del _connection_settings[alias]
|
||||
|
||||
|
||||
def disconnect_all():
|
||||
"""Close all registered database."""
|
||||
for alias in list(_connections.keys()):
|
||||
disconnect(alias)
|
||||
|
||||
|
||||
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
"""Return a connection with a given alias."""
|
||||
@@ -159,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
raise MongoEngineConnectionError(msg)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
# set literal more efficient than calling set function
|
||||
irrelevant_fields_set = {
|
||||
'name', 'username', 'password',
|
||||
'authentication_source', 'authentication_mechanism'
|
||||
@@ -169,10 +244,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
if k not in irrelevant_fields_set
|
||||
}
|
||||
|
||||
raw_conn_settings = _connection_settings[alias].copy()
|
||||
|
||||
# Retrieve a copy of the connection settings associated with the requested
|
||||
# alias and remove the database name and authentication info (we don't
|
||||
# care about them at this point).
|
||||
conn_settings = _clean_settings(_connection_settings[alias].copy())
|
||||
conn_settings = _clean_settings(raw_conn_settings)
|
||||
|
||||
# Determine if we should use PyMongo's or mongomock's MongoClient.
|
||||
is_mock = conn_settings.pop('is_mock', False)
|
||||
@@ -186,51 +263,60 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
else:
|
||||
connection_class = MongoClient
|
||||
|
||||
# For replica set connections with PyMongo 2.x, use
|
||||
# MongoReplicaSetClient.
|
||||
# TODO remove this once we stop supporting PyMongo 2.x.
|
||||
if 'replicaSet' in conn_settings and not IS_PYMONGO_3:
|
||||
connection_class = MongoReplicaSetClient
|
||||
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
||||
|
||||
# hosts_or_uri has to be a string, so if 'host' was provided
|
||||
# as a list, join its parts and separate them by ','
|
||||
if isinstance(conn_settings['hosts_or_uri'], list):
|
||||
conn_settings['hosts_or_uri'] = ','.join(
|
||||
conn_settings['hosts_or_uri'])
|
||||
|
||||
# Discard port since it can't be used on MongoReplicaSetClient
|
||||
conn_settings.pop('port', None)
|
||||
|
||||
# Iterate over all of the connection settings and if a connection with
|
||||
# the same parameters is already established, use it instead of creating
|
||||
# a new one.
|
||||
existing_connection = None
|
||||
connection_settings_iterator = (
|
||||
(db_alias, settings.copy())
|
||||
for db_alias, settings in _connection_settings.items()
|
||||
)
|
||||
for db_alias, connection_settings in connection_settings_iterator:
|
||||
connection_settings = _clean_settings(connection_settings)
|
||||
if conn_settings == connection_settings and _connections.get(db_alias):
|
||||
existing_connection = _connections[db_alias]
|
||||
break
|
||||
# Re-use existing connection if one is suitable
|
||||
existing_connection = _find_existing_connection(raw_conn_settings)
|
||||
|
||||
# If an existing connection was found, assign it to the new alias
|
||||
if existing_connection:
|
||||
_connections[alias] = existing_connection
|
||||
else:
|
||||
# Otherwise, create the new connection for this alias. Raise
|
||||
# MongoEngineConnectionError if it can't be established.
|
||||
try:
|
||||
_connections[alias] = connection_class(**conn_settings)
|
||||
except Exception as e:
|
||||
raise MongoEngineConnectionError(
|
||||
'Cannot connect to database %s :\n%s' % (alias, e))
|
||||
_connections[alias] = _create_connection(alias=alias,
|
||||
connection_class=connection_class,
|
||||
**conn_settings)
|
||||
|
||||
return _connections[alias]
|
||||
|
||||
|
||||
def _create_connection(alias, connection_class, **connection_settings):
|
||||
"""
|
||||
Create the new connection for this alias. Raise
|
||||
MongoEngineConnectionError if it can't be established.
|
||||
"""
|
||||
try:
|
||||
return connection_class(**connection_settings)
|
||||
except Exception as e:
|
||||
raise MongoEngineConnectionError(
|
||||
'Cannot connect to database %s :\n%s' % (alias, e))
|
||||
|
||||
|
||||
def _find_existing_connection(connection_settings):
|
||||
"""
|
||||
Check if an existing connection could be reused
|
||||
|
||||
Iterate over all of the connection settings and if an existing connection
|
||||
with the same parameters is suitable, return it
|
||||
|
||||
:param connection_settings: the settings of the new connection
|
||||
:return: An existing connection or None
|
||||
"""
|
||||
connection_settings_bis = (
|
||||
(db_alias, settings.copy())
|
||||
for db_alias, settings in _connection_settings.items()
|
||||
)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
# Only remove the name but it's important to
|
||||
# keep the username/password/authentication_source/authentication_mechanism
|
||||
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
|
||||
return {k: v for k, v in settings_dict.items() if k != 'name'}
|
||||
|
||||
cleaned_conn_settings = _clean_settings(connection_settings)
|
||||
for db_alias, connection_settings in connection_settings_bis:
|
||||
db_conn_settings = _clean_settings(connection_settings)
|
||||
if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias):
|
||||
return _connections[db_alias]
|
||||
|
||||
|
||||
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
if reconnect:
|
||||
disconnect(alias)
|
||||
@@ -258,14 +344,24 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
|
||||
provide username and password arguments as well.
|
||||
|
||||
Multiple databases are supported by using aliases. Provide a separate
|
||||
`alias` to connect to a different instance of :program:`mongod`.
|
||||
`alias` to connect to a different instance of: program: `mongod`.
|
||||
|
||||
In order to replace a connection identified by a given alias, you'll
|
||||
need to call ``disconnect`` first
|
||||
|
||||
See the docstring for `register_connection` for more details about all
|
||||
supported kwargs.
|
||||
|
||||
.. versionchanged:: 0.6 - added multiple database support.
|
||||
"""
|
||||
if alias not in _connections:
|
||||
if alias in _connections:
|
||||
prev_conn_setting = _connection_settings[alias]
|
||||
new_conn_settings = _get_connection_settings(db, **kwargs)
|
||||
|
||||
if new_conn_settings != prev_conn_setting:
|
||||
raise MongoEngineConnectionError(
|
||||
'A different connection with alias `%s` was already registered. Use disconnect() first' % alias)
|
||||
else:
|
||||
register_connection(alias, db, **kwargs)
|
||||
|
||||
return get_connection(alias)
|
||||
|
@@ -18,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern,
|
||||
switch_db)
|
||||
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
|
||||
SaveConditionError)
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3, list_collection_names
|
||||
from mongoengine.pymongo_support import list_collection_names
|
||||
from mongoengine.queryset import (NotUniqueError, OperationError,
|
||||
QuerySet, transform)
|
||||
|
||||
@@ -90,18 +90,6 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
|
||||
|
||||
return data
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
warnings.warn("EmbeddedDocument.save is deprecated and will be removed in a next version of mongoengine."
|
||||
"Use the parent document's .save() or ._instance.save()",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._instance.save(*args, **kwargs)
|
||||
|
||||
def reload(self, *args, **kwargs):
|
||||
warnings.warn("EmbeddedDocument.reload is deprecated and will be removed in a next version of mongoengine."
|
||||
"Use the parent document's .reload() or ._instance.reload()",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._instance.reload(*args, **kwargs)
|
||||
|
||||
|
||||
class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
"""The base class used for defining the structure and properties of
|
||||
@@ -188,10 +176,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
|
||||
|
||||
@classmethod
|
||||
def _get_collection(cls):
|
||||
"""Return a PyMongo collection for the document."""
|
||||
if not hasattr(cls, '_collection') or cls._collection is None:
|
||||
def _disconnect(cls):
|
||||
"""Detach the Document class from the (cached) database collection"""
|
||||
cls._collection = None
|
||||
|
||||
@classmethod
|
||||
def _get_collection(cls):
|
||||
"""Return the corresponding PyMongo collection of this document.
|
||||
Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached
|
||||
"""
|
||||
if not hasattr(cls, '_collection') or cls._collection is None:
|
||||
# Get the collection, either capped or regular.
|
||||
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
|
||||
cls._collection = cls._get_capped_collection()
|
||||
@@ -253,7 +247,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
data = super(Document, self).to_mongo(*args, **kwargs)
|
||||
|
||||
# If '_id' is None, try and set it from self._data. If that
|
||||
# doesn't exist either, remote '_id' from the SON completely.
|
||||
# doesn't exist either, remove '_id' from the SON completely.
|
||||
if data['_id'] is None:
|
||||
if self._data.get('id') is None:
|
||||
del data['_id']
|
||||
@@ -359,21 +353,21 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
.. versionchanged:: 0.10.7
|
||||
Add signal_kwargs argument
|
||||
"""
|
||||
signal_kwargs = signal_kwargs or {}
|
||||
|
||||
if self._meta.get('abstract'):
|
||||
raise InvalidDocumentError('Cannot save an abstract document.')
|
||||
|
||||
signal_kwargs = signal_kwargs or {}
|
||||
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
|
||||
|
||||
if validate:
|
||||
self.validate(clean=clean)
|
||||
|
||||
if write_concern is None:
|
||||
write_concern = {'w': 1}
|
||||
write_concern = {}
|
||||
|
||||
doc = self.to_mongo()
|
||||
|
||||
created = ('_id' not in doc or self._created or force_insert)
|
||||
doc_id = self.to_mongo(fields=['id'])
|
||||
created = ('_id' not in doc_id or self._created or force_insert)
|
||||
|
||||
signals.pre_save_post_validation.send(self.__class__, document=self,
|
||||
created=created, **signal_kwargs)
|
||||
@@ -451,16 +445,6 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
|
||||
object_id = wc_collection.insert_one(doc).inserted_id
|
||||
|
||||
# In PyMongo 3.0, the save() call calls internally the _update() call
|
||||
# but they forget to return the _id value passed back, therefore getting it back here
|
||||
# Correct behaviour in 2.X and in 3.0.1+ versions
|
||||
if not object_id and pymongo.version_tuple == (3, 0):
|
||||
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
|
||||
object_id = (
|
||||
self._qs.filter(pk=pk_as_mongo_obj).first() and
|
||||
self._qs.filter(pk=pk_as_mongo_obj).first().pk
|
||||
) # TODO doesn't this make 2 queries?
|
||||
|
||||
return object_id
|
||||
|
||||
def _get_update_doc(self):
|
||||
@@ -506,8 +490,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
update_doc = self._get_update_doc()
|
||||
if update_doc:
|
||||
upsert = save_condition is None
|
||||
last_error = collection.update(select_dict, update_doc,
|
||||
upsert=upsert, **write_concern)
|
||||
with set_write_concern(collection, write_concern) as wc_collection:
|
||||
last_error = wc_collection.update_one(
|
||||
select_dict,
|
||||
update_doc,
|
||||
upsert=upsert
|
||||
).raw_result
|
||||
if not upsert and last_error['n'] == 0:
|
||||
raise SaveConditionError('Race condition preventing'
|
||||
' document update detected')
|
||||
@@ -799,13 +787,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
.. versionchanged:: 0.10.7
|
||||
:class:`OperationError` exception raised if no collection available
|
||||
"""
|
||||
col_name = cls._get_collection_name()
|
||||
if not col_name:
|
||||
coll_name = cls._get_collection_name()
|
||||
if not coll_name:
|
||||
raise OperationError('Document %s has no collection defined '
|
||||
'(is it abstract ?)' % cls)
|
||||
cls._collection = None
|
||||
db = cls._get_db()
|
||||
db.drop_collection(col_name)
|
||||
db.drop_collection(coll_name)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, keys, background=False, **kwargs):
|
||||
@@ -820,18 +808,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
index_spec = index_spec.copy()
|
||||
fields = index_spec.pop('fields')
|
||||
drop_dups = kwargs.get('drop_dups', False)
|
||||
if IS_PYMONGO_3 and drop_dups:
|
||||
if drop_dups:
|
||||
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
elif not IS_PYMONGO_3:
|
||||
index_spec['drop_dups'] = drop_dups
|
||||
index_spec['background'] = background
|
||||
index_spec.update(kwargs)
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
return cls._get_collection().create_index(fields, **index_spec)
|
||||
else:
|
||||
return cls._get_collection().ensure_index(fields, **index_spec)
|
||||
return cls._get_collection().create_index(fields, **index_spec)
|
||||
|
||||
@classmethod
|
||||
def ensure_index(cls, key_or_list, drop_dups=False, background=False,
|
||||
@@ -846,11 +829,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
:param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value
|
||||
will be removed if PyMongo3+ is used
|
||||
"""
|
||||
if IS_PYMONGO_3 and drop_dups:
|
||||
if drop_dups:
|
||||
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
elif not IS_PYMONGO_3:
|
||||
kwargs.update({'drop_dups': drop_dups})
|
||||
return cls.create_index(key_or_list, background=background, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -866,7 +847,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
drop_dups = cls._meta.get('index_drop_dups', False)
|
||||
index_opts = cls._meta.get('index_opts') or {}
|
||||
index_cls = cls._meta.get('index_cls', True)
|
||||
if IS_PYMONGO_3 and drop_dups:
|
||||
if drop_dups:
|
||||
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
|
||||
@@ -897,11 +878,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
if 'cls' in opts:
|
||||
del opts['cls']
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
collection.create_index(fields, background=background, **opts)
|
||||
else:
|
||||
collection.ensure_index(fields, background=background,
|
||||
drop_dups=drop_dups, **opts)
|
||||
collection.create_index(fields, background=background, **opts)
|
||||
|
||||
# If _cls is being used (for polymorphism), it needs an index,
|
||||
# only if another index doesn't begin with _cls
|
||||
@@ -912,12 +889,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
|
||||
if 'cls' in index_opts:
|
||||
del index_opts['cls']
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
collection.create_index('_cls', background=background,
|
||||
**index_opts)
|
||||
else:
|
||||
collection.ensure_index('_cls', background=background,
|
||||
**index_opts)
|
||||
collection.create_index('_cls', background=background,
|
||||
**index_opts)
|
||||
|
||||
@classmethod
|
||||
def list_indexes(cls):
|
||||
|
@@ -6,7 +6,7 @@ from six import iteritems
|
||||
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
|
||||
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
|
||||
'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
|
||||
'ValidationError', 'SaveConditionError')
|
||||
'ValidationError', 'SaveConditionError', 'DeprecatedError')
|
||||
|
||||
|
||||
class NotRegistered(Exception):
|
||||
@@ -110,9 +110,6 @@ class ValidationError(AssertionError):
|
||||
|
||||
def build_dict(source):
|
||||
errors_dict = {}
|
||||
if not source:
|
||||
return errors_dict
|
||||
|
||||
if isinstance(source, dict):
|
||||
for field_name, error in iteritems(source):
|
||||
errors_dict[field_name] = build_dict(error)
|
||||
@@ -145,3 +142,8 @@ class ValidationError(AssertionError):
|
||||
for k, v in iteritems(self.to_dict()):
|
||||
error_dict[generate_key(v)].append(k)
|
||||
return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)])
|
||||
|
||||
|
||||
class DeprecatedError(Exception):
|
||||
"""Raise when a user uses a feature that has been Deprecated"""
|
||||
pass
|
||||
|
@@ -37,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
|
||||
from mongoengine.python_support import StringIO
|
||||
from mongoengine.queryset import DO_NOTHING
|
||||
from mongoengine.queryset.base import BaseQuerySet
|
||||
from mongoengine.queryset.transform import STRING_OPERATORS
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageOps
|
||||
@@ -106,11 +107,11 @@ class StringField(BaseField):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
|
||||
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
|
||||
flags = 0
|
||||
if op.startswith('i'):
|
||||
flags = re.IGNORECASE
|
||||
op = op.lstrip('i')
|
||||
if op in STRING_OPERATORS:
|
||||
case_insensitive = op.startswith('i')
|
||||
op = op.lstrip('i')
|
||||
|
||||
flags = re.IGNORECASE if case_insensitive else 0
|
||||
|
||||
regex = r'%s'
|
||||
if op == 'startswith':
|
||||
@@ -152,12 +153,10 @@ class URLField(StringField):
|
||||
scheme = value.split('://')[0].lower()
|
||||
if scheme not in self.schemes:
|
||||
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
|
||||
return
|
||||
|
||||
# Then check full URL
|
||||
if not self.url_regex.match(value):
|
||||
self.error(u'Invalid URL: {}'.format(value))
|
||||
return
|
||||
|
||||
|
||||
class EmailField(StringField):
|
||||
@@ -259,10 +258,10 @@ class EmailField(StringField):
|
||||
try:
|
||||
domain_part = domain_part.encode('idna').decode('ascii')
|
||||
except UnicodeError:
|
||||
self.error(self.error_msg % value)
|
||||
self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)"))
|
||||
else:
|
||||
if not self.validate_domain_part(domain_part):
|
||||
self.error(self.error_msg % value)
|
||||
self.error("%s %s" % (self.error_msg % value, "(domain validation failed)"))
|
||||
|
||||
|
||||
class IntField(BaseField):
|
||||
@@ -499,15 +498,18 @@ class DateTimeField(BaseField):
|
||||
if not isinstance(value, six.string_types):
|
||||
return None
|
||||
|
||||
return self._parse_datetime(value)
|
||||
|
||||
def _parse_datetime(self, value):
|
||||
# Attempt to parse a datetime from a string
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
|
||||
# Attempt to parse a datetime:
|
||||
if dateutil:
|
||||
try:
|
||||
return dateutil.parser.parse(value)
|
||||
except (TypeError, ValueError):
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
# split usecs, because they are not recognized by strptime.
|
||||
@@ -700,7 +702,11 @@ class EmbeddedDocumentField(BaseField):
|
||||
self.document_type.validate(value, clean)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
return self.document_type._fields.get(member_name)
|
||||
doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__()
|
||||
for doc_type in doc_and_subclasses:
|
||||
field = doc_type._fields.get(member_name)
|
||||
if field:
|
||||
return field
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if value is not None and not isinstance(value, self.document_type):
|
||||
@@ -747,12 +753,13 @@ class GenericEmbeddedDocumentField(BaseField):
|
||||
value.validate(clean=clean)
|
||||
|
||||
def lookup_member(self, member_name):
|
||||
if self.choices:
|
||||
for choice in self.choices:
|
||||
field = choice._fields.get(member_name)
|
||||
document_choices = self.choices or []
|
||||
for document_choice in document_choices:
|
||||
doc_and_subclasses = [document_choice] + document_choice.__subclasses__()
|
||||
for doc_type in doc_and_subclasses:
|
||||
field = doc_type._fields.get(member_name)
|
||||
if field:
|
||||
return field
|
||||
return None
|
||||
|
||||
def to_mongo(self, document, use_db_field=True, fields=None):
|
||||
if document is None:
|
||||
|
@@ -7,9 +7,7 @@ from mongoengine.connection import get_connection
|
||||
# Constant that can be used to compare the version retrieved with
|
||||
# get_mongodb_version()
|
||||
MONGODB_34 = (3, 4)
|
||||
MONGODB_32 = (3, 2)
|
||||
MONGODB_3 = (3, 0)
|
||||
MONGODB_26 = (2, 6)
|
||||
MONGODB_36 = (3, 6)
|
||||
|
||||
|
||||
def get_mongodb_version():
|
||||
|
@@ -7,7 +7,6 @@ _PYMONGO_37 = (3, 7)
|
||||
|
||||
PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
|
||||
|
||||
IS_PYMONGO_3 = PYMONGO_VERSION[0] >= 3
|
||||
IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37
|
||||
|
||||
|
||||
|
@@ -10,6 +10,7 @@ from bson import SON, json_util
|
||||
from bson.code import Code
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.common import validate_read_preference
|
||||
import six
|
||||
from six import iteritems
|
||||
@@ -21,14 +22,10 @@ from mongoengine.connection import get_db
|
||||
from mongoengine.context_managers import set_write_concern, switch_db
|
||||
from mongoengine.errors import (InvalidQueryError, LookUpError,
|
||||
NotUniqueError, OperationError)
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
from mongoengine.queryset import transform
|
||||
from mongoengine.queryset.field_list import QueryFieldList
|
||||
from mongoengine.queryset.visitor import Q, QNode
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
from pymongo.collection import ReturnDocument
|
||||
|
||||
|
||||
__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
|
||||
|
||||
@@ -197,7 +194,7 @@ class BaseQuerySet(object):
|
||||
only_fields=self.only_fields
|
||||
)
|
||||
|
||||
raise AttributeError('Provide a slice or an integer index')
|
||||
raise TypeError('Provide a slice or an integer index')
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
@@ -338,7 +335,7 @@ class BaseQuerySet(object):
|
||||
% str(self._document))
|
||||
raise OperationError(msg)
|
||||
if doc.pk and not doc._created:
|
||||
msg = 'Some documents have ObjectIds use doc.update() instead'
|
||||
msg = 'Some documents have ObjectIds, use doc.update() instead'
|
||||
raise OperationError(msg)
|
||||
|
||||
signal_kwargs = signal_kwargs or {}
|
||||
@@ -626,31 +623,25 @@ class BaseQuerySet(object):
|
||||
|
||||
queryset = self.clone()
|
||||
query = queryset._query
|
||||
if not IS_PYMONGO_3 or not remove:
|
||||
if not remove:
|
||||
update = transform.update(queryset._document, **update)
|
||||
sort = queryset._ordering
|
||||
|
||||
try:
|
||||
if IS_PYMONGO_3:
|
||||
if full_response:
|
||||
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
if remove:
|
||||
result = queryset._collection.find_one_and_delete(
|
||||
query, sort=sort, **self._cursor_args)
|
||||
else:
|
||||
if new:
|
||||
return_doc = ReturnDocument.AFTER
|
||||
else:
|
||||
return_doc = ReturnDocument.BEFORE
|
||||
result = queryset._collection.find_one_and_update(
|
||||
query, update, upsert=upsert, sort=sort, return_document=return_doc,
|
||||
**self._cursor_args)
|
||||
|
||||
if full_response:
|
||||
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
if remove:
|
||||
result = queryset._collection.find_one_and_delete(
|
||||
query, sort=sort, **self._cursor_args)
|
||||
else:
|
||||
result = queryset._collection.find_and_modify(
|
||||
query, update, upsert=upsert, sort=sort, remove=remove, new=new,
|
||||
full_response=full_response, **self._cursor_args)
|
||||
if new:
|
||||
return_doc = ReturnDocument.AFTER
|
||||
else:
|
||||
return_doc = ReturnDocument.BEFORE
|
||||
result = queryset._collection.find_one_and_update(
|
||||
query, update, upsert=upsert, sort=sort, return_document=return_doc,
|
||||
**self._cursor_args)
|
||||
except pymongo.errors.DuplicateKeyError as err:
|
||||
raise NotUniqueError(u'Update failed (%s)' % err)
|
||||
except pymongo.errors.OperationFailure as err:
|
||||
@@ -1082,15 +1073,14 @@ class BaseQuerySet(object):
|
||||
..versionchanged:: 0.5 - made chainable
|
||||
.. deprecated:: Ignored with PyMongo 3+
|
||||
"""
|
||||
if IS_PYMONGO_3:
|
||||
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
queryset = self.clone()
|
||||
queryset._snapshot = enabled
|
||||
return queryset
|
||||
|
||||
def timeout(self, enabled):
|
||||
"""Enable or disable the default mongod timeout when querying.
|
||||
"""Enable or disable the default mongod timeout when querying. (no_cursor_timeout option)
|
||||
|
||||
:param enabled: whether or not the timeout is used
|
||||
|
||||
@@ -1108,9 +1098,8 @@ class BaseQuerySet(object):
|
||||
|
||||
.. deprecated:: Ignored with PyMongo 3+
|
||||
"""
|
||||
if IS_PYMONGO_3:
|
||||
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
queryset = self.clone()
|
||||
queryset._slave_okay = enabled
|
||||
return queryset
|
||||
@@ -1200,14 +1189,18 @@ class BaseQuerySet(object):
|
||||
initial_pipeline.append({'$sort': dict(self._ordering)})
|
||||
|
||||
if self._limit is not None:
|
||||
initial_pipeline.append({'$limit': self._limit})
|
||||
# As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/),
|
||||
# keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents
|
||||
# for a skip stage that might succeed these. So we need to maintain more documents in memory in such a
|
||||
# case (https://stackoverflow.com/a/24161461).
|
||||
initial_pipeline.append({'$limit': self._limit + (self._skip or 0)})
|
||||
|
||||
if self._skip is not None:
|
||||
initial_pipeline.append({'$skip': self._skip})
|
||||
|
||||
pipeline = initial_pipeline + list(pipeline)
|
||||
|
||||
if IS_PYMONGO_3 and self._read_preference is not None:
|
||||
if self._read_preference is not None:
|
||||
return self._collection.with_options(read_preference=self._read_preference) \
|
||||
.aggregate(pipeline, cursor={}, **kwargs)
|
||||
|
||||
@@ -1417,11 +1410,7 @@ class BaseQuerySet(object):
|
||||
if isinstance(field_instances[-1], ListField):
|
||||
pipeline.insert(1, {'$unwind': '$' + field})
|
||||
|
||||
result = self._document._get_collection().aggregate(pipeline)
|
||||
if IS_PYMONGO_3:
|
||||
result = tuple(result)
|
||||
else:
|
||||
result = result.get('result')
|
||||
result = tuple(self._document._get_collection().aggregate(pipeline))
|
||||
|
||||
if result:
|
||||
return result[0]['total']
|
||||
@@ -1448,11 +1437,7 @@ class BaseQuerySet(object):
|
||||
if isinstance(field_instances[-1], ListField):
|
||||
pipeline.insert(1, {'$unwind': '$' + field})
|
||||
|
||||
result = self._document._get_collection().aggregate(pipeline)
|
||||
if IS_PYMONGO_3:
|
||||
result = tuple(result)
|
||||
else:
|
||||
result = result.get('result')
|
||||
result = tuple(self._document._get_collection().aggregate(pipeline))
|
||||
if result:
|
||||
return result[0]['total']
|
||||
return 0
|
||||
@@ -1527,26 +1512,16 @@ class BaseQuerySet(object):
|
||||
|
||||
@property
|
||||
def _cursor_args(self):
|
||||
if not IS_PYMONGO_3:
|
||||
fields_name = 'fields'
|
||||
cursor_args = {
|
||||
'timeout': self._timeout,
|
||||
'snapshot': self._snapshot
|
||||
}
|
||||
if self._read_preference is not None:
|
||||
cursor_args['read_preference'] = self._read_preference
|
||||
else:
|
||||
cursor_args['slave_okay'] = self._slave_okay
|
||||
else:
|
||||
fields_name = 'projection'
|
||||
# snapshot is not handled at all by PyMongo 3+
|
||||
# TODO: evaluate similar possibilities using modifiers
|
||||
if self._snapshot:
|
||||
msg = 'The snapshot option is not anymore available with PyMongo 3+'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
cursor_args = {
|
||||
'no_cursor_timeout': not self._timeout
|
||||
}
|
||||
fields_name = 'projection'
|
||||
# snapshot is not handled at all by PyMongo 3+
|
||||
# TODO: evaluate similar possibilities using modifiers
|
||||
if self._snapshot:
|
||||
msg = 'The snapshot option is not anymore available with PyMongo 3+'
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
cursor_args = {
|
||||
'no_cursor_timeout': not self._timeout
|
||||
}
|
||||
|
||||
if self._loaded_fields:
|
||||
cursor_args[fields_name] = self._loaded_fields.as_dict()
|
||||
|
||||
@@ -1570,7 +1545,7 @@ class BaseQuerySet(object):
|
||||
# XXX In PyMongo 3+, we define the read preference on a collection
|
||||
# level, not a cursor level. Thus, we need to get a cloned collection
|
||||
# object using `with_options` first.
|
||||
if IS_PYMONGO_3 and self._read_preference is not None:
|
||||
if self._read_preference is not None:
|
||||
self._cursor_obj = self._collection\
|
||||
.with_options(read_preference=self._read_preference)\
|
||||
.find(self._query, **self._cursor_args)
|
||||
|
@@ -8,9 +8,7 @@ from six import iteritems
|
||||
|
||||
from mongoengine.base import UPDATE_OPERATORS
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine.connection import get_connection
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
|
||||
__all__ = ('query', 'update')
|
||||
|
||||
@@ -88,18 +86,10 @@ def query(_doc_cls=None, **kwargs):
|
||||
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
|
||||
singular_ops += STRING_OPERATORS
|
||||
if op in singular_ops:
|
||||
if isinstance(field, six.string_types):
|
||||
if (op in STRING_OPERATORS and
|
||||
isinstance(value, six.string_types)):
|
||||
StringField = _import_class('StringField')
|
||||
value = StringField.prepare_query_value(op, value)
|
||||
else:
|
||||
value = field
|
||||
else:
|
||||
value = field.prepare_query_value(op, value)
|
||||
value = field.prepare_query_value(op, value)
|
||||
|
||||
if isinstance(field, CachedReferenceField) and value:
|
||||
value = value['_id']
|
||||
if isinstance(field, CachedReferenceField) and value:
|
||||
value = value['_id']
|
||||
|
||||
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
|
||||
# Raise an error if the in/nin/all/near param is not iterable.
|
||||
@@ -163,16 +153,14 @@ def query(_doc_cls=None, **kwargs):
|
||||
# PyMongo 3+ and MongoDB < 2.6
|
||||
near_embedded = False
|
||||
for near_op in ('$near', '$nearSphere'):
|
||||
if isinstance(value_dict.get(near_op), dict) and (
|
||||
IS_PYMONGO_3 or get_connection().max_wire_version > 1):
|
||||
if isinstance(value_dict.get(near_op), dict):
|
||||
value_son[near_op] = SON(value_son[near_op])
|
||||
if '$maxDistance' in value_dict:
|
||||
value_son[near_op][
|
||||
'$maxDistance'] = value_dict['$maxDistance']
|
||||
value_son[near_op]['$maxDistance'] = value_dict['$maxDistance']
|
||||
if '$minDistance' in value_dict:
|
||||
value_son[near_op][
|
||||
'$minDistance'] = value_dict['$minDistance']
|
||||
value_son[near_op]['$minDistance'] = value_dict['$minDistance']
|
||||
near_embedded = True
|
||||
|
||||
if not near_embedded:
|
||||
if '$maxDistance' in value_dict:
|
||||
value_son['$maxDistance'] = value_dict['$maxDistance']
|
||||
@@ -281,7 +269,7 @@ def update(_doc_cls=None, **update):
|
||||
|
||||
if op == 'pull':
|
||||
if field.required or value is not None:
|
||||
if match == 'in' and not isinstance(value, dict):
|
||||
if match in ('in', 'nin') and not isinstance(value, dict):
|
||||
value = _prepare_query_for_iterable(field, op, value)
|
||||
else:
|
||||
value = field.prepare_query_value(op, value)
|
||||
@@ -308,10 +296,6 @@ def update(_doc_cls=None, **update):
|
||||
|
||||
key = '.'.join(parts)
|
||||
|
||||
if not op:
|
||||
raise InvalidQueryError('Updates must supply an operation '
|
||||
'eg: set__FIELD=value')
|
||||
|
||||
if 'pull' in op and '.' in key:
|
||||
# Dot operators don't work on pull operations
|
||||
# unless they point to a list field
|
||||
|
@@ -1,5 +1,5 @@
|
||||
nose
|
||||
pymongo>=2.7.1
|
||||
pymongo>=3.4
|
||||
six==1.10.0
|
||||
flake8
|
||||
flake8-import-order
|
||||
|
2
setup.py
2
setup.py
@@ -80,7 +80,7 @@ setup(
|
||||
long_description=LONG_DESCRIPTION,
|
||||
platforms=['any'],
|
||||
classifiers=CLASSIFIERS,
|
||||
install_requires=['pymongo>=2.7.1', 'six'],
|
||||
install_requires=['pymongo>=3.4', 'six'],
|
||||
test_suite='nose.collector',
|
||||
**extra_opts
|
||||
)
|
||||
|
@@ -6,7 +6,6 @@ from mongoengine.pymongo_support import list_collection_names
|
||||
|
||||
from mongoengine.queryset import NULLIFY, PULL
|
||||
from mongoengine.connection import get_db
|
||||
from tests.utils import requires_mongodb_gte_26
|
||||
|
||||
__all__ = ("ClassMethodsTest", )
|
||||
|
||||
@@ -187,7 +186,6 @@ class ClassMethodsTest(unittest.TestCase):
|
||||
self.assertEqual(BlogPostWithTags.compare_indexes(), {'missing': [], 'extra': []})
|
||||
self.assertEqual(BlogPostWithCustomField.compare_indexes(), {'missing': [], 'extra': []})
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_compare_indexes_for_text_indexes(self):
|
||||
""" Ensure that compare_indexes behaves correctly for text indexes """
|
||||
|
||||
|
@@ -9,8 +9,6 @@ from six import iteritems
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32, MONGODB_3
|
||||
from tests.utils import requires_mongodb_gte_26, requires_mongodb_lte_32, requires_mongodb_gte_34
|
||||
|
||||
__all__ = ("IndexesTest", )
|
||||
|
||||
@@ -20,7 +18,6 @@ class IndexesTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.connection = connect(db='mongoenginetest')
|
||||
self.db = get_db()
|
||||
self.mongodb_version = get_mongodb_version()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
@@ -409,7 +406,7 @@ class IndexesTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(2, User.objects.count())
|
||||
info = User.objects._collection.index_information()
|
||||
self.assertEqual(info.keys(), ['_id_'])
|
||||
self.assertEqual(list(info.keys()), ['_id_'])
|
||||
|
||||
User.ensure_indexes()
|
||||
info = User.objects._collection.index_information()
|
||||
@@ -478,8 +475,6 @@ class IndexesTest(unittest.TestCase):
|
||||
def test_covered_index(self):
|
||||
"""Ensure that covered indexes can be used
|
||||
"""
|
||||
IS_MONGODB_3 = get_mongodb_version() >= MONGODB_3
|
||||
|
||||
class Test(Document):
|
||||
a = IntField()
|
||||
b = IntField()
|
||||
@@ -497,33 +492,38 @@ class IndexesTest(unittest.TestCase):
|
||||
# Need to be explicit about covered indexes as mongoDB doesn't know if
|
||||
# the documents returned might have more keys in that here.
|
||||
query_plan = Test.objects(id=obj.id).exclude('a').explain()
|
||||
if not IS_MONGODB_3:
|
||||
self.assertFalse(query_plan['indexOnly'])
|
||||
else:
|
||||
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK')
|
||||
self.assertEqual(
|
||||
query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
|
||||
'IDHACK'
|
||||
)
|
||||
|
||||
query_plan = Test.objects(id=obj.id).only('id').explain()
|
||||
if not IS_MONGODB_3:
|
||||
self.assertTrue(query_plan['indexOnly'])
|
||||
else:
|
||||
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK')
|
||||
self.assertEqual(
|
||||
query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
|
||||
'IDHACK'
|
||||
)
|
||||
|
||||
query_plan = Test.objects(a=1).only('a').exclude('id').explain()
|
||||
if not IS_MONGODB_3:
|
||||
self.assertTrue(query_plan['indexOnly'])
|
||||
else:
|
||||
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN')
|
||||
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'PROJECTION')
|
||||
self.assertEqual(
|
||||
query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
|
||||
'IXSCAN'
|
||||
)
|
||||
self.assertEqual(
|
||||
query_plan.get('queryPlanner').get('winningPlan').get('stage'),
|
||||
'PROJECTION'
|
||||
)
|
||||
|
||||
query_plan = Test.objects(a=1).explain()
|
||||
if not IS_MONGODB_3:
|
||||
self.assertFalse(query_plan['indexOnly'])
|
||||
else:
|
||||
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN')
|
||||
self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'FETCH')
|
||||
self.assertEqual(
|
||||
query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
|
||||
'IXSCAN'
|
||||
)
|
||||
self.assertEqual(
|
||||
query_plan.get('queryPlanner').get('winningPlan').get('stage'),
|
||||
'FETCH'
|
||||
)
|
||||
|
||||
def test_index_on_id(self):
|
||||
|
||||
class BlogPost(Document):
|
||||
meta = {
|
||||
'indexes': [
|
||||
@@ -542,9 +542,8 @@ class IndexesTest(unittest.TestCase):
|
||||
[('categories', 1), ('_id', 1)])
|
||||
|
||||
def test_hint(self):
|
||||
MONGO_VER = self.mongodb_version
|
||||
|
||||
TAGS_INDEX_NAME = 'tags_1'
|
||||
|
||||
class BlogPost(Document):
|
||||
tags = ListField(StringField())
|
||||
meta = {
|
||||
@@ -562,25 +561,27 @@ class IndexesTest(unittest.TestCase):
|
||||
tags = [("tag %i" % n) for n in range(i % 2)]
|
||||
BlogPost(tags=tags).save()
|
||||
|
||||
self.assertEqual(BlogPost.objects.count(), 10)
|
||||
self.assertEqual(BlogPost.objects.hint().count(), 10)
|
||||
|
||||
# PyMongo 3.0 bug only, works correctly with 2.X and 3.0.1+ versions
|
||||
if pymongo.version != '3.0':
|
||||
self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10)
|
||||
|
||||
if MONGO_VER >= MONGODB_32:
|
||||
# Mongo32 throws an error if an index exists (i.e `tags` in our case)
|
||||
# and you use hint on an index name that does not exist
|
||||
with self.assertRaises(OperationFailure):
|
||||
BlogPost.objects.hint([('ZZ', 1)]).count()
|
||||
else:
|
||||
self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10)
|
||||
# Hinting by shape should work.
|
||||
self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10)
|
||||
|
||||
# Hinting by index name should work.
|
||||
self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
BlogPost.objects.hint(('tags', 1)).next()
|
||||
# Clearing the hint should work fine.
|
||||
self.assertEqual(BlogPost.objects.hint().count(), 10)
|
||||
self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).hint().count(), 10)
|
||||
|
||||
# Hinting on a non-existent index shape should fail.
|
||||
with self.assertRaises(OperationFailure):
|
||||
BlogPost.objects.hint([('ZZ', 1)]).count()
|
||||
|
||||
# Hinting on a non-existent index name should fail.
|
||||
with self.assertRaises(OperationFailure):
|
||||
BlogPost.objects.hint('Bad Name').count()
|
||||
|
||||
# Invalid shape argument (missing list brackets) should fail.
|
||||
with self.assertRaises(ValueError):
|
||||
BlogPost.objects.hint(('tags', 1)).count()
|
||||
|
||||
def test_unique(self):
|
||||
"""Ensure that uniqueness constraints are applied to fields.
|
||||
@@ -597,13 +598,13 @@ class IndexesTest(unittest.TestCase):
|
||||
# Two posts with the same slug is not allowed
|
||||
post2 = BlogPost(title='test2', slug='test')
|
||||
self.assertRaises(NotUniqueError, post2.save)
|
||||
self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2)
|
||||
|
||||
# Ensure backwards compatibilty for errors
|
||||
# Ensure backwards compatibility for errors
|
||||
self.assertRaises(OperationError, post2.save)
|
||||
|
||||
@requires_mongodb_gte_34
|
||||
def test_primary_key_unique_not_working_under_mongo_34(self):
|
||||
"""Relates to #1445"""
|
||||
def test_primary_key_unique_not_working(self):
|
||||
"""Relates to #1445"""
|
||||
class Blog(Document):
|
||||
id = StringField(primary_key=True, unique=True)
|
||||
|
||||
@@ -611,21 +612,17 @@ class IndexesTest(unittest.TestCase):
|
||||
|
||||
with self.assertRaises(OperationFailure) as ctx_err:
|
||||
Blog(id='garbage').save()
|
||||
try:
|
||||
self.assertIn("The field 'unique' is not valid for an _id index specification", str(ctx_err.exception))
|
||||
except AssertionError:
|
||||
# error is slightly different on python 3.6
|
||||
self.assertIn("The field 'background' is not valid for an _id index specification", str(ctx_err.exception))
|
||||
|
||||
@requires_mongodb_lte_32
|
||||
def test_primary_key_unique_working_under_mongo_32(self):
|
||||
"""Relates to #1445"""
|
||||
class Blog(Document):
|
||||
id = StringField(primary_key=True, unique=True)
|
||||
|
||||
Blog.drop_collection()
|
||||
|
||||
Blog(id='garbage').save()
|
||||
# One of the errors below should happen. Which one depends on the
|
||||
# PyMongo version and dict order.
|
||||
err_msg = str(ctx_err.exception)
|
||||
self.assertTrue(
|
||||
any([
|
||||
"The field 'unique' is not valid for an _id index specification" in err_msg,
|
||||
"The field 'background' is not valid for an _id index specification" in err_msg,
|
||||
"The field 'sparse' is not valid for an _id index specification" in err_msg,
|
||||
])
|
||||
)
|
||||
|
||||
def test_unique_with(self):
|
||||
"""Ensure that unique_with constraints are applied to fields.
|
||||
@@ -708,6 +705,77 @@ class IndexesTest(unittest.TestCase):
|
||||
|
||||
self.assertRaises(NotUniqueError, post2.save)
|
||||
|
||||
def test_unique_embedded_document_in_sorted_list(self):
|
||||
"""
|
||||
Ensure that the uniqueness constraints are applied to fields in
|
||||
embedded documents, even when the embedded documents in a sorted list
|
||||
field.
|
||||
"""
|
||||
class SubDocument(EmbeddedDocument):
|
||||
year = IntField()
|
||||
slug = StringField(unique=True)
|
||||
|
||||
class BlogPost(Document):
|
||||
title = StringField()
|
||||
subs = SortedListField(EmbeddedDocumentField(SubDocument),
|
||||
ordering='year')
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post1 = BlogPost(
|
||||
title='test1', subs=[
|
||||
SubDocument(year=2009, slug='conflict'),
|
||||
SubDocument(year=2009, slug='conflict')
|
||||
]
|
||||
)
|
||||
post1.save()
|
||||
|
||||
# confirm that the unique index is created
|
||||
indexes = BlogPost._get_collection().index_information()
|
||||
self.assertIn('subs.slug_1', indexes)
|
||||
self.assertTrue(indexes['subs.slug_1']['unique'])
|
||||
|
||||
post2 = BlogPost(
|
||||
title='test2', subs=[SubDocument(year=2014, slug='conflict')]
|
||||
)
|
||||
|
||||
self.assertRaises(NotUniqueError, post2.save)
|
||||
|
||||
def test_unique_embedded_document_in_embedded_document_list(self):
|
||||
"""
|
||||
Ensure that the uniqueness constraints are applied to fields in
|
||||
embedded documents, even when the embedded documents in an embedded
|
||||
list field.
|
||||
"""
|
||||
class SubDocument(EmbeddedDocument):
|
||||
year = IntField()
|
||||
slug = StringField(unique=True)
|
||||
|
||||
class BlogPost(Document):
|
||||
title = StringField()
|
||||
subs = EmbeddedDocumentListField(SubDocument)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post1 = BlogPost(
|
||||
title='test1', subs=[
|
||||
SubDocument(year=2009, slug='conflict'),
|
||||
SubDocument(year=2009, slug='conflict')
|
||||
]
|
||||
)
|
||||
post1.save()
|
||||
|
||||
# confirm that the unique index is created
|
||||
indexes = BlogPost._get_collection().index_information()
|
||||
self.assertIn('subs.slug_1', indexes)
|
||||
self.assertTrue(indexes['subs.slug_1']['unique'])
|
||||
|
||||
post2 = BlogPost(
|
||||
title='test2', subs=[SubDocument(year=2014, slug='conflict')]
|
||||
)
|
||||
|
||||
self.assertRaises(NotUniqueError, post2.save)
|
||||
|
||||
def test_unique_with_embedded_document_and_embedded_unique(self):
|
||||
"""Ensure that uniqueness constraints are applied to fields on
|
||||
embedded documents. And work with unique_with as well.
|
||||
@@ -759,6 +827,18 @@ class IndexesTest(unittest.TestCase):
|
||||
self.assertEqual(3600,
|
||||
info['created_1']['expireAfterSeconds'])
|
||||
|
||||
def test_index_drop_dups_silently_ignored(self):
|
||||
class Customer(Document):
|
||||
cust_id = IntField(unique=True, required=True)
|
||||
meta = {
|
||||
'indexes': ['cust_id'],
|
||||
'index_drop_dups': True,
|
||||
'allow_inheritance': False,
|
||||
}
|
||||
|
||||
Customer.drop_collection()
|
||||
Customer.objects.first()
|
||||
|
||||
def test_unique_and_indexes(self):
|
||||
"""Ensure that 'unique' constraints aren't overridden by
|
||||
meta.indexes.
|
||||
@@ -775,11 +855,16 @@ class IndexesTest(unittest.TestCase):
|
||||
cust.save()
|
||||
|
||||
cust_dupe = Customer(cust_id=1)
|
||||
try:
|
||||
with self.assertRaises(NotUniqueError):
|
||||
cust_dupe.save()
|
||||
raise AssertionError("We saved a dupe!")
|
||||
except NotUniqueError:
|
||||
pass
|
||||
|
||||
cust = Customer(cust_id=2)
|
||||
cust.save()
|
||||
|
||||
# duplicate key on update
|
||||
with self.assertRaises(NotUniqueError):
|
||||
cust.cust_id = 1
|
||||
cust.save()
|
||||
|
||||
def test_primary_save_duplicate_update_existing_object(self):
|
||||
"""If you set a field as primary, then unexpected behaviour can occur.
|
||||
@@ -899,7 +984,6 @@ class IndexesTest(unittest.TestCase):
|
||||
info['provider_ids.foo_1_provider_ids.bar_1']['key'])
|
||||
self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse'])
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_text_indexes(self):
|
||||
class Book(Document):
|
||||
title = DictField()
|
||||
|
@@ -12,11 +12,12 @@ from bson import DBRef, ObjectId
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
from six import iteritems
|
||||
|
||||
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36, MONGODB_34
|
||||
from mongoengine.pymongo_support import list_collection_names
|
||||
from tests import fixtures
|
||||
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
|
||||
PickleDynamicEmbedded, PickleDynamicTest)
|
||||
from tests.utils import MongoDBTestCase
|
||||
from tests.utils import MongoDBTestCase, get_as_pymongo
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.base import get_document, _document_registry
|
||||
@@ -28,8 +29,6 @@ from mongoengine.queryset import NULLIFY, Q
|
||||
from mongoengine.context_managers import switch_db, query_counter
|
||||
from mongoengine import signals
|
||||
|
||||
from tests.utils import requires_mongodb_gte_26
|
||||
|
||||
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
|
||||
'../fields/mongoengine.png')
|
||||
|
||||
@@ -420,6 +419,12 @@ class InstanceTest(MongoDBTestCase):
|
||||
person.save()
|
||||
person.to_dbref()
|
||||
|
||||
def test_key_like_attribute_access(self):
|
||||
person = self.Person(age=30)
|
||||
self.assertEqual(person['age'], 30)
|
||||
with self.assertRaises(KeyError):
|
||||
person['unknown_attr']
|
||||
|
||||
def test_save_abstract_document(self):
|
||||
"""Saving an abstract document should fail."""
|
||||
class Doc(Document):
|
||||
@@ -462,7 +467,16 @@ class InstanceTest(MongoDBTestCase):
|
||||
Animal.drop_collection()
|
||||
doc = Animal(superphylum='Deuterostomia')
|
||||
doc.save()
|
||||
doc.reload()
|
||||
|
||||
mongo_db = get_mongodb_version()
|
||||
CMD_QUERY_KEY = 'command' if mongo_db >= MONGODB_36 else 'query'
|
||||
|
||||
with query_counter() as q:
|
||||
doc.reload()
|
||||
query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0]
|
||||
self.assertEqual(set(query_op[CMD_QUERY_KEY]['filter'].keys()), set(['_id', 'superphylum']))
|
||||
|
||||
Animal.drop_collection()
|
||||
|
||||
def test_reload_sharded_nested(self):
|
||||
class SuperPhylum(EmbeddedDocument):
|
||||
@@ -476,6 +490,34 @@ class InstanceTest(MongoDBTestCase):
|
||||
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
Animal.drop_collection()
|
||||
|
||||
def test_update_shard_key_routing(self):
|
||||
"""Ensures updating a doc with a specified shard_key includes it in
|
||||
the query.
|
||||
"""
|
||||
class Animal(Document):
|
||||
is_mammal = BooleanField()
|
||||
name = StringField()
|
||||
meta = {'shard_key': ('is_mammal', 'id')}
|
||||
|
||||
Animal.drop_collection()
|
||||
doc = Animal(is_mammal=True, name='Dog')
|
||||
doc.save()
|
||||
|
||||
mongo_db = get_mongodb_version()
|
||||
|
||||
with query_counter() as q:
|
||||
doc.name = 'Cat'
|
||||
doc.save()
|
||||
query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0]
|
||||
self.assertEqual(query_op['op'], 'update')
|
||||
if mongo_db <= MONGODB_34:
|
||||
self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal']))
|
||||
else:
|
||||
self.assertEqual(set(query_op['command']['q'].keys()), set(['_id', 'is_mammal']))
|
||||
|
||||
Animal.drop_collection()
|
||||
|
||||
def test_reload_with_changed_fields(self):
|
||||
"""Ensures reloading will not affect changed fields"""
|
||||
@@ -711,39 +753,78 @@ class InstanceTest(MongoDBTestCase):
|
||||
acc1 = Account.objects.first()
|
||||
self.assertHasInstance(acc1._data["emails"][0], acc1)
|
||||
|
||||
def test_save_checks_that_clean_is_called(self):
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
class TestDocument(Document):
|
||||
def clean(self):
|
||||
raise CustomError()
|
||||
|
||||
with self.assertRaises(CustomError):
|
||||
TestDocument().save()
|
||||
|
||||
TestDocument().save(clean=False)
|
||||
|
||||
def test_save_signal_pre_save_post_validation_makes_change_to_doc(self):
|
||||
class BlogPost(Document):
|
||||
content = StringField()
|
||||
|
||||
@classmethod
|
||||
def pre_save_post_validation(cls, sender, document, **kwargs):
|
||||
document.content = 'checked'
|
||||
|
||||
signals.pre_save_post_validation.connect(BlogPost.pre_save_post_validation, sender=BlogPost)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(content='unchecked').save()
|
||||
self.assertEqual(post.content, 'checked')
|
||||
# Make sure pre_save_post_validation changes makes it to the db
|
||||
raw_doc = get_as_pymongo(post)
|
||||
self.assertEqual(
|
||||
raw_doc,
|
||||
{
|
||||
'content': 'checked',
|
||||
'_id': post.id
|
||||
})
|
||||
|
||||
# Important to disconnect as it could cause some assertions in test_signals
|
||||
# to fail (due to the garbage collection timing of this signal)
|
||||
signals.pre_save_post_validation.disconnect(BlogPost.pre_save_post_validation)
|
||||
|
||||
def test_document_clean(self):
|
||||
class TestDocument(Document):
|
||||
status = StringField()
|
||||
pub_date = DateTimeField()
|
||||
cleaned = BooleanField(default=False)
|
||||
|
||||
def clean(self):
|
||||
if self.status == 'draft' and self.pub_date is not None:
|
||||
msg = 'Draft entries may not have a publication date.'
|
||||
raise ValidationError(msg)
|
||||
# Set the pub_date for published items if not set.
|
||||
if self.status == 'published' and self.pub_date is None:
|
||||
self.pub_date = datetime.now()
|
||||
self.cleaned = True
|
||||
|
||||
TestDocument.drop_collection()
|
||||
|
||||
t = TestDocument(status="draft", pub_date=datetime.now())
|
||||
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
t.save()
|
||||
|
||||
expected_msg = "Draft entries may not have a publication date."
|
||||
self.assertIn(expected_msg, cm.exception.message)
|
||||
self.assertEqual(cm.exception.to_dict(), {'__all__': expected_msg})
|
||||
t = TestDocument(status="draft")
|
||||
|
||||
# Ensure clean=False prevent call to clean
|
||||
t = TestDocument(status="published")
|
||||
t.save(clean=False)
|
||||
|
||||
self.assertEqual(t.pub_date, None)
|
||||
self.assertEqual(t.status, "published")
|
||||
self.assertEqual(t.cleaned, False)
|
||||
|
||||
t = TestDocument(status="published")
|
||||
self.assertEqual(t.cleaned, False)
|
||||
t.save(clean=True)
|
||||
|
||||
self.assertEqual(type(t.pub_date), datetime)
|
||||
self.assertEqual(t.status, "published")
|
||||
self.assertEqual(t.cleaned, True)
|
||||
raw_doc = get_as_pymongo(t)
|
||||
# Make sure clean changes makes it to the db
|
||||
self.assertEqual(
|
||||
raw_doc,
|
||||
{
|
||||
'status': 'published',
|
||||
'cleaned': True,
|
||||
'_id': t.id
|
||||
})
|
||||
|
||||
def test_document_embedded_clean(self):
|
||||
class TestEmbeddedDocument(EmbeddedDocument):
|
||||
@@ -844,7 +925,6 @@ class InstanceTest(MongoDBTestCase):
|
||||
|
||||
self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_modify_with_positional_push(self):
|
||||
class Content(EmbeddedDocument):
|
||||
keywords = ListField(StringField())
|
||||
@@ -884,19 +964,39 @@ class InstanceTest(MongoDBTestCase):
|
||||
person.save()
|
||||
|
||||
# Ensure that the object is in the database
|
||||
collection = self.db[self.Person._get_collection_name()]
|
||||
person_obj = collection.find_one({'name': 'Test User'})
|
||||
self.assertEqual(person_obj['name'], 'Test User')
|
||||
self.assertEqual(person_obj['age'], 30)
|
||||
self.assertEqual(person_obj['_id'], person.id)
|
||||
raw_doc = get_as_pymongo(person)
|
||||
self.assertEqual(
|
||||
raw_doc,
|
||||
{
|
||||
'_cls': 'Person',
|
||||
'name': 'Test User',
|
||||
'age': 30,
|
||||
'_id': person.id
|
||||
})
|
||||
|
||||
# Test skipping validation on save
|
||||
def test_save_skip_validation(self):
|
||||
class Recipient(Document):
|
||||
email = EmailField(required=True)
|
||||
|
||||
recipient = Recipient(email='not-an-email')
|
||||
self.assertRaises(ValidationError, recipient.save)
|
||||
with self.assertRaises(ValidationError):
|
||||
recipient.save()
|
||||
|
||||
recipient.save(validate=False)
|
||||
raw_doc = get_as_pymongo(recipient)
|
||||
self.assertEqual(
|
||||
raw_doc,
|
||||
{
|
||||
'email': 'not-an-email',
|
||||
'_id': recipient.id
|
||||
})
|
||||
|
||||
def test_save_with_bad_id(self):
|
||||
class Clown(Document):
|
||||
id = IntField(primary_key=True)
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
Clown(id="not_an_int").save()
|
||||
|
||||
def test_save_to_a_value_that_equates_to_false(self):
|
||||
class Thing(EmbeddedDocument):
|
||||
@@ -3089,24 +3189,6 @@ class InstanceTest(MongoDBTestCase):
|
||||
"UNDEFINED",
|
||||
system.nodes["node"].parameters["param"].macros["test"].value)
|
||||
|
||||
def test_embedded_document_save_reload_warning(self):
|
||||
"""Relates to #1570"""
|
||||
class Embedded(EmbeddedDocument):
|
||||
pass
|
||||
|
||||
class Doc(Document):
|
||||
emb = EmbeddedDocumentField(Embedded)
|
||||
|
||||
doc = Doc(emb=Embedded()).save()
|
||||
doc.emb.save() # Make sure its still working
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", DeprecationWarning)
|
||||
with self.assertRaises(DeprecationWarning):
|
||||
doc.emb.save()
|
||||
|
||||
with self.assertRaises(DeprecationWarning):
|
||||
doc.emb.reload()
|
||||
|
||||
def test_embedded_document_equality(self):
|
||||
class Test(Document):
|
||||
field = StringField(required=True)
|
||||
@@ -3198,7 +3280,7 @@ class InstanceTest(MongoDBTestCase):
|
||||
p2.name = 'alon2'
|
||||
p2.save()
|
||||
p3 = Person.objects().only('created_on')[0]
|
||||
self.assertEquals(orig_created_on, p3.created_on)
|
||||
self.assertEqual(orig_created_on, p3.created_on)
|
||||
|
||||
class Person(Document):
|
||||
created_on = DateTimeField(default=lambda: datetime.utcnow())
|
||||
@@ -3207,10 +3289,10 @@ class InstanceTest(MongoDBTestCase):
|
||||
|
||||
p4 = Person.objects()[0]
|
||||
p4.save()
|
||||
self.assertEquals(p4.height, 189)
|
||||
self.assertEqual(p4.height, 189)
|
||||
|
||||
# However the default will not be fixed in DB
|
||||
self.assertEquals(Person.objects(height=189).count(), 0)
|
||||
self.assertEqual(Person.objects(height=189).count(), 0)
|
||||
|
||||
# alter DB for the new default
|
||||
coll = Person._get_collection()
|
||||
@@ -3218,17 +3300,17 @@ class InstanceTest(MongoDBTestCase):
|
||||
if 'height' not in person:
|
||||
coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}})
|
||||
|
||||
self.assertEquals(Person.objects(height=189).count(), 1)
|
||||
self.assertEqual(Person.objects(height=189).count(), 1)
|
||||
|
||||
def test_from_son(self):
|
||||
# 771
|
||||
class MyPerson(self.Person):
|
||||
meta = dict(shard_key=["id"])
|
||||
p = MyPerson.from_json('{"name": "name", "age": 27}', created=True)
|
||||
self.assertEquals(p.id, None)
|
||||
self.assertEqual(p.id, None)
|
||||
p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here
|
||||
p = MyPerson._from_son({"name": "name", "age": 27}, created=True)
|
||||
self.assertEquals(p.id, None)
|
||||
self.assertEqual(p.id, None)
|
||||
p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here
|
||||
|
||||
def test_from_son_created_False_without_id(self):
|
||||
@@ -3306,7 +3388,7 @@ class InstanceTest(MongoDBTestCase):
|
||||
u_from_db = User.objects.get(name='user')
|
||||
u_from_db.height = None
|
||||
u_from_db.save()
|
||||
self.assertEquals(u_from_db.height, None)
|
||||
self.assertEqual(u_from_db.height, None)
|
||||
# 864
|
||||
self.assertEqual(u_from_db.str_fld, None)
|
||||
self.assertEqual(u_from_db.int_fld, None)
|
||||
@@ -3320,7 +3402,7 @@ class InstanceTest(MongoDBTestCase):
|
||||
u.save()
|
||||
User.objects(name='user').update_one(set__height=None, upsert=True)
|
||||
u_from_db = User.objects.get(name='user')
|
||||
self.assertEquals(u_from_db.height, None)
|
||||
self.assertEqual(u_from_db.height, None)
|
||||
|
||||
def test_not_saved_eq(self):
|
||||
"""Ensure we can compare documents not saved.
|
||||
@@ -3362,7 +3444,6 @@ class InstanceTest(MongoDBTestCase):
|
||||
|
||||
person.update(set__height=2.0)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_push_with_position(self):
|
||||
"""Ensure that push with position works properly for an instance."""
|
||||
class BlogPost(Document):
|
||||
|
@@ -61,10 +61,6 @@ class TestJson(unittest.TestCase):
|
||||
self.assertEqual(doc, Doc.from_json(doc.to_json()))
|
||||
|
||||
def test_json_complex(self):
|
||||
|
||||
if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3:
|
||||
raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs")
|
||||
|
||||
class EmbeddedDoc(EmbeddedDocument):
|
||||
pass
|
||||
|
||||
|
@@ -8,11 +8,11 @@ from bson import DBRef, ObjectId, SON
|
||||
|
||||
from mongoengine import Document, StringField, IntField, DateTimeField, DateField, ValidationError, \
|
||||
ComplexDateTimeField, FloatField, ListField, ReferenceField, DictField, EmbeddedDocument, EmbeddedDocumentField, \
|
||||
GenericReferenceField, DoesNotExist, NotRegistered, GenericEmbeddedDocumentField, OperationError, DynamicField, \
|
||||
FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField, ObjectIdField, \
|
||||
SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument
|
||||
from mongoengine.base import (BaseField, EmbeddedDocumentList,
|
||||
_document_registry)
|
||||
GenericReferenceField, DoesNotExist, NotRegistered, OperationError, DynamicField, \
|
||||
FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField,\
|
||||
ObjectIdField, SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument
|
||||
from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry)
|
||||
from mongoengine.errors import DeprecatedError
|
||||
|
||||
from tests.utils import MongoDBTestCase
|
||||
|
||||
@@ -57,6 +57,48 @@ class FieldTest(MongoDBTestCase):
|
||||
self.assertEqual(
|
||||
data_to_be_saved, ['age', 'created', 'day', 'name', 'userid'])
|
||||
|
||||
def test_custom_field_validation_raise_deprecated_error_when_validation_return_something(self):
|
||||
# Covers introduction of a breaking change in the validation parameter (0.18)
|
||||
def _not_empty(z):
|
||||
return bool(z)
|
||||
|
||||
class Person(Document):
|
||||
name = StringField(validation=_not_empty)
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
error = ("validation argument for `name` must not return anything, "
|
||||
"it should raise a ValidationError if validation fails")
|
||||
|
||||
with self.assertRaises(DeprecatedError) as ctx_err:
|
||||
Person(name="").validate()
|
||||
self.assertEqual(str(ctx_err.exception), error)
|
||||
|
||||
with self.assertRaises(DeprecatedError) as ctx_err:
|
||||
Person(name="").save()
|
||||
self.assertEqual(str(ctx_err.exception), error)
|
||||
|
||||
def test_custom_field_validation_raise_validation_error(self):
|
||||
def _not_empty(z):
|
||||
if not z:
|
||||
raise ValidationError('cantbeempty')
|
||||
|
||||
class Person(Document):
|
||||
name = StringField(validation=_not_empty)
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
with self.assertRaises(ValidationError) as ctx_err:
|
||||
Person(name="").validate()
|
||||
self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception))
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
Person(name="").save()
|
||||
self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception))
|
||||
|
||||
Person(name="garbage").validate()
|
||||
Person(name="garbage").save()
|
||||
|
||||
def test_default_values_set_to_None(self):
|
||||
"""Ensure that default field values are used even when
|
||||
we explcitly initialize the doc with None values.
|
||||
@@ -1373,7 +1415,7 @@ class FieldTest(MongoDBTestCase):
|
||||
brother = Brother(name="Bob", sibling=sister)
|
||||
brother.save()
|
||||
|
||||
self.assertEquals(Brother.objects[0].sibling.name, sister.name)
|
||||
self.assertEqual(Brother.objects[0].sibling.name, sister.name)
|
||||
|
||||
def test_reference_abstract_class(self):
|
||||
"""Ensure that an abstract class instance cannot be used in the
|
||||
@@ -1769,79 +1811,6 @@ class FieldTest(MongoDBTestCase):
|
||||
with self.assertRaises(ValidationError):
|
||||
shirt.validate()
|
||||
|
||||
def test_choices_validation_documents(self):
|
||||
"""
|
||||
Ensure fields with document choices validate given a valid choice.
|
||||
"""
|
||||
class UserComments(EmbeddedDocument):
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class BlogPost(Document):
|
||||
comments = ListField(
|
||||
GenericEmbeddedDocumentField(choices=(UserComments,))
|
||||
)
|
||||
|
||||
# Ensure Validation Passes
|
||||
BlogPost(comments=[
|
||||
UserComments(author='user2', message='message2'),
|
||||
]).save()
|
||||
|
||||
def test_choices_validation_documents_invalid(self):
|
||||
"""
|
||||
Ensure fields with document choices validate given an invalid choice.
|
||||
This should throw a ValidationError exception.
|
||||
"""
|
||||
class UserComments(EmbeddedDocument):
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class ModeratorComments(EmbeddedDocument):
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class BlogPost(Document):
|
||||
comments = ListField(
|
||||
GenericEmbeddedDocumentField(choices=(UserComments,))
|
||||
)
|
||||
|
||||
# Single Entry Failure
|
||||
post = BlogPost(comments=[
|
||||
ModeratorComments(author='mod1', message='message1'),
|
||||
])
|
||||
self.assertRaises(ValidationError, post.save)
|
||||
|
||||
# Mixed Entry Failure
|
||||
post = BlogPost(comments=[
|
||||
ModeratorComments(author='mod1', message='message1'),
|
||||
UserComments(author='user2', message='message2'),
|
||||
])
|
||||
self.assertRaises(ValidationError, post.save)
|
||||
|
||||
def test_choices_validation_documents_inheritance(self):
|
||||
"""
|
||||
Ensure fields with document choices validate given subclass of choice.
|
||||
"""
|
||||
class Comments(EmbeddedDocument):
|
||||
meta = {
|
||||
'abstract': True
|
||||
}
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class UserComments(Comments):
|
||||
pass
|
||||
|
||||
class BlogPost(Document):
|
||||
comments = ListField(
|
||||
GenericEmbeddedDocumentField(choices=(Comments,))
|
||||
)
|
||||
|
||||
# Save Valid EmbeddedDocument Type
|
||||
BlogPost(comments=[
|
||||
UserComments(author='user2', message='message2'),
|
||||
]).save()
|
||||
|
||||
def test_choices_get_field_display(self):
|
||||
"""Test dynamic helper for returning the display value of a choices
|
||||
field.
|
||||
@@ -1958,85 +1927,6 @@ class FieldTest(MongoDBTestCase):
|
||||
self.assertEqual(error_dict['size'], SIZE_MESSAGE)
|
||||
self.assertEqual(error_dict['color'], COLOR_MESSAGE)
|
||||
|
||||
def test_generic_embedded_document(self):
|
||||
class Car(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Dish(EmbeddedDocument):
|
||||
food = StringField(required=True)
|
||||
number = IntField()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
like = GenericEmbeddedDocumentField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.like = Car(name='Fiat')
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.like, Car)
|
||||
|
||||
person.like = Dish(food="arroz", number=15)
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.like, Dish)
|
||||
|
||||
def test_generic_embedded_document_choices(self):
|
||||
"""Ensure you can limit GenericEmbeddedDocument choices."""
|
||||
class Car(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Dish(EmbeddedDocument):
|
||||
food = StringField(required=True)
|
||||
number = IntField()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
like = GenericEmbeddedDocumentField(choices=(Dish,))
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.like = Car(name='Fiat')
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person.like = Dish(food="arroz", number=15)
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.like, Dish)
|
||||
|
||||
def test_generic_list_embedded_document_choices(self):
|
||||
"""Ensure you can limit GenericEmbeddedDocument choices inside
|
||||
a list field.
|
||||
"""
|
||||
class Car(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Dish(EmbeddedDocument):
|
||||
food = StringField(required=True)
|
||||
number = IntField()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,)))
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.likes = [Car(name='Fiat')]
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person.likes = [Dish(food="arroz", number=15)]
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.likes[0], Dish)
|
||||
|
||||
def test_recursive_validation(self):
|
||||
"""Ensure that a validation result to_dict is available."""
|
||||
class Author(EmbeddedDocument):
|
||||
@@ -2198,8 +2088,8 @@ class FieldTest(MongoDBTestCase):
|
||||
Dog().save()
|
||||
Fish().save()
|
||||
Human().save()
|
||||
self.assertEquals(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2)
|
||||
self.assertEquals(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0)
|
||||
self.assertEqual(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2)
|
||||
self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0)
|
||||
|
||||
def test_sparse_field(self):
|
||||
class Doc(Document):
|
||||
@@ -2702,44 +2592,5 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
|
||||
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
|
||||
|
||||
|
||||
class TestEmbeddedDocumentField(MongoDBTestCase):
|
||||
def test___init___(self):
|
||||
class MyDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
field = EmbeddedDocumentField(MyDoc)
|
||||
self.assertEqual(field.document_type_obj, MyDoc)
|
||||
|
||||
field2 = EmbeddedDocumentField('MyDoc')
|
||||
self.assertEqual(field2.document_type_obj, 'MyDoc')
|
||||
|
||||
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
|
||||
with self.assertRaises(ValidationError):
|
||||
EmbeddedDocumentField(dict)
|
||||
|
||||
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
|
||||
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
with self.assertRaises(ValidationError) as ctx:
|
||||
emb.document_type
|
||||
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
|
||||
|
||||
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
|
||||
# Relates to #1661
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingDoc(Document):
|
||||
emb = EmbeddedDocumentField(MyDoc)
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingdoc2(Document):
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@@ -320,16 +320,16 @@ class FileTest(MongoDBTestCase):
|
||||
|
||||
files = db.fs.files.find()
|
||||
chunks = db.fs.chunks.find()
|
||||
self.assertEquals(len(list(files)), 1)
|
||||
self.assertEquals(len(list(chunks)), 1)
|
||||
self.assertEqual(len(list(files)), 1)
|
||||
self.assertEqual(len(list(chunks)), 1)
|
||||
|
||||
# Deleting the docoument should delete the files
|
||||
testfile.delete()
|
||||
|
||||
files = db.fs.files.find()
|
||||
chunks = db.fs.chunks.find()
|
||||
self.assertEquals(len(list(files)), 0)
|
||||
self.assertEquals(len(list(chunks)), 0)
|
||||
self.assertEqual(len(list(files)), 0)
|
||||
self.assertEqual(len(list(chunks)), 0)
|
||||
|
||||
# Test case where we don't store a file in the first place
|
||||
testfile = TestFile()
|
||||
@@ -337,15 +337,15 @@ class FileTest(MongoDBTestCase):
|
||||
|
||||
files = db.fs.files.find()
|
||||
chunks = db.fs.chunks.find()
|
||||
self.assertEquals(len(list(files)), 0)
|
||||
self.assertEquals(len(list(chunks)), 0)
|
||||
self.assertEqual(len(list(files)), 0)
|
||||
self.assertEqual(len(list(chunks)), 0)
|
||||
|
||||
testfile.delete()
|
||||
|
||||
files = db.fs.files.find()
|
||||
chunks = db.fs.chunks.find()
|
||||
self.assertEquals(len(list(files)), 0)
|
||||
self.assertEquals(len(list(chunks)), 0)
|
||||
self.assertEqual(len(list(files)), 0)
|
||||
self.assertEqual(len(list(chunks)), 0)
|
||||
|
||||
# Test case where we overwrite the file
|
||||
testfile = TestFile()
|
||||
@@ -358,15 +358,15 @@ class FileTest(MongoDBTestCase):
|
||||
|
||||
files = db.fs.files.find()
|
||||
chunks = db.fs.chunks.find()
|
||||
self.assertEquals(len(list(files)), 1)
|
||||
self.assertEquals(len(list(chunks)), 1)
|
||||
self.assertEqual(len(list(files)), 1)
|
||||
self.assertEqual(len(list(chunks)), 1)
|
||||
|
||||
testfile.delete()
|
||||
|
||||
files = db.fs.files.find()
|
||||
chunks = db.fs.chunks.find()
|
||||
self.assertEquals(len(list(files)), 0)
|
||||
self.assertEquals(len(list(chunks)), 0)
|
||||
self.assertEqual(len(list(files)), 0)
|
||||
self.assertEqual(len(list(chunks)), 0)
|
||||
|
||||
def test_image_field(self):
|
||||
if not HAS_PIL:
|
||||
|
@@ -40,6 +40,11 @@ class GeoFieldTest(unittest.TestCase):
|
||||
expected = "Both values (%s) in point must be float or int" % repr(coord)
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
invalid_coords = [21, 4, 'a']
|
||||
for coord in invalid_coords:
|
||||
expected = "GeoPointField can only accept tuples or lists of (x, y)"
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
def test_point_validation(self):
|
||||
class Location(Document):
|
||||
loc = PointField()
|
||||
|
@@ -208,10 +208,7 @@ class TestCachedReferenceField(MongoDBTestCase):
|
||||
('pj', "PJ")
|
||||
)
|
||||
name = StringField()
|
||||
tp = StringField(
|
||||
choices=TYPES
|
||||
)
|
||||
|
||||
tp = StringField(choices=TYPES)
|
||||
father = CachedReferenceField('self', fields=('tp',))
|
||||
|
||||
Person.drop_collection()
|
||||
@@ -222,6 +219,9 @@ class TestCachedReferenceField(MongoDBTestCase):
|
||||
a2 = Person(name='Wilson Junior', tp='pf', father=a1)
|
||||
a2.save()
|
||||
|
||||
a2 = Person.objects.with_id(a2.id)
|
||||
self.assertEqual(a2.father.tp, a1.tp)
|
||||
|
||||
self.assertEqual(dict(a2.to_mongo()), {
|
||||
"_id": a2.pk,
|
||||
"name": u"Wilson Junior",
|
||||
@@ -374,6 +374,9 @@ class TestCachedReferenceField(MongoDBTestCase):
|
||||
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
|
||||
self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u')
|
||||
|
||||
# Check to_mongo with fields
|
||||
self.assertNotIn('animal', o.to_mongo(fields=['person']))
|
||||
|
||||
# counts
|
||||
Ocorrence(person="teste 2").save()
|
||||
Ocorrence(person="teste 3").save()
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import datetime as dt
|
||||
import six
|
||||
|
||||
try:
|
||||
@@ -41,13 +41,13 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
a document.
|
||||
"""
|
||||
class Person(Document):
|
||||
created = DateTimeField(default=datetime.datetime.utcnow)
|
||||
created = DateTimeField(default=dt.datetime.utcnow)
|
||||
|
||||
utcnow = datetime.datetime.utcnow()
|
||||
utcnow = dt.datetime.utcnow()
|
||||
person = Person()
|
||||
person.validate()
|
||||
person_created_t0 = person.created
|
||||
self.assertLess(person.created - utcnow, datetime.timedelta(seconds=1))
|
||||
self.assertLess(person.created - utcnow, dt.timedelta(seconds=1))
|
||||
self.assertEqual(person_created_t0, person.created) # make sure it does not change
|
||||
self.assertEqual(person._data['created'], person.created)
|
||||
|
||||
@@ -65,15 +65,15 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
|
||||
# Test can save dates
|
||||
log = LogEntry()
|
||||
log.date = datetime.date.today()
|
||||
log.date = dt.date.today()
|
||||
log.save()
|
||||
log.reload()
|
||||
self.assertEqual(log.date.date(), datetime.date.today())
|
||||
self.assertEqual(log.date.date(), dt.date.today())
|
||||
|
||||
# Post UTC - microseconds are rounded (down) nearest millisecond and
|
||||
# dropped
|
||||
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
|
||||
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1)
|
||||
d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 999)
|
||||
d2 = dt.datetime(1970, 1, 1, 0, 0, 1)
|
||||
log = LogEntry()
|
||||
log.date = d1
|
||||
log.save()
|
||||
@@ -82,8 +82,8 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
self.assertEqual(log.date, d2)
|
||||
|
||||
# Post UTC - microseconds are rounded (down) nearest millisecond
|
||||
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999)
|
||||
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000)
|
||||
d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 9999)
|
||||
d2 = dt.datetime(1970, 1, 1, 0, 0, 1, 9000)
|
||||
log.date = d1
|
||||
log.save()
|
||||
log.reload()
|
||||
@@ -93,8 +93,8 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
if not six.PY3:
|
||||
# Pre UTC dates microseconds below 1000 are dropped
|
||||
# This does not seem to be true in PY3
|
||||
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
|
||||
d2 = datetime.datetime(1969, 12, 31, 23, 59, 59)
|
||||
d1 = dt.datetime(1969, 12, 31, 23, 59, 59, 999)
|
||||
d2 = dt.datetime(1969, 12, 31, 23, 59, 59)
|
||||
log.date = d1
|
||||
log.save()
|
||||
log.reload()
|
||||
@@ -108,7 +108,7 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
|
||||
LogEntry.drop_collection()
|
||||
|
||||
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1)
|
||||
d1 = dt.datetime(1970, 1, 1, 0, 0, 1)
|
||||
log = LogEntry()
|
||||
log.date = d1
|
||||
log.validate()
|
||||
@@ -124,7 +124,7 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
|
||||
# create additional 19 log entries for a total of 20
|
||||
for i in range(1971, 1990):
|
||||
d = datetime.datetime(i, 1, 1, 0, 0, 1)
|
||||
d = dt.datetime(i, 1, 1, 0, 0, 1)
|
||||
LogEntry(date=d).save()
|
||||
|
||||
self.assertEqual(LogEntry.objects.count(), 20)
|
||||
@@ -143,15 +143,15 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
i += 1
|
||||
|
||||
# Test searching
|
||||
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1))
|
||||
logs = LogEntry.objects.filter(date__gte=dt.datetime(1980, 1, 1))
|
||||
self.assertEqual(logs.count(), 10)
|
||||
|
||||
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1))
|
||||
logs = LogEntry.objects.filter(date__lte=dt.datetime(1980, 1, 1))
|
||||
self.assertEqual(logs.count(), 10)
|
||||
|
||||
logs = LogEntry.objects.filter(
|
||||
date__lte=datetime.datetime(1980, 1, 1),
|
||||
date__gte=datetime.datetime(1975, 1, 1),
|
||||
date__lte=dt.datetime(1980, 1, 1),
|
||||
date__gte=dt.datetime(1975, 1, 1),
|
||||
)
|
||||
self.assertEqual(logs.count(), 5)
|
||||
|
||||
@@ -163,23 +163,51 @@ class TestDateTimeField(MongoDBTestCase):
|
||||
time = DateTimeField()
|
||||
|
||||
log = LogEntry()
|
||||
log.time = datetime.datetime.now()
|
||||
log.time = dt.datetime.now()
|
||||
log.validate()
|
||||
|
||||
log.time = datetime.date.today()
|
||||
log.time = dt.date.today()
|
||||
log.validate()
|
||||
|
||||
log.time = datetime.datetime.now().isoformat(' ')
|
||||
log.time = dt.datetime.now().isoformat(' ')
|
||||
log.validate()
|
||||
|
||||
log.time = '2019-05-16 21:42:57.897847'
|
||||
log.validate()
|
||||
|
||||
if dateutil:
|
||||
log.time = datetime.datetime.now().isoformat('T')
|
||||
log.time = dt.datetime.now().isoformat('T')
|
||||
log.validate()
|
||||
|
||||
log.time = -1
|
||||
self.assertRaises(ValidationError, log.validate)
|
||||
log.time = 'ABC'
|
||||
self.assertRaises(ValidationError, log.validate)
|
||||
log.time = '2019-05-16 21:GARBAGE:12'
|
||||
self.assertRaises(ValidationError, log.validate)
|
||||
log.time = '2019-05-16 21:42:57.GARBAGE'
|
||||
self.assertRaises(ValidationError, log.validate)
|
||||
log.time = '2019-05-16 21:42:57.123.456'
|
||||
self.assertRaises(ValidationError, log.validate)
|
||||
|
||||
def test_parse_datetime_as_str(self):
|
||||
class DTDoc(Document):
|
||||
date = DateTimeField()
|
||||
|
||||
date_str = '2019-03-02 22:26:01'
|
||||
|
||||
# make sure that passing a parsable datetime works
|
||||
dtd = DTDoc()
|
||||
dtd.date = date_str
|
||||
self.assertIsInstance(dtd.date, six.string_types)
|
||||
dtd.save()
|
||||
dtd.reload()
|
||||
|
||||
self.assertIsInstance(dtd.date, dt.datetime)
|
||||
self.assertEqual(str(dtd.date), date_str)
|
||||
|
||||
dtd.date = 'January 1st, 9999999999'
|
||||
self.assertRaises(ValidationError, dtd.validate)
|
||||
|
||||
|
||||
class TestDateTimeTzAware(MongoDBTestCase):
|
||||
@@ -196,8 +224,8 @@ class TestDateTimeTzAware(MongoDBTestCase):
|
||||
|
||||
LogEntry.drop_collection()
|
||||
|
||||
LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save()
|
||||
LogEntry(time=dt.datetime(2013, 1, 1, 0, 0, 0)).save()
|
||||
|
||||
log = LogEntry.objects.first()
|
||||
log.time = datetime.datetime(2013, 1, 1, 0, 0, 0)
|
||||
log.time = dt.datetime(2013, 1, 1, 0, 0, 0)
|
||||
self.assertEqual(['time'], log._changed_fields)
|
||||
|
@@ -75,6 +75,16 @@ class TestEmailField(MongoDBTestCase):
|
||||
user = User(email='me@localhost')
|
||||
user.validate()
|
||||
|
||||
def test_email_domain_validation_fails_if_invalid_idn(self):
|
||||
class User(Document):
|
||||
email = EmailField()
|
||||
|
||||
invalid_idn = '.google.com'
|
||||
user = User(email='me@%s' % invalid_idn)
|
||||
with self.assertRaises(ValidationError) as ctx_err:
|
||||
user.validate()
|
||||
self.assertIn("domain failed IDN encoding", str(ctx_err.exception))
|
||||
|
||||
def test_email_field_ip_domain(self):
|
||||
class User(Document):
|
||||
email = EmailField()
|
||||
|
344
tests/fields/test_embedded_document_field.py
Normal file
344
tests/fields/test_embedded_document_field.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \
|
||||
InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField, \
|
||||
ReferenceField
|
||||
|
||||
from tests.utils import MongoDBTestCase
|
||||
|
||||
|
||||
class TestEmbeddedDocumentField(MongoDBTestCase):
|
||||
def test___init___(self):
|
||||
class MyDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
field = EmbeddedDocumentField(MyDoc)
|
||||
self.assertEqual(field.document_type_obj, MyDoc)
|
||||
|
||||
field2 = EmbeddedDocumentField('MyDoc')
|
||||
self.assertEqual(field2.document_type_obj, 'MyDoc')
|
||||
|
||||
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
|
||||
with self.assertRaises(ValidationError):
|
||||
EmbeddedDocumentField(dict)
|
||||
|
||||
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
|
||||
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
with self.assertRaises(ValidationError) as ctx:
|
||||
emb.document_type
|
||||
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
|
||||
|
||||
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
|
||||
# Relates to #1661
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingDoc(Document):
|
||||
emb = EmbeddedDocumentField(MyDoc)
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingdoc2(Document):
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
|
||||
def test_query_embedded_document_attribute(self):
|
||||
class AdminSettings(EmbeddedDocument):
|
||||
foo1 = StringField()
|
||||
foo2 = StringField()
|
||||
|
||||
class Person(Document):
|
||||
settings = EmbeddedDocumentField(AdminSettings)
|
||||
name = StringField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p = Person(
|
||||
settings=AdminSettings(foo1='bar1', foo2='bar2'),
|
||||
name='John',
|
||||
).save()
|
||||
|
||||
# Test non exiting attribute
|
||||
with self.assertRaises(InvalidQueryError) as ctx_err:
|
||||
Person.objects(settings__notexist='bar').first()
|
||||
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
|
||||
|
||||
with self.assertRaises(LookUpError):
|
||||
Person.objects.only('settings.notexist')
|
||||
|
||||
# Test existing attribute
|
||||
self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p.id)
|
||||
only_p = Person.objects.only('settings.foo1').first()
|
||||
self.assertEqual(only_p.settings.foo1, p.settings.foo1)
|
||||
self.assertIsNone(only_p.settings.foo2)
|
||||
self.assertIsNone(only_p.name)
|
||||
|
||||
exclude_p = Person.objects.exclude('settings.foo1').first()
|
||||
self.assertIsNone(exclude_p.settings.foo1)
|
||||
self.assertEqual(exclude_p.settings.foo2, p.settings.foo2)
|
||||
self.assertEqual(exclude_p.name, p.name)
|
||||
|
||||
def test_query_embedded_document_attribute_with_inheritance(self):
|
||||
class BaseSettings(EmbeddedDocument):
|
||||
meta = {'allow_inheritance': True}
|
||||
base_foo = StringField()
|
||||
|
||||
class AdminSettings(BaseSettings):
|
||||
sub_foo = StringField()
|
||||
|
||||
class Person(Document):
|
||||
settings = EmbeddedDocumentField(BaseSettings)
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo'))
|
||||
p.save()
|
||||
|
||||
# Test non exiting attribute
|
||||
with self.assertRaises(InvalidQueryError) as ctx_err:
|
||||
self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id)
|
||||
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
|
||||
|
||||
# Test existing attribute
|
||||
self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id)
|
||||
self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id)
|
||||
|
||||
only_p = Person.objects.only('settings.base_foo', 'settings._cls').first()
|
||||
self.assertEqual(only_p.settings.base_foo, 'basefoo')
|
||||
self.assertIsNone(only_p.settings.sub_foo)
|
||||
|
||||
def test_query_list_embedded_document_with_inheritance(self):
|
||||
class Post(EmbeddedDocument):
|
||||
title = StringField(max_length=120, required=True)
|
||||
meta = {'allow_inheritance': True}
|
||||
|
||||
class TextPost(Post):
|
||||
content = StringField()
|
||||
|
||||
class MoviePost(Post):
|
||||
author = StringField()
|
||||
|
||||
class Record(Document):
|
||||
posts = ListField(EmbeddedDocumentField(Post))
|
||||
|
||||
record_movie = Record(posts=[MoviePost(author='John', title='foo')]).save()
|
||||
record_text = Record(posts=[TextPost(content='a', title='foo')]).save()
|
||||
|
||||
records = list(Record.objects(posts__author=record_movie.posts[0].author))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(records[0].id, record_movie.id)
|
||||
|
||||
records = list(Record.objects(posts__content=record_text.posts[0].content))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(records[0].id, record_text.id)
|
||||
|
||||
self.assertEqual(Record.objects(posts__title='foo').count(), 2)
|
||||
|
||||
|
||||
class TestGenericEmbeddedDocumentField(MongoDBTestCase):
|
||||
|
||||
def test_generic_embedded_document(self):
|
||||
class Car(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Dish(EmbeddedDocument):
|
||||
food = StringField(required=True)
|
||||
number = IntField()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
like = GenericEmbeddedDocumentField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.like = Car(name='Fiat')
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.like, Car)
|
||||
|
||||
person.like = Dish(food="arroz", number=15)
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.like, Dish)
|
||||
|
||||
def test_generic_embedded_document_choices(self):
|
||||
"""Ensure you can limit GenericEmbeddedDocument choices."""
|
||||
class Car(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Dish(EmbeddedDocument):
|
||||
food = StringField(required=True)
|
||||
number = IntField()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
like = GenericEmbeddedDocumentField(choices=(Dish,))
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.like = Car(name='Fiat')
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person.like = Dish(food="arroz", number=15)
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.like, Dish)
|
||||
|
||||
def test_generic_list_embedded_document_choices(self):
|
||||
"""Ensure you can limit GenericEmbeddedDocument choices inside
|
||||
a list field.
|
||||
"""
|
||||
class Car(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Dish(EmbeddedDocument):
|
||||
food = StringField(required=True)
|
||||
number = IntField()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,)))
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.likes = [Car(name='Fiat')]
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person.likes = [Dish(food="arroz", number=15)]
|
||||
person.save()
|
||||
|
||||
person = Person.objects.first()
|
||||
self.assertIsInstance(person.likes[0], Dish)
|
||||
|
||||
def test_choices_validation_documents(self):
|
||||
"""
|
||||
Ensure fields with document choices validate given a valid choice.
|
||||
"""
|
||||
class UserComments(EmbeddedDocument):
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class BlogPost(Document):
|
||||
comments = ListField(
|
||||
GenericEmbeddedDocumentField(choices=(UserComments,))
|
||||
)
|
||||
|
||||
# Ensure Validation Passes
|
||||
BlogPost(comments=[
|
||||
UserComments(author='user2', message='message2'),
|
||||
]).save()
|
||||
|
||||
def test_choices_validation_documents_invalid(self):
|
||||
"""
|
||||
Ensure fields with document choices validate given an invalid choice.
|
||||
This should throw a ValidationError exception.
|
||||
"""
|
||||
class UserComments(EmbeddedDocument):
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class ModeratorComments(EmbeddedDocument):
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class BlogPost(Document):
|
||||
comments = ListField(
|
||||
GenericEmbeddedDocumentField(choices=(UserComments,))
|
||||
)
|
||||
|
||||
# Single Entry Failure
|
||||
post = BlogPost(comments=[
|
||||
ModeratorComments(author='mod1', message='message1'),
|
||||
])
|
||||
self.assertRaises(ValidationError, post.save)
|
||||
|
||||
# Mixed Entry Failure
|
||||
post = BlogPost(comments=[
|
||||
ModeratorComments(author='mod1', message='message1'),
|
||||
UserComments(author='user2', message='message2'),
|
||||
])
|
||||
self.assertRaises(ValidationError, post.save)
|
||||
|
||||
def test_choices_validation_documents_inheritance(self):
|
||||
"""
|
||||
Ensure fields with document choices validate given subclass of choice.
|
||||
"""
|
||||
class Comments(EmbeddedDocument):
|
||||
meta = {
|
||||
'abstract': True
|
||||
}
|
||||
author = StringField()
|
||||
message = StringField()
|
||||
|
||||
class UserComments(Comments):
|
||||
pass
|
||||
|
||||
class BlogPost(Document):
|
||||
comments = ListField(
|
||||
GenericEmbeddedDocumentField(choices=(Comments,))
|
||||
)
|
||||
|
||||
# Save Valid EmbeddedDocument Type
|
||||
BlogPost(comments=[
|
||||
UserComments(author='user2', message='message2'),
|
||||
]).save()
|
||||
|
||||
def test_query_generic_embedded_document_attribute(self):
|
||||
class AdminSettings(EmbeddedDocument):
|
||||
foo1 = StringField()
|
||||
|
||||
class NonAdminSettings(EmbeddedDocument):
|
||||
foo2 = StringField()
|
||||
|
||||
class Person(Document):
|
||||
settings = GenericEmbeddedDocumentField(choices=(AdminSettings, NonAdminSettings))
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(settings=AdminSettings(foo1='bar1')).save()
|
||||
p2 = Person(settings=NonAdminSettings(foo2='bar2')).save()
|
||||
|
||||
# Test non exiting attribute
|
||||
with self.assertRaises(InvalidQueryError) as ctx_err:
|
||||
Person.objects(settings__notexist='bar').first()
|
||||
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
|
||||
|
||||
with self.assertRaises(LookUpError):
|
||||
Person.objects.only('settings.notexist')
|
||||
|
||||
# Test existing attribute
|
||||
self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p1.id)
|
||||
self.assertEqual(Person.objects(settings__foo2='bar2').first().id, p2.id)
|
||||
|
||||
def test_query_generic_embedded_document_attribute_with_inheritance(self):
|
||||
class BaseSettings(EmbeddedDocument):
|
||||
meta = {'allow_inheritance': True}
|
||||
base_foo = StringField()
|
||||
|
||||
class AdminSettings(BaseSettings):
|
||||
sub_foo = StringField()
|
||||
|
||||
class Person(Document):
|
||||
settings = GenericEmbeddedDocumentField(choices=[BaseSettings])
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo'))
|
||||
p.save()
|
||||
|
||||
# Test non exiting attribute
|
||||
with self.assertRaises(InvalidQueryError) as ctx_err:
|
||||
self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id)
|
||||
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
|
||||
|
||||
# Test existing attribute
|
||||
self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id)
|
||||
self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id)
|
@@ -13,6 +13,35 @@ class TestLazyReferenceField(MongoDBTestCase):
|
||||
# with a document class name.
|
||||
self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument)
|
||||
|
||||
def test___repr__(self):
|
||||
class Animal(Document):
|
||||
pass
|
||||
|
||||
class Ocurrence(Document):
|
||||
animal = LazyReferenceField(Animal)
|
||||
|
||||
Animal.drop_collection()
|
||||
Ocurrence.drop_collection()
|
||||
|
||||
animal = Animal()
|
||||
oc = Ocurrence(animal=animal)
|
||||
self.assertIn('LazyReference', repr(oc.animal))
|
||||
|
||||
def test___getattr___unknown_attr_raises_attribute_error(self):
|
||||
class Animal(Document):
|
||||
pass
|
||||
|
||||
class Ocurrence(Document):
|
||||
animal = LazyReferenceField(Animal)
|
||||
|
||||
Animal.drop_collection()
|
||||
Ocurrence.drop_collection()
|
||||
|
||||
animal = Animal().save()
|
||||
oc = Ocurrence(animal=animal)
|
||||
with self.assertRaises(AttributeError):
|
||||
oc.animal.not_exist
|
||||
|
||||
def test_lazy_reference_simple(self):
|
||||
class Animal(Document):
|
||||
name = StringField()
|
||||
@@ -479,6 +508,23 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
|
||||
p = Ocurrence.objects.get()
|
||||
self.assertIs(p.animal, None)
|
||||
|
||||
def test_generic_lazy_reference_accepts_string_instead_of_class(self):
|
||||
class Animal(Document):
|
||||
name = StringField()
|
||||
tag = StringField()
|
||||
|
||||
class Ocurrence(Document):
|
||||
person = StringField()
|
||||
animal = GenericLazyReferenceField('Animal')
|
||||
|
||||
Animal.drop_collection()
|
||||
Ocurrence.drop_collection()
|
||||
|
||||
animal = Animal().save()
|
||||
Ocurrence(animal=animal).save()
|
||||
p = Ocurrence.objects.get()
|
||||
self.assertEqual(p.animal, animal)
|
||||
|
||||
def test_generic_lazy_reference_embedded(self):
|
||||
class Animal(Document):
|
||||
name = StringField()
|
||||
|
@@ -39,9 +39,9 @@ class TestLongField(MongoDBTestCase):
|
||||
|
||||
doc.value = -1
|
||||
self.assertRaises(ValidationError, doc.validate)
|
||||
doc.age = 120
|
||||
doc.value = 120
|
||||
self.assertRaises(ValidationError, doc.validate)
|
||||
doc.age = 'ten'
|
||||
doc.value = 'ten'
|
||||
self.assertRaises(ValidationError, doc.validate)
|
||||
|
||||
def test_long_ne_operator(self):
|
||||
|
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
from mongoengine import *
|
||||
|
||||
from tests.utils import MongoDBTestCase, requires_mongodb_gte_3
|
||||
from tests.utils import MongoDBTestCase
|
||||
|
||||
|
||||
__all__ = ("GeoQueriesTest",)
|
||||
@@ -70,9 +70,6 @@ class GeoQueriesTest(MongoDBTestCase):
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0], event2)
|
||||
|
||||
# $minDistance was added in MongoDB v2.6, but continued being buggy
|
||||
# until v3.0; skip for older versions
|
||||
@requires_mongodb_gte_3
|
||||
def test_near_and_min_distance(self):
|
||||
"""Ensure the "min_distance" operator works alongside the "near"
|
||||
operator.
|
||||
@@ -243,9 +240,6 @@ class GeoQueriesTest(MongoDBTestCase):
|
||||
events = self.Event.objects(location__geo_within_polygon=polygon2)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
# $minDistance was added in MongoDB v2.6, but continued being buggy
|
||||
# until v3.0; skip for older versions
|
||||
@requires_mongodb_gte_3
|
||||
def test_2dsphere_near_and_min_max_distance(self):
|
||||
"""Ensure "min_distace" and "max_distance" operators work well
|
||||
together with the "near" operator in a 2dsphere index.
|
||||
@@ -328,8 +322,6 @@ class GeoQueriesTest(MongoDBTestCase):
|
||||
"""Make sure PointField works properly in an embedded document."""
|
||||
self._test_embedded(point_field_class=PointField)
|
||||
|
||||
# Needs MongoDB > 2.6.4 https://jira.mongodb.org/browse/SERVER-14039
|
||||
@requires_mongodb_gte_3
|
||||
def test_spherical_geospatial_operators(self):
|
||||
"""Ensure that spherical geospatial queries are working."""
|
||||
class Point(Document):
|
||||
|
@@ -2,8 +2,6 @@ import unittest
|
||||
|
||||
from mongoengine import connect, Document, IntField, StringField, ListField
|
||||
|
||||
from tests.utils import requires_mongodb_gte_26
|
||||
|
||||
__all__ = ("FindAndModifyTest",)
|
||||
|
||||
|
||||
@@ -96,7 +94,6 @@ class FindAndModifyTest(unittest.TestCase):
|
||||
self.assertEqual(old_doc.to_mongo(), {"_id": 1})
|
||||
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_modify_with_push(self):
|
||||
class BlogPost(Document):
|
||||
tags = ListField(StringField())
|
||||
|
@@ -6,7 +6,6 @@ import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
from bson import DBRef, ObjectId
|
||||
from nose.plugins.skip import SkipTest
|
||||
import pymongo
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
@@ -18,11 +17,9 @@ from mongoengine import *
|
||||
from mongoengine.connection import get_connection, get_db
|
||||
from mongoengine.context_managers import query_counter, switch_db
|
||||
from mongoengine.errors import InvalidQueryError
|
||||
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36
|
||||
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
|
||||
QuerySet, QuerySetManager, queryset_manager)
|
||||
from tests.utils import requires_mongodb_gte_26, skip_pymongo3
|
||||
|
||||
|
||||
class db_ops_tracker(query_counter):
|
||||
@@ -33,6 +30,12 @@ class db_ops_tracker(query_counter):
|
||||
return list(self.db.system.profile.find(ignore_query))
|
||||
|
||||
|
||||
def get_key_compat(mongo_ver):
|
||||
ORDER_BY_KEY = 'sort'
|
||||
CMD_QUERY_KEY = 'command' if mongo_ver >= MONGODB_36 else 'query'
|
||||
return ORDER_BY_KEY, CMD_QUERY_KEY
|
||||
|
||||
|
||||
class QuerySetTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -87,7 +90,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
results = list(people)
|
||||
|
||||
self.assertIsInstance(results[0], self.Person)
|
||||
self.assertIsInstance(results[0].id, (ObjectId, str, unicode))
|
||||
self.assertIsInstance(results[0].id, ObjectId)
|
||||
|
||||
self.assertEqual(results[0], user_a)
|
||||
self.assertEqual(results[0].name, 'User A')
|
||||
@@ -158,6 +161,11 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(person.name, 'User B')
|
||||
self.assertEqual(person.age, None)
|
||||
|
||||
def test___getitem___invalid_index(self):
|
||||
"""Ensure slicing a queryset works as expected."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.Person.objects()['a']
|
||||
|
||||
def test_slice(self):
|
||||
"""Ensure slicing a queryset works as expected."""
|
||||
user_a = self.Person.objects.create(name='User A', age=20)
|
||||
@@ -589,7 +597,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(post.comments[0].by, 'joe')
|
||||
self.assertEqual(post.comments[0].votes.score, 4)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_update_min_max(self):
|
||||
class Scores(Document):
|
||||
high_score = IntField()
|
||||
@@ -607,7 +614,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
Scores.objects(id=scores.id).update(max__high_score=500)
|
||||
self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_update_multiple(self):
|
||||
class Product(Document):
|
||||
item = StringField()
|
||||
@@ -859,11 +865,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
Blog.objects.insert(blogs, load_bulk=False)
|
||||
|
||||
if MONGO_VER >= MONGODB_32:
|
||||
self.assertEqual(q, 1) # 1 entry containing the list of inserts
|
||||
else:
|
||||
self.assertEqual(q, len(blogs)) # 1 entry per doc inserted
|
||||
self.assertEqual(q, 1) # 1 entry containing the list of inserts
|
||||
|
||||
self.assertEqual(Blog.objects.count(), len(blogs))
|
||||
|
||||
@@ -876,11 +878,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
Blog.objects.insert(blogs)
|
||||
|
||||
if MONGO_VER >= MONGODB_32:
|
||||
self.assertEqual(q, 2) # 1 for insert 1 for fetch
|
||||
else:
|
||||
self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs
|
||||
self.assertEqual(q, 2) # 1 for insert 1 for fetch
|
||||
|
||||
Blog.drop_collection()
|
||||
|
||||
@@ -900,13 +898,19 @@ class QuerySetTest(unittest.TestCase):
|
||||
with self.assertRaises(OperationError) as cm:
|
||||
blog = Blog.objects.first()
|
||||
Blog.objects.insert(blog)
|
||||
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
'Some documents have ObjectIds, use doc.update() instead'
|
||||
)
|
||||
|
||||
# test inserting a query set
|
||||
with self.assertRaises(OperationError) as cm:
|
||||
blogs_qs = Blog.objects
|
||||
Blog.objects.insert(blogs_qs)
|
||||
self.assertEqual(str(cm.exception), 'Some documents have ObjectIds use doc.update() instead')
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
'Some documents have ObjectIds, use doc.update() instead'
|
||||
)
|
||||
|
||||
# insert 1 new doc
|
||||
new_post = Blog(title="code123", id=ObjectId())
|
||||
@@ -986,6 +990,29 @@ class QuerySetTest(unittest.TestCase):
|
||||
inserted_comment_id = Comment.objects.insert(comment, load_bulk=False)
|
||||
self.assertEqual(comment.id, inserted_comment_id)
|
||||
|
||||
def test_bulk_insert_accepts_doc_with_ids(self):
|
||||
class Comment(Document):
|
||||
id = IntField(primary_key=True)
|
||||
|
||||
Comment.drop_collection()
|
||||
|
||||
com1 = Comment(id=0)
|
||||
com2 = Comment(id=1)
|
||||
Comment.objects.insert([com1, com2])
|
||||
|
||||
def test_insert_raise_if_duplicate_in_constraint(self):
|
||||
class Comment(Document):
|
||||
id = IntField(primary_key=True)
|
||||
|
||||
Comment.drop_collection()
|
||||
|
||||
com1 = Comment(id=0)
|
||||
|
||||
Comment.objects.insert(com1)
|
||||
|
||||
with self.assertRaises(NotUniqueError):
|
||||
Comment.objects.insert(com1)
|
||||
|
||||
def test_get_changed_fields_query_count(self):
|
||||
"""Make sure we don't perform unnecessary db operations when
|
||||
none of document's fields were updated.
|
||||
@@ -1047,48 +1074,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
org.save() # saves the org
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
@skip_pymongo3
|
||||
def test_slave_okay(self):
|
||||
"""Ensures that a query can take slave_okay syntax.
|
||||
Useless with PyMongo 3+ as well as with MongoDB 3+.
|
||||
"""
|
||||
person1 = self.Person(name="User A", age=20)
|
||||
person1.save()
|
||||
person2 = self.Person(name="User B", age=30)
|
||||
person2.save()
|
||||
|
||||
# Retrieve the first person from the database
|
||||
person = self.Person.objects.slave_okay(True).first()
|
||||
self.assertIsInstance(person, self.Person)
|
||||
self.assertEqual(person.name, "User A")
|
||||
self.assertEqual(person.age, 20)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
@skip_pymongo3
|
||||
def test_cursor_args(self):
|
||||
"""Ensures the cursor args can be set as expected
|
||||
"""
|
||||
p = self.Person.objects
|
||||
# Check default
|
||||
self.assertEqual(p._cursor_args,
|
||||
{'snapshot': False, 'slave_okay': False, 'timeout': True})
|
||||
|
||||
p = p.snapshot(False).slave_okay(False).timeout(False)
|
||||
self.assertEqual(p._cursor_args,
|
||||
{'snapshot': False, 'slave_okay': False, 'timeout': False})
|
||||
|
||||
p = p.snapshot(True).slave_okay(False).timeout(False)
|
||||
self.assertEqual(p._cursor_args,
|
||||
{'snapshot': True, 'slave_okay': False, 'timeout': False})
|
||||
|
||||
p = p.snapshot(True).slave_okay(True).timeout(False)
|
||||
self.assertEqual(p._cursor_args,
|
||||
{'snapshot': True, 'slave_okay': True, 'timeout': False})
|
||||
|
||||
p = p.snapshot(True).slave_okay(True).timeout(True)
|
||||
self.assertEqual(p._cursor_args,
|
||||
{'snapshot': True, 'slave_okay': True, 'timeout': True})
|
||||
|
||||
def test_repeated_iteration(self):
|
||||
"""Ensure that QuerySet rewinds itself one iteration finishes.
|
||||
"""
|
||||
@@ -1323,8 +1308,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
"""Ensure that the default ordering can be cleared by calling
|
||||
order_by() w/o any arguments.
|
||||
"""
|
||||
MONGO_VER = self.mongodb_version
|
||||
ORDER_BY_KEY = 'sort' if MONGO_VER >= MONGODB_32 else '$orderby'
|
||||
ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version)
|
||||
|
||||
class BlogPost(Document):
|
||||
title = StringField()
|
||||
@@ -1341,7 +1325,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
BlogPost.objects.filter(title='whatever').first()
|
||||
self.assertEqual(len(q.get_ops()), 1)
|
||||
self.assertEqual(
|
||||
q.get_ops()[0]['query'][ORDER_BY_KEY],
|
||||
q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY],
|
||||
{'published_date': -1}
|
||||
)
|
||||
|
||||
@@ -1349,14 +1333,14 @@ class QuerySetTest(unittest.TestCase):
|
||||
with db_ops_tracker() as q:
|
||||
BlogPost.objects.filter(title='whatever').order_by().first()
|
||||
self.assertEqual(len(q.get_ops()), 1)
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query'])
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
|
||||
|
||||
# calling an explicit order_by should use a specified sort
|
||||
with db_ops_tracker() as q:
|
||||
BlogPost.objects.filter(title='whatever').order_by('published_date').first()
|
||||
self.assertEqual(len(q.get_ops()), 1)
|
||||
self.assertEqual(
|
||||
q.get_ops()[0]['query'][ORDER_BY_KEY],
|
||||
q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY],
|
||||
{'published_date': 1}
|
||||
)
|
||||
|
||||
@@ -1365,13 +1349,12 @@ class QuerySetTest(unittest.TestCase):
|
||||
qs = BlogPost.objects.filter(title='whatever').order_by('published_date')
|
||||
qs.order_by().first()
|
||||
self.assertEqual(len(q.get_ops()), 1)
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query'])
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
|
||||
|
||||
def test_no_ordering_for_get(self):
|
||||
""" Ensure that Doc.objects.get doesn't use any ordering.
|
||||
"""
|
||||
MONGO_VER = self.mongodb_version
|
||||
ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby'
|
||||
ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version)
|
||||
|
||||
class BlogPost(Document):
|
||||
title = StringField()
|
||||
@@ -1387,13 +1370,13 @@ class QuerySetTest(unittest.TestCase):
|
||||
with db_ops_tracker() as q:
|
||||
BlogPost.objects.get(title='whatever')
|
||||
self.assertEqual(len(q.get_ops()), 1)
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query'])
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
|
||||
|
||||
# Ordering should be ignored for .get even if we set it explicitly
|
||||
with db_ops_tracker() as q:
|
||||
BlogPost.objects.order_by('-title').get(title='whatever')
|
||||
self.assertEqual(len(q.get_ops()), 1)
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query'])
|
||||
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
|
||||
|
||||
def test_find_embedded(self):
|
||||
"""Ensure that an embedded document is properly returned from
|
||||
@@ -2042,7 +2025,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
pymongo_doc = BlogPost.objects.as_pymongo().first()
|
||||
self.assertNotIn('title', pymongo_doc)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_update_push_with_position(self):
|
||||
"""Ensure that the 'push' update with position works properly.
|
||||
"""
|
||||
@@ -2193,6 +2175,40 @@ class QuerySetTest(unittest.TestCase):
|
||||
Site.objects(id=s.id).update_one(
|
||||
pull_all__collaborators__helpful__name=['Ross'])
|
||||
|
||||
def test_pull_from_nested_embedded_using_in_nin(self):
|
||||
"""Ensure that the 'pull' update operation works on embedded documents using 'in' and 'nin' operators.
|
||||
"""
|
||||
|
||||
class User(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
def __unicode__(self):
|
||||
return '%s' % self.name
|
||||
|
||||
class Collaborator(EmbeddedDocument):
|
||||
helpful = ListField(EmbeddedDocumentField(User))
|
||||
unhelpful = ListField(EmbeddedDocumentField(User))
|
||||
|
||||
class Site(Document):
|
||||
name = StringField(max_length=75, unique=True, required=True)
|
||||
collaborators = EmbeddedDocumentField(Collaborator)
|
||||
|
||||
Site.drop_collection()
|
||||
|
||||
a = User(name='Esteban')
|
||||
b = User(name='Frank')
|
||||
x = User(name='Harry')
|
||||
y = User(name='John')
|
||||
|
||||
s = Site(name="test", collaborators=Collaborator(
|
||||
helpful=[a, b], unhelpful=[x, y])).save()
|
||||
|
||||
Site.objects(id=s.id).update_one(pull__collaborators__helpful__name__in=['Esteban']) # Pull a
|
||||
self.assertEqual(Site.objects.first().collaborators['helpful'], [b])
|
||||
|
||||
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful__name__nin=['John']) # Pull x
|
||||
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [y])
|
||||
|
||||
def test_pull_from_nested_mapfield(self):
|
||||
|
||||
class Collaborator(EmbeddedDocument):
|
||||
@@ -2532,8 +2548,9 @@ class QuerySetTest(unittest.TestCase):
|
||||
def test_comment(self):
|
||||
"""Make sure adding a comment to the query gets added to the query"""
|
||||
MONGO_VER = self.mongodb_version
|
||||
QUERY_KEY = 'filter' if MONGO_VER >= MONGODB_32 else '$query'
|
||||
COMMENT_KEY = 'comment' if MONGO_VER >= MONGODB_32 else '$comment'
|
||||
_, CMD_QUERY_KEY = get_key_compat(MONGO_VER)
|
||||
QUERY_KEY = 'filter'
|
||||
COMMENT_KEY = 'comment'
|
||||
|
||||
class User(Document):
|
||||
age = IntField()
|
||||
@@ -2550,8 +2567,8 @@ class QuerySetTest(unittest.TestCase):
|
||||
ops = q.get_ops()
|
||||
self.assertEqual(len(ops), 2)
|
||||
for op in ops:
|
||||
self.assertEqual(op['query'][QUERY_KEY], {'age': {'$gte': 18}})
|
||||
self.assertEqual(op['query'][COMMENT_KEY], 'looking for an adult')
|
||||
self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {'age': {'$gte': 18}})
|
||||
self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], 'looking for an adult')
|
||||
|
||||
def test_map_reduce(self):
|
||||
"""Ensure map/reduce is both mapping and reducing.
|
||||
@@ -3347,7 +3364,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(Foo.objects.distinct("bar"), [bar])
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_text_indexes(self):
|
||||
class News(Document):
|
||||
title = StringField()
|
||||
@@ -3415,10 +3431,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(query.count(), 3)
|
||||
self.assertEqual(query._query, {'$text': {'$search': 'brasil'}})
|
||||
cursor_args = query._cursor_args
|
||||
if not IS_PYMONGO_3:
|
||||
cursor_args_fields = cursor_args['fields']
|
||||
else:
|
||||
cursor_args_fields = cursor_args['projection']
|
||||
cursor_args_fields = cursor_args['projection']
|
||||
self.assertEqual(
|
||||
cursor_args_fields, {'_text_score': {'$meta': 'textScore'}})
|
||||
|
||||
@@ -3434,7 +3447,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
'brasil').order_by('$text_score').first()
|
||||
self.assertEqual(item.get_text_score(), max_text_score)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_distinct_handles_references_to_alias(self):
|
||||
register_connection('testdb', 'mongoenginetest2')
|
||||
|
||||
@@ -3570,6 +3582,11 @@ class QuerySetTest(unittest.TestCase):
|
||||
opts = {"deleted": False}
|
||||
return qryset(**opts)
|
||||
|
||||
@queryset_manager
|
||||
def objects_1_arg(qryset):
|
||||
opts = {"deleted": False}
|
||||
return qryset(**opts)
|
||||
|
||||
@queryset_manager
|
||||
def music_posts(doc_cls, queryset, deleted=False):
|
||||
return queryset(tags='music',
|
||||
@@ -3584,6 +3601,8 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual([p.id for p in BlogPost.objects()],
|
||||
[post1.id, post2.id, post3.id])
|
||||
self.assertEqual([p.id for p in BlogPost.objects_1_arg()],
|
||||
[post1.id, post2.id, post3.id])
|
||||
self.assertEqual([p.id for p in BlogPost.music_posts()],
|
||||
[post1.id, post2.id])
|
||||
|
||||
@@ -4511,11 +4530,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY))
|
||||
self.assertEqual([], bars)
|
||||
|
||||
if not IS_PYMONGO_3:
|
||||
error_class = ConfigurationError
|
||||
else:
|
||||
error_class = TypeError
|
||||
self.assertRaises(error_class, Bar.objects, read_preference='Primary')
|
||||
self.assertRaises(TypeError, Bar.objects, read_preference='Primary')
|
||||
|
||||
# read_preference as a kwarg
|
||||
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
|
||||
@@ -4563,7 +4578,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_read_preference_aggregation_framework(self):
|
||||
class Bar(Document):
|
||||
txt = StringField()
|
||||
@@ -4575,12 +4589,8 @@ class QuerySetTest(unittest.TestCase):
|
||||
bars = Bar.objects \
|
||||
.read_preference(ReadPreference.SECONDARY_PREFERRED) \
|
||||
.aggregate()
|
||||
if IS_PYMONGO_3:
|
||||
self.assertEqual(bars._CommandCursor__collection.read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
else:
|
||||
self.assertNotEqual(bars._CommandCursor__collection.read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._CommandCursor__collection.read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
def test_json_simple(self):
|
||||
|
||||
@@ -4602,9 +4612,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
|
||||
|
||||
def test_json_complex(self):
|
||||
if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3:
|
||||
raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs")
|
||||
|
||||
class EmbeddedDoc(EmbeddedDocument):
|
||||
pass
|
||||
|
||||
@@ -4971,6 +4978,38 @@ class QuerySetTest(unittest.TestCase):
|
||||
people.count()
|
||||
self.assertEqual(q, 3)
|
||||
|
||||
def test_no_cached_queryset__repr__(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
|
||||
Person.drop_collection()
|
||||
qs = Person.objects.no_cache()
|
||||
self.assertEqual(repr(qs), '[]')
|
||||
|
||||
def test_no_cached_on_a_cached_queryset_raise_error(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
|
||||
Person.drop_collection()
|
||||
Person(name='a').save()
|
||||
qs = Person.objects()
|
||||
_ = list(qs)
|
||||
with self.assertRaises(OperationError) as ctx_err:
|
||||
qs.no_cache()
|
||||
self.assertEqual("QuerySet already cached", str(ctx_err.exception))
|
||||
|
||||
def test_no_cached_queryset_no_cache_back_to_cache(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
|
||||
Person.drop_collection()
|
||||
qs = Person.objects()
|
||||
self.assertIsInstance(qs, QuerySet)
|
||||
qs = qs.no_cache()
|
||||
self.assertIsInstance(qs, QuerySetNoCache)
|
||||
qs = qs.cache()
|
||||
self.assertIsInstance(qs, QuerySet)
|
||||
|
||||
def test_cache_not_cloned(self):
|
||||
|
||||
class User(Document):
|
||||
@@ -5243,8 +5282,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(op['nreturned'], 1)
|
||||
|
||||
def test_bool_with_ordering(self):
|
||||
MONGO_VER = self.mongodb_version
|
||||
ORDER_BY_KEY = 'sort' if MONGO_VER >= MONGODB_32 else '$orderby'
|
||||
ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version)
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
@@ -5263,21 +5301,22 @@ class QuerySetTest(unittest.TestCase):
|
||||
op = q.db.system.profile.find({"ns":
|
||||
{"$ne": "%s.system.indexes" % q.db.name}})[0]
|
||||
|
||||
self.assertNotIn(ORDER_BY_KEY, op['query'])
|
||||
self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY])
|
||||
|
||||
# Check that normal query uses orderby
|
||||
qs2 = Person.objects.order_by('name')
|
||||
with query_counter() as p:
|
||||
with query_counter() as q:
|
||||
|
||||
for x in qs2:
|
||||
pass
|
||||
|
||||
op = p.db.system.profile.find({"ns":
|
||||
op = q.db.system.profile.find({"ns":
|
||||
{"$ne": "%s.system.indexes" % q.db.name}})[0]
|
||||
|
||||
self.assertIn(ORDER_BY_KEY, op['query'])
|
||||
self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY])
|
||||
|
||||
def test_bool_with_ordering_from_meta_dict(self):
|
||||
ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version)
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
@@ -5299,14 +5338,13 @@ class QuerySetTest(unittest.TestCase):
|
||||
op = q.db.system.profile.find({"ns":
|
||||
{"$ne": "%s.system.indexes" % q.db.name}})[0]
|
||||
|
||||
self.assertNotIn('$orderby', op['query'],
|
||||
self.assertNotIn('$orderby', op[CMD_QUERY_KEY],
|
||||
'BaseQuerySet must remove orderby from meta in boolen test')
|
||||
|
||||
self.assertEqual(Person.objects.first().name, 'A')
|
||||
self.assertTrue(Person.objects._has_data(),
|
||||
'Cursor has data and returned False')
|
||||
|
||||
@requires_mongodb_gte_26
|
||||
def test_queryset_aggregation_framework(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
@@ -5315,13 +5353,9 @@ class QuerySetTest(unittest.TestCase):
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p1.save()
|
||||
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p2.save()
|
||||
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
p3.save()
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects(age__lte=22).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
@@ -5352,6 +5386,179 @@ class QuerySetTest(unittest.TestCase):
|
||||
{'_id': None, 'avg': 29, 'total': 2}
|
||||
])
|
||||
|
||||
def test_queryset_aggregation_with_skip(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects.skip(1).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p2.pk, 'name': "WILSON JUNIOR"},
|
||||
{'_id': p3.pk, 'name': "SANDRA MARA"}
|
||||
])
|
||||
|
||||
def test_queryset_aggregation_with_limit(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects.limit(1).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p1.pk, 'name': "ISABELLA LUANNA"}
|
||||
])
|
||||
|
||||
def test_queryset_aggregation_with_sort(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects.order_by('name').aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p1.pk, 'name': "ISABELLA LUANNA"},
|
||||
{'_id': p3.pk, 'name': "SANDRA MARA"},
|
||||
{'_id': p2.pk, 'name': "WILSON JUNIOR"}
|
||||
])
|
||||
|
||||
def test_queryset_aggregation_with_skip_with_limit(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = list(
|
||||
Person.objects.skip(1).limit(1).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p2.pk, 'name': "WILSON JUNIOR"},
|
||||
])
|
||||
|
||||
# Make sure limit/skip chaining order has no impact
|
||||
data2 = Person.objects.limit(1).skip(1).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(data, list(data2))
|
||||
|
||||
def test_queryset_aggregation_with_sort_with_limit(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects.order_by('name').limit(2).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p1.pk, 'name': "ISABELLA LUANNA"},
|
||||
{'_id': p3.pk, 'name': "SANDRA MARA"}
|
||||
])
|
||||
|
||||
# Verify adding limit/skip steps works as expected
|
||||
data = Person.objects.order_by('name').limit(2).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}},
|
||||
{'$limit': 1},
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p1.pk, 'name': "ISABELLA LUANNA"},
|
||||
])
|
||||
|
||||
data = Person.objects.order_by('name').limit(2).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}},
|
||||
{'$skip': 1},
|
||||
{'$limit': 1},
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p3.pk, 'name': "SANDRA MARA"},
|
||||
])
|
||||
|
||||
def test_queryset_aggregation_with_sort_with_skip(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects.order_by('name').skip(2).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p2.pk, 'name': "WILSON JUNIOR"}
|
||||
])
|
||||
|
||||
def test_queryset_aggregation_with_sort_with_skip_with_limit(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
p1 = Person(name="Isabella Luanna", age=16)
|
||||
p2 = Person(name="Wilson Junior", age=21)
|
||||
p3 = Person(name="Sandra Mara", age=37)
|
||||
Person.objects.insert([p1, p2, p3])
|
||||
|
||||
data = Person.objects.order_by('name').skip(1).limit(1).aggregate(
|
||||
{'$project': {'name': {'$toUpper': '$name'}}}
|
||||
)
|
||||
|
||||
self.assertEqual(list(data), [
|
||||
{'_id': p3.pk, 'name': "SANDRA MARA"}
|
||||
])
|
||||
|
||||
def test_delete_count(self):
|
||||
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
|
||||
self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count
|
||||
@@ -5385,8 +5592,8 @@ class QuerySetTest(unittest.TestCase):
|
||||
Animal(is_mamal=False).save()
|
||||
Cat(is_mamal=True, whiskers_length=5.1).save()
|
||||
ScottishCat(is_mamal=True, folded_ears=True).save()
|
||||
self.assertEquals(Animal.objects(folded_ears=True).count(), 1)
|
||||
self.assertEquals(Animal.objects(whiskers_length=5.1).count(), 1)
|
||||
self.assertEqual(Animal.objects(folded_ears=True).count(), 1)
|
||||
self.assertEqual(Animal.objects(whiskers_length=5.1).count(), 1)
|
||||
|
||||
def test_loop_over_invalid_id_does_not_crash(self):
|
||||
class Person(Document):
|
||||
|
@@ -71,6 +71,14 @@ class TransformTest(unittest.TestCase):
|
||||
update = transform.update(BlogPost, push_all__tags=['mongo', 'db'])
|
||||
self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}})
|
||||
|
||||
def test_transform_update_no_operator_default_to_set(self):
|
||||
"""Ensure the differences in behvaior between 'push' and 'push_all'"""
|
||||
class BlogPost(Document):
|
||||
tags = ListField(StringField())
|
||||
|
||||
update = transform.update(BlogPost, tags=['mongo', 'db'])
|
||||
self.assertEqual(update, {'$set': {'tags': ['mongo', 'db']}})
|
||||
|
||||
def test_query_field_name(self):
|
||||
"""Ensure that the correct field name is used when querying.
|
||||
"""
|
||||
@@ -283,6 +291,11 @@ class TransformTest(unittest.TestCase):
|
||||
update = transform.update(MainDoc, pull__content__heading='xyz')
|
||||
self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}})
|
||||
|
||||
update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar'])
|
||||
self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}})
|
||||
|
||||
update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar'])
|
||||
self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}})
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
15
tests/test_common.py
Normal file
15
tests/test_common.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import unittest
|
||||
|
||||
from mongoengine.common import _import_class
|
||||
from mongoengine import Document
|
||||
|
||||
|
||||
class TestCommon(unittest.TestCase):
|
||||
|
||||
def test__import_class(self):
|
||||
doc_cls = _import_class("Document")
|
||||
self.assertIs(doc_cls, Document)
|
||||
|
||||
def test__import_class_raise_if_not_known(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_import_class("UnknownClass")
|
@@ -1,5 +1,8 @@
|
||||
import datetime
|
||||
from pymongo.errors import OperationFailure
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import OperationFailure, InvalidName
|
||||
from pymongo import ReadPreference
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
@@ -12,23 +15,27 @@ from bson.tz_util import utc
|
||||
|
||||
from mongoengine import (
|
||||
connect, register_connection,
|
||||
Document, DateTimeField
|
||||
)
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
Document, DateTimeField,
|
||||
disconnect_all, StringField)
|
||||
import mongoengine.connection
|
||||
from mongoengine.connection import (MongoEngineConnectionError, get_db,
|
||||
get_connection)
|
||||
get_connection, disconnect, DEFAULT_DATABASE_NAME)
|
||||
|
||||
|
||||
def get_tz_awareness(connection):
|
||||
if not IS_PYMONGO_3:
|
||||
return connection.tz_aware
|
||||
else:
|
||||
return connection.codec_options.tz_aware
|
||||
return connection.codec_options.tz_aware
|
||||
|
||||
|
||||
class ConnectionTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
disconnect_all()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
disconnect_all()
|
||||
|
||||
def tearDown(self):
|
||||
mongoengine.connection._connection_settings = {}
|
||||
mongoengine.connection._connections = {}
|
||||
@@ -49,6 +56,147 @@ class ConnectionTest(unittest.TestCase):
|
||||
conn = get_connection('testdb')
|
||||
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
|
||||
|
||||
def test_connect_disconnect_works_properly(self):
|
||||
class History1(Document):
|
||||
name = StringField()
|
||||
meta = {'db_alias': 'db1'}
|
||||
|
||||
class History2(Document):
|
||||
name = StringField()
|
||||
meta = {'db_alias': 'db2'}
|
||||
|
||||
connect('db1', alias='db1')
|
||||
connect('db2', alias='db2')
|
||||
|
||||
History1.drop_collection()
|
||||
History2.drop_collection()
|
||||
|
||||
h = History1(name='default').save()
|
||||
h1 = History2(name='db1').save()
|
||||
|
||||
self.assertEqual(list(History1.objects().as_pymongo()),
|
||||
[{'_id': h.id, 'name': 'default'}])
|
||||
self.assertEqual(list(History2.objects().as_pymongo()),
|
||||
[{'_id': h1.id, 'name': 'db1'}])
|
||||
|
||||
disconnect('db1')
|
||||
disconnect('db2')
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
list(History1.objects().as_pymongo())
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
list(History2.objects().as_pymongo())
|
||||
|
||||
connect('db1', alias='db1')
|
||||
connect('db2', alias='db2')
|
||||
|
||||
self.assertEqual(list(History1.objects().as_pymongo()),
|
||||
[{'_id': h.id, 'name': 'default'}])
|
||||
self.assertEqual(list(History2.objects().as_pymongo()),
|
||||
[{'_id': h1.id, 'name': 'db1'}])
|
||||
|
||||
def test_connect_different_documents_to_different_database(self):
|
||||
class History(Document):
|
||||
name = StringField()
|
||||
|
||||
class History1(Document):
|
||||
name = StringField()
|
||||
meta = {'db_alias': 'db1'}
|
||||
|
||||
class History2(Document):
|
||||
name = StringField()
|
||||
meta = {'db_alias': 'db2'}
|
||||
|
||||
connect()
|
||||
connect('db1', alias='db1')
|
||||
connect('db2', alias='db2')
|
||||
|
||||
History.drop_collection()
|
||||
History1.drop_collection()
|
||||
History2.drop_collection()
|
||||
|
||||
h = History(name='default').save()
|
||||
h1 = History1(name='db1').save()
|
||||
h2 = History2(name='db2').save()
|
||||
|
||||
self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME)
|
||||
self.assertEqual(History1._collection.database.name, 'db1')
|
||||
self.assertEqual(History2._collection.database.name, 'db2')
|
||||
|
||||
self.assertEqual(list(History.objects().as_pymongo()),
|
||||
[{'_id': h.id, 'name': 'default'}])
|
||||
self.assertEqual(list(History1.objects().as_pymongo()),
|
||||
[{'_id': h1.id, 'name': 'db1'}])
|
||||
self.assertEqual(list(History2.objects().as_pymongo()),
|
||||
[{'_id': h2.id, 'name': 'db2'}])
|
||||
|
||||
def test_connect_fails_if_connect_2_times_with_default_alias(self):
|
||||
connect('mongoenginetest')
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError) as ctx_err:
|
||||
connect('mongoenginetest2')
|
||||
self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception))
|
||||
|
||||
def test_connect_fails_if_connect_2_times_with_custom_alias(self):
|
||||
connect('mongoenginetest', alias='alias1')
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError) as ctx_err:
|
||||
connect('mongoenginetest2', alias='alias1')
|
||||
|
||||
self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception))
|
||||
|
||||
def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self):
|
||||
"""Intended to keep the detecton function simple but robust"""
|
||||
db_name = 'mongoenginetest'
|
||||
db_alias = 'alias1'
|
||||
connect(db=db_name, alias=db_alias, host='localhost', port=27017)
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias)
|
||||
|
||||
def test_connect_passes_silently_connect_multiple_times_with_same_config(self):
|
||||
# test default connection to `test`
|
||||
connect()
|
||||
connect()
|
||||
self.assertEqual(len(mongoengine.connection._connections), 1)
|
||||
connect('test01', alias='test01')
|
||||
connect('test01', alias='test01')
|
||||
self.assertEqual(len(mongoengine.connection._connections), 2)
|
||||
connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02')
|
||||
connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02')
|
||||
self.assertEqual(len(mongoengine.connection._connections), 3)
|
||||
|
||||
def test_connect_with_invalid_db_name(self):
|
||||
"""Ensure that connect() method fails fast if db name is invalid
|
||||
"""
|
||||
with self.assertRaises(InvalidName):
|
||||
connect('mongomock://localhost')
|
||||
|
||||
def test_connect_with_db_name_external(self):
|
||||
"""Ensure that connect() works if db name is $external
|
||||
"""
|
||||
"""Ensure that the connect() method works properly."""
|
||||
connect('$external')
|
||||
|
||||
conn = get_connection()
|
||||
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
|
||||
|
||||
db = get_db()
|
||||
self.assertIsInstance(db, pymongo.database.Database)
|
||||
self.assertEqual(db.name, '$external')
|
||||
|
||||
connect('$external', alias='testdb')
|
||||
conn = get_connection('testdb')
|
||||
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
|
||||
|
||||
def test_connect_with_invalid_db_name_type(self):
|
||||
"""Ensure that connect() method fails fast if db name has invalid type
|
||||
"""
|
||||
with self.assertRaises(TypeError):
|
||||
non_string_db_name = ['e. g. list instead of a string']
|
||||
connect(non_string_db_name)
|
||||
|
||||
def test_connect_in_mocking(self):
|
||||
"""Ensure that the connect() method works properly in mocking.
|
||||
"""
|
||||
@@ -119,13 +267,133 @@ class ConnectionTest(unittest.TestCase):
|
||||
conn = get_connection('testdb6')
|
||||
self.assertIsInstance(conn, mongomock.MongoClient)
|
||||
|
||||
def test_disconnect(self):
|
||||
"""Ensure that the disconnect() method works properly
|
||||
"""
|
||||
def test_disconnect_cleans_globals(self):
|
||||
"""Ensure that the disconnect() method cleans the globals objects"""
|
||||
connections = mongoengine.connection._connections
|
||||
dbs = mongoengine.connection._dbs
|
||||
connection_settings = mongoengine.connection._connection_settings
|
||||
|
||||
connect('mongoenginetest')
|
||||
|
||||
self.assertEqual(len(connections), 1)
|
||||
self.assertEqual(len(dbs), 0)
|
||||
self.assertEqual(len(connection_settings), 1)
|
||||
|
||||
class TestDoc(Document):
|
||||
pass
|
||||
|
||||
TestDoc.drop_collection() # triggers the db
|
||||
self.assertEqual(len(dbs), 1)
|
||||
|
||||
disconnect()
|
||||
self.assertEqual(len(connections), 0)
|
||||
self.assertEqual(len(dbs), 0)
|
||||
self.assertEqual(len(connection_settings), 0)
|
||||
|
||||
def test_disconnect_cleans_cached_collection_attribute_in_document(self):
|
||||
"""Ensure that the disconnect() method works properly"""
|
||||
conn1 = connect('mongoenginetest')
|
||||
mongoengine.connection.disconnect()
|
||||
conn2 = connect('mongoenginetest')
|
||||
self.assertTrue(conn1 is not conn2)
|
||||
|
||||
class History(Document):
|
||||
pass
|
||||
|
||||
self.assertIsNone(History._collection)
|
||||
|
||||
History.drop_collection()
|
||||
|
||||
History.objects.first() # will trigger the caching of _collection attribute
|
||||
self.assertIsNotNone(History._collection)
|
||||
|
||||
disconnect()
|
||||
|
||||
self.assertIsNone(History._collection)
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError) as ctx_err:
|
||||
History.objects.first()
|
||||
self.assertEqual("You have not defined a default connection", str(ctx_err.exception))
|
||||
|
||||
def test_connect_disconnect_works_on_same_document(self):
|
||||
"""Ensure that the connect/disconnect works properly with a single Document"""
|
||||
db1 = 'db1'
|
||||
db2 = 'db2'
|
||||
|
||||
# Ensure freshness of the 2 databases through pymongo
|
||||
client = MongoClient('localhost', 27017)
|
||||
client.drop_database(db1)
|
||||
client.drop_database(db2)
|
||||
|
||||
# Save in db1
|
||||
connect(db1)
|
||||
|
||||
class User(Document):
|
||||
name = StringField(required=True)
|
||||
|
||||
user1 = User(name='John is in db1').save()
|
||||
disconnect()
|
||||
|
||||
# Make sure save doesnt work at this stage
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
User(name='Wont work').save()
|
||||
|
||||
# Save in db2
|
||||
connect(db2)
|
||||
user2 = User(name='Bob is in db2').save()
|
||||
disconnect()
|
||||
|
||||
db1_users = list(client[db1].user.find())
|
||||
self.assertEqual(db1_users, [{'_id': user1.id, 'name': 'John is in db1'}])
|
||||
db2_users = list(client[db2].user.find())
|
||||
self.assertEqual(db2_users, [{'_id': user2.id, 'name': 'Bob is in db2'}])
|
||||
|
||||
def test_disconnect_silently_pass_if_alias_does_not_exist(self):
|
||||
connections = mongoengine.connection._connections
|
||||
self.assertEqual(len(connections), 0)
|
||||
disconnect(alias='not_exist')
|
||||
|
||||
def test_disconnect_all(self):
|
||||
connections = mongoengine.connection._connections
|
||||
dbs = mongoengine.connection._dbs
|
||||
connection_settings = mongoengine.connection._connection_settings
|
||||
|
||||
connect('mongoenginetest')
|
||||
connect('mongoenginetest2', alias='db1')
|
||||
|
||||
class History(Document):
|
||||
pass
|
||||
|
||||
class History1(Document):
|
||||
name = StringField()
|
||||
meta = {'db_alias': 'db1'}
|
||||
|
||||
History.drop_collection() # will trigger the caching of _collection attribute
|
||||
History.objects.first()
|
||||
History1.drop_collection()
|
||||
History1.objects.first()
|
||||
|
||||
self.assertIsNotNone(History._collection)
|
||||
self.assertIsNotNone(History1._collection)
|
||||
|
||||
self.assertEqual(len(connections), 2)
|
||||
self.assertEqual(len(dbs), 2)
|
||||
self.assertEqual(len(connection_settings), 2)
|
||||
|
||||
disconnect_all()
|
||||
|
||||
self.assertIsNone(History._collection)
|
||||
self.assertIsNone(History1._collection)
|
||||
|
||||
self.assertEqual(len(connections), 0)
|
||||
self.assertEqual(len(dbs), 0)
|
||||
self.assertEqual(len(connection_settings), 0)
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
History.objects.first()
|
||||
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
History1.objects.first()
|
||||
|
||||
def test_disconnect_all_silently_pass_if_no_connection_exist(self):
|
||||
disconnect_all()
|
||||
|
||||
def test_sharing_connections(self):
|
||||
"""Ensure that connections are shared when the connection settings are exactly the same
|
||||
@@ -136,11 +404,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
connect('mongoenginetests', alias='testdb2')
|
||||
actual_connection = get_connection('testdb2')
|
||||
|
||||
# Handle PyMongo 3+ Async Connection
|
||||
if IS_PYMONGO_3:
|
||||
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
|
||||
# Purposely not catching exception to fail test if thrown.
|
||||
expected_connection.server_info()
|
||||
expected_connection.server_info()
|
||||
|
||||
self.assertEqual(expected_connection, actual_connection)
|
||||
|
||||
@@ -154,12 +418,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
c.admin.authenticate("admin", "password")
|
||||
c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"])
|
||||
|
||||
if not IS_PYMONGO_3:
|
||||
self.assertRaises(
|
||||
MongoEngineConnectionError, connect, 'testdb_uri_bad',
|
||||
host='mongodb://test:password@localhost'
|
||||
)
|
||||
|
||||
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
|
||||
|
||||
conn = get_connection()
|
||||
@@ -222,19 +480,11 @@ class ConnectionTest(unittest.TestCase):
|
||||
c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"])
|
||||
|
||||
# Authentication fails without "authSource"
|
||||
if IS_PYMONGO_3:
|
||||
test_conn = connect(
|
||||
'mongoenginetest', alias='test1',
|
||||
host='mongodb://username2:password@localhost/mongoenginetest'
|
||||
)
|
||||
self.assertRaises(OperationFailure, test_conn.server_info)
|
||||
else:
|
||||
self.assertRaises(
|
||||
MongoEngineConnectionError,
|
||||
connect, 'mongoenginetest', alias='test1',
|
||||
host='mongodb://username2:password@localhost/mongoenginetest'
|
||||
)
|
||||
self.assertRaises(MongoEngineConnectionError, get_db, 'test1')
|
||||
test_conn = connect(
|
||||
'mongoenginetest', alias='test1',
|
||||
host='mongodb://username2:password@localhost/mongoenginetest'
|
||||
)
|
||||
self.assertRaises(OperationFailure, test_conn.server_info)
|
||||
|
||||
# Authentication succeeds with "authSource"
|
||||
authd_conn = connect(
|
||||
@@ -285,14 +535,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
"""Ensure we can specify a max connection pool size using
|
||||
a connection kwarg.
|
||||
"""
|
||||
# Use "max_pool_size" or "maxpoolsize" depending on PyMongo version
|
||||
# (former was changed to the latter as described in
|
||||
# https://jira.mongodb.org/browse/PYTHON-854).
|
||||
# TODO remove once PyMongo < 3.0 support is dropped
|
||||
if pymongo.version_tuple[0] >= 3:
|
||||
pool_size_kwargs = {'maxpoolsize': 100}
|
||||
else:
|
||||
pool_size_kwargs = {'max_pool_size': 100}
|
||||
pool_size_kwargs = {'maxpoolsize': 100}
|
||||
|
||||
conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs)
|
||||
self.assertEqual(conn.max_pool_size, 100)
|
||||
@@ -301,9 +544,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
"""Ensure we can specify a max connection pool size using
|
||||
an option in a connection URI.
|
||||
"""
|
||||
if pymongo.version_tuple[0] == 2 and pymongo.version_tuple[1] < 9:
|
||||
raise SkipTest('maxpoolsize as a URI option is only supported in PyMongo v2.9+')
|
||||
|
||||
conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri')
|
||||
self.assertEqual(conn.max_pool_size, 100)
|
||||
|
||||
@@ -313,46 +553,30 @@ class ConnectionTest(unittest.TestCase):
|
||||
"""
|
||||
conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true')
|
||||
conn2 = connect('testing', alias='conn2', w=1, j=True)
|
||||
if IS_PYMONGO_3:
|
||||
self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True})
|
||||
self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
|
||||
else:
|
||||
self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True})
|
||||
self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True})
|
||||
self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True})
|
||||
self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
|
||||
|
||||
def test_connect_with_replicaset_via_uri(self):
|
||||
"""Ensure connect() works when specifying a replicaSet via the
|
||||
MongoDB URI.
|
||||
"""
|
||||
if IS_PYMONGO_3:
|
||||
c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
|
||||
db = get_db()
|
||||
self.assertIsInstance(db, pymongo.database.Database)
|
||||
self.assertEqual(db.name, 'test')
|
||||
else:
|
||||
# PyMongo < v3.x raises an exception:
|
||||
# "localhost:27017 is not a member of replica set local-rs"
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
|
||||
c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
|
||||
db = get_db()
|
||||
self.assertIsInstance(db, pymongo.database.Database)
|
||||
self.assertEqual(db.name, 'test')
|
||||
|
||||
def test_connect_with_replicaset_via_kwargs(self):
|
||||
"""Ensure connect() works when specifying a replicaSet via the
|
||||
connection kwargs
|
||||
"""
|
||||
if IS_PYMONGO_3:
|
||||
c = connect(replicaset='local-rs')
|
||||
self.assertEqual(c._MongoClient__options.replica_set_name,
|
||||
'local-rs')
|
||||
db = get_db()
|
||||
self.assertIsInstance(db, pymongo.database.Database)
|
||||
self.assertEqual(db.name, 'test')
|
||||
else:
|
||||
# PyMongo < v3.x raises an exception:
|
||||
# "localhost:27017 is not a member of replica set local-rs"
|
||||
with self.assertRaises(MongoEngineConnectionError):
|
||||
c = connect(replicaset='local-rs')
|
||||
c = connect(replicaset='local-rs')
|
||||
self.assertEqual(c._MongoClient__options.replica_set_name,
|
||||
'local-rs')
|
||||
db = get_db()
|
||||
self.assertIsInstance(db, pymongo.database.Database)
|
||||
self.assertEqual(db.name, 'test')
|
||||
|
||||
def test_datetime(self):
|
||||
def test_connect_tz_aware(self):
|
||||
connect('mongoenginetest', tz_aware=True)
|
||||
d = datetime.datetime(2010, 5, 5, tzinfo=utc)
|
||||
|
||||
@@ -366,10 +590,8 @@ class ConnectionTest(unittest.TestCase):
|
||||
self.assertEqual(d, date_doc.the_date)
|
||||
|
||||
def test_read_preference_from_parse(self):
|
||||
if IS_PYMONGO_3:
|
||||
from pymongo import ReadPreference
|
||||
conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
|
||||
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
|
||||
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
def test_multiple_connection_settings(self):
|
||||
connect('mongoenginetest', alias='t1', host="localhost")
|
||||
@@ -380,17 +602,24 @@ class ConnectionTest(unittest.TestCase):
|
||||
self.assertEqual(len(mongo_connections.items()), 2)
|
||||
self.assertIn('t1', mongo_connections.keys())
|
||||
self.assertIn('t2', mongo_connections.keys())
|
||||
if not IS_PYMONGO_3:
|
||||
self.assertEqual(mongo_connections['t1'].host, 'localhost')
|
||||
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
|
||||
else:
|
||||
# Handle PyMongo 3+ Async Connection
|
||||
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
|
||||
# Purposely not catching exception to fail test if thrown.
|
||||
mongo_connections['t1'].server_info()
|
||||
mongo_connections['t2'].server_info()
|
||||
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
|
||||
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
|
||||
|
||||
# Handle PyMongo 3+ Async Connection
|
||||
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
|
||||
# Purposely not catching exception to fail test if thrown.
|
||||
mongo_connections['t1'].server_info()
|
||||
mongo_connections['t2'].server_info()
|
||||
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
|
||||
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
|
||||
|
||||
def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self):
|
||||
c1 = connect(alias='testdb1', db='testdb1')
|
||||
c2 = connect(alias='testdb2', db='testdb2')
|
||||
self.assertIs(c1, c2)
|
||||
|
||||
def test_connect_2_databases_uses_different_client_if_different_parameters(self):
|
||||
c1 = connect(alias='testdb1', db='testdb1', username='u1')
|
||||
c2 = connect(alias='testdb2', db='testdb2', username='u2')
|
||||
self.assertIsNot(c1, c2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -37,14 +37,15 @@ class ContextManagersTest(unittest.TestCase):
|
||||
|
||||
def test_switch_collection_context_manager(self):
|
||||
connect('mongoenginetest')
|
||||
register_connection('testdb-1', 'mongoenginetest2')
|
||||
register_connection(alias='testdb-1', db='mongoenginetest2')
|
||||
|
||||
class Group(Document):
|
||||
name = StringField()
|
||||
|
||||
Group.drop_collection()
|
||||
Group.drop_collection() # drops in default
|
||||
|
||||
with switch_collection(Group, 'group1') as Group:
|
||||
Group.drop_collection()
|
||||
Group.drop_collection() # drops in group1
|
||||
|
||||
Group(name="hello - group").save()
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
@@ -269,6 +270,14 @@ class ContextManagersTest(unittest.TestCase):
|
||||
counter += 1
|
||||
self.assertEqual(q, counter)
|
||||
|
||||
self.assertEqual(int(q), counter) # test __int__
|
||||
self.assertEqual(repr(q), str(int(q))) # test __repr__
|
||||
self.assertGreater(q, -1) # test __gt__
|
||||
self.assertGreaterEqual(q, int(q)) # test __gte__
|
||||
self.assertNotEqual(q, -1)
|
||||
self.assertLess(q, 1000)
|
||||
self.assertLessEqual(q, int(q))
|
||||
|
||||
def test_query_counter_counts_getmore_queries(self):
|
||||
connect('mongoenginetest')
|
||||
db = get_db()
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from six import iterkeys
|
||||
|
||||
from mongoengine import Document
|
||||
from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict
|
||||
@@ -368,6 +369,20 @@ class TestStrictDict(unittest.TestCase):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
|
||||
|
||||
def test_iterkeys(self):
|
||||
d = self.dtype(a=1)
|
||||
self.assertEqual(list(iterkeys(d)), ['a'])
|
||||
|
||||
def test_len(self):
|
||||
d = self.dtype(a=1)
|
||||
self.assertEqual(len(d), 1)
|
||||
|
||||
def test_pop(self):
|
||||
d = self.dtype(a=1)
|
||||
self.assertIn('a', d)
|
||||
d.pop('a')
|
||||
self.assertNotIn('a', d)
|
||||
|
||||
def test_repr(self):
|
||||
d = self.dtype(a=1, b=2, c=3)
|
||||
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
|
||||
|
@@ -105,6 +105,14 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
self.assertTrue(group_obj._data['members']._dereferenced)
|
||||
|
||||
# verifies that no additional queries gets executed
|
||||
# if we re-iterate over the ListField once it is
|
||||
# dereferenced
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
self.assertTrue(group_obj._data['members']._dereferenced)
|
||||
|
||||
# Document select_related
|
||||
with query_counter() as q:
|
||||
@@ -125,6 +133,46 @@ class FieldTest(unittest.TestCase):
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
def test_list_item_dereference_orphan_dbref(self):
|
||||
"""Ensure that orphan DBRef items in ListFields are dereferenced.
|
||||
"""
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = ListField(ReferenceField(User, dbref=False))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
for i in range(1, 51):
|
||||
user = User(name='user %s' % i)
|
||||
user.save()
|
||||
|
||||
group = Group(members=User.objects)
|
||||
group.save()
|
||||
group.reload() # Confirm reload works
|
||||
|
||||
# Delete one User so one of the references in the
|
||||
# Group.members list is an orphan DBRef
|
||||
User.objects[0].delete()
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
self.assertTrue(group_obj._data['members']._dereferenced)
|
||||
|
||||
# verifies that no additional queries gets executed
|
||||
# if we re-iterate over the ListField once it is
|
||||
# dereferenced
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
self.assertTrue(group_obj._data['members']._dereferenced)
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
@@ -505,6 +553,61 @@ class FieldTest(unittest.TestCase):
|
||||
for m in group_obj.members:
|
||||
self.assertIn('User', m.__class__.__name__)
|
||||
|
||||
|
||||
def test_generic_reference_orphan_dbref(self):
|
||||
"""Ensure that generic orphan DBRef items in ListFields are dereferenced.
|
||||
"""
|
||||
|
||||
class UserA(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserB(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserC(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = ListField(GenericReferenceField())
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
members = []
|
||||
for i in range(1, 51):
|
||||
a = UserA(name='User A %s' % i)
|
||||
a.save()
|
||||
|
||||
b = UserB(name='User B %s' % i)
|
||||
b.save()
|
||||
|
||||
c = UserC(name='User C %s' % i)
|
||||
c.save()
|
||||
|
||||
members += [a, b, c]
|
||||
|
||||
group = Group(members=members)
|
||||
group.save()
|
||||
|
||||
# Delete one UserA instance so that there is
|
||||
# an orphan DBRef in the GenericReference ListField
|
||||
UserA.objects[0].delete()
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
self.assertTrue(group_obj._data['members']._dereferenced)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
self.assertTrue(group_obj._data['members']._dereferenced)
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
|
@@ -1,23 +1,16 @@
|
||||
import unittest
|
||||
|
||||
from pymongo import ReadPreference
|
||||
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
|
||||
if IS_PYMONGO_3:
|
||||
from pymongo import MongoClient
|
||||
CONN_CLASS = MongoClient
|
||||
READ_PREF = ReadPreference.SECONDARY
|
||||
else:
|
||||
from pymongo import ReplicaSetConnection
|
||||
CONN_CLASS = ReplicaSetConnection
|
||||
READ_PREF = ReadPreference.SECONDARY_ONLY
|
||||
from pymongo import MongoClient
|
||||
|
||||
import mongoengine
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import MongoEngineConnectionError
|
||||
|
||||
|
||||
CONN_CLASS = MongoClient
|
||||
READ_PREF = ReadPreference.SECONDARY
|
||||
|
||||
|
||||
class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -35,7 +28,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
"""
|
||||
|
||||
try:
|
||||
conn = connect(db='mongoenginetest',
|
||||
conn = mongoengine.connect(db='mongoenginetest',
|
||||
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
|
||||
read_preference=READ_PREF)
|
||||
except MongoEngineConnectionError as e:
|
||||
|
@@ -227,6 +227,9 @@ class SignalTests(unittest.TestCase):
|
||||
|
||||
self.ExplicitId.objects.delete()
|
||||
|
||||
# Note that there is a chance that the following assert fails in case
|
||||
# some receivers (eventually created in other tests)
|
||||
# gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect)
|
||||
self.assertEqual(self.pre_signals, post_signals)
|
||||
|
||||
def test_model_signals(self):
|
||||
|
@@ -4,9 +4,8 @@ import unittest
|
||||
from nose.plugins.skip import SkipTest
|
||||
|
||||
from mongoengine import connect
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34
|
||||
from mongoengine.pymongo_support import IS_PYMONGO_3
|
||||
from mongoengine.connection import get_db, disconnect_all
|
||||
from mongoengine.mongodb_support import get_mongodb_version
|
||||
|
||||
|
||||
MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
|
||||
@@ -19,6 +18,7 @@ class MongoDBTestCase(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
disconnect_all()
|
||||
cls._connection = connect(db=MONGO_TEST_DB)
|
||||
cls._connection.drop_database(MONGO_TEST_DB)
|
||||
cls.db = get_db()
|
||||
@@ -26,6 +26,7 @@ class MongoDBTestCase(unittest.TestCase):
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._connection.drop_database(MONGO_TEST_DB)
|
||||
disconnect_all()
|
||||
|
||||
|
||||
def get_as_pymongo(doc):
|
||||
@@ -34,8 +35,20 @@ def get_as_pymongo(doc):
|
||||
|
||||
|
||||
def _decorated_with_ver_requirement(func, mongo_version_req, oper):
|
||||
"""Return a given function decorated with the version requirement
|
||||
for a particular MongoDB version tuple.
|
||||
"""Return a MongoDB version requirement decorator.
|
||||
|
||||
The resulting decorator will raise a SkipTest exception if the current
|
||||
MongoDB version doesn't match the provided version/operator.
|
||||
|
||||
For example, if you define a decorator like so:
|
||||
|
||||
def requires_mongodb_gte_36(func):
|
||||
return _decorated_with_ver_requirement(
|
||||
func, (3.6), oper=operator.ge
|
||||
)
|
||||
|
||||
Then tests decorated with @requires_mongodb_gte_36 will be skipped if
|
||||
ran against MongoDB < v3.6.
|
||||
|
||||
:param mongo_version_req: The mongodb version requirement (tuple(int, int))
|
||||
:param oper: The operator to apply (e.g: operator.ge)
|
||||
@@ -50,47 +63,3 @@ def _decorated_with_ver_requirement(func, mongo_version_req, oper):
|
||||
_inner.__name__ = func.__name__
|
||||
_inner.__doc__ = func.__doc__
|
||||
return _inner
|
||||
|
||||
|
||||
def requires_mongodb_gte_34(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v3.4
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, MONGODB_34, oper=operator.ge)
|
||||
|
||||
|
||||
def requires_mongodb_lte_32(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
greater than v3.2.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, MONGODB_32, oper=operator.le)
|
||||
|
||||
|
||||
def requires_mongodb_gte_26(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v2.6.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, MONGODB_26, oper=operator.ge)
|
||||
|
||||
|
||||
def requires_mongodb_gte_3(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v3.0.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, MONGODB_3, oper=operator.ge)
|
||||
|
||||
|
||||
def skip_pymongo3(f):
|
||||
"""Raise a SkipTest exception if we're running a test against
|
||||
PyMongo v3.x.
|
||||
"""
|
||||
def _inner(*args, **kwargs):
|
||||
if IS_PYMONGO_3:
|
||||
raise SkipTest("Useless with PyMongo 3+")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
_inner.__name__ = f.__name__
|
||||
_inner.__doc__ = f.__doc__
|
||||
|
||||
return _inner
|
||||
|
||||
|
Reference in New Issue
Block a user