Merge branch 'master' of https://github.com/MongoEngine/mongoengine into dax_py3

This commit is contained in:
Bastien Gérard 2018-09-04 16:06:45 +02:00
commit f29b93c762
11 changed files with 229 additions and 48 deletions

View File

@ -513,6 +513,9 @@ If a dictionary is passed then the following options are available:
Allows you to automatically expire data from a collection by setting the Allows you to automatically expire data from a collection by setting the
time in seconds to expire the a field. time in seconds to expire the a field.
:attr:`name` (Optional)
Allows you to specify a name for the index
.. note:: .. note::
Inheritance adds extra fields indices see: :ref:`document-inheritance`. Inheritance adds extra fields indices see: :ref:`document-inheritance`.

View File

@ -57,7 +57,8 @@ document values for example::
def clean(self): def clean(self):
"""Ensures that only published essays have a `pub_date` and """Ensures that only published essays have a `pub_date` and
automatically sets the pub_date if published and not set""" automatically sets `pub_date` if essay is published and `pub_date`
is not set"""
if self.status == 'Draft' and self.pub_date is not None: if self.status == 'Draft' and self.pub_date is not None:
msg = 'Draft entries should not have a publication date.' msg = 'Draft entries should not have a publication date.'
raise ValidationError(msg) raise ValidationError(msg)

View File

@ -104,6 +104,18 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
conn_settings['authentication_source'] = uri_options['authsource'] conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options: if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism'] conn_settings['authentication_mechanism'] = uri_options['authmechanism']
if IS_PYMONGO_3 and 'readpreference' in uri_options:
read_preferences = (
ReadPreference.NEAREST,
ReadPreference.PRIMARY,
ReadPreference.PRIMARY_PREFERRED,
ReadPreference.SECONDARY,
ReadPreference.SECONDARY_PREFERRED)
read_pf_mode = uri_options['readpreference'].lower()
for preference in read_preferences:
if preference.name.lower() == read_pf_mode:
conn_settings['read_preference'] = preference
break
else: else:
resolved_hosts.append(entity) resolved_hosts.append(entity)
conn_settings['host'] = resolved_hosts conn_settings['host'] = resolved_hosts

View File

@ -145,18 +145,17 @@ class no_sub_classes(object):
:param cls: the class to turn querying sub classes on :param cls: the class to turn querying sub classes on
""" """
self.cls = cls self.cls = cls
self.cls_initial_subclasses = None
def __enter__(self): def __enter__(self):
"""Change the objects default and _auto_dereference values.""" """Change the objects default and _auto_dereference values."""
self.cls._all_subclasses = self.cls._subclasses self.cls_initial_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls,) self.cls._subclasses = (self.cls._class_name,)
return self.cls return self.cls
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values.""" """Reset the default and _auto_dereference values."""
self.cls._subclasses = self.cls._all_subclasses self.cls._subclasses = self.cls_initial_subclasses
delattr(self.cls, '_all_subclasses')
return self.cls
class query_counter(object): class query_counter(object):
@ -215,7 +214,7 @@ class query_counter(object):
"""Get the number of queries.""" """Get the number of queries."""
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
count = self.db.system.profile.find(ignore_query).count() - self.counter count = self.db.system.profile.find(ignore_query).count() - self.counter
self.counter += 1 self.counter += 1 # Account for the query we just fired
return count return count

View File

@ -364,7 +364,8 @@ class FloatField(BaseField):
class DecimalField(BaseField): class DecimalField(BaseField):
"""Fixed-point decimal number field. """Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
If using floats, beware of Decimal to float conversion (potential precision loss)
.. versionchanged:: 0.8 .. versionchanged:: 0.8
.. versionadded:: 0.3 .. versionadded:: 0.3
@ -375,7 +376,9 @@ class DecimalField(BaseField):
""" """
:param min_value: Validation rule for the minimum acceptable value. :param min_value: Validation rule for the minimum acceptable value.
:param max_value: Validation rule for the maximum acceptable value. :param max_value: Validation rule for the maximum acceptable value.
:param force_string: Store as a string. :param force_string: Store the value as a string (instead of a float).
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
and some query operator won't work (e.g: inc, dec)
:param precision: Number of decimal places to store. :param precision: Number of decimal places to store.
:param rounding: The rounding rule from the python decimal library: :param rounding: The rounding rule from the python decimal library:
@ -647,9 +650,17 @@ class EmbeddedDocumentField(BaseField):
def document_type(self): def document_type(self):
if isinstance(self.document_type_obj, six.string_types): if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document resolved_document_type = self.owner_document
else: else:
self.document_type_obj = get_document(self.document_type_obj) resolved_document_type = get_document(self.document_type_obj)
if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
# There is a chance that it won't be an EmbeddedDocument (#1661)
self.error('Invalid embedded document class provided to an '
'EmbeddedDocumentField')
self.document_type_obj = resolved_document_type
return self.document_type_obj return self.document_type_obj
def to_python(self, value): def to_python(self, value):
@ -1029,11 +1040,13 @@ class ReferenceField(BaseField):
.. code-block:: python .. code-block:: python
class Bar(Document): class Org(Document):
content = StringField() owner = ReferenceField('User')
foo = ReferenceField('Foo')
Foo.register_delete_rule(Bar, 'foo', NULLIFY) class User(Document):
org = ReferenceField('Org', reverse_delete_rule=CASCADE)
User.register_delete_rule(Org, 'owner', DENY)
.. versionchanged:: 0.5 added `reverse_delete_rule` .. versionchanged:: 0.5 added `reverse_delete_rule`
""" """

View File

@ -147,7 +147,7 @@ def query(_doc_cls=None, **kwargs):
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
mongo_query[key] = value mongo_query[key] = value
elif key in mongo_query: elif key in mongo_query:
if isinstance(mongo_query[key], dict): if isinstance(mongo_query[key], dict) and isinstance(value, dict):
mongo_query[key].update(value) mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON # $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key] value_dict = mongo_query[key]
@ -201,30 +201,37 @@ def update(_doc_cls=None, **update):
format. format.
""" """
mongo_update = {} mongo_update = {}
for key, value in update.items(): for key, value in update.items():
if key == '__raw__': if key == '__raw__':
mongo_update.update(value) mongo_update.update(value)
continue continue
parts = key.split('__') parts = key.split('__')
# if there is no operator, default to 'set' # if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
parts.insert(0, 'set') parts.insert(0, 'set')
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
op = None op = None
if parts[0] in UPDATE_OPERATORS: if parts[0] in UPDATE_OPERATORS:
op = parts.pop(0) op = parts.pop(0)
# Convert Pythonic names to Mongo equivalents # Convert Pythonic names to Mongo equivalents
if op in ('push_all', 'pull_all'): operator_map = {
op = op.replace('_all', 'All') 'push_all': 'pushAll',
elif op == 'dec': 'pull_all': 'pullAll',
'dec': 'inc',
'add_to_set': 'addToSet',
'set_on_insert': 'setOnInsert'
}
if op == 'dec':
# Support decrement by flipping a positive value's sign # Support decrement by flipping a positive value's sign
# and using 'inc' # and using 'inc'
op = 'inc'
value = -value value = -value
elif op == 'add_to_set': # If the operator doesn't found from operator map, the op value
op = 'addToSet' # will stay unchanged
elif op == 'set_on_insert': op = operator_map.get(op, op)
op = 'setOnInsert'
match = None match = None
if parts[-1] in COMPARISON_OPERATORS: if parts[-1] in COMPARISON_OPERATORS:
@ -291,6 +298,8 @@ def update(_doc_cls=None, **update):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op == 'unset': elif op == 'unset':
value = 1 value = 1
elif op == 'inc':
value = field.prepare_query_value(op, value)
if match: if match:
match = '$' + match match = '$' + match

View File

@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase):
])) ]))
self.assertEqual(a.b.c.txt, 'hi') self.assertEqual(a.b.c.txt, 'hi')
def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self):
raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet")
class MyDoc2(Document):
emb = EmbeddedDocumentField('MyDoc')
class MyDoc(EmbeddedDocument):
name = StringField()
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to """Ensure that invalid embedded documents cannot be assigned to
embedded document fields. embedded document fields.
@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
class TestEmbeddedDocumentField(MongoDBTestCase):
def test___init___(self):
class MyDoc(EmbeddedDocument):
name = StringField()
field = EmbeddedDocumentField(MyDoc)
self.assertEqual(field.document_type_obj, MyDoc)
field2 = EmbeddedDocumentField('MyDoc')
self.assertEqual(field2.document_type_obj, 'MyDoc')
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
with self.assertRaises(ValidationError):
EmbeddedDocumentField(dict)
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
class MyDoc(Document):
name = StringField()
emb = EmbeddedDocumentField('MyDoc')
with self.assertRaises(ValidationError) as ctx:
emb.document_type
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
# Relates to #1661
class MyDoc(Document):
name = StringField()
with self.assertRaises(ValidationError):
class MyFailingDoc(Document):
emb = EmbeddedDocumentField(MyDoc)
with self.assertRaises(ValidationError):
class MyFailingdoc2(Document):
emb = EmbeddedDocumentField('MyDoc')
class CachedReferenceFieldTest(MongoDBTestCase): class CachedReferenceFieldTest(MongoDBTestCase):
def test_cached_reference_field_get_and_save(self): def test_cached_reference_field_get_and_save(self):

View File

@ -3,6 +3,7 @@
import datetime import datetime
import unittest import unittest
import uuid import uuid
from decimal import Decimal
from bson import DBRef, ObjectId from bson import DBRef, ObjectId
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
@ -1202,6 +1203,14 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
Blog.drop_collection() Blog.drop_collection()
def test_filter_chaining_with_regex(self):
person = self.Person(name='Guido van Rossum')
person.save()
people = self.Person.objects
people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum')
self.assertEqual(people.count(), 1)
def assertSequence(self, qs, expected): def assertSequence(self, qs, expected):
qs = list(qs) qs = list(qs)
expected = list(expected) expected = list(expected)
@ -1851,21 +1860,16 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual( self.assertEqual(
1, BlogPost.objects(author__in=["%s" % me.pk]).count()) 1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_update(self): def test_update_intfield_operator(self):
"""Ensure that atomic updates work properly.
"""
class BlogPost(Document): class BlogPost(Document):
name = StringField()
title = StringField()
hits = IntField() hits = IntField()
tags = ListField(StringField())
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost(name="Test Post", hits=5, tags=['test']) post = BlogPost(hits=5)
post.save() post.save()
BlogPost.objects.update(set__hits=10) BlogPost.objects.update_one(set__hits=10)
post.reload() post.reload()
self.assertEqual(post.hits, 10) self.assertEqual(post.hits, 10)
@ -1882,6 +1886,55 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.hits, 11) self.assertEqual(post.hits, 11)
def test_update_decimalfield_operator(self):
class BlogPost(Document):
review = DecimalField()
BlogPost.drop_collection()
post = BlogPost(review=3.5)
post.save()
BlogPost.objects.update_one(inc__review=0.1) # test with floats
post.reload()
self.assertEqual(float(post.review), 3.6)
BlogPost.objects.update_one(dec__review=0.1)
post.reload()
self.assertEqual(float(post.review), 3.5)
BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
post.reload()
self.assertEqual(float(post.review), 3.62)
BlogPost.objects.update_one(dec__review=Decimal(0.12))
post.reload()
self.assertEqual(float(post.review), 3.5)
def test_update_decimalfield_operator_not_working_with_force_string(self):
class BlogPost(Document):
review = DecimalField(force_string=True)
BlogPost.drop_collection()
post = BlogPost(review=3.5)
post.save()
with self.assertRaises(OperationError):
BlogPost.objects.update_one(inc__review=0.1) # test with floats
def test_update_listfield_operator(self):
"""Ensure that atomic updates work properly.
"""
class BlogPost(Document):
tags = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost(tags=['test'])
post.save()
# ListField operator
BlogPost.objects.update(push__tags='mongo') BlogPost.objects.update(push__tags='mongo')
post.reload() post.reload()
self.assertTrue('mongo' in post.tags) self.assertTrue('mongo' in post.tags)
@ -1900,13 +1953,23 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertEqual(post.tags.count('unique'), 1) self.assertEqual(post.tags.count('unique'), 1)
self.assertNotEqual(post.hits, None) BlogPost.drop_collection()
BlogPost.objects.update_one(unset__hits=1)
post.reload() def test_update_unset(self):
self.assertEqual(post.hits, None) class BlogPost(Document):
title = StringField()
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost(title='garbage').save()
self.assertNotEqual(post.title, None)
BlogPost.objects.update_one(unset__title=1)
post.reload()
self.assertEqual(post.title, None)
pymongo_doc = BlogPost.objects.as_pymongo().first()
self.assertNotIn('title', pymongo_doc)
@needs_mongodb_v26 @needs_mongodb_v26
def test_update_push_with_position(self): def test_update_push_with_position(self):
"""Ensure that the 'push' update with position works properly. """Ensure that the 'push' update with position works properly.

View File

@ -364,6 +364,12 @@ class ConnectionTest(unittest.TestCase):
date_doc = DateDoc.objects.first() date_doc = DateDoc.objects.first()
self.assertEqual(d, date_doc.the_date) self.assertEqual(d, date_doc.the_date)
def test_read_preference_from_parse(self):
if IS_PYMONGO_3:
from pymongo import ReadPreference
conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
def test_multiple_connection_settings(self): def test_multiple_connection_settings(self):
connect('mongoenginetest', alias='t1', host="localhost") connect('mongoenginetest', alias='t1', host="localhost")

View File

@ -140,8 +140,6 @@ class ContextManagersTest(unittest.TestCase):
def test_no_sub_classes(self): def test_no_sub_classes(self):
class A(Document): class A(Document):
x = IntField() x = IntField()
y = IntField()
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
class B(A): class B(A):
@ -152,29 +150,29 @@ class ContextManagersTest(unittest.TestCase):
A.drop_collection() A.drop_collection()
A(x=10, y=20).save() A(x=10).save()
A(x=15, y=30).save() A(x=15).save()
B(x=20, y=40).save() B(x=20).save()
B(x=30, y=50).save() B(x=30).save()
C(x=40, y=60).save() C(x=40).save()
self.assertEqual(A.objects.count(), 5) self.assertEqual(A.objects.count(), 5)
self.assertEqual(B.objects.count(), 3) self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.count(), 1) self.assertEqual(C.objects.count(), 1)
with no_sub_classes(A) as A: with no_sub_classes(A):
self.assertEqual(A.objects.count(), 2) self.assertEqual(A.objects.count(), 2)
for obj in A.objects: for obj in A.objects:
self.assertEqual(obj.__class__, A) self.assertEqual(obj.__class__, A)
with no_sub_classes(B) as B: with no_sub_classes(B):
self.assertEqual(B.objects.count(), 2) self.assertEqual(B.objects.count(), 2)
for obj in B.objects: for obj in B.objects:
self.assertEqual(obj.__class__, B) self.assertEqual(obj.__class__, B)
with no_sub_classes(C) as C: with no_sub_classes(C):
self.assertEqual(C.objects.count(), 1) self.assertEqual(C.objects.count(), 1)
for obj in C.objects: for obj in C.objects:
@ -185,6 +183,32 @@ class ContextManagersTest(unittest.TestCase):
self.assertEqual(B.objects.count(), 3) self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.count(), 1) self.assertEqual(C.objects.count(), 1)
def test_no_sub_classes_modification_to_document_class_are_temporary(self):
class A(Document):
x = IntField()
meta = {'allow_inheritance': True}
class B(A):
z = IntField()
self.assertEqual(A._subclasses, ('A', 'A.B'))
with no_sub_classes(A):
self.assertEqual(A._subclasses, ('A',))
self.assertEqual(A._subclasses, ('A', 'A.B'))
self.assertEqual(B._subclasses, ('A.B',))
with no_sub_classes(B):
self.assertEqual(B._subclasses, ('A.B',))
self.assertEqual(B._subclasses, ('A.B',))
def test_no_subclass_context_manager_does_not_swallow_exception(self):
class User(Document):
name = StringField()
with self.assertRaises(TypeError):
with no_sub_classes(User):
raise TypeError()
def test_query_counter(self): def test_query_counter(self):
connect('mongoenginetest') connect('mongoenginetest')
db = get_db() db = get_db()

View File

@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection
from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.python_support import IS_PYMONGO_3
MONGO_TEST_DB = 'mongoenginetest' MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
class MongoDBTestCase(unittest.TestCase): class MongoDBTestCase(unittest.TestCase):
"""Base class for tests that need a mongodb connection """Base class for tests that need a mongodb connection
db is being dropped automatically It ensures that the db is clean at the beginning and dropped at the end automatically
""" """
@classmethod @classmethod
@ -32,6 +32,7 @@ def get_mongodb_version():
""" """
return tuple(get_connection().server_info()['versionArray']) return tuple(get_connection().server_info()['versionArray'])
def _decorated_with_ver_requirement(func, ver_tuple): def _decorated_with_ver_requirement(func, ver_tuple):
"""Return a given function decorated with the version requirement """Return a given function decorated with the version requirement
for a particular MongoDB version tuple. for a particular MongoDB version tuple.
@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple):
return _inner return _inner
def needs_mongodb_v26(func): def needs_mongodb_v26(func):
"""Raise a SkipTest exception if we're working with MongoDB version """Raise a SkipTest exception if we're working with MongoDB version
lower than v2.6. lower than v2.6.
""" """
return _decorated_with_ver_requirement(func, (2, 6)) return _decorated_with_ver_requirement(func, (2, 6))
def needs_mongodb_v3(func): def needs_mongodb_v3(func):
"""Raise a SkipTest exception if we're working with MongoDB version """Raise a SkipTest exception if we're working with MongoDB version
lower than v3.0. lower than v3.0.
""" """
return _decorated_with_ver_requirement(func, (3, 0)) return _decorated_with_ver_requirement(func, (3, 0))
def skip_pymongo3(f): def skip_pymongo3(f):
"""Raise a SkipTest exception if we're running a test against """Raise a SkipTest exception if we're running a test against
PyMongo v3.x. PyMongo v3.x.