Compare commits

...

45 Commits

Author SHA1 Message Date
Stefan Wojcik
4e8bb14131 skip uri test for pymongo < v2.9 2017-02-25 13:40:51 -05:00
Stefan Wojcik
9cc4fad614 dummy 2017-02-25 13:15:52 -05:00
Stefan Wojcik
2a486ee537 better connection docstrings [ci skip] 2017-02-25 12:56:36 -05:00
Stefan Wojcik
2579ed754f add unit tests for setting the connection pool size 2017-02-25 12:48:03 -05:00
Stefan Wójcik
3f31666796 Fix the exception message when validating unicode URLs (#1486) 2017-02-24 16:18:34 -05:00
Stefan Wojcik
3fe8031cf3 fix EmbeddedDocumentListFieldTestCase 2017-02-22 12:44:05 -05:00
bagerard
b27c7ce11b allow to use sets in field choices (#1482) 2017-02-15 08:51:47 -05:00
Stefan Wojcik
ed34c2ca68 update the changelog and upgrade docs 2017-02-09 12:13:56 -08:00
Stefan Wójcik
3ca2e953fb Fix limit/skip/hint/batch_size chaining (#1476) 2017-02-09 12:02:46 -08:00
martin sereinig
d8a7328365 Fix docs regarding reverse_delete_rule and delete signals (#1473) 2017-02-06 14:11:42 -07:00
Stefan Wojcik
f33cd625bf nicer readme 2017-01-17 02:47:45 -05:00
Stefan Wojcik
80530bb13c nicer readme 2017-01-17 02:46:37 -05:00
Stefan Wójcik
affc12df4b Update README.rst 2017-01-17 02:43:29 -05:00
Stefan Wojcik
4eedf00025 nicer readme note about dependencies 2017-01-17 02:42:23 -05:00
Eli Boyarski
e5acbcc0dd Improved a docstring for FieldDoesNotExist (#1466) 2017-01-09 11:24:27 -05:00
Stefan Wojcik
1b6743ee53 add a changelog entry about broken references raising DoesNotExist 2017-01-08 14:50:16 -05:00
Eli Boyarski
b5fb82d95d Typo fix (#1463) 2017-01-08 12:57:36 -05:00
lanf0n
193aa4e1f2 [#1459] fix typo __neq__ to __ne__ (#1461) 2017-01-05 22:37:09 -05:00
Stefan Wójcik
ebd34427c7 Cleaner Document.save (#1458) 2016-12-30 05:43:56 -05:00
Stefan Wójcik
3d75573889 Validate db_field (#1448) 2016-12-29 12:39:05 -05:00
Stefan Wójcik
c6240ca415 Test connection's write concern (#1456) 2016-12-29 12:37:38 -05:00
Stefan Wójcik
2ee8984b44 add a $rename operator (#1454) 2016-12-28 23:25:38 -05:00
Stefan Wojcik
b7ec587e5b better docstring for BaseDocument.to_json 2016-12-28 22:15:46 -05:00
Stefan Wojcik
47c58bce2b fix "connect" example in the docs 2016-12-28 21:08:18 -05:00
Stefan Wojcik
96e95ac533 minor readme tweaks 2016-12-28 17:18:55 -05:00
Stefan Wojcik
b013a065f7 remove readme mention of the irc channel 2016-12-28 11:50:28 -05:00
Stefan Wojcik
74b37d11cf only validate db_field if it's a string type 2016-12-28 11:46:18 -05:00
Stefan Wójcik
c6cc013617 fix BaseQuerySet.fields when mixing exclusion/inclusion with complex values like $slice (#1452) 2016-12-28 11:40:57 -05:00
Stefan Wójcik
f4e1d80a87 support a negative dec operator (#1450) 2016-12-28 02:04:49 -05:00
Stefan Wójcik
91dad4060f raise an error when trying to save an abstract document (#1449) 2016-12-28 00:51:47 -05:00
Stefan Wojcik
e07cb82c15 validate db_field 2016-12-27 17:38:26 -05:00
Stefan Wojcik
2770cec187 better docstring for BaseQuerySet.fields 2016-12-27 10:20:13 -05:00
Stefan Wojcik
5c3928190a fix line width 2016-12-22 13:20:05 -05:00
Manuel Jeckelmann
9f4b04ea0f Fix querying an embedded document field by an invalid value (#1440) 2016-12-22 13:19:18 -05:00
Stefan Wojcik
96d20756ca remove redundant whitespace 2016-12-22 13:13:19 -05:00
John Dupuy
b8454c7f5b Fixed ListField deletion bug (#1435) 2016-12-22 13:11:44 -05:00
George Karakostas
c84f703f92 Update documentation to include a Q import (#1441) 2016-12-22 13:06:55 -05:00
Manuel Jeckelmann
57c2e867d8 Remove py26 from contributing docs (#1439)
Python 2.6 is not supported anymore with version 0.11.0
2016-12-19 17:54:43 -05:00
Stefan Wojcik
553f496d84 fix tests 2016-12-13 00:42:03 -05:00
Stefan Wojcik
b1d8aca46a update the changelog 2016-12-12 23:33:49 -05:00
Stefan Wojcik
8e884fd3ea make the __in=non_iterable_or_doc tests more concise 2016-12-12 23:30:38 -05:00
Malthe Jørgensen
76524b7498 Raise TypeError when __in-operator used with a Document (#1237) 2016-12-12 23:27:25 -05:00
Stefan Wojcik
65914fb2b2 fix the way MongoDB URI w/ ?replicaset is passed 2016-12-12 23:24:19 -05:00
Stefan Wojcik
a4d0da0085 update the changelog 2016-12-12 23:08:57 -05:00
Stefan Wójcik
c9d496e9a0 Fix connecting to MongoReplicaSetClient (#1436) 2016-12-12 23:08:11 -05:00
28 changed files with 969 additions and 300 deletions

View File

@@ -14,13 +14,13 @@ Before starting to write code, look for existing `tickets
<https://github.com/MongoEngine/mongoengine/issues?state=open>`_ or `create one
<https://github.com/MongoEngine/mongoengine/issues>`_ for your specific
issue or feature request. That way you avoid working on something
that might not be of interest or that has already been addressed. If in doubt
that might not be of interest or that has already been addressed. If in doubt
post to the `user group <http://groups.google.com/group/mongoengine-users>`
Supported Interpreters
----------------------
MongoEngine supports CPython 2.6 and newer. Language
MongoEngine supports CPython 2.7 and newer. Language
features not supported by all interpreters can not be used.
Please also ensure that your code is properly converted by
`2to3 <http://docs.python.org/library/2to3.html>`_ for Python 3 support.

View File

@@ -35,16 +35,22 @@ setup.py install``.
Dependencies
============
- pymongo>=2.7.1
- sphinx (optional - for documentation generation)
All of the dependencies can easily be installed via `pip <https://pip.pypa.io/>`_. At the very least, you'll need these two packages to use MongoEngine:
- pymongo>=2.7.1
- six>=1.10.0
If you utilize a ``DateTimeField``, you might also use a more flexible date parser:
Optional Dependencies
---------------------
- **Image Fields**: Pillow>=2.0.0
- dateutil>=2.1.0
.. note
MongoEngine always runs it's test suite against the latest patch version of each dependecy. e.g.: PyMongo 3.0.1
If you need to use an ``ImageField`` or ``ImageGridFsProxy``:
- Pillow>=2.0.0
If you want to generate the documentation (e.g. to contribute to it):
- sphinx
Examples
========
@@ -57,7 +63,7 @@ Some simple examples of what MongoEngine code looks like:
class BlogPost(Document):
title = StringField(required=True, max_length=200)
posted = DateTimeField(default=datetime.datetime.now)
posted = DateTimeField(default=datetime.datetime.utcnow)
tags = ListField(StringField(max_length=50))
meta = {'allow_inheritance': True}
@@ -87,17 +93,18 @@ Some simple examples of what MongoEngine code looks like:
... print
...
>>> len(BlogPost.objects)
# Count all blog posts and its subtypes
>>> BlogPost.objects.count()
2
>>> len(TextPost.objects)
>>> TextPost.objects.count()
1
>>> len(LinkPost.objects)
>>> LinkPost.objects.count()
1
# Find tagged posts
>>> len(BlogPost.objects(tags='mongoengine'))
# Count tagged posts
>>> BlogPost.objects(tags='mongoengine').count()
2
>>> len(BlogPost.objects(tags='mongodb'))
>>> BlogPost.objects(tags='mongodb').count()
1
Tests
@@ -130,8 +137,7 @@ Community
<http://groups.google.com/group/mongoengine-users>`_
- `MongoEngine Developers mailing list
<http://groups.google.com/group/mongoengine-dev>`_
- `#mongoengine IRC channel <http://webchat.freenode.net/?channels=mongoengine>`_
Contributing
============
We welcome contributions! see the `Contribution guidelines <https://github.com/MongoEngine/mongoengine/blob/master/CONTRIBUTING.rst>`_
We welcome contributions! See the `Contribution guidelines <https://github.com/MongoEngine/mongoengine/blob/master/CONTRIBUTING.rst>`_

View File

@@ -4,13 +4,19 @@ Changelog
Development
===========
- (Fill this out as you fix issues and develop you features).
- (Fill this out as you fix issues and develop your features).
- Fixed using sets in field choices #1481
- POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476
- POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476
- Fixed connecting to a replica set with PyMongo 2.x #1436
- Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237
Changes in 0.11.0
=================
- BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428
- BREAKING CHANGE: Dropped Python 2.6 support. #1428
- BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428
- BREAKING CHANGE: Accessing a broken reference will raise a `DoesNotExist` error. In the past it used to return `None`. #1334
- Fixed absent rounding for DecimalField when `force_string` is set. #1103
Changes in 0.10.8

View File

@@ -33,7 +33,7 @@ the :attr:`host` to
corresponding parameters in :func:`~mongoengine.connect`: ::
connect(
name='test',
db='test',
username='user',
password='12345',
host='mongodb://admin:qwerty@localhost/production'

View File

@@ -150,7 +150,7 @@ arguments can be set on all fields:
.. note:: If set, this field is also accessible through the `pk` field.
:attr:`choices` (Default: None)
An iterable (e.g. a list or tuple) of choices to which the value of this
An iterable (e.g. list, tuple or set) of choices to which the value of this
field should be limited.
Can be either be a nested tuples of value (stored in mongo) and a
@@ -214,8 +214,8 @@ document class as the first argument::
Dictionary Fields
-----------------
Often, an embedded document may be used instead of a dictionary generally
embedded documents are recommended as dictionaries dont support validation
Often, an embedded document may be used instead of a dictionary generally
embedded documents are recommended as dictionaries dont support validation
or custom field types. However, sometimes you will not know the structure of what you want to
store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate::
@@ -361,11 +361,6 @@ Its value can take any of the following constants:
In Django, be sure to put all apps that have such delete rule declarations in
their :file:`models.py` in the :const:`INSTALLED_APPS` tuple.
.. warning::
Signals are not triggered when doing cascading updates / deletes - if this
is required you must manually handle the update / delete.
Generic reference fields
''''''''''''''''''''''''
A second kind of reference field also exists,

View File

@@ -479,6 +479,8 @@ operators. To use a :class:`~mongoengine.queryset.Q` object, pass it in as the
first positional argument to :attr:`Document.objects` when you filter it by
calling it with keyword arguments::
from mongoengine.queryset.visitor import Q
# Get published posts
Post.objects(Q(published=True) | Q(publish_date__lte=datetime.now()))

View File

@@ -142,11 +142,4 @@ cleaner looking while still allowing manual execution of the callback::
modified = DateTimeField()
ReferenceFields and Signals
---------------------------
Currently `reverse_delete_rule` does not trigger signals on the other part of
the relationship. If this is required you must manually handle the
reverse deletion.
.. _blinker: http://pypi.python.org/pypi/blinker

View File

@@ -2,6 +2,20 @@
Upgrading
#########
Development
***********
(Fill this out whenever you introduce breaking changes to MongoEngine)
This release includes various fixes for the `BaseQuerySet` methods and how they
are chained together. Since version 0.10.1 applying limit/skip/hint/batch_size
to an already-existing queryset wouldn't modify the underlying PyMongo cursor.
This has been fixed now, so you'll need to make sure that your code didn't rely
on the broken implementation.
Additionally, a public `BaseQuerySet.clone_into` has been renamed to a private
`_clone_into`. If you directly used that method in your code, you'll need to
rename its occurrences.
0.11.0
******
This release includes a major rehaul of MongoEngine's code quality and

View File

@@ -5,7 +5,7 @@ __all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry')
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
'push_all', 'pull', 'pull_all', 'add_to_set',
'set_on_insert', 'min', 'max'])
'set_on_insert', 'min', 'max', 'rename'])
_document_registry = {}

View File

@@ -138,10 +138,7 @@ class BaseList(list):
return super(BaseList, self).__setitem__(key, value)
def __delitem__(self, key, *args, **kwargs):
if isinstance(key, slice):
self._mark_as_changed()
else:
self._mark_as_changed(key)
self._mark_as_changed()
return super(BaseList, self).__delitem__(key)
def __setslice__(self, *args, **kwargs):
@@ -432,7 +429,7 @@ class StrictDict(object):
def __eq__(self, other):
return self.items() == other.items()
def __neq__(self, other):
def __ne__(self, other):
return self.items() != other.items()
@classmethod

View File

@@ -402,9 +402,11 @@ class BaseDocument(object):
raise ValidationError(message, errors=errors)
def to_json(self, *args, **kwargs):
"""Converts a document to JSON.
:param use_db_field: Set to True by default but enables the output of the json structure with the field names
and not the mongodb store db_names in case of set to False
"""Convert this document to JSON.
:param use_db_field: Serialize field names as they appear in
MongoDB (as opposed to attribute names on this document).
Defaults to True.
"""
use_db_field = kwargs.pop('use_db_field', True)
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
@@ -675,6 +677,9 @@ class BaseDocument(object):
if not only_fields:
only_fields = []
if son and not isinstance(son, dict):
raise ValueError("The source SON object needs to be of type 'dict'")
# Get the class name from the document, falling back to the given
# class if unavailable
class_name = son.get('_cls', cls._class_name)

View File

@@ -41,7 +41,7 @@ class BaseField(object):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
:param name: Depreciated - use db_field
:param name: Deprecated - use db_field
:param required: If the field is required. Whether it has to have a
value or not. Defaults to False.
:param default: (optional) The default value for this field if no value
@@ -81,6 +81,17 @@ class BaseField(object):
self.sparse = sparse
self._owner_document = None
# Validate the db_field
if isinstance(self.db_field, six.string_types) and (
'.' in self.db_field or
'\0' in self.db_field or
self.db_field.startswith('$')
):
raise ValueError(
'field names cannot contain dots (".") or null characters '
'("\\0"), and they must not start with a dollar sign ("$").'
)
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
@@ -182,7 +193,8 @@ class BaseField(object):
EmbeddedDocument = _import_class('EmbeddedDocument')
choice_list = self.choices
if isinstance(choice_list[0], (list, tuple)):
if isinstance(next(iter(choice_list)), (list, tuple)):
# next(iter) is useful for sets
choice_list = [k for k, _ in choice_list]
# Choices which are other types of Documents

View File

@@ -34,7 +34,10 @@ def _import_class(cls_name):
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)
if cls_name in doc_classes:
if cls_name == 'BaseDocument':
from mongoengine.base import document as module
import_classes = ['BaseDocument']
elif cls_name in doc_classes:
from mongoengine import document as module
import_classes = doc_classes
elif cls_name in field_classes:

View File

@@ -51,7 +51,9 @@ def register_connection(alias, name=None, host=None, port=None,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
:param kwargs: ad-hoc parameters to be passed into the pymongo driver,
for example maxpoolsize, tz_aware, etc. See the documentation
for pymongo's `MongoClient` for a full list.
.. versionchanged:: 0.10.6 - added mongomock support
"""
@@ -66,9 +68,9 @@ def register_connection(alias, name=None, host=None, port=None,
'authentication_mechanism': authentication_mechanism
}
# Handle uri style connections
conn_host = conn_settings['host']
# host can be a list or a string, so if string, force to a list
# Host can be a list or a string, so if string, force to a list.
if isinstance(conn_host, six.string_types):
conn_host = [conn_host]
@@ -96,7 +98,7 @@ def register_connection(alias, name=None, host=None, port=None,
uri_options = uri_dict['options']
if 'replicaset' in uri_options:
conn_settings['replicaSet'] = True
conn_settings['replicaSet'] = uri_options['replicaset']
if 'authsource' in uri_options:
conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options:
@@ -170,23 +172,22 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
else:
connection_class = MongoClient
# Handle replica set connections
if 'replicaSet' in conn_settings:
# For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if 'replicaSet' in conn_settings and not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# hosts_or_uri has to be a string, so if 'host' was provided
# as a list, join its parts and separate them by ','
if isinstance(conn_settings['hosts_or_uri'], list):
conn_settings['hosts_or_uri'] = ','.join(
conn_settings['hosts_or_uri'])
# Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None)
# Discard replicaSet if it's not a string
if not isinstance(conn_settings['replicaSet'], six.string_types):
del conn_settings['replicaSet']
# For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# Iterate over all of the connection settings and if a connection with
# the same parameters is already established, use it instead of creating
# a new one.
@@ -242,9 +243,12 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
running on the default port on localhost. If authentication is needed,
provide username and password arguments as well.
Multiple databases are supported by using aliases. Provide a separate
Multiple databases are supported by using aliases. Provide a separate
`alias` to connect to a different instance of :program:`mongod`.
See the docstring for `register_connection` for more details about all
supported kwargs.
.. versionchanged:: 0.6 - added multiple database support.
"""
if alias not in _connections:

View File

@@ -313,6 +313,9 @@ class Document(BaseDocument):
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
if self._meta.get('abstract'):
raise InvalidDocumentError('Cannot save an abstract document.')
signal_kwargs = signal_kwargs or {}
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
@@ -329,68 +332,20 @@ class Document(BaseDocument):
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created, **signal_kwargs)
if self._meta.get('auto_create_index', True):
self.ensure_indexes()
try:
collection = self._get_collection()
if self._meta.get('auto_create_index', True):
self.ensure_indexes()
# Save a new document or update an existing one
if created:
if force_insert:
object_id = collection.insert(doc, **write_concern)
else:
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get('updatedExisting')
if updated is not None:
return not updated
return created
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
created = is_new_object(last_error)
object_id, created = self._save_update(doc, save_condition,
write_concern)
if cascade is None:
cascade = self._meta.get(
'cascade', False) or cascade_kwargs is not None
cascade = (self._meta.get('cascade', False) or
cascade_kwargs is not None)
if cascade:
kwargs = {
@@ -403,6 +358,7 @@ class Document(BaseDocument):
kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs
self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err:
message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % six.text_type(err))
@@ -415,16 +371,91 @@ class Document(BaseDocument):
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
# Make sure we store the PK on this document now that it's saved
id_field = self._meta['id_field']
if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs)
self._clear_changed_fields()
self._created = False
return self
def _save_create(self, doc, force_insert, write_concern):
"""Save a new document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
if force_insert:
return collection.insert(doc, **write_concern)
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
return object_id
def _save_update(self, doc, save_condition, write_concern):
"""Update an existing document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
object_id = doc['_id']
created = False
select_dict = {}
if save_condition is not None:
select_dict = transform.query(self.__class__, **save_condition)
select_dict['_id'] = object_id
# Need to add shard key to query, or you get an error
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
updates, removals = self._delta()
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
if last_error is not None:
updated_existing = last_error.get('updatedExisting')
if updated_existing is False:
created = True
# !!! This is bad, means we accidentally created a new,
# potentially corrupted document. See
# https://github.com/MongoEngine/mongoengine/issues/564
return object_id, created
def cascade_save(self, **kwargs):
"""Recursively save any references and generic references on the
document.
@@ -828,7 +859,6 @@ class Document(BaseDocument):
""" Lists all of the indexes that should be created for given
collection. It includes all the indexes from super- and sub-classes.
"""
if cls._meta.get('abstract'):
return []

View File

@@ -50,8 +50,8 @@ class FieldDoesNotExist(Exception):
or an :class:`~mongoengine.EmbeddedDocument`.
To avoid this behavior on data loading,
you should the :attr:`strict` to ``False``
in the :attr:`meta` dictionnary.
you should set the :attr:`strict` to ``False``
in the :attr:`meta` dictionary.
"""

View File

@@ -28,7 +28,7 @@ from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField,
GeoJsonBaseField, ObjectIdField, get_document)
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.errors import DoesNotExist, ValidationError
from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
from mongoengine.python_support import StringIO
from mongoengine.queryset import DO_NOTHING, QuerySet
@@ -139,12 +139,12 @@ class URLField(StringField):
# Check first if the scheme is valid
scheme = value.split('://')[0].lower()
if scheme not in self.schemes:
self.error('Invalid scheme {} in URL: {}'.format(scheme, value))
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
return
# Then check full URL
if not self.url_regex.match(value):
self.error('Invalid URL: {}'.format(value))
self.error(u'Invalid URL: {}'.format(value))
return
@@ -566,7 +566,11 @@ class EmbeddedDocumentField(BaseField):
def prepare_query_value(self, op, value):
if value is not None and not isinstance(value, self.document_type):
value = self.document_type._from_son(value)
try:
value = self.document_type._from_son(value)
except ValueError:
raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" %
(self.document_type._class_name,))
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value)
@@ -884,10 +888,6 @@ class ReferenceField(BaseField):
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
.. note ::
`reverse_delete_rule` does not trigger pre / post delete signals to be
triggered.
.. versionchanged:: 0.5 added `reverse_delete_rule`
"""

View File

@@ -86,6 +86,7 @@ class BaseQuerySet(object):
self._batch_size = None
self.only_fields = []
self._max_time_ms = None
self._comment = None
def __call__(self, q_obj=None, class_check=True, read_preference=None,
**query):
@@ -706,39 +707,36 @@ class BaseQuerySet(object):
with switch_db(self._document, alias) as cls:
collection = cls._get_collection()
return self.clone_into(self.__class__(self._document, collection))
return self._clone_into(self.__class__(self._document, collection))
def clone(self):
"""Creates a copy of the current
:class:`~mongoengine.queryset.QuerySet`
"""Create a copy of the current queryset."""
return self._clone_into(self.__class__(self._document, self._collection_obj))
.. versionadded:: 0.5
def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
"""
return self.clone_into(self.__class__(self._document, self._collection_obj))
def clone_into(self, cls):
"""Creates a copy of the current
:class:`~mongoengine.queryset.base.BaseQuerySet` into another child class
"""
if not isinstance(cls, BaseQuerySet):
if not isinstance(new_qs, BaseQuerySet):
raise OperationError(
'%s is not a subclass of BaseQuerySet' % cls.__name__)
'%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_class_check', '_slave_okay', '_read_preference',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms')
'_search_text', 'only_fields', '_max_time_ms', '_comment')
for prop in copy_props:
val = getattr(self, prop)
setattr(cls, prop, copy.copy(val))
setattr(new_qs, prop, copy.copy(val))
if self._cursor_obj:
cls._cursor_obj = self._cursor_obj.clone()
new_qs._cursor_obj = self._cursor_obj.clone()
return cls
return new_qs
def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects or
@@ -760,7 +758,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._limit = n if n != 0 else 1
# Return self to allow chaining
# If a cursor object has already been created, apply the limit to it.
if queryset._cursor_obj:
queryset._cursor_obj.limit(queryset._limit)
return queryset
def skip(self, n):
@@ -771,6 +773,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._skip = n
# If a cursor object has already been created, apply the skip to it.
if queryset._cursor_obj:
queryset._cursor_obj.skip(queryset._skip)
return queryset
def hint(self, index=None):
@@ -788,6 +795,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._hint = index
# If a cursor object has already been created, apply the hint to it.
if queryset._cursor_obj:
queryset._cursor_obj.hint(queryset._hint)
return queryset
def batch_size(self, size):
@@ -801,6 +813,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._batch_size = size
# If a cursor object has already been created, apply the batch size to it.
if queryset._cursor_obj:
queryset._cursor_obj.batch_size(queryset._batch_size)
return queryset
def distinct(self, field):
@@ -900,18 +917,24 @@ class BaseQuerySet(object):
return self.fields(**fields)
def fields(self, _only_called=False, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. If called
directly, use a set of kwargs similar to the MongoDB projection
document. For example:
Retrieving a Subrange of Array Elements:
Include only a subset of fields:
You can use the $slice operator to retrieve a subrange of elements in
an array. For example to get the first 5 comments::
posts = BlogPost.objects(...).fields(author=1, title=1)
post = BlogPost.objects(...).fields(slice__comments=5)
Exclude a specific field:
:param kwargs: A dictionary identifying what to include
posts = BlogPost.objects(...).fields(comments=0)
To retrieve a subrange of array elements:
posts = BlogPost.objects(...).fields(slice__comments=5)
:param kwargs: A set keywors arguments identifying what to include.
.. versionadded:: 0.5
"""
@@ -927,7 +950,20 @@ class BaseQuerySet(object):
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
# Sort fields by their values, explicitly excluded fields first, then
# explicitly included, and then more complicated operators such as
# $slice.
def _sort_key(field_tuple):
key, value = field_tuple
if isinstance(value, (int)):
return value # 0 for exclusion, 1 for inclusion
else:
return 2 # so that complex values appear last
fields = sorted(cleaned_fields, key=_sort_key)
# Clone the queryset, group all fields by their value, convert
# each of them to db_fields, and set the queryset's _loaded_fields
queryset = self.clone()
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
@@ -953,13 +989,31 @@ class BaseQuerySet(object):
def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed.
Ascending order is assumed. If no keys are passed, existing ordering
is cleared instead.
:param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction
"""
queryset = self.clone()
queryset._ordering = queryset._get_order_by(keys)
old_ordering = queryset._ordering
new_ordering = queryset._get_order_by(keys)
if queryset._cursor_obj:
# If a cursor object has already been created, apply the sort to it
if new_ordering:
queryset._cursor_obj.sort(new_ordering)
# If we're trying to clear a previous explicit ordering, we need
# to clear the cursor entirely (because PyMongo doesn't allow
# clearing an existing sort on a cursor).
elif old_ordering:
queryset._cursor_obj = None
queryset._ordering = new_ordering
return queryset
def comment(self, text):
@@ -1405,10 +1459,13 @@ class BaseQuerySet(object):
raise StopIteration
raw_doc = self._cursor.next()
if self._as_pymongo:
return self._get_as_pymongo(raw_doc)
doc = self._document._from_son(raw_doc,
_auto_dereference=self._auto_dereference, only_fields=self.only_fields)
doc = self._document._from_son(
raw_doc, _auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
if self._scalar:
return self._get_scalar(doc)
@@ -1418,7 +1475,6 @@ class BaseQuerySet(object):
def rewind(self):
"""Rewind the cursor to its unevaluated state.
.. versionadded:: 0.3
"""
self._iter = False
@@ -1468,43 +1524,54 @@ class BaseQuerySet(object):
@property
def _cursor(self):
if self._cursor_obj is None:
"""Return a PyMongo cursor object corresponding to this queryset."""
# In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned
# collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
# If _cursor_obj already exists, return it immediately.
if self._cursor_obj is not None:
return self._cursor_obj
if self._ordering:
# Apply query ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model, unless
# it's been explicitly cleared via order_by with no arguments
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
# Create a new PyMongo cursor.
# XXX In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply "where" clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
# Apply ordering to the cursor.
# XXX self._ordering can be equal to:
# * None if we didn't explicitly call order_by on this queryset.
# * A list of PyMongo-style sorting tuples.
# * An empty list if we explicitly called order_by() without any
# arguments. This indicates that we want to clear the default
# ordering.
if self._ordering:
# explicit ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# default ordering
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj
@@ -1679,7 +1746,13 @@ class BaseQuerySet(object):
return ret
def _get_order_by(self, keys):
"""Creates a list of order by fields"""
"""Given a list of MongoEngine-style sort keys, return a list
of sorting tuples that can be applied to a PyMongo cursor. For
example:
>>> qs._get_order_by(['-last_name', 'first_name'])
[('last_name', -1), ('first_name', 1)]
"""
key_list = []
for key in keys:
if not key:
@@ -1692,17 +1765,19 @@ class BaseQuerySet(object):
direction = pymongo.ASCENDING
if key[0] == '-':
direction = pymongo.DESCENDING
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
try:
key = self._document._translate_field_name(key)
except Exception:
# TODO this exception should be more specific
pass
key_list.append((key, direction))
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list
def _get_scalar(self, doc):
@@ -1800,10 +1875,21 @@ class BaseQuerySet(object):
return code
def _chainable_method(self, method_name, val):
"""Call a particular method on the PyMongo cursor call a particular chainable method
with the provided value.
"""
queryset = self.clone()
method = getattr(queryset._cursor, method_name)
method(val)
# Get an existing cursor object or create a new one
cursor = queryset._cursor
# Find the requested method on the cursor and call it with the
# provided value
getattr(cursor, method_name)(val)
# Cache the value on the queryset._{method_name}
setattr(queryset, '_' + method_name, val)
return queryset
# Deprecated

View File

@@ -136,13 +136,15 @@ class QuerySet(BaseQuerySet):
return self._len
def no_cache(self):
"""Convert to a non_caching queryset
"""Convert to a non-caching queryset
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection))
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
class QuerySetNoCache(BaseQuerySet):
@@ -153,7 +155,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to caching queryset
"""
return self.clone_into(QuerySet(self._document, self._collection))
return self._clone_into(QuerySet(self._document, self._collection))
def __repr__(self):
"""Provides the string representation of the QuerySet

View File

@@ -101,8 +101,21 @@ def query(_doc_cls=None, **kwargs):
value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
# Raise an error if the in/nin/all/near param is not iterable. We need a
# special check for BaseDocument, because - although it's iterable - using
# it as such in the context of this method is most definitely a mistake.
BaseDocument = _import_class('BaseDocument')
if isinstance(value, BaseDocument):
raise TypeError("When using the `in`, `nin`, `all`, or "
"`near`-operators you can\'t use a "
"`Document`, you must wrap your object "
"in a list (object -> [object]).")
elif not hasattr(value, '__iter__'):
raise TypeError("The `in`, `nin`, `all`, or "
"`near`-operators must be applied to an "
"iterable (e.g. a list).")
else:
value = [field.prepare_query_value(op, v) for v in value]
# If we're querying a GenericReferenceField, we need to alter the
# key depending on the value:
@@ -220,8 +233,7 @@ def update(_doc_cls=None, **update):
# Support decrement by flipping a positive value's sign
# and using 'inc'
op = 'inc'
if value > 0:
value = -value
value = -value
elif op == 'add_to_set':
op = 'addToSet'
elif op == 'set_on_insert':

View File

@@ -7,5 +7,5 @@ cover-package=mongoengine
[flake8]
ignore=E501,F401,F403,F405,I201
exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests
max-complexity=45
max-complexity=47
application-import-names=mongoengine,tests

View File

@@ -435,6 +435,15 @@ class InstanceTest(unittest.TestCase):
person.to_dbref()
def test_save_abstract_document(self):
"""Saving an abstract document should fail."""
class Doc(Document):
name = StringField()
meta = {'abstract': True}
with self.assertRaises(InvalidDocumentError):
Doc(name='aaa').save()
def test_reload(self):
"""Ensure that attributes may be reloaded.
"""
@@ -1223,6 +1232,19 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, None)
self.assertEqual(person.age, None)
def test_update_rename_operator(self):
"""Test the $rename operator."""
coll = self.Person._get_collection()
doc = self.Person(name='John').save()
raw_doc = coll.find_one({'_id': doc.pk})
self.assertEqual(set(raw_doc.keys()), set(['_id', '_cls', 'name']))
doc.update(rename__name='first_name')
raw_doc = coll.find_one({'_id': doc.pk})
self.assertEqual(set(raw_doc.keys()),
set(['_id', '_cls', 'first_name']))
self.assertEqual(raw_doc['first_name'], 'John')
def test_inserts_if_you_set_the_pk(self):
p1 = self.Person(name='p1', id=bson.ObjectId()).save()
p2 = self.Person(name='p2')
@@ -1860,6 +1882,10 @@ class InstanceTest(unittest.TestCase):
'occurs': {"hello": None}
})
# Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438
with self.assertRaises(ValueError):
Word._from_son('this is not a valid SON dict')
def test_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon deletion.
"""

View File

@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
import six
from nose.plugins.skip import SkipTest
import datetime
import unittest
import uuid
import math
import itertools
import re
from nose.plugins.skip import SkipTest
import six
try:
@@ -27,21 +26,13 @@ from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList,
_document_registry)
from mongoengine.errors import NotRegistered, DoesNotExist
from tests.utils import MongoDBTestCase
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def tearDown(self):
self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks')
self.db.drop_collection('mongoengine.counters')
class FieldTest(MongoDBTestCase):
def test_default_values_nothing_set(self):
"""Ensure that default field values are used when creating a document.
@@ -227,9 +218,9 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime))
def test_not_required_handles_none_from_database(self):
"""Ensure that every fields can handle null values from the database.
"""Ensure that every field can handle null values from the
database.
"""
class HandleNoneFields(Document):
str_fld = StringField(required=True)
int_fld = IntField(required=True)
@@ -306,6 +297,24 @@ class FieldTest(unittest.TestCase):
person.id = '497ce96f395f2f052a494fd4'
person.validate()
def test_db_field_validation(self):
"""Ensure that db_field doesn't accept invalid values."""
# dot in the name
with self.assertRaises(ValueError):
class User(Document):
name = StringField(db_field='user.name')
# name starting with $
with self.assertRaises(ValueError):
class User(Document):
name = StringField(db_field='$name')
# name containing a null character
with self.assertRaises(ValueError):
class User(Document):
name = StringField(db_field='name\0')
def test_string_validation(self):
"""Ensure that invalid values cannot be assigned to string fields.
"""
@@ -332,11 +341,12 @@ class FieldTest(unittest.TestCase):
person.validate()
def test_url_validation(self):
"""Ensure that URLFields validate urls properly.
"""
"""Ensure that URLFields validate urls properly."""
class Link(Document):
url = URLField()
Link.drop_collection()
link = Link()
link.url = 'google'
self.assertRaises(ValidationError, link.validate)
@@ -344,6 +354,27 @@ class FieldTest(unittest.TestCase):
link.url = 'http://www.google.com:8080'
link.validate()
def test_unicode_url_validation(self):
"""Ensure unicode URLs are validated properly."""
class Link(Document):
url = URLField()
Link.drop_collection()
link = Link()
link.url = u'http://привет.com'
# TODO fix URL validation - this *IS* a valid URL
# For now we just want to make sure that the error message is correct
try:
link.validate()
self.assertTrue(False)
except ValidationError as e:
self.assertEqual(
unicode(e),
u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])"
)
def test_url_scheme_validation(self):
"""Ensure that URLFields validate urls with specific schemes properly.
"""
@@ -1042,6 +1073,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(
BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
# test queries by list
post = BlogPost()
post.info = ['1', '2']
post.save()
@@ -1053,6 +1085,248 @@ class FieldTest(unittest.TestCase):
post.info *= 2
post.save()
self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1)
BlogPost.drop_collection()
def test_list_field_manipulative_operators(self):
"""Ensure that ListField works with standard list operators that manipulate the list.
"""
class BlogPost(Document):
ref = StringField()
info = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost()
post.ref = "1234"
post.info = ['0', '1', '2', '3', '4', '5']
post.save()
def reset_post():
post.info = ['0', '1', '2', '3', '4', '5']
post.save()
# '__add__(listB)'
# listA+listB
# operator.add(listA, listB)
reset_post()
temp = ['a', 'b']
post.info = post.info + temp
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
# '__delitem__(index)'
# aka 'del list[index]'
# aka 'operator.delitem(list, index)'
reset_post()
del post.info[2] # del from middle ('2')
self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
# '__delitem__(slice(i, j))'
# aka 'del list[i:j]'
# aka 'operator.delitem(list, slice(i,j))'
reset_post()
del post.info[1:3] # removes '1', '2'
self.assertEqual(post.info, ['0', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '3', '4', '5'])
# '__iadd__'
# aka 'list += list'
reset_post()
temp = ['a', 'b']
post.info += temp
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
# '__imul__'
# aka 'list *= number'
reset_post()
post.info *= 2
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
# '__mul__'
# aka 'listA*listB'
reset_post()
post.info = post.info * 2
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
# '__rmul__'
# aka 'listB*listA'
reset_post()
post.info = 2 * post.info
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
# '__setitem__(index, value)'
# aka 'list[index]=value'
# aka 'setitem(list, value)'
reset_post()
post.info[4] = 'a'
self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5'])
# '__setitem__(slice(i, j), listB)'
# aka 'listA[i:j] = listB'
# aka 'setitem(listA, slice(i, j), listB)'
reset_post()
post.info[1:3] = ['h', 'e', 'l', 'l', 'o']
self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5'])
# 'append'
reset_post()
post.info.append('h')
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h'])
# 'extend'
reset_post()
post.info.extend(['h', 'e', 'l', 'l', 'o'])
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o'])
# 'insert'
# 'pop'
reset_post()
x = post.info.pop(2)
y = post.info.pop()
self.assertEqual(post.info, ['0', '1', '3', '4'])
self.assertEqual(x, '2')
self.assertEqual(y, '5')
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '3', '4'])
# 'remove'
reset_post()
post.info.remove('2')
self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
post.save()
post.reload()
self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
# 'reverse'
reset_post()
post.info.reverse()
self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0'])
post.save()
post.reload()
self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0'])
# 'sort': though this operator method does manipulate the list, it is tested in
# the 'test_list_field_lexicograpic_operators' function
BlogPost.drop_collection()
def test_list_field_invalid_operators(self):
class BlogPost(Document):
ref = StringField()
info = ListField(StringField())
post = BlogPost()
post.ref = "1234"
post.info = ['0', '1', '2', '3', '4', '5']
# '__hash__'
# aka 'hash(list)'
# # assert TypeError
self.assertRaises(TypeError, lambda: hash(post.info))
def test_list_field_lexicographic_operators(self):
"""Ensure that ListField works with standard list operators that do lexigraphic ordering.
"""
class BlogPost(Document):
ref = StringField()
text_info = ListField(StringField())
oid_info = ListField(ObjectIdField())
bool_info = ListField(BooleanField())
BlogPost.drop_collection()
blogSmall = BlogPost(ref="small")
blogSmall.text_info = ["a", "a", "a"]
blogSmall.bool_info = [False, False]
blogSmall.save()
blogSmall.reload()
blogLargeA = BlogPost(ref="big")
blogLargeA.text_info = ["a", "z", "j"]
blogLargeA.bool_info = [False, True]
blogLargeA.save()
blogLargeA.reload()
blogLargeB = BlogPost(ref="big2")
blogLargeB.text_info = ["a", "z", "j"]
blogLargeB.oid_info = [
"54495ad94c934721ede76f90",
"54495ad94c934721ede76d23",
"54495ad94c934721ede76d00"
]
blogLargeB.bool_info = [False, True]
blogLargeB.save()
blogLargeB.reload()
# '__eq__' aka '=='
self.assertEqual(blogLargeA.text_info, blogLargeB.text_info)
self.assertEqual(blogLargeA.bool_info, blogLargeB.bool_info)
# '__ge__' aka '>='
self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info)
self.assertGreaterEqual(blogLargeA.text_info, blogLargeB.text_info)
self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info)
self.assertGreaterEqual(blogLargeA.bool_info, blogLargeB.bool_info)
# '__gt__' aka '>'
self.assertGreaterEqual(blogLargeA.text_info, blogSmall.text_info)
self.assertGreaterEqual(blogLargeA.bool_info, blogSmall.bool_info)
# '__le__' aka '<='
self.assertLessEqual(blogSmall.text_info, blogLargeB.text_info)
self.assertLessEqual(blogLargeA.text_info, blogLargeB.text_info)
self.assertLessEqual(blogSmall.bool_info, blogLargeB.bool_info)
self.assertLessEqual(blogLargeA.bool_info, blogLargeB.bool_info)
# '__lt__' aka '<'
self.assertLess(blogSmall.text_info, blogLargeB.text_info)
self.assertLess(blogSmall.bool_info, blogLargeB.bool_info)
# '__ne__' aka '!='
self.assertNotEqual(blogSmall.text_info, blogLargeB.text_info)
self.assertNotEqual(blogSmall.bool_info, blogLargeB.bool_info)
# 'sort'
blogLargeB.bool_info = [True, False, True, False]
blogLargeB.text_info.sort()
blogLargeB.oid_info.sort()
blogLargeB.bool_info.sort()
sorted_target_list = [
ObjectId("54495ad94c934721ede76d00"),
ObjectId("54495ad94c934721ede76d23"),
ObjectId("54495ad94c934721ede76f90")
]
self.assertEqual(blogLargeB.text_info, ["a", "j", "z"])
self.assertEqual(blogLargeB.oid_info, sorted_target_list)
self.assertEqual(blogLargeB.bool_info, [False, False, True, True])
blogLargeB.save()
blogLargeB.reload()
self.assertEqual(blogLargeB.text_info, ["a", "j", "z"])
self.assertEqual(blogLargeB.oid_info, sorted_target_list)
self.assertEqual(blogLargeB.bool_info, [False, False, True, True])
BlogPost.drop_collection()
def test_list_assignment(self):
@@ -1102,7 +1376,6 @@ class FieldTest(unittest.TestCase):
post.reload()
self.assertEqual(post.info, [1, 2, 3, 4, 'n5'])
def test_list_field_passed_in_value(self):
class Foo(Document):
bars = ListField(ReferenceField("Bar"))
@@ -1725,7 +1998,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(content, User.objects.first().groups[0].content)
def test_reference_miss(self):
"""Ensure an exception is raised when dereferencing unknow document
"""Ensure an exception is raised when dereferencing unknown document
"""
class Foo(Document):
@@ -2926,26 +3199,42 @@ class FieldTest(unittest.TestCase):
att.delete()
self.assertEqual(0, Attachment.objects.count())
def test_choices_validation(self):
"""Ensure that value is in a container of allowed values.
def test_choices_allow_using_sets_as_choices(self):
"""Ensure that sets can be used when setting choices
"""
class Shirt(Document):
size = StringField(max_length=3, choices=(
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
size = StringField(choices={'M', 'L'})
Shirt.drop_collection()
Shirt(size='M').validate()
def test_choices_validation_allow_no_value(self):
"""Ensure that .validate passes and no value was provided
for a field setup with choices
"""
class Shirt(Document):
size = StringField(choices=('S', 'M'))
shirt = Shirt()
shirt.validate()
shirt.size = "S"
def test_choices_validation_accept_possible_value(self):
"""Ensure that value is in a container of allowed values.
"""
class Shirt(Document):
size = StringField(choices=('S', 'M'))
shirt = Shirt(size='S')
shirt.validate()
shirt.size = "XS"
self.assertRaises(ValidationError, shirt.validate)
def test_choices_validation_reject_unknown_value(self):
"""Ensure that unallowed value are rejected upon validation
"""
class Shirt(Document):
size = StringField(choices=('S', 'M'))
Shirt.drop_collection()
shirt = Shirt(size="XS")
with self.assertRaises(ValidationError):
shirt.validate()
def test_choices_validation_documents(self):
"""
@@ -3731,30 +4020,25 @@ class FieldTest(unittest.TestCase):
"""Tests if a `FieldDoesNotExist` exception is raised when trying to
instanciate a document with a field that's not defined.
"""
class Doc(Document):
foo = StringField(db_field='f')
foo = StringField()
def test():
with self.assertRaises(FieldDoesNotExist):
Doc(bar='test')
self.assertRaises(FieldDoesNotExist, test)
def test_undefined_field_exception_with_strict(self):
"""Tests if a `FieldDoesNotExist` exception is raised when trying to
instanciate a document with a field that's not defined,
even when strict is set to False.
"""
class Doc(Document):
foo = StringField(db_field='f')
foo = StringField()
meta = {'strict': False}
def test():
with self.assertRaises(FieldDoesNotExist):
Doc(bar='test')
self.assertRaises(FieldDoesNotExist, test)
def test_long_field_is_considered_as_int64(self):
"""
Tests that long fields are stored as long in mongo, even if long value
@@ -3769,12 +4053,13 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(doc.some_long, six.integer_types))
class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = connect(db='EmbeddedDocumentListFieldTestCase')
class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
def setUp(self):
"""
Create two BlogPost entries in the database, each with
several EmbeddedDocuments.
"""
class Comments(EmbeddedDocument):
author = StringField()
message = StringField()
@@ -3782,14 +4067,11 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
class BlogPost(Document):
comments = EmbeddedDocumentListField(Comments)
cls.Comments = Comments
cls.BlogPost = BlogPost
BlogPost.drop_collection()
self.Comments = Comments
self.BlogPost = BlogPost
def setUp(self):
"""
Create two BlogPost entries in the database, each with
several EmbeddedDocuments.
"""
self.post1 = self.BlogPost(comments=[
self.Comments(author='user1', message='message1'),
self.Comments(author='user2', message='message1')
@@ -3801,13 +4083,6 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.Comments(author='user3', message='message1')
]).save()
def tearDown(self):
self.BlogPost.drop_collection()
@classmethod
def tearDownClass(cls):
cls.db.drop_database('EmbeddedDocumentListFieldTestCase')
def test_no_keyword_filter(self):
"""
Tests the filter method of a List of Embedded Documents
@@ -4165,7 +4440,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
my_list = ListField(EmbeddedDocumentField(EmbeddedWithUnique))
A(my_list=[]).save()
self.assertRaises(NotUniqueError, lambda: A(my_list=[]).save())
with self.assertRaises(NotUniqueError):
A(my_list=[]).save()
class EmbeddedWithSparseUnique(EmbeddedDocument):
number = IntField(unique=True, sparse=True)
@@ -4173,6 +4449,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
class B(Document):
my_list = ListField(EmbeddedDocumentField(EmbeddedWithSparseUnique))
A.drop_collection()
B.drop_collection()
B(my_list=[]).save()
B(my_list=[]).save()
@@ -4212,6 +4491,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
a_field = IntField()
c_field = IntField(custom_data=custom_data)
CustomData.drop_collection()
a1 = CustomData(a_field=1, c_field=2).save()
self.assertEqual(2, a1.c_field)
self.assertFalse(hasattr(a1.c_field, 'custom_data'))

View File

@@ -18,15 +18,13 @@ try:
except ImportError:
HAS_PIL = False
from tests.utils import MongoDBTestCase
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')
class FileTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
class FileTest(MongoDBTestCase):
def tearDown(self):
self.db.drop_collection('fs.files')

View File

@@ -141,6 +141,16 @@ class OnlyExcludeAllTest(unittest.TestCase):
self.assertEqual(qs._loaded_fields.as_dict(),
{'b': {'$slice': 5}})
def test_mix_slice_with_other_fields(self):
class MyDoc(Document):
a = ListField()
b = ListField()
c = ListField()
qs = MyDoc.objects.fields(a=1, b=0, slice__c=2)
self.assertEqual(qs._loaded_fields.as_dict(),
{'c': {'$slice': 2}, 'a': 1})
def test_only(self):
"""Ensure that QuerySet.only only returns the requested fields.
"""

View File

@@ -106,58 +106,111 @@ class QuerySetTest(unittest.TestCase):
list(BlogPost.objects(author2__name="test"))
def test_find(self):
"""Ensure that a query returns a valid set of results.
"""
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=30).save()
"""Ensure that a query returns a valid set of results."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Find all people in the collection
people = self.Person.objects
self.assertEqual(people.count(), 2)
results = list(people)
self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode)))
self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0], user_a)
self.assertEqual(results[0].name, 'User A')
self.assertEqual(results[0].age, 20)
self.assertEqual(results[1].name, "User B")
self.assertEqual(results[1], user_b)
self.assertEqual(results[1].name, 'User B')
self.assertEqual(results[1].age, 30)
# Use a query to filter the people found to just person1
# Filter people by age
people = self.Person.objects(age=20)
self.assertEqual(people.count(), 1)
person = people.next()
self.assertEqual(person, user_a)
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
# Test limit
def test_limit(self):
"""Ensure that QuerySet.limit works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test limit on a new queryset
people = list(self.Person.objects.limit(1))
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[0], user_a)
# Test skip
# Test limit on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.limit(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_a)
# Test chaining of only after limit
person = self.Person.objects().limit(1).only('name').first()
self.assertEqual(person, user_a)
self.assertEqual(person.name, 'User A')
self.assertEqual(person.age, None)
def test_skip(self):
"""Ensure that QuerySet.skip works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test skip on a new queryset
people = list(self.Person.objects.skip(1))
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[0], user_b)
person3 = self.Person(name="User C", age=40)
person3.save()
# Test skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.skip(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test chaining of only after skip
person = self.Person.objects().skip(1).only('name').first()
self.assertEqual(person, user_b)
self.assertEqual(person.name, 'User B')
self.assertEqual(person.age, None)
def test_slice(self):
"""Ensure slicing a queryset works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
user_c = self.Person.objects.create(name="User C", age=40)
# Test slice limit
people = list(self.Person.objects[:2])
self.assertEqual(len(people), 2)
self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[1].name, 'User B')
self.assertEqual(people[0], user_a)
self.assertEqual(people[1], user_b)
# Test slice skip
people = list(self.Person.objects[1:])
self.assertEqual(len(people), 2)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[1].name, 'User C')
self.assertEqual(people[0], user_b)
self.assertEqual(people[1], user_c)
# Test slice limit and skip
people = list(self.Person.objects[1:2])
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[0], user_b)
# Test slice limit and skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 3)
people2 = people[1:2]
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test slice limit and skip cursor reset
qs = self.Person.objects[1:2]
@@ -168,6 +221,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
# Test empty slice
people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0)
@@ -187,12 +241,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual("[<Person: Person object>, <Person: Person object>]",
"%s" % self.Person.objects[51:53])
# Test only after limit
self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None)
# Test only after skip
self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None)
def test_find_one(self):
"""Ensure that a query using find_one returns a valid result.
"""
@@ -1226,6 +1274,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
# default ordering should be used by default
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1)
@@ -1234,11 +1283,28 @@ class QuerySetTest(unittest.TestCase):
{'published_date': -1}
)
# calling order_by() should clear the default ordering
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
# calling an explicit order_by should use a specified sort
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by('published_date').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
q.get_ops()[0]['query']['$orderby'],
{'published_date': 1}
)
# calling order_by() after an explicit sort should clear it
with db_ops_tracker() as q:
qs = BlogPost.objects.filter(title='whatever').order_by('published_date')
qs.order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_no_ordering_for_get(self):
""" Ensure that Doc.objects.get doesn't use any ordering.
"""
@@ -1266,7 +1332,7 @@ class QuerySetTest(unittest.TestCase):
def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from
a query.
different manners of querying.
"""
class User(EmbeddedDocument):
name = StringField()
@@ -1277,8 +1343,9 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
user = User(name='Test User')
BlogPost.objects.create(
author=User(name='Test User'),
author=user,
content='Had a good coffee today...'
)
@@ -1286,6 +1353,19 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User')
result = BlogPost.objects.get(author__name=user.name)
self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User')
result = BlogPost.objects.get(author={'name': user.name})
self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User')
# Fails, since the string is not a type that is able to represent the
# author's document structure (should be dict)
with self.assertRaises(InvalidQueryError):
BlogPost.objects.get(author=user.name)
def test_find_empty_embedded(self):
"""Ensure that you can save and find an empty embedded document."""
class User(EmbeddedDocument):
@@ -1812,6 +1892,11 @@ class QuerySetTest(unittest.TestCase):
post.reload()
self.assertEqual(post.hits, 10)
# Negative dec operator is equal to a positive inc operator
BlogPost.objects.update_one(dec__hits=-1)
post.reload()
self.assertEqual(post.hits, 11)
BlogPost.objects.update(push__tags='mongo')
post.reload()
self.assertTrue('mongo' in post.tags)
@@ -4963,6 +5048,35 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(i, 249)
self.assertEqual(j, 249)
def test_in_operator_on_non_iterable(self):
"""Ensure that using the `__in` operator on a non-iterable raises an
error.
"""
class User(Document):
name = StringField()
class BlogPost(Document):
content = StringField()
authors = ListField(ReferenceField(User))
User.drop_collection()
BlogPost.drop_collection()
author = User.objects.create(name='Test User')
post = BlogPost.objects.create(content='Had a good coffee today...',
authors=[author])
# Make sure using `__in` with a list works
blog_posts = BlogPost.objects(authors__in=[author])
self.assertEqual(list(blog_posts), [post])
# Using `__in` with a non-iterable should raise a TypeError
self.assertRaises(TypeError, BlogPost.objects(authors__in=author.pk).count)
# Using `__in` with a `Document` (which is seemingly iterable but not
# in a way we'd expect) should raise a TypeError, too
self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count)
if __name__ == '__main__':
unittest.main()

View File

@@ -200,6 +200,19 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'test')
def test_connect_uri_with_replicaset(self):
"""Ensure connect() works when specifying a replicaSet."""
if IS_PYMONGO_3:
c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'test')
else:
# PyMongo < v3.x raises an exception:
# "localhost:27017 is not a member of replica set local-rs"
with self.assertRaises(MongoEngineConnectionError):
c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
def test_uri_without_credentials_doesnt_override_conn_settings(self):
"""Ensure connect() uses the username & password params if the URI
doesn't explicitly specify them.
@@ -272,8 +285,7 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
def test_connection_kwargs(self):
"""Ensure that connection kwargs get passed to pymongo.
"""
"""Ensure that connection kwargs get passed to pymongo."""
connect('mongoenginetest', alias='t1', tz_aware=True)
conn = get_connection('t1')
@@ -283,6 +295,45 @@ class ConnectionTest(unittest.TestCase):
conn = get_connection('t2')
self.assertFalse(get_tz_awareness(conn))
def test_connection_pool_via_kwarg(self):
"""Ensure we can specify a max connection pool size using
a connection kwarg.
"""
# Use "max_pool_size" or "maxpoolsize" depending on PyMongo version
# (former was changed to the latter as described in
# https://jira.mongodb.org/browse/PYTHON-854).
# TODO remove once PyMongo < 3.0 support is dropped
if pymongo.version_tuple[0] >= 3:
pool_size_kwargs = {'maxpoolsize': 100}
else:
pool_size_kwargs = {'max_pool_size': 100}
conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs)
self.assertEqual(conn.max_pool_size, 100)
def test_connection_pool_via_uri(self):
"""Ensure we can specify a max connection pool size using
an option in a connection URI.
"""
if pymongo.version_tuple[0] == 2 and pymongo.version_tuple[1] < 9:
raise SkipTest('maxpoolsize as a URI option is only supported in PyMongo v2.9+')
conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri')
self.assertEqual(conn.max_pool_size, 100)
def test_write_concern(self):
"""Ensure write concern can be specified in connect() via
a kwarg or as part of the connection URI.
"""
conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true')
conn2 = connect('testing', alias='conn2', w=1, j=True)
if IS_PYMONGO_3:
self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True})
self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
else:
self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True})
self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True})
def test_datetime(self):
connect('mongoenginetest', tz_aware=True)
d = datetime.datetime(2010, 5, 5, tzinfo=utc)

22
tests/utils.py Normal file
View File

@@ -0,0 +1,22 @@
import unittest
from mongoengine import connect
from mongoengine.connection import get_db
MONGO_TEST_DB = 'mongoenginetest'
class MongoDBTestCase(unittest.TestCase):
"""Base class for tests that need a mongodb connection
db is being dropped automatically
"""
@classmethod
def setUpClass(cls):
cls._connection = connect(db=MONGO_TEST_DB)
cls._connection.drop_database(MONGO_TEST_DB)
cls.db = get_db()
@classmethod
def tearDownClass(cls):
cls._connection.drop_database(MONGO_TEST_DB)