Merge branch 'master' of https://github.com/MongoEngine/mongoengine into dax_py3
This commit is contained in:
commit
f29b93c762
@ -513,6 +513,9 @@ If a dictionary is passed then the following options are available:
|
||||
Allows you to automatically expire data from a collection by setting the
|
||||
time in seconds to expire the a field.
|
||||
|
||||
:attr:`name` (Optional)
|
||||
Allows you to specify a name for the index
|
||||
|
||||
.. note::
|
||||
|
||||
Inheritance adds extra fields indices see: :ref:`document-inheritance`.
|
||||
|
@ -57,7 +57,8 @@ document values for example::
|
||||
|
||||
def clean(self):
|
||||
"""Ensures that only published essays have a `pub_date` and
|
||||
automatically sets the pub_date if published and not set"""
|
||||
automatically sets `pub_date` if essay is published and `pub_date`
|
||||
is not set"""
|
||||
if self.status == 'Draft' and self.pub_date is not None:
|
||||
msg = 'Draft entries should not have a publication date.'
|
||||
raise ValidationError(msg)
|
||||
|
@ -104,6 +104,18 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
|
||||
conn_settings['authentication_source'] = uri_options['authsource']
|
||||
if 'authmechanism' in uri_options:
|
||||
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
|
||||
if IS_PYMONGO_3 and 'readpreference' in uri_options:
|
||||
read_preferences = (
|
||||
ReadPreference.NEAREST,
|
||||
ReadPreference.PRIMARY,
|
||||
ReadPreference.PRIMARY_PREFERRED,
|
||||
ReadPreference.SECONDARY,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
read_pf_mode = uri_options['readpreference'].lower()
|
||||
for preference in read_preferences:
|
||||
if preference.name.lower() == read_pf_mode:
|
||||
conn_settings['read_preference'] = preference
|
||||
break
|
||||
else:
|
||||
resolved_hosts.append(entity)
|
||||
conn_settings['host'] = resolved_hosts
|
||||
|
@ -145,18 +145,17 @@ class no_sub_classes(object):
|
||||
:param cls: the class to turn querying sub classes on
|
||||
"""
|
||||
self.cls = cls
|
||||
self.cls_initial_subclasses = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Change the objects default and _auto_dereference values."""
|
||||
self.cls._all_subclasses = self.cls._subclasses
|
||||
self.cls._subclasses = (self.cls,)
|
||||
self.cls_initial_subclasses = self.cls._subclasses
|
||||
self.cls._subclasses = (self.cls._class_name,)
|
||||
return self.cls
|
||||
|
||||
def __exit__(self, t, value, traceback):
|
||||
"""Reset the default and _auto_dereference values."""
|
||||
self.cls._subclasses = self.cls._all_subclasses
|
||||
delattr(self.cls, '_all_subclasses')
|
||||
return self.cls
|
||||
self.cls._subclasses = self.cls_initial_subclasses
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
@ -215,7 +214,7 @@ class query_counter(object):
|
||||
"""Get the number of queries."""
|
||||
ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}}
|
||||
count = self.db.system.profile.find(ignore_query).count() - self.counter
|
||||
self.counter += 1
|
||||
self.counter += 1 # Account for the query we just fired
|
||||
return count
|
||||
|
||||
|
||||
|
@ -364,7 +364,8 @@ class FloatField(BaseField):
|
||||
|
||||
|
||||
class DecimalField(BaseField):
|
||||
"""Fixed-point decimal number field.
|
||||
"""Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
|
||||
If using floats, beware of Decimal to float conversion (potential precision loss)
|
||||
|
||||
.. versionchanged:: 0.8
|
||||
.. versionadded:: 0.3
|
||||
@ -375,7 +376,9 @@ class DecimalField(BaseField):
|
||||
"""
|
||||
:param min_value: Validation rule for the minimum acceptable value.
|
||||
:param max_value: Validation rule for the maximum acceptable value.
|
||||
:param force_string: Store as a string.
|
||||
:param force_string: Store the value as a string (instead of a float).
|
||||
Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
|
||||
and some query operator won't work (e.g: inc, dec)
|
||||
:param precision: Number of decimal places to store.
|
||||
:param rounding: The rounding rule from the python decimal library:
|
||||
|
||||
@ -647,9 +650,17 @@ class EmbeddedDocumentField(BaseField):
|
||||
def document_type(self):
|
||||
if isinstance(self.document_type_obj, six.string_types):
|
||||
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
|
||||
self.document_type_obj = self.owner_document
|
||||
resolved_document_type = self.owner_document
|
||||
else:
|
||||
self.document_type_obj = get_document(self.document_type_obj)
|
||||
resolved_document_type = get_document(self.document_type_obj)
|
||||
|
||||
if not issubclass(resolved_document_type, EmbeddedDocument):
|
||||
# Due to the late resolution of the document_type
|
||||
# There is a chance that it won't be an EmbeddedDocument (#1661)
|
||||
self.error('Invalid embedded document class provided to an '
|
||||
'EmbeddedDocumentField')
|
||||
self.document_type_obj = resolved_document_type
|
||||
|
||||
return self.document_type_obj
|
||||
|
||||
def to_python(self, value):
|
||||
@ -1029,11 +1040,13 @@ class ReferenceField(BaseField):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Bar(Document):
|
||||
content = StringField()
|
||||
foo = ReferenceField('Foo')
|
||||
class Org(Document):
|
||||
owner = ReferenceField('User')
|
||||
|
||||
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
|
||||
class User(Document):
|
||||
org = ReferenceField('Org', reverse_delete_rule=CASCADE)
|
||||
|
||||
User.register_delete_rule(Org, 'owner', DENY)
|
||||
|
||||
.. versionchanged:: 0.5 added `reverse_delete_rule`
|
||||
"""
|
||||
|
@ -147,7 +147,7 @@ def query(_doc_cls=None, **kwargs):
|
||||
if op is None or key not in mongo_query:
|
||||
mongo_query[key] = value
|
||||
elif key in mongo_query:
|
||||
if isinstance(mongo_query[key], dict):
|
||||
if isinstance(mongo_query[key], dict) and isinstance(value, dict):
|
||||
mongo_query[key].update(value)
|
||||
# $max/minDistance needs to come last - convert to SON
|
||||
value_dict = mongo_query[key]
|
||||
@ -201,30 +201,37 @@ def update(_doc_cls=None, **update):
|
||||
format.
|
||||
"""
|
||||
mongo_update = {}
|
||||
|
||||
for key, value in update.items():
|
||||
if key == '__raw__':
|
||||
mongo_update.update(value)
|
||||
continue
|
||||
|
||||
parts = key.split('__')
|
||||
|
||||
# if there is no operator, default to 'set'
|
||||
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
|
||||
parts.insert(0, 'set')
|
||||
|
||||
# Check for an operator and transform to mongo-style if there is
|
||||
op = None
|
||||
if parts[0] in UPDATE_OPERATORS:
|
||||
op = parts.pop(0)
|
||||
# Convert Pythonic names to Mongo equivalents
|
||||
if op in ('push_all', 'pull_all'):
|
||||
op = op.replace('_all', 'All')
|
||||
elif op == 'dec':
|
||||
operator_map = {
|
||||
'push_all': 'pushAll',
|
||||
'pull_all': 'pullAll',
|
||||
'dec': 'inc',
|
||||
'add_to_set': 'addToSet',
|
||||
'set_on_insert': 'setOnInsert'
|
||||
}
|
||||
if op == 'dec':
|
||||
# Support decrement by flipping a positive value's sign
|
||||
# and using 'inc'
|
||||
op = 'inc'
|
||||
value = -value
|
||||
elif op == 'add_to_set':
|
||||
op = 'addToSet'
|
||||
elif op == 'set_on_insert':
|
||||
op = 'setOnInsert'
|
||||
# If the operator doesn't found from operator map, the op value
|
||||
# will stay unchanged
|
||||
op = operator_map.get(op, op)
|
||||
|
||||
match = None
|
||||
if parts[-1] in COMPARISON_OPERATORS:
|
||||
@ -291,6 +298,8 @@ def update(_doc_cls=None, **update):
|
||||
value = field.prepare_query_value(op, value)
|
||||
elif op == 'unset':
|
||||
value = 1
|
||||
elif op == 'inc':
|
||||
value = field.prepare_query_value(op, value)
|
||||
|
||||
if match:
|
||||
match = '$' + match
|
||||
|
@ -2147,6 +2147,15 @@ class FieldTest(MongoDBTestCase):
|
||||
]))
|
||||
self.assertEqual(a.b.c.txt, 'hi')
|
||||
|
||||
def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self):
|
||||
raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet")
|
||||
|
||||
class MyDoc2(Document):
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
|
||||
class MyDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
def test_embedded_document_validation(self):
|
||||
"""Ensure that invalid embedded documents cannot be assigned to
|
||||
embedded document fields.
|
||||
@ -4388,6 +4397,44 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
|
||||
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
|
||||
|
||||
|
||||
class TestEmbeddedDocumentField(MongoDBTestCase):
|
||||
def test___init___(self):
|
||||
class MyDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
field = EmbeddedDocumentField(MyDoc)
|
||||
self.assertEqual(field.document_type_obj, MyDoc)
|
||||
|
||||
field2 = EmbeddedDocumentField('MyDoc')
|
||||
self.assertEqual(field2.document_type_obj, 'MyDoc')
|
||||
|
||||
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
|
||||
with self.assertRaises(ValidationError):
|
||||
EmbeddedDocumentField(dict)
|
||||
|
||||
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
|
||||
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
with self.assertRaises(ValidationError) as ctx:
|
||||
emb.document_type
|
||||
self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
|
||||
|
||||
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
|
||||
# Relates to #1661
|
||||
class MyDoc(Document):
|
||||
name = StringField()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingDoc(Document):
|
||||
emb = EmbeddedDocumentField(MyDoc)
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
class MyFailingdoc2(Document):
|
||||
emb = EmbeddedDocumentField('MyDoc')
|
||||
|
||||
class CachedReferenceFieldTest(MongoDBTestCase):
|
||||
|
||||
def test_cached_reference_field_get_and_save(self):
|
||||
|
@ -3,6 +3,7 @@
|
||||
import datetime
|
||||
import unittest
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
from bson import DBRef, ObjectId
|
||||
from nose.plugins.skip import SkipTest
|
||||
@ -1202,6 +1203,14 @@ class QuerySetTest(unittest.TestCase):
|
||||
BlogPost.drop_collection()
|
||||
Blog.drop_collection()
|
||||
|
||||
def test_filter_chaining_with_regex(self):
|
||||
person = self.Person(name='Guido van Rossum')
|
||||
person.save()
|
||||
|
||||
people = self.Person.objects
|
||||
people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum')
|
||||
self.assertEqual(people.count(), 1)
|
||||
|
||||
def assertSequence(self, qs, expected):
|
||||
qs = list(qs)
|
||||
expected = list(expected)
|
||||
@ -1851,21 +1860,16 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
||||
|
||||
def test_update(self):
|
||||
"""Ensure that atomic updates work properly.
|
||||
"""
|
||||
def test_update_intfield_operator(self):
|
||||
class BlogPost(Document):
|
||||
name = StringField()
|
||||
title = StringField()
|
||||
hits = IntField()
|
||||
tags = ListField(StringField())
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(name="Test Post", hits=5, tags=['test'])
|
||||
post = BlogPost(hits=5)
|
||||
post.save()
|
||||
|
||||
BlogPost.objects.update(set__hits=10)
|
||||
BlogPost.objects.update_one(set__hits=10)
|
||||
post.reload()
|
||||
self.assertEqual(post.hits, 10)
|
||||
|
||||
@ -1882,6 +1886,55 @@ class QuerySetTest(unittest.TestCase):
|
||||
post.reload()
|
||||
self.assertEqual(post.hits, 11)
|
||||
|
||||
def test_update_decimalfield_operator(self):
|
||||
class BlogPost(Document):
|
||||
review = DecimalField()
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(review=3.5)
|
||||
post.save()
|
||||
|
||||
BlogPost.objects.update_one(inc__review=0.1) # test with floats
|
||||
post.reload()
|
||||
self.assertEqual(float(post.review), 3.6)
|
||||
|
||||
BlogPost.objects.update_one(dec__review=0.1)
|
||||
post.reload()
|
||||
self.assertEqual(float(post.review), 3.5)
|
||||
|
||||
BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
|
||||
post.reload()
|
||||
self.assertEqual(float(post.review), 3.62)
|
||||
|
||||
BlogPost.objects.update_one(dec__review=Decimal(0.12))
|
||||
post.reload()
|
||||
self.assertEqual(float(post.review), 3.5)
|
||||
|
||||
def test_update_decimalfield_operator_not_working_with_force_string(self):
|
||||
class BlogPost(Document):
|
||||
review = DecimalField(force_string=True)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(review=3.5)
|
||||
post.save()
|
||||
|
||||
with self.assertRaises(OperationError):
|
||||
BlogPost.objects.update_one(inc__review=0.1) # test with floats
|
||||
|
||||
def test_update_listfield_operator(self):
|
||||
"""Ensure that atomic updates work properly.
|
||||
"""
|
||||
class BlogPost(Document):
|
||||
tags = ListField(StringField())
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(tags=['test'])
|
||||
post.save()
|
||||
|
||||
# ListField operator
|
||||
BlogPost.objects.update(push__tags='mongo')
|
||||
post.reload()
|
||||
self.assertTrue('mongo' in post.tags)
|
||||
@ -1900,13 +1953,23 @@ class QuerySetTest(unittest.TestCase):
|
||||
post.reload()
|
||||
self.assertEqual(post.tags.count('unique'), 1)
|
||||
|
||||
self.assertNotEqual(post.hits, None)
|
||||
BlogPost.objects.update_one(unset__hits=1)
|
||||
post.reload()
|
||||
self.assertEqual(post.hits, None)
|
||||
BlogPost.drop_collection()
|
||||
|
||||
def test_update_unset(self):
|
||||
class BlogPost(Document):
|
||||
title = StringField()
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(title='garbage').save()
|
||||
|
||||
self.assertNotEqual(post.title, None)
|
||||
BlogPost.objects.update_one(unset__title=1)
|
||||
post.reload()
|
||||
self.assertEqual(post.title, None)
|
||||
pymongo_doc = BlogPost.objects.as_pymongo().first()
|
||||
self.assertNotIn('title', pymongo_doc)
|
||||
|
||||
@needs_mongodb_v26
|
||||
def test_update_push_with_position(self):
|
||||
"""Ensure that the 'push' update with position works properly.
|
||||
|
@ -364,6 +364,12 @@ class ConnectionTest(unittest.TestCase):
|
||||
date_doc = DateDoc.objects.first()
|
||||
self.assertEqual(d, date_doc.the_date)
|
||||
|
||||
def test_read_preference_from_parse(self):
|
||||
if IS_PYMONGO_3:
|
||||
from pymongo import ReadPreference
|
||||
conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
|
||||
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
def test_multiple_connection_settings(self):
|
||||
connect('mongoenginetest', alias='t1', host="localhost")
|
||||
|
||||
|
@ -140,8 +140,6 @@ class ContextManagersTest(unittest.TestCase):
|
||||
def test_no_sub_classes(self):
|
||||
class A(Document):
|
||||
x = IntField()
|
||||
y = IntField()
|
||||
|
||||
meta = {'allow_inheritance': True}
|
||||
|
||||
class B(A):
|
||||
@ -152,29 +150,29 @@ class ContextManagersTest(unittest.TestCase):
|
||||
|
||||
A.drop_collection()
|
||||
|
||||
A(x=10, y=20).save()
|
||||
A(x=15, y=30).save()
|
||||
B(x=20, y=40).save()
|
||||
B(x=30, y=50).save()
|
||||
C(x=40, y=60).save()
|
||||
A(x=10).save()
|
||||
A(x=15).save()
|
||||
B(x=20).save()
|
||||
B(x=30).save()
|
||||
C(x=40).save()
|
||||
|
||||
self.assertEqual(A.objects.count(), 5)
|
||||
self.assertEqual(B.objects.count(), 3)
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
with no_sub_classes(A) as A:
|
||||
with no_sub_classes(A):
|
||||
self.assertEqual(A.objects.count(), 2)
|
||||
|
||||
for obj in A.objects:
|
||||
self.assertEqual(obj.__class__, A)
|
||||
|
||||
with no_sub_classes(B) as B:
|
||||
with no_sub_classes(B):
|
||||
self.assertEqual(B.objects.count(), 2)
|
||||
|
||||
for obj in B.objects:
|
||||
self.assertEqual(obj.__class__, B)
|
||||
|
||||
with no_sub_classes(C) as C:
|
||||
with no_sub_classes(C):
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
for obj in C.objects:
|
||||
@ -185,6 +183,32 @@ class ContextManagersTest(unittest.TestCase):
|
||||
self.assertEqual(B.objects.count(), 3)
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
def test_no_sub_classes_modification_to_document_class_are_temporary(self):
|
||||
class A(Document):
|
||||
x = IntField()
|
||||
meta = {'allow_inheritance': True}
|
||||
|
||||
class B(A):
|
||||
z = IntField()
|
||||
|
||||
self.assertEqual(A._subclasses, ('A', 'A.B'))
|
||||
with no_sub_classes(A):
|
||||
self.assertEqual(A._subclasses, ('A',))
|
||||
self.assertEqual(A._subclasses, ('A', 'A.B'))
|
||||
|
||||
self.assertEqual(B._subclasses, ('A.B',))
|
||||
with no_sub_classes(B):
|
||||
self.assertEqual(B._subclasses, ('A.B',))
|
||||
self.assertEqual(B._subclasses, ('A.B',))
|
||||
|
||||
def test_no_subclass_context_manager_does_not_swallow_exception(self):
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
with no_sub_classes(User):
|
||||
raise TypeError()
|
||||
|
||||
def test_query_counter(self):
|
||||
connect('mongoenginetest')
|
||||
db = get_db()
|
||||
|
@ -7,12 +7,12 @@ from mongoengine.connection import get_db, get_connection
|
||||
from mongoengine.python_support import IS_PYMONGO_3
|
||||
|
||||
|
||||
MONGO_TEST_DB = 'mongoenginetest'
|
||||
MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
|
||||
|
||||
|
||||
class MongoDBTestCase(unittest.TestCase):
|
||||
"""Base class for tests that need a mongodb connection
|
||||
db is being dropped automatically
|
||||
It ensures that the db is clean at the beginning and dropped at the end automatically
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -32,6 +32,7 @@ def get_mongodb_version():
|
||||
"""
|
||||
return tuple(get_connection().server_info()['versionArray'])
|
||||
|
||||
|
||||
def _decorated_with_ver_requirement(func, ver_tuple):
|
||||
"""Return a given function decorated with the version requirement
|
||||
for a particular MongoDB version tuple.
|
||||
@ -50,18 +51,21 @@ def _decorated_with_ver_requirement(func, ver_tuple):
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def needs_mongodb_v26(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v2.6.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, (2, 6))
|
||||
|
||||
|
||||
def needs_mongodb_v3(func):
|
||||
"""Raise a SkipTest exception if we're working with MongoDB version
|
||||
lower than v3.0.
|
||||
"""
|
||||
return _decorated_with_ver_requirement(func, (3, 0))
|
||||
|
||||
|
||||
def skip_pymongo3(f):
|
||||
"""Raise a SkipTest exception if we're running a test against
|
||||
PyMongo v3.x.
|
||||
|
Loading…
x
Reference in New Issue
Block a user