Compare commits

...

9 Commits

Author SHA1 Message Date
Stefan Wojcik
4e8bb14131 skip uri test for pymongo < v2.9 2017-02-25 13:40:51 -05:00
Stefan Wojcik
9cc4fad614 dummy 2017-02-25 13:15:52 -05:00
Stefan Wojcik
2a486ee537 better connection docstrings [ci skip] 2017-02-25 12:56:36 -05:00
Stefan Wojcik
2579ed754f add unit tests for setting the connection pool size 2017-02-25 12:48:03 -05:00
Stefan Wójcik
3f31666796 Fix the exception message when validating unicode URLs (#1486) 2017-02-24 16:18:34 -05:00
Stefan Wojcik
3fe8031cf3 fix EmbeddedDocumentListFieldTestCase 2017-02-22 12:44:05 -05:00
bagerard
b27c7ce11b allow to use sets in field choices (#1482) 2017-02-15 08:51:47 -05:00
Stefan Wojcik
ed34c2ca68 update the changelog and upgrade docs 2017-02-09 12:13:56 -08:00
Stefan Wójcik
3ca2e953fb Fix limit/skip/hint/batch_size chaining (#1476) 2017-02-09 12:02:46 -08:00
13 changed files with 378 additions and 149 deletions

View File

@@ -4,7 +4,10 @@ Changelog
Development
===========
- (Fill this out as you fix issues and develop you features).
- (Fill this out as you fix issues and develop your features).
- Fixed using sets in field choices #1481
- POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476
- POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476
- Fixed connecting to a replica set with PyMongo 2.x #1436
- Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237

View File

@@ -150,7 +150,7 @@ arguments can be set on all fields:
.. note:: If set, this field is also accessible through the `pk` field.
:attr:`choices` (Default: None)
An iterable (e.g. a list or tuple) of choices to which the value of this
An iterable (e.g. list, tuple or set) of choices to which the value of this
field should be limited.
Can be either be a nested tuples of value (stored in mongo) and a
@@ -214,8 +214,8 @@ document class as the first argument::
Dictionary Fields
-----------------
Often, an embedded document may be used instead of a dictionary generally
embedded documents are recommended as dictionaries dont support validation
Often, an embedded document may be used instead of a dictionary generally
embedded documents are recommended as dictionaries dont support validation
or custom field types. However, sometimes you will not know the structure of what you want to
store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate::

View File

@@ -2,6 +2,20 @@
Upgrading
#########
Development
***********
(Fill this out whenever you introduce breaking changes to MongoEngine)
This release includes various fixes for the `BaseQuerySet` methods and how they
are chained together. Since version 0.10.1 applying limit/skip/hint/batch_size
to an already-existing queryset wouldn't modify the underlying PyMongo cursor.
This has been fixed now, so you'll need to make sure that your code didn't rely
on the broken implementation.
Additionally, a public `BaseQuerySet.clone_into` has been renamed to a private
`_clone_into`. If you directly used that method in your code, you'll need to
rename its occurrences.
0.11.0
******
This release includes a major rehaul of MongoEngine's code quality and

View File

@@ -193,7 +193,8 @@ class BaseField(object):
EmbeddedDocument = _import_class('EmbeddedDocument')
choice_list = self.choices
if isinstance(choice_list[0], (list, tuple)):
if isinstance(next(iter(choice_list)), (list, tuple)):
# next(iter) is useful for sets
choice_list = [k for k, _ in choice_list]
# Choices which are other types of Documents

View File

@@ -51,7 +51,9 @@ def register_connection(alias, name=None, host=None, port=None,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
:param kwargs: ad-hoc parameters to be passed into the pymongo driver,
for example maxpoolsize, tz_aware, etc. See the documentation
for pymongo's `MongoClient` for a full list.
.. versionchanged:: 0.10.6 - added mongomock support
"""
@@ -241,9 +243,12 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
running on the default port on localhost. If authentication is needed,
provide username and password arguments as well.
Multiple databases are supported by using aliases. Provide a separate
Multiple databases are supported by using aliases. Provide a separate
`alias` to connect to a different instance of :program:`mongod`.
See the docstring for `register_connection` for more details about all
supported kwargs.
.. versionchanged:: 0.6 - added multiple database support.
"""
if alias not in _connections:

View File

@@ -139,12 +139,12 @@ class URLField(StringField):
# Check first if the scheme is valid
scheme = value.split('://')[0].lower()
if scheme not in self.schemes:
self.error('Invalid scheme {} in URL: {}'.format(scheme, value))
self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value))
return
# Then check full URL
if not self.url_regex.match(value):
self.error('Invalid URL: {}'.format(value))
self.error(u'Invalid URL: {}'.format(value))
return

View File

@@ -86,6 +86,7 @@ class BaseQuerySet(object):
self._batch_size = None
self.only_fields = []
self._max_time_ms = None
self._comment = None
def __call__(self, q_obj=None, class_check=True, read_preference=None,
**query):
@@ -706,39 +707,36 @@ class BaseQuerySet(object):
with switch_db(self._document, alias) as cls:
collection = cls._get_collection()
return self.clone_into(self.__class__(self._document, collection))
return self._clone_into(self.__class__(self._document, collection))
def clone(self):
"""Creates a copy of the current
:class:`~mongoengine.queryset.QuerySet`
"""Create a copy of the current queryset."""
return self._clone_into(self.__class__(self._document, self._collection_obj))
.. versionadded:: 0.5
def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
"""
return self.clone_into(self.__class__(self._document, self._collection_obj))
def clone_into(self, cls):
"""Creates a copy of the current
:class:`~mongoengine.queryset.base.BaseQuerySet` into another child class
"""
if not isinstance(cls, BaseQuerySet):
if not isinstance(new_qs, BaseQuerySet):
raise OperationError(
'%s is not a subclass of BaseQuerySet' % cls.__name__)
'%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_class_check', '_slave_okay', '_read_preference',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms')
'_search_text', 'only_fields', '_max_time_ms', '_comment')
for prop in copy_props:
val = getattr(self, prop)
setattr(cls, prop, copy.copy(val))
setattr(new_qs, prop, copy.copy(val))
if self._cursor_obj:
cls._cursor_obj = self._cursor_obj.clone()
new_qs._cursor_obj = self._cursor_obj.clone()
return cls
return new_qs
def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects or
@@ -760,7 +758,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._limit = n if n != 0 else 1
# Return self to allow chaining
# If a cursor object has already been created, apply the limit to it.
if queryset._cursor_obj:
queryset._cursor_obj.limit(queryset._limit)
return queryset
def skip(self, n):
@@ -771,6 +773,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._skip = n
# If a cursor object has already been created, apply the skip to it.
if queryset._cursor_obj:
queryset._cursor_obj.skip(queryset._skip)
return queryset
def hint(self, index=None):
@@ -788,6 +795,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._hint = index
# If a cursor object has already been created, apply the hint to it.
if queryset._cursor_obj:
queryset._cursor_obj.hint(queryset._hint)
return queryset
def batch_size(self, size):
@@ -801,6 +813,11 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._batch_size = size
# If a cursor object has already been created, apply the batch size to it.
if queryset._cursor_obj:
queryset._cursor_obj.batch_size(queryset._batch_size)
return queryset
def distinct(self, field):
@@ -972,13 +989,31 @@ class BaseQuerySet(object):
def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed.
Ascending order is assumed. If no keys are passed, existing ordering
is cleared instead.
:param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction
"""
queryset = self.clone()
queryset._ordering = queryset._get_order_by(keys)
old_ordering = queryset._ordering
new_ordering = queryset._get_order_by(keys)
if queryset._cursor_obj:
# If a cursor object has already been created, apply the sort to it
if new_ordering:
queryset._cursor_obj.sort(new_ordering)
# If we're trying to clear a previous explicit ordering, we need
# to clear the cursor entirely (because PyMongo doesn't allow
# clearing an existing sort on a cursor).
elif old_ordering:
queryset._cursor_obj = None
queryset._ordering = new_ordering
return queryset
def comment(self, text):
@@ -1424,10 +1459,13 @@ class BaseQuerySet(object):
raise StopIteration
raw_doc = self._cursor.next()
if self._as_pymongo:
return self._get_as_pymongo(raw_doc)
doc = self._document._from_son(raw_doc,
_auto_dereference=self._auto_dereference, only_fields=self.only_fields)
doc = self._document._from_son(
raw_doc, _auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
if self._scalar:
return self._get_scalar(doc)
@@ -1437,7 +1475,6 @@ class BaseQuerySet(object):
def rewind(self):
"""Rewind the cursor to its unevaluated state.
.. versionadded:: 0.3
"""
self._iter = False
@@ -1487,43 +1524,54 @@ class BaseQuerySet(object):
@property
def _cursor(self):
if self._cursor_obj is None:
"""Return a PyMongo cursor object corresponding to this queryset."""
# In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned
# collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
# If _cursor_obj already exists, return it immediately.
if self._cursor_obj is not None:
return self._cursor_obj
if self._ordering:
# Apply query ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# Otherwise, apply the ordering from the document model, unless
# it's been explicitly cleared via order_by with no arguments
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
# Create a new PyMongo cursor.
# XXX In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply "where" clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
# Apply ordering to the cursor.
# XXX self._ordering can be equal to:
# * None if we didn't explicitly call order_by on this queryset.
# * A list of PyMongo-style sorting tuples.
# * An empty list if we explicitly called order_by() without any
# arguments. This indicates that we want to clear the default
# ordering.
if self._ordering:
# explicit ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# default ordering
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._limit is not None:
self._cursor_obj.limit(self._limit)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._skip is not None:
self._cursor_obj.skip(self._skip)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj
@@ -1698,7 +1746,13 @@ class BaseQuerySet(object):
return ret
def _get_order_by(self, keys):
"""Creates a list of order by fields"""
"""Given a list of MongoEngine-style sort keys, return a list
of sorting tuples that can be applied to a PyMongo cursor. For
example:
>>> qs._get_order_by(['-last_name', 'first_name'])
[('last_name', -1), ('first_name', 1)]
"""
key_list = []
for key in keys:
if not key:
@@ -1711,17 +1765,19 @@ class BaseQuerySet(object):
direction = pymongo.ASCENDING
if key[0] == '-':
direction = pymongo.DESCENDING
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
try:
key = self._document._translate_field_name(key)
except Exception:
# TODO this exception should be more specific
pass
key_list.append((key, direction))
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list
def _get_scalar(self, doc):
@@ -1819,10 +1875,21 @@ class BaseQuerySet(object):
return code
def _chainable_method(self, method_name, val):
"""Call a particular method on the PyMongo cursor call a particular chainable method
with the provided value.
"""
queryset = self.clone()
method = getattr(queryset._cursor, method_name)
method(val)
# Get an existing cursor object or create a new one
cursor = queryset._cursor
# Find the requested method on the cursor and call it with the
# provided value
getattr(cursor, method_name)(val)
# Cache the value on the queryset._{method_name}
setattr(queryset, '_' + method_name, val)
return queryset
# Deprecated

View File

@@ -136,13 +136,15 @@ class QuerySet(BaseQuerySet):
return self._len
def no_cache(self):
"""Convert to a non_caching queryset
"""Convert to a non-caching queryset
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection))
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
class QuerySetNoCache(BaseQuerySet):
@@ -153,7 +155,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to caching queryset
"""
return self.clone_into(QuerySet(self._document, self._collection))
return self._clone_into(QuerySet(self._document, self._collection))
def __repr__(self):
"""Provides the string representation of the QuerySet

View File

@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
import six
from nose.plugins.skip import SkipTest
import datetime
import unittest
import uuid
import math
import itertools
import re
from nose.plugins.skip import SkipTest
import six
try:
@@ -27,21 +26,13 @@ from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList,
_document_registry)
from mongoengine.errors import NotRegistered, DoesNotExist
from tests.utils import MongoDBTestCase
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def tearDown(self):
self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks')
self.db.drop_collection('mongoengine.counters')
class FieldTest(MongoDBTestCase):
def test_default_values_nothing_set(self):
"""Ensure that default field values are used when creating a document.
@@ -227,9 +218,9 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime))
def test_not_required_handles_none_from_database(self):
"""Ensure that every fields can handle null values from the database.
"""Ensure that every field can handle null values from the
database.
"""
class HandleNoneFields(Document):
str_fld = StringField(required=True)
int_fld = IntField(required=True)
@@ -350,11 +341,12 @@ class FieldTest(unittest.TestCase):
person.validate()
def test_url_validation(self):
"""Ensure that URLFields validate urls properly.
"""
"""Ensure that URLFields validate urls properly."""
class Link(Document):
url = URLField()
Link.drop_collection()
link = Link()
link.url = 'google'
self.assertRaises(ValidationError, link.validate)
@@ -362,6 +354,27 @@ class FieldTest(unittest.TestCase):
link.url = 'http://www.google.com:8080'
link.validate()
def test_unicode_url_validation(self):
"""Ensure unicode URLs are validated properly."""
class Link(Document):
url = URLField()
Link.drop_collection()
link = Link()
link.url = u'http://привет.com'
# TODO fix URL validation - this *IS* a valid URL
# For now we just want to make sure that the error message is correct
try:
link.validate()
self.assertTrue(False)
except ValidationError as e:
self.assertEqual(
unicode(e),
u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])"
)
def test_url_scheme_validation(self):
"""Ensure that URLFields validate urls with specific schemes properly.
"""
@@ -3186,26 +3199,42 @@ class FieldTest(unittest.TestCase):
att.delete()
self.assertEqual(0, Attachment.objects.count())
def test_choices_validation(self):
"""Ensure that value is in a container of allowed values.
def test_choices_allow_using_sets_as_choices(self):
"""Ensure that sets can be used when setting choices
"""
class Shirt(Document):
size = StringField(max_length=3, choices=(
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
size = StringField(choices={'M', 'L'})
Shirt.drop_collection()
Shirt(size='M').validate()
def test_choices_validation_allow_no_value(self):
"""Ensure that .validate passes and no value was provided
for a field setup with choices
"""
class Shirt(Document):
size = StringField(choices=('S', 'M'))
shirt = Shirt()
shirt.validate()
shirt.size = "S"
def test_choices_validation_accept_possible_value(self):
"""Ensure that value is in a container of allowed values.
"""
class Shirt(Document):
size = StringField(choices=('S', 'M'))
shirt = Shirt(size='S')
shirt.validate()
shirt.size = "XS"
self.assertRaises(ValidationError, shirt.validate)
def test_choices_validation_reject_unknown_value(self):
"""Ensure that unallowed value are rejected upon validation
"""
class Shirt(Document):
size = StringField(choices=('S', 'M'))
Shirt.drop_collection()
shirt = Shirt(size="XS")
with self.assertRaises(ValidationError):
shirt.validate()
def test_choices_validation_documents(self):
"""
@@ -4024,12 +4053,13 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(doc.some_long, six.integer_types))
class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.db = connect(db='EmbeddedDocumentListFieldTestCase')
class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
def setUp(self):
"""
Create two BlogPost entries in the database, each with
several EmbeddedDocuments.
"""
class Comments(EmbeddedDocument):
author = StringField()
message = StringField()
@@ -4037,14 +4067,11 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
class BlogPost(Document):
comments = EmbeddedDocumentListField(Comments)
cls.Comments = Comments
cls.BlogPost = BlogPost
BlogPost.drop_collection()
self.Comments = Comments
self.BlogPost = BlogPost
def setUp(self):
"""
Create two BlogPost entries in the database, each with
several EmbeddedDocuments.
"""
self.post1 = self.BlogPost(comments=[
self.Comments(author='user1', message='message1'),
self.Comments(author='user2', message='message1')
@@ -4056,13 +4083,6 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.Comments(author='user3', message='message1')
]).save()
def tearDown(self):
self.BlogPost.drop_collection()
@classmethod
def tearDownClass(cls):
cls.db.drop_database('EmbeddedDocumentListFieldTestCase')
def test_no_keyword_filter(self):
"""
Tests the filter method of a List of Embedded Documents
@@ -4420,7 +4440,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
my_list = ListField(EmbeddedDocumentField(EmbeddedWithUnique))
A(my_list=[]).save()
self.assertRaises(NotUniqueError, lambda: A(my_list=[]).save())
with self.assertRaises(NotUniqueError):
A(my_list=[]).save()
class EmbeddedWithSparseUnique(EmbeddedDocument):
number = IntField(unique=True, sparse=True)
@@ -4428,6 +4449,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
class B(Document):
my_list = ListField(EmbeddedDocumentField(EmbeddedWithSparseUnique))
A.drop_collection()
B.drop_collection()
B(my_list=[]).save()
B(my_list=[]).save()
@@ -4467,6 +4491,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
a_field = IntField()
c_field = IntField(custom_data=custom_data)
CustomData.drop_collection()
a1 = CustomData(a_field=1, c_field=2).save()
self.assertEqual(2, a1.c_field)
self.assertFalse(hasattr(a1.c_field, 'custom_data'))

View File

@@ -18,15 +18,13 @@ try:
except ImportError:
HAS_PIL = False
from tests.utils import MongoDBTestCase
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')
class FileTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
class FileTest(MongoDBTestCase):
def tearDown(self):
self.db.drop_collection('fs.files')

View File

@@ -106,58 +106,111 @@ class QuerySetTest(unittest.TestCase):
list(BlogPost.objects(author2__name="test"))
def test_find(self):
"""Ensure that a query returns a valid set of results.
"""
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=30).save()
"""Ensure that a query returns a valid set of results."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Find all people in the collection
people = self.Person.objects
self.assertEqual(people.count(), 2)
results = list(people)
self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode)))
self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0], user_a)
self.assertEqual(results[0].name, 'User A')
self.assertEqual(results[0].age, 20)
self.assertEqual(results[1].name, "User B")
self.assertEqual(results[1], user_b)
self.assertEqual(results[1].name, 'User B')
self.assertEqual(results[1].age, 30)
# Use a query to filter the people found to just person1
# Filter people by age
people = self.Person.objects(age=20)
self.assertEqual(people.count(), 1)
person = people.next()
self.assertEqual(person, user_a)
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
# Test limit
def test_limit(self):
"""Ensure that QuerySet.limit works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test limit on a new queryset
people = list(self.Person.objects.limit(1))
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[0], user_a)
# Test skip
# Test limit on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.limit(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_a)
# Test chaining of only after limit
person = self.Person.objects().limit(1).only('name').first()
self.assertEqual(person, user_a)
self.assertEqual(person.name, 'User A')
self.assertEqual(person.age, None)
def test_skip(self):
"""Ensure that QuerySet.skip works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test skip on a new queryset
people = list(self.Person.objects.skip(1))
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[0], user_b)
person3 = self.Person(name="User C", age=40)
person3.save()
# Test skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.skip(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test chaining of only after skip
person = self.Person.objects().skip(1).only('name').first()
self.assertEqual(person, user_b)
self.assertEqual(person.name, 'User B')
self.assertEqual(person.age, None)
def test_slice(self):
"""Ensure slicing a queryset works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
user_c = self.Person.objects.create(name="User C", age=40)
# Test slice limit
people = list(self.Person.objects[:2])
self.assertEqual(len(people), 2)
self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[1].name, 'User B')
self.assertEqual(people[0], user_a)
self.assertEqual(people[1], user_b)
# Test slice skip
people = list(self.Person.objects[1:])
self.assertEqual(len(people), 2)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[1].name, 'User C')
self.assertEqual(people[0], user_b)
self.assertEqual(people[1], user_c)
# Test slice limit and skip
people = list(self.Person.objects[1:2])
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[0], user_b)
# Test slice limit and skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 3)
people2 = people[1:2]
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test slice limit and skip cursor reset
qs = self.Person.objects[1:2]
@@ -168,6 +221,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B')
# Test empty slice
people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0)
@@ -187,12 +241,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual("[<Person: Person object>, <Person: Person object>]",
"%s" % self.Person.objects[51:53])
# Test only after limit
self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None)
# Test only after skip
self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None)
def test_find_one(self):
"""Ensure that a query using find_one returns a valid result.
"""
@@ -1226,6 +1274,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
# default ordering should be used by default
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1)
@@ -1234,11 +1283,28 @@ class QuerySetTest(unittest.TestCase):
{'published_date': -1}
)
# calling order_by() should clear the default ordering
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
# calling an explicit order_by should use a specified sort
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by('published_date').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
q.get_ops()[0]['query']['$orderby'],
{'published_date': 1}
)
# calling order_by() after an explicit sort should clear it
with db_ops_tracker() as q:
qs = BlogPost.objects.filter(title='whatever').order_by('published_date')
qs.order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_no_ordering_for_get(self):
""" Ensure that Doc.objects.get doesn't use any ordering.
"""

View File

@@ -285,8 +285,7 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
def test_connection_kwargs(self):
"""Ensure that connection kwargs get passed to pymongo.
"""
"""Ensure that connection kwargs get passed to pymongo."""
connect('mongoenginetest', alias='t1', tz_aware=True)
conn = get_connection('t1')
@@ -296,6 +295,32 @@ class ConnectionTest(unittest.TestCase):
conn = get_connection('t2')
self.assertFalse(get_tz_awareness(conn))
def test_connection_pool_via_kwarg(self):
"""Ensure we can specify a max connection pool size using
a connection kwarg.
"""
# Use "max_pool_size" or "maxpoolsize" depending on PyMongo version
# (former was changed to the latter as described in
# https://jira.mongodb.org/browse/PYTHON-854).
# TODO remove once PyMongo < 3.0 support is dropped
if pymongo.version_tuple[0] >= 3:
pool_size_kwargs = {'maxpoolsize': 100}
else:
pool_size_kwargs = {'max_pool_size': 100}
conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs)
self.assertEqual(conn.max_pool_size, 100)
def test_connection_pool_via_uri(self):
"""Ensure we can specify a max connection pool size using
an option in a connection URI.
"""
if pymongo.version_tuple[0] == 2 and pymongo.version_tuple[1] < 9:
raise SkipTest('maxpoolsize as a URI option is only supported in PyMongo v2.9+')
conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri')
self.assertEqual(conn.max_pool_size, 100)
def test_write_concern(self):
"""Ensure write concern can be specified in connect() via
a kwarg or as part of the connection URI.

22
tests/utils.py Normal file
View File

@@ -0,0 +1,22 @@
import unittest
from mongoengine import connect
from mongoengine.connection import get_db
MONGO_TEST_DB = 'mongoenginetest'
class MongoDBTestCase(unittest.TestCase):
"""Base class for tests that need a mongodb connection
db is being dropped automatically
"""
@classmethod
def setUpClass(cls):
cls._connection = connect(db=MONGO_TEST_DB)
cls._connection.drop_database(MONGO_TEST_DB)
cls.db = get_db()
@classmethod
def tearDownClass(cls):
cls._connection.drop_database(MONGO_TEST_DB)