Merge branch 'master' of https://github.com/MongoEngine/mongoengine into dax_py3
This commit is contained in:
commit
f29b93c762
@ -513,6 +513,9 @@ If a dictionary is passed then the following options are available:
|
|||||||
Allows you to automatically expire data from a collection by setting the
|
Allows you to automatically expire data from a collection by setting the
|
||||||
time in seconds to expire the a field.
|
time in seconds to expire the a field.
|
||||||
|
|
||||||
|
:attr:`name` (Optional)
|
||||||
|
Allows you to specify a name for the index
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Inheritance adds extra fields indices see: :ref:`document-inheritance`.
|
Inheritance adds extra fields indices see: :ref:`document-inheritance`.
|
||||||
|
@ -57,7 +57,8 @@ document values for example::
|
|||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
"""Ensures that only published essays have a `pub_date` and
|
"""Ensures that only published essays have a `pub_date` and
|
||||||
automatically sets the pub_date if published and not set"""
|
automatically sets `pub_date` if essay is published and `pub_date`
|
||||||
|
is not set"""
|
||||||
if self.status == 'Draft' and self.pub_date is not None:
|
if self.status == 'Draft' and self.pub_date is not None:
|
||||||
msg = 'Draft entries should not have a publication date.'
|
msg = 'Draft entries should not have a publication date.'
|
||||||
raise ValidationError(msg)
|
raise ValidationError(msg)
|
||||||
|
@ -104,6 +104,18 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
|||||||
conn_settings['authentication_source'] = uri_options['authsource']
|
conn_settings['authentication_source'] = uri_options['authsource']
|
||||||
if 'authmechanism' in uri_options:
|
if 'authmechanism' in uri_options:
|
||||||
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
|
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
|
||||||
|
if IS_PYMONGO_3 and 'readpreference' in uri_options:
|
||||||
|
read_preferences = (
|
||||||
|
ReadPreference.NEAREST,
|
||||||
|
ReadPreference.PRIMARY,
|
||||||
|
ReadPreference.PRIMARY_PREFERRED,
|
||||||
|
ReadPreference.SECONDARY,
|
||||||
|
ReadPreference.SECONDARY_PREFERRED)
|
||||||
|
read_pf_mode = uri_options['readpreference'].lower()
|
||||||
|
for preference in read_preferences:
|
||||||
|
if preference.name.lower() == read_pf_mode:
|
||||||
|
conn_settings['read_preference'] = preference
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
resolved_hosts.append(entity)
|
resolved_hosts.append(entity)
|
||||||
conn_settings['host'] = resolved_hosts
|
conn_settings['host'] = resolved_hosts
|
||||||
|
@ -145,18 +145,17 @@ class no_sub_classes(object):
|
|||||||
:param cls: the class to turn querying sub classes on
|
:param cls: the class to turn querying sub classes on
|
||||||
"""
|
"""
|
||||||
self.cls = cls
|
self.cls = cls
|
||||||
|
self.cls_initial_subclasses = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Change the objects default and _auto_dereference values."""
|
"""Change the objects default and _auto_dereference values."""
|
||||||
self.cls._all_subclasses = self.cls._subclasses
|
self.cls_initial_subclasses = self.cls._subclasses
|
||||||
self.cls._subclasses = (self.cls,)
|
self.cls._subclasses = (self.cls._class_name,)
|
||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
def __exit__(self, t, value, traceback):
|
def __exit__(self, t, value, traceback):
|
||||||
"""Reset the default and _auto_dereference values."""
|
"""Reset the default and _auto_dereference values."""
|
||||||
self.cls._subclasses = self.cls._all_subclasses
|
self.cls._subclasses = self.cls_initial_subclasses
|
||||||
delattr(self.cls, '_all_subclasses')
|
|
||||||
return self.cls
|
|
||||||
|
|
||||||
|
|
||||||
class query_counter(object):
|
class query_counter(object):
|
||||||
@ -215,7 +214,7 @@ class query_counter(object):
|
|||||||
"""Get the number of queries."""
|
"""Get the number of queries."""
|
||||||
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
|
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
|
||||||
count = self.db.system.profile.find(ignore_query).count() - self.counter
|
count = self.db.system.profile.find(ignore_query).count() - self.counter
|
||||||
self.counter += 1
|
self.counter += 1 # Account for the query we just fired
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
@ -364,7 +364,8 @@ class FloatField(BaseField):
|
|||||||
|
|
||||||
|
|
||||||
class DecimalField(BaseField):
|
class DecimalField(BaseField):
|
||||||
"""Fixed-point decimal number field.
|
"""Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
|
||||||
|
If using floats, beware of Decimal to float conversion (potential precision loss)
|
||||||
|
|
||||||
.. versionchanged:: 0.8
|
.. versionchanged:: 0.8
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
@ -375,7 +376,9 @@ class DecimalField(BaseField):
|
|||||||
"""
|
"""
|
||||||
:param min_value: Validation rule for the minimum acceptable value.
|
:param min_value: Validation rule for the minimum acceptable value.
|
||||||
:param max_value: Validation rule for the maximum acceptable value.
|
:param max_value: Validation rule for the maximum acceptable value.
|
||||||
:param force_string: Store as a string.
|
:param force_string: Store the value as a string (instead of a float).
|
||||||
|
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
|
||||||
|
and some query operator won't work (e.g: inc, dec)
|
||||||
:param precision: Number of decimal places to store.
|
:param precision: Number of decimal places to store.
|
||||||
:param rounding: The rounding rule from the python decimal library:
|
:param rounding: The rounding rule from the python decimal library:
|
||||||
|
|
||||||
@ -647,9 +650,17 @@ class EmbeddedDocumentField(BaseField):
|
|||||||
def document_type(self):
|
def document_type(self):
|
||||||
if isinstance(self.document_type_obj, six.string_types):
|
if isinstance(self.document_type_obj, six.string_types):
|
||||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||||
self.document_type_obj = self.owner_document
|
resolved_document_type = self.owner_document
|
||||||
else:
|
else:
|
||||||
self.document_type_obj = get_document(self.document_type_obj)
|
resolved_document_type = get_document(self.document_type_obj)
|
||||||
|
|
||||||
|
if not issubclass(resolved_document_type, EmbeddedDocument):
|
||||||
|
# Due to the late resolution of the document_type
|
||||||
|
# There is a chance that it won't be an EmbeddedDocument (#1661)
|
||||||
|
self.error('Invalid embedded document class provided to an '
|
||||||
|
'EmbeddedDocumentField')
|
||||||
|
self.document_type_obj = resolved_document_type
|
||||||
|
|
||||||
return self.document_type_obj
|
return self.document_type_obj
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
@ -1029,11 +1040,13 @@ class ReferenceField(BaseField):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
class Bar(Document):
|
class Org(Document):
|
||||||
content = StringField()
|
owner = ReferenceField('User')
|
||||||
foo = ReferenceField('Foo')
|
|
||||||
|
|
||||||
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
|
class User(Document):
|
||||||
|
org = ReferenceField('Org', reverse_delete_rule=CASCADE)
|
||||||
|
|
||||||
|
User.register_delete_rule(Org, 'owner', DENY)
|
||||||
|
|
||||||
.. versionchanged:: 0.5 added `reverse_delete_rule`
|
.. versionchanged:: 0.5 added `reverse_delete_rule`
|
||||||
"""
|
"""
|
||||||
|
@ -147,7 +147,7 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
if op is None or key not in mongo_query:
|
if op is None or key not in mongo_query:
|
||||||
mongo_query[key] = value
|
mongo_query[key] = value
|
||||||
elif key in mongo_query:
|
elif key in mongo_query:
|
||||||
if isinstance(mongo_query[key], dict):
|
if isinstance(mongo_query[key], dict) and isinstance(value, dict):
|
||||||
mongo_query[key].update(value)
|
mongo_query[key].update(value)
|
||||||
# $max/minDistance needs to come last - convert to SON
|
# $max/minDistance needs to come last - convert to SON
|
||||||
value_dict = mongo_query[key]
|
value_dict = mongo_query[key]
|
||||||
@ -201,30 +201,37 @@ def update(_doc_cls=None, **update):
|
|||||||
format.
|
format.
|
||||||
"""
|
"""
|
||||||
mongo_update = {}
|
mongo_update = {}
|
||||||
|
|
||||||
for key, value in update.items():
|
for key, value in update.items():
|
||||||
if key == '__raw__':
|
if key == '__raw__':
|
||||||
mongo_update.update(value)
|
mongo_update.update(value)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parts = key.split('__')
|
parts = key.split('__')
|
||||||
|
|
||||||
# if there is no operator, default to 'set'
|
# if there is no operator, default to 'set'
|
||||||
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
||||||
parts.insert(0, 'set')
|
parts.insert(0, 'set')
|
||||||
|
|
||||||
# Check for an operator and transform to mongo-style if there is
|
# Check for an operator and transform to mongo-style if there is
|
||||||
op = None
|
op = None
|
||||||
if parts[0] in UPDATE_OPERATORS:
|
if parts[0] in UPDATE_OPERATORS:
|
||||||
op = parts.pop(0)
|
op = parts.pop(0)
|
||||||
# Convert Pythonic names to Mongo equivalents
|
# Convert Pythonic names to Mongo equivalents
|
||||||
if op in ('push_all', 'pull_all'):
|
operator_map = {
|
||||||
op = op.replace('_all', 'All')
|
'push_all': 'pushAll',
|
||||||
elif op == 'dec':
|
'pull_all': 'pullAll',
|
||||||
|
'dec': 'inc',
|
||||||
|
'add_to_set': 'addToSet',
|
||||||
|
'set_on_insert': 'setOnInsert'
|
||||||
|
}
|
||||||
|
if op == 'dec':
|
||||||
# Support decrement by flipping a positive value's sign
|
# Support decrement by flipping a positive value's sign
|
||||||
# and using 'inc'
|
# and using 'inc'
|
||||||
op = 'inc'
|
|
||||||
value = -value
|
value = -value
|
||||||
elif op == 'add_to_set':
|
# If the operator doesn't found from operator map, the op value
|
||||||
op = 'addToSet'
|
# will stay unchanged
|
||||||
elif op == 'set_on_insert':
|
op = operator_map.get(op, op)
|
||||||
op = 'setOnInsert'
|
|
||||||
|
|
||||||
match = None
|
match = None
|
||||||
if parts[-1] in COMPARISON_OPERATORS:
|
if parts[-1] in COMPARISON_OPERATORS:
|
||||||
@ -291,6 +298,8 @@ def update(_doc_cls=None, **update):
|
|||||||
value = field.prepare_query_value(op, value)
|
value = field.prepare_query_value(op, value)
|
||||||
elif op == 'unset':
|
elif op == 'unset':
|
||||||
value = 1
|
value = 1
|
||||||
|
elif op == 'inc':
|
||||||
|
value = field.prepare_query_value(op, value)
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
match = '$' + match
|
match = '$' + match
|
||||||
|
@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase):
|
|||||||
]))
|
]))
|
||||||
self.assertEqual(a.b.c.txt, 'hi')
|
self.assertEqual(a.b.c.txt, 'hi')
|
||||||
|
|
||||||
|
def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self):
|
||||||
|
raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet")
|
||||||
|
|
||||||
|
class MyDoc2(Document):
|
||||||
|
emb = EmbeddedDocumentField('MyDoc')
|
||||||
|
|
||||||
|
class MyDoc(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
def test_embedded_document_validation(self):
|
def test_embedded_document_validation(self):
|
||||||
"""Ensure that invalid embedded documents cannot be assigned to
|
"""Ensure that invalid embedded documents cannot be assigned to
|
||||||
embedded document fields.
|
embedded document fields.
|
||||||
@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
|
|||||||
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
|
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddedDocumentField(MongoDBTestCase):
|
||||||
|
def test___init___(self):
|
||||||
|
class MyDoc(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
field = EmbeddedDocumentField(MyDoc)
|
||||||
|
self.assertEqual(field.document_type_obj, MyDoc)
|
||||||
|
|
||||||
|
field2 = EmbeddedDocumentField('MyDoc')
|
||||||
|
self.assertEqual(field2.document_type_obj, 'MyDoc')
|
||||||
|
|
||||||
|
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
EmbeddedDocumentField(dict)
|
||||||
|
|
||||||
|
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
|
||||||
|
|
||||||
|
class MyDoc(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
emb = EmbeddedDocumentField('MyDoc')
|
||||||
|
with self.assertRaises(ValidationError) as ctx:
|
||||||
|
emb.document_type
|
||||||
|
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
|
||||||
|
|
||||||
|
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
|
||||||
|
# Relates to #1661
|
||||||
|
class MyDoc(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
class MyFailingDoc(Document):
|
||||||
|
emb = EmbeddedDocumentField(MyDoc)
|
||||||
|
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
class MyFailingdoc2(Document):
|
||||||
|
emb = EmbeddedDocumentField('MyDoc')
|
||||||
|
|
||||||
class CachedReferenceFieldTest(MongoDBTestCase):
|
class CachedReferenceFieldTest(MongoDBTestCase):
|
||||||
|
|
||||||
def test_cached_reference_field_get_and_save(self):
|
def test_cached_reference_field_get_and_save(self):
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
from bson import DBRef, ObjectId
|
from bson import DBRef, ObjectId
|
||||||
from nose.plugins.skip import SkipTest
|
from nose.plugins.skip import SkipTest
|
||||||
@ -1202,6 +1203,14 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
Blog.drop_collection()
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
def test_filter_chaining_with_regex(self):
|
||||||
|
person = self.Person(name='Guido van Rossum')
|
||||||
|
person.save()
|
||||||
|
|
||||||
|
people = self.Person.objects
|
||||||
|
people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum')
|
||||||
|
self.assertEqual(people.count(), 1)
|
||||||
|
|
||||||
def assertSequence(self, qs, expected):
|
def assertSequence(self, qs, expected):
|
||||||
qs = list(qs)
|
qs = list(qs)
|
||||||
expected = list(expected)
|
expected = list(expected)
|
||||||
@ -1851,21 +1860,16 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
||||||
|
|
||||||
def test_update(self):
|
def test_update_intfield_operator(self):
|
||||||
"""Ensure that atomic updates work properly.
|
|
||||||
"""
|
|
||||||
class BlogPost(Document):
|
class BlogPost(Document):
|
||||||
name = StringField()
|
|
||||||
title = StringField()
|
|
||||||
hits = IntField()
|
hits = IntField()
|
||||||
tags = ListField(StringField())
|
|
||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
post = BlogPost(name="Test Post", hits=5, tags=['test'])
|
post = BlogPost(hits=5)
|
||||||
post.save()
|
post.save()
|
||||||
|
|
||||||
BlogPost.objects.update(set__hits=10)
|
BlogPost.objects.update_one(set__hits=10)
|
||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.hits, 10)
|
self.assertEqual(post.hits, 10)
|
||||||
|
|
||||||
@ -1882,6 +1886,55 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.hits, 11)
|
self.assertEqual(post.hits, 11)
|
||||||
|
|
||||||
|
def test_update_decimalfield_operator(self):
|
||||||
|
class BlogPost(Document):
|
||||||
|
review = DecimalField()
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(review=3.5)
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(inc__review=0.1) # test with floats
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.6)
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(dec__review=0.1)
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.5)
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.62)
|
||||||
|
|
||||||
|
BlogPost.objects.update_one(dec__review=Decimal(0.12))
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(float(post.review), 3.5)
|
||||||
|
|
||||||
|
def test_update_decimalfield_operator_not_working_with_force_string(self):
|
||||||
|
class BlogPost(Document):
|
||||||
|
review = DecimalField(force_string=True)
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(review=3.5)
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
with self.assertRaises(OperationError):
|
||||||
|
BlogPost.objects.update_one(inc__review=0.1) # test with floats
|
||||||
|
|
||||||
|
def test_update_listfield_operator(self):
|
||||||
|
"""Ensure that atomic updates work properly.
|
||||||
|
"""
|
||||||
|
class BlogPost(Document):
|
||||||
|
tags = ListField(StringField())
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(tags=['test'])
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
# ListField operator
|
||||||
BlogPost.objects.update(push__tags='mongo')
|
BlogPost.objects.update(push__tags='mongo')
|
||||||
post.reload()
|
post.reload()
|
||||||
self.assertTrue('mongo' in post.tags)
|
self.assertTrue('mongo' in post.tags)
|
||||||
@ -1900,13 +1953,23 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.tags.count('unique'), 1)
|
self.assertEqual(post.tags.count('unique'), 1)
|
||||||
|
|
||||||
self.assertNotEqual(post.hits, None)
|
BlogPost.drop_collection()
|
||||||
BlogPost.objects.update_one(unset__hits=1)
|
|
||||||
post.reload()
|
def test_update_unset(self):
|
||||||
self.assertEqual(post.hits, None)
|
class BlogPost(Document):
|
||||||
|
title = StringField()
|
||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
post = BlogPost(title='garbage').save()
|
||||||
|
|
||||||
|
self.assertNotEqual(post.title, None)
|
||||||
|
BlogPost.objects.update_one(unset__title=1)
|
||||||
|
post.reload()
|
||||||
|
self.assertEqual(post.title, None)
|
||||||
|
pymongo_doc = BlogPost.objects.as_pymongo().first()
|
||||||
|
self.assertNotIn('title', pymongo_doc)
|
||||||
|
|
||||||
@needs_mongodb_v26
|
@needs_mongodb_v26
|
||||||
def test_update_push_with_position(self):
|
def test_update_push_with_position(self):
|
||||||
"""Ensure that the 'push' update with position works properly.
|
"""Ensure that the 'push' update with position works properly.
|
||||||
|
@ -364,6 +364,12 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
date_doc = DateDoc.objects.first()
|
date_doc = DateDoc.objects.first()
|
||||||
self.assertEqual(d, date_doc.the_date)
|
self.assertEqual(d, date_doc.the_date)
|
||||||
|
|
||||||
|
def test_read_preference_from_parse(self):
|
||||||
|
if IS_PYMONGO_3:
|
||||||
|
from pymongo import ReadPreference
|
||||||
|
conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
|
||||||
|
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||||
|
|
||||||
def test_multiple_connection_settings(self):
|
def test_multiple_connection_settings(self):
|
||||||
connect('mongoenginetest', alias='t1', host="localhost")
|
connect('mongoenginetest', alias='t1', host="localhost")
|
||||||
|
|
||||||
|
@ -140,8 +140,6 @@ class ContextManagersTest(unittest.TestCase):
|
|||||||
def test_no_sub_classes(self):
|
def test_no_sub_classes(self):
|
||||||
class A(Document):
|
class A(Document):
|
||||||
x = IntField()
|
x = IntField()
|
||||||
y = IntField()
|
|
||||||
|
|
||||||
meta = {'allow_inheritance': True}
|
meta = {'allow_inheritance': True}
|
||||||
|
|
||||||
class B(A):
|
class B(A):
|
||||||
@ -152,29 +150,29 @@ class ContextManagersTest(unittest.TestCase):
|
|||||||
|
|
||||||
A.drop_collection()
|
A.drop_collection()
|
||||||
|
|
||||||
A(x=10, y=20).save()
|
A(x=10).save()
|
||||||
A(x=15, y=30).save()
|
A(x=15).save()
|
||||||
B(x=20, y=40).save()
|
B(x=20).save()
|
||||||
B(x=30, y=50).save()
|
B(x=30).save()
|
||||||
C(x=40, y=60).save()
|
C(x=40).save()
|
||||||
|
|
||||||
self.assertEqual(A.objects.count(), 5)
|
self.assertEqual(A.objects.count(), 5)
|
||||||
self.assertEqual(B.objects.count(), 3)
|
self.assertEqual(B.objects.count(), 3)
|
||||||
self.assertEqual(C.objects.count(), 1)
|
self.assertEqual(C.objects.count(), 1)
|
||||||
|
|
||||||
with no_sub_classes(A) as A:
|
with no_sub_classes(A):
|
||||||
self.assertEqual(A.objects.count(), 2)
|
self.assertEqual(A.objects.count(), 2)
|
||||||
|
|
||||||
for obj in A.objects:
|
for obj in A.objects:
|
||||||
self.assertEqual(obj.__class__, A)
|
self.assertEqual(obj.__class__, A)
|
||||||
|
|
||||||
with no_sub_classes(B) as B:
|
with no_sub_classes(B):
|
||||||
self.assertEqual(B.objects.count(), 2)
|
self.assertEqual(B.objects.count(), 2)
|
||||||
|
|
||||||
for obj in B.objects:
|
for obj in B.objects:
|
||||||
self.assertEqual(obj.__class__, B)
|
self.assertEqual(obj.__class__, B)
|
||||||
|
|
||||||
with no_sub_classes(C) as C:
|
with no_sub_classes(C):
|
||||||
self.assertEqual(C.objects.count(), 1)
|
self.assertEqual(C.objects.count(), 1)
|
||||||
|
|
||||||
for obj in C.objects:
|
for obj in C.objects:
|
||||||
@ -185,6 +183,32 @@ class ContextManagersTest(unittest.TestCase):
|
|||||||
self.assertEqual(B.objects.count(), 3)
|
self.assertEqual(B.objects.count(), 3)
|
||||||
self.assertEqual(C.objects.count(), 1)
|
self.assertEqual(C.objects.count(), 1)
|
||||||
|
|
||||||
|
def test_no_sub_classes_modification_to_document_class_are_temporary(self):
|
||||||
|
class A(Document):
|
||||||
|
x = IntField()
|
||||||
|
meta = {'allow_inheritance': True}
|
||||||
|
|
||||||
|
class B(A):
|
||||||
|
z = IntField()
|
||||||
|
|
||||||
|
self.assertEqual(A._subclasses, ('A', 'A.B'))
|
||||||
|
with no_sub_classes(A):
|
||||||
|
self.assertEqual(A._subclasses, ('A',))
|
||||||
|
self.assertEqual(A._subclasses, ('A', 'A.B'))
|
||||||
|
|
||||||
|
self.assertEqual(B._subclasses, ('A.B',))
|
||||||
|
with no_sub_classes(B):
|
||||||
|
self.assertEqual(B._subclasses, ('A.B',))
|
||||||
|
self.assertEqual(B._subclasses, ('A.B',))
|
||||||
|
|
||||||
|
def test_no_subclass_context_manager_does_not_swallow_exception(self):
|
||||||
|
class User(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
with no_sub_classes(User):
|
||||||
|
raise TypeError()
|
||||||
|
|
||||||
def test_query_counter(self):
|
def test_query_counter(self):
|
||||||
connect('mongoenginetest')
|
connect('mongoenginetest')
|
||||||
db = get_db()
|
db = get_db()
|
||||||
|
@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection
|
|||||||
from mongoengine.python_support import IS_PYMONGO_3
|
from mongoengine.python_support import IS_PYMONGO_3
|
||||||
|
|
||||||
|
|
||||||
MONGO_TEST_DB = 'mongoenginetest'
|
MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
|
||||||
|
|
||||||
|
|
||||||
class MongoDBTestCase(unittest.TestCase):
|
class MongoDBTestCase(unittest.TestCase):
|
||||||
"""Base class for tests that need a mongodb connection
|
"""Base class for tests that need a mongodb connection
|
||||||
db is being dropped automatically
|
It ensures that the db is clean at the beginning and dropped at the end automatically
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -32,6 +32,7 @@ def get_mongodb_version():
|
|||||||
"""
|
"""
|
||||||
return tuple(get_connection().server_info()['versionArray'])
|
return tuple(get_connection().server_info()['versionArray'])
|
||||||
|
|
||||||
|
|
||||||
def _decorated_with_ver_requirement(func, ver_tuple):
|
def _decorated_with_ver_requirement(func, ver_tuple):
|
||||||
"""Return a given function decorated with the version requirement
|
"""Return a given function decorated with the version requirement
|
||||||
for a particular MongoDB version tuple.
|
for a particular MongoDB version tuple.
|
||||||
@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple):
|
|||||||
|
|
||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
|
|
||||||
def needs_mongodb_v26(func):
|
def needs_mongodb_v26(func):
|
||||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||||
lower than v2.6.
|
lower than v2.6.
|
||||||
"""
|
"""
|
||||||
return _decorated_with_ver_requirement(func, (2, 6))
|
return _decorated_with_ver_requirement(func, (2, 6))
|
||||||
|
|
||||||
|
|
||||||
def needs_mongodb_v3(func):
|
def needs_mongodb_v3(func):
|
||||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||||
lower than v3.0.
|
lower than v3.0.
|
||||||
"""
|
"""
|
||||||
return _decorated_with_ver_requirement(func, (3, 0))
|
return _decorated_with_ver_requirement(func, (3, 0))
|
||||||
|
|
||||||
|
|
||||||
def skip_pymongo3(f):
|
def skip_pymongo3(f):
|
||||||
"""Raise a SkipTest exception if we're running a test against
|
"""Raise a SkipTest exception if we're running a test against
|
||||||
PyMongo v3.x.
|
PyMongo v3.x.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user