diff --git a/.install_mongodb_on_travis.sh b/.install_mongodb_on_travis.sh index 057ccf74..0be02655 100644 --- a/.install_mongodb_on_travis.sh +++ b/.install_mongodb_on_travis.sh @@ -3,24 +3,20 @@ sudo apt-get remove mongodb-org-server sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 -if [ "$MONGODB" = "2.6" ]; then - echo "deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen" | sudo tee /etc/apt/sources.list.d/mongodb.list +if [ "$MONGODB" = "3.4" ]; then + sudo apt-key adv --keyserver keyserver.ubuntu.com:80 --recv 0C49F3730359A14518585931BC711F9BA15703C6 + echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list sudo apt-get update - sudo apt-get install mongodb-org-server=2.6.12 + sudo apt-get install mongodb-org-server=3.4.17 # service should be started automatically -elif [ "$MONGODB" = "3.0" ]; then - echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb.list +elif [ "$MONGODB" = "3.6" ]; then + sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2930ADAE8CAF5059EE73BB4B58712A2291FA4AD5 + echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.6 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.6.list sudo apt-get update - sudo apt-get install mongodb-org-server=3.0.14 - # service should be started automatically -elif [ "$MONGODB" = "3.2" ]; then - sudo apt-key adv --keyserver keyserver.ubuntu.com --recv EA312927 - echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.2 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.2.list - sudo apt-get update - sudo apt-get install mongodb-org-server=3.2.20 + sudo apt-get install mongodb-org-server=3.6.12 # service should be started automatically else - echo "Invalid MongoDB version, expected 2.6, 3.0, or 3.2" + echo "Invalid MongoDB version, expected 2.6, 3.0, 3.2, 3.4 or 3.6." exit 1 fi; diff --git a/.travis.yml b/.travis.yml index 4f77f4e0..3186ea1c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,10 +2,18 @@ # PyMongo combinations. However, that would result in an overly long build # with a very large number of jobs, hence we only test a subset of all the # combinations: -# * MongoDB v2.6 is currently the "main" version tested against Python v2.7, -# v3.5, v3.6, PyPy, and PyMongo v3.x. -# * MongoDB v3.0 & v3.2 are tested against Python v2.7, v3.5 & v3.6 -# and Pymongo v3.5 & v3.x +# * MongoDB v3.4 & the latest PyMongo v3.x is currently the "main" setup, +# tested against Python v2.7, v3.5, v3.6, and PyPy. +# * Besides that, we test the lowest actively supported Python/MongoDB/PyMongo +# combination: MongoDB v3.4, PyMongo v3.4, Python v2.7. +# * MongoDB v3.6 is tested against Python v3.6, and PyMongo v3.6, v3.7, v3.8. +# +# We should periodically check MongoDB Server versions supported by MongoDB +# Inc., add newly released versions to the test matrix, and remove versions +# which have reached their End of Life. See: +# 1. https://www.mongodb.com/support-policy. +# 2. https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility +# # Reminder: Update README.rst if you change MongoDB versions we test. language: python @@ -17,7 +25,7 @@ python: - pypy env: -- MONGODB=2.6 PYMONGO=3.x +- MONGODB=3.4 PYMONGO=3.x matrix: # Finish the build as soon as one job fails @@ -25,17 +33,9 @@ matrix: include: - python: 2.7 - env: MONGODB=3.0 PYMONGO=3.5 - - python: 2.7 - env: MONGODB=3.2 PYMONGO=3.x - - python: 3.5 - env: MONGODB=3.0 PYMONGO=3.5 - - python: 3.5 - env: MONGODB=3.2 PYMONGO=3.x + env: MONGODB=3.4 PYMONGO=3.4.x - python: 3.6 - env: MONGODB=3.0 PYMONGO=3.5 - - python: 3.6 - env: MONGODB=3.2 PYMONGO=3.x + env: MONGODB=3.6 PYMONGO=3.x before_install: - bash .install_mongodb_on_travis.sh @@ -49,8 +49,8 @@ install: - travis_retry pip install --upgrade pip - travis_retry pip install coveralls - travis_retry pip install flake8 flake8-import-order -- travis_retry pip install tox>=1.9 -- travis_retry pip install "virtualenv<14.0.0" # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) +- travis_retry pip install "tox" # tox 3.11.0 has requirement virtualenv>=14.0.0 +- travis_retry pip install "virtualenv" # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test # Cache dependencies installed via pip @@ -85,15 +85,15 @@ deploy: password: secure: QMyatmWBnC6ZN3XLW2+fTBDU4LQcp1m/LjR2/0uamyeUzWKdlOoh/Wx5elOgLwt/8N9ppdPeG83ose1jOz69l5G0MUMjv8n/RIcMFSpCT59tGYqn3kh55b0cIZXFT9ar+5cxlif6a5rS72IHm5li7QQyxexJIII6Uxp0kpvUmek= - # create a source distribution and a pure python wheel for faster installs + # Create a source distribution and a pure python wheel for faster installs. distributions: "sdist bdist_wheel" - # only deploy on tagged commits (aka GitHub releases) and only for the - # parent repo's builds running Python 2.7 along with PyMongo v3.x (we run - # Travis against many different Python and PyMongo versions and we don't - # want the deploy to occur multiple times). + # Only deploy on tagged commits (aka GitHub releases) and only for the parent + # repo's builds running Python v2.7 along with PyMongo v3.x and MongoDB v3.4. + # We run Travis against many different Python, PyMongo, and MongoDB versions + # and we don't want the deploy to occur multiple times). on: tags: true repo: MongoEngine/mongoengine - condition: "$PYMONGO = 3.x" + condition: ($PYMONGO = 3.x) AND ($MONGODB = 3.4) python: 2.7 diff --git a/AUTHORS b/AUTHORS index 880dfad1..45a754cc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -248,4 +248,7 @@ that much better: * Andy Yankovsky (https://github.com/werat) * Bastien Gérard (https://github.com/bagerard) * Trevor Hall (https://github.com/tjhall13) - * Gleb Voropaev (https://github.com/buggyspace) \ No newline at end of file + * Gleb Voropaev (https://github.com/buggyspace) + * Paulo Amaral (https://github.com/pauloAmaral) + * Gaurav Dadhania (https://github.com/GVRV) + * Yurii Andrieiev (https://github.com/yandrieiev) diff --git a/README.rst b/README.rst index 4e186a85..fe5f5f22 100644 --- a/README.rst +++ b/README.rst @@ -26,17 +26,17 @@ an `API reference `_. Supported MongoDB Versions ========================== -MongoEngine is currently tested against MongoDB v2.6, v3.0 and v3.2. Future -versions should be supported as well, but aren't actively tested at the moment. -Make sure to open an issue or submit a pull request if you experience any -problems with MongoDB v3.4+. +MongoEngine is currently tested against MongoDB v3.4 and v3.6. Future versions +should be supported as well, but aren't actively tested at the moment. Make +sure to open an issue or submit a pull request if you experience any problems +with MongoDB version > 3.6. Installation ============ We recommend the use of `virtualenv `_ and of `pip `_. You can then use ``pip install -U mongoengine``. You may also have `setuptools `_ -and thus you can use ``easy_install -U mongoengine``. Another option is +and thus you can use ``easy_install -U mongoengine``. Another option is `pipenv `_. You can then use ``pipenv install mongoengine`` to both create the virtual environment and install the package. Otherwise, you can download the source from `GitHub `_ and @@ -47,7 +47,7 @@ Dependencies All of the dependencies can easily be installed via `pip `_. At the very least, you'll need these two packages to use MongoEngine: -- pymongo>=2.7.1 +- pymongo>=3.5 - six>=1.10.0 If you utilize a ``DateTimeField``, you might also use a more flexible date parser: diff --git a/docs/changelog.rst b/docs/changelog.rst index dbd328d8..9ef90afe 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,11 +4,37 @@ Changelog Development =========== +- Drop support for EOL'd MongoDB v2.6, v3.0, and v3.2. +- MongoEngine now requires PyMongo >= v3.4. Travis CI now tests against MongoDB v3.4 – v3.6 and PyMongo v3.4 – v3.6 (#2017 #2066). +- Improve performance by avoiding a call to `to_mongo` in `Document.save()` #2049 +- Connection/disconnection improvements: + - Expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` + - Fix disconnecting #566 #1599 #605 #607 #1213 #565 + - Improve documentation of `connect`/`disconnect` + - Fix issue when using multiple connections to the same mongo with different credentials #2047 + - `connect` fails immediately when db name contains invalid characters #2031 #1718 +- Fix the default write concern of `Document.save` that was overwriting the connection write concern #568 +- Fix querying on `List(EmbeddedDocument)` subclasses fields #1961 #1492 +- Fix querying on `(Generic)EmbeddedDocument` subclasses fields #475 +- Generate unique indices for `SortedListField` and `EmbeddedDocumentListFields` #2020 +- BREAKING CHANGE: Changed the behavior of a custom field validator (i.e `validation` parameter of a `Field`). It is now expected to raise a `ValidationError` instead of returning True/False #2050 +- BREAKING CHANGE: `QuerySet.aggregate` now takes limit and skip value into account #2029 +- BREAKING CHANGES (associated with connect/disconnect fixes): + - Calling `connect` 2 times with the same alias and different parameter will raise an error (should call `disconnect` first). + - `disconnect` now clears `mongoengine.connection._connection_settings`. + - `disconnect` now clears the cached attribute `Document._collection`. - (Fill this out as you fix issues and develop your features). + +Changes in 0.17.0 +================= - Fix .only() working improperly after using .count() of the same instance of QuerySet +- Fix batch_size that was not copied when cloning a queryset object #2011 - POTENTIAL BREAKING CHANGE: All result fields are now passed, including internal fields (_cls, _id) when using `QuerySet.as_pymongo` #1976 +- Document a BREAKING CHANGE introduced in 0.15.3 and not reported at that time (#1995) - Fix InvalidStringData error when using modify on a BinaryField #1127 - DEPRECATION: `EmbeddedDocument.save` & `.reload` are marked as deprecated and will be removed in a next version of mongoengine #1552 +- Fix test suite and CI to support MongoDB 3.4 #1445 +- Fix reference fields querying the database on each access if value contains orphan DBRefs ================= Changes in 0.16.3 @@ -66,6 +92,7 @@ Changes in 0.16.0 Changes in 0.15.3 ================= +- BREAKING CHANGES: `Queryset.update/update_one` methods now returns an UpdateResult when `full_result=True` is provided and no longer a dict (relates to #1491) - Subfield resolve error in generic_emdedded_document query #1651 #1652 - use each modifier only with $position #1673 #1675 - Improve LazyReferenceField and GenericLazyReferenceField with nested fields #1704 diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 5dac6ae9..aac13902 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -4,9 +4,11 @@ Connecting to MongoDB ===================== -To connect to a running instance of :program:`mongod`, use the -:func:`~mongoengine.connect` function. The first argument is the name of the -database to connect to:: +Connections in MongoEngine are registered globally and are identified with aliases. +If no `alias` is provided during the connection, it will use "default" as alias. + +To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect` +function. The first argument is the name of the database to connect to:: from mongoengine import connect connect('project1') @@ -42,6 +44,9 @@ the :attr:`host` to will establish connection to ``production`` database using ``admin`` username and ``qwerty`` password. +.. note:: Calling :func:`~mongoengine.connect` without argument will establish + a connection to the "test" database by default + Replica Sets ============ @@ -71,28 +76,61 @@ is used. In the background this uses :func:`~mongoengine.register_connection` to store the data and you can register all aliases up front if required. -Individual documents can also support multiple databases by providing a +Documents defined in different database +--------------------------------------- +Individual documents can be attached to different databases by providing a `db_alias` in their meta data. This allows :class:`~pymongo.dbref.DBRef` objects to point across databases and collections. Below is an example schema, using 3 different databases to store data:: + connect(alias='user-db-alias', db='user-db') + connect(alias='book-db-alias', db='book-db') + connect(alias='users-books-db-alias', db='users-books-db') + class User(Document): name = StringField() - meta = {'db_alias': 'user-db'} + meta = {'db_alias': 'user-db-alias'} class Book(Document): name = StringField() - meta = {'db_alias': 'book-db'} + meta = {'db_alias': 'book-db-alias'} class AuthorBooks(Document): author = ReferenceField(User) book = ReferenceField(Book) - meta = {'db_alias': 'users-books-db'} + meta = {'db_alias': 'users-books-db-alias'} +Disconnecting an existing connection +------------------------------------ +The function :func:`~mongoengine.disconnect` can be used to +disconnect a particular connection. This can be used to change a +connection globally:: + + from mongoengine import connect, disconnect + connect('a_db', alias='db1') + + class User(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + disconnect(alias='db1') + + connect('another_db', alias='db1') + +.. note:: Calling :func:`~mongoengine.disconnect` without argument + will disconnect the "default" connection + +.. note:: Since connections gets registered globally, it is important + to use the `disconnect` function from MongoEngine and not the + `disconnect()` method of an existing connection (pymongo.MongoClient) + +.. note:: :class:`~mongoengine.Document` are caching the pymongo collection. + using `disconnect` ensures that it gets cleaned as well + Context Managers ================ Sometimes you may want to switch the database or collection to query against. @@ -119,7 +157,7 @@ access to the same User document across databases:: Switch Collection ----------------- -The :class:`~mongoengine.context_managers.switch_collection` context manager +The :func:`~mongoengine.context_managers.switch_collection` context manager allows you to change the collection for a given class allowing quick and easy access to the same Group document across collection:: diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 911de36d..ae9d3b36 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -176,6 +176,21 @@ arguments can be set on all fields: class Shirt(Document): size = StringField(max_length=3, choices=SIZE) +:attr:`validation` (Optional) + A callable to validate the value of the field. + The callable takes the value as parameter and should raise a ValidationError + if validation fails + + e.g :: + + def _not_empty(val): + if not val: + raise ValidationError('value can not be empty') + + class Person(Document): + name = StringField(validation=_not_empty) + + :attr:`**kwargs` (Optional) You can supply additional metadata as arbitrary additional keyword arguments. You can not override existing attributes, however. Common diff --git a/docs/guide/mongomock.rst b/docs/guide/mongomock.rst index 1d5227ec..d70ee6a6 100644 --- a/docs/guide/mongomock.rst +++ b/docs/guide/mongomock.rst @@ -19,3 +19,30 @@ or with an alias: connect('mongoenginetest', host='mongomock://localhost', alias='testdb') conn = get_connection('testdb') + +Example of test file: +-------- +.. code-block:: python + + import unittest + from mongoengine import connect, disconnect + + class Person(Document): + name = StringField() + + class TestPerson(unittest.TestCase): + + @classmethod + def setUpClass(cls): + connect('mongoenginetest', host='mongomock://localhost') + + @classmethod + def tearDownClass(cls): + disconnect() + + def test_thing(self): + pers = Person(name='John') + pers.save() + + fresh_pers = Person.objects().first() + self.assertEqual(fresh_pers.name, 'John') diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 08987835..151855a6 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -64,7 +64,7 @@ Available operators are as follows: * ``gt`` -- greater than * ``gte`` -- greater than or equal to * ``not`` -- negate a standard check, may be used before other operators (e.g. - ``Q(age__not__mod=5)``) + ``Q(age__not__mod=(5, 0))``) * ``in`` -- value is in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 2b78d4e6..b94efab9 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) + list(signals.__all__) + list(errors.__all__)) -VERSION = (0, 16, 3) +VERSION = (0, 17, 0) def get_version(): diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index d747c8cc..999fd23a 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -13,7 +13,7 @@ _document_registry = {} def get_document(name): - """Get a document class by name.""" + """Get a registered Document class by name.""" doc = _document_registry.get(name, None) if not doc: # Possible old style name @@ -30,3 +30,12 @@ def get_document(name): been imported? """.strip() % name) return doc + + +def _get_documents_by_db(connection_alias, default_connection_alias): + """Get all registered Documents class attached to a given database""" + def get_doc_alias(doc_cls): + return doc_cls._meta.get('db_alias', default_connection_alias) + + return [doc_cls for doc_cls in _document_registry.values() + if get_doc_alias(doc_cls) == connection_alias] diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 808332b9..fafc08b7 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -2,6 +2,7 @@ import weakref from bson import DBRef import six +from six import iteritems from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned @@ -363,7 +364,7 @@ class StrictDict(object): _classes = {} def __init__(self, **kwargs): - for k, v in kwargs.iteritems(): + for k, v in iteritems(kwargs): setattr(self, k, v) def __getitem__(self, key): @@ -411,7 +412,7 @@ class StrictDict(object): return (key for key in self.__slots__ if hasattr(self, key)) def __len__(self): - return len(list(self.iteritems())) + return len(list(iteritems(self))) def __eq__(self, other): return self.items() == other.items() diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 6a4c6bd9..2e8dd9f1 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -5,6 +5,7 @@ from functools import partial from bson import DBRef, ObjectId, SON, json_util import pymongo import six +from six import iteritems from mongoengine import signals from mongoengine.base.common import get_document @@ -83,7 +84,7 @@ class BaseDocument(object): self._dynamic_fields = SON() # Assign default values to instance - for key, field in self._fields.iteritems(): + for key, field in iteritems(self._fields): if self._db_field_map.get(key, key) in __only_fields: continue value = getattr(self, key, None) @@ -95,14 +96,14 @@ class BaseDocument(object): # Set passed values after initialisation if self._dynamic: dynamic_data = {} - for key, value in values.iteritems(): + for key, value in iteritems(values): if key in self._fields or key == '_id': setattr(self, key, value) else: dynamic_data[key] = value else: FileField = _import_class('FileField') - for key, value in values.iteritems(): + for key, value in iteritems(values): key = self._reverse_db_field_map.get(key, key) if key in self._fields or key in ('id', 'pk', '_cls'): if __auto_convert and value is not None: @@ -118,7 +119,7 @@ class BaseDocument(object): if self._dynamic: self._dynamic_lock = False - for key, value in dynamic_data.iteritems(): + for key, value in iteritems(dynamic_data): setattr(self, key, value) # Flag initialised @@ -292,8 +293,7 @@ class BaseDocument(object): """ Return as SON data ready for use with MongoDB. """ - if not fields: - fields = [] + fields = fields or [] data = SON() data['_id'] = None @@ -513,7 +513,7 @@ class BaseDocument(object): if not hasattr(data, 'items'): iterator = enumerate(data) else: - iterator = data.iteritems() + iterator = iteritems(data) for index_or_key, value in iterator: item_key = '%s%s.' % (base_key, index_or_key) @@ -678,7 +678,7 @@ class BaseDocument(object): # Convert SON to a data dict, making sure each key is a string and # corresponds to the right db field. data = {} - for key, value in son.iteritems(): + for key, value in iteritems(son): key = str(key) key = cls._db_field_map.get(key, key) data[key] = value @@ -694,7 +694,7 @@ class BaseDocument(object): if not _auto_dereference: fields = copy.deepcopy(fields) - for field_name, field in fields.iteritems(): + for field_name, field in iteritems(fields): field._auto_dereference = _auto_dereference if field.db_field in data: value = data[field.db_field] @@ -715,7 +715,7 @@ class BaseDocument(object): # In STRICT documents, remove any keys that aren't in cls._fields if cls.STRICT: - data = {k: v for k, v in data.iteritems() if k in cls._fields} + data = {k: v for k, v in iteritems(data) if k in cls._fields} obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) obj._changed_fields = changed_fields @@ -882,7 +882,8 @@ class BaseDocument(object): index = {'fields': fields, 'unique': True, 'sparse': sparse} unique_indexes.append(index) - if field.__class__.__name__ == 'ListField': + if field.__class__.__name__ in {'EmbeddedDocumentListField', + 'ListField', 'SortedListField'}: field = field.field # Grab any embedded document field unique indexes diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index a32544d8..fe96f15b 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -5,13 +5,13 @@ import weakref from bson import DBRef, ObjectId, SON import pymongo import six +from six import iteritems from mongoengine.base.common import UPDATE_OPERATORS from mongoengine.base.datastructures import (BaseDict, BaseList, EmbeddedDocumentList) from mongoengine.common import _import_class -from mongoengine.errors import ValidationError - +from mongoengine.errors import DeprecatedError, ValidationError __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField') @@ -52,8 +52,8 @@ class BaseField(object): unique with. :param primary_key: Mark this field as the primary key. Defaults to False. :param validation: (optional) A callable to validate the value of the - field. Generally this is deprecated in favour of the - `FIELD.validate` method + field. The callable takes the value as parameter and should raise + a ValidationError if validation fails :param choices: (optional) The valid choices :param null: (optional) If the field value can be null. If no and there is a default value then the default value is set @@ -225,10 +225,18 @@ class BaseField(object): # check validation argument if self.validation is not None: if callable(self.validation): - if not self.validation(value): - self.error('Value does not match custom validation method') + try: + # breaking change of 0.18 + # Get rid of True/False-type return for the validation method + # in favor of having validation raising a ValidationError + ret = self.validation(value) + if ret is not None: + raise DeprecatedError('validation argument for `%s` must not return anything, ' + 'it should raise a ValidationError if validation fails' % self.name) + except ValidationError as ex: + self.error(str(ex)) else: - raise ValueError('validation argument for "%s" must be a ' + raise ValueError('validation argument for `"%s"` must be a ' 'callable.' % self.name) self.validate(value, **kwargs) @@ -275,11 +283,16 @@ class ComplexBaseField(BaseField): _dereference = _import_class('DeReference')() - if instance._initialised and dereference and instance._data.get(self.name): + if (instance._initialised and + dereference and + instance._data.get(self.name) and + not getattr(instance._data[self.name], '_dereferenced', False)): instance._data[self.name] = _dereference( instance._data.get(self.name), max_depth=1, instance=instance, name=self.name ) + if hasattr(instance._data[self.name], '_dereferenced'): + instance._data[self.name]._dereferenced = True value = super(ComplexBaseField, self).__get__(instance, owner) @@ -382,11 +395,11 @@ class ComplexBaseField(BaseField): if self.field: value_dict = { key: self.field._to_mongo_safe_call(item, use_db_field, fields) - for key, item in value.iteritems() + for key, item in iteritems(value) } else: value_dict = {} - for k, v in value.iteritems(): + for k, v in iteritems(value): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: @@ -423,7 +436,7 @@ class ComplexBaseField(BaseField): errors = {} if self.field: if hasattr(value, 'iteritems') or hasattr(value, 'items'): - sequence = value.iteritems() + sequence = iteritems(value) else: sequence = enumerate(value) for k, v in sequence: diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 8eb10008..6f507eaa 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -1,6 +1,7 @@ import warnings import six +from six import iteritems, itervalues from mongoengine.base.common import _document_registry from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField @@ -62,7 +63,7 @@ class DocumentMetaclass(type): # Standard object mixin - merge in any Fields if not hasattr(base, '_meta'): base_fields = {} - for attr_name, attr_value in base.__dict__.iteritems(): + for attr_name, attr_value in iteritems(base.__dict__): if not isinstance(attr_value, BaseField): continue attr_value.name = attr_name @@ -74,7 +75,7 @@ class DocumentMetaclass(type): # Discover any document fields field_names = {} - for attr_name, attr_value in attrs.iteritems(): + for attr_name, attr_value in iteritems(attrs): if not isinstance(attr_value, BaseField): continue attr_value.name = attr_name @@ -103,7 +104,7 @@ class DocumentMetaclass(type): attrs['_fields_ordered'] = tuple(i[1] for i in sorted( (v.creation_counter, v.name) - for v in doc_fields.itervalues())) + for v in itervalues(doc_fields))) # # Set document hierarchy @@ -173,7 +174,7 @@ class DocumentMetaclass(type): f.__dict__.update({'im_self': getattr(f, '__self__')}) # Handle delete rules - for field in new_class._fields.itervalues(): + for field in itervalues(new_class._fields): f = field if f.owner_document is None: f.owner_document = new_class @@ -183,9 +184,6 @@ class DocumentMetaclass(type): if issubclass(new_class, EmbeddedDocument): raise InvalidDocumentError('CachedReferenceFields is not ' 'allowed in EmbeddedDocuments') - if not f.document_type: - raise InvalidDocumentError( - 'Document is not available to sync') if f.auto_sync: f.start_listener() @@ -375,7 +373,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.objects = QuerySetManager() # Validate the fields and set primary key if needed - for field_name, field in new_class._fields.iteritems(): + for field_name, field in iteritems(new_class._fields): if field.primary_key: # Ensure only one primary key is set current_pk = new_class._meta.get('id_field') @@ -438,7 +436,7 @@ class MetaDict(dict): _merge_options = ('indexes',) def merge(self, new_options): - for k, v in new_options.iteritems(): + for k, v in iteritems(new_options): if k in self._merge_options: self[k] = self.get(k, []) + v else: diff --git a/mongoengine/common.py b/mongoengine/common.py index bde7e78c..bcdea194 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -31,7 +31,6 @@ def _import_class(cls_name): field_classes = _field_list_cache - queryset_classes = ('OperationError',) deref_classes = ('DeReference',) if cls_name == 'BaseDocument': @@ -43,14 +42,11 @@ def _import_class(cls_name): elif cls_name in field_classes: from mongoengine import fields as module import_classes = field_classes - elif cls_name in queryset_classes: - from mongoengine import queryset as module - import_classes = queryset_classes elif cls_name in deref_classes: from mongoengine import dereference as module import_classes = deref_classes else: - raise ValueError('No import set for: ' % cls_name) + raise ValueError('No import set for: %s' % cls_name) for cls in import_classes: _class_registry_cache[cls] = getattr(module, cls) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 38ebb243..9d4f25fc 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,19 +1,22 @@ from pymongo import MongoClient, ReadPreference, uri_parser +from pymongo.database import _check_name import six -from mongoengine.python_support import IS_PYMONGO_3 - -__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME'] +__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all', + 'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME', + 'get_db', 'get_connection'] DEFAULT_CONNECTION_NAME = 'default' +DEFAULT_DATABASE_NAME = 'test' +DEFAULT_HOST = 'localhost' +DEFAULT_PORT = 27017 -if IS_PYMONGO_3: - READ_PREFERENCE = ReadPreference.PRIMARY -else: - from pymongo import MongoReplicaSetClient - READ_PREFERENCE = False +_connection_settings = {} +_connections = {} +_dbs = {} + +READ_PREFERENCE = ReadPreference.PRIMARY class MongoEngineConnectionError(Exception): @@ -23,45 +26,48 @@ class MongoEngineConnectionError(Exception): pass -_connection_settings = {} -_connections = {} -_dbs = {} +def _check_db_name(name): + """Check if a database name is valid. + This functionality is copied from pymongo Database class constructor. + """ + if not isinstance(name, six.string_types): + raise TypeError('name must be an instance of %s' % six.string_types) + elif name != '$external': + _check_name(name) -def register_connection(alias, db=None, name=None, host=None, port=None, - read_preference=READ_PREFERENCE, - username=None, password=None, - authentication_source=None, - authentication_mechanism=None, - **kwargs): - """Add a connection. +def _get_connection_settings( + db=None, name=None, host=None, port=None, + read_preference=READ_PREFERENCE, + username=None, password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs): + """Get the connection settings as a dict - :param alias: the name that will be used to refer to this connection - throughout MongoEngine - :param name: the name of the specific database to use - :param db: the name of the database to use, for compatibility with connect - :param host: the host name of the :program:`mongod` instance to connect to - :param port: the port that the :program:`mongod` instance is running on - :param read_preference: The read preference for the collection - ** Added pymongo 2.1 - :param username: username to authenticate with - :param password: password to authenticate with - :param authentication_source: database to authenticate against - :param authentication_mechanism: database authentication mechanisms. + : param db: the name of the database to use, for compatibility with connect + : param name: the name of the specific database to use + : param host: the host name of the: program: `mongod` instance to connect to + : param port: the port that the: program: `mongod` instance is running on + : param read_preference: The read preference for the collection + : param username: username to authenticate with + : param password: password to authenticate with + : param authentication_source: database to authenticate against + : param authentication_mechanism: database authentication mechanisms. By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, MONGODB-CR (MongoDB Challenge Response protocol) for older servers. - :param is_mock: explicitly use mongomock for this connection - (can also be done by using `mongomock://` as db host prefix) - :param kwargs: ad-hoc parameters to be passed into the pymongo driver, + : param is_mock: explicitly use mongomock for this connection + (can also be done by using `mongomock: // ` as db host prefix) + : param kwargs: ad-hoc parameters to be passed into the pymongo driver, for example maxpoolsize, tz_aware, etc. See the documentation for pymongo's `MongoClient` for a full list. .. versionchanged:: 0.10.6 - added mongomock support """ conn_settings = { - 'name': name or db or 'test', - 'host': host or 'localhost', - 'port': port or 27017, + 'name': name or db or DEFAULT_DATABASE_NAME, + 'host': host or DEFAULT_HOST, + 'port': port or DEFAULT_PORT, 'read_preference': read_preference, 'username': username, 'password': password, @@ -69,6 +75,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, 'authentication_mechanism': authentication_mechanism } + _check_db_name(conn_settings['name']) conn_host = conn_settings['host'] # Host can be a list or a string, so if string, force to a list. @@ -104,16 +111,28 @@ def register_connection(alias, db=None, name=None, host=None, port=None, conn_settings['authentication_source'] = uri_options['authsource'] if 'authmechanism' in uri_options: conn_settings['authentication_mechanism'] = uri_options['authmechanism'] - if IS_PYMONGO_3 and 'readpreference' in uri_options: + if 'readpreference' in uri_options: read_preferences = ( ReadPreference.NEAREST, ReadPreference.PRIMARY, ReadPreference.PRIMARY_PREFERRED, ReadPreference.SECONDARY, - ReadPreference.SECONDARY_PREFERRED) - read_pf_mode = uri_options['readpreference'].lower() + ReadPreference.SECONDARY_PREFERRED, + ) + + # Starting with PyMongo v3.5, the "readpreference" option is + # returned as a string (e.g. "secondaryPreferred") and not an + # int (e.g. 3). + # TODO simplify the code below once we drop support for + # PyMongo v3.4. + read_pf_mode = uri_options['readpreference'] + if isinstance(read_pf_mode, six.string_types): + read_pf_mode = read_pf_mode.lower() for preference in read_preferences: - if preference.name.lower() == read_pf_mode: + if ( + preference.name.lower() == read_pf_mode or + preference.mode == read_pf_mode + ): conn_settings['read_preference'] = preference break else: @@ -125,17 +144,74 @@ def register_connection(alias, db=None, name=None, host=None, port=None, kwargs.pop('is_slave', None) conn_settings.update(kwargs) + return conn_settings + + +def register_connection(alias, db=None, name=None, host=None, port=None, + read_preference=READ_PREFERENCE, + username=None, password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs): + """Register the connection settings. + + : param alias: the name that will be used to refer to this connection + throughout MongoEngine + : param name: the name of the specific database to use + : param db: the name of the database to use, for compatibility with connect + : param host: the host name of the: program: `mongod` instance to connect to + : param port: the port that the: program: `mongod` instance is running on + : param read_preference: The read preference for the collection + : param username: username to authenticate with + : param password: password to authenticate with + : param authentication_source: database to authenticate against + : param authentication_mechanism: database authentication mechanisms. + By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, + MONGODB-CR (MongoDB Challenge Response protocol) for older servers. + : param is_mock: explicitly use mongomock for this connection + (can also be done by using `mongomock: // ` as db host prefix) + : param kwargs: ad-hoc parameters to be passed into the pymongo driver, + for example maxpoolsize, tz_aware, etc. See the documentation + for pymongo's `MongoClient` for a full list. + + .. versionchanged:: 0.10.6 - added mongomock support + """ + conn_settings = _get_connection_settings( + db=db, name=name, host=host, port=port, + read_preference=read_preference, + username=username, password=password, + authentication_source=authentication_source, + authentication_mechanism=authentication_mechanism, + **kwargs) _connection_settings[alias] = conn_settings def disconnect(alias=DEFAULT_CONNECTION_NAME): """Close the connection with a given alias.""" + from mongoengine.base.common import _get_documents_by_db + from mongoengine import Document + if alias in _connections: get_connection(alias=alias).close() del _connections[alias] + if alias in _dbs: + # Detach all cached collections in Documents + for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): + if issubclass(doc_cls, Document): # Skip EmbeddedDocument + doc_cls._disconnect() + del _dbs[alias] + if alias in _connection_settings: + del _connection_settings[alias] + + +def disconnect_all(): + """Close all registered database.""" + for alias in list(_connections.keys()): + disconnect(alias) + def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): """Return a connection with a given alias.""" @@ -159,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): raise MongoEngineConnectionError(msg) def _clean_settings(settings_dict): - # set literal more efficient than calling set function irrelevant_fields_set = { 'name', 'username', 'password', 'authentication_source', 'authentication_mechanism' @@ -169,10 +244,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if k not in irrelevant_fields_set } + raw_conn_settings = _connection_settings[alias].copy() + # Retrieve a copy of the connection settings associated with the requested # alias and remove the database name and authentication info (we don't # care about them at this point). - conn_settings = _clean_settings(_connection_settings[alias].copy()) + conn_settings = _clean_settings(raw_conn_settings) # Determine if we should use PyMongo's or mongomock's MongoClient. is_mock = conn_settings.pop('is_mock', False) @@ -186,51 +263,60 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): else: connection_class = MongoClient - # For replica set connections with PyMongo 2.x, use - # MongoReplicaSetClient. - # TODO remove this once we stop supporting PyMongo 2.x. - if 'replicaSet' in conn_settings and not IS_PYMONGO_3: - connection_class = MongoReplicaSetClient - conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) - - # hosts_or_uri has to be a string, so if 'host' was provided - # as a list, join its parts and separate them by ',' - if isinstance(conn_settings['hosts_or_uri'], list): - conn_settings['hosts_or_uri'] = ','.join( - conn_settings['hosts_or_uri']) - - # Discard port since it can't be used on MongoReplicaSetClient - conn_settings.pop('port', None) - - # Iterate over all of the connection settings and if a connection with - # the same parameters is already established, use it instead of creating - # a new one. - existing_connection = None - connection_settings_iterator = ( - (db_alias, settings.copy()) - for db_alias, settings in _connection_settings.items() - ) - for db_alias, connection_settings in connection_settings_iterator: - connection_settings = _clean_settings(connection_settings) - if conn_settings == connection_settings and _connections.get(db_alias): - existing_connection = _connections[db_alias] - break + # Re-use existing connection if one is suitable + existing_connection = _find_existing_connection(raw_conn_settings) # If an existing connection was found, assign it to the new alias if existing_connection: _connections[alias] = existing_connection else: - # Otherwise, create the new connection for this alias. Raise - # MongoEngineConnectionError if it can't be established. - try: - _connections[alias] = connection_class(**conn_settings) - except Exception as e: - raise MongoEngineConnectionError( - 'Cannot connect to database %s :\n%s' % (alias, e)) + _connections[alias] = _create_connection(alias=alias, + connection_class=connection_class, + **conn_settings) return _connections[alias] +def _create_connection(alias, connection_class, **connection_settings): + """ + Create the new connection for this alias. Raise + MongoEngineConnectionError if it can't be established. + """ + try: + return connection_class(**connection_settings) + except Exception as e: + raise MongoEngineConnectionError( + 'Cannot connect to database %s :\n%s' % (alias, e)) + + +def _find_existing_connection(connection_settings): + """ + Check if an existing connection could be reused + + Iterate over all of the connection settings and if an existing connection + with the same parameters is suitable, return it + + :param connection_settings: the settings of the new connection + :return: An existing connection or None + """ + connection_settings_bis = ( + (db_alias, settings.copy()) + for db_alias, settings in _connection_settings.items() + ) + + def _clean_settings(settings_dict): + # Only remove the name but it's important to + # keep the username/password/authentication_source/authentication_mechanism + # to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047) + return {k: v for k, v in settings_dict.items() if k != 'name'} + + cleaned_conn_settings = _clean_settings(connection_settings) + for db_alias, connection_settings in connection_settings_bis: + db_conn_settings = _clean_settings(connection_settings) + if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): + return _connections[db_alias] + + def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if reconnect: disconnect(alias) @@ -258,14 +344,24 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): provide username and password arguments as well. Multiple databases are supported by using aliases. Provide a separate - `alias` to connect to a different instance of :program:`mongod`. + `alias` to connect to a different instance of: program: `mongod`. + + In order to replace a connection identified by a given alias, you'll + need to call ``disconnect`` first See the docstring for `register_connection` for more details about all supported kwargs. .. versionchanged:: 0.6 - added multiple database support. """ - if alias not in _connections: + if alias in _connections: + prev_conn_setting = _connection_settings[alias] + new_conn_settings = _get_connection_settings(db, **kwargs) + + if new_conn_settings != prev_conn_setting: + raise MongoEngineConnectionError( + 'A different connection with alias `%s` was already registered. Use disconnect() first' % alias) + else: register_connection(alias, db, **kwargs) return get_connection(alias) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index c26b0a79..98bd897b 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,8 +1,11 @@ from contextlib import contextmanager + from pymongo.write_concern import WriteConcern +from six import iteritems + from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db - +from mongoengine.pymongo_support import count_documents __all__ = ('switch_db', 'switch_collection', 'no_dereference', 'no_sub_classes', 'query_counter', 'set_write_concern') @@ -112,7 +115,7 @@ class no_dereference(object): GenericReferenceField = _import_class('GenericReferenceField') ComplexBaseField = _import_class('ComplexBaseField') - self.deref_fields = [k for k, v in self.cls._fields.iteritems() + self.deref_fields = [k for k, v in iteritems(self.cls._fields) if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))] @@ -235,7 +238,7 @@ class query_counter(object): and substracting the queries issued by this context. In fact everytime this is called, 1 query is issued so we need to balance that """ - count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter + count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter self._ctx_query_counter += 1 # Account for the query we just issued to gather the information return count diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 619b5d1f..eaebb56f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,5 +1,6 @@ from bson import DBRef, SON import six +from six import iteritems from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document) @@ -71,7 +72,7 @@ class DeReference(object): def _get_items_from_dict(items): new_items = {} - for k, v in items.iteritems(): + for k, v in iteritems(items): value = v if isinstance(v, list): value = _get_items_from_list(v) @@ -112,7 +113,7 @@ class DeReference(object): depth += 1 for item in iterator: if isinstance(item, (Document, EmbeddedDocument)): - for field_name, field in item._fields.iteritems(): + for field_name, field in iteritems(item._fields): v = item._data.get(field_name, None) if isinstance(v, LazyReference): # LazyReference inherits DBRef but should not be dereferenced here ! @@ -124,7 +125,7 @@ class DeReference(object): elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: field_cls = getattr(getattr(field, 'field', None), 'document_type', None) references = self._find_references(v, depth) - for key, refs in references.iteritems(): + for key, refs in iteritems(references): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): key = field_cls reference_map.setdefault(key, set()).update(refs) @@ -137,7 +138,7 @@ class DeReference(object): reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: references = self._find_references(item, depth - 1) - for key, refs in references.iteritems(): + for key, refs in iteritems(references): reference_map.setdefault(key, set()).update(refs) return reference_map @@ -146,7 +147,7 @@ class DeReference(object): """Fetch all references and convert to their document objects """ object_map = {} - for collection, dbrefs in self.reference_map.iteritems(): + for collection, dbrefs in iteritems(self.reference_map): # we use getattr instead of hasattr because hasattr swallows any exception under python2 # so it could hide nasty things without raising exceptions (cfr bug #1688)) @@ -157,7 +158,7 @@ class DeReference(object): refs = [dbref for dbref in dbrefs if (col_name, dbref) not in object_map] references = collection.objects.in_bulk(refs) - for key, doc in references.iteritems(): + for key, doc in iteritems(references): object_map[(col_name, key)] = doc else: # Generic reference: use the refs data to convert to document if isinstance(doc_type, (ListField, DictField, MapField)): @@ -229,7 +230,7 @@ class DeReference(object): data = [] else: is_list = False - iterator = items.iteritems() + iterator = iteritems(items) data = {} depth += 1 diff --git a/mongoengine/document.py b/mongoengine/document.py index 4f91401c..341a41ba 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -5,6 +5,7 @@ from bson.dbref import DBRef import pymongo from pymongo.read_preferences import ReadPreference import six +from six import iteritems from mongoengine import signals from mongoengine.base import (BaseDict, BaseDocument, BaseList, @@ -17,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern, switch_db) from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, SaveConditionError) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import (NotUniqueError, OperationError, QuerySet, transform) @@ -175,10 +176,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) @classmethod - def _get_collection(cls): - """Return a PyMongo collection for the document.""" - if not hasattr(cls, '_collection') or cls._collection is None: + def _disconnect(cls): + """Detach the Document class from the (cached) database collection""" + cls._collection = None + @classmethod + def _get_collection(cls): + """Return the corresponding PyMongo collection of this document. + Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached + """ + if not hasattr(cls, '_collection') or cls._collection is None: # Get the collection, either capped or regular. if cls._meta.get('max_size') or cls._meta.get('max_documents'): cls._collection = cls._get_capped_collection() @@ -215,7 +222,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # If the collection already exists and has different options # (i.e. isn't capped or has different max/size), raise an error. - if collection_name in db.collection_names(): + if collection_name in list_collection_names(db, include_system_collections=True): collection = db[collection_name] options = collection.options() if ( @@ -240,7 +247,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): data = super(Document, self).to_mongo(*args, **kwargs) # If '_id' is None, try and set it from self._data. If that - # doesn't exist either, remote '_id' from the SON completely. + # doesn't exist either, remove '_id' from the SON completely. if data['_id'] is None: if self._data.get('id') is None: del data['_id'] @@ -346,21 +353,21 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. versionchanged:: 0.10.7 Add signal_kwargs argument """ + signal_kwargs = signal_kwargs or {} + if self._meta.get('abstract'): raise InvalidDocumentError('Cannot save an abstract document.') - signal_kwargs = signal_kwargs or {} signals.pre_save.send(self.__class__, document=self, **signal_kwargs) if validate: self.validate(clean=clean) if write_concern is None: - write_concern = {'w': 1} + write_concern = {} - doc = self.to_mongo() - - created = ('_id' not in doc or self._created or force_insert) + doc_id = self.to_mongo(fields=['id']) + created = ('_id' not in doc_id or self._created or force_insert) signals.pre_save_post_validation.send(self.__class__, document=self, created=created, **signal_kwargs) @@ -438,16 +445,6 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): object_id = wc_collection.insert_one(doc).inserted_id - # In PyMongo 3.0, the save() call calls internally the _update() call - # but they forget to return the _id value passed back, therefore getting it back here - # Correct behaviour in 2.X and in 3.0.1+ versions - if not object_id and pymongo.version_tuple == (3, 0): - pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) - object_id = ( - self._qs.filter(pk=pk_as_mongo_obj).first() and - self._qs.filter(pk=pk_as_mongo_obj).first().pk - ) # TODO doesn't this make 2 queries? - return object_id def _get_update_doc(self): @@ -493,8 +490,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): update_doc = self._get_update_doc() if update_doc: upsert = save_condition is None - last_error = collection.update(select_dict, update_doc, - upsert=upsert, **write_concern) + with set_write_concern(collection, write_concern) as wc_collection: + last_error = wc_collection.update_one( + select_dict, + update_doc, + upsert=upsert + ).raw_result if not upsert and last_error['n'] == 0: raise SaveConditionError('Race condition preventing' ' document update detected') @@ -601,7 +602,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # Delete FileFields separately FileField = _import_class('FileField') - for name, field in self._fields.iteritems(): + for name, field in iteritems(self._fields): if isinstance(field, FileField): getattr(self, name).delete() @@ -786,13 +787,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. versionchanged:: 0.10.7 :class:`OperationError` exception raised if no collection available """ - col_name = cls._get_collection_name() - if not col_name: + coll_name = cls._get_collection_name() + if not coll_name: raise OperationError('Document %s has no collection defined ' '(is it abstract ?)' % cls) cls._collection = None db = cls._get_db() - db.drop_collection(col_name) + db.drop_collection(coll_name) @classmethod def create_index(cls, keys, background=False, **kwargs): @@ -807,18 +808,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): index_spec = index_spec.copy() fields = index_spec.pop('fields') drop_dups = kwargs.get('drop_dups', False) - if IS_PYMONGO_3 and drop_dups: + if drop_dups: msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) - elif not IS_PYMONGO_3: - index_spec['drop_dups'] = drop_dups index_spec['background'] = background index_spec.update(kwargs) - if IS_PYMONGO_3: - return cls._get_collection().create_index(fields, **index_spec) - else: - return cls._get_collection().ensure_index(fields, **index_spec) + return cls._get_collection().create_index(fields, **index_spec) @classmethod def ensure_index(cls, key_or_list, drop_dups=False, background=False, @@ -833,11 +829,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): :param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value will be removed if PyMongo3+ is used """ - if IS_PYMONGO_3 and drop_dups: + if drop_dups: msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) - elif not IS_PYMONGO_3: - kwargs.update({'drop_dups': drop_dups}) return cls.create_index(key_or_list, background=background, **kwargs) @classmethod @@ -853,7 +847,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): drop_dups = cls._meta.get('index_drop_dups', False) index_opts = cls._meta.get('index_opts') or {} index_cls = cls._meta.get('index_cls', True) - if IS_PYMONGO_3 and drop_dups: + if drop_dups: msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) @@ -884,11 +878,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if 'cls' in opts: del opts['cls'] - if IS_PYMONGO_3: - collection.create_index(fields, background=background, **opts) - else: - collection.ensure_index(fields, background=background, - drop_dups=drop_dups, **opts) + collection.create_index(fields, background=background, **opts) # If _cls is being used (for polymorphism), it needs an index, # only if another index doesn't begin with _cls @@ -899,12 +889,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if 'cls' in index_opts: del index_opts['cls'] - if IS_PYMONGO_3: - collection.create_index('_cls', background=background, - **index_opts) - else: - collection.ensure_index('_cls', background=background, - **index_opts) + collection.create_index('_cls', background=background, + **index_opts) @classmethod def list_indexes(cls): diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 986ebf73..bea1d3dc 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -1,11 +1,12 @@ from collections import defaultdict import six +from six import iteritems __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', 'OperationError', 'NotUniqueError', 'FieldDoesNotExist', - 'ValidationError', 'SaveConditionError') + 'ValidationError', 'SaveConditionError', 'DeprecatedError') class NotRegistered(Exception): @@ -109,11 +110,8 @@ class ValidationError(AssertionError): def build_dict(source): errors_dict = {} - if not source: - return errors_dict - if isinstance(source, dict): - for field_name, error in source.iteritems(): + for field_name, error in iteritems(source): errors_dict[field_name] = build_dict(error) elif isinstance(source, ValidationError) and source.errors: return build_dict(source.errors) @@ -135,12 +133,17 @@ class ValidationError(AssertionError): value = ' '.join([generate_key(k) for k in value]) elif isinstance(value, dict): value = ' '.join( - [generate_key(v, k) for k, v in value.iteritems()]) + [generate_key(v, k) for k, v in iteritems(value)]) results = '%s.%s' % (prefix, value) if prefix else value return results error_dict = defaultdict(list) - for k, v in self.to_dict().iteritems(): + for k, v in iteritems(self.to_dict()): error_dict[generate_key(v)].append(k) - return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()]) + return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)]) + + +class DeprecatedError(Exception): + """Raise when a user uses a feature that has been Deprecated""" + pass diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0055bcab..aa5aa805 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -11,6 +11,7 @@ from bson import Binary, DBRef, ObjectId, SON import gridfs import pymongo import six +from six import iteritems try: import dateutil @@ -36,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError from mongoengine.python_support import StringIO from mongoengine.queryset import DO_NOTHING from mongoengine.queryset.base import BaseQuerySet +from mongoengine.queryset.transform import STRING_OPERATORS try: from PIL import Image, ImageOps @@ -105,11 +107,11 @@ class StringField(BaseField): if not isinstance(op, six.string_types): return value - if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): - flags = 0 - if op.startswith('i'): - flags = re.IGNORECASE - op = op.lstrip('i') + if op in STRING_OPERATORS: + case_insensitive = op.startswith('i') + op = op.lstrip('i') + + flags = re.IGNORECASE if case_insensitive else 0 regex = r'%s' if op == 'startswith': @@ -151,12 +153,10 @@ class URLField(StringField): scheme = value.split('://')[0].lower() if scheme not in self.schemes: self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value)) - return # Then check full URL if not self.url_regex.match(value): self.error(u'Invalid URL: {}'.format(value)) - return class EmailField(StringField): @@ -258,10 +258,10 @@ class EmailField(StringField): try: domain_part = domain_part.encode('idna').decode('ascii') except UnicodeError: - self.error(self.error_msg % value) + self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)")) else: if not self.validate_domain_part(domain_part): - self.error(self.error_msg % value) + self.error("%s %s" % (self.error_msg % value, "(domain validation failed)")) class IntField(BaseField): @@ -498,15 +498,18 @@ class DateTimeField(BaseField): if not isinstance(value, six.string_types): return None + return self._parse_datetime(value) + + def _parse_datetime(self, value): + # Attempt to parse a datetime from a string value = value.strip() if not value: return None - # Attempt to parse a datetime: if dateutil: try: return dateutil.parser.parse(value) - except (TypeError, ValueError): + except (TypeError, ValueError, OverflowError): return None # split usecs, because they are not recognized by strptime. @@ -699,7 +702,11 @@ class EmbeddedDocumentField(BaseField): self.document_type.validate(value, clean) def lookup_member(self, member_name): - return self.document_type._fields.get(member_name) + doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__() + for doc_type in doc_and_subclasses: + field = doc_type._fields.get(member_name) + if field: + return field def prepare_query_value(self, op, value): if value is not None and not isinstance(value, self.document_type): @@ -746,12 +753,13 @@ class GenericEmbeddedDocumentField(BaseField): value.validate(clean=clean) def lookup_member(self, member_name): - if self.choices: - for choice in self.choices: - field = choice._fields.get(member_name) + document_choices = self.choices or [] + for document_choice in document_choices: + doc_and_subclasses = [document_choice] + document_choice.__subclasses__() + for doc_type in doc_and_subclasses: + field = doc_type._fields.get(member_name) if field: return field - return None def to_mongo(self, document, use_db_field=True, fields=None): if document is None: @@ -794,12 +802,12 @@ class DynamicField(BaseField): value = {k: v for k, v in enumerate(value)} data = {} - for k, v in value.iteritems(): + for k, v in iteritems(value): data[k] = self.to_mongo(v, use_db_field, fields) value = data if is_list: # Convert back to a list - value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] + value = [v for k, v in sorted(iteritems(data), key=itemgetter(0))] return value def to_python(self, value): diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py new file mode 100644 index 00000000..b20ebc1e --- /dev/null +++ b/mongoengine/mongodb_support.py @@ -0,0 +1,19 @@ +""" +Helper functions, constants, and types to aid with MongoDB version support +""" +from mongoengine.connection import get_connection + + +# Constant that can be used to compare the version retrieved with +# get_mongodb_version() +MONGODB_34 = (3, 4) +MONGODB_36 = (3, 6) + + +def get_mongodb_version(): + """Return the version of the connected mongoDB (first 2 digits) + + :return: tuple(int, int) + """ + version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2) + return tuple(version_list) diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py new file mode 100644 index 00000000..f66c038e --- /dev/null +++ b/mongoengine/pymongo_support.py @@ -0,0 +1,32 @@ +""" +Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support. +""" +import pymongo + +_PYMONGO_37 = (3, 7) + +PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) + +IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 + + +def count_documents(collection, filter): + """Pymongo>3.7 deprecates count in favour of count_documents""" + if IS_PYMONGO_GTE_37: + return collection.count_documents(filter) + else: + count = collection.find(filter).count() + return count + + +def list_collection_names(db, include_system_collections=False): + """Pymongo>3.7 deprecates collection_names in favour of list_collection_names""" + if IS_PYMONGO_GTE_37: + collections = db.list_collection_names() + else: + collections = db.collection_names() + + if not include_system_collections: + collections = [c for c in collections if not c.startswith('system.')] + + return collections diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 7e8e108f..57e467db 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -1,13 +1,8 @@ """ -Helper functions, constants, and types to aid with Python v2.7 - v3.x and -PyMongo v2.7 - v3.x support. +Helper functions, constants, and types to aid with Python v2.7 - v3.x support """ -import pymongo import six - -IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3 - # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. StringIO = six.BytesIO diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 0ebeafa6..c6244825 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -10,8 +10,10 @@ from bson import SON, json_util from bson.code import Code import pymongo import pymongo.errors +from pymongo.collection import ReturnDocument from pymongo.common import validate_read_preference import six +from six import iteritems from mongoengine import signals from mongoengine.base import get_document @@ -20,14 +22,10 @@ from mongoengine.connection import get_db from mongoengine.context_managers import set_write_concern, switch_db from mongoengine.errors import (InvalidQueryError, LookUpError, NotUniqueError, OperationError) -from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode -if IS_PYMONGO_3: - from pymongo.collection import ReturnDocument - __all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') @@ -196,7 +194,7 @@ class BaseQuerySet(object): only_fields=self.only_fields ) - raise AttributeError('Provide a slice or an integer index') + raise TypeError('Provide a slice or an integer index') def __iter__(self): raise NotImplementedError @@ -498,11 +496,12 @@ class BaseQuerySet(object): ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. - :param full_result: Return the full result dictionary rather than just the number - updated, e.g. return - ``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``. + :param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number + updated items :param update: Django-style update keyword arguments + :returns the number of updated documents (unless ``full_result`` is True) + .. versionadded:: 0.2 """ if not update and not upsert: @@ -566,7 +565,7 @@ class BaseQuerySet(object): document = self._document.objects.with_id(atomic_update.upserted_id) return document - def update_one(self, upsert=False, write_concern=None, **update): + def update_one(self, upsert=False, write_concern=None, full_result=False, **update): """Perform an atomic update on the fields of the first document matched by the query. @@ -577,12 +576,19 @@ class BaseQuerySet(object): ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. + :param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number + updated items :param update: Django-style update keyword arguments - + full_result + :returns the number of updated documents (unless ``full_result`` is True) .. versionadded:: 0.2 """ return self.update( - upsert=upsert, multi=False, write_concern=write_concern, **update) + upsert=upsert, + multi=False, + write_concern=write_concern, + full_result=full_result, + **update) def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): """Update and return the updated document. @@ -617,31 +623,25 @@ class BaseQuerySet(object): queryset = self.clone() query = queryset._query - if not IS_PYMONGO_3 or not remove: + if not remove: update = transform.update(queryset._document, **update) sort = queryset._ordering try: - if IS_PYMONGO_3: - if full_response: - msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' - warnings.warn(msg, DeprecationWarning) - if remove: - result = queryset._collection.find_one_and_delete( - query, sort=sort, **self._cursor_args) - else: - if new: - return_doc = ReturnDocument.AFTER - else: - return_doc = ReturnDocument.BEFORE - result = queryset._collection.find_one_and_update( - query, update, upsert=upsert, sort=sort, return_document=return_doc, - **self._cursor_args) - + if full_response: + msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' + warnings.warn(msg, DeprecationWarning) + if remove: + result = queryset._collection.find_one_and_delete( + query, sort=sort, **self._cursor_args) else: - result = queryset._collection.find_and_modify( - query, update, upsert=upsert, sort=sort, remove=remove, new=new, - full_response=full_response, **self._cursor_args) + if new: + return_doc = ReturnDocument.AFTER + else: + return_doc = ReturnDocument.BEFORE + result = queryset._collection.find_one_and_update( + query, update, upsert=upsert, sort=sort, return_document=return_doc, + **self._cursor_args) except pymongo.errors.DuplicateKeyError as err: raise NotUniqueError(u'Update failed (%s)' % err) except pymongo.errors.OperationFailure as err: @@ -748,7 +748,7 @@ class BaseQuerySet(object): '_read_preference', '_iter', '_scalar', '_as_pymongo', '_limit', '_skip', '_hint', '_auto_dereference', '_search_text', 'only_fields', '_max_time_ms', - '_comment') + '_comment', '_batch_size') for prop in copy_props: val = getattr(self, prop) @@ -1073,15 +1073,14 @@ class BaseQuerySet(object): ..versionchanged:: 0.5 - made chainable .. deprecated:: Ignored with PyMongo 3+ """ - if IS_PYMONGO_3: - msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' - warnings.warn(msg, DeprecationWarning) + msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' + warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._snapshot = enabled return queryset def timeout(self, enabled): - """Enable or disable the default mongod timeout when querying. + """Enable or disable the default mongod timeout when querying. (no_cursor_timeout option) :param enabled: whether or not the timeout is used @@ -1099,9 +1098,8 @@ class BaseQuerySet(object): .. deprecated:: Ignored with PyMongo 3+ """ - if IS_PYMONGO_3: - msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' - warnings.warn(msg, DeprecationWarning) + msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' + warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._slave_okay = enabled return queryset @@ -1191,14 +1189,18 @@ class BaseQuerySet(object): initial_pipeline.append({'$sort': dict(self._ordering)}) if self._limit is not None: - initial_pipeline.append({'$limit': self._limit}) + # As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/), + # keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents + # for a skip stage that might succeed these. So we need to maintain more documents in memory in such a + # case (https://stackoverflow.com/a/24161461). + initial_pipeline.append({'$limit': self._limit + (self._skip or 0)}) if self._skip is not None: initial_pipeline.append({'$skip': self._skip}) pipeline = initial_pipeline + list(pipeline) - if IS_PYMONGO_3 and self._read_preference is not None: + if self._read_preference is not None: return self._collection.with_options(read_preference=self._read_preference) \ .aggregate(pipeline, cursor={}, **kwargs) @@ -1408,11 +1410,7 @@ class BaseQuerySet(object): if isinstance(field_instances[-1], ListField): pipeline.insert(1, {'$unwind': '$' + field}) - result = self._document._get_collection().aggregate(pipeline) - if IS_PYMONGO_3: - result = tuple(result) - else: - result = result.get('result') + result = tuple(self._document._get_collection().aggregate(pipeline)) if result: return result[0]['total'] @@ -1439,11 +1437,7 @@ class BaseQuerySet(object): if isinstance(field_instances[-1], ListField): pipeline.insert(1, {'$unwind': '$' + field}) - result = self._document._get_collection().aggregate(pipeline) - if IS_PYMONGO_3: - result = tuple(result) - else: - result = result.get('result') + result = tuple(self._document._get_collection().aggregate(pipeline)) if result: return result[0]['total'] return 0 @@ -1518,26 +1512,16 @@ class BaseQuerySet(object): @property def _cursor_args(self): - if not IS_PYMONGO_3: - fields_name = 'fields' - cursor_args = { - 'timeout': self._timeout, - 'snapshot': self._snapshot - } - if self._read_preference is not None: - cursor_args['read_preference'] = self._read_preference - else: - cursor_args['slave_okay'] = self._slave_okay - else: - fields_name = 'projection' - # snapshot is not handled at all by PyMongo 3+ - # TODO: evaluate similar possibilities using modifiers - if self._snapshot: - msg = 'The snapshot option is not anymore available with PyMongo 3+' - warnings.warn(msg, DeprecationWarning) - cursor_args = { - 'no_cursor_timeout': not self._timeout - } + fields_name = 'projection' + # snapshot is not handled at all by PyMongo 3+ + # TODO: evaluate similar possibilities using modifiers + if self._snapshot: + msg = 'The snapshot option is not anymore available with PyMongo 3+' + warnings.warn(msg, DeprecationWarning) + cursor_args = { + 'no_cursor_timeout': not self._timeout + } + if self._loaded_fields: cursor_args[fields_name] = self._loaded_fields.as_dict() @@ -1561,7 +1545,7 @@ class BaseQuerySet(object): # XXX In PyMongo 3+, we define the read preference on a collection # level, not a cursor level. Thus, we need to get a cloned collection # object using `with_options` first. - if IS_PYMONGO_3 and self._read_preference is not None: + if self._read_preference is not None: self._cursor_obj = self._collection\ .with_options(read_preference=self._read_preference)\ .find(self._query, **self._cursor_args) @@ -1731,13 +1715,13 @@ class BaseQuerySet(object): } """ total, data, types = self.exec_js(freq_func, field) - values = {types.get(k): int(v) for k, v in data.iteritems()} + values = {types.get(k): int(v) for k, v in iteritems(data)} if normalize: values = {k: float(v) / total for k, v in values.items()} frequencies = {} - for k, v in values.iteritems(): + for k, v in iteritems(values): if isinstance(k, float): if int(k) == k: k = int(k) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 2d22c350..128a4e44 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -4,12 +4,11 @@ from bson import ObjectId, SON from bson.dbref import DBRef import pymongo import six +from six import iteritems from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class -from mongoengine.connection import get_connection from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import IS_PYMONGO_3 __all__ = ('query', 'update') @@ -87,18 +86,10 @@ def query(_doc_cls=None, **kwargs): singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += STRING_OPERATORS if op in singular_ops: - if isinstance(field, six.string_types): - if (op in STRING_OPERATORS and - isinstance(value, six.string_types)): - StringField = _import_class('StringField') - value = StringField.prepare_query_value(op, value) - else: - value = field - else: - value = field.prepare_query_value(op, value) + value = field.prepare_query_value(op, value) - if isinstance(field, CachedReferenceField) and value: - value = value['_id'] + if isinstance(field, CachedReferenceField) and value: + value = value['_id'] elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): # Raise an error if the in/nin/all/near param is not iterable. @@ -154,7 +145,7 @@ def query(_doc_cls=None, **kwargs): if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \ ('$near' in value_dict or '$nearSphere' in value_dict): value_son = SON() - for k, v in value_dict.iteritems(): + for k, v in iteritems(value_dict): if k == '$maxDistance' or k == '$minDistance': continue value_son[k] = v @@ -162,16 +153,14 @@ def query(_doc_cls=None, **kwargs): # PyMongo 3+ and MongoDB < 2.6 near_embedded = False for near_op in ('$near', '$nearSphere'): - if isinstance(value_dict.get(near_op), dict) and ( - IS_PYMONGO_3 or get_connection().max_wire_version > 1): + if isinstance(value_dict.get(near_op), dict): value_son[near_op] = SON(value_son[near_op]) if '$maxDistance' in value_dict: - value_son[near_op][ - '$maxDistance'] = value_dict['$maxDistance'] + value_son[near_op]['$maxDistance'] = value_dict['$maxDistance'] if '$minDistance' in value_dict: - value_son[near_op][ - '$minDistance'] = value_dict['$minDistance'] + value_son[near_op]['$minDistance'] = value_dict['$minDistance'] near_embedded = True + if not near_embedded: if '$maxDistance' in value_dict: value_son['$maxDistance'] = value_dict['$maxDistance'] @@ -280,7 +269,7 @@ def update(_doc_cls=None, **update): if op == 'pull': if field.required or value is not None: - if match == 'in' and not isinstance(value, dict): + if match in ('in', 'nin') and not isinstance(value, dict): value = _prepare_query_for_iterable(field, op, value) else: value = field.prepare_query_value(op, value) @@ -307,10 +296,6 @@ def update(_doc_cls=None, **update): key = '.'.join(parts) - if not op: - raise InvalidQueryError('Updates must supply an operation ' - 'eg: set__FIELD=value') - if 'pull' in op and '.' in key: # Dot operators don't work on pull operations # unless they point to a list field diff --git a/requirements.txt b/requirements.txt index 4e3ea940..9bb319a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ nose -pymongo>=2.7.1 +pymongo>=3.4 six==1.10.0 flake8 flake8-import-order diff --git a/setup.py b/setup.py index c7632ce3..f1f5dea7 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ setup( long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=2.7.1', 'six'], + install_requires=['pymongo>=3.4', 'six'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 2632d38f..4fc648b7 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -2,10 +2,10 @@ import unittest from mongoengine import * +from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db -from tests.utils import requires_mongodb_gte_26 __all__ = ("ClassMethodsTest", ) @@ -27,9 +27,7 @@ class ClassMethodsTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_definition(self): @@ -66,10 +64,10 @@ class ClassMethodsTest(unittest.TestCase): """ collection_name = 'person' self.Person(name='Test').save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, list_collection_names(self.db)) self.Person.drop_collection() - self.assertNotIn(collection_name, self.db.collection_names()) + self.assertNotIn(collection_name, list_collection_names(self.db)) def test_register_delete_rule(self): """Ensure that register delete rule adds a delete rule to the document @@ -102,16 +100,16 @@ class ClassMethodsTest(unittest.TestCase): BlogPost.drop_collection() BlogPost.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPost.ensure_index(['author', 'description']) - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('author', 1), ('description', 1)]] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('author', 1), ('description', 1)]]}) BlogPost._get_collection().drop_index('author_1_description_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPost._get_collection().drop_index('author_1_title_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('author', 1), ('title', 1)]], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('author', 1), ('title', 1)]], 'extra': []}) def test_compare_indexes_inheritance(self): """ Ensure that the indexes are properly created and that @@ -140,16 +138,16 @@ class ClassMethodsTest(unittest.TestCase): BlogPost.ensure_indexes() BlogPostWithTags.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPostWithTags.ensure_index(['author', 'tag_list']) - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]]}) BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tag_list_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': []}) def test_compare_indexes_multiple_subclasses(self): """ Ensure that compare_indexes behaves correctly if called from a @@ -184,11 +182,10 @@ class ClassMethodsTest(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithCustomField.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) - self.assertEqual(BlogPostWithTags.compare_indexes(), { 'missing': [], 'extra': [] }) - self.assertEqual(BlogPostWithCustomField.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPostWithTags.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPostWithCustomField.compare_indexes(), {'missing': [], 'extra': []}) - @requires_mongodb_gte_26 def test_compare_indexes_for_text_indexes(self): """ Ensure that compare_indexes behaves correctly for text indexes """ @@ -340,7 +337,7 @@ class ClassMethodsTest(unittest.TestCase): meta = {'collection': collection_name} Person(name="Test User").save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, list_collection_names(self.db)) user_obj = self.db[collection_name].find_one() self.assertEqual(user_obj['name'], "Test User") @@ -349,7 +346,7 @@ class ClassMethodsTest(unittest.TestCase): self.assertEqual(user_obj.name, "Test User") Person.drop_collection() - self.assertNotIn(collection_name, self.db.collection_names()) + self.assertNotIn(collection_name, list_collection_names(self.db)) def test_collection_name_and_primary(self): """Ensure that a collection with a specified name may be used. diff --git a/tests/document/delta.py b/tests/document/delta.py index 30296956..504c1707 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -3,16 +3,14 @@ import unittest from bson import SON from mongoengine import * -from mongoengine.connection import get_db - -__all__ = ("DeltaTest",) +from mongoengine.pymongo_support import list_collection_names +from tests.utils import MongoDBTestCase -class DeltaTest(unittest.TestCase): +class DeltaTest(MongoDBTestCase): def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() + super(DeltaTest, self).setUp() class Person(Document): name = StringField() @@ -25,9 +23,7 @@ class DeltaTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_delta(self): @@ -863,5 +859,6 @@ class DeltaTest(unittest.TestCase): self.assertEqual('oops', delta[0]["users.007.rolist"][0]["type"]) self.assertEqual(uinfo.id, delta[0]["users.007.info"]) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 757d8037..764ef0c5 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -5,10 +5,10 @@ from datetime import datetime from nose.plugins.skip import SkipTest from pymongo.errors import OperationFailure import pymongo +from six import iteritems from mongoengine import * from mongoengine.connection import get_db -from tests.utils import get_mongodb_version, requires_mongodb_gte_26, MONGODB_32, MONGODB_3 __all__ = ("IndexesTest", ) @@ -18,7 +18,6 @@ class IndexesTest(unittest.TestCase): def setUp(self): self.connection = connect(db='mongoenginetest') self.db = get_db() - self.mongodb_version = get_mongodb_version() class Person(Document): name = StringField() @@ -68,7 +67,7 @@ class IndexesTest(unittest.TestCase): info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') self.assertEqual(len(info), 4) - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] for expected in expected_specs: self.assertIn(expected['fields'], info) @@ -100,7 +99,7 @@ class IndexesTest(unittest.TestCase): # the indices on -date and tags will both contain # _cls as first element in the key self.assertEqual(len(info), 4) - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] for expected in expected_specs: self.assertIn(expected['fields'], info) @@ -115,7 +114,7 @@ class IndexesTest(unittest.TestCase): ExtendedBlogPost.ensure_indexes() info = ExtendedBlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] for expected in expected_specs: self.assertIn(expected['fields'], info) @@ -225,7 +224,7 @@ class IndexesTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(Person.objects) info = Person.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('rank.title', 1)], info) def test_explicit_geo2d_index(self): @@ -245,7 +244,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', '2d')], info) def test_explicit_geo2d_index_embedded(self): @@ -268,7 +267,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('current.location.point', '2d')], info) def test_explicit_geosphere_index(self): @@ -288,7 +287,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', '2dsphere')], info) def test_explicit_geohaystack_index(self): @@ -310,7 +309,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', 'geoHaystack')], info) def test_create_geohaystack_index(self): @@ -322,7 +321,7 @@ class IndexesTest(unittest.TestCase): Place.create_index({'fields': (')location.point', 'name')}, bucketSize=10) info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', 'geoHaystack'), ('name', 1)], info) def test_dictionary_indexes(self): @@ -355,7 +354,7 @@ class IndexesTest(unittest.TestCase): info = [(value['key'], value.get('unique', False), value.get('sparse', False)) - for key, value in info.iteritems()] + for key, value in iteritems(info)] self.assertIn(([('addDate', -1)], True, True), info) BlogPost.drop_collection() @@ -407,7 +406,7 @@ class IndexesTest(unittest.TestCase): self.assertEqual(2, User.objects.count()) info = User.objects._collection.index_information() - self.assertEqual(info.keys(), ['_id_']) + self.assertEqual(list(info.keys()), ['_id_']) User.ensure_indexes() info = User.objects._collection.index_information() @@ -476,7 +475,6 @@ class IndexesTest(unittest.TestCase): def test_covered_index(self): """Ensure that covered indexes can be used """ - class Test(Document): a = IntField() b = IntField() @@ -491,38 +489,41 @@ class IndexesTest(unittest.TestCase): obj = Test(a=1) obj.save() - IS_MONGODB_3 = get_mongodb_version() >= MONGODB_3 - # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude('a').explain() - if not IS_MONGODB_3: - self.assertFalse(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IDHACK' + ) query_plan = Test.objects(id=obj.id).only('id').explain() - if not IS_MONGODB_3: - self.assertTrue(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IDHACK' + ) query_plan = Test.objects(a=1).only('a').exclude('id').explain() - if not IS_MONGODB_3: - self.assertTrue(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN') - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'PROJECTION') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IXSCAN' + ) + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('stage'), + 'PROJECTION' + ) query_plan = Test.objects(a=1).explain() - if not IS_MONGODB_3: - self.assertFalse(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN') - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'FETCH') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IXSCAN' + ) + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('stage'), + 'FETCH' + ) def test_index_on_id(self): - class BlogPost(Document): meta = { 'indexes': [ @@ -541,9 +542,8 @@ class IndexesTest(unittest.TestCase): [('categories', 1), ('_id', 1)]) def test_hint(self): - MONGO_VER = self.mongodb_version - TAGS_INDEX_NAME = 'tags_1' + class BlogPost(Document): tags = ListField(StringField()) meta = { @@ -561,25 +561,27 @@ class IndexesTest(unittest.TestCase): tags = [("tag %i" % n) for n in range(i % 2)] BlogPost(tags=tags).save() - self.assertEqual(BlogPost.objects.count(), 10) + # Hinting by shape should work. + self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) + + # Hinting by index name should work. + self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) + + # Clearing the hint should work fine. self.assertEqual(BlogPost.objects.hint().count(), 10) + self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).hint().count(), 10) - # PyMongo 3.0 bug only, works correctly with 2.X and 3.0.1+ versions - if pymongo.version != '3.0': - self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) + # Hinting on a non-existent index shape should fail. + with self.assertRaises(OperationFailure): + BlogPost.objects.hint([('ZZ', 1)]).count() - if MONGO_VER == MONGODB_32: - # Mongo32 throws an error if an index exists (i.e `tags` in our case) - # and you use hint on an index name that does not exist - with self.assertRaises(OperationFailure): - BlogPost.objects.hint([('ZZ', 1)]).count() - else: - self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) + # Hinting on a non-existent index name should fail. + with self.assertRaises(OperationFailure): + BlogPost.objects.hint('Bad Name').count() - self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME ).count(), 10) - - with self.assertRaises(Exception): - BlogPost.objects.hint(('tags', 1)).next() + # Invalid shape argument (missing list brackets) should fail. + with self.assertRaises(ValueError): + BlogPost.objects.hint(('tags', 1)).count() def test_unique(self): """Ensure that uniqueness constraints are applied to fields. @@ -596,10 +598,32 @@ class IndexesTest(unittest.TestCase): # Two posts with the same slug is not allowed post2 = BlogPost(title='test2', slug='test') self.assertRaises(NotUniqueError, post2.save) + self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2) - # Ensure backwards compatibilty for errors + # Ensure backwards compatibility for errors self.assertRaises(OperationError, post2.save) + def test_primary_key_unique_not_working(self): + """Relates to #1445""" + class Blog(Document): + id = StringField(primary_key=True, unique=True) + + Blog.drop_collection() + + with self.assertRaises(OperationFailure) as ctx_err: + Blog(id='garbage').save() + + # One of the errors below should happen. Which one depends on the + # PyMongo version and dict order. + err_msg = str(ctx_err.exception) + self.assertTrue( + any([ + "The field 'unique' is not valid for an _id index specification" in err_msg, + "The field 'background' is not valid for an _id index specification" in err_msg, + "The field 'sparse' is not valid for an _id index specification" in err_msg, + ]) + ) + def test_unique_with(self): """Ensure that unique_with constraints are applied to fields. """ @@ -681,6 +705,77 @@ class IndexesTest(unittest.TestCase): self.assertRaises(NotUniqueError, post2.save) + def test_unique_embedded_document_in_sorted_list(self): + """ + Ensure that the uniqueness constraints are applied to fields in + embedded documents, even when the embedded documents in a sorted list + field. + """ + class SubDocument(EmbeddedDocument): + year = IntField() + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + subs = SortedListField(EmbeddedDocumentField(SubDocument), + ordering='year') + + BlogPost.drop_collection() + + post1 = BlogPost( + title='test1', subs=[ + SubDocument(year=2009, slug='conflict'), + SubDocument(year=2009, slug='conflict') + ] + ) + post1.save() + + # confirm that the unique index is created + indexes = BlogPost._get_collection().index_information() + self.assertIn('subs.slug_1', indexes) + self.assertTrue(indexes['subs.slug_1']['unique']) + + post2 = BlogPost( + title='test2', subs=[SubDocument(year=2014, slug='conflict')] + ) + + self.assertRaises(NotUniqueError, post2.save) + + def test_unique_embedded_document_in_embedded_document_list(self): + """ + Ensure that the uniqueness constraints are applied to fields in + embedded documents, even when the embedded documents in an embedded + list field. + """ + class SubDocument(EmbeddedDocument): + year = IntField() + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + subs = EmbeddedDocumentListField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost( + title='test1', subs=[ + SubDocument(year=2009, slug='conflict'), + SubDocument(year=2009, slug='conflict') + ] + ) + post1.save() + + # confirm that the unique index is created + indexes = BlogPost._get_collection().index_information() + self.assertIn('subs.slug_1', indexes) + self.assertTrue(indexes['subs.slug_1']['unique']) + + post2 = BlogPost( + title='test2', subs=[SubDocument(year=2014, slug='conflict')] + ) + + self.assertRaises(NotUniqueError, post2.save) + def test_unique_with_embedded_document_and_embedded_unique(self): """Ensure that uniqueness constraints are applied to fields on embedded documents. And work with unique_with as well. @@ -732,6 +827,18 @@ class IndexesTest(unittest.TestCase): self.assertEqual(3600, info['created_1']['expireAfterSeconds']) + def test_index_drop_dups_silently_ignored(self): + class Customer(Document): + cust_id = IntField(unique=True, required=True) + meta = { + 'indexes': ['cust_id'], + 'index_drop_dups': True, + 'allow_inheritance': False, + } + + Customer.drop_collection() + Customer.objects.first() + def test_unique_and_indexes(self): """Ensure that 'unique' constraints aren't overridden by meta.indexes. @@ -748,18 +855,23 @@ class IndexesTest(unittest.TestCase): cust.save() cust_dupe = Customer(cust_id=1) - try: + with self.assertRaises(NotUniqueError): cust_dupe.save() - raise AssertionError("We saved a dupe!") - except NotUniqueError: - pass + + cust = Customer(cust_id=2) + cust.save() + + # duplicate key on update + with self.assertRaises(NotUniqueError): + cust.cust_id = 1 + cust.save() def test_primary_save_duplicate_update_existing_object(self): """If you set a field as primary, then unexpected behaviour can occur. You won't create a duplicate but you will update an existing document. """ class User(Document): - name = StringField(primary_key=True, unique=True) + name = StringField(primary_key=True) password = StringField() User.drop_collection() @@ -806,7 +918,7 @@ class IndexesTest(unittest.TestCase): self.fail('Unbound local error at index + pk definition') info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] index_item = [('_id', 1), ('comments.comment_id', 1)] self.assertIn(index_item, info) @@ -854,7 +966,7 @@ class IndexesTest(unittest.TestCase): } info = MyDoc.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('provider_ids.foo', 1)], info) self.assertIn([('provider_ids.bar', 1)], info) @@ -872,7 +984,6 @@ class IndexesTest(unittest.TestCase): info['provider_ids.foo_1_provider_ids.bar_1']['key']) self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse']) - @requires_mongodb_gte_26 def test_text_indexes(self): class Book(Document): title = DictField() @@ -936,7 +1047,6 @@ class IndexesTest(unittest.TestCase): # Drop the temporary database at the end connection.drop_database('tempdatabase') - def test_index_dont_send_cls_option(self): """ Ensure that 'cls' option is not sent through ensureIndex. We shouldn't diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 32e3ed29..d81039f4 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -2,25 +2,22 @@ import unittest import warnings +from six import iteritems + from mongoengine import (BooleanField, Document, EmbeddedDocument, EmbeddedDocumentField, GenericReferenceField, - IntField, ReferenceField, StringField, connect) -from mongoengine.connection import get_db + IntField, ReferenceField, StringField) +from mongoengine.pymongo_support import list_collection_names +from tests.utils import MongoDBTestCase from tests.fixtures import Base __all__ = ('InheritanceTest', ) -class InheritanceTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() +class InheritanceTest(MongoDBTestCase): def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_constructor_cls(self): @@ -36,12 +33,12 @@ class InheritanceTest(unittest.TestCase): meta = {'allow_inheritance': True} test_doc = DataDoc(name='test', embed=EmbedData(data='data')) - assert test_doc._cls == 'DataDoc' - assert test_doc.embed._cls == 'EmbedData' + self.assertEqual(test_doc._cls, 'DataDoc') + self.assertEqual(test_doc.embed._cls, 'EmbedData') test_doc.save() saved_doc = DataDoc.objects.with_id(test_doc.id) - assert test_doc._cls == saved_doc._cls - assert test_doc.embed._cls == saved_doc.embed._cls + self.assertEqual(test_doc._cls, saved_doc._cls) + self.assertEqual(test_doc.embed._cls, saved_doc.embed._cls) test_doc.delete() def test_superclasses(self): @@ -485,7 +482,7 @@ class InheritanceTest(unittest.TestCase): meta = {'abstract': True} class Human(Mammal): pass - for k, v in defaults.iteritems(): + for k, v in iteritems(defaults): for cls in [Animal, Fish, Guppy]: self.assertEqual(cls._meta[k], v) diff --git a/tests/document/instance.py b/tests/document/instance.py index 39e47524..0f2f0c0f 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -4,18 +4,20 @@ import os import pickle import unittest import uuid +import warnings import weakref - from datetime import datetime -import warnings from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError +from six import iteritems +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36, MONGODB_34 +from mongoengine.pymongo_support import list_collection_names from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, PickleDynamicEmbedded, PickleDynamicTest) -from tests.utils import MongoDBTestCase +from tests.utils import MongoDBTestCase, get_as_pymongo from mongoengine import * from mongoengine.base import get_document, _document_registry @@ -27,8 +29,6 @@ from mongoengine.queryset import NULLIFY, Q from mongoengine.context_managers import switch_db, query_counter from mongoengine import signals -from tests.utils import requires_mongodb_gte_26 - TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), '../fields/mongoengine.png') @@ -55,9 +55,7 @@ class InstanceTest(MongoDBTestCase): self.Job = Job def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def assertDbEqual(self, docs): @@ -421,6 +419,12 @@ class InstanceTest(MongoDBTestCase): person.save() person.to_dbref() + def test_key_like_attribute_access(self): + person = self.Person(age=30) + self.assertEqual(person['age'], 30) + with self.assertRaises(KeyError): + person['unknown_attr'] + def test_save_abstract_document(self): """Saving an abstract document should fail.""" class Doc(Document): @@ -463,7 +467,16 @@ class InstanceTest(MongoDBTestCase): Animal.drop_collection() doc = Animal(superphylum='Deuterostomia') doc.save() - doc.reload() + + mongo_db = get_mongodb_version() + CMD_QUERY_KEY = 'command' if mongo_db >= MONGODB_36 else 'query' + + with query_counter() as q: + doc.reload() + query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] + self.assertEqual(set(query_op[CMD_QUERY_KEY]['filter'].keys()), set(['_id', 'superphylum'])) + + Animal.drop_collection() def test_reload_sharded_nested(self): class SuperPhylum(EmbeddedDocument): @@ -477,6 +490,34 @@ class InstanceTest(MongoDBTestCase): doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) doc.save() doc.reload() + Animal.drop_collection() + + def test_update_shard_key_routing(self): + """Ensures updating a doc with a specified shard_key includes it in + the query. + """ + class Animal(Document): + is_mammal = BooleanField() + name = StringField() + meta = {'shard_key': ('is_mammal', 'id')} + + Animal.drop_collection() + doc = Animal(is_mammal=True, name='Dog') + doc.save() + + mongo_db = get_mongodb_version() + + with query_counter() as q: + doc.name = 'Cat' + doc.save() + query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] + self.assertEqual(query_op['op'], 'update') + if mongo_db == MONGODB_34: + self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal'])) + else: + self.assertEqual(set(query_op['command']['q'].keys()), set(['_id', 'is_mammal'])) + + Animal.drop_collection() def test_reload_with_changed_fields(self): """Ensures reloading will not affect changed fields""" @@ -572,7 +613,7 @@ class InstanceTest(MongoDBTestCase): Post.drop_collection() - Post._get_collection().insert({ + Post._get_collection().insert_one({ "title": "Items eclipse", "items": ["more lorem", "even more ipsum"] }) @@ -712,39 +753,78 @@ class InstanceTest(MongoDBTestCase): acc1 = Account.objects.first() self.assertHasInstance(acc1._data["emails"][0], acc1) + def test_save_checks_that_clean_is_called(self): + class CustomError(Exception): + pass + + class TestDocument(Document): + def clean(self): + raise CustomError() + + with self.assertRaises(CustomError): + TestDocument().save() + + TestDocument().save(clean=False) + + def test_save_signal_pre_save_post_validation_makes_change_to_doc(self): + class BlogPost(Document): + content = StringField() + + @classmethod + def pre_save_post_validation(cls, sender, document, **kwargs): + document.content = 'checked' + + signals.pre_save_post_validation.connect(BlogPost.pre_save_post_validation, sender=BlogPost) + + BlogPost.drop_collection() + + post = BlogPost(content='unchecked').save() + self.assertEqual(post.content, 'checked') + # Make sure pre_save_post_validation changes makes it to the db + raw_doc = get_as_pymongo(post) + self.assertEqual( + raw_doc, + { + 'content': 'checked', + '_id': post.id + }) + + # Important to disconnect as it could cause some assertions in test_signals + # to fail (due to the garbage collection timing of this signal) + signals.pre_save_post_validation.disconnect(BlogPost.pre_save_post_validation) + def test_document_clean(self): class TestDocument(Document): status = StringField() - pub_date = DateTimeField() + cleaned = BooleanField(default=False) def clean(self): - if self.status == 'draft' and self.pub_date is not None: - msg = 'Draft entries may not have a publication date.' - raise ValidationError(msg) - # Set the pub_date for published items if not set. - if self.status == 'published' and self.pub_date is None: - self.pub_date = datetime.now() + self.cleaned = True TestDocument.drop_collection() - t = TestDocument(status="draft", pub_date=datetime.now()) - - with self.assertRaises(ValidationError) as cm: - t.save() - - expected_msg = "Draft entries may not have a publication date." - self.assertIn(expected_msg, cm.exception.message) - self.assertEqual(cm.exception.to_dict(), {'__all__': expected_msg}) + t = TestDocument(status="draft") + # Ensure clean=False prevent call to clean t = TestDocument(status="published") t.save(clean=False) - - self.assertEqual(t.pub_date, None) + self.assertEqual(t.status, "published") + self.assertEqual(t.cleaned, False) t = TestDocument(status="published") + self.assertEqual(t.cleaned, False) t.save(clean=True) - - self.assertEqual(type(t.pub_date), datetime) + self.assertEqual(t.status, "published") + self.assertEqual(t.cleaned, True) + raw_doc = get_as_pymongo(t) + # Make sure clean changes makes it to the db + self.assertEqual( + raw_doc, + { + 'status': 'published', + 'cleaned': True, + '_id': t.id + }) def test_document_embedded_clean(self): class TestEmbeddedDocument(EmbeddedDocument): @@ -806,7 +886,8 @@ class InstanceTest(MongoDBTestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - assert not doc1.modify({'name': doc2.name}, set__age=100) + n_modified = doc1.modify({'name': doc2.name}, set__age=100) + self.assertEqual(n_modified, 0) self.assertDbEqual(docs) @@ -815,7 +896,8 @@ class InstanceTest(MongoDBTestCase): doc2 = self.Person(id=ObjectId(), name="jim", age=20) docs = [dict(doc1.to_mongo())] - assert not doc2.modify({'name': doc2.name}, set__age=100) + n_modified = doc2.modify({'name': doc2.name}, set__age=100) + self.assertEqual(n_modified, 0) self.assertDbEqual(docs) @@ -831,18 +913,18 @@ class InstanceTest(MongoDBTestCase): doc.job.name = "Google" doc.job.years = 3 - assert doc.modify( + n_modified = doc.modify( set__age=21, set__job__name="MongoDB", unset__job__years=True) + self.assertEqual(n_modified, 1) doc_copy.age = 21 doc_copy.job.name = "MongoDB" del doc_copy.job.years - assert doc.to_json() == doc_copy.to_json() - assert doc._get_changed_fields() == [] + self.assertEqual(doc.to_json(), doc_copy.to_json()) + self.assertEqual(doc._get_changed_fields(), []) self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) - @requires_mongodb_gte_26 def test_modify_with_positional_push(self): class Content(EmbeddedDocument): keywords = ListField(StringField()) @@ -882,19 +964,39 @@ class InstanceTest(MongoDBTestCase): person.save() # Ensure that the object is in the database - collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(person_obj['name'], 'Test User') - self.assertEqual(person_obj['age'], 30) - self.assertEqual(person_obj['_id'], person.id) + raw_doc = get_as_pymongo(person) + self.assertEqual( + raw_doc, + { + '_cls': 'Person', + 'name': 'Test User', + 'age': 30, + '_id': person.id + }) - # Test skipping validation on save + def test_save_skip_validation(self): class Recipient(Document): email = EmailField(required=True) recipient = Recipient(email='not-an-email') - self.assertRaises(ValidationError, recipient.save) + with self.assertRaises(ValidationError): + recipient.save() + recipient.save(validate=False) + raw_doc = get_as_pymongo(recipient) + self.assertEqual( + raw_doc, + { + 'email': 'not-an-email', + '_id': recipient.id + }) + + def test_save_with_bad_id(self): + class Clown(Document): + id = IntField(primary_key=True) + + with self.assertRaises(ValidationError): + Clown(id="not_an_int").save() def test_save_to_a_value_that_equates_to_false(self): class Thing(EmbeddedDocument): @@ -2758,7 +2860,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'foo': 'Bar', 'data': [1, 2, 3] @@ -2774,7 +2876,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'foo': 'Bar', 'data': [1, 2, 3] @@ -2797,7 +2899,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'thing': { 'name': 'My thing', @@ -2820,7 +2922,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'thing': { 'name': 'My thing', @@ -2843,7 +2945,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'thing': { 'name': 'My thing', @@ -3060,7 +3162,7 @@ class InstanceTest(MongoDBTestCase): def expand(self): self.flattened_parameter = {} - for parameter_name, parameter in self.parameters.iteritems(): + for parameter_name, parameter in iteritems(self.parameters): parameter.expand() class NodesSystem(Document): @@ -3068,7 +3170,7 @@ class InstanceTest(MongoDBTestCase): nodes = MapField(ReferenceField(Node, dbref=False)) def save(self, *args, **kwargs): - for node_name, node in self.nodes.iteritems(): + for node_name, node in iteritems(self.nodes): node.expand() node.save(*args, **kwargs) super(NodesSystem, self).save(*args, **kwargs) @@ -3196,7 +3298,7 @@ class InstanceTest(MongoDBTestCase): p2.name = 'alon2' p2.save() p3 = Person.objects().only('created_on')[0] - self.assertEquals(orig_created_on, p3.created_on) + self.assertEqual(orig_created_on, p3.created_on) class Person(Document): created_on = DateTimeField(default=lambda: datetime.utcnow()) @@ -3205,29 +3307,28 @@ class InstanceTest(MongoDBTestCase): p4 = Person.objects()[0] p4.save() - self.assertEquals(p4.height, 189) + self.assertEqual(p4.height, 189) # However the default will not be fixed in DB - self.assertEquals(Person.objects(height=189).count(), 0) + self.assertEqual(Person.objects(height=189).count(), 0) # alter DB for the new default coll = Person._get_collection() for person in Person.objects.as_pymongo(): if 'height' not in person: - person['height'] = 189 - coll.save(person) + coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}}) - self.assertEquals(Person.objects(height=189).count(), 1) + self.assertEqual(Person.objects(height=189).count(), 1) def test_from_son(self): # 771 class MyPerson(self.Person): meta = dict(shard_key=["id"]) p = MyPerson.from_json('{"name": "name", "age": 27}', created=True) - self.assertEquals(p.id, None) + self.assertEqual(p.id, None) p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here p = MyPerson._from_son({"name": "name", "age": 27}, created=True) - self.assertEquals(p.id, None) + self.assertEqual(p.id, None) p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here def test_from_son_created_False_without_id(self): @@ -3305,7 +3406,7 @@ class InstanceTest(MongoDBTestCase): u_from_db = User.objects.get(name='user') u_from_db.height = None u_from_db.save() - self.assertEquals(u_from_db.height, None) + self.assertEqual(u_from_db.height, None) # 864 self.assertEqual(u_from_db.str_fld, None) self.assertEqual(u_from_db.int_fld, None) @@ -3319,7 +3420,7 @@ class InstanceTest(MongoDBTestCase): u.save() User.objects(name='user').update_one(set__height=None, upsert=True) u_from_db = User.objects.get(name='user') - self.assertEquals(u_from_db.height, None) + self.assertEqual(u_from_db.height, None) def test_not_saved_eq(self): """Ensure we can compare documents not saved. @@ -3361,7 +3462,6 @@ class InstanceTest(MongoDBTestCase): person.update(set__height=2.0) - @requires_mongodb_gte_26 def test_push_with_position(self): """Ensure that push with position works properly for an instance.""" class BlogPost(Document): diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 110f1e14..251b65a2 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -32,12 +32,12 @@ class TestJson(unittest.TestCase): string = StringField(db_field='s') embedded = EmbeddedDocumentField(Embedded, db_field='e') - doc = Doc( string="Hello", embedded=Embedded(string="Inner Hello")) - doc_json = doc.to_json(sort_keys=True, use_db_field=False,separators=(',', ':')) + doc = Doc(string="Hello", embedded=Embedded(string="Inner Hello")) + doc_json = doc.to_json(sort_keys=True, use_db_field=False, separators=(',', ':')) expected_json = """{"embedded":{"string":"Inner Hello"},"string":"Hello"}""" - self.assertEqual( doc_json, expected_json) + self.assertEqual(doc_json, expected_json) def test_json_simple(self): @@ -61,10 +61,6 @@ class TestJson(unittest.TestCase): self.assertEqual(doc, Doc.from_json(doc.to_json())) def test_json_complex(self): - - if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: - raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") - class EmbeddedDoc(EmbeddedDocument): pass diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 05810f2c..68baab46 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1,84 +1,24 @@ # -*- coding: utf-8 -*- import datetime import unittest -import uuid -import math -import itertools -import re -import sys from nose.plugins.skip import SkipTest -import six - -try: - import dateutil -except ImportError: - dateutil = None - -from decimal import Decimal from bson import DBRef, ObjectId, SON -try: - from bson.int64 import Int64 -except ImportError: - Int64 = long -from mongoengine import * -from mongoengine.connection import get_db -from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, - _document_registry, LazyReference) +from mongoengine import Document, StringField, IntField, DateTimeField, DateField, ValidationError, \ + ComplexDateTimeField, FloatField, ListField, ReferenceField, DictField, EmbeddedDocument, EmbeddedDocumentField, \ + GenericReferenceField, DoesNotExist, NotRegistered, OperationError, DynamicField, \ + FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField,\ + ObjectIdField, SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument +from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry) +from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase -__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") - class FieldTest(MongoDBTestCase): - def test_datetime_from_empty_string(self): - """ - Ensure an exception is raised when trying to - cast an empty string to datetime. - """ - class MyDoc(Document): - dt = DateTimeField() - - md = MyDoc(dt='') - self.assertRaises(ValidationError, md.save) - - def test_date_from_empty_string(self): - """ - Ensure an exception is raised when trying to - cast an empty string to datetime. - """ - class MyDoc(Document): - dt = DateField() - - md = MyDoc(dt='') - self.assertRaises(ValidationError, md.save) - - def test_datetime_from_whitespace_string(self): - """ - Ensure an exception is raised when trying to - cast a whitespace-only string to datetime. - """ - class MyDoc(Document): - dt = DateTimeField() - - md = MyDoc(dt=' ') - self.assertRaises(ValidationError, md.save) - - def test_date_from_whitespace_string(self): - """ - Ensure an exception is raised when trying to - cast a whitespace-only string to datetime. - """ - class MyDoc(Document): - dt = DateField() - - md = MyDoc(dt=' ') - self.assertRaises(ValidationError, md.save) - def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. @@ -117,6 +57,48 @@ class FieldTest(MongoDBTestCase): self.assertEqual( data_to_be_saved, ['age', 'created', 'day', 'name', 'userid']) + def test_custom_field_validation_raise_deprecated_error_when_validation_return_something(self): + # Covers introduction of a breaking change in the validation parameter (0.18) + def _not_empty(z): + return bool(z) + + class Person(Document): + name = StringField(validation=_not_empty) + + Person.drop_collection() + + error = ("validation argument for `name` must not return anything, " + "it should raise a ValidationError if validation fails") + + with self.assertRaises(DeprecatedError) as ctx_err: + Person(name="").validate() + self.assertEqual(str(ctx_err.exception), error) + + with self.assertRaises(DeprecatedError) as ctx_err: + Person(name="").save() + self.assertEqual(str(ctx_err.exception), error) + + def test_custom_field_validation_raise_validation_error(self): + def _not_empty(z): + if not z: + raise ValidationError('cantbeempty') + + class Person(Document): + name = StringField(validation=_not_empty) + + Person.drop_collection() + + with self.assertRaises(ValidationError) as ctx_err: + Person(name="").validate() + self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception)) + + with self.assertRaises(ValidationError): + Person(name="").save() + self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception)) + + Person(name="garbage").validate() + Person(name="garbage").save() + def test_default_values_set_to_None(self): """Ensure that default field values are used even when we explcitly initialize the doc with None values. @@ -335,31 +317,7 @@ class FieldTest(MongoDBTestCase): # attempted. self.assertRaises(ValidationError, ret.validate) - def test_int_and_float_ne_operator(self): - class TestDocument(Document): - int_fld = IntField() - float_fld = FloatField() - - TestDocument.drop_collection() - - TestDocument(int_fld=None, float_fld=None).save() - TestDocument(int_fld=1, float_fld=1).save() - - self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) - self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) - - def test_long_ne_operator(self): - class TestDocument(Document): - long_fld = LongField() - - TestDocument.drop_collection() - - TestDocument(long_fld=None).save() - TestDocument(long_fld=1).save() - - self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) - - def test_object_id_validation(self): + def test_default_id_validation_as_objectid(self): """Ensure that invalid values cannot be assigned to an ObjectIdField. """ @@ -375,7 +333,7 @@ class FieldTest(MongoDBTestCase): person.id = 'abc' self.assertRaises(ValidationError, person.validate) - person.id = '497ce96f395f2f052a494fd4' + person.id = str(ObjectId()) person.validate() def test_string_validation(self): @@ -402,162 +360,6 @@ class FieldTest(MongoDBTestCase): person.name = 'Shorter name' person.validate() - def test_url_validation(self): - """Ensure that URLFields validate urls properly.""" - class Link(Document): - url = URLField() - - link = Link() - link.url = 'google' - self.assertRaises(ValidationError, link.validate) - - link.url = 'http://www.google.com:8080' - link.validate() - - def test_unicode_url_validation(self): - """Ensure unicode URLs are validated properly.""" - class Link(Document): - url = URLField() - - link = Link() - link.url = u'http://привет.com' - - # TODO fix URL validation - this *IS* a valid URL - # For now we just want to make sure that the error message is correct - try: - link.validate() - self.assertTrue(False) - except ValidationError as e: - self.assertEqual( - unicode(e), - u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])" - ) - - def test_url_scheme_validation(self): - """Ensure that URLFields validate urls with specific schemes properly. - """ - class Link(Document): - url = URLField() - - class SchemeLink(Document): - url = URLField(schemes=['ws', 'irc']) - - link = Link() - link.url = 'ws://google.com' - self.assertRaises(ValidationError, link.validate) - - scheme_link = SchemeLink() - scheme_link.url = 'ws://google.com' - scheme_link.validate() - - def test_url_allowed_domains(self): - """Allow underscore in domain names. - """ - class Link(Document): - url = URLField() - - link = Link() - link.url = 'https://san_leandro-ca.geebo.com' - link.validate() - - def test_int_validation(self): - """Ensure that invalid values cannot be assigned to int fields. - """ - class Person(Document): - age = IntField(min_value=0, max_value=110) - - person = Person() - person.age = 50 - person.validate() - - person.age = -1 - self.assertRaises(ValidationError, person.validate) - person.age = 120 - self.assertRaises(ValidationError, person.validate) - person.age = 'ten' - self.assertRaises(ValidationError, person.validate) - - def test_long_validation(self): - """Ensure that invalid values cannot be assigned to long fields. - """ - class TestDocument(Document): - value = LongField(min_value=0, max_value=110) - - doc = TestDocument() - doc.value = 50 - doc.validate() - - doc.value = -1 - self.assertRaises(ValidationError, doc.validate) - doc.age = 120 - self.assertRaises(ValidationError, doc.validate) - doc.age = 'ten' - self.assertRaises(ValidationError, doc.validate) - - def test_float_validation(self): - """Ensure that invalid values cannot be assigned to float fields. - """ - class Person(Document): - height = FloatField(min_value=0.1, max_value=3.5) - - class BigPerson(Document): - height = FloatField() - - person = Person() - person.height = 1.89 - person.validate() - - person.height = '2.0' - self.assertRaises(ValidationError, person.validate) - - person.height = 0.01 - self.assertRaises(ValidationError, person.validate) - - person.height = 4.0 - self.assertRaises(ValidationError, person.validate) - - person_2 = Person(height='something invalid') - self.assertRaises(ValidationError, person_2.validate) - - big_person = BigPerson() - - for value, value_type in enumerate(six.integer_types): - big_person.height = value_type(value) - big_person.validate() - - big_person.height = 2 ** 500 - big_person.validate() - - big_person.height = 2 ** 100000 # Too big for a float value - self.assertRaises(ValidationError, big_person.validate) - - def test_decimal_validation(self): - """Ensure that invalid values cannot be assigned to decimal fields. - """ - class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), - max_value=Decimal('3.5')) - - Person.drop_collection() - - Person(height=Decimal('1.89')).save() - person = Person.objects.first() - self.assertEqual(person.height, Decimal('1.89')) - - person.height = '2.0' - person.save() - person.height = 0.01 - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('0.01') - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('4.0') - self.assertRaises(ValidationError, person.validate) - person.height = 'something invalid' - self.assertRaises(ValidationError, person.validate) - - person_2 = Person(height='something invalid') - self.assertRaises(ValidationError, person_2.validate) - def test_db_field_validation(self): """Ensure that db_field doesn't accept invalid values.""" @@ -576,395 +378,9 @@ class FieldTest(MongoDBTestCase): class User(Document): name = StringField(db_field='name\0') - def test_decimal_comparison(self): - class Person(Document): - money = DecimalField() - - Person.drop_collection() - - Person(money=6).save() - Person(money=8).save() - Person(money=10).save() - - self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) - self.assertEqual(2, Person.objects(money__gt=7).count()) - self.assertEqual(2, Person.objects(money__gt="7").count()) - - def test_decimal_storage(self): - class Person(Document): - float_value = DecimalField(precision=4) - string_value = DecimalField(precision=4, force_string=True) - - Person.drop_collection() - values_to_store = [10, 10.1, 10.11, "10.111", Decimal("10.1111"), Decimal("10.11111")] - for store_at_creation in [True, False]: - for value in values_to_store: - # to_python is called explicitly if values were sent in the kwargs of __init__ - if store_at_creation: - Person(float_value=value, string_value=value).save() - else: - person = Person.objects.create() - person.float_value = value - person.string_value = value - person.save() - - # How its stored - expected = [ - {'float_value': 10.0, 'string_value': '10.0000'}, - {'float_value': 10.1, 'string_value': '10.1000'}, - {'float_value': 10.11, 'string_value': '10.1100'}, - {'float_value': 10.111, 'string_value': '10.1110'}, - {'float_value': 10.1111, 'string_value': '10.1111'}, - {'float_value': 10.1111, 'string_value': '10.1111'}] - expected.extend(expected) - actual = list(Person.objects.exclude('id').as_pymongo()) - self.assertEqual(expected, actual) - - # How it comes out locally - expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), - Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] - expected.extend(expected) - for field_name in ['float_value', 'string_value']: - actual = list(Person.objects().scalar(field_name)) - self.assertEqual(expected, actual) - - def test_boolean_validation(self): - """Ensure that invalid values cannot be assigned to boolean - fields. - """ - class Person(Document): - admin = BooleanField() - - person = Person() - person.admin = True - person.validate() - - person.admin = 2 - self.assertRaises(ValidationError, person.validate) - person.admin = 'Yes' - self.assertRaises(ValidationError, person.validate) - person.admin = 'False' - self.assertRaises(ValidationError, person.validate) - - def test_uuid_field_string(self): - """Test UUID fields storing as String - """ - class Person(Document): - api_key = UUIDField(binary=False) - - Person.drop_collection() - - uu = uuid.uuid4() - Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) - - person = Person() - valid = (uuid.uuid4(), uuid.uuid1()) - for api_key in valid: - person.api_key = api_key - person.validate() - - invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', - '9d159858-549b-4975-9f98-dd2f987c113') - for api_key in invalid: - person.api_key = api_key - self.assertRaises(ValidationError, person.validate) - - def test_uuid_field_binary(self): - """Test UUID fields storing as Binary object.""" - class Person(Document): - api_key = UUIDField(binary=True) - - Person.drop_collection() - - uu = uuid.uuid4() - Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) - - person = Person() - valid = (uuid.uuid4(), uuid.uuid1()) - for api_key in valid: - person.api_key = api_key - person.validate() - - invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', - '9d159858-549b-4975-9f98-dd2f987c113') - for api_key in invalid: - person.api_key = api_key - self.assertRaises(ValidationError, person.validate) - - def test_datetime_validation(self): - """Ensure that invalid values cannot be assigned to datetime - fields. - """ - class LogEntry(Document): - time = DateTimeField() - - log = LogEntry() - log.time = datetime.datetime.now() - log.validate() - - log.time = datetime.date.today() - log.validate() - - log.time = datetime.datetime.now().isoformat(' ') - log.validate() - - if dateutil: - log.time = datetime.datetime.now().isoformat('T') - log.validate() - - log.time = -1 - self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' - self.assertRaises(ValidationError, log.validate) - - def test_date_validation(self): - """Ensure that invalid values cannot be assigned to datetime - fields. - """ - class LogEntry(Document): - time = DateField() - - log = LogEntry() - log.time = datetime.datetime.now() - log.validate() - - log.time = datetime.date.today() - log.validate() - - log.time = datetime.datetime.now().isoformat(' ') - log.validate() - - if dateutil: - log.time = datetime.datetime.now().isoformat('T') - log.validate() - - log.time = -1 - self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' - self.assertRaises(ValidationError, log.validate) - - def test_datetime_tz_aware_mark_as_changed(self): - from mongoengine import connection - - # Reset the connections - connection._connection_settings = {} - connection._connections = {} - connection._dbs = {} - - connect(db='mongoenginetest', tz_aware=True) - - class LogEntry(Document): - time = DateTimeField() - - LogEntry.drop_collection() - - LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save() - - log = LogEntry.objects.first() - log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(['time'], log._changed_fields) - - def test_datetime(self): - """Tests showing pymongo datetime fields handling of microseconds. - Microseconds are rounded to the nearest millisecond and pre UTC - handling is wonky. - - See: http://api.mongodb.org/python/current/api/bson/son.html#dt - """ - class LogEntry(Document): - date = DateTimeField() - - LogEntry.drop_collection() - - # Test can save dates - log = LogEntry() - log.date = datetime.date.today() - log.save() - log.reload() - self.assertEqual(log.date.date(), datetime.date.today()) - - # Post UTC - microseconds are rounded (down) nearest millisecond and - # dropped - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) - - # Post UTC - microseconds are rounded (down) nearest millisecond - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) - log.date = d1 - log.save() - log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) - - if not six.PY3: - # Pre UTC dates microseconds below 1000 are dropped - # This does not seem to be true in PY3 - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) - log.date = d1 - log.save() - log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) - - def test_date(self): - """Tests showing pymongo date fields - - See: http://api.mongodb.org/python/current/api/bson/son.html#dt - """ - class LogEntry(Document): - date = DateField() - - LogEntry.drop_collection() - - # Test can save dates - log = LogEntry() - log.date = datetime.date.today() - log.save() - log.reload() - self.assertEqual(log.date, datetime.date.today()) - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) - - if not six.PY3: - # Pre UTC dates microseconds below 1000 are dropped - # This does not seem to be true in PY3 - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) - - def test_datetime_usage(self): - """Tests for regular datetime fields""" - class LogEntry(Document): - date = DateTimeField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.validate() - log.save() - - for query in (d1, d1.isoformat(' ')): - log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) - - if dateutil: - log1 = LogEntry.objects.get(date=d1.isoformat('T')) - self.assertEqual(log, log1) - - # create additional 19 log entries for a total of 20 - for i in range(1971, 1990): - d = datetime.datetime(i, 1, 1, 0, 0, 1) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 20) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) - - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(1980, 1, 1), - date__gte=datetime.datetime(1975, 1, 1), - ) - self.assertEqual(logs.count(), 5) - - def test_date_usage(self): - """Tests for regular datetime fields""" - class LogEntry(Document): - date = DateField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.validate() - log.save() - - for query in (d1, d1.isoformat(' ')): - log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) - - if dateutil: - log1 = LogEntry.objects.get(date=d1.isoformat('T')) - self.assertEqual(log, log1) - - # create additional 19 log entries for a total of 20 - for i in range(1971, 1990): - d = datetime.datetime(i, 1, 1, 0, 0, 1) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 20) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) - def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements.""" - AccessLevelChoices = ( + access_level_choices = ( ('a', u'Administration'), ('b', u'Manager'), ('c', u'Staff'), @@ -984,7 +400,7 @@ class FieldTest(MongoDBTestCase): authors_as_lazy = ListField(LazyReferenceField(User)) generic = ListField(GenericReferenceField()) generic_as_lazy = ListField(GenericLazyReferenceField()) - access_list = ListField(choices=AccessLevelChoices, display_sep=', ') + access_list = ListField(choices=access_level_choices, display_sep=', ') User.drop_collection() BlogPost.drop_collection() @@ -1666,374 +1082,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual( Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) - def test_dict_field(self): - """Ensure that dict types work as expected.""" - class BlogPost(Document): - info = DictField() - - BlogPost.drop_collection() - - post = BlogPost() - post.info = 'my post' - self.assertRaises(ValidationError, post.validate) - - post.info = ['test', 'test'] - self.assertRaises(ValidationError, post.validate) - - post.info = {'$title': 'test'} - self.assertRaises(ValidationError, post.validate) - - post.info = {'nested': {'$title': 'test'}} - self.assertRaises(ValidationError, post.validate) - - post.info = {'the.title': 'test'} - self.assertRaises(ValidationError, post.validate) - - post.info = {'nested': {'the.title': 'test'}} - self.assertRaises(ValidationError, post.validate) - - post.info = {1: 'test'} - self.assertRaises(ValidationError, post.validate) - - post.info = {'title': 'test'} - post.save() - - post = BlogPost() - post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}} - post.save() - - post = BlogPost() - post.info = {'details': {'test': 'test'}} - post.save() - - post = BlogPost() - post.info = {'details': {'test': 3}} - post.save() - - self.assertEqual(BlogPost.objects.count(), 4) - self.assertEqual( - BlogPost.objects.filter(info__title__exact='test').count(), 1) - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact='test').count(), 1) - - post = BlogPost.objects.filter(info__title__exact='dollar_sign').first() - self.assertIn('te$t', post['info']['details']) - - # Confirm handles non strings or non existing keys - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact=5).count(), 0) - self.assertEqual( - BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) - - post = BlogPost.objects.create(info={'title': 'original'}) - post.info.update({'title': 'updated'}) - post.save() - post.reload() - self.assertEqual('updated', post.info['title']) - - post.info.setdefault('authors', []) - post.save() - post.reload() - self.assertEqual([], post.info['authors']) - - def test_dictfield_dump_document(self): - """Ensure a DictField can handle another document's dump.""" - class Doc(Document): - field = DictField() - - class ToEmbed(Document): - id = IntField(primary_key=True, default=1) - recursive = DictField() - - class ToEmbedParent(Document): - id = IntField(primary_key=True, default=1) - recursive = DictField() - - meta = {'allow_inheritance': True} - - class ToEmbedChild(ToEmbedParent): - pass - - to_embed_recursive = ToEmbed(id=1).save() - to_embed = ToEmbed( - id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() - doc = Doc(field=to_embed.to_mongo().to_dict()) - doc.save() - assert isinstance(doc.field, dict) - assert doc.field == {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}} - # Same thing with a Document with a _cls field - to_embed_recursive = ToEmbedChild(id=1).save() - to_embed_child = ToEmbedChild( - id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() - doc = Doc(field=to_embed_child.to_mongo().to_dict()) - doc.save() - assert isinstance(doc.field, dict) - assert doc.field == { - '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild', - 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}} - } - - def test_dictfield_strict(self): - """Ensure that dict field handles validation if provided a strict field type.""" - class Simple(Document): - mapping = DictField(field=IntField()) - - Simple.drop_collection() - - e = Simple() - e.mapping['someint'] = 1 - e.save() - - # try creating an invalid mapping - with self.assertRaises(ValidationError): - e.mapping['somestring'] = "abc" - e.save() - - def test_dictfield_complex(self): - """Ensure that the dict field can handle the complex types.""" - class SettingBase(EmbeddedDocument): - meta = {'allow_inheritance': True} - - class StringSetting(SettingBase): - value = StringField() - - class IntegerSetting(SettingBase): - value = IntField() - - class Simple(Document): - mapping = DictField() - - Simple.drop_collection() - - e = Simple() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', - 'float': 1.001, - 'complex': IntegerSetting(value=42), - 'list': [IntegerSetting(value=42), - StringSetting(value='foo')]} - e.save() - - e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping['somestring'], StringSetting) - self.assertIsInstance(e2.mapping['someint'], IntegerSetting) - - # Test querying - self.assertEqual( - Simple.objects.filter(mapping__someint__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) - - # Confirm can update - Simple.objects().update( - set__mapping={"someint": IntegerSetting(value=10)}) - Simple.objects().update( - set__mapping__nested_dict__list__1=StringSetting(value='Boo')) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) - - def test_atomic_update_dict_field(self): - """Ensure that the entire DictField can be atomically updated.""" - class Simple(Document): - mapping = DictField(field=ListField(IntField(required=True))) - - Simple.drop_collection() - - e = Simple() - e.mapping['someints'] = [1, 2] - e.save() - e.update(set__mapping={"ints": [3, 4]}) - e.reload() - self.assertEqual(BaseDict, type(e.mapping)) - self.assertEqual({"ints": [3, 4]}, e.mapping) - - # try creating an invalid mapping - with self.assertRaises(ValueError): - e.update(set__mapping={"somestrings": ["foo", "bar", ]}) - - def test_dictfield_with_referencefield_complex_nesting_cases(self): - """Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)""" - # Relates to Issue #1453 - class Doc(Document): - s = StringField() - - class Simple(Document): - mapping0 = DictField(ReferenceField(Doc, dbref=True)) - mapping1 = DictField(ReferenceField(Doc, dbref=False)) - mapping2 = DictField(ListField(ReferenceField(Doc, dbref=True))) - mapping3 = DictField(ListField(ReferenceField(Doc, dbref=False))) - mapping4 = DictField(DictField(field=ReferenceField(Doc, dbref=True))) - mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False))) - mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True)))) - mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False)))) - mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))) - mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))) - - Doc.drop_collection() - Simple.drop_collection() - - d = Doc(s='aa').save() - e = Simple() - e.mapping0['someint'] = e.mapping1['someint'] = d - e.mapping2['someint'] = e.mapping3['someint'] = [d] - e.mapping4['someint'] = e.mapping5['someint'] = {'d': d} - e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}] - e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}] - e.save() - - s = Simple.objects.first() - self.assertIsInstance(s.mapping0['someint'], Doc) - self.assertIsInstance(s.mapping1['someint'], Doc) - self.assertIsInstance(s.mapping2['someint'][0], Doc) - self.assertIsInstance(s.mapping3['someint'][0], Doc) - self.assertIsInstance(s.mapping4['someint']['d'], Doc) - self.assertIsInstance(s.mapping5['someint']['d'], Doc) - self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc) - self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc) - self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc) - self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc) - - def test_mapfield(self): - """Ensure that the MapField handles the declared type.""" - class Simple(Document): - mapping = MapField(IntField()) - - Simple.drop_collection() - - e = Simple() - e.mapping['someint'] = 1 - e.save() - - with self.assertRaises(ValidationError): - e.mapping['somestring'] = "abc" - e.save() - - with self.assertRaises(ValidationError): - class NoDeclaredType(Document): - mapping = MapField() - - def test_complex_mapfield(self): - """Ensure that the MapField can handle complex declared types.""" - class SettingBase(EmbeddedDocument): - meta = {"allow_inheritance": True} - - class StringSetting(SettingBase): - value = StringField() - - class IntegerSetting(SettingBase): - value = IntField() - - class Extensible(Document): - mapping = MapField(EmbeddedDocumentField(SettingBase)) - - Extensible.drop_collection() - - e = Extensible() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.save() - - e2 = Extensible.objects.get(id=e.id) - self.assertIsInstance(e2.mapping['somestring'], StringSetting) - self.assertIsInstance(e2.mapping['someint'], IntegerSetting) - - with self.assertRaises(ValidationError): - e.mapping['someint'] = 123 - e.save() - - def test_embedded_mapfield_db_field(self): - class Embedded(EmbeddedDocument): - number = IntField(default=0, db_field='i') - - class Test(Document): - my_map = MapField(field=EmbeddedDocumentField(Embedded), - db_field='x') - - Test.drop_collection() - - test = Test() - test.my_map['DICTIONARY_KEY'] = Embedded(number=1) - test.save() - - Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) - - test = Test.objects.get() - self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2) - doc = self.db.test.find_one() - self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) - - def test_mapfield_numerical_index(self): - """Ensure that MapField accept numeric strings as indexes.""" - class Embedded(EmbeddedDocument): - name = StringField() - - class Test(Document): - my_map = MapField(EmbeddedDocumentField(Embedded)) - - Test.drop_collection() - - test = Test() - test.my_map['1'] = Embedded(name='test') - test.save() - test.my_map['1'].name = 'test updated' - test.save() - - def test_map_field_lookup(self): - """Ensure MapField lookups succeed on Fields without a lookup - method. - """ - class Action(EmbeddedDocument): - operation = StringField() - object = StringField() - - class Log(Document): - name = StringField() - visited = MapField(DateTimeField()) - actions = MapField(EmbeddedDocumentField(Action)) - - Log.drop_collection() - Log(name="wilson", visited={'friends': datetime.datetime.now()}, - actions={'friends': Action(operation='drink', object='beer')}).save() - - self.assertEqual(1, Log.objects( - visited__friends__exists=True).count()) - - self.assertEqual(1, Log.objects( - actions__friends__operation='drink', - actions__friends__object='beer').count()) - - def test_map_field_unicode(self): - class Info(EmbeddedDocument): - description = StringField() - value_list = ListField(field=StringField()) - - class BlogPost(Document): - info_dict = MapField(field=EmbeddedDocumentField(Info)) - - BlogPost.drop_collection() - - tree = BlogPost(info_dict={ - u"éééé": { - 'description': u"VALUE: éééé" - } - }) - - tree.save() - - self.assertEqual( - BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, - u"VALUE: éééé" - ) - def test_embedded_db_field(self): class Embedded(EmbeddedDocument): number = IntField(default=0, db_field='i') @@ -2220,121 +1268,6 @@ class FieldTest(MongoDBTestCase): bar._fields['generic_ref']._auto_dereference = False self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'}) - def test_reference_validation(self): - """Ensure that invalid document objects cannot be assigned to - reference fields. - """ - class User(Document): - name = StringField() - - class BlogPost(Document): - content = StringField() - author = ReferenceField(User) - - User.drop_collection() - BlogPost.drop_collection() - - # Make sure ReferenceField only accepts a document class or a string - # with a document class name. - self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) - - user = User(name='Test User') - - # Ensure that the referenced object must have been saved - post1 = BlogPost(content='Chips and gravy taste good.') - post1.author = user - self.assertRaises(ValidationError, post1.save) - - # Check that an invalid object type cannot be used - post2 = BlogPost(content='Chips and chilli taste good.') - post1.author = post2 - self.assertRaises(ValidationError, post1.validate) - - # Ensure ObjectID's are accepted as references - user_object_id = user.pk - post3 = BlogPost(content="Chips and curry sauce taste good.") - post3.author = user_object_id - post3.save() - - # Make sure referencing a saved document of the right type works - user.save() - post1.author = user - post1.save() - - # Make sure referencing a saved document of the *wrong* type fails - post2.save() - post1.author = post2 - self.assertRaises(ValidationError, post1.validate) - - def test_objectid_reference_fields(self): - """Make sure storing Object ID references works.""" - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1.pk).save() - - p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) - - def test_dbref_reference_fields(self): - """Make sure storing references as bson.dbref.DBRef works.""" - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=True) - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1).save() - - self.assertEqual( - Person._get_collection().find_one({'name': 'Ross'})['parent'], - DBRef('person', p1.pk) - ) - - p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) - - def test_dbref_to_mongo(self): - """Make sure that calling to_mongo on a ReferenceField which - has dbref=False, but actually actually contains a DBRef returns - an ID of that DBRef. - """ - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=False) - - p = Person( - name='Steve', - parent=DBRef('person', 'abcdefghijklmnop') - ) - self.assertEqual(p.to_mongo(), SON([ - ('name', u'Steve'), - ('parent', 'abcdefghijklmnop') - ])) - - def test_objectid_reference_fields(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=False) - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1).save() - - col = Person._get_collection() - data = col.find_one({'name': 'Ross'}) - self.assertEqual(data['parent'], p1.pk) - - p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) - def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -2451,99 +1384,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual(tree.children[0].children[0].name, second_child.name) self.assertEqual(tree.children[0].children[1].name, third_child.name) - def test_undefined_reference(self): - """Ensure that ReferenceFields may reference undefined Documents. - """ - class Product(Document): - name = StringField() - company = ReferenceField('Company') - - class Company(Document): - name = StringField() - - Product.drop_collection() - Company.drop_collection() - - ten_gen = Company(name='10gen') - ten_gen.save() - mongodb = Product(name='MongoDB', company=ten_gen) - mongodb.save() - - me = Product(name='MongoEngine') - me.save() - - obj = Product.objects(company=ten_gen).first() - self.assertEqual(obj, mongodb) - self.assertEqual(obj.company, ten_gen) - - obj = Product.objects(company=None).first() - self.assertEqual(obj, me) - - obj = Product.objects.get(company=None) - self.assertEqual(obj, me) - - def test_reference_query_conversion(self): - """Ensure that ReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = ReferenceField(Member, dbref=False) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - def test_reference_query_conversion_dbref(self): - """Ensure that ReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = ReferenceField(Member, dbref=True) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - def test_drop_abstract_document(self): """Ensure that an abstract document cannot be dropped given it has no underlying collection. @@ -2575,7 +1415,7 @@ class FieldTest(MongoDBTestCase): brother = Brother(name="Bob", sibling=sister) brother.save() - self.assertEquals(Brother.objects[0].sibling.name, sister.name) + self.assertEqual(Brother.objects[0].sibling.name, sister.name) def test_reference_abstract_class(self): """Ensure that an abstract class instance cannot be used in the @@ -2971,79 +1811,6 @@ class FieldTest(MongoDBTestCase): with self.assertRaises(ValidationError): shirt.validate() - def test_choices_validation_documents(self): - """ - Ensure fields with document choices validate given a valid choice. - """ - class UserComments(EmbeddedDocument): - author = StringField() - message = StringField() - - class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(UserComments,)) - ) - - # Ensure Validation Passes - BlogPost(comments=[ - UserComments(author='user2', message='message2'), - ]).save() - - def test_choices_validation_documents_invalid(self): - """ - Ensure fields with document choices validate given an invalid choice. - This should throw a ValidationError exception. - """ - class UserComments(EmbeddedDocument): - author = StringField() - message = StringField() - - class ModeratorComments(EmbeddedDocument): - author = StringField() - message = StringField() - - class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(UserComments,)) - ) - - # Single Entry Failure - post = BlogPost(comments=[ - ModeratorComments(author='mod1', message='message1'), - ]) - self.assertRaises(ValidationError, post.save) - - # Mixed Entry Failure - post = BlogPost(comments=[ - ModeratorComments(author='mod1', message='message1'), - UserComments(author='user2', message='message2'), - ]) - self.assertRaises(ValidationError, post.save) - - def test_choices_validation_documents_inheritance(self): - """ - Ensure fields with document choices validate given subclass of choice. - """ - class Comments(EmbeddedDocument): - meta = { - 'abstract': True - } - author = StringField() - message = StringField() - - class UserComments(Comments): - pass - - class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(Comments,)) - ) - - # Save Valid EmbeddedDocument Type - BlogPost(comments=[ - UserComments(author='user2', message='message2'), - ]).save() - def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -3160,362 +1927,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual(error_dict['size'], SIZE_MESSAGE) self.assertEqual(error_dict['color'], COLOR_MESSAGE) - def test_ensure_unique_default_instances(self): - """Ensure that every field has it's own unique default instance.""" - class D(Document): - data = DictField() - data2 = DictField(default=lambda: {}) - - d1 = D() - d1.data['foo'] = 'bar' - d1.data2['foo'] = 'bar' - d2 = D() - self.assertEqual(d2.data, {}) - self.assertEqual(d2.data2, {}) - - def test_sequence_field(self): - class Person(Document): - id = SequenceField(primary_key=True) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 1000) - - def test_sequence_field_get_next_value(self): - class Person(Document): - id = SequenceField(primary_key=True) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - self.assertEqual(Person.id.get_next_value(), 11) - self.db['mongoengine.counters'].drop() - - self.assertEqual(Person.id.get_next_value(), 1) - - class Person(Document): - id = SequenceField(primary_key=True, value_decorator=str) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - self.assertEqual(Person.id.get_next_value(), '11') - self.db['mongoengine.counters'].drop() - - self.assertEqual(Person.id.get_next_value(), '1') - - def test_sequence_field_sequence_name(self): - class Person(Document): - id = SequenceField(primary_key=True, sequence_name='jelly') - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 10) - - Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 1000) - - def test_multiple_sequence_fields(self): - class Person(Document): - id = SequenceField(primary_key=True) - counter = SequenceField() - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - counters = [i.counter for i in Person.objects] - self.assertEqual(counters, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 1000) - - Person.counter.set_next_value(999) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) - self.assertEqual(c['next'], 999) - - def test_sequence_fields_reload(self): - class Animal(Document): - counter = SequenceField() - name = StringField() - - self.db['mongoengine.counters'].drop() - Animal.drop_collection() - - a = Animal(name="Boi").save() - - self.assertEqual(a.counter, 1) - a.reload() - self.assertEqual(a.counter, 1) - - a.counter = None - self.assertEqual(a.counter, 2) - a.save() - - self.assertEqual(a.counter, 2) - - a = Animal.objects.first() - self.assertEqual(a.counter, 2) - a.reload() - self.assertEqual(a.counter, 2) - - def test_multiple_sequence_fields_on_docs(self): - class Animal(Document): - id = SequenceField(primary_key=True) - name = StringField() - - class Person(Document): - id = SequenceField(primary_key=True) - name = StringField() - - self.db['mongoengine.counters'].drop() - Animal.drop_collection() - Person.drop_collection() - - for x in range(10): - Animal(name="Animal %s" % x).save() - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - id = [i.id for i in Animal.objects] - self.assertEqual(id, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) - self.assertEqual(c['next'], 10) - - def test_sequence_field_value_decorator(self): - class Person(Document): - id = SequenceField(primary_key=True, value_decorator=str) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - p = Person(name="Person %s" % x) - p.save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, map(str, range(1, 11))) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - def test_embedded_sequence_field(self): - class Comment(EmbeddedDocument): - id = SequenceField() - content = StringField(required=True) - - class Post(Document): - title = StringField(required=True) - comments = ListField(EmbeddedDocumentField(Comment)) - - self.db['mongoengine.counters'].drop() - Post.drop_collection() - - Post(title="MongoEngine", - comments=[Comment(content="NoSQL Rocks"), - Comment(content="MongoEngine Rocks")]).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'}) - self.assertEqual(c['next'], 2) - post = Post.objects.first() - self.assertEqual(1, post.comments[0].id) - self.assertEqual(2, post.comments[1].id) - - def test_inherited_sequencefield(self): - class Base(Document): - name = StringField() - counter = SequenceField() - meta = {'abstract': True} - - class Foo(Base): - pass - - class Bar(Base): - pass - - bar = Bar(name='Bar') - bar.save() - - foo = Foo(name='Foo') - foo.save() - - self.assertTrue('base.counter' in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertFalse(('foo.counter' or 'bar.counter') in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertNotEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields['counter'].owner_document, Base) - self.assertEqual(bar._fields['counter'].owner_document, Base) - - def test_no_inherited_sequencefield(self): - class Base(Document): - name = StringField() - meta = {'abstract': True} - - class Foo(Base): - counter = SequenceField() - - class Bar(Base): - counter = SequenceField() - - bar = Bar(name='Bar') - bar.save() - - foo = Foo(name='Foo') - foo.save() - - self.assertFalse('base.counter' in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertTrue(('foo.counter' and 'bar.counter') in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields['counter'].owner_document, Foo) - self.assertEqual(bar._fields['counter'].owner_document, Bar) - - def test_generic_embedded_document(self): - class Car(EmbeddedDocument): - name = StringField() - - class Dish(EmbeddedDocument): - food = StringField(required=True) - number = IntField() - - class Person(Document): - name = StringField() - like = GenericEmbeddedDocumentField() - - Person.drop_collection() - - person = Person(name='Test User') - person.like = Car(name='Fiat') - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.like, Car) - - person.like = Dish(food="arroz", number=15) - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.like, Dish) - - def test_generic_embedded_document_choices(self): - """Ensure you can limit GenericEmbeddedDocument choices.""" - class Car(EmbeddedDocument): - name = StringField() - - class Dish(EmbeddedDocument): - food = StringField(required=True) - number = IntField() - - class Person(Document): - name = StringField() - like = GenericEmbeddedDocumentField(choices=(Dish,)) - - Person.drop_collection() - - person = Person(name='Test User') - person.like = Car(name='Fiat') - self.assertRaises(ValidationError, person.validate) - - person.like = Dish(food="arroz", number=15) - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.like, Dish) - - def test_generic_list_embedded_document_choices(self): - """Ensure you can limit GenericEmbeddedDocument choices inside - a list field. - """ - class Car(EmbeddedDocument): - name = StringField() - - class Dish(EmbeddedDocument): - food = StringField(required=True) - number = IntField() - - class Person(Document): - name = StringField() - likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,))) - - Person.drop_collection() - - person = Person(name='Test User') - person.likes = [Car(name='Fiat')] - self.assertRaises(ValidationError, person.validate) - - person.likes = [Dish(food="arroz", number=15)] - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.likes[0], Dish) - def test_recursive_validation(self): """Ensure that a validation result to_dict is available.""" class Author(EmbeddedDocument): @@ -3557,117 +1968,6 @@ class FieldTest(MongoDBTestCase): post.comments[1].content = 'here we go' post.validate() - def test_email_field(self): - class User(Document): - email = EmailField() - - user = User(email='ross@example.com') - user.validate() - - user = User(email='ross@example.co.uk') - user.validate() - - user = User(email=('Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S' - 'aJIazqqWkm7.net')) - user.validate() - - user = User(email='new-tld@example.technology') - user.validate() - - user = User(email='ross@example.com.') - self.assertRaises(ValidationError, user.validate) - - # unicode domain - user = User(email=u'user@пример.рф') - user.validate() - - # invalid unicode domain - user = User(email=u'user@пример') - self.assertRaises(ValidationError, user.validate) - - # invalid data type - user = User(email=123) - self.assertRaises(ValidationError, user.validate) - - def test_email_field_unicode_user(self): - # Don't run this test on pypy3, which doesn't support unicode regex: - # https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode - if sys.version_info[:2] == (3, 2): - raise SkipTest('unicode email addresses are not supported on PyPy 3') - - class User(Document): - email = EmailField() - - # unicode user shouldn't validate by default... - user = User(email=u'Dörte@Sörensen.example.com') - self.assertRaises(ValidationError, user.validate) - - # ...but it should be fine with allow_utf8_user set to True - class User(Document): - email = EmailField(allow_utf8_user=True) - - user = User(email=u'Dörte@Sörensen.example.com') - user.validate() - - def test_email_field_domain_whitelist(self): - class User(Document): - email = EmailField() - - # localhost domain shouldn't validate by default... - user = User(email='me@localhost') - self.assertRaises(ValidationError, user.validate) - - # ...but it should be fine if it's whitelisted - class User(Document): - email = EmailField(domain_whitelist=['localhost']) - - user = User(email='me@localhost') - user.validate() - - def test_email_field_ip_domain(self): - class User(Document): - email = EmailField() - - valid_ipv4 = 'email@[127.0.0.1]' - valid_ipv6 = 'email@[2001:dB8::1]' - invalid_ip = 'email@[324.0.0.1]' - - # IP address as a domain shouldn't validate by default... - user = User(email=valid_ipv4) - self.assertRaises(ValidationError, user.validate) - - user = User(email=valid_ipv6) - self.assertRaises(ValidationError, user.validate) - - user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) - - # ...but it should be fine with allow_ip_domain set to True - class User(Document): - email = EmailField(allow_ip_domain=True) - - user = User(email=valid_ipv4) - user.validate() - - user = User(email=valid_ipv6) - user.validate() - - # invalid IP should still fail validation - user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) - - def test_email_field_honors_regex(self): - class User(Document): - email = EmailField(regex=r'\w+@example.com') - - # Fails regex validation - user = User(email='me@foo.com') - self.assertRaises(ValidationError, user.validate) - - # Passes regex validation - user = User(email='me@example.com') - self.assertIsNone(user.validate()) - def test_tuples_as_tuples(self): """Ensure that tuples remain tuples when they are inside a ComplexBaseField. @@ -3703,7 +2003,7 @@ class FieldTest(MongoDBTestCase): field_1 = StringField(db_field='f') class Doc(Document): - my_id = IntField(required=True, unique=True, primary_key=True) + my_id = IntField(primary_key=True) embed_me = DynamicField(db_field='e') field_x = StringField(db_field='x') @@ -3725,7 +2025,7 @@ class FieldTest(MongoDBTestCase): field_1 = StringField(db_field='f') class Doc(Document): - my_id = IntField(required=True, unique=True, primary_key=True) + my_id = IntField(primary_key=True) embed_me = DynamicField(db_field='e') field_x = StringField(db_field='x') @@ -3758,45 +2058,15 @@ class FieldTest(MongoDBTestCase): to_embed = ToEmbed(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed) doc.save() - assert isinstance(doc.field, ToEmbed) - assert doc.field == to_embed + self.assertIsInstance(doc.field, ToEmbed) + self.assertEqual(doc.field, to_embed) # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed_child) doc.save() - assert isinstance(doc.field, ToEmbedChild) - assert doc.field == to_embed_child - - def test_dict_field_invalid_dict_value(self): - class DictFieldTest(Document): - dictionary = DictField(required=True) - - DictFieldTest.drop_collection() - - test = DictFieldTest(dictionary=None) - test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) - - test = DictFieldTest(dictionary=False) - test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) - - def test_dict_field_raises_validation_error_if_wrongly_assign_embedded_doc(self): - class DictFieldTest(Document): - dictionary = DictField(required=True) - - DictFieldTest.drop_collection() - - class Embedded(EmbeddedDocument): - name = StringField() - - embed = Embedded(name='garbage') - doc = DictFieldTest(dictionary=embed) - with self.assertRaises(ValidationError) as ctx_err: - doc.validate() - self.assertIn("'dictionary'", str(ctx_err.exception)) - self.assertIn('Only dictionaries may be used in a DictField', str(ctx_err.exception)) + self.assertIsInstance(doc.field, ToEmbedChild) + self.assertEqual(doc.field, to_embed_child) def test_cls_field(self): class Animal(Document): @@ -3818,8 +2088,8 @@ class FieldTest(MongoDBTestCase): Dog().save() Fish().save() Human().save() - self.assertEquals(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2) - self.assertEquals(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) + self.assertEqual(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2) + self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) def test_sparse_field(self): class Doc(Document): @@ -3852,19 +2122,6 @@ class FieldTest(MongoDBTestCase): with self.assertRaises(FieldDoesNotExist): Doc(bar='test') - def test_long_field_is_considered_as_int64(self): - """ - Tests that long fields are stored as long in mongo, even if long - value is small enough to be an int. - """ - class TestLongFieldConsideredAsInt64(Document): - some_long = LongField() - - doc = TestLongFieldConsideredAsInt64(some_long=42).save() - db = get_db() - self.assertIsInstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64) - self.assertIsInstance(doc.some_long, six.integer_types) - class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): @@ -4335,1173 +2592,5 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) -class TestEmbeddedDocumentField(MongoDBTestCase): - def test___init___(self): - class MyDoc(EmbeddedDocument): - name = StringField() - - field = EmbeddedDocumentField(MyDoc) - self.assertEqual(field.document_type_obj, MyDoc) - - field2 = EmbeddedDocumentField('MyDoc') - self.assertEqual(field2.document_type_obj, 'MyDoc') - - def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): - with self.assertRaises(ValidationError): - EmbeddedDocumentField(dict) - - def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): - - class MyDoc(Document): - name = StringField() - - emb = EmbeddedDocumentField('MyDoc') - with self.assertRaises(ValidationError) as ctx: - emb.document_type - self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) - - def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): - # Relates to #1661 - class MyDoc(Document): - name = StringField() - - with self.assertRaises(ValidationError): - class MyFailingDoc(Document): - emb = EmbeddedDocumentField(MyDoc) - - with self.assertRaises(ValidationError): - class MyFailingdoc2(Document): - emb = EmbeddedDocumentField('MyDoc') - - -class CachedReferenceFieldTest(MongoDBTestCase): - - def test_cached_reference_field_get_and_save(self): - """ - Tests #1047: CachedReferenceField creates DBRefs on to_python, - but can't save them on to_mongo. - """ - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField(Animal) - - Animal.drop_collection() - Ocorrence.drop_collection() - - Ocorrence(person="testte", - animal=Animal(name="Leopard", tag="heavy").save()).save() - p = Ocorrence.objects.get() - p.person = 'new_testte' - p.save() - - def test_cached_reference_fields(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag']) - - Animal.drop_collection() - Ocorrence.drop_collection() - - a = Animal(name="Leopard", tag="heavy") - a.save() - - self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) - o = Ocorrence(person="teste", animal=a) - o.save() - - p = Ocorrence(person="Wilson") - p.save() - - self.assertEqual(Ocorrence.objects(animal=None).count(), 1) - - self.assertEqual( - a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) - - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - - # counts - Ocorrence(person="teste 2").save() - Ocorrence(person="teste 3").save() - - count = Ocorrence.objects(animal__tag='heavy').count() - self.assertEqual(count, 1) - - ocorrence = Ocorrence.objects(animal__tag='heavy').first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) - - def test_cached_reference_field_decimal(self): - class PersonAuto(Document): - name = StringField() - salary = DecimalField() - - class SocialTest(Document): - group = StringField() - person = CachedReferenceField( - PersonAuto, - fields=('salary',)) - - PersonAuto.drop_collection() - SocialTest.drop_collection() - - p = PersonAuto(name="Alberto", salary=Decimal('7000.00')) - p.save() - - s = SocialTest(group="dev", person=p) - s.save() - - self.assertEqual( - SocialTest.objects._collection.find_one({'person.salary': 7000.00}), { - '_id': s.pk, - 'group': s.group, - 'person': { - '_id': p.pk, - 'salary': 7000.00 - } - }) - - def test_cached_reference_field_reference(self): - class Group(Document): - name = StringField() - - class Person(Document): - name = StringField() - group = ReferenceField(Group) - - class SocialData(Document): - obs = StringField() - tags = ListField( - StringField()) - person = CachedReferenceField( - Person, - fields=('group',)) - - Group.drop_collection() - Person.drop_collection() - SocialData.drop_collection() - - g1 = Group(name='dev') - g1.save() - - g2 = Group(name="designers") - g2.save() - - p1 = Person(name="Alberto", group=g1) - p1.save() - - p2 = Person(name="Andre", group=g1) - p2.save() - - p3 = Person(name="Afro design", group=g2) - p3.save() - - s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2']) - s1.save() - - s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4']) - s2.save() - - self.assertEqual(SocialData.objects._collection.find_one( - {'tags': 'tag2'}), { - '_id': s1.pk, - 'obs': 'testing 123', - 'tags': ['tag1', 'tag2'], - 'person': { - '_id': p1.pk, - 'group': g1.pk - } - }) - - self.assertEqual(SocialData.objects(person__group=g2).count(), 1) - self.assertEqual(SocialData.objects(person__group=g2).first(), s2) - - def test_cached_reference_field_push_with_fields(self): - class Product(Document): - name = StringField() - - Product.drop_collection() - - class Basket(Document): - products = ListField(CachedReferenceField(Product, fields=['name'])) - - Basket.drop_collection() - product1 = Product(name='abc').save() - product2 = Product(name='def').save() - basket = Basket(products=[product1]).save() - self.assertEqual( - Basket.objects._collection.find_one(), - { - '_id': basket.pk, - 'products': [ - { - '_id': product1.pk, - 'name': product1.name - } - ] - } - ) - # push to list - basket.update(push__products=product2) - basket.reload() - self.assertEqual( - Basket.objects._collection.find_one(), - { - '_id': basket.pk, - 'products': [ - { - '_id': product1.pk, - 'name': product1.name - }, - { - '_id': product2.pk, - 'name': product2.name - } - ] - } - ) - - def test_cached_reference_field_update_all(self): - class Person(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) - name = StringField() - tp = StringField( - choices=TYPES - ) - - father = CachedReferenceField('self', fields=('tp',)) - - Person.drop_collection() - - a1 = Person(name="Wilson Father", tp="pj") - a1.save() - - a2 = Person(name='Wilson Junior', tp='pf', father=a1) - a2.save() - - self.assertEqual(dict(a2.to_mongo()), { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": { - "_id": a1.pk, - "tp": u"pj" - } - }) - - self.assertEqual(Person.objects(father=a1)._query, { - 'father._id': a1.pk - }) - self.assertEqual(Person.objects(father=a1).count(), 1) - - Person.objects.update(set__tp="pf") - Person.father.sync_all() - - a2.reload() - self.assertEqual(dict(a2.to_mongo()), { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": { - "_id": a1.pk, - "tp": u"pf" - } - }) - - def test_cached_reference_fields_on_embedded_documents(self): - with self.assertRaises(InvalidDocumentError): - class Test(Document): - name = StringField() - - type('WrongEmbeddedDocument', ( - EmbeddedDocument,), { - 'test': CachedReferenceField(Test) - }) - - def test_cached_reference_auto_sync(self): - class Person(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) - name = StringField() - tp = StringField( - choices=TYPES - ) - - father = CachedReferenceField('self', fields=('tp',)) - - Person.drop_collection() - - a1 = Person(name="Wilson Father", tp="pj") - a1.save() - - a2 = Person(name='Wilson Junior', tp='pf', father=a1) - a2.save() - - a1.tp = 'pf' - a1.save() - - a2.reload() - self.assertEqual(dict(a2.to_mongo()), { - '_id': a2.pk, - 'name': 'Wilson Junior', - 'tp': 'pf', - 'father': { - '_id': a1.pk, - 'tp': 'pf' - } - }) - - def test_cached_reference_auto_sync_disabled(self): - class Persone(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) - name = StringField() - tp = StringField( - choices=TYPES - ) - - father = CachedReferenceField( - 'self', fields=('tp',), auto_sync=False) - - Persone.drop_collection() - - a1 = Persone(name="Wilson Father", tp="pj") - a1.save() - - a2 = Persone(name='Wilson Junior', tp='pf', father=a1) - a2.save() - - a1.tp = 'pf' - a1.save() - - self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), { - '_id': a2.pk, - 'name': 'Wilson Junior', - 'tp': 'pf', - 'father': { - '_id': a1.pk, - 'tp': 'pj' - } - }) - - def test_cached_reference_embedded_fields(self): - class Owner(EmbeddedDocument): - TPS = ( - ('n', "Normal"), - ('u', "Urgent") - ) - name = StringField() - tp = StringField( - verbose_name="Type", - db_field="t", - choices=TPS) - - class Animal(Document): - name = StringField() - tag = StringField() - - owner = EmbeddedDocumentField(Owner) - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag', 'owner.tp']) - - Animal.drop_collection() - Ocorrence.drop_collection() - - a = Animal(name="Leopard", tag="heavy", - owner=Owner(tp='u', name="Wilson Júnior") - ) - a.save() - - o = Ocorrence(person="teste", animal=a) - o.save() - self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { - '_id': a.pk, - 'tag': 'heavy', - 'owner': { - 't': 'u' - } - }) - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') - - # counts - Ocorrence(person="teste 2").save() - Ocorrence(person="teste 3").save() - - count = Ocorrence.objects( - animal__tag='heavy', animal__owner__tp='u').count() - self.assertEqual(count, 1) - - ocorrence = Ocorrence.objects( - animal__tag='heavy', - animal__owner__tp='u').first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) - - def test_cached_reference_embedded_list_fields(self): - class Owner(EmbeddedDocument): - name = StringField() - tags = ListField(StringField()) - - class Animal(Document): - name = StringField() - tag = StringField() - - owner = EmbeddedDocumentField(Owner) - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag', 'owner.tags']) - - Animal.drop_collection() - Ocorrence.drop_collection() - - a = Animal(name="Leopard", tag="heavy", - owner=Owner(tags=['cool', 'funny'], - name="Wilson Júnior") - ) - a.save() - - o = Ocorrence(person="teste 2", animal=a) - o.save() - self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { - '_id': a.pk, - 'tag': 'heavy', - 'owner': { - 'tags': ['cool', 'funny'] - } - }) - - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - self.assertEqual(o.to_mongo()['animal']['owner']['tags'], - ['cool', 'funny']) - - # counts - Ocorrence(person="teste 2").save() - Ocorrence(person="teste 3").save() - - query = Ocorrence.objects( - animal__tag='heavy', animal__owner__tags='cool')._query - self.assertEqual( - query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) - - ocorrence = Ocorrence.objects( - animal__tag='heavy', - animal__owner__tags='cool').first() - self.assertEqual(ocorrence.person, "teste 2") - self.assertIsInstance(ocorrence.animal, Animal) - - -class LazyReferenceFieldTest(MongoDBTestCase): - def test_lazy_reference_config(self): - # Make sure ReferenceField only accepts a document class or a string - # with a document class name. - self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) - - def test_lazy_reference_simple(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(person="test", animal=animal).save() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) - # `fetch` keep cache on referenced document by default... - animal.tag = "not so heavy" - animal.save() - double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") - # ...unless specified otherwise - fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") - - def test_lazy_reference_fetch_invalid_ref(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(person="test", animal=animal).save() - animal.delete() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(DoesNotExist): - p.animal.fetch() - - def test_lazy_reference_set(self): - class Animal(Document): - meta = {'allow_inheritance': True} - - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - class SubAnimal(Animal): - nick = StringField() - - animal = Animal(name="Leopard", tag="heavy").save() - sub_animal = SubAnimal(nick='doggo', name='dog').save() - for ref in ( - animal, - animal.pk, - DBRef(animal._get_collection_name(), animal.pk), - LazyReference(Animal, animal.pk), - - sub_animal, - sub_animal.pk, - DBRef(sub_animal._get_collection_name(), sub_animal.pk), - LazyReference(SubAnimal, sub_animal.pk), - ): - p = Ocurrence(person="test", animal=ref).save() - p.reload() - self.assertIsInstance(p.animal, LazyReference) - p.animal.fetch() - - def test_lazy_reference_bad_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - class BadDoc(Document): - pass - - animal = Animal(name="Leopard", tag="heavy").save() - baddoc = BadDoc().save() - for bad in ( - 42, - 'foo', - baddoc, - DBRef(baddoc._get_collection_name(), animal.pk), - LazyReference(BadDoc, animal.pk) - ): - with self.assertRaises(ValidationError): - p = Ocurrence(person="test", animal=bad).save() - - def test_lazy_reference_query_conversion(self): - """Ensure that LazyReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = LazyReferenceField(Member, dbref=False) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - # Same thing by passing a LazyReference instance - post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) - - def test_lazy_reference_query_conversion_dbref(self): - """Ensure that LazyReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = LazyReferenceField(Member, dbref=True) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - # Same thing by passing a LazyReference instance - post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) - - def test_lazy_reference_passthrough(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - animal = LazyReferenceField(Animal, passthrough=False) - animal_passthrough = LazyReferenceField(Animal, passthrough=True) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(animal=animal, animal_passthrough=animal).save() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(KeyError): - p.animal['name'] - with self.assertRaises(AttributeError): - p.animal.name - self.assertEqual(p.animal.pk, animal.pk) - - self.assertEqual(p.animal_passthrough.name, "Leopard") - self.assertEqual(p.animal_passthrough['name'], "Leopard") - - # Should not be able to access referenced document's methods - with self.assertRaises(AttributeError): - p.animal.save - with self.assertRaises(KeyError): - p.animal['save'] - - def test_lazy_reference_not_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - Ocurrence(person='foo').save() - p = Ocurrence.objects.get() - self.assertIs(p.animal, None) - - def test_lazy_reference_equality(self): - class Animal(Document): - name = StringField() - tag = StringField() - - Animal.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - animalref = LazyReference(Animal, animal.pk) - self.assertEqual(animal, animalref) - self.assertEqual(animalref, animal) - - other_animalref = LazyReference(Animal, ObjectId("54495ad94c934721ede76f90")) - self.assertNotEqual(animal, other_animalref) - self.assertNotEqual(other_animalref, animal) - - def test_lazy_reference_embedded(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class EmbeddedOcurrence(EmbeddedDocument): - in_list = ListField(LazyReferenceField(Animal)) - direct = LazyReferenceField(Animal) - - class Ocurrence(Document): - in_list = ListField(LazyReferenceField(Animal)) - in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) - direct = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal1 = Animal('doggo').save() - animal2 = Animal('cheeta').save() - - def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) - for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) - for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) - - occ = Ocurrence( - in_list=[animal1, animal2], - in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, - direct=animal1 - ).save() - check_fields_type(occ) - occ.reload() - check_fields_type(occ) - occ.direct = animal1.id - occ.in_list = [animal1.id, animal2.id] - occ.in_embedded.direct = animal1.id - occ.in_embedded.in_list = [animal1.id, animal2.id] - check_fields_type(occ) - - -class GenericLazyReferenceFieldTest(MongoDBTestCase): - def test_generic_lazy_reference_simple(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(person="test", animal=animal).save() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) - # `fetch` keep cache on referenced document by default... - animal.tag = "not so heavy" - animal.save() - double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") - # ...unless specified otherwise - fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") - - def test_generic_lazy_reference_choices(self): - class Animal(Document): - name = StringField() - - class Vegetal(Document): - name = StringField() - - class Mineral(Document): - name = StringField() - - class Ocurrence(Document): - living_thing = GenericLazyReferenceField(choices=[Animal, Vegetal]) - thing = GenericLazyReferenceField() - - Animal.drop_collection() - Vegetal.drop_collection() - Mineral.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard").save() - vegetal = Vegetal(name="Oak").save() - mineral = Mineral(name="Granite").save() - - occ_animal = Ocurrence(living_thing=animal, thing=animal).save() - occ_vegetal = Ocurrence(living_thing=vegetal, thing=vegetal).save() - with self.assertRaises(ValidationError): - Ocurrence(living_thing=mineral).save() - - occ = Ocurrence.objects.get(living_thing=animal) - self.assertEqual(occ, occ_animal) - self.assertIsInstance(occ.thing, LazyReference) - self.assertIsInstance(occ.living_thing, LazyReference) - - occ.thing = vegetal - occ.living_thing = vegetal - occ.save() - - occ.thing = mineral - occ.living_thing = mineral - with self.assertRaises(ValidationError): - occ.save() - - def test_generic_lazy_reference_set(self): - class Animal(Document): - meta = {'allow_inheritance': True} - - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - class SubAnimal(Animal): - nick = StringField() - - animal = Animal(name="Leopard", tag="heavy").save() - sub_animal = SubAnimal(nick='doggo', name='dog').save() - for ref in ( - animal, - LazyReference(Animal, animal.pk), - {'_cls': 'Animal', '_ref': DBRef(animal._get_collection_name(), animal.pk)}, - - sub_animal, - LazyReference(SubAnimal, sub_animal.pk), - {'_cls': 'SubAnimal', '_ref': DBRef(sub_animal._get_collection_name(), sub_animal.pk)}, - ): - p = Ocurrence(person="test", animal=ref).save() - p.reload() - self.assertIsInstance(p.animal, (LazyReference, Document)) - p.animal.fetch() - - def test_generic_lazy_reference_bad_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField(choices=['Animal']) - - Animal.drop_collection() - Ocurrence.drop_collection() - - class BadDoc(Document): - pass - - animal = Animal(name="Leopard", tag="heavy").save() - baddoc = BadDoc().save() - for bad in ( - 42, - 'foo', - baddoc, - LazyReference(BadDoc, animal.pk) - ): - with self.assertRaises(ValidationError): - p = Ocurrence(person="test", animal=bad).save() - - def test_generic_lazy_reference_query_conversion(self): - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = GenericLazyReferenceField() - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - # Same thing by passing a LazyReference instance - post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) - - def test_generic_lazy_reference_not_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - Ocurrence(person='foo').save() - p = Ocurrence.objects.get() - self.assertIs(p.animal, None) - - def test_generic_lazy_reference_embedded(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class EmbeddedOcurrence(EmbeddedDocument): - in_list = ListField(GenericLazyReferenceField()) - direct = GenericLazyReferenceField() - - class Ocurrence(Document): - in_list = ListField(GenericLazyReferenceField()) - in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) - direct = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal1 = Animal('doggo').save() - animal2 = Animal('cheeta').save() - - def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) - for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) - for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) - - occ = Ocurrence( - in_list=[animal1, animal2], - in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, - direct=animal1 - ).save() - check_fields_type(occ) - occ.reload() - check_fields_type(occ) - animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} - animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} - occ.direct = animal1_ref - occ.in_list = [animal1_ref, animal2_ref] - occ.in_embedded.direct = animal1_ref - occ.in_embedded.in_list = [animal1_ref, animal2_ref] - check_fields_type(occ) - - -class ComplexDateTimeFieldTest(MongoDBTestCase): - def test_complexdatetime_storage(self): - """Tests for complex datetime fields - which can handle - microseconds without rounding. - """ - class LogEntry(Document): - date = ComplexDateTimeField() - date_with_dots = ComplexDateTimeField(separator='.') - - LogEntry.drop_collection() - - # Post UTC - microseconds are rounded (down) nearest millisecond and - # dropped - with default datetimefields - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Post UTC - microseconds are rounded (down) nearest millisecond - with - # default datetimefields - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Pre UTC dates microseconds below 1000 are dropped - with default - # datetimefields - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Pre UTC microseconds above 1000 is wonky - with default datetimefields - # log.date has an invalid microsecond value so I can't construct - # a date to compare. - for i in range(1001, 3113, 33): - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) - - # Test string padding - microsecond = map(int, [math.pow(10, x) for x in range(6)]) - mm = dd = hh = ii = ss = [1, 10] - - for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): - stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] - self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) - - # Test separator - stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] - self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) - - def test_complexdatetime_usage(self): - """Tests for complex datetime fields - which can handle - microseconds without rounding. - """ - class LogEntry(Document): - date = ComplexDateTimeField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1950, 1, 1, 0, 0, 1, 999) - log = LogEntry() - log.date = d1 - log.save() - - log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) - - # create extra 59 log entries for a total of 60 - for i in range(1951, 2010): - d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 60) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 59: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 59: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) - - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2011, 1, 1), - date__gte=datetime.datetime(2000, 1, 1), - ) - self.assertEqual(logs.count(), 10) - - LogEntry.drop_collection() - - # Test microsecond-level ordering/filtering - for microsecond in (99, 999, 9999, 10000): - LogEntry( - date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) - ).save() - - logs = list(LogEntry.objects.order_by('date')) - for next_idx, log in enumerate(logs[:-1], start=1): - next_log = logs[next_idx] - self.assertTrue(log.date < next_log.date) - - logs = list(LogEntry.objects.order_by('-date')) - for next_idx, log in enumerate(logs[:-1], start=1): - next_log = logs[next_idx] - self.assertTrue(log.date > next_log.date) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) - self.assertEqual(logs.count(), 4) - - def test_no_default_value(self): - class Log(Document): - timestamp = ComplexDateTimeField() - - Log.drop_collection() - - log = Log() - self.assertIsNone(log.timestamp) - log.save() - - fetched_log = Log.objects.with_id(log.id) - self.assertIsNone(fetched_log.timestamp) - - def test_default_static_value(self): - NOW = datetime.datetime.utcnow() - class Log(Document): - timestamp = ComplexDateTimeField(default=NOW) - - Log.drop_collection() - - log = Log() - self.assertEqual(log.timestamp, NOW) - log.save() - - fetched_log = Log.objects.with_id(log.id) - self.assertEqual(fetched_log.timestamp, NOW) - - def test_default_callable(self): - NOW = datetime.datetime.utcnow() - - class Log(Document): - timestamp = ComplexDateTimeField(default=datetime.datetime.utcnow) - - Log.drop_collection() - - log = Log() - self.assertGreaterEqual(log.timestamp, NOW) - log.save() - - fetched_log = Log.objects.with_id(log.id) - self.assertGreaterEqual(fetched_log.timestamp, NOW) - - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 213e889c..a7722458 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -24,6 +24,16 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') +def get_file(path): + """Use a BytesIO instead of a file to allow + to have a one-liner and avoid that the file remains opened""" + bytes_io = StringIO() + with open(path, 'rb') as f: + bytes_io.write(f.read()) + bytes_io.seek(0) + return bytes_io + + class FileTest(MongoDBTestCase): def tearDown(self): @@ -247,8 +257,8 @@ class FileTest(MongoDBTestCase): Animal.drop_collection() marmot = Animal(genus='Marmota', family='Sciuridae') - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - marmot.photo.put(marmot_photo, content_type='image/jpeg', foo='bar') + marmot_photo_content = get_file(TEST_IMAGE_PATH) # Retrieve a photo from disk + marmot.photo.put(marmot_photo_content, content_type='image/jpeg', foo='bar') marmot.photo.close() marmot.save() @@ -261,11 +271,11 @@ class FileTest(MongoDBTestCase): the_file = FileField() TestFile.drop_collection() - test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() self.assertEqual(test_file.the_file.get().length, 8313) test_file = TestFile.objects.first() - test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() self.assertEqual(test_file.the_file.get().length, 4971) @@ -310,16 +320,16 @@ class FileTest(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 1) - self.assertEquals(len(list(chunks)), 1) + self.assertEqual(len(list(files)), 1) + self.assertEqual(len(list(chunks)), 1) # Deleting the docoument should delete the files testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) # Test case where we don't store a file in the first place testfile = TestFile() @@ -327,15 +337,15 @@ class FileTest(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) # Test case where we overwrite the file testfile = TestFile() @@ -348,15 +358,15 @@ class FileTest(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 1) - self.assertEquals(len(list(chunks)), 1) + self.assertEqual(len(list(files)), 1) + self.assertEqual(len(list(chunks)), 1) testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) def test_image_field(self): if not HAS_PIL: @@ -379,7 +389,7 @@ class FileTest(MongoDBTestCase): self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -400,11 +410,11 @@ class FileTest(MongoDBTestCase): the_file = ImageField() TestFile.drop_collection() - test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() self.assertEqual(test_file.the_file.size, (371, 76)) test_file = TestFile.objects.first() - test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() self.assertEqual(test_file.the_file.size, (45, 101)) @@ -418,7 +428,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -441,7 +451,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -464,7 +474,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -542,8 +552,8 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image1.put(open(TEST_IMAGE_PATH, 'rb')) - t.image2.put(open(TEST_IMAGE2_PATH, 'rb')) + t.image1.put(get_file(TEST_IMAGE_PATH)) + t.image2.put(get_file(TEST_IMAGE2_PATH)) t.save() test = TestImage.objects.first() @@ -563,12 +573,10 @@ class FileTest(MongoDBTestCase): Animal.drop_collection() marmot = Animal(genus='Marmota', family='Sciuridae') - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - - photos_field = marmot._fields['photos'].field - new_proxy = photos_field.get_proxy_obj('photos', marmot) - new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') - marmot_photo.close() + with open(TEST_IMAGE_PATH, 'rb') as marmot_photo: # Retrieve a photo from disk + photos_field = marmot._fields['photos'].field + new_proxy = photos_field.get_proxy_obj('photos', marmot) + new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') marmot.photos.append(new_proxy) marmot.save() @@ -578,5 +586,6 @@ class FileTest(MongoDBTestCase): self.assertEqual(marmot.photos[0].foo, 'bar') self.assertEqual(marmot.photos[0].get().length, 8313) + if __name__ == '__main__': unittest.main() diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 754f4203..37ed97f5 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -40,6 +40,11 @@ class GeoFieldTest(unittest.TestCase): expected = "Both values (%s) in point must be float or int" % repr(coord) self._test_for_expected_error(Location, coord, expected) + invalid_coords = [21, 4, 'a'] + for coord in invalid_coords: + expected = "GeoPointField can only accept tuples or lists of (x, y)" + self._test_for_expected_error(Location, coord, expected) + def test_point_validation(self): class Location(Document): loc = PointField() diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py new file mode 100644 index 00000000..7a2a3db6 --- /dev/null +++ b/tests/fields/test_boolean_field.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +from mongoengine import * + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class TestBooleanField(MongoDBTestCase): + def test_storage(self): + class Person(Document): + admin = BooleanField() + + person = Person(admin=True) + person.save() + self.assertEqual( + get_as_pymongo(person), + {'_id': person.id, + 'admin': True}) + + def test_validation(self): + """Ensure that invalid values cannot be assigned to boolean + fields. + """ + class Person(Document): + admin = BooleanField() + + person = Person() + person.admin = True + person.validate() + + person.admin = 2 + self.assertRaises(ValidationError, person.validate) + person.admin = 'Yes' + self.assertRaises(ValidationError, person.validate) + person.admin = 'False' + self.assertRaises(ValidationError, person.validate) + + def test_weirdness_constructor(self): + """When attribute is set in contructor, it gets cast into a bool + which causes some weird behavior. We dont necessarily want to maintain this behavior + but its a known issue + """ + class Person(Document): + admin = BooleanField() + + new_person = Person(admin='False') + self.assertTrue(new_person.admin) + + new_person = Person(admin='0') + self.assertTrue(new_person.admin) diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py new file mode 100644 index 00000000..470ecc5d --- /dev/null +++ b/tests/fields/test_cached_reference_field.py @@ -0,0 +1,446 @@ +# -*- coding: utf-8 -*- +from decimal import Decimal + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestCachedReferenceField(MongoDBTestCase): + + def test_get_and_save(self): + """ + Tests #1047: CachedReferenceField creates DBRefs on to_python, + but can't save them on to_mongo. + """ + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField(Animal) + + Animal.drop_collection() + Ocorrence.drop_collection() + + Ocorrence(person="testte", + animal=Animal(name="Leopard", tag="heavy").save()).save() + p = Ocorrence.objects.get() + p.person = 'new_testte' + p.save() + + def test_general_things(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(name="Leopard", tag="heavy") + a.save() + + self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) + o = Ocorrence(person="teste", animal=a) + o.save() + + p = Ocorrence(person="Wilson") + p.save() + + self.assertEqual(Ocorrence.objects(animal=None).count(), 1) + + self.assertEqual( + a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects(animal__tag='heavy').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects(animal__tag='heavy').first() + self.assertEqual(ocorrence.person, "teste") + self.assertIsInstance(ocorrence.animal, Animal) + + def test_with_decimal(self): + class PersonAuto(Document): + name = StringField() + salary = DecimalField() + + class SocialTest(Document): + group = StringField() + person = CachedReferenceField( + PersonAuto, + fields=('salary',)) + + PersonAuto.drop_collection() + SocialTest.drop_collection() + + p = PersonAuto(name="Alberto", salary=Decimal('7000.00')) + p.save() + + s = SocialTest(group="dev", person=p) + s.save() + + self.assertEqual( + SocialTest.objects._collection.find_one({'person.salary': 7000.00}), { + '_id': s.pk, + 'group': s.group, + 'person': { + '_id': p.pk, + 'salary': 7000.00 + } + }) + + def test_cached_reference_field_reference(self): + class Group(Document): + name = StringField() + + class Person(Document): + name = StringField() + group = ReferenceField(Group) + + class SocialData(Document): + obs = StringField() + tags = ListField( + StringField()) + person = CachedReferenceField( + Person, + fields=('group',)) + + Group.drop_collection() + Person.drop_collection() + SocialData.drop_collection() + + g1 = Group(name='dev') + g1.save() + + g2 = Group(name="designers") + g2.save() + + p1 = Person(name="Alberto", group=g1) + p1.save() + + p2 = Person(name="Andre", group=g1) + p2.save() + + p3 = Person(name="Afro design", group=g2) + p3.save() + + s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2']) + s1.save() + + s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4']) + s2.save() + + self.assertEqual(SocialData.objects._collection.find_one( + {'tags': 'tag2'}), { + '_id': s1.pk, + 'obs': 'testing 123', + 'tags': ['tag1', 'tag2'], + 'person': { + '_id': p1.pk, + 'group': g1.pk + } + }) + + self.assertEqual(SocialData.objects(person__group=g2).count(), 1) + self.assertEqual(SocialData.objects(person__group=g2).first(), s2) + + def test_cached_reference_field_push_with_fields(self): + class Product(Document): + name = StringField() + + Product.drop_collection() + + class Basket(Document): + products = ListField(CachedReferenceField(Product, fields=['name'])) + + Basket.drop_collection() + product1 = Product(name='abc').save() + product2 = Product(name='def').save() + basket = Basket(products=[product1]).save() + self.assertEqual( + Basket.objects._collection.find_one(), + { + '_id': basket.pk, + 'products': [ + { + '_id': product1.pk, + 'name': product1.name + } + ] + } + ) + # push to list + basket.update(push__products=product2) + basket.reload() + self.assertEqual( + Basket.objects._collection.find_one(), + { + '_id': basket.pk, + 'products': [ + { + '_id': product1.pk, + 'name': product1.name + }, + { + '_id': product2.pk, + 'name': product2.name + } + ] + } + ) + + def test_cached_reference_field_update_all(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField(choices=TYPES) + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a2 = Person.objects.with_id(a2.id) + self.assertEqual(a2.father.tp, a1.tp) + + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pj" + } + }) + + self.assertEqual(Person.objects(father=a1)._query, { + 'father._id': a1.pk + }) + self.assertEqual(Person.objects(father=a1).count(), 1) + + Person.objects.update(set__tp="pf") + Person.father.sync_all() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pf" + } + }) + + def test_cached_reference_fields_on_embedded_documents(self): + with self.assertRaises(InvalidDocumentError): + class Test(Document): + name = StringField() + + type('WrongEmbeddedDocument', ( + EmbeddedDocument,), { + 'test': CachedReferenceField(Test) + }) + + def test_cached_reference_auto_sync(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a1.tp = 'pf' + a1.save() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + '_id': a2.pk, + 'name': 'Wilson Junior', + 'tp': 'pf', + 'father': { + '_id': a1.pk, + 'tp': 'pf' + } + }) + + def test_cached_reference_auto_sync_disabled(self): + class Persone(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField( + 'self', fields=('tp',), auto_sync=False) + + Persone.drop_collection() + + a1 = Persone(name="Wilson Father", tp="pj") + a1.save() + + a2 = Persone(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a1.tp = 'pf' + a1.save() + + self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), { + '_id': a2.pk, + 'name': 'Wilson Junior', + 'tp': 'pf', + 'father': { + '_id': a1.pk, + 'tp': 'pj' + } + }) + + def test_cached_reference_embedded_fields(self): + class Owner(EmbeddedDocument): + TPS = ( + ('n', "Normal"), + ('u', "Urgent") + ) + name = StringField() + tp = StringField( + verbose_name="Type", + db_field="t", + choices=TPS) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tp']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(name="Leopard", tag="heavy", + owner=Owner(tp='u', name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 't': 'u' + } + }) + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') + + # Check to_mongo with fields + self.assertNotIn('animal', o.to_mongo(fields=['person'])) + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects( + animal__tag='heavy', animal__owner__tp='u').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tp='u').first() + self.assertEqual(ocorrence.person, "teste") + self.assertIsInstance(ocorrence.animal, Animal) + + def test_cached_reference_embedded_list_fields(self): + class Owner(EmbeddedDocument): + name = StringField() + tags = ListField(StringField()) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tags']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(name="Leopard", tag="heavy", + owner=Owner(tags=['cool', 'funny'], + name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste 2", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 'tags': ['cool', 'funny'] + } + }) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['tags'], + ['cool', 'funny']) + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + query = Ocorrence.objects( + animal__tag='heavy', animal__owner__tags='cool')._query + self.assertEqual( + query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tags='cool').first() + self.assertEqual(ocorrence.person, "teste 2") + self.assertIsInstance(ocorrence.animal, Animal) diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py new file mode 100644 index 00000000..58dc4b43 --- /dev/null +++ b/tests/fields/test_complex_datetime_field.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +import datetime +import math +import itertools +import re + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class ComplexDateTimeFieldTest(MongoDBTestCase): + def test_complexdatetime_storage(self): + """Tests for complex datetime fields - which can handle + microseconds without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + date_with_dots = ComplexDateTimeField(separator='.') + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped - with default datetimefields + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Post UTC - microseconds are rounded (down) nearest millisecond - with + # default datetimefields + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Pre UTC dates microseconds below 1000 are dropped - with default + # datetimefields + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Pre UTC microseconds above 1000 is wonky - with default datetimefields + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + for i in range(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + # Test string padding + microsecond = map(int, [math.pow(10, x) for x in range(6)]) + mm = dd = hh = ii = ss = [1, 10] + + for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): + stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] + self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) + + # Test separator + stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] + self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) + + def test_complexdatetime_usage(self): + """Tests for complex datetime fields - which can handle + microseconds without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1950, 1, 1, 0, 0, 1, 999) + log = LogEntry() + log.date = d1 + log.save() + + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + # create extra 59 log entries for a total of 60 + for i in range(1951, 2010): + d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 60) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 59: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 59: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2011, 1, 1), + date__gte=datetime.datetime(2000, 1, 1), + ) + self.assertEqual(logs.count(), 10) + + LogEntry.drop_collection() + + # Test microsecond-level ordering/filtering + for microsecond in (99, 999, 9999, 10000): + LogEntry( + date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) + ).save() + + logs = list(LogEntry.objects.order_by('date')) + for next_idx, log in enumerate(logs[:-1], start=1): + next_log = logs[next_idx] + self.assertTrue(log.date < next_log.date) + + logs = list(LogEntry.objects.order_by('-date')) + for next_idx, log in enumerate(logs[:-1], start=1): + next_log = logs[next_idx] + self.assertTrue(log.date > next_log.date) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) + self.assertEqual(logs.count(), 4) + + def test_no_default_value(self): + class Log(Document): + timestamp = ComplexDateTimeField() + + Log.drop_collection() + + log = Log() + self.assertIsNone(log.timestamp) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertIsNone(fetched_log.timestamp) + + def test_default_static_value(self): + NOW = datetime.datetime.utcnow() + class Log(Document): + timestamp = ComplexDateTimeField(default=NOW) + + Log.drop_collection() + + log = Log() + self.assertEqual(log.timestamp, NOW) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertEqual(fetched_log.timestamp, NOW) + + def test_default_callable(self): + NOW = datetime.datetime.utcnow() + + class Log(Document): + timestamp = ComplexDateTimeField(default=datetime.datetime.utcnow) + + Log.drop_collection() + + log = Log() + self.assertGreaterEqual(log.timestamp, NOW) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertGreaterEqual(fetched_log.timestamp, NOW) diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py new file mode 100644 index 00000000..82adb514 --- /dev/null +++ b/tests/fields/test_date_field.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +import datetime +import six + +try: + import dateutil +except ImportError: + dateutil = None + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestDateField(MongoDBTestCase): + def test_date_from_empty_string(self): + """ + Ensure an exception is raised when trying to + cast an empty string to datetime. + """ + class MyDoc(Document): + dt = DateField() + + md = MyDoc(dt='') + self.assertRaises(ValidationError, md.save) + + def test_date_from_whitespace_string(self): + """ + Ensure an exception is raised when trying to + cast a whitespace-only string to datetime. + """ + class MyDoc(Document): + dt = DateField() + + md = MyDoc(dt=' ') + self.assertRaises(ValidationError, md.save) + + def test_default_values_today(self): + """Ensure that default field values are used when creating + a document. + """ + class Person(Document): + day = DateField(default=datetime.date.today) + + person = Person() + person.validate() + self.assertEqual(person.day, person.day) + self.assertEqual(person.day, datetime.date.today()) + self.assertEqual(person._data['day'], person.day) + + def test_date(self): + """Tests showing pymongo date fields + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateField() + + LogEntry.drop_collection() + + # Test can save dates + log = LogEntry() + log.date = datetime.date.today() + log.save() + log.reload() + self.assertEqual(log.date, datetime.date.today()) + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1.date()) + self.assertEqual(log.date, d2.date()) + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1.date()) + self.assertEqual(log.date, d2.date()) + + if not six.PY3: + # Pre UTC dates microseconds below 1000 are dropped + # This does not seem to be true in PY3 + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1.date()) + self.assertEqual(log.date, d2.date()) + + def test_regular_usage(self): + """Tests for regular datetime fields""" + class LogEntry(Document): + date = DateField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.validate() + log.save() + + for query in (d1, d1.isoformat(' ')): + log1 = LogEntry.objects.get(date=query) + self.assertEqual(log, log1) + + if dateutil: + log1 = LogEntry.objects.get(date=d1.isoformat('T')) + self.assertEqual(log, log1) + + # create additional 19 log entries for a total of 20 + for i in range(1971, 1990): + d = datetime.datetime(i, 1, 1, 0, 0, 1) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 20) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 10) + + def test_validation(self): + """Ensure that invalid values cannot be assigned to datetime + fields. + """ + class LogEntry(Document): + time = DateField() + + log = LogEntry() + log.time = datetime.datetime.now() + log.validate() + + log.time = datetime.date.today() + log.validate() + + log.time = datetime.datetime.now().isoformat(' ') + log.validate() + + if dateutil: + log.time = datetime.datetime.now().isoformat('T') + log.validate() + + log.time = -1 + self.assertRaises(ValidationError, log.validate) + log.time = 'ABC' + self.assertRaises(ValidationError, log.validate) diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py new file mode 100644 index 00000000..92f0668a --- /dev/null +++ b/tests/fields/test_datetime_field.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +import datetime as dt +import six + +try: + import dateutil +except ImportError: + dateutil = None + +from mongoengine import * +from mongoengine import connection + +from tests.utils import MongoDBTestCase + + +class TestDateTimeField(MongoDBTestCase): + def test_datetime_from_empty_string(self): + """ + Ensure an exception is raised when trying to + cast an empty string to datetime. + """ + class MyDoc(Document): + dt = DateTimeField() + + md = MyDoc(dt='') + self.assertRaises(ValidationError, md.save) + + def test_datetime_from_whitespace_string(self): + """ + Ensure an exception is raised when trying to + cast a whitespace-only string to datetime. + """ + class MyDoc(Document): + dt = DateTimeField() + + md = MyDoc(dt=' ') + self.assertRaises(ValidationError, md.save) + + def test_default_value_utcnow(self): + """Ensure that default field values are used when creating + a document. + """ + class Person(Document): + created = DateTimeField(default=dt.datetime.utcnow) + + utcnow = dt.datetime.utcnow() + person = Person() + person.validate() + person_created_t0 = person.created + self.assertLess(person.created - utcnow, dt.timedelta(seconds=1)) + self.assertEqual(person_created_t0, person.created) # make sure it does not change + self.assertEqual(person._data['created'], person.created) + + def test_handling_microseconds(self): + """Tests showing pymongo datetime fields handling of microseconds. + Microseconds are rounded to the nearest millisecond and pre UTC + handling is wonky. + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + # Test can save dates + log = LogEntry() + log.date = dt.date.today() + log.save() + log.reload() + self.assertEqual(log.date.date(), dt.date.today()) + + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped + d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 999) + d2 = dt.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertNotEqual(log.date, d1) + self.assertEqual(log.date, d2) + + # Post UTC - microseconds are rounded (down) nearest millisecond + d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 9999) + d2 = dt.datetime(1970, 1, 1, 0, 0, 1, 9000) + log.date = d1 + log.save() + log.reload() + self.assertNotEqual(log.date, d1) + self.assertEqual(log.date, d2) + + if not six.PY3: + # Pre UTC dates microseconds below 1000 are dropped + # This does not seem to be true in PY3 + d1 = dt.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = dt.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertNotEqual(log.date, d1) + self.assertEqual(log.date, d2) + + def test_regular_usage(self): + """Tests for regular datetime fields""" + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + d1 = dt.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.validate() + log.save() + + for query in (d1, d1.isoformat(' ')): + log1 = LogEntry.objects.get(date=query) + self.assertEqual(log, log1) + + if dateutil: + log1 = LogEntry.objects.get(date=d1.isoformat('T')) + self.assertEqual(log, log1) + + # create additional 19 log entries for a total of 20 + for i in range(1971, 1990): + d = dt.datetime(i, 1, 1, 0, 0, 1) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 20) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=dt.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 10) + + logs = LogEntry.objects.filter(date__lte=dt.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 10) + + logs = LogEntry.objects.filter( + date__lte=dt.datetime(1980, 1, 1), + date__gte=dt.datetime(1975, 1, 1), + ) + self.assertEqual(logs.count(), 5) + + def test_datetime_validation(self): + """Ensure that invalid values cannot be assigned to datetime + fields. + """ + class LogEntry(Document): + time = DateTimeField() + + log = LogEntry() + log.time = dt.datetime.now() + log.validate() + + log.time = dt.date.today() + log.validate() + + log.time = dt.datetime.now().isoformat(' ') + log.validate() + + log.time = '2019-05-16 21:42:57.897847' + log.validate() + + if dateutil: + log.time = dt.datetime.now().isoformat('T') + log.validate() + + log.time = -1 + self.assertRaises(ValidationError, log.validate) + log.time = 'ABC' + self.assertRaises(ValidationError, log.validate) + log.time = '2019-05-16 21:GARBAGE:12' + self.assertRaises(ValidationError, log.validate) + log.time = '2019-05-16 21:42:57.GARBAGE' + self.assertRaises(ValidationError, log.validate) + log.time = '2019-05-16 21:42:57.123.456' + self.assertRaises(ValidationError, log.validate) + + def test_parse_datetime_as_str(self): + class DTDoc(Document): + date = DateTimeField() + + date_str = '2019-03-02 22:26:01' + + # make sure that passing a parsable datetime works + dtd = DTDoc() + dtd.date = date_str + self.assertIsInstance(dtd.date, six.string_types) + dtd.save() + dtd.reload() + + self.assertIsInstance(dtd.date, dt.datetime) + self.assertEqual(str(dtd.date), date_str) + + dtd.date = 'January 1st, 9999999999' + self.assertRaises(ValidationError, dtd.validate) + + +class TestDateTimeTzAware(MongoDBTestCase): + def test_datetime_tz_aware_mark_as_changed(self): + # Reset the connections + connection._connection_settings = {} + connection._connections = {} + connection._dbs = {} + + connect(db='mongoenginetest', tz_aware=True) + + class LogEntry(Document): + time = DateTimeField() + + LogEntry.drop_collection() + + LogEntry(time=dt.datetime(2013, 1, 1, 0, 0, 0)).save() + + log = LogEntry.objects.first() + log.time = dt.datetime(2013, 1, 1, 0, 0, 0) + self.assertEqual(['time'], log._changed_fields) diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py new file mode 100644 index 00000000..0213b880 --- /dev/null +++ b/tests/fields/test_decimal_field.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +from decimal import Decimal + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestDecimalField(MongoDBTestCase): + + def test_validation(self): + """Ensure that invalid values cannot be assigned to decimal fields. + """ + class Person(Document): + height = DecimalField(min_value=Decimal('0.1'), + max_value=Decimal('3.5')) + + Person.drop_collection() + + Person(height=Decimal('1.89')).save() + person = Person.objects.first() + self.assertEqual(person.height, Decimal('1.89')) + + person.height = '2.0' + person.save() + person.height = 0.01 + self.assertRaises(ValidationError, person.validate) + person.height = Decimal('0.01') + self.assertRaises(ValidationError, person.validate) + person.height = Decimal('4.0') + self.assertRaises(ValidationError, person.validate) + person.height = 'something invalid' + self.assertRaises(ValidationError, person.validate) + + person_2 = Person(height='something invalid') + self.assertRaises(ValidationError, person_2.validate) + + def test_comparison(self): + class Person(Document): + money = DecimalField() + + Person.drop_collection() + + Person(money=6).save() + Person(money=7).save() + Person(money=8).save() + Person(money=10).save() + + self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) + self.assertEqual(2, Person.objects(money__gt=7).count()) + self.assertEqual(2, Person.objects(money__gt="7").count()) + + self.assertEqual(3, Person.objects(money__gte="7").count()) + + def test_storage(self): + class Person(Document): + float_value = DecimalField(precision=4) + string_value = DecimalField(precision=4, force_string=True) + + Person.drop_collection() + values_to_store = [10, 10.1, 10.11, "10.111", Decimal("10.1111"), Decimal("10.11111")] + for store_at_creation in [True, False]: + for value in values_to_store: + # to_python is called explicitly if values were sent in the kwargs of __init__ + if store_at_creation: + Person(float_value=value, string_value=value).save() + else: + person = Person.objects.create() + person.float_value = value + person.string_value = value + person.save() + + # How its stored + expected = [ + {'float_value': 10.0, 'string_value': '10.0000'}, + {'float_value': 10.1, 'string_value': '10.1000'}, + {'float_value': 10.11, 'string_value': '10.1100'}, + {'float_value': 10.111, 'string_value': '10.1110'}, + {'float_value': 10.1111, 'string_value': '10.1111'}, + {'float_value': 10.1111, 'string_value': '10.1111'}] + expected.extend(expected) + actual = list(Person.objects.exclude('id').as_pymongo()) + self.assertEqual(expected, actual) + + # How it comes out locally + expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), + Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] + expected.extend(expected) + for field_name in ['float_value', 'string_value']: + actual = list(Person.objects().scalar(field_name)) + self.assertEqual(expected, actual) diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py new file mode 100644 index 00000000..ade02ccf --- /dev/null +++ b/tests/fields/test_dict_field.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +from mongoengine import * +from mongoengine.base import BaseDict + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class TestDictField(MongoDBTestCase): + + def test_storage(self): + class BlogPost(Document): + info = DictField() + + BlogPost.drop_collection() + + info = {'testkey': 'testvalue'} + post = BlogPost(info=info).save() + self.assertEqual( + get_as_pymongo(post), + { + '_id': post.id, + 'info': info + } + ) + + def test_general_things(self): + """Ensure that dict types work as expected.""" + class BlogPost(Document): + info = DictField() + + BlogPost.drop_collection() + + post = BlogPost() + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) + + post.info = ['test', 'test'] + self.assertRaises(ValidationError, post.validate) + + post.info = {'$title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = {'nested': {'$title': 'test'}} + self.assertRaises(ValidationError, post.validate) + + post.info = {'the.title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = {'nested': {'the.title': 'test'}} + self.assertRaises(ValidationError, post.validate) + + post.info = {1: 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = {'title': 'test'} + post.save() + + post = BlogPost() + post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}} + post.save() + + post = BlogPost() + post.info = {'details': {'test': 'test'}} + post.save() + + post = BlogPost() + post.info = {'details': {'test': 3}} + post.save() + + self.assertEqual(BlogPost.objects.count(), 4) + self.assertEqual( + BlogPost.objects.filter(info__title__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + + post = BlogPost.objects.filter(info__title__exact='dollar_sign').first() + self.assertIn('te$t', post['info']['details']) + + # Confirm handles non strings or non existing keys + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + + post = BlogPost.objects.create(info={'title': 'original'}) + post.info.update({'title': 'updated'}) + post.save() + post.reload() + self.assertEqual('updated', post.info['title']) + + post.info.setdefault('authors', []) + post.save() + post.reload() + self.assertEqual([], post.info['authors']) + + def test_dictfield_dump_document(self): + """Ensure a DictField can handle another document's dump.""" + class Doc(Document): + field = DictField() + + class ToEmbed(Document): + id = IntField(primary_key=True, default=1) + recursive = DictField() + + class ToEmbedParent(Document): + id = IntField(primary_key=True, default=1) + recursive = DictField() + + meta = {'allow_inheritance': True} + + class ToEmbedChild(ToEmbedParent): + pass + + to_embed_recursive = ToEmbed(id=1).save() + to_embed = ToEmbed( + id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() + doc = Doc(field=to_embed.to_mongo().to_dict()) + doc.save() + self.assertIsInstance(doc.field, dict) + self.assertEqual(doc.field, {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}}) + # Same thing with a Document with a _cls field + to_embed_recursive = ToEmbedChild(id=1).save() + to_embed_child = ToEmbedChild( + id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() + doc = Doc(field=to_embed_child.to_mongo().to_dict()) + doc.save() + self.assertIsInstance(doc.field, dict) + expected = { + '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild', + 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}} + } + self.assertEqual(doc.field, expected) + + def test_dictfield_strict(self): + """Ensure that dict field handles validation if provided a strict field type.""" + class Simple(Document): + mapping = DictField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + # try creating an invalid mapping + with self.assertRaises(ValidationError): + e.mapping['somestring'] = "abc" + e.save() + + def test_dictfield_complex(self): + """Ensure that the dict field can handle the complex types.""" + class SettingBase(EmbeddedDocument): + meta = {'allow_inheritance': True} + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = DictField() + + Simple.drop_collection() + + e = Simple() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', + 'float': 1.001, + 'complex': IntegerSetting(value=42), + 'list': [IntegerSetting(value=42), + StringSetting(value='foo')]} + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertIsInstance(e2.mapping['somestring'], StringSetting) + self.assertIsInstance(e2.mapping['someint'], IntegerSetting) + + # Test querying + self.assertEqual( + Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update( + set__mapping={"someint": IntegerSetting(value=10)}) + Simple.objects().update( + set__mapping__nested_dict__list__1=StringSetting(value='Boo')) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + + def test_push_dict(self): + class MyModel(Document): + events = ListField(DictField()) + + doc = MyModel(events=[{'a': 1}]).save() + raw_doc = get_as_pymongo(doc) + expected_raw_doc = { + '_id': doc.id, + 'events': [{'a': 1}] + } + self.assertEqual(raw_doc, expected_raw_doc) + + MyModel.objects(id=doc.id).update(push__events={}) + raw_doc = get_as_pymongo(doc) + expected_raw_doc = { + '_id': doc.id, + 'events': [{'a': 1}, {}] + } + self.assertEqual(raw_doc, expected_raw_doc) + + def test_ensure_unique_default_instances(self): + """Ensure that every field has it's own unique default instance.""" + class D(Document): + data = DictField() + data2 = DictField(default=lambda: {}) + + d1 = D() + d1.data['foo'] = 'bar' + d1.data2['foo'] = 'bar' + d2 = D() + self.assertEqual(d2.data, {}) + self.assertEqual(d2.data2, {}) + + def test_dict_field_invalid_dict_value(self): + class DictFieldTest(Document): + dictionary = DictField(required=True) + + DictFieldTest.drop_collection() + + test = DictFieldTest(dictionary=None) + test.dictionary # Just access to test getter + self.assertRaises(ValidationError, test.validate) + + test = DictFieldTest(dictionary=False) + test.dictionary # Just access to test getter + self.assertRaises(ValidationError, test.validate) + + def test_dict_field_raises_validation_error_if_wrongly_assign_embedded_doc(self): + class DictFieldTest(Document): + dictionary = DictField(required=True) + + DictFieldTest.drop_collection() + + class Embedded(EmbeddedDocument): + name = StringField() + + embed = Embedded(name='garbage') + doc = DictFieldTest(dictionary=embed) + with self.assertRaises(ValidationError) as ctx_err: + doc.validate() + self.assertIn("'dictionary'", str(ctx_err.exception)) + self.assertIn('Only dictionaries may be used in a DictField', str(ctx_err.exception)) + + def test_atomic_update_dict_field(self): + """Ensure that the entire DictField can be atomically updated.""" + class Simple(Document): + mapping = DictField(field=ListField(IntField(required=True))) + + Simple.drop_collection() + + e = Simple() + e.mapping['someints'] = [1, 2] + e.save() + e.update(set__mapping={"ints": [3, 4]}) + e.reload() + self.assertEqual(BaseDict, type(e.mapping)) + self.assertEqual({"ints": [3, 4]}, e.mapping) + + # try creating an invalid mapping + with self.assertRaises(ValueError): + e.update(set__mapping={"somestrings": ["foo", "bar", ]}) + + def test_dictfield_with_referencefield_complex_nesting_cases(self): + """Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)""" + # Relates to Issue #1453 + class Doc(Document): + s = StringField() + + class Simple(Document): + mapping0 = DictField(ReferenceField(Doc, dbref=True)) + mapping1 = DictField(ReferenceField(Doc, dbref=False)) + mapping2 = DictField(ListField(ReferenceField(Doc, dbref=True))) + mapping3 = DictField(ListField(ReferenceField(Doc, dbref=False))) + mapping4 = DictField(DictField(field=ReferenceField(Doc, dbref=True))) + mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False))) + mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True)))) + mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False)))) + mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))) + mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))) + + Doc.drop_collection() + Simple.drop_collection() + + d = Doc(s='aa').save() + e = Simple() + e.mapping0['someint'] = e.mapping1['someint'] = d + e.mapping2['someint'] = e.mapping3['someint'] = [d] + e.mapping4['someint'] = e.mapping5['someint'] = {'d': d} + e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}] + e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}] + e.save() + + s = Simple.objects.first() + self.assertIsInstance(s.mapping0['someint'], Doc) + self.assertIsInstance(s.mapping1['someint'], Doc) + self.assertIsInstance(s.mapping2['someint'][0], Doc) + self.assertIsInstance(s.mapping3['someint'][0], Doc) + self.assertIsInstance(s.mapping4['someint']['d'], Doc) + self.assertIsInstance(s.mapping5['someint']['d'], Doc) + self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc) + self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc) + self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc) + self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc) diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py new file mode 100644 index 00000000..3ce49d62 --- /dev/null +++ b/tests/fields/test_email_field.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import sys +from unittest import SkipTest + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestEmailField(MongoDBTestCase): + def test_generic_behavior(self): + class User(Document): + email = EmailField() + + user = User(email='ross@example.com') + user.validate() + + user = User(email='ross@example.co.uk') + user.validate() + + user = User(email=('Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S' + 'aJIazqqWkm7.net')) + user.validate() + + user = User(email='new-tld@example.technology') + user.validate() + + user = User(email='ross@example.com.') + self.assertRaises(ValidationError, user.validate) + + # unicode domain + user = User(email=u'user@пример.рф') + user.validate() + + # invalid unicode domain + user = User(email=u'user@пример') + self.assertRaises(ValidationError, user.validate) + + # invalid data type + user = User(email=123) + self.assertRaises(ValidationError, user.validate) + + def test_email_field_unicode_user(self): + # Don't run this test on pypy3, which doesn't support unicode regex: + # https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode + if sys.version_info[:2] == (3, 2): + raise SkipTest('unicode email addresses are not supported on PyPy 3') + + class User(Document): + email = EmailField() + + # unicode user shouldn't validate by default... + user = User(email=u'Dörte@Sörensen.example.com') + self.assertRaises(ValidationError, user.validate) + + # ...but it should be fine with allow_utf8_user set to True + class User(Document): + email = EmailField(allow_utf8_user=True) + + user = User(email=u'Dörte@Sörensen.example.com') + user.validate() + + def test_email_field_domain_whitelist(self): + class User(Document): + email = EmailField() + + # localhost domain shouldn't validate by default... + user = User(email='me@localhost') + self.assertRaises(ValidationError, user.validate) + + # ...but it should be fine if it's whitelisted + class User(Document): + email = EmailField(domain_whitelist=['localhost']) + + user = User(email='me@localhost') + user.validate() + + def test_email_domain_validation_fails_if_invalid_idn(self): + class User(Document): + email = EmailField() + + invalid_idn = '.google.com' + user = User(email='me@%s' % invalid_idn) + with self.assertRaises(ValidationError) as ctx_err: + user.validate() + self.assertIn("domain failed IDN encoding", str(ctx_err.exception)) + + def test_email_field_ip_domain(self): + class User(Document): + email = EmailField() + + valid_ipv4 = 'email@[127.0.0.1]' + valid_ipv6 = 'email@[2001:dB8::1]' + invalid_ip = 'email@[324.0.0.1]' + + # IP address as a domain shouldn't validate by default... + user = User(email=valid_ipv4) + self.assertRaises(ValidationError, user.validate) + + user = User(email=valid_ipv6) + self.assertRaises(ValidationError, user.validate) + + user = User(email=invalid_ip) + self.assertRaises(ValidationError, user.validate) + + # ...but it should be fine with allow_ip_domain set to True + class User(Document): + email = EmailField(allow_ip_domain=True) + + user = User(email=valid_ipv4) + user.validate() + + user = User(email=valid_ipv6) + user.validate() + + # invalid IP should still fail validation + user = User(email=invalid_ip) + self.assertRaises(ValidationError, user.validate) + + def test_email_field_honors_regex(self): + class User(Document): + email = EmailField(regex=r'\w+@example.com') + + # Fails regex validation + user = User(email='me@foo.com') + self.assertRaises(ValidationError, user.validate) + + # Passes regex validation + user = User(email='me@example.com') + self.assertIsNone(user.validate()) diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py new file mode 100644 index 00000000..a262d054 --- /dev/null +++ b/tests/fields/test_embedded_document_field.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- +from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \ + InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField, \ + ReferenceField + +from tests.utils import MongoDBTestCase + + +class TestEmbeddedDocumentField(MongoDBTestCase): + def test___init___(self): + class MyDoc(EmbeddedDocument): + name = StringField() + + field = EmbeddedDocumentField(MyDoc) + self.assertEqual(field.document_type_obj, MyDoc) + + field2 = EmbeddedDocumentField('MyDoc') + self.assertEqual(field2.document_type_obj, 'MyDoc') + + def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): + with self.assertRaises(ValidationError): + EmbeddedDocumentField(dict) + + def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): + + class MyDoc(Document): + name = StringField() + + emb = EmbeddedDocumentField('MyDoc') + with self.assertRaises(ValidationError) as ctx: + emb.document_type + self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) + + def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): + # Relates to #1661 + class MyDoc(Document): + name = StringField() + + with self.assertRaises(ValidationError): + class MyFailingDoc(Document): + emb = EmbeddedDocumentField(MyDoc) + + with self.assertRaises(ValidationError): + class MyFailingdoc2(Document): + emb = EmbeddedDocumentField('MyDoc') + + def test_query_embedded_document_attribute(self): + class AdminSettings(EmbeddedDocument): + foo1 = StringField() + foo2 = StringField() + + class Person(Document): + settings = EmbeddedDocumentField(AdminSettings) + name = StringField() + + Person.drop_collection() + + p = Person( + settings=AdminSettings(foo1='bar1', foo2='bar2'), + name='John', + ).save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + Person.objects(settings__notexist='bar').first() + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + with self.assertRaises(LookUpError): + Person.objects.only('settings.notexist') + + # Test existing attribute + self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p.id) + only_p = Person.objects.only('settings.foo1').first() + self.assertEqual(only_p.settings.foo1, p.settings.foo1) + self.assertIsNone(only_p.settings.foo2) + self.assertIsNone(only_p.name) + + exclude_p = Person.objects.exclude('settings.foo1').first() + self.assertIsNone(exclude_p.settings.foo1) + self.assertEqual(exclude_p.settings.foo2, p.settings.foo2) + self.assertEqual(exclude_p.name, p.name) + + def test_query_embedded_document_attribute_with_inheritance(self): + class BaseSettings(EmbeddedDocument): + meta = {'allow_inheritance': True} + base_foo = StringField() + + class AdminSettings(BaseSettings): + sub_foo = StringField() + + class Person(Document): + settings = EmbeddedDocumentField(BaseSettings) + + Person.drop_collection() + + p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo')) + p.save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id) + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + # Test existing attribute + self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id) + self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id) + + only_p = Person.objects.only('settings.base_foo', 'settings._cls').first() + self.assertEqual(only_p.settings.base_foo, 'basefoo') + self.assertIsNone(only_p.settings.sub_foo) + + def test_query_list_embedded_document_with_inheritance(self): + class Post(EmbeddedDocument): + title = StringField(max_length=120, required=True) + meta = {'allow_inheritance': True} + + class TextPost(Post): + content = StringField() + + class MoviePost(Post): + author = StringField() + + class Record(Document): + posts = ListField(EmbeddedDocumentField(Post)) + + record_movie = Record(posts=[MoviePost(author='John', title='foo')]).save() + record_text = Record(posts=[TextPost(content='a', title='foo')]).save() + + records = list(Record.objects(posts__author=record_movie.posts[0].author)) + self.assertEqual(len(records), 1) + self.assertEqual(records[0].id, record_movie.id) + + records = list(Record.objects(posts__content=record_text.posts[0].content)) + self.assertEqual(len(records), 1) + self.assertEqual(records[0].id, record_text.id) + + self.assertEqual(Record.objects(posts__title='foo').count(), 2) + + +class TestGenericEmbeddedDocumentField(MongoDBTestCase): + + def test_generic_embedded_document(self): + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + like = GenericEmbeddedDocumentField() + + Person.drop_collection() + + person = Person(name='Test User') + person.like = Car(name='Fiat') + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.like, Car) + + person.like = Dish(food="arroz", number=15) + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.like, Dish) + + def test_generic_embedded_document_choices(self): + """Ensure you can limit GenericEmbeddedDocument choices.""" + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + like = GenericEmbeddedDocumentField(choices=(Dish,)) + + Person.drop_collection() + + person = Person(name='Test User') + person.like = Car(name='Fiat') + self.assertRaises(ValidationError, person.validate) + + person.like = Dish(food="arroz", number=15) + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.like, Dish) + + def test_generic_list_embedded_document_choices(self): + """Ensure you can limit GenericEmbeddedDocument choices inside + a list field. + """ + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,))) + + Person.drop_collection() + + person = Person(name='Test User') + person.likes = [Car(name='Fiat')] + self.assertRaises(ValidationError, person.validate) + + person.likes = [Dish(food="arroz", number=15)] + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.likes[0], Dish) + + def test_choices_validation_documents(self): + """ + Ensure fields with document choices validate given a valid choice. + """ + class UserComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(UserComments,)) + ) + + # Ensure Validation Passes + BlogPost(comments=[ + UserComments(author='user2', message='message2'), + ]).save() + + def test_choices_validation_documents_invalid(self): + """ + Ensure fields with document choices validate given an invalid choice. + This should throw a ValidationError exception. + """ + class UserComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class ModeratorComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(UserComments,)) + ) + + # Single Entry Failure + post = BlogPost(comments=[ + ModeratorComments(author='mod1', message='message1'), + ]) + self.assertRaises(ValidationError, post.save) + + # Mixed Entry Failure + post = BlogPost(comments=[ + ModeratorComments(author='mod1', message='message1'), + UserComments(author='user2', message='message2'), + ]) + self.assertRaises(ValidationError, post.save) + + def test_choices_validation_documents_inheritance(self): + """ + Ensure fields with document choices validate given subclass of choice. + """ + class Comments(EmbeddedDocument): + meta = { + 'abstract': True + } + author = StringField() + message = StringField() + + class UserComments(Comments): + pass + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(Comments,)) + ) + + # Save Valid EmbeddedDocument Type + BlogPost(comments=[ + UserComments(author='user2', message='message2'), + ]).save() + + def test_query_generic_embedded_document_attribute(self): + class AdminSettings(EmbeddedDocument): + foo1 = StringField() + + class NonAdminSettings(EmbeddedDocument): + foo2 = StringField() + + class Person(Document): + settings = GenericEmbeddedDocumentField(choices=(AdminSettings, NonAdminSettings)) + + Person.drop_collection() + + p1 = Person(settings=AdminSettings(foo1='bar1')).save() + p2 = Person(settings=NonAdminSettings(foo2='bar2')).save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + Person.objects(settings__notexist='bar').first() + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + with self.assertRaises(LookUpError): + Person.objects.only('settings.notexist') + + # Test existing attribute + self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p1.id) + self.assertEqual(Person.objects(settings__foo2='bar2').first().id, p2.id) + + def test_query_generic_embedded_document_attribute_with_inheritance(self): + class BaseSettings(EmbeddedDocument): + meta = {'allow_inheritance': True} + base_foo = StringField() + + class AdminSettings(BaseSettings): + sub_foo = StringField() + + class Person(Document): + settings = GenericEmbeddedDocumentField(choices=[BaseSettings]) + + Person.drop_collection() + + p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo')) + p.save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id) + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + # Test existing attribute + self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id) + self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id) diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py new file mode 100644 index 00000000..fa92cf20 --- /dev/null +++ b/tests/fields/test_float_field.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +import six + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestFloatField(MongoDBTestCase): + + def test_float_ne_operator(self): + class TestDocument(Document): + float_fld = FloatField() + + TestDocument.drop_collection() + + TestDocument(float_fld=None).save() + TestDocument(float_fld=1).save() + + self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) + self.assertEqual(1, TestDocument.objects(float_fld__ne=1).count()) + + def test_validation(self): + """Ensure that invalid values cannot be assigned to float fields. + """ + class Person(Document): + height = FloatField(min_value=0.1, max_value=3.5) + + class BigPerson(Document): + height = FloatField() + + person = Person() + person.height = 1.89 + person.validate() + + person.height = '2.0' + self.assertRaises(ValidationError, person.validate) + + person.height = 0.01 + self.assertRaises(ValidationError, person.validate) + + person.height = 4.0 + self.assertRaises(ValidationError, person.validate) + + person_2 = Person(height='something invalid') + self.assertRaises(ValidationError, person_2.validate) + + big_person = BigPerson() + + for value, value_type in enumerate(six.integer_types): + big_person.height = value_type(value) + big_person.validate() + + big_person.height = 2 ** 500 + big_person.validate() + + big_person.height = 2 ** 100000 # Too big for a float value + self.assertRaises(ValidationError, big_person.validate) diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py new file mode 100644 index 00000000..1b1f7ad9 --- /dev/null +++ b/tests/fields/test_int_field.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestIntField(MongoDBTestCase): + + def test_int_validation(self): + """Ensure that invalid values cannot be assigned to int fields. + """ + class Person(Document): + age = IntField(min_value=0, max_value=110) + + person = Person() + person.age = 0 + person.validate() + + person.age = 50 + person.validate() + + person.age = 110 + person.validate() + + person.age = -1 + self.assertRaises(ValidationError, person.validate) + person.age = 120 + self.assertRaises(ValidationError, person.validate) + person.age = 'ten' + self.assertRaises(ValidationError, person.validate) + + def test_ne_operator(self): + class TestDocument(Document): + int_fld = IntField() + + TestDocument.drop_collection() + + TestDocument(int_fld=None).save() + TestDocument(int_fld=1).save() + + self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) + self.assertEqual(1, TestDocument.objects(int_fld__ne=1).count()) diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py new file mode 100644 index 00000000..b10506e7 --- /dev/null +++ b/tests/fields/test_lazy_reference_field.py @@ -0,0 +1,570 @@ +# -*- coding: utf-8 -*- +from bson import DBRef, ObjectId + +from mongoengine import * +from mongoengine.base import LazyReference + +from tests.utils import MongoDBTestCase + + +class TestLazyReferenceField(MongoDBTestCase): + def test_lazy_reference_config(self): + # Make sure ReferenceField only accepts a document class or a string + # with a document class name. + self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) + + def test___repr__(self): + class Animal(Document): + pass + + class Ocurrence(Document): + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal() + oc = Ocurrence(animal=animal) + self.assertIn('LazyReference', repr(oc.animal)) + + def test___getattr___unknown_attr_raises_attribute_error(self): + class Animal(Document): + pass + + class Ocurrence(Document): + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal().save() + oc = Ocurrence(animal=animal) + with self.assertRaises(AttributeError): + oc.animal.not_exist + + def test_lazy_reference_simple(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(person="test", animal=animal).save() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + fetched_animal = p.animal.fetch() + self.assertEqual(fetched_animal, animal) + # `fetch` keep cache on referenced document by default... + animal.tag = "not so heavy" + animal.save() + double_fetch = p.animal.fetch() + self.assertIs(fetched_animal, double_fetch) + self.assertEqual(double_fetch.tag, "heavy") + # ...unless specified otherwise + fetch_force = p.animal.fetch(force=True) + self.assertIsNot(fetch_force, fetched_animal) + self.assertEqual(fetch_force.tag, "not so heavy") + + def test_lazy_reference_fetch_invalid_ref(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(person="test", animal=animal).save() + animal.delete() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + with self.assertRaises(DoesNotExist): + p.animal.fetch() + + def test_lazy_reference_set(self): + class Animal(Document): + meta = {'allow_inheritance': True} + + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + class SubAnimal(Animal): + nick = StringField() + + animal = Animal(name="Leopard", tag="heavy").save() + sub_animal = SubAnimal(nick='doggo', name='dog').save() + for ref in ( + animal, + animal.pk, + DBRef(animal._get_collection_name(), animal.pk), + LazyReference(Animal, animal.pk), + + sub_animal, + sub_animal.pk, + DBRef(sub_animal._get_collection_name(), sub_animal.pk), + LazyReference(SubAnimal, sub_animal.pk), + ): + p = Ocurrence(person="test", animal=ref).save() + p.reload() + self.assertIsInstance(p.animal, LazyReference) + p.animal.fetch() + + def test_lazy_reference_bad_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + class BadDoc(Document): + pass + + animal = Animal(name="Leopard", tag="heavy").save() + baddoc = BadDoc().save() + for bad in ( + 42, + 'foo', + baddoc, + DBRef(baddoc._get_collection_name(), animal.pk), + LazyReference(BadDoc, animal.pk) + ): + with self.assertRaises(ValidationError): + p = Ocurrence(person="test", animal=bad).save() + + def test_lazy_reference_query_conversion(self): + """Ensure that LazyReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = LazyReferenceField(Member, dbref=False) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + # Same thing by passing a LazyReference instance + post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() + self.assertEqual(post.id, post2.id) + + def test_lazy_reference_query_conversion_dbref(self): + """Ensure that LazyReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = LazyReferenceField(Member, dbref=True) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + # Same thing by passing a LazyReference instance + post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() + self.assertEqual(post.id, post2.id) + + def test_lazy_reference_passthrough(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + animal = LazyReferenceField(Animal, passthrough=False) + animal_passthrough = LazyReferenceField(Animal, passthrough=True) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(animal=animal, animal_passthrough=animal).save() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + with self.assertRaises(KeyError): + p.animal['name'] + with self.assertRaises(AttributeError): + p.animal.name + self.assertEqual(p.animal.pk, animal.pk) + + self.assertEqual(p.animal_passthrough.name, "Leopard") + self.assertEqual(p.animal_passthrough['name'], "Leopard") + + # Should not be able to access referenced document's methods + with self.assertRaises(AttributeError): + p.animal.save + with self.assertRaises(KeyError): + p.animal['save'] + + def test_lazy_reference_not_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + Ocurrence(person='foo').save() + p = Ocurrence.objects.get() + self.assertIs(p.animal, None) + + def test_lazy_reference_equality(self): + class Animal(Document): + name = StringField() + tag = StringField() + + Animal.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + animalref = LazyReference(Animal, animal.pk) + self.assertEqual(animal, animalref) + self.assertEqual(animalref, animal) + + other_animalref = LazyReference(Animal, ObjectId("54495ad94c934721ede76f90")) + self.assertNotEqual(animal, other_animalref) + self.assertNotEqual(other_animalref, animal) + + def test_lazy_reference_embedded(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class EmbeddedOcurrence(EmbeddedDocument): + in_list = ListField(LazyReferenceField(Animal)) + direct = LazyReferenceField(Animal) + + class Ocurrence(Document): + in_list = ListField(LazyReferenceField(Animal)) + in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) + direct = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal1 = Animal('doggo').save() + animal2 = Animal('cheeta').save() + + def check_fields_type(occ): + self.assertIsInstance(occ.direct, LazyReference) + for elem in occ.in_list: + self.assertIsInstance(elem, LazyReference) + self.assertIsInstance(occ.in_embedded.direct, LazyReference) + for elem in occ.in_embedded.in_list: + self.assertIsInstance(elem, LazyReference) + + occ = Ocurrence( + in_list=[animal1, animal2], + in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, + direct=animal1 + ).save() + check_fields_type(occ) + occ.reload() + check_fields_type(occ) + occ.direct = animal1.id + occ.in_list = [animal1.id, animal2.id] + occ.in_embedded.direct = animal1.id + occ.in_embedded.in_list = [animal1.id, animal2.id] + check_fields_type(occ) + + +class TestGenericLazyReferenceField(MongoDBTestCase): + def test_generic_lazy_reference_simple(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(person="test", animal=animal).save() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + fetched_animal = p.animal.fetch() + self.assertEqual(fetched_animal, animal) + # `fetch` keep cache on referenced document by default... + animal.tag = "not so heavy" + animal.save() + double_fetch = p.animal.fetch() + self.assertIs(fetched_animal, double_fetch) + self.assertEqual(double_fetch.tag, "heavy") + # ...unless specified otherwise + fetch_force = p.animal.fetch(force=True) + self.assertIsNot(fetch_force, fetched_animal) + self.assertEqual(fetch_force.tag, "not so heavy") + + def test_generic_lazy_reference_choices(self): + class Animal(Document): + name = StringField() + + class Vegetal(Document): + name = StringField() + + class Mineral(Document): + name = StringField() + + class Ocurrence(Document): + living_thing = GenericLazyReferenceField(choices=[Animal, Vegetal]) + thing = GenericLazyReferenceField() + + Animal.drop_collection() + Vegetal.drop_collection() + Mineral.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard").save() + vegetal = Vegetal(name="Oak").save() + mineral = Mineral(name="Granite").save() + + occ_animal = Ocurrence(living_thing=animal, thing=animal).save() + occ_vegetal = Ocurrence(living_thing=vegetal, thing=vegetal).save() + with self.assertRaises(ValidationError): + Ocurrence(living_thing=mineral).save() + + occ = Ocurrence.objects.get(living_thing=animal) + self.assertEqual(occ, occ_animal) + self.assertIsInstance(occ.thing, LazyReference) + self.assertIsInstance(occ.living_thing, LazyReference) + + occ.thing = vegetal + occ.living_thing = vegetal + occ.save() + + occ.thing = mineral + occ.living_thing = mineral + with self.assertRaises(ValidationError): + occ.save() + + def test_generic_lazy_reference_set(self): + class Animal(Document): + meta = {'allow_inheritance': True} + + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + class SubAnimal(Animal): + nick = StringField() + + animal = Animal(name="Leopard", tag="heavy").save() + sub_animal = SubAnimal(nick='doggo', name='dog').save() + for ref in ( + animal, + LazyReference(Animal, animal.pk), + {'_cls': 'Animal', '_ref': DBRef(animal._get_collection_name(), animal.pk)}, + + sub_animal, + LazyReference(SubAnimal, sub_animal.pk), + {'_cls': 'SubAnimal', '_ref': DBRef(sub_animal._get_collection_name(), sub_animal.pk)}, + ): + p = Ocurrence(person="test", animal=ref).save() + p.reload() + self.assertIsInstance(p.animal, (LazyReference, Document)) + p.animal.fetch() + + def test_generic_lazy_reference_bad_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField(choices=['Animal']) + + Animal.drop_collection() + Ocurrence.drop_collection() + + class BadDoc(Document): + pass + + animal = Animal(name="Leopard", tag="heavy").save() + baddoc = BadDoc().save() + for bad in ( + 42, + 'foo', + baddoc, + LazyReference(BadDoc, animal.pk) + ): + with self.assertRaises(ValidationError): + p = Ocurrence(person="test", animal=bad).save() + + def test_generic_lazy_reference_query_conversion(self): + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = GenericLazyReferenceField() + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + # Same thing by passing a LazyReference instance + post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() + self.assertEqual(post.id, post2.id) + + def test_generic_lazy_reference_not_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + Ocurrence(person='foo').save() + p = Ocurrence.objects.get() + self.assertIs(p.animal, None) + + def test_generic_lazy_reference_accepts_string_instead_of_class(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField('Animal') + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal().save() + Ocurrence(animal=animal).save() + p = Ocurrence.objects.get() + self.assertEqual(p.animal, animal) + + def test_generic_lazy_reference_embedded(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class EmbeddedOcurrence(EmbeddedDocument): + in_list = ListField(GenericLazyReferenceField()) + direct = GenericLazyReferenceField() + + class Ocurrence(Document): + in_list = ListField(GenericLazyReferenceField()) + in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) + direct = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal1 = Animal('doggo').save() + animal2 = Animal('cheeta').save() + + def check_fields_type(occ): + self.assertIsInstance(occ.direct, LazyReference) + for elem in occ.in_list: + self.assertIsInstance(elem, LazyReference) + self.assertIsInstance(occ.in_embedded.direct, LazyReference) + for elem in occ.in_embedded.in_list: + self.assertIsInstance(elem, LazyReference) + + occ = Ocurrence( + in_list=[animal1, animal2], + in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, + direct=animal1 + ).save() + check_fields_type(occ) + occ.reload() + check_fields_type(occ) + animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} + animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} + occ.direct = animal1_ref + occ.in_list = [animal1_ref, animal2_ref] + occ.in_embedded.direct = animal1_ref + occ.in_embedded.in_list = [animal1_ref, animal2_ref] + check_fields_type(occ) diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py new file mode 100644 index 00000000..3f307809 --- /dev/null +++ b/tests/fields/test_long_field.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +import six + +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long + +from mongoengine import * +from mongoengine.connection import get_db + +from tests.utils import MongoDBTestCase + + +class TestLongField(MongoDBTestCase): + + def test_long_field_is_considered_as_int64(self): + """ + Tests that long fields are stored as long in mongo, even if long + value is small enough to be an int. + """ + class TestLongFieldConsideredAsInt64(Document): + some_long = LongField() + + doc = TestLongFieldConsideredAsInt64(some_long=42).save() + db = get_db() + self.assertIsInstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64) + self.assertIsInstance(doc.some_long, six.integer_types) + + def test_long_validation(self): + """Ensure that invalid values cannot be assigned to long fields. + """ + class TestDocument(Document): + value = LongField(min_value=0, max_value=110) + + doc = TestDocument() + doc.value = 50 + doc.validate() + + doc.value = -1 + self.assertRaises(ValidationError, doc.validate) + doc.value = 120 + self.assertRaises(ValidationError, doc.validate) + doc.value = 'ten' + self.assertRaises(ValidationError, doc.validate) + + def test_long_ne_operator(self): + class TestDocument(Document): + long_fld = LongField() + + TestDocument.drop_collection() + + TestDocument(long_fld=None).save() + TestDocument(long_fld=1).save() + + self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py new file mode 100644 index 00000000..cb27cfff --- /dev/null +++ b/tests/fields/test_map_field.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +import datetime + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestMapField(MongoDBTestCase): + + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + with self.assertRaises(ValidationError): + e.mapping['somestring'] = "abc" + e.save() + + with self.assertRaises(ValidationError): + class NoDeclaredType(Document): + mapping = MapField() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + meta = {"allow_inheritance": True} + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertIsInstance(e2.mapping['somestring'], StringSetting) + self.assertIsInstance(e2.mapping['someint'], IntegerSetting) + + with self.assertRaises(ValidationError): + e.mapping['someint'] = 123 + e.save() + + def test_embedded_mapfield_db_field(self): + class Embedded(EmbeddedDocument): + number = IntField(default=0, db_field='i') + + class Test(Document): + my_map = MapField(field=EmbeddedDocumentField(Embedded), + db_field='x') + + Test.drop_collection() + + test = Test() + test.my_map['DICTIONARY_KEY'] = Embedded(number=1) + test.save() + + Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) + + test = Test.objects.get() + self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2) + doc = self.db.test.find_one() + self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) + + def test_mapfield_numerical_index(self): + """Ensure that MapField accept numeric strings as indexes.""" + + class Embedded(EmbeddedDocument): + name = StringField() + + class Test(Document): + my_map = MapField(EmbeddedDocumentField(Embedded)) + + Test.drop_collection() + + test = Test() + test.my_map['1'] = Embedded(name='test') + test.save() + test.my_map['1'].name = 'test updated' + test.save() + + def test_map_field_lookup(self): + """Ensure MapField lookups succeed on Fields without a lookup + method. + """ + + class Action(EmbeddedDocument): + operation = StringField() + object = StringField() + + class Log(Document): + name = StringField() + visited = MapField(DateTimeField()) + actions = MapField(EmbeddedDocumentField(Action)) + + Log.drop_collection() + Log(name="wilson", visited={'friends': datetime.datetime.now()}, + actions={'friends': Action(operation='drink', object='beer')}).save() + + self.assertEqual(1, Log.objects( + visited__friends__exists=True).count()) + + self.assertEqual(1, Log.objects( + actions__friends__operation='drink', + actions__friends__object='beer').count()) + + def test_map_field_unicode(self): + class Info(EmbeddedDocument): + description = StringField() + value_list = ListField(field=StringField()) + + class BlogPost(Document): + info_dict = MapField(field=EmbeddedDocumentField(Info)) + + BlogPost.drop_collection() + + tree = BlogPost(info_dict={ + u"éééé": { + 'description': u"VALUE: éééé" + } + }) + + tree.save() + + self.assertEqual( + BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, + u"VALUE: éééé" + ) diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py new file mode 100644 index 00000000..5e1fc605 --- /dev/null +++ b/tests/fields/test_reference_field.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +from bson import SON, DBRef + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestReferenceField(MongoDBTestCase): + def test_reference_validation(self): + """Ensure that invalid document objects cannot be assigned to + reference fields. + """ + + class User(Document): + name = StringField() + + class BlogPost(Document): + content = StringField() + author = ReferenceField(User) + + User.drop_collection() + BlogPost.drop_collection() + + # Make sure ReferenceField only accepts a document class or a string + # with a document class name. + self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) + + user = User(name='Test User') + + # Ensure that the referenced object must have been saved + post1 = BlogPost(content='Chips and gravy taste good.') + post1.author = user + self.assertRaises(ValidationError, post1.save) + + # Check that an invalid object type cannot be used + post2 = BlogPost(content='Chips and chilli taste good.') + post1.author = post2 + self.assertRaises(ValidationError, post1.validate) + + # Ensure ObjectID's are accepted as references + user_object_id = user.pk + post3 = BlogPost(content="Chips and curry sauce taste good.") + post3.author = user_object_id + post3.save() + + # Make sure referencing a saved document of the right type works + user.save() + post1.author = user + post1.save() + + # Make sure referencing a saved document of the *wrong* type fails + post2.save() + post1.author = post2 + self.assertRaises(ValidationError, post1.validate) + + def test_objectid_reference_fields(self): + """Make sure storing Object ID references works.""" + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="John").save() + Person(name="Ross", parent=p1.pk).save() + + p = Person.objects.get(name="Ross") + self.assertEqual(p.parent, p1) + + def test_dbref_reference_fields(self): + """Make sure storing references as bson.dbref.DBRef works.""" + + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=True) + + Person.drop_collection() + + p1 = Person(name="John").save() + Person(name="Ross", parent=p1).save() + + self.assertEqual( + Person._get_collection().find_one({'name': 'Ross'})['parent'], + DBRef('person', p1.pk) + ) + + p = Person.objects.get(name="Ross") + self.assertEqual(p.parent, p1) + + def test_dbref_to_mongo(self): + """Make sure that calling to_mongo on a ReferenceField which + has dbref=False, but actually actually contains a DBRef returns + an ID of that DBRef. + """ + + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=False) + + p = Person( + name='Steve', + parent=DBRef('person', 'abcdefghijklmnop') + ) + self.assertEqual(p.to_mongo(), SON([ + ('name', u'Steve'), + ('parent', 'abcdefghijklmnop') + ])) + + def test_objectid_reference_fields(self): + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=False) + + Person.drop_collection() + + p1 = Person(name="John").save() + Person(name="Ross", parent=p1).save() + + col = Person._get_collection() + data = col.find_one({'name': 'Ross'}) + self.assertEqual(data['parent'], p1.pk) + + p = Person.objects.get(name="Ross") + self.assertEqual(p.parent, p1) + + def test_undefined_reference(self): + """Ensure that ReferenceFields may reference undefined Documents. + """ + class Product(Document): + name = StringField() + company = ReferenceField('Company') + + class Company(Document): + name = StringField() + + Product.drop_collection() + Company.drop_collection() + + ten_gen = Company(name='10gen') + ten_gen.save() + mongodb = Product(name='MongoDB', company=ten_gen) + mongodb.save() + + me = Product(name='MongoEngine') + me.save() + + obj = Product.objects(company=ten_gen).first() + self.assertEqual(obj, mongodb) + self.assertEqual(obj.company, ten_gen) + + obj = Product.objects(company=None).first() + self.assertEqual(obj, me) + + obj = Product.objects.get(company=None) + self.assertEqual(obj, me) + + def test_reference_query_conversion(self): + """Ensure that ReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = ReferenceField(Member, dbref=False) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + def test_reference_query_conversion_dbref(self): + """Ensure that ReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = ReferenceField(Member, dbref=True) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py new file mode 100644 index 00000000..6124c65e --- /dev/null +++ b/tests/fields/test_sequence_field.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestSequenceField(MongoDBTestCase): + def test_sequence_field(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + def test_sequence_field_get_next_value(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + self.assertEqual(Person.id.get_next_value(), 11) + self.db['mongoengine.counters'].drop() + + self.assertEqual(Person.id.get_next_value(), 1) + + class Person(Document): + id = SequenceField(primary_key=True, value_decorator=str) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + self.assertEqual(Person.id.get_next_value(), '11') + self.db['mongoengine.counters'].drop() + + self.assertEqual(Person.id.get_next_value(), '1') + + def test_sequence_field_sequence_name(self): + class Person(Document): + id = SequenceField(primary_key=True, sequence_name='jelly') + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 10) + + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 1000) + + def test_multiple_sequence_fields(self): + class Person(Document): + id = SequenceField(primary_key=True) + counter = SequenceField() + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + counters = [i.counter for i in Person.objects] + self.assertEqual(counters, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + Person.counter.set_next_value(999) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) + self.assertEqual(c['next'], 999) + + def test_sequence_fields_reload(self): + class Animal(Document): + counter = SequenceField() + name = StringField() + + self.db['mongoengine.counters'].drop() + Animal.drop_collection() + + a = Animal(name="Boi").save() + + self.assertEqual(a.counter, 1) + a.reload() + self.assertEqual(a.counter, 1) + + a.counter = None + self.assertEqual(a.counter, 2) + a.save() + + self.assertEqual(a.counter, 2) + + a = Animal.objects.first() + self.assertEqual(a.counter, 2) + a.reload() + self.assertEqual(a.counter, 2) + + def test_multiple_sequence_fields_on_docs(self): + class Animal(Document): + id = SequenceField(primary_key=True) + name = StringField() + + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Animal.drop_collection() + Person.drop_collection() + + for x in range(10): + Animal(name="Animal %s" % x).save() + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + id = [i.id for i in Animal.objects] + self.assertEqual(id, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) + self.assertEqual(c['next'], 10) + + def test_sequence_field_value_decorator(self): + class Person(Document): + id = SequenceField(primary_key=True, value_decorator=str) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + p = Person(name="Person %s" % x) + p.save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, map(str, range(1, 11))) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + def test_embedded_sequence_field(self): + class Comment(EmbeddedDocument): + id = SequenceField() + content = StringField(required=True) + + class Post(Document): + title = StringField(required=True) + comments = ListField(EmbeddedDocumentField(Comment)) + + self.db['mongoengine.counters'].drop() + Post.drop_collection() + + Post(title="MongoEngine", + comments=[Comment(content="NoSQL Rocks"), + Comment(content="MongoEngine Rocks")]).save() + c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'}) + self.assertEqual(c['next'], 2) + post = Post.objects.first() + self.assertEqual(1, post.comments[0].id) + self.assertEqual(2, post.comments[1].id) + + def test_inherited_sequencefield(self): + class Base(Document): + name = StringField() + counter = SequenceField() + meta = {'abstract': True} + + class Foo(Base): + pass + + class Bar(Base): + pass + + bar = Bar(name='Bar') + bar.save() + + foo = Foo(name='Foo') + foo.save() + + self.assertTrue('base.counter' in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertFalse(('foo.counter' or 'bar.counter') in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertNotEqual(foo.counter, bar.counter) + self.assertEqual(foo._fields['counter'].owner_document, Base) + self.assertEqual(bar._fields['counter'].owner_document, Base) + + def test_no_inherited_sequencefield(self): + class Base(Document): + name = StringField() + meta = {'abstract': True} + + class Foo(Base): + counter = SequenceField() + + class Bar(Base): + counter = SequenceField() + + bar = Bar(name='Bar') + bar.save() + + foo = Foo(name='Foo') + foo.save() + + self.assertFalse('base.counter' in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertTrue(('foo.counter' and 'bar.counter') in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertEqual(foo.counter, bar.counter) + self.assertEqual(foo._fields['counter'].owner_document, Foo) + self.assertEqual(bar._fields['counter'].owner_document, Bar) diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py new file mode 100644 index 00000000..ddbf707e --- /dev/null +++ b/tests/fields/test_url_field.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestURLField(MongoDBTestCase): + + def test_validation(self): + """Ensure that URLFields validate urls properly.""" + class Link(Document): + url = URLField() + + link = Link() + link.url = 'google' + self.assertRaises(ValidationError, link.validate) + + link.url = 'http://www.google.com:8080' + link.validate() + + def test_unicode_url_validation(self): + """Ensure unicode URLs are validated properly.""" + class Link(Document): + url = URLField() + + link = Link() + link.url = u'http://привет.com' + + # TODO fix URL validation - this *IS* a valid URL + # For now we just want to make sure that the error message is correct + with self.assertRaises(ValidationError) as ctx_err: + link.validate() + self.assertEqual(unicode(ctx_err.exception), + u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])") + + def test_url_scheme_validation(self): + """Ensure that URLFields validate urls with specific schemes properly. + """ + class Link(Document): + url = URLField() + + class SchemeLink(Document): + url = URLField(schemes=['ws', 'irc']) + + link = Link() + link.url = 'ws://google.com' + self.assertRaises(ValidationError, link.validate) + + scheme_link = SchemeLink() + scheme_link.url = 'ws://google.com' + scheme_link.validate() + + def test_underscore_allowed_in_domains_names(self): + class Link(Document): + url = URLField() + + link = Link() + link.url = 'https://san_leandro-ca.geebo.com' + link.validate() diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py new file mode 100644 index 00000000..7b7faaf2 --- /dev/null +++ b/tests/fields/test_uuid_field.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +import uuid + +from mongoengine import * + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class Person(Document): + api_key = UUIDField(binary=False) + + +class TestUUIDField(MongoDBTestCase): + def test_storage(self): + uid = uuid.uuid4() + person = Person(api_key=uid).save() + self.assertEqual( + get_as_pymongo(person), + {'_id': person.id, + 'api_key': str(uid) + } + ) + + def test_field_string(self): + """Test UUID fields storing as String + """ + Person.drop_collection() + + uu = uuid.uuid4() + Person(api_key=uu).save() + self.assertEqual(1, Person.objects(api_key=uu).count()) + self.assertEqual(uu, Person.objects.first().api_key) + + person = Person() + valid = (uuid.uuid4(), uuid.uuid1()) + for api_key in valid: + person.api_key = api_key + person.validate() + + invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', + '9d159858-549b-4975-9f98-dd2f987c113') + for api_key in invalid: + person.api_key = api_key + self.assertRaises(ValidationError, person.validate) + + def test_field_binary(self): + """Test UUID fields storing as Binary object.""" + Person.drop_collection() + + uu = uuid.uuid4() + Person(api_key=uu).save() + self.assertEqual(1, Person.objects(api_key=uu).count()) + self.assertEqual(uu, Person.objects.first().api_key) + + person = Person() + valid = (uuid.uuid4(), uuid.uuid1()) + for api_key in valid: + person.api_key = api_key + person.validate() + + invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', + '9d159858-549b-4975-9f98-dd2f987c113') + for api_key in invalid: + person.api_key = api_key + self.assertRaises(ValidationError, person.validate) diff --git a/tests/fixtures.py b/tests/fixtures.py index d8eb8487..b8303b99 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -48,6 +48,7 @@ class PickleSignalsTest(Document): def post_delete(self, sender, document, **kwargs): pickled = pickle.dumps(document) + signals.post_save.connect(PickleSignalsTest.post_save, sender=PickleSignalsTest) signals.post_delete.connect(PickleSignalsTest.post_delete, sender=PickleSignalsTest) diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index b111238a..250e2601 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -208,7 +208,7 @@ class OnlyExcludeAllTest(unittest.TestCase): BlogPost.drop_collection() - post = BlogPost(content='Had a good coffee today...', various={'test_dynamic':{'some': True}}) + post = BlogPost(content='Had a good coffee today...', various={'test_dynamic': {'some': True}}) post.author = User(name='Test User') post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] post.save() @@ -413,7 +413,6 @@ class OnlyExcludeAllTest(unittest.TestCase): numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) - def test_exclude_from_subclasses_docs(self): class Base(Document): @@ -436,5 +435,6 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index fd8c9b0f..45e6a089 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -3,7 +3,7 @@ import unittest from mongoengine import * -from tests.utils import MongoDBTestCase, requires_mongodb_gte_3 +from tests.utils import MongoDBTestCase __all__ = ("GeoQueriesTest",) @@ -70,9 +70,6 @@ class GeoQueriesTest(MongoDBTestCase): self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) - # $minDistance was added in MongoDB v2.6, but continued being buggy - # until v3.0; skip for older versions - @requires_mongodb_gte_3 def test_near_and_min_distance(self): """Ensure the "min_distance" operator works alongside the "near" operator. @@ -243,9 +240,6 @@ class GeoQueriesTest(MongoDBTestCase): events = self.Event.objects(location__geo_within_polygon=polygon2) self.assertEqual(events.count(), 0) - # $minDistance was added in MongoDB v2.6, but continued being buggy - # until v3.0; skip for older versions - @requires_mongodb_gte_3 def test_2dsphere_near_and_min_max_distance(self): """Ensure "min_distace" and "max_distance" operators work well together with the "near" operator in a 2dsphere index. @@ -328,8 +322,6 @@ class GeoQueriesTest(MongoDBTestCase): """Make sure PointField works properly in an embedded document.""" self._test_embedded(point_field_class=PointField) - # Needs MongoDB > 2.6.4 https://jira.mongodb.org/browse/SERVER-14039 - @requires_mongodb_gte_3 def test_spherical_geospatial_operators(self): """Ensure that spherical geospatial queries are working.""" class Point(Document): @@ -534,11 +526,11 @@ class GeoQueriesTest(MongoDBTestCase): Location.drop_collection() - Location(loc=[1,2]).save() + Location(loc=[1, 2]).save() loc = Location.objects.as_pymongo()[0] self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [1, 2]}) - Location.objects.update(set__loc=[2,1]) + Location.objects.update(set__loc=[2, 1]) loc = Location.objects.as_pymongo()[0] self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [2, 1]}) diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py index 4b7c3da2..3c5879ba 100644 --- a/tests/queryset/modify.py +++ b/tests/queryset/modify.py @@ -2,8 +2,6 @@ import unittest from mongoengine import connect, Document, IntField, StringField, ListField -from tests.utils import requires_mongodb_gte_26 - __all__ = ("FindAndModifyTest",) @@ -96,7 +94,6 @@ class FindAndModifyTest(unittest.TestCase): self.assertEqual(old_doc.to_mongo(), {"_id": 1}) self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) - @requires_mongodb_gte_26 def test_modify_with_push(self): class BlogPost(Document): tags = ListField(StringField()) diff --git a/tests/queryset/pickable.py b/tests/queryset/pickable.py index d96e7dc6..bf7bb31c 100644 --- a/tests/queryset/pickable.py +++ b/tests/queryset/pickable.py @@ -6,10 +6,12 @@ from mongoengine.connection import connect __author__ = 'stas' + class Person(Document): name = StringField() age = IntField() + class TestQuerysetPickable(unittest.TestCase): """ Test for adding pickling support for QuerySet instances @@ -18,7 +20,7 @@ class TestQuerysetPickable(unittest.TestCase): def setUp(self): super(TestQuerysetPickable, self).setUp() - connection = connect(db="test") #type: pymongo.mongo_client.MongoClient + connection = connect(db="test") # type: pymongo.mongo_client.MongoClient connection.drop_database("test") @@ -27,7 +29,6 @@ class TestQuerysetPickable(unittest.TestCase): age=21 ) - def test_picke_simple_qs(self): qs = Person.objects.all() @@ -46,10 +47,10 @@ class TestQuerysetPickable(unittest.TestCase): self.assertEqual(qs.count(), loadedQs.count()) - #can update loadedQs + # can update loadedQs loadedQs.update(age=23) - #check + # check self.assertEqual(Person.objects.first().age, 23) def test_pickle_support_filtration(self): @@ -70,7 +71,7 @@ class TestQuerysetPickable(unittest.TestCase): self.assertEqual(loaded.count(), 2) self.assertEqual(loaded.filter(name="Bob").first().age, 23) - + diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index d3a2418a..04cfb061 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -6,25 +6,21 @@ import uuid from decimal import Decimal from bson import DBRef, ObjectId -from nose.plugins.skip import SkipTest import pymongo from pymongo.errors import ConfigurationError from pymongo.read_preferences import ReadPreference from pymongo.results import UpdateResult import six +from six import iteritems from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) -from tests.utils import requires_mongodb_gte_26, skip_pymongo3, get_mongodb_version, MONGODB_32 - -__all__ = ("QuerySetTest",) - class db_ops_tracker(query_counter): @@ -34,6 +30,12 @@ class db_ops_tracker(query_counter): return list(self.db.system.profile.find(ignore_query)) +def get_key_compat(mongo_ver): + ORDER_BY_KEY = 'sort' + CMD_QUERY_KEY = 'command' if mongo_ver >= MONGODB_36 else 'query' + return ORDER_BY_KEY, CMD_QUERY_KEY + + class QuerySetTest(unittest.TestCase): def setUp(self): @@ -88,7 +90,7 @@ class QuerySetTest(unittest.TestCase): results = list(people) self.assertIsInstance(results[0], self.Person) - self.assertIsInstance(results[0].id, (ObjectId, str, unicode)) + self.assertIsInstance(results[0].id, ObjectId) self.assertEqual(results[0], user_a) self.assertEqual(results[0].name, 'User A') @@ -159,6 +161,11 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(person.name, 'User B') self.assertEqual(person.age, None) + def test___getitem___invalid_index(self): + """Ensure slicing a queryset works as expected.""" + with self.assertRaises(TypeError): + self.Person.objects()['a'] + def test_slice(self): """Ensure slicing a queryset works as expected.""" user_a = self.Person.objects.create(name='User A', age=20) @@ -395,6 +402,16 @@ class QuerySetTest(unittest.TestCase): with self.assertRaises(ValueError): list(qs) + def test_batch_size_cloned(self): + class A(Document): + s = StringField() + + # test that batch size gets cloned + qs = A.objects.batch_size(5) + self.assertEqual(qs._batch_size, 5) + qs_clone = qs.clone() + self.assertEqual(qs_clone._batch_size, 5) + def test_update_write_concern(self): """Test that passing write_concern works""" self.Person.drop_collection() @@ -580,7 +597,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(post.comments[0].by, 'joe') self.assertEqual(post.comments[0].votes.score, 4) - @requires_mongodb_gte_26 def test_update_min_max(self): class Scores(Document): high_score = IntField() @@ -598,7 +614,6 @@ class QuerySetTest(unittest.TestCase): Scores.objects(id=scores.id).update(max__high_score=500) self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) - @requires_mongodb_gte_26 def test_update_multiple(self): class Product(Document): item = StringField() @@ -850,11 +865,7 @@ class QuerySetTest(unittest.TestCase): with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs, load_bulk=False) - - if MONGO_VER == MONGODB_32: - self.assertEqual(q, 1) # 1 entry containing the list of inserts - else: - self.assertEqual(q, len(blogs)) # 1 entry per doc inserted + self.assertEqual(q, 1) # 1 entry containing the list of inserts self.assertEqual(Blog.objects.count(), len(blogs)) @@ -867,11 +878,7 @@ class QuerySetTest(unittest.TestCase): with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs) - - if MONGO_VER == MONGODB_32: - self.assertEqual(q, 2) # 1 for insert 1 for fetch - else: - self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs + self.assertEqual(q, 2) # 1 for insert 1 for fetch Blog.drop_collection() @@ -977,6 +984,29 @@ class QuerySetTest(unittest.TestCase): inserted_comment_id = Comment.objects.insert(comment, load_bulk=False) self.assertEqual(comment.id, inserted_comment_id) + def test_bulk_insert_accepts_doc_with_ids(self): + class Comment(Document): + id = IntField(primary_key=True) + + Comment.drop_collection() + + com1 = Comment(id=0) + com2 = Comment(id=1) + Comment.objects.insert([com1, com2]) + + def test_insert_raise_if_duplicate_in_constraint(self): + class Comment(Document): + id = IntField(primary_key=True) + + Comment.drop_collection() + + com1 = Comment(id=0) + + Comment.objects.insert(com1) + + with self.assertRaises(NotUniqueError): + Comment.objects.insert(com1) + def test_get_changed_fields_query_count(self): """Make sure we don't perform unnecessary db operations when none of document's fields were updated. @@ -1038,48 +1068,6 @@ class QuerySetTest(unittest.TestCase): org.save() # saves the org self.assertEqual(q, 2) - @skip_pymongo3 - def test_slave_okay(self): - """Ensures that a query can take slave_okay syntax. - Useless with PyMongo 3+ as well as with MongoDB 3+. - """ - person1 = self.Person(name="User A", age=20) - person1.save() - person2 = self.Person(name="User B", age=30) - person2.save() - - # Retrieve the first person from the database - person = self.Person.objects.slave_okay(True).first() - self.assertIsInstance(person, self.Person) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, 20) - - @requires_mongodb_gte_26 - @skip_pymongo3 - def test_cursor_args(self): - """Ensures the cursor args can be set as expected - """ - p = self.Person.objects - # Check default - self.assertEqual(p._cursor_args, - {'snapshot': False, 'slave_okay': False, 'timeout': True}) - - p = p.snapshot(False).slave_okay(False).timeout(False) - self.assertEqual(p._cursor_args, - {'snapshot': False, 'slave_okay': False, 'timeout': False}) - - p = p.snapshot(True).slave_okay(False).timeout(False) - self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': False, 'timeout': False}) - - p = p.snapshot(True).slave_okay(True).timeout(False) - self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': True, 'timeout': False}) - - p = p.snapshot(True).slave_okay(True).timeout(True) - self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': True, 'timeout': True}) - def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """ @@ -1203,7 +1191,7 @@ class QuerySetTest(unittest.TestCase): """Ensure filters can be chained together. """ class Blog(Document): - id = StringField(unique=True, primary_key=True) + id = StringField(primary_key=True) class BlogPost(Document): blog = ReferenceField(Blog) @@ -1314,8 +1302,7 @@ class QuerySetTest(unittest.TestCase): """Ensure that the default ordering can be cleared by calling order_by() w/o any arguments. """ - MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby' + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class BlogPost(Document): title = StringField() @@ -1332,7 +1319,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.filter(title='whatever').first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0]['query'][ORDER_BY_KEY], + q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {'published_date': -1} ) @@ -1340,14 +1327,14 @@ class QuerySetTest(unittest.TestCase): with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by().first() self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) # calling an explicit order_by should use a specified sort with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by('published_date').first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0]['query'][ORDER_BY_KEY], + q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {'published_date': 1} ) @@ -1356,13 +1343,12 @@ class QuerySetTest(unittest.TestCase): qs = BlogPost.objects.filter(title='whatever').order_by('published_date') qs.order_by().first() self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) def test_no_ordering_for_get(self): """ Ensure that Doc.objects.get doesn't use any ordering. """ - MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby' + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class BlogPost(Document): title = StringField() @@ -1378,13 +1364,13 @@ class QuerySetTest(unittest.TestCase): with db_ops_tracker() as q: BlogPost.objects.get(title='whatever') self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) # Ordering should be ignored for .get even if we set it explicitly with db_ops_tracker() as q: BlogPost.objects.order_by('-title').get(title='whatever') self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) def test_find_embedded(self): """Ensure that an embedded document is properly returned from @@ -2033,7 +2019,6 @@ class QuerySetTest(unittest.TestCase): pymongo_doc = BlogPost.objects.as_pymongo().first() self.assertNotIn('title', pymongo_doc) - @requires_mongodb_gte_26 def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly. """ @@ -2184,6 +2169,40 @@ class QuerySetTest(unittest.TestCase): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=['Ross']) + def test_pull_from_nested_embedded_using_in_nin(self): + """Ensure that the 'pull' update operation works on embedded documents using 'in' and 'nin' operators. + """ + + class User(EmbeddedDocument): + name = StringField() + + def __unicode__(self): + return '%s' % self.name + + class Collaborator(EmbeddedDocument): + helpful = ListField(EmbeddedDocumentField(User)) + unhelpful = ListField(EmbeddedDocumentField(User)) + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = EmbeddedDocumentField(Collaborator) + + Site.drop_collection() + + a = User(name='Esteban') + b = User(name='Frank') + x = User(name='Harry') + y = User(name='John') + + s = Site(name="test", collaborators=Collaborator( + helpful=[a, b], unhelpful=[x, y])).save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful__name__in=['Esteban']) # Pull a + self.assertEqual(Site.objects.first().collaborators['helpful'], [b]) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful__name__nin=['John']) # Pull x + self.assertEqual(Site.objects.first().collaborators['unhelpful'], [y]) + def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): @@ -2233,6 +2252,19 @@ class QuerySetTest(unittest.TestCase): bar.reload() self.assertEqual(len(bar.foos), 0) + def test_update_one_check_return_with_full_result(self): + class BlogTag(Document): + name = StringField(required=True) + + BlogTag.drop_collection() + + BlogTag(name='garbage').save() + default_update = BlogTag.objects.update_one(name='new') + self.assertEqual(default_update, 1) + + full_result_update = BlogTag.objects.update_one(name='new', full_result=True) + self.assertIsInstance(full_result_update, UpdateResult) + def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2510,8 +2542,9 @@ class QuerySetTest(unittest.TestCase): def test_comment(self): """Make sure adding a comment to the query gets added to the query""" MONGO_VER = self.mongodb_version - QUERY_KEY = 'filter' if MONGO_VER == MONGODB_32 else '$query' - COMMENT_KEY = 'comment' if MONGO_VER == MONGODB_32 else '$comment' + _, CMD_QUERY_KEY = get_key_compat(MONGO_VER) + QUERY_KEY = 'filter' + COMMENT_KEY = 'comment' class User(Document): age = IntField() @@ -2528,8 +2561,8 @@ class QuerySetTest(unittest.TestCase): ops = q.get_ops() self.assertEqual(len(ops), 2) for op in ops: - self.assertEqual(op['query'][QUERY_KEY], {'age': {'$gte': 18}}) - self.assertEqual(op['query'][COMMENT_KEY], 'looking for an adult') + self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {'age': {'$gte': 18}}) + self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], 'looking for an adult') def test_map_reduce(self): """Ensure map/reduce is both mapping and reducing. @@ -3325,7 +3358,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Foo.objects.distinct("bar"), [bar]) - @requires_mongodb_gte_26 def test_text_indexes(self): class News(Document): title = StringField() @@ -3335,7 +3367,7 @@ class QuerySetTest(unittest.TestCase): meta = {'indexes': [ {'fields': ['$title', "$content"], 'default_language': 'portuguese', - 'weight': {'title': 10, 'content': 2} + 'weights': {'title': 10, 'content': 2} } ]} @@ -3393,10 +3425,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(query.count(), 3) self.assertEqual(query._query, {'$text': {'$search': 'brasil'}}) cursor_args = query._cursor_args - if not IS_PYMONGO_3: - cursor_args_fields = cursor_args['fields'] - else: - cursor_args_fields = cursor_args['projection'] + cursor_args_fields = cursor_args['projection'] self.assertEqual( cursor_args_fields, {'_text_score': {'$meta': 'textScore'}}) @@ -3412,7 +3441,6 @@ class QuerySetTest(unittest.TestCase): 'brasil').order_by('$text_score').first() self.assertEqual(item.get_text_score(), max_text_score) - @requires_mongodb_gte_26 def test_distinct_handles_references_to_alias(self): register_connection('testdb', 'mongoenginetest2') @@ -3548,6 +3576,11 @@ class QuerySetTest(unittest.TestCase): opts = {"deleted": False} return qryset(**opts) + @queryset_manager + def objects_1_arg(qryset): + opts = {"deleted": False} + return qryset(**opts) + @queryset_manager def music_posts(doc_cls, queryset, deleted=False): return queryset(tags='music', @@ -3562,6 +3595,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([p.id for p in BlogPost.objects()], [post1.id, post2.id, post3.id]) + self.assertEqual([p.id for p in BlogPost.objects_1_arg()], + [post1.id, post2.id, post3.id]) self.assertEqual([p.id for p in BlogPost.music_posts()], [post1.id, post2.id]) @@ -4026,7 +4061,7 @@ class QuerySetTest(unittest.TestCase): info = [(value['key'], value.get('unique', False), value.get('sparse', False)) - for key, value in info.iteritems()] + for key, value in iteritems(info)] self.assertIn(([('_cls', 1), ('message', 1)], False, False), info) def test_where(self): @@ -4037,7 +4072,7 @@ class QuerySetTest(unittest.TestCase): fielda = IntField() fieldb = IntField() - IntPair.objects._collection.remove() + IntPair.drop_collection() a = IntPair(fielda=1, fieldb=1) b = IntPair(fielda=1, fieldb=2) @@ -4489,11 +4524,7 @@ class QuerySetTest(unittest.TestCase): bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) self.assertEqual([], bars) - if not IS_PYMONGO_3: - error_class = ConfigurationError - else: - error_class = TypeError - self.assertRaises(error_class, Bar.objects, read_preference='Primary') + self.assertRaises(TypeError, Bar.objects, read_preference='Primary') # read_preference as a kwarg bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) @@ -4541,7 +4572,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED) - @requires_mongodb_gte_26 def test_read_preference_aggregation_framework(self): class Bar(Document): txt = StringField() @@ -4553,12 +4583,8 @@ class QuerySetTest(unittest.TestCase): bars = Bar.objects \ .read_preference(ReadPreference.SECONDARY_PREFERRED) \ .aggregate() - if IS_PYMONGO_3: - self.assertEqual(bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED) - else: - self.assertNotEqual(bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._CommandCursor__collection.read_preference, + ReadPreference.SECONDARY_PREFERRED) def test_json_simple(self): @@ -4580,9 +4606,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) def test_json_complex(self): - if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: - raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") - class EmbeddedDoc(EmbeddedDocument): pass @@ -4949,6 +4972,38 @@ class QuerySetTest(unittest.TestCase): people.count() self.assertEqual(q, 3) + def test_no_cached_queryset__repr__(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + qs = Person.objects.no_cache() + self.assertEqual(repr(qs), '[]') + + def test_no_cached_on_a_cached_queryset_raise_error(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + Person(name='a').save() + qs = Person.objects() + _ = list(qs) + with self.assertRaises(OperationError) as ctx_err: + qs.no_cache() + self.assertEqual("QuerySet already cached", str(ctx_err.exception)) + + def test_no_cached_queryset_no_cache_back_to_cache(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + qs = Person.objects() + self.assertIsInstance(qs, QuerySet) + qs = qs.no_cache() + self.assertIsInstance(qs, QuerySetNoCache) + qs = qs.cache() + self.assertIsInstance(qs, QuerySet) + def test_cache_not_cloned(self): class User(Document): @@ -5117,7 +5172,7 @@ class QuerySetTest(unittest.TestCase): def test_query_reference_to_custom_pk_doc(self): class A(Document): - id = StringField(unique=True, primary_key=True) + id = StringField(primary_key=True) class B(Document): a = ReferenceField(A) @@ -5221,8 +5276,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(op['nreturned'], 1) def test_bool_with_ordering(self): - MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby' + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class Person(Document): name = StringField() @@ -5241,21 +5295,22 @@ class QuerySetTest(unittest.TestCase): op = q.db.system.profile.find({"ns": {"$ne": "%s.system.indexes" % q.db.name}})[0] - self.assertNotIn(ORDER_BY_KEY, op['query']) + self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) # Check that normal query uses orderby qs2 = Person.objects.order_by('name') - with query_counter() as p: + with query_counter() as q: for x in qs2: pass - op = p.db.system.profile.find({"ns": + op = q.db.system.profile.find({"ns": {"$ne": "%s.system.indexes" % q.db.name}})[0] - self.assertIn(ORDER_BY_KEY, op['query']) + self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) def test_bool_with_ordering_from_meta_dict(self): + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class Person(Document): name = StringField() @@ -5277,14 +5332,13 @@ class QuerySetTest(unittest.TestCase): op = q.db.system.profile.find({"ns": {"$ne": "%s.system.indexes" % q.db.name}})[0] - self.assertNotIn('$orderby', op['query'], + self.assertNotIn('$orderby', op[CMD_QUERY_KEY], 'BaseQuerySet must remove orderby from meta in boolen test') self.assertEqual(Person.objects.first().name, 'A') self.assertTrue(Person.objects._has_data(), 'Cursor has data and returned False') - @requires_mongodb_gte_26 def test_queryset_aggregation_framework(self): class Person(Document): name = StringField() @@ -5293,13 +5347,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects(age__lte=22).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5330,6 +5380,179 @@ class QuerySetTest(unittest.TestCase): {'_id': None, 'avg': 29, 'total': 2} ]) + def test_queryset_aggregation_with_skip(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = Person.objects.skip(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p2.pk, 'name': "WILSON JUNIOR"}, + {'_id': p3.pk, 'name': "SANDRA MARA"} + ]) + + def test_queryset_aggregation_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = Person.objects.limit(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"} + ]) + + def test_queryset_aggregation_with_sort(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = Person.objects.order_by('name').aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, + {'_id': p3.pk, 'name': "SANDRA MARA"}, + {'_id': p2.pk, 'name': "WILSON JUNIOR"} + ]) + + def test_queryset_aggregation_with_skip_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = list( + Person.objects.skip(1).limit(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + ) + + self.assertEqual(list(data), [ + {'_id': p2.pk, 'name': "WILSON JUNIOR"}, + ]) + + # Make sure limit/skip chaining order has no impact + data2 = Person.objects.limit(1).skip(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(data, list(data2)) + + def test_queryset_aggregation_with_sort_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = Person.objects.order_by('name').limit(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, + {'_id': p3.pk, 'name': "SANDRA MARA"} + ]) + + # Verify adding limit/skip steps works as expected + data = Person.objects.order_by('name').limit(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}}, + {'$limit': 1}, + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, + ]) + + data = Person.objects.order_by('name').limit(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}}, + {'$skip': 1}, + {'$limit': 1}, + ) + + self.assertEqual(list(data), [ + {'_id': p3.pk, 'name': "SANDRA MARA"}, + ]) + + def test_queryset_aggregation_with_sort_with_skip(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = Person.objects.order_by('name').skip(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p2.pk, 'name': "WILSON JUNIOR"} + ]) + + def test_queryset_aggregation_with_sort_with_skip_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p2 = Person(name="Wilson Junior", age=21) + p3 = Person(name="Sandra Mara", age=37) + Person.objects.insert([p1, p2, p3]) + + data = Person.objects.order_by('name').skip(1).limit(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p3.pk, 'name': "SANDRA MARA"} + ]) + def test_delete_count(self): [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count @@ -5363,8 +5586,8 @@ class QuerySetTest(unittest.TestCase): Animal(is_mamal=False).save() Cat(is_mamal=True, whiskers_length=5.1).save() ScottishCat(is_mamal=True, folded_ears=True).save() - self.assertEquals(Animal.objects(folded_ears=True).count(), 1) - self.assertEquals(Animal.objects(whiskers_length=5.1).count(), 1) + self.assertEqual(Animal.objects(folded_ears=True).count(), 1) + self.assertEqual(Animal.objects(whiskers_length=5.1).count(), 1) def test_loop_over_invalid_id_does_not_crash(self): class Person(Document): @@ -5372,7 +5595,7 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() - Person._get_collection().insert({'name': 'a', 'id': ''}) + Person._get_collection().insert_one({'name': 'a', 'id': ''}) for p in Person.objects(): self.assertEqual(p.name, 'a') diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 8064f09c..2c2d018c 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -71,6 +71,14 @@ class TransformTest(unittest.TestCase): update = transform.update(BlogPost, push_all__tags=['mongo', 'db']) self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}}) + def test_transform_update_no_operator_default_to_set(self): + """Ensure the differences in behvaior between 'push' and 'push_all'""" + class BlogPost(Document): + tags = ListField(StringField()) + + update = transform.update(BlogPost, tags=['mongo', 'db']) + self.assertEqual(update, {'$set': {'tags': ['mongo', 'db']}}) + def test_query_field_name(self): """Ensure that the correct field name is used when querying. """ @@ -283,6 +291,11 @@ class TransformTest(unittest.TestCase): update = transform.update(MainDoc, pull__content__heading='xyz') self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) + update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar']) + self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}}) + + update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar']) + self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}}) if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index 8261faae..22d274a8 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -275,7 +275,6 @@ class QTest(unittest.TestCase): with self.assertRaises(InvalidQueryError): self.Person.objects.filter('user1') - def test_q_regex(self): """Ensure that Q objects can be queried using regexes. """ diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..04ad5b34 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,15 @@ +import unittest + +from mongoengine.common import _import_class +from mongoengine import Document + + +class TestCommon(unittest.TestCase): + + def test__import_class(self): + doc_cls = _import_class("Document") + self.assertIs(doc_cls, Document) + + def test__import_class_raise_if_not_known(self): + with self.assertRaises(ValueError): + _import_class("UnknownClass") diff --git a/tests/test_connection.py b/tests/test_connection.py index 88d63cdb..d3fcc395 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,5 +1,8 @@ import datetime -from pymongo.errors import OperationFailure + +from pymongo import MongoClient +from pymongo.errors import OperationFailure, InvalidName +from pymongo import ReadPreference try: import unittest2 as unittest @@ -12,23 +15,27 @@ from bson.tz_util import utc from mongoengine import ( connect, register_connection, - Document, DateTimeField -) -from mongoengine.python_support import IS_PYMONGO_3 + Document, DateTimeField, + disconnect_all, StringField) import mongoengine.connection from mongoengine.connection import (MongoEngineConnectionError, get_db, - get_connection) + get_connection, disconnect, DEFAULT_DATABASE_NAME) def get_tz_awareness(connection): - if not IS_PYMONGO_3: - return connection.tz_aware - else: - return connection.codec_options.tz_aware + return connection.codec_options.tz_aware class ConnectionTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + disconnect_all() + + @classmethod + def tearDownClass(cls): + disconnect_all() + def tearDown(self): mongoengine.connection._connection_settings = {} mongoengine.connection._connections = {} @@ -49,6 +56,147 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + def test_connect_disconnect_works_properly(self): + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + class History2(Document): + name = StringField() + meta = {'db_alias': 'db2'} + + connect('db1', alias='db1') + connect('db2', alias='db2') + + History1.drop_collection() + History2.drop_collection() + + h = History1(name='default').save() + h1 = History2(name='db1').save() + + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + + disconnect('db1') + disconnect('db2') + + with self.assertRaises(MongoEngineConnectionError): + list(History1.objects().as_pymongo()) + + with self.assertRaises(MongoEngineConnectionError): + list(History2.objects().as_pymongo()) + + connect('db1', alias='db1') + connect('db2', alias='db2') + + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + + def test_connect_different_documents_to_different_database(self): + class History(Document): + name = StringField() + + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + class History2(Document): + name = StringField() + meta = {'db_alias': 'db2'} + + connect() + connect('db1', alias='db1') + connect('db2', alias='db2') + + History.drop_collection() + History1.drop_collection() + History2.drop_collection() + + h = History(name='default').save() + h1 = History1(name='db1').save() + h2 = History2(name='db2').save() + + self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) + self.assertEqual(History1._collection.database.name, 'db1') + self.assertEqual(History2._collection.database.name, 'db2') + + self.assertEqual(list(History.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h2.id, 'name': 'db2'}]) + + def test_connect_fails_if_connect_2_times_with_default_alias(self): + connect('mongoenginetest') + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + connect('mongoenginetest2') + self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception)) + + def test_connect_fails_if_connect_2_times_with_custom_alias(self): + connect('mongoenginetest', alias='alias1') + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + connect('mongoenginetest2', alias='alias1') + + self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception)) + + def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self): + """Intended to keep the detecton function simple but robust""" + db_name = 'mongoenginetest' + db_alias = 'alias1' + connect(db=db_name, alias=db_alias, host='localhost', port=27017) + + with self.assertRaises(MongoEngineConnectionError): + connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias) + + def test_connect_passes_silently_connect_multiple_times_with_same_config(self): + # test default connection to `test` + connect() + connect() + self.assertEqual(len(mongoengine.connection._connections), 1) + connect('test01', alias='test01') + connect('test01', alias='test01') + self.assertEqual(len(mongoengine.connection._connections), 2) + connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + self.assertEqual(len(mongoengine.connection._connections), 3) + + def test_connect_with_invalid_db_name(self): + """Ensure that connect() method fails fast if db name is invalid + """ + with self.assertRaises(InvalidName): + connect('mongomock://localhost') + + def test_connect_with_db_name_external(self): + """Ensure that connect() works if db name is $external + """ + """Ensure that the connect() method works properly.""" + connect('$external') + + conn = get_connection() + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, '$external') + + connect('$external', alias='testdb') + conn = get_connection('testdb') + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + def test_connect_with_invalid_db_name_type(self): + """Ensure that connect() method fails fast if db name has invalid type + """ + with self.assertRaises(TypeError): + non_string_db_name = ['e. g. list instead of a string'] + connect(non_string_db_name) + def test_connect_in_mocking(self): """Ensure that the connect() method works properly in mocking. """ @@ -99,11 +247,11 @@ class ConnectionTest(unittest.TestCase): conn = get_connection() self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2') + connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2') conn = get_connection('testdb2') self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['localhost'], is_mock=True, alias='testdb3') + connect(host=['localhost'], is_mock=True, alias='testdb3') conn = get_connection('testdb3') self.assertIsInstance(conn, mongomock.MongoClient) @@ -111,21 +259,141 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb4') self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5') + connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5') conn = get_connection('testdb5') self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6') + connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6') conn = get_connection('testdb6') self.assertIsInstance(conn, mongomock.MongoClient) - def test_disconnect(self): - """Ensure that the disconnect() method works properly - """ + def test_disconnect_cleans_globals(self): + """Ensure that the disconnect() method cleans the globals objects""" + connections = mongoengine.connection._connections + dbs = mongoengine.connection._dbs + connection_settings = mongoengine.connection._connection_settings + + connect('mongoenginetest') + + self.assertEqual(len(connections), 1) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 1) + + class TestDoc(Document): + pass + + TestDoc.drop_collection() # triggers the db + self.assertEqual(len(dbs), 1) + + disconnect() + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + def test_disconnect_cleans_cached_collection_attribute_in_document(self): + """Ensure that the disconnect() method works properly""" conn1 = connect('mongoenginetest') - mongoengine.connection.disconnect() - conn2 = connect('mongoenginetest') - self.assertTrue(conn1 is not conn2) + + class History(Document): + pass + + self.assertIsNone(History._collection) + + History.drop_collection() + + History.objects.first() # will trigger the caching of _collection attribute + self.assertIsNotNone(History._collection) + + disconnect() + + self.assertIsNone(History._collection) + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + History.objects.first() + self.assertEqual("You have not defined a default connection", str(ctx_err.exception)) + + def test_connect_disconnect_works_on_same_document(self): + """Ensure that the connect/disconnect works properly with a single Document""" + db1 = 'db1' + db2 = 'db2' + + # Ensure freshness of the 2 databases through pymongo + client = MongoClient('localhost', 27017) + client.drop_database(db1) + client.drop_database(db2) + + # Save in db1 + connect(db1) + + class User(Document): + name = StringField(required=True) + + user1 = User(name='John is in db1').save() + disconnect() + + # Make sure save doesnt work at this stage + with self.assertRaises(MongoEngineConnectionError): + User(name='Wont work').save() + + # Save in db2 + connect(db2) + user2 = User(name='Bob is in db2').save() + disconnect() + + db1_users = list(client[db1].user.find()) + self.assertEqual(db1_users, [{'_id': user1.id, 'name': 'John is in db1'}]) + db2_users = list(client[db2].user.find()) + self.assertEqual(db2_users, [{'_id': user2.id, 'name': 'Bob is in db2'}]) + + def test_disconnect_silently_pass_if_alias_does_not_exist(self): + connections = mongoengine.connection._connections + self.assertEqual(len(connections), 0) + disconnect(alias='not_exist') + + def test_disconnect_all(self): + connections = mongoengine.connection._connections + dbs = mongoengine.connection._dbs + connection_settings = mongoengine.connection._connection_settings + + connect('mongoenginetest') + connect('mongoenginetest2', alias='db1') + + class History(Document): + pass + + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + History.drop_collection() # will trigger the caching of _collection attribute + History.objects.first() + History1.drop_collection() + History1.objects.first() + + self.assertIsNotNone(History._collection) + self.assertIsNotNone(History1._collection) + + self.assertEqual(len(connections), 2) + self.assertEqual(len(dbs), 2) + self.assertEqual(len(connection_settings), 2) + + disconnect_all() + + self.assertIsNone(History._collection) + self.assertIsNone(History1._collection) + + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + with self.assertRaises(MongoEngineConnectionError): + History.objects.first() + + with self.assertRaises(MongoEngineConnectionError): + History1.objects.first() + + def test_disconnect_all_silently_pass_if_no_connection_exist(self): + disconnect_all() def test_sharing_connections(self): """Ensure that connections are shared when the connection settings are exactly the same @@ -136,29 +404,19 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetests', alias='testdb2') actual_connection = get_connection('testdb2') - # Handle PyMongo 3+ Async Connection - if IS_PYMONGO_3: - # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. - # Purposely not catching exception to fail test if thrown. - expected_connection.server_info() + expected_connection.server_info() self.assertEqual(expected_connection, actual_connection) def test_connect_uri(self): """Ensure that the connect() method works properly with URIs.""" c = connect(db='mongoenginetest', alias='admin') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) - c.admin.add_user("admin", "password") + c.admin.command("createUser", "admin", pwd="password", roles=["root"]) c.admin.authenticate("admin", "password") - c.mongoenginetest.add_user("username", "password") - - if not IS_PYMONGO_3: - self.assertRaises( - MongoEngineConnectionError, connect, 'testdb_uri_bad', - host='mongodb://test:password@localhost' - ) + c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"]) connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') @@ -169,8 +427,8 @@ class ConnectionTest(unittest.TestCase): self.assertIsInstance(db, pymongo.database.Database) self.assertEqual(db.name, 'mongoenginetest') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) def test_connect_uri_without_db(self): """Ensure connect() method works properly if the URI doesn't @@ -217,23 +475,16 @@ class ConnectionTest(unittest.TestCase): """ # Create users c = connect('mongoenginetest') - c.admin.system.users.remove({}) - c.admin.add_user('username2', 'password') + + c.admin.system.users.delete_many({}) + c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"]) # Authentication fails without "authSource" - if IS_PYMONGO_3: - test_conn = connect( - 'mongoenginetest', alias='test1', - host='mongodb://username2:password@localhost/mongoenginetest' - ) - self.assertRaises(OperationFailure, test_conn.server_info) - else: - self.assertRaises( - MongoEngineConnectionError, - connect, 'mongoenginetest', alias='test1', - host='mongodb://username2:password@localhost/mongoenginetest' - ) - self.assertRaises(MongoEngineConnectionError, get_db, 'test1') + test_conn = connect( + 'mongoenginetest', alias='test1', + host='mongodb://username2:password@localhost/mongoenginetest' + ) + self.assertRaises(OperationFailure, test_conn.server_info) # Authentication succeeds with "authSource" authd_conn = connect( @@ -246,7 +497,7 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(db.name, 'mongoenginetest') # Clear all users - authd_conn.admin.system.users.remove({}) + authd_conn.admin.system.users.delete_many({}) def test_register_connection(self): """Ensure that connections with different aliases may be registered. @@ -284,14 +535,7 @@ class ConnectionTest(unittest.TestCase): """Ensure we can specify a max connection pool size using a connection kwarg. """ - # Use "max_pool_size" or "maxpoolsize" depending on PyMongo version - # (former was changed to the latter as described in - # https://jira.mongodb.org/browse/PYTHON-854). - # TODO remove once PyMongo < 3.0 support is dropped - if pymongo.version_tuple[0] >= 3: - pool_size_kwargs = {'maxpoolsize': 100} - else: - pool_size_kwargs = {'max_pool_size': 100} + pool_size_kwargs = {'maxpoolsize': 100} conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs) self.assertEqual(conn.max_pool_size, 100) @@ -300,9 +544,6 @@ class ConnectionTest(unittest.TestCase): """Ensure we can specify a max connection pool size using an option in a connection URI. """ - if pymongo.version_tuple[0] == 2 and pymongo.version_tuple[1] < 9: - raise SkipTest('maxpoolsize as a URI option is only supported in PyMongo v2.9+') - conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri') self.assertEqual(conn.max_pool_size, 100) @@ -312,46 +553,30 @@ class ConnectionTest(unittest.TestCase): """ conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true') conn2 = connect('testing', alias='conn2', w=1, j=True) - if IS_PYMONGO_3: - self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) - self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) - else: - self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True}) - self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True}) + self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) + self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) def test_connect_with_replicaset_via_uri(self): """Ensure connect() works when specifying a replicaSet via the MongoDB URI. """ - if IS_PYMONGO_3: - c = connect(host='mongodb://localhost/test?replicaSet=local-rs') - db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') - else: - # PyMongo < v3.x raises an exception: - # "localhost:27017 is not a member of replica set local-rs" - with self.assertRaises(MongoEngineConnectionError): - c = connect(host='mongodb://localhost/test?replicaSet=local-rs') + c = connect(host='mongodb://localhost/test?replicaSet=local-rs') + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, 'test') def test_connect_with_replicaset_via_kwargs(self): """Ensure connect() works when specifying a replicaSet via the connection kwargs """ - if IS_PYMONGO_3: - c = connect(replicaset='local-rs') - self.assertEqual(c._MongoClient__options.replica_set_name, - 'local-rs') - db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') - else: - # PyMongo < v3.x raises an exception: - # "localhost:27017 is not a member of replica set local-rs" - with self.assertRaises(MongoEngineConnectionError): - c = connect(replicaset='local-rs') + c = connect(replicaset='local-rs') + self.assertEqual(c._MongoClient__options.replica_set_name, + 'local-rs') + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, 'test') - def test_datetime(self): + def test_connect_tz_aware(self): connect('mongoenginetest', tz_aware=True) d = datetime.datetime(2010, 5, 5, tzinfo=utc) @@ -365,10 +590,8 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(d, date_doc.the_date) def test_read_preference_from_parse(self): - if IS_PYMONGO_3: - from pymongo import ReadPreference - conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") - self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) + conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") + self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) def test_multiple_connection_settings(self): connect('mongoenginetest', alias='t1', host="localhost") @@ -379,17 +602,24 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(len(mongo_connections.items()), 2) self.assertIn('t1', mongo_connections.keys()) self.assertIn('t2', mongo_connections.keys()) - if not IS_PYMONGO_3: - self.assertEqual(mongo_connections['t1'].host, 'localhost') - self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') - else: - # Handle PyMongo 3+ Async Connection - # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. - # Purposely not catching exception to fail test if thrown. - mongo_connections['t1'].server_info() - mongo_connections['t2'].server_info() - self.assertEqual(mongo_connections['t1'].address[0], 'localhost') - self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') + + # Handle PyMongo 3+ Async Connection + # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. + # Purposely not catching exception to fail test if thrown. + mongo_connections['t1'].server_info() + mongo_connections['t2'].server_info() + self.assertEqual(mongo_connections['t1'].address[0], 'localhost') + self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') + + def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): + c1 = connect(alias='testdb1', db='testdb1') + c2 = connect(alias='testdb2', db='testdb2') + self.assertIs(c1, c2) + + def test_connect_2_databases_uses_different_client_if_different_parameters(self): + c1 = connect(alias='testdb1', db='testdb1', username='u1') + c2 = connect(alias='testdb2', db='testdb2', username='u2') + self.assertIsNot(c1, c2) if __name__ == '__main__': diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 8fb7bc78..529032fe 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -5,6 +5,7 @@ from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, no_sub_classes, no_dereference, query_counter) +from mongoengine.pymongo_support import count_documents class ContextManagersTest(unittest.TestCase): @@ -36,14 +37,15 @@ class ContextManagersTest(unittest.TestCase): def test_switch_collection_context_manager(self): connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') + register_connection(alias='testdb-1', db='mongoenginetest2') class Group(Document): name = StringField() - Group.drop_collection() + Group.drop_collection() # drops in default + with switch_collection(Group, 'group1') as Group: - Group.drop_collection() + Group.drop_collection() # drops in group1 Group(name="hello - group").save() self.assertEqual(1, Group.objects.count()) @@ -240,7 +242,7 @@ class ContextManagersTest(unittest.TestCase): collection.drop() def issue_1_count_query(): - collection.find({}).count() + count_documents(collection, {}) def issue_1_insert_query(): collection.insert_one({'test': 'garbage'}) @@ -268,6 +270,14 @@ class ContextManagersTest(unittest.TestCase): counter += 1 self.assertEqual(q, counter) + self.assertEqual(int(q), counter) # test __int__ + self.assertEqual(repr(q), str(int(q))) # test __repr__ + self.assertGreater(q, -1) # test __gt__ + self.assertGreaterEqual(q, int(q)) # test __gte__ + self.assertNotEqual(q, -1) + self.assertLess(q, 1000) + self.assertLessEqual(q, int(q)) + def test_query_counter_counts_getmore_queries(self): connect('mongoenginetest') db = get_db() @@ -302,5 +312,6 @@ class ContextManagersTest(unittest.TestCase): _ = db.system.indexes.find_one() # queries on db.system.indexes are ignored as well self.assertEqual(q, 1) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 2f1277e6..a9ef98e7 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,5 @@ import unittest +from six import iterkeys from mongoengine import Document from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict @@ -203,7 +204,7 @@ class TestBaseList(unittest.TestCase): def test___getitem__using_slice(self): base_list = self._get_baselist([0, 1, 2]) - self.assertEqual(base_list[1:3], [1,2]) + self.assertEqual(base_list[1:3], [1, 2]) self.assertEqual(base_list[0:3:2], [0, 2]) def test___getitem___using_slice_returns_list(self): @@ -218,7 +219,7 @@ class TestBaseList(unittest.TestCase): def test___getitem__sublist_returns_BaseList_bound_to_instance(self): base_list = self._get_baselist( [ - [1,2], + [1, 2], [3, 4] ] ) @@ -305,10 +306,10 @@ class TestBaseList(unittest.TestCase): self.assertEqual(base_list, [-1, 1, -2]) def test___setitem___with_slice(self): - base_list = self._get_baselist([0,1,2,3,4,5]) + base_list = self._get_baselist([0, 1, 2, 3, 4, 5]) base_list[0:6:2] = [None, None, None] self.assertEqual(base_list._instance._changed_fields, ['my_name']) - self.assertEqual(base_list, [None,1,None,3,None,5]) + self.assertEqual(base_list, [None, 1, None, 3, None, 5]) def test___setitem___item_0_calls_mark_as_changed(self): base_list = self._get_baselist([True]) @@ -368,6 +369,20 @@ class TestStrictDict(unittest.TestCase): d = self.dtype(a=1, b=1, c=1) self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + def test_iterkeys(self): + d = self.dtype(a=1) + self.assertEqual(list(iterkeys(d)), ['a']) + + def test_len(self): + d = self.dtype(a=1) + self.assertEqual(len(d), 1) + + def test_pop(self): + d = self.dtype(a=1) + self.assertIn('a', d) + d.pop('a') + self.assertNotIn('a', d) + def test_repr(self): d = self.dtype(a=1, b=2, c=3) self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') @@ -426,8 +441,8 @@ class TestStrictDict(unittest.TestCase): def test_mappings_protocol(self): d = self.dtype(a=1, b=2) - assert dict(d) == {'a': 1, 'b': 2} - assert dict(**d) == {'a': 1, 'b': 2} + self.assertEqual(dict(d), {'a': 1, 'b': 2}) + self.assertEqual(dict(**d), {'a': 1, 'b': 2}) if __name__ == '__main__': diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 5cf089f4..9c565810 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -2,6 +2,7 @@ import unittest from bson import DBRef, ObjectId +from six import iteritems from mongoengine import * from mongoengine.connection import get_db @@ -104,6 +105,14 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + + # verifies that no additional queries gets executed + # if we re-iterate over the ListField once it is + # dereferenced + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) # Document select_related with query_counter() as q: @@ -124,6 +133,46 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + def test_list_item_dereference_orphan_dbref(self): + """Ensure that orphan DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in range(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + group.reload() # Confirm reload works + + # Delete one User so one of the references in the + # Group.members list is an orphan DBRef + User.objects[0].delete() + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + + # verifies that no additional queries gets executed + # if we re-iterate over the ListField once it is + # dereferenced + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + User.drop_collection() Group.drop_collection() @@ -504,6 +553,61 @@ class FieldTest(unittest.TestCase): for m in group_obj.members: self.assertIn('User', m.__class__.__name__) + + def test_generic_reference_orphan_dbref(self): + """Ensure that generic orphan DBRef items in ListFields are dereferenced. + """ + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in range(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + # Delete one UserA instance so that there is + # an orphan DBRef in the GenericReference ListField + UserA.objects[0].delete() + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + self.assertTrue(group_obj._data['members']._dereferenced) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + self.assertTrue(group_obj._data['members']._dereferenced) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() @@ -632,7 +736,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) # Document select_related @@ -645,7 +749,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) # Queryset select_related @@ -659,7 +763,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) User.drop_collection() @@ -714,7 +818,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Document select_related @@ -730,7 +834,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Queryset select_related @@ -747,7 +851,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) Group.objects.delete() @@ -805,7 +909,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, UserA) # Document select_related @@ -821,7 +925,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, UserA) # Queryset select_related @@ -838,7 +942,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, UserA) UserA.drop_collection() @@ -893,7 +997,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Document select_related @@ -909,7 +1013,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Queryset select_related @@ -926,7 +1030,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) Group.objects.delete() @@ -1064,7 +1168,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(msg.author, user) self.assertEqual(msg.author.name, 'new-name') - def test_list_lookup_not_checked_in_map(self): """Ensure we dereference list data correctly """ @@ -1286,5 +1389,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index a53f5903..cacdce8b 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,23 +1,16 @@ import unittest from pymongo import ReadPreference - -from mongoengine.python_support import IS_PYMONGO_3 - -if IS_PYMONGO_3: - from pymongo import MongoClient - CONN_CLASS = MongoClient - READ_PREF = ReadPreference.SECONDARY -else: - from pymongo import ReplicaSetConnection - CONN_CLASS = ReplicaSetConnection - READ_PREF = ReadPreference.SECONDARY_ONLY +from pymongo import MongoClient import mongoengine -from mongoengine import * from mongoengine.connection import MongoEngineConnectionError +CONN_CLASS = MongoClient +READ_PREF = ReadPreference.SECONDARY + + class ConnectionTest(unittest.TestCase): def setUp(self): @@ -35,7 +28,7 @@ class ConnectionTest(unittest.TestCase): """ try: - conn = connect(db='mongoenginetest', + conn = mongoengine.connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=READ_PREF) except MongoEngineConnectionError as e: @@ -47,5 +40,6 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(conn.read_preference, READ_PREF) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_signals.py b/tests/test_signals.py index df687d0e..34cb43c3 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -39,7 +39,6 @@ class SignalTests(unittest.TestCase): def post_init(cls, sender, document, **kwargs): signal_output.append('post_init signal, %s, document._created = %s' % (document, document._created)) - @classmethod def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) @@ -228,6 +227,9 @@ class SignalTests(unittest.TestCase): self.ExplicitId.objects.delete() + # Note that there is a chance that the following assert fails in case + # some receivers (eventually created in other tests) + # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) self.assertEqual(self.pre_signals, post_signals) def test_model_signals(self): @@ -247,7 +249,7 @@ class SignalTests(unittest.TestCase): def load_existing_author(): a = self.Author(name='Bill Shakespeare') a.save() - self.get_signal_output(lambda: None) # eliminate signal output + self.get_signal_output(lambda: None) # eliminate signal output a1 = self.Author.objects(name='Bill Shakespeare')[0] self.assertEqual(self.get_signal_output(create_author), [ @@ -431,5 +433,6 @@ class SignalTests(unittest.TestCase): {} ]) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 5345f75e..27d5ada7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,22 +1,16 @@ +import operator import unittest from nose.plugins.skip import SkipTest from mongoengine import connect -from mongoengine.connection import get_db, get_connection -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.connection import get_db, disconnect_all +from mongoengine.mongodb_support import get_mongodb_version MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database -# Constant that can be used to compare the version retrieved with -# get_mongodb_version() -MONGODB_26 = (2, 6) -MONGODB_3 = (3,0) -MONGODB_32 = (3, 2) - - class MongoDBTestCase(unittest.TestCase): """Base class for tests that need a mongodb connection It ensures that the db is clean at the beginning and dropped at the end automatically @@ -24,6 +18,7 @@ class MongoDBTestCase(unittest.TestCase): @classmethod def setUpClass(cls): + disconnect_all() cls._connection = connect(db=MONGO_TEST_DB) cls._connection.drop_database(MONGO_TEST_DB) cls.db = get_db() @@ -31,61 +26,40 @@ class MongoDBTestCase(unittest.TestCase): @classmethod def tearDownClass(cls): cls._connection.drop_database(MONGO_TEST_DB) + disconnect_all() -def get_mongodb_version(): - """Return the version of the connected mongoDB (first 2 digits) - - :return: tuple(int, int) - """ - version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2) - return tuple(version_list) +def get_as_pymongo(doc): + """Fetch the pymongo version of a certain Document""" + return doc.__class__.objects.as_pymongo().get(id=doc.id) -def _decorated_with_ver_requirement(func, version): - """Return a given function decorated with the version requirement - for a particular MongoDB version tuple. +def _decorated_with_ver_requirement(func, mongo_version_req, oper): + """Return a MongoDB version requirement decorator. - :param version: The version required (tuple(int, int)) + The resulting decorator will raise a SkipTest exception if the current + MongoDB version doesn't match the provided version/operator. + + For example, if you define a decorator like so: + + def requires_mongodb_gte_36(func): + return _decorated_with_ver_requirement( + func, (3.6), oper=operator.ge + ) + + Then tests decorated with @requires_mongodb_gte_36 will be skipped if + ran against MongoDB < v3.6. + + :param mongo_version_req: The mongodb version requirement (tuple(int, int)) + :param oper: The operator to apply (e.g: operator.ge) """ def _inner(*args, **kwargs): - MONGODB_V = get_mongodb_version() - if MONGODB_V >= version: + mongodb_v = get_mongodb_version() + if oper(mongodb_v, mongo_version_req): return func(*args, **kwargs) - raise SkipTest('Needs MongoDB v{}+'.format('.'.join(str(n) for n in version))) + raise SkipTest('Needs MongoDB v{}+'.format('.'.join(str(n) for n in mongo_version_req))) _inner.__name__ = func.__name__ _inner.__doc__ = func.__doc__ - return _inner - - -def requires_mongodb_gte_26(func): - """Raise a SkipTest exception if we're working with MongoDB version - lower than v2.6. - """ - return _decorated_with_ver_requirement(func, MONGODB_26) - - -def requires_mongodb_gte_3(func): - """Raise a SkipTest exception if we're working with MongoDB version - lower than v3.0. - """ - return _decorated_with_ver_requirement(func, MONGODB_3) - - -def skip_pymongo3(f): - """Raise a SkipTest exception if we're running a test against - PyMongo v3.x. - """ - def _inner(*args, **kwargs): - if IS_PYMONGO_3: - raise SkipTest("Useless with PyMongo 3+") - return f(*args, **kwargs) - - _inner.__name__ = f.__name__ - _inner.__doc__ = f.__doc__ - - return _inner - diff --git a/tox.ini b/tox.ini index 815d2acc..40bcea8a 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ commands = python setup.py nosetests {posargs} deps = nose - mg35: PyMongo==3.5 + mg34x: PyMongo>=3.4,<3.5 mg3x: PyMongo>=3.0,<3.7 setenv = PYTHON_EGG_CACHE = {envdir}/python-eggs