From cab659dce6048d391a4dd470d5bf4b8cbbaf36ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 16 Feb 2019 21:54:05 +0100 Subject: [PATCH 01/71] Fix documentation of Queryset.update regarding full_result #1995 --- docs/changelog.rst | 2 ++ mongoengine/queryset/base.py | 20 ++++++++++++++------ tests/queryset/queryset.py | 13 +++++++++++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index d8bed7e6..b6c5b277 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,7 @@ Development - (Fill this out as you fix issues and develop your features). - Fix .only() working improperly after using .count() of the same instance of QuerySet - POTENTIAL BREAKING CHANGE: All result fields are now passed, including internal fields (_cls, _id) when using `QuerySet.as_pymongo` #1976 +- Document a BREAKING CHANGE introduced in 0.15.3 and not reported at that time (#1995) ================= Changes in 0.16.3 @@ -64,6 +65,7 @@ Changes in 0.16.0 Changes in 0.15.3 ================= +- BREAKING CHANGES: `Queryset.update/update_one` methods now returns an UpdateResult when `full_result=True` is provided and no longer a dict (relates to #1491) - Subfield resolve error in generic_emdedded_document query #1651 #1652 - use each modifier only with $position #1673 #1675 - Improve LazyReferenceField and GenericLazyReferenceField with nested fields #1704 diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 0ebeafa6..391f4f86 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -498,11 +498,12 @@ class BaseQuerySet(object): ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. - :param full_result: Return the full result dictionary rather than just the number - updated, e.g. return - ``{'n': 2, 'nModified': 2, 'ok': 1.0, 'updatedExisting': True}``. + :param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number + updated items :param update: Django-style update keyword arguments + :returns the number of updated documents (unless ``full_result`` is True) + .. versionadded:: 0.2 """ if not update and not upsert: @@ -566,7 +567,7 @@ class BaseQuerySet(object): document = self._document.objects.with_id(atomic_update.upserted_id) return document - def update_one(self, upsert=False, write_concern=None, **update): + def update_one(self, upsert=False, write_concern=None, full_result=False, **update): """Perform an atomic update on the fields of the first document matched by the query. @@ -577,12 +578,19 @@ class BaseQuerySet(object): ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. + :param full_result: Return the associated ``pymongo.UpdateResult`` rather than just the number + updated items :param update: Django-style update keyword arguments - + full_result + :returns the number of updated documents (unless ``full_result`` is True) .. versionadded:: 0.2 """ return self.update( - upsert=upsert, multi=False, write_concern=write_concern, **update) + upsert=upsert, + multi=False, + write_concern=write_concern, + full_result=full_result, + **update) def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): """Update and return the updated document. diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index c183aa86..4dac6922 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2233,6 +2233,19 @@ class QuerySetTest(unittest.TestCase): bar.reload() self.assertEqual(len(bar.foos), 0) + def test_update_one_check_return_with_full_result(self): + class BlogTag(Document): + name = StringField(required=True) + + BlogTag.drop_collection() + + BlogTag(name='garbage').save() + default_update = BlogTag.objects.update_one(name='new') + self.assertEqual(default_update, 1) + + full_result_update = BlogTag.objects.update_one(name='new', full_result=True) + self.assertIsInstance(full_result_update, UpdateResult) + def test_update_one_pop_generic_reference(self): class BlogTag(Document): From 4a46f5f095e424fb0b91ac4ef85bed9241f33ae5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 17 Feb 2019 21:32:32 +0100 Subject: [PATCH 02/71] Separate fields tests into separate modules (date/datetime/complexdatetime) relates to #1983 --- tests/fields/fields.py | 497 +------------------- tests/fields/test_complex_datetime_field.py | 189 ++++++++ tests/fields/test_date_field.py | 184 ++++++++ tests/fields/test_datetime_field.py | 208 ++++++++ 4 files changed, 583 insertions(+), 495 deletions(-) create mode 100644 tests/fields/test_complex_datetime_field.py create mode 100644 tests/fields/test_date_field.py create mode 100644 tests/fields/test_datetime_field.py diff --git a/tests/fields/fields.py b/tests/fields/fields.py index b09c0a2d..8f256782 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -10,11 +10,6 @@ import sys from nose.plugins.skip import SkipTest import six -try: - import dateutil -except ImportError: - dateutil = None - from decimal import Decimal from bson import Binary, DBRef, ObjectId, SON @@ -30,55 +25,9 @@ from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, from tests.utils import MongoDBTestCase -__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") - class FieldTest(MongoDBTestCase): - def test_datetime_from_empty_string(self): - """ - Ensure an exception is raised when trying to - cast an empty string to datetime. - """ - class MyDoc(Document): - dt = DateTimeField() - - md = MyDoc(dt='') - self.assertRaises(ValidationError, md.save) - - def test_date_from_empty_string(self): - """ - Ensure an exception is raised when trying to - cast an empty string to datetime. - """ - class MyDoc(Document): - dt = DateField() - - md = MyDoc(dt='') - self.assertRaises(ValidationError, md.save) - - def test_datetime_from_whitespace_string(self): - """ - Ensure an exception is raised when trying to - cast a whitespace-only string to datetime. - """ - class MyDoc(Document): - dt = DateTimeField() - - md = MyDoc(dt=' ') - self.assertRaises(ValidationError, md.save) - - def test_date_from_whitespace_string(self): - """ - Ensure an exception is raised when trying to - cast a whitespace-only string to datetime. - """ - class MyDoc(Document): - dt = DateField() - - md = MyDoc(dt=' ') - self.assertRaises(ValidationError, md.save) - def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. @@ -695,273 +644,6 @@ class FieldTest(MongoDBTestCase): person.api_key = api_key self.assertRaises(ValidationError, person.validate) - def test_datetime_validation(self): - """Ensure that invalid values cannot be assigned to datetime - fields. - """ - class LogEntry(Document): - time = DateTimeField() - - log = LogEntry() - log.time = datetime.datetime.now() - log.validate() - - log.time = datetime.date.today() - log.validate() - - log.time = datetime.datetime.now().isoformat(' ') - log.validate() - - if dateutil: - log.time = datetime.datetime.now().isoformat('T') - log.validate() - - log.time = -1 - self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' - self.assertRaises(ValidationError, log.validate) - - def test_date_validation(self): - """Ensure that invalid values cannot be assigned to datetime - fields. - """ - class LogEntry(Document): - time = DateField() - - log = LogEntry() - log.time = datetime.datetime.now() - log.validate() - - log.time = datetime.date.today() - log.validate() - - log.time = datetime.datetime.now().isoformat(' ') - log.validate() - - if dateutil: - log.time = datetime.datetime.now().isoformat('T') - log.validate() - - log.time = -1 - self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' - self.assertRaises(ValidationError, log.validate) - - def test_datetime_tz_aware_mark_as_changed(self): - from mongoengine import connection - - # Reset the connections - connection._connection_settings = {} - connection._connections = {} - connection._dbs = {} - - connect(db='mongoenginetest', tz_aware=True) - - class LogEntry(Document): - time = DateTimeField() - - LogEntry.drop_collection() - - LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save() - - log = LogEntry.objects.first() - log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(['time'], log._changed_fields) - - def test_datetime(self): - """Tests showing pymongo datetime fields handling of microseconds. - Microseconds are rounded to the nearest millisecond and pre UTC - handling is wonky. - - See: http://api.mongodb.org/python/current/api/bson/son.html#dt - """ - class LogEntry(Document): - date = DateTimeField() - - LogEntry.drop_collection() - - # Test can save dates - log = LogEntry() - log.date = datetime.date.today() - log.save() - log.reload() - self.assertEqual(log.date.date(), datetime.date.today()) - - # Post UTC - microseconds are rounded (down) nearest millisecond and - # dropped - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) - - # Post UTC - microseconds are rounded (down) nearest millisecond - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) - log.date = d1 - log.save() - log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) - - if not six.PY3: - # Pre UTC dates microseconds below 1000 are dropped - # This does not seem to be true in PY3 - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) - log.date = d1 - log.save() - log.reload() - self.assertNotEqual(log.date, d1) - self.assertEqual(log.date, d2) - - def test_date(self): - """Tests showing pymongo date fields - - See: http://api.mongodb.org/python/current/api/bson/son.html#dt - """ - class LogEntry(Document): - date = DateField() - - LogEntry.drop_collection() - - # Test can save dates - log = LogEntry() - log.date = datetime.date.today() - log.save() - log.reload() - self.assertEqual(log.date, datetime.date.today()) - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) - - if not six.PY3: - # Pre UTC dates microseconds below 1000 are dropped - # This does not seem to be true in PY3 - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1.date()) - self.assertEqual(log.date, d2.date()) - - def test_datetime_usage(self): - """Tests for regular datetime fields""" - class LogEntry(Document): - date = DateTimeField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.validate() - log.save() - - for query in (d1, d1.isoformat(' ')): - log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) - - if dateutil: - log1 = LogEntry.objects.get(date=d1.isoformat('T')) - self.assertEqual(log, log1) - - # create additional 19 log entries for a total of 20 - for i in range(1971, 1990): - d = datetime.datetime(i, 1, 1, 0, 0, 1) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 20) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) - - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(1980, 1, 1), - date__gte=datetime.datetime(1975, 1, 1), - ) - self.assertEqual(logs.count(), 5) - - def test_date_usage(self): - """Tests for regular datetime fields""" - class LogEntry(Document): - date = DateField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) - log = LogEntry() - log.date = d1 - log.validate() - log.save() - - for query in (d1, d1.isoformat(' ')): - log1 = LogEntry.objects.get(date=query) - self.assertEqual(log, log1) - - if dateutil: - log1 = LogEntry.objects.get(date=d1.isoformat('T')) - self.assertEqual(log, log1) - - # create additional 19 log entries for a total of 20 - for i in range(1971, 1990): - d = datetime.datetime(i, 1, 1, 0, 0, 1) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 20) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 19: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 10) - def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements.""" AccessLevelChoices = ( @@ -1699,7 +1381,7 @@ class FieldTest(MongoDBTestCase): post.save() post = BlogPost() - post.info = {'title' : 'dollar_sign', 'details' : {'te$t' : 'test'} } + post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}} post.save() post = BlogPost() @@ -1718,7 +1400,7 @@ class FieldTest(MongoDBTestCase): post = BlogPost.objects.filter(info__title__exact='dollar_sign').first() self.assertIn('te$t', post['info']['details']) - + # Confirm handles non strings or non existing keys self.assertEqual( BlogPost.objects.filter(info__details__test__exact=5).count(), 0) @@ -5400,180 +5082,5 @@ class GenericLazyReferenceFieldTest(MongoDBTestCase): check_fields_type(occ) -class ComplexDateTimeFieldTest(MongoDBTestCase): - def test_complexdatetime_storage(self): - """Tests for complex datetime fields - which can handle - microseconds without rounding. - """ - class LogEntry(Document): - date = ComplexDateTimeField() - date_with_dots = ComplexDateTimeField(separator='.') - - LogEntry.drop_collection() - - # Post UTC - microseconds are rounded (down) nearest millisecond and - # dropped - with default datetimefields - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Post UTC - microseconds are rounded (down) nearest millisecond - with - # default datetimefields - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Pre UTC dates microseconds below 1000 are dropped - with default - # datetimefields - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Pre UTC microseconds above 1000 is wonky - with default datetimefields - # log.date has an invalid microsecond value so I can't construct - # a date to compare. - for i in range(1001, 3113, 33): - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) - - # Test string padding - microsecond = map(int, [math.pow(10, x) for x in range(6)]) - mm = dd = hh = ii = ss = [1, 10] - - for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): - stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] - self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) - - # Test separator - stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] - self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) - - def test_complexdatetime_usage(self): - """Tests for complex datetime fields - which can handle - microseconds without rounding. - """ - class LogEntry(Document): - date = ComplexDateTimeField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1950, 1, 1, 0, 0, 1, 999) - log = LogEntry() - log.date = d1 - log.save() - - log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) - - # create extra 59 log entries for a total of 60 - for i in range(1951, 2010): - d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 60) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 59: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 59: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) - - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2011, 1, 1), - date__gte=datetime.datetime(2000, 1, 1), - ) - self.assertEqual(logs.count(), 10) - - LogEntry.drop_collection() - - # Test microsecond-level ordering/filtering - for microsecond in (99, 999, 9999, 10000): - LogEntry( - date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) - ).save() - - logs = list(LogEntry.objects.order_by('date')) - for next_idx, log in enumerate(logs[:-1], start=1): - next_log = logs[next_idx] - self.assertTrue(log.date < next_log.date) - - logs = list(LogEntry.objects.order_by('-date')) - for next_idx, log in enumerate(logs[:-1], start=1): - next_log = logs[next_idx] - self.assertTrue(log.date > next_log.date) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) - self.assertEqual(logs.count(), 4) - - def test_no_default_value(self): - class Log(Document): - timestamp = ComplexDateTimeField() - - Log.drop_collection() - - log = Log() - self.assertIsNone(log.timestamp) - log.save() - - fetched_log = Log.objects.with_id(log.id) - self.assertIsNone(fetched_log.timestamp) - - def test_default_static_value(self): - NOW = datetime.datetime.utcnow() - class Log(Document): - timestamp = ComplexDateTimeField(default=NOW) - - Log.drop_collection() - - log = Log() - self.assertEqual(log.timestamp, NOW) - log.save() - - fetched_log = Log.objects.with_id(log.id) - self.assertEqual(fetched_log.timestamp, NOW) - - def test_default_callable(self): - NOW = datetime.datetime.utcnow() - - class Log(Document): - timestamp = ComplexDateTimeField(default=datetime.datetime.utcnow) - - Log.drop_collection() - - log = Log() - self.assertGreaterEqual(log.timestamp, NOW) - log.save() - - fetched_log = Log.objects.with_id(log.id) - self.assertGreaterEqual(fetched_log.timestamp, NOW) - - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py new file mode 100644 index 00000000..bac534c0 --- /dev/null +++ b/tests/fields/test_complex_datetime_field.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +import datetime +import math +import itertools +import re + +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class ComplexDateTimeFieldTest(MongoDBTestCase): + def test_complexdatetime_storage(self): + """Tests for complex datetime fields - which can handle + microseconds without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + date_with_dots = ComplexDateTimeField(separator='.') + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped - with default datetimefields + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Post UTC - microseconds are rounded (down) nearest millisecond - with + # default datetimefields + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Pre UTC dates microseconds below 1000 are dropped - with default + # datetimefields + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Pre UTC microseconds above 1000 is wonky - with default datetimefields + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + for i in range(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + # Test string padding + microsecond = map(int, [math.pow(10, x) for x in range(6)]) + mm = dd = hh = ii = ss = [1, 10] + + for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): + stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] + self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) + + # Test separator + stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] + self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) + + def test_complexdatetime_usage(self): + """Tests for complex datetime fields - which can handle + microseconds without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1950, 1, 1, 0, 0, 1, 999) + log = LogEntry() + log.date = d1 + log.save() + + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + # create extra 59 log entries for a total of 60 + for i in range(1951, 2010): + d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 60) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 59: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 59: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2011, 1, 1), + date__gte=datetime.datetime(2000, 1, 1), + ) + self.assertEqual(logs.count(), 10) + + LogEntry.drop_collection() + + # Test microsecond-level ordering/filtering + for microsecond in (99, 999, 9999, 10000): + LogEntry( + date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) + ).save() + + logs = list(LogEntry.objects.order_by('date')) + for next_idx, log in enumerate(logs[:-1], start=1): + next_log = logs[next_idx] + self.assertTrue(log.date < next_log.date) + + logs = list(LogEntry.objects.order_by('-date')) + for next_idx, log in enumerate(logs[:-1], start=1): + next_log = logs[next_idx] + self.assertTrue(log.date > next_log.date) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) + self.assertEqual(logs.count(), 4) + + def test_no_default_value(self): + class Log(Document): + timestamp = ComplexDateTimeField() + + Log.drop_collection() + + log = Log() + self.assertIsNone(log.timestamp) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertIsNone(fetched_log.timestamp) + + def test_default_static_value(self): + NOW = datetime.datetime.utcnow() + class Log(Document): + timestamp = ComplexDateTimeField(default=NOW) + + Log.drop_collection() + + log = Log() + self.assertEqual(log.timestamp, NOW) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertEqual(fetched_log.timestamp, NOW) + + def test_default_callable(self): + NOW = datetime.datetime.utcnow() + + class Log(Document): + timestamp = ComplexDateTimeField(default=datetime.datetime.utcnow) + + Log.drop_collection() + + log = Log() + self.assertGreaterEqual(log.timestamp, NOW) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertGreaterEqual(fetched_log.timestamp, NOW) diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py new file mode 100644 index 00000000..b5aed5c1 --- /dev/null +++ b/tests/fields/test_date_field.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +import datetime +import unittest +import uuid +import math +import itertools +import re +import sys + +from nose.plugins.skip import SkipTest +import six + +try: + import dateutil +except ImportError: + dateutil = None + +from decimal import Decimal + +from bson import Binary, DBRef, ObjectId, SON +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long + +from mongoengine import * +from mongoengine.connection import get_db +from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, + _document_registry, LazyReference) + +from tests.utils import MongoDBTestCase + + +class TestDateField(MongoDBTestCase): + def test_date_from_empty_string(self): + """ + Ensure an exception is raised when trying to + cast an empty string to datetime. + """ + class MyDoc(Document): + dt = DateField() + + md = MyDoc(dt='') + self.assertRaises(ValidationError, md.save) + + def test_date_from_whitespace_string(self): + """ + Ensure an exception is raised when trying to + cast a whitespace-only string to datetime. + """ + class MyDoc(Document): + dt = DateField() + + md = MyDoc(dt=' ') + self.assertRaises(ValidationError, md.save) + + def test_default_values_today(self): + """Ensure that default field values are used when creating + a document. + """ + class Person(Document): + day = DateField(default=datetime.date.today) + + person = Person() + person.validate() + self.assertEqual(person.day, person.day) + self.assertEqual(person.day, datetime.date.today()) + self.assertEqual(person._data['day'], person.day) + + def test_date(self): + """Tests showing pymongo date fields + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateField() + + LogEntry.drop_collection() + + # Test can save dates + log = LogEntry() + log.date = datetime.date.today() + log.save() + log.reload() + self.assertEqual(log.date, datetime.date.today()) + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1.date()) + self.assertEqual(log.date, d2.date()) + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1.date()) + self.assertEqual(log.date, d2.date()) + + if not six.PY3: + # Pre UTC dates microseconds below 1000 are dropped + # This does not seem to be true in PY3 + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1.date()) + self.assertEqual(log.date, d2.date()) + + def test_regular_usage(self): + """Tests for regular datetime fields""" + class LogEntry(Document): + date = DateField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.validate() + log.save() + + for query in (d1, d1.isoformat(' ')): + log1 = LogEntry.objects.get(date=query) + self.assertEqual(log, log1) + + if dateutil: + log1 = LogEntry.objects.get(date=d1.isoformat('T')) + self.assertEqual(log, log1) + + # create additional 19 log entries for a total of 20 + for i in range(1971, 1990): + d = datetime.datetime(i, 1, 1, 0, 0, 1) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 20) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 10) + + def test_validation(self): + """Ensure that invalid values cannot be assigned to datetime + fields. + """ + class LogEntry(Document): + time = DateField() + + log = LogEntry() + log.time = datetime.datetime.now() + log.validate() + + log.time = datetime.date.today() + log.validate() + + log.time = datetime.datetime.now().isoformat(' ') + log.validate() + + if dateutil: + log.time = datetime.datetime.now().isoformat('T') + log.validate() + + log.time = -1 + self.assertRaises(ValidationError, log.validate) + log.time = 'ABC' + self.assertRaises(ValidationError, log.validate) diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py new file mode 100644 index 00000000..24d1c777 --- /dev/null +++ b/tests/fields/test_datetime_field.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +import datetime +import six + +try: + import dateutil +except ImportError: + dateutil = None + +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long + +from mongoengine import * +from mongoengine import connection + +from tests.utils import MongoDBTestCase + + +class TestDateTimeField(MongoDBTestCase): + def test_datetime_from_empty_string(self): + """ + Ensure an exception is raised when trying to + cast an empty string to datetime. + """ + class MyDoc(Document): + dt = DateTimeField() + + md = MyDoc(dt='') + self.assertRaises(ValidationError, md.save) + + def test_datetime_from_whitespace_string(self): + """ + Ensure an exception is raised when trying to + cast a whitespace-only string to datetime. + """ + class MyDoc(Document): + dt = DateTimeField() + + md = MyDoc(dt=' ') + self.assertRaises(ValidationError, md.save) + + def test_default_value_utcnow(self): + """Ensure that default field values are used when creating + a document. + """ + class Person(Document): + created = DateTimeField(default=datetime.datetime.utcnow) + + utcnow = datetime.datetime.utcnow() + person = Person() + person.validate() + person_created_t0 = person.created + self.assertLess(person.created - utcnow, datetime.timedelta(seconds=1)) + self.assertEqual(person_created_t0, person.created) # make sure it does not change + self.assertEqual(person._data['created'], person.created) + + def test_handling_microseconds(self): + """Tests showing pymongo datetime fields handling of microseconds. + Microseconds are rounded to the nearest millisecond and pre UTC + handling is wonky. + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + # Test can save dates + log = LogEntry() + log.date = datetime.date.today() + log.save() + log.reload() + self.assertEqual(log.date.date(), datetime.date.today()) + + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertNotEqual(log.date, d1) + self.assertEqual(log.date, d2) + + # Post UTC - microseconds are rounded (down) nearest millisecond + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) + log.date = d1 + log.save() + log.reload() + self.assertNotEqual(log.date, d1) + self.assertEqual(log.date, d2) + + if not six.PY3: + # Pre UTC dates microseconds below 1000 are dropped + # This does not seem to be true in PY3 + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertNotEqual(log.date, d1) + self.assertEqual(log.date, d2) + + def test_regular_usage(self): + """Tests for regular datetime fields""" + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) + log = LogEntry() + log.date = d1 + log.validate() + log.save() + + for query in (d1, d1.isoformat(' ')): + log1 = LogEntry.objects.get(date=query) + self.assertEqual(log, log1) + + if dateutil: + log1 = LogEntry.objects.get(date=d1.isoformat('T')) + self.assertEqual(log, log1) + + # create additional 19 log entries for a total of 20 + for i in range(1971, 1990): + d = datetime.datetime(i, 1, 1, 0, 0, 1) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 20) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 19: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 10) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 10) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(1980, 1, 1), + date__gte=datetime.datetime(1975, 1, 1), + ) + self.assertEqual(logs.count(), 5) + + def test_datetime_validation(self): + """Ensure that invalid values cannot be assigned to datetime + fields. + """ + class LogEntry(Document): + time = DateTimeField() + + log = LogEntry() + log.time = datetime.datetime.now() + log.validate() + + log.time = datetime.date.today() + log.validate() + + log.time = datetime.datetime.now().isoformat(' ') + log.validate() + + if dateutil: + log.time = datetime.datetime.now().isoformat('T') + log.validate() + + log.time = -1 + self.assertRaises(ValidationError, log.validate) + log.time = 'ABC' + self.assertRaises(ValidationError, log.validate) + + +class TestDateTimeTzAware(MongoDBTestCase): + def test_datetime_tz_aware_mark_as_changed(self): + # Reset the connections + connection._connection_settings = {} + connection._connections = {} + connection._dbs = {} + + connect(db='mongoenginetest', tz_aware=True) + + class LogEntry(Document): + time = DateTimeField() + + LogEntry.drop_collection() + + LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save() + + log = LogEntry.objects.first() + log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) + self.assertEqual(['time'], log._changed_fields) From 6d353dae1e52b54f244fb1c6459b5d314e9d522d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 18 Feb 2019 21:08:04 +0100 Subject: [PATCH 03/71] refactored iteritems/itervalues to improve 2/3 compat #2003 --- mongoengine/base/datastructures.py | 5 +++-- mongoengine/base/document.py | 17 +++++++++-------- mongoengine/base/fields.py | 7 ++++--- mongoengine/base/metaclasses.py | 13 +++++++------ mongoengine/context_managers.py | 6 ++++-- mongoengine/dereference.py | 15 ++++++++------- mongoengine/document.py | 3 ++- mongoengine/errors.py | 9 +++++---- mongoengine/fields.py | 5 +++-- mongoengine/queryset/base.py | 5 +++-- mongoengine/queryset/transform.py | 3 ++- tests/document/indexes.py | 28 ++++++++++++++-------------- tests/document/inheritance.py | 4 +++- tests/document/instance.py | 9 ++++----- tests/queryset/queryset.py | 3 ++- tests/test_dereference.py | 27 ++++++++++++++------------- 16 files changed, 87 insertions(+), 72 deletions(-) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 808332b9..fafc08b7 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -2,6 +2,7 @@ import weakref from bson import DBRef import six +from six import iteritems from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned @@ -363,7 +364,7 @@ class StrictDict(object): _classes = {} def __init__(self, **kwargs): - for k, v in kwargs.iteritems(): + for k, v in iteritems(kwargs): setattr(self, k, v) def __getitem__(self, key): @@ -411,7 +412,7 @@ class StrictDict(object): return (key for key in self.__slots__ if hasattr(self, key)) def __len__(self): - return len(list(self.iteritems())) + return len(list(iteritems(self))) def __eq__(self, other): return self.items() == other.items() diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 6a4c6bd9..8587f17f 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -5,6 +5,7 @@ from functools import partial from bson import DBRef, ObjectId, SON, json_util import pymongo import six +from six import iteritems from mongoengine import signals from mongoengine.base.common import get_document @@ -83,7 +84,7 @@ class BaseDocument(object): self._dynamic_fields = SON() # Assign default values to instance - for key, field in self._fields.iteritems(): + for key, field in iteritems(self._fields): if self._db_field_map.get(key, key) in __only_fields: continue value = getattr(self, key, None) @@ -95,14 +96,14 @@ class BaseDocument(object): # Set passed values after initialisation if self._dynamic: dynamic_data = {} - for key, value in values.iteritems(): + for key, value in iteritems(values): if key in self._fields or key == '_id': setattr(self, key, value) else: dynamic_data[key] = value else: FileField = _import_class('FileField') - for key, value in values.iteritems(): + for key, value in iteritems(values): key = self._reverse_db_field_map.get(key, key) if key in self._fields or key in ('id', 'pk', '_cls'): if __auto_convert and value is not None: @@ -118,7 +119,7 @@ class BaseDocument(object): if self._dynamic: self._dynamic_lock = False - for key, value in dynamic_data.iteritems(): + for key, value in iteritems(dynamic_data): setattr(self, key, value) # Flag initialised @@ -513,7 +514,7 @@ class BaseDocument(object): if not hasattr(data, 'items'): iterator = enumerate(data) else: - iterator = data.iteritems() + iterator = iteritems(data) for index_or_key, value in iterator: item_key = '%s%s.' % (base_key, index_or_key) @@ -678,7 +679,7 @@ class BaseDocument(object): # Convert SON to a data dict, making sure each key is a string and # corresponds to the right db field. data = {} - for key, value in son.iteritems(): + for key, value in iteritems(son): key = str(key) key = cls._db_field_map.get(key, key) data[key] = value @@ -694,7 +695,7 @@ class BaseDocument(object): if not _auto_dereference: fields = copy.deepcopy(fields) - for field_name, field in fields.iteritems(): + for field_name, field in iteritems(fields): field._auto_dereference = _auto_dereference if field.db_field in data: value = data[field.db_field] @@ -715,7 +716,7 @@ class BaseDocument(object): # In STRICT documents, remove any keys that aren't in cls._fields if cls.STRICT: - data = {k: v for k, v in data.iteritems() if k in cls._fields} + data = {k: v for k, v in iteritems(data) if k in cls._fields} obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) obj._changed_fields = changed_fields diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index a32544d8..5586c5b7 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -5,6 +5,7 @@ import weakref from bson import DBRef, ObjectId, SON import pymongo import six +from six import iteritems from mongoengine.base.common import UPDATE_OPERATORS from mongoengine.base.datastructures import (BaseDict, BaseList, @@ -382,11 +383,11 @@ class ComplexBaseField(BaseField): if self.field: value_dict = { key: self.field._to_mongo_safe_call(item, use_db_field, fields) - for key, item in value.iteritems() + for key, item in iteritems(value) } else: value_dict = {} - for k, v in value.iteritems(): + for k, v in iteritems(value): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: @@ -423,7 +424,7 @@ class ComplexBaseField(BaseField): errors = {} if self.field: if hasattr(value, 'iteritems') or hasattr(value, 'items'): - sequence = value.iteritems() + sequence = iteritems(value) else: sequence = enumerate(value) for k, v in sequence: diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 8eb10008..a1970825 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -1,6 +1,7 @@ import warnings import six +from six import iteritems, itervalues from mongoengine.base.common import _document_registry from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField @@ -62,7 +63,7 @@ class DocumentMetaclass(type): # Standard object mixin - merge in any Fields if not hasattr(base, '_meta'): base_fields = {} - for attr_name, attr_value in base.__dict__.iteritems(): + for attr_name, attr_value in iteritems(base.__dict__): if not isinstance(attr_value, BaseField): continue attr_value.name = attr_name @@ -74,7 +75,7 @@ class DocumentMetaclass(type): # Discover any document fields field_names = {} - for attr_name, attr_value in attrs.iteritems(): + for attr_name, attr_value in iteritems(attrs): if not isinstance(attr_value, BaseField): continue attr_value.name = attr_name @@ -103,7 +104,7 @@ class DocumentMetaclass(type): attrs['_fields_ordered'] = tuple(i[1] for i in sorted( (v.creation_counter, v.name) - for v in doc_fields.itervalues())) + for v in itervalues(doc_fields))) # # Set document hierarchy @@ -173,7 +174,7 @@ class DocumentMetaclass(type): f.__dict__.update({'im_self': getattr(f, '__self__')}) # Handle delete rules - for field in new_class._fields.itervalues(): + for field in itervalues(new_class._fields): f = field if f.owner_document is None: f.owner_document = new_class @@ -375,7 +376,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.objects = QuerySetManager() # Validate the fields and set primary key if needed - for field_name, field in new_class._fields.iteritems(): + for field_name, field in iteritems(new_class._fields): if field.primary_key: # Ensure only one primary key is set current_pk = new_class._meta.get('id_field') @@ -438,7 +439,7 @@ class MetaDict(dict): _merge_options = ('indexes',) def merge(self, new_options): - for k, v in new_options.iteritems(): + for k, v in iteritems(new_options): if k in self._merge_options: self[k] = self.get(k, []) + v else: diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index c26b0a79..d1e5d9ef 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,9 +1,11 @@ from contextlib import contextmanager + from pymongo.write_concern import WriteConcern +from six import iteritems + from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db - __all__ = ('switch_db', 'switch_collection', 'no_dereference', 'no_sub_classes', 'query_counter', 'set_write_concern') @@ -112,7 +114,7 @@ class no_dereference(object): GenericReferenceField = _import_class('GenericReferenceField') ComplexBaseField = _import_class('ComplexBaseField') - self.deref_fields = [k for k, v in self.cls._fields.iteritems() + self.deref_fields = [k for k, v in iteritems(self.cls._fields) if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))] diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 619b5d1f..eaebb56f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,5 +1,6 @@ from bson import DBRef, SON import six +from six import iteritems from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document) @@ -71,7 +72,7 @@ class DeReference(object): def _get_items_from_dict(items): new_items = {} - for k, v in items.iteritems(): + for k, v in iteritems(items): value = v if isinstance(v, list): value = _get_items_from_list(v) @@ -112,7 +113,7 @@ class DeReference(object): depth += 1 for item in iterator: if isinstance(item, (Document, EmbeddedDocument)): - for field_name, field in item._fields.iteritems(): + for field_name, field in iteritems(item._fields): v = item._data.get(field_name, None) if isinstance(v, LazyReference): # LazyReference inherits DBRef but should not be dereferenced here ! @@ -124,7 +125,7 @@ class DeReference(object): elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: field_cls = getattr(getattr(field, 'field', None), 'document_type', None) references = self._find_references(v, depth) - for key, refs in references.iteritems(): + for key, refs in iteritems(references): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): key = field_cls reference_map.setdefault(key, set()).update(refs) @@ -137,7 +138,7 @@ class DeReference(object): reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: references = self._find_references(item, depth - 1) - for key, refs in references.iteritems(): + for key, refs in iteritems(references): reference_map.setdefault(key, set()).update(refs) return reference_map @@ -146,7 +147,7 @@ class DeReference(object): """Fetch all references and convert to their document objects """ object_map = {} - for collection, dbrefs in self.reference_map.iteritems(): + for collection, dbrefs in iteritems(self.reference_map): # we use getattr instead of hasattr because hasattr swallows any exception under python2 # so it could hide nasty things without raising exceptions (cfr bug #1688)) @@ -157,7 +158,7 @@ class DeReference(object): refs = [dbref for dbref in dbrefs if (col_name, dbref) not in object_map] references = collection.objects.in_bulk(refs) - for key, doc in references.iteritems(): + for key, doc in iteritems(references): object_map[(col_name, key)] = doc else: # Generic reference: use the refs data to convert to document if isinstance(doc_type, (ListField, DictField, MapField)): @@ -229,7 +230,7 @@ class DeReference(object): data = [] else: is_list = False - iterator = items.iteritems() + iterator = iteritems(items) data = {} depth += 1 diff --git a/mongoengine/document.py b/mongoengine/document.py index 7a491b7d..8885825b 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -5,6 +5,7 @@ from bson.dbref import DBRef import pymongo from pymongo.read_preferences import ReadPreference import six +from six import iteritems from mongoengine import signals from mongoengine.base import (BaseDict, BaseDocument, BaseList, @@ -607,7 +608,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # Delete FileFields separately FileField = _import_class('FileField') - for name, field in self._fields.iteritems(): + for name, field in iteritems(self._fields): if isinstance(field, FileField): getattr(self, name).delete() diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 986ebf73..0e92a8c4 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -1,6 +1,7 @@ from collections import defaultdict import six +from six import iteritems __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', @@ -113,7 +114,7 @@ class ValidationError(AssertionError): return errors_dict if isinstance(source, dict): - for field_name, error in source.iteritems(): + for field_name, error in iteritems(source): errors_dict[field_name] = build_dict(error) elif isinstance(source, ValidationError) and source.errors: return build_dict(source.errors) @@ -135,12 +136,12 @@ class ValidationError(AssertionError): value = ' '.join([generate_key(k) for k in value]) elif isinstance(value, dict): value = ' '.join( - [generate_key(v, k) for k, v in value.iteritems()]) + [generate_key(v, k) for k, v in iteritems(value)]) results = '%s.%s' % (prefix, value) if prefix else value return results error_dict = defaultdict(list) - for k, v in self.to_dict().iteritems(): + for k, v in iteritems(self.to_dict()): error_dict[generate_key(v)].append(k) - return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()]) + return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)]) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0055bcab..52ed4bc9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -11,6 +11,7 @@ from bson import Binary, DBRef, ObjectId, SON import gridfs import pymongo import six +from six import iteritems try: import dateutil @@ -794,12 +795,12 @@ class DynamicField(BaseField): value = {k: v for k, v in enumerate(value)} data = {} - for k, v in value.iteritems(): + for k, v in iteritems(value): data[k] = self.to_mongo(v, use_db_field, fields) value = data if is_list: # Convert back to a list - value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] + value = [v for k, v in sorted(iteritems(data), key=itemgetter(0))] return value def to_python(self, value): diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 0ebeafa6..f39fd65f 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -12,6 +12,7 @@ import pymongo import pymongo.errors from pymongo.common import validate_read_preference import six +from six import iteritems from mongoengine import signals from mongoengine.base import get_document @@ -1731,13 +1732,13 @@ class BaseQuerySet(object): } """ total, data, types = self.exec_js(freq_func, field) - values = {types.get(k): int(v) for k, v in data.iteritems()} + values = {types.get(k): int(v) for k, v in iteritems(data)} if normalize: values = {k: float(v) / total for k, v in values.items()} frequencies = {} - for k, v in values.iteritems(): + for k, v in iteritems(values): if isinstance(k, float): if int(k) == k: k = int(k) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 2d22c350..c00271f3 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -4,6 +4,7 @@ from bson import ObjectId, SON from bson.dbref import DBRef import pymongo import six +from six import iteritems from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class @@ -154,7 +155,7 @@ def query(_doc_cls=None, **kwargs): if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \ ('$near' in value_dict or '$nearSphere' in value_dict): value_son = SON() - for k, v in value_dict.iteritems(): + for k, v in iteritems(value_dict): if k == '$maxDistance' or k == '$minDistance': continue value_son[k] = v diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 757d8037..57f48587 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -5,6 +5,7 @@ from datetime import datetime from nose.plugins.skip import SkipTest from pymongo.errors import OperationFailure import pymongo +from six import iteritems from mongoengine import * from mongoengine.connection import get_db @@ -68,7 +69,7 @@ class IndexesTest(unittest.TestCase): info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') self.assertEqual(len(info), 4) - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] for expected in expected_specs: self.assertIn(expected['fields'], info) @@ -100,7 +101,7 @@ class IndexesTest(unittest.TestCase): # the indices on -date and tags will both contain # _cls as first element in the key self.assertEqual(len(info), 4) - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] for expected in expected_specs: self.assertIn(expected['fields'], info) @@ -115,7 +116,7 @@ class IndexesTest(unittest.TestCase): ExtendedBlogPost.ensure_indexes() info = ExtendedBlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] for expected in expected_specs: self.assertIn(expected['fields'], info) @@ -225,7 +226,7 @@ class IndexesTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(Person.objects) info = Person.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('rank.title', 1)], info) def test_explicit_geo2d_index(self): @@ -245,7 +246,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', '2d')], info) def test_explicit_geo2d_index_embedded(self): @@ -268,7 +269,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('current.location.point', '2d')], info) def test_explicit_geosphere_index(self): @@ -288,7 +289,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', '2dsphere')], info) def test_explicit_geohaystack_index(self): @@ -310,7 +311,7 @@ class IndexesTest(unittest.TestCase): Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', 'geoHaystack')], info) def test_create_geohaystack_index(self): @@ -322,7 +323,7 @@ class IndexesTest(unittest.TestCase): Place.create_index({'fields': (')location.point', 'name')}, bucketSize=10) info = Place._get_collection().index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('location.point', 'geoHaystack'), ('name', 1)], info) def test_dictionary_indexes(self): @@ -355,7 +356,7 @@ class IndexesTest(unittest.TestCase): info = [(value['key'], value.get('unique', False), value.get('sparse', False)) - for key, value in info.iteritems()] + for key, value in iteritems(info)] self.assertIn(([('addDate', -1)], True, True), info) BlogPost.drop_collection() @@ -576,7 +577,7 @@ class IndexesTest(unittest.TestCase): else: self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) - self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME ).count(), 10) + self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) with self.assertRaises(Exception): BlogPost.objects.hint(('tags', 1)).next() @@ -806,7 +807,7 @@ class IndexesTest(unittest.TestCase): self.fail('Unbound local error at index + pk definition') info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] index_item = [('_id', 1), ('comments.comment_id', 1)] self.assertIn(index_item, info) @@ -854,7 +855,7 @@ class IndexesTest(unittest.TestCase): } info = MyDoc.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] + info = [value['key'] for key, value in iteritems(info)] self.assertIn([('provider_ids.foo', 1)], info) self.assertIn([('provider_ids.bar', 1)], info) @@ -936,7 +937,6 @@ class IndexesTest(unittest.TestCase): # Drop the temporary database at the end connection.drop_database('tempdatabase') - def test_index_dont_send_cls_option(self): """ Ensure that 'cls' option is not sent through ensureIndex. We shouldn't diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 32e3ed29..9cc20c89 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -2,6 +2,8 @@ import unittest import warnings +from six import iteritems + from mongoengine import (BooleanField, Document, EmbeddedDocument, EmbeddedDocumentField, GenericReferenceField, IntField, ReferenceField, StringField, connect) @@ -485,7 +487,7 @@ class InheritanceTest(unittest.TestCase): meta = {'abstract': True} class Human(Mammal): pass - for k, v in defaults.iteritems(): + for k, v in iteritems(defaults): for cls in [Animal, Fish, Guppy]: self.assertEqual(cls._meta[k], v) diff --git a/tests/document/instance.py b/tests/document/instance.py index 5319ace4..9bde23f3 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -9,6 +9,7 @@ import weakref from datetime import datetime from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError +from six import iteritems from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, @@ -1482,7 +1483,7 @@ class InstanceTest(MongoDBTestCase): Message.drop_collection() # All objects share the same id, but each in a different collection - user = User(id=1, name='user-name')#.save() + user = User(id=1, name='user-name') # .save() message = Message(id=1, author=user).save() message.author.name = 'tutu' @@ -2000,7 +2001,6 @@ class InstanceTest(MongoDBTestCase): child_record.delete() self.assertEqual(Record.objects(name='parent').get().children, []) - def test_reverse_delete_rule_with_custom_id_field(self): """Ensure that a referenced document with custom primary key is also deleted upon deletion. @@ -3059,7 +3059,7 @@ class InstanceTest(MongoDBTestCase): def expand(self): self.flattened_parameter = {} - for parameter_name, parameter in self.parameters.iteritems(): + for parameter_name, parameter in iteritems(self.parameters): parameter.expand() class NodesSystem(Document): @@ -3067,7 +3067,7 @@ class InstanceTest(MongoDBTestCase): nodes = MapField(ReferenceField(Node, dbref=False)) def save(self, *args, **kwargs): - for node_name, node in self.nodes.iteritems(): + for node_name, node in iteritems(self.nodes): node.expand() node.save(*args, **kwargs) super(NodesSystem, self).save(*args, **kwargs) @@ -3381,7 +3381,6 @@ class InstanceTest(MongoDBTestCase): class User(Document): company = ReferenceField(Company) - # Ensure index creation exception aren't swallowed (#1688) with self.assertRaises(DuplicateKeyError): User.objects().select_related() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index d3a2418a..ef67ac54 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -12,6 +12,7 @@ from pymongo.errors import ConfigurationError from pymongo.read_preferences import ReadPreference from pymongo.results import UpdateResult import six +from six import iteritems from mongoengine import * from mongoengine.connection import get_connection, get_db @@ -4026,7 +4027,7 @@ class QuerySetTest(unittest.TestCase): info = [(value['key'], value.get('unique', False), value.get('sparse', False)) - for key, value in info.iteritems()] + for key, value in iteritems(info)] self.assertIn(([('_cls', 1), ('message', 1)], False, False), info) def test_where(self): diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 5cf089f4..cf1194f4 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -2,6 +2,7 @@ import unittest from bson import DBRef, ObjectId +from six import iteritems from mongoengine import * from mongoengine.connection import get_db @@ -632,7 +633,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) # Document select_related @@ -645,7 +646,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) # Queryset select_related @@ -659,7 +660,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) User.drop_collection() @@ -714,7 +715,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Document select_related @@ -730,7 +731,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Queryset select_related @@ -747,7 +748,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) Group.objects.delete() @@ -805,7 +806,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, UserA) # Document select_related @@ -821,7 +822,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, UserA) # Queryset select_related @@ -838,7 +839,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIsInstance(m, UserA) UserA.drop_collection() @@ -893,7 +894,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Document select_related @@ -909,7 +910,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) # Queryset select_related @@ -926,7 +927,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - for k, m in group_obj.members.iteritems(): + for k, m in iteritems(group_obj.members): self.assertIn('User', m.__class__.__name__) Group.objects.delete() @@ -1064,7 +1065,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(msg.author, user) self.assertEqual(msg.author.name, 'new-name') - def test_list_lookup_not_checked_in_map(self): """Ensure we dereference list data correctly """ @@ -1286,5 +1286,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) + if __name__ == '__main__': unittest.main() From f0a344525001ab2a74ad97fe16a236f3225cb7ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 18 Feb 2019 22:15:58 +0100 Subject: [PATCH 04/71] minor fix for import order --- tests/document/instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/document/instance.py b/tests/document/instance.py index 9bde23f3..03fd2f2c 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -5,8 +5,8 @@ import pickle import unittest import uuid import weakref - from datetime import datetime + from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError from six import iteritems From 5bbe782812df249ef60f7ad9cc44482ec837bd71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 23 Feb 2019 22:37:32 +0100 Subject: [PATCH 05/71] fix deprecated call to pymongo save() in tests --- tests/document/instance.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/document/instance.py b/tests/document/instance.py index f9331bd0..cde18c9f 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2758,7 +2758,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'foo': 'Bar', 'data': [1, 2, 3] @@ -2774,7 +2774,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'foo': 'Bar', 'data': [1, 2, 3] @@ -2797,7 +2797,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'thing': { 'name': 'My thing', @@ -2820,7 +2820,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'thing': { 'name': 'My thing', @@ -2843,7 +2843,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().save({ + User._get_collection().insert_one({ 'name': 'John', 'thing': { 'name': 'My thing', From 28606e9985cb4abcf1fe32b7201f4c28a14325b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 21 Feb 2019 21:53:59 +0100 Subject: [PATCH 06/71] refactor fields tests (float, int, lazyref, long, url) #1983 --- tests/fields/fields.py | 693 +--------------------- tests/fields/test_float_field.py | 58 ++ tests/fields/test_int_field.py | 42 ++ tests/fields/test_lazy_reference_field.py | 524 ++++++++++++++++ tests/fields/test_long_field.py | 56 ++ tests/fields/test_url_field.py | 59 ++ 6 files changed, 740 insertions(+), 692 deletions(-) create mode 100644 tests/fields/test_float_field.py create mode 100644 tests/fields/test_int_field.py create mode 100644 tests/fields/test_lazy_reference_field.py create mode 100644 tests/fields/test_long_field.py create mode 100644 tests/fields/test_url_field.py diff --git a/tests/fields/fields.py b/tests/fields/fields.py index b43c92af..194d07d7 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2,9 +2,6 @@ import datetime import unittest import uuid -import math -import itertools -import re import sys from nose.plugins.skip import SkipTest @@ -13,15 +10,10 @@ import six from decimal import Decimal from bson import DBRef, ObjectId, SON -try: - from bson.int64 import Int64 -except ImportError: - Int64 = long from mongoengine import * -from mongoengine.connection import get_db from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, - _document_registry, LazyReference) + _document_registry) from tests.utils import MongoDBTestCase @@ -284,30 +276,6 @@ class FieldTest(MongoDBTestCase): # attempted. self.assertRaises(ValidationError, ret.validate) - def test_int_and_float_ne_operator(self): - class TestDocument(Document): - int_fld = IntField() - float_fld = FloatField() - - TestDocument.drop_collection() - - TestDocument(int_fld=None, float_fld=None).save() - TestDocument(int_fld=1, float_fld=1).save() - - self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) - self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) - - def test_long_ne_operator(self): - class TestDocument(Document): - long_fld = LongField() - - TestDocument.drop_collection() - - TestDocument(long_fld=None).save() - TestDocument(long_fld=1).save() - - self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) - def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to an ObjectIdField. @@ -351,135 +319,6 @@ class FieldTest(MongoDBTestCase): person.name = 'Shorter name' person.validate() - def test_url_validation(self): - """Ensure that URLFields validate urls properly.""" - class Link(Document): - url = URLField() - - link = Link() - link.url = 'google' - self.assertRaises(ValidationError, link.validate) - - link.url = 'http://www.google.com:8080' - link.validate() - - def test_unicode_url_validation(self): - """Ensure unicode URLs are validated properly.""" - class Link(Document): - url = URLField() - - link = Link() - link.url = u'http://привет.com' - - # TODO fix URL validation - this *IS* a valid URL - # For now we just want to make sure that the error message is correct - try: - link.validate() - self.assertTrue(False) - except ValidationError as e: - self.assertEqual( - unicode(e), - u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])" - ) - - def test_url_scheme_validation(self): - """Ensure that URLFields validate urls with specific schemes properly. - """ - class Link(Document): - url = URLField() - - class SchemeLink(Document): - url = URLField(schemes=['ws', 'irc']) - - link = Link() - link.url = 'ws://google.com' - self.assertRaises(ValidationError, link.validate) - - scheme_link = SchemeLink() - scheme_link.url = 'ws://google.com' - scheme_link.validate() - - def test_url_allowed_domains(self): - """Allow underscore in domain names. - """ - class Link(Document): - url = URLField() - - link = Link() - link.url = 'https://san_leandro-ca.geebo.com' - link.validate() - - def test_int_validation(self): - """Ensure that invalid values cannot be assigned to int fields. - """ - class Person(Document): - age = IntField(min_value=0, max_value=110) - - person = Person() - person.age = 50 - person.validate() - - person.age = -1 - self.assertRaises(ValidationError, person.validate) - person.age = 120 - self.assertRaises(ValidationError, person.validate) - person.age = 'ten' - self.assertRaises(ValidationError, person.validate) - - def test_long_validation(self): - """Ensure that invalid values cannot be assigned to long fields. - """ - class TestDocument(Document): - value = LongField(min_value=0, max_value=110) - - doc = TestDocument() - doc.value = 50 - doc.validate() - - doc.value = -1 - self.assertRaises(ValidationError, doc.validate) - doc.age = 120 - self.assertRaises(ValidationError, doc.validate) - doc.age = 'ten' - self.assertRaises(ValidationError, doc.validate) - - def test_float_validation(self): - """Ensure that invalid values cannot be assigned to float fields. - """ - class Person(Document): - height = FloatField(min_value=0.1, max_value=3.5) - - class BigPerson(Document): - height = FloatField() - - person = Person() - person.height = 1.89 - person.validate() - - person.height = '2.0' - self.assertRaises(ValidationError, person.validate) - - person.height = 0.01 - self.assertRaises(ValidationError, person.validate) - - person.height = 4.0 - self.assertRaises(ValidationError, person.validate) - - person_2 = Person(height='something invalid') - self.assertRaises(ValidationError, person_2.validate) - - big_person = BigPerson() - - for value, value_type in enumerate(six.integer_types): - big_person.height = value_type(value) - big_person.validate() - - big_person.height = 2 ** 500 - big_person.validate() - - big_person.height = 2 ** 100000 # Too big for a float value - self.assertRaises(ValidationError, big_person.validate) - def test_decimal_validation(self): """Ensure that invalid values cannot be assigned to decimal fields. """ @@ -3534,19 +3373,6 @@ class FieldTest(MongoDBTestCase): with self.assertRaises(FieldDoesNotExist): Doc(bar='test') - def test_long_field_is_considered_as_int64(self): - """ - Tests that long fields are stored as long in mongo, even if long - value is small enough to be an int. - """ - class TestLongFieldConsideredAsInt64(Document): - some_long = LongField() - - doc = TestLongFieldConsideredAsInt64(some_long=42).save() - db = get_db() - self.assertIsInstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64) - self.assertIsInstance(doc.some_long, six.integer_types) - class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): @@ -4493,522 +4319,5 @@ class CachedReferenceFieldTest(MongoDBTestCase): self.assertIsInstance(ocorrence.animal, Animal) -class LazyReferenceFieldTest(MongoDBTestCase): - def test_lazy_reference_config(self): - # Make sure ReferenceField only accepts a document class or a string - # with a document class name. - self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) - - def test_lazy_reference_simple(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(person="test", animal=animal).save() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) - # `fetch` keep cache on referenced document by default... - animal.tag = "not so heavy" - animal.save() - double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") - # ...unless specified otherwise - fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") - - def test_lazy_reference_fetch_invalid_ref(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(person="test", animal=animal).save() - animal.delete() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(DoesNotExist): - p.animal.fetch() - - def test_lazy_reference_set(self): - class Animal(Document): - meta = {'allow_inheritance': True} - - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - class SubAnimal(Animal): - nick = StringField() - - animal = Animal(name="Leopard", tag="heavy").save() - sub_animal = SubAnimal(nick='doggo', name='dog').save() - for ref in ( - animal, - animal.pk, - DBRef(animal._get_collection_name(), animal.pk), - LazyReference(Animal, animal.pk), - - sub_animal, - sub_animal.pk, - DBRef(sub_animal._get_collection_name(), sub_animal.pk), - LazyReference(SubAnimal, sub_animal.pk), - ): - p = Ocurrence(person="test", animal=ref).save() - p.reload() - self.assertIsInstance(p.animal, LazyReference) - p.animal.fetch() - - def test_lazy_reference_bad_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - class BadDoc(Document): - pass - - animal = Animal(name="Leopard", tag="heavy").save() - baddoc = BadDoc().save() - for bad in ( - 42, - 'foo', - baddoc, - DBRef(baddoc._get_collection_name(), animal.pk), - LazyReference(BadDoc, animal.pk) - ): - with self.assertRaises(ValidationError): - p = Ocurrence(person="test", animal=bad).save() - - def test_lazy_reference_query_conversion(self): - """Ensure that LazyReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = LazyReferenceField(Member, dbref=False) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - # Same thing by passing a LazyReference instance - post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) - - def test_lazy_reference_query_conversion_dbref(self): - """Ensure that LazyReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = LazyReferenceField(Member, dbref=True) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - # Same thing by passing a LazyReference instance - post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) - - def test_lazy_reference_passthrough(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - animal = LazyReferenceField(Animal, passthrough=False) - animal_passthrough = LazyReferenceField(Animal, passthrough=True) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(animal=animal, animal_passthrough=animal).save() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - with self.assertRaises(KeyError): - p.animal['name'] - with self.assertRaises(AttributeError): - p.animal.name - self.assertEqual(p.animal.pk, animal.pk) - - self.assertEqual(p.animal_passthrough.name, "Leopard") - self.assertEqual(p.animal_passthrough['name'], "Leopard") - - # Should not be able to access referenced document's methods - with self.assertRaises(AttributeError): - p.animal.save - with self.assertRaises(KeyError): - p.animal['save'] - - def test_lazy_reference_not_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - Ocurrence(person='foo').save() - p = Ocurrence.objects.get() - self.assertIs(p.animal, None) - - def test_lazy_reference_equality(self): - class Animal(Document): - name = StringField() - tag = StringField() - - Animal.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - animalref = LazyReference(Animal, animal.pk) - self.assertEqual(animal, animalref) - self.assertEqual(animalref, animal) - - other_animalref = LazyReference(Animal, ObjectId("54495ad94c934721ede76f90")) - self.assertNotEqual(animal, other_animalref) - self.assertNotEqual(other_animalref, animal) - - def test_lazy_reference_embedded(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class EmbeddedOcurrence(EmbeddedDocument): - in_list = ListField(LazyReferenceField(Animal)) - direct = LazyReferenceField(Animal) - - class Ocurrence(Document): - in_list = ListField(LazyReferenceField(Animal)) - in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) - direct = LazyReferenceField(Animal) - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal1 = Animal('doggo').save() - animal2 = Animal('cheeta').save() - - def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) - for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) - for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) - - occ = Ocurrence( - in_list=[animal1, animal2], - in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, - direct=animal1 - ).save() - check_fields_type(occ) - occ.reload() - check_fields_type(occ) - occ.direct = animal1.id - occ.in_list = [animal1.id, animal2.id] - occ.in_embedded.direct = animal1.id - occ.in_embedded.in_list = [animal1.id, animal2.id] - check_fields_type(occ) - - -class GenericLazyReferenceFieldTest(MongoDBTestCase): - def test_generic_lazy_reference_simple(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard", tag="heavy").save() - Ocurrence(person="test", animal=animal).save() - p = Ocurrence.objects.get() - self.assertIsInstance(p.animal, LazyReference) - fetched_animal = p.animal.fetch() - self.assertEqual(fetched_animal, animal) - # `fetch` keep cache on referenced document by default... - animal.tag = "not so heavy" - animal.save() - double_fetch = p.animal.fetch() - self.assertIs(fetched_animal, double_fetch) - self.assertEqual(double_fetch.tag, "heavy") - # ...unless specified otherwise - fetch_force = p.animal.fetch(force=True) - self.assertIsNot(fetch_force, fetched_animal) - self.assertEqual(fetch_force.tag, "not so heavy") - - def test_generic_lazy_reference_choices(self): - class Animal(Document): - name = StringField() - - class Vegetal(Document): - name = StringField() - - class Mineral(Document): - name = StringField() - - class Ocurrence(Document): - living_thing = GenericLazyReferenceField(choices=[Animal, Vegetal]) - thing = GenericLazyReferenceField() - - Animal.drop_collection() - Vegetal.drop_collection() - Mineral.drop_collection() - Ocurrence.drop_collection() - - animal = Animal(name="Leopard").save() - vegetal = Vegetal(name="Oak").save() - mineral = Mineral(name="Granite").save() - - occ_animal = Ocurrence(living_thing=animal, thing=animal).save() - occ_vegetal = Ocurrence(living_thing=vegetal, thing=vegetal).save() - with self.assertRaises(ValidationError): - Ocurrence(living_thing=mineral).save() - - occ = Ocurrence.objects.get(living_thing=animal) - self.assertEqual(occ, occ_animal) - self.assertIsInstance(occ.thing, LazyReference) - self.assertIsInstance(occ.living_thing, LazyReference) - - occ.thing = vegetal - occ.living_thing = vegetal - occ.save() - - occ.thing = mineral - occ.living_thing = mineral - with self.assertRaises(ValidationError): - occ.save() - - def test_generic_lazy_reference_set(self): - class Animal(Document): - meta = {'allow_inheritance': True} - - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - class SubAnimal(Animal): - nick = StringField() - - animal = Animal(name="Leopard", tag="heavy").save() - sub_animal = SubAnimal(nick='doggo', name='dog').save() - for ref in ( - animal, - LazyReference(Animal, animal.pk), - {'_cls': 'Animal', '_ref': DBRef(animal._get_collection_name(), animal.pk)}, - - sub_animal, - LazyReference(SubAnimal, sub_animal.pk), - {'_cls': 'SubAnimal', '_ref': DBRef(sub_animal._get_collection_name(), sub_animal.pk)}, - ): - p = Ocurrence(person="test", animal=ref).save() - p.reload() - self.assertIsInstance(p.animal, (LazyReference, Document)) - p.animal.fetch() - - def test_generic_lazy_reference_bad_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField(choices=['Animal']) - - Animal.drop_collection() - Ocurrence.drop_collection() - - class BadDoc(Document): - pass - - animal = Animal(name="Leopard", tag="heavy").save() - baddoc = BadDoc().save() - for bad in ( - 42, - 'foo', - baddoc, - LazyReference(BadDoc, animal.pk) - ): - with self.assertRaises(ValidationError): - p = Ocurrence(person="test", animal=bad).save() - - def test_generic_lazy_reference_query_conversion(self): - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = GenericLazyReferenceField() - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - # Same thing by passing a LazyReference instance - post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() - self.assertEqual(post.id, post2.id) - - def test_generic_lazy_reference_not_set(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocurrence(Document): - person = StringField() - animal = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - Ocurrence(person='foo').save() - p = Ocurrence.objects.get() - self.assertIs(p.animal, None) - - def test_generic_lazy_reference_embedded(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class EmbeddedOcurrence(EmbeddedDocument): - in_list = ListField(GenericLazyReferenceField()) - direct = GenericLazyReferenceField() - - class Ocurrence(Document): - in_list = ListField(GenericLazyReferenceField()) - in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) - direct = GenericLazyReferenceField() - - Animal.drop_collection() - Ocurrence.drop_collection() - - animal1 = Animal('doggo').save() - animal2 = Animal('cheeta').save() - - def check_fields_type(occ): - self.assertIsInstance(occ.direct, LazyReference) - for elem in occ.in_list: - self.assertIsInstance(elem, LazyReference) - self.assertIsInstance(occ.in_embedded.direct, LazyReference) - for elem in occ.in_embedded.in_list: - self.assertIsInstance(elem, LazyReference) - - occ = Ocurrence( - in_list=[animal1, animal2], - in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, - direct=animal1 - ).save() - check_fields_type(occ) - occ.reload() - check_fields_type(occ) - animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} - animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} - occ.direct = animal1_ref - occ.in_list = [animal1_ref, animal2_ref] - occ.in_embedded.direct = animal1_ref - occ.in_embedded.in_list = [animal1_ref, animal2_ref] - check_fields_type(occ) - - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py new file mode 100644 index 00000000..fa92cf20 --- /dev/null +++ b/tests/fields/test_float_field.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +import six + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestFloatField(MongoDBTestCase): + + def test_float_ne_operator(self): + class TestDocument(Document): + float_fld = FloatField() + + TestDocument.drop_collection() + + TestDocument(float_fld=None).save() + TestDocument(float_fld=1).save() + + self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) + self.assertEqual(1, TestDocument.objects(float_fld__ne=1).count()) + + def test_validation(self): + """Ensure that invalid values cannot be assigned to float fields. + """ + class Person(Document): + height = FloatField(min_value=0.1, max_value=3.5) + + class BigPerson(Document): + height = FloatField() + + person = Person() + person.height = 1.89 + person.validate() + + person.height = '2.0' + self.assertRaises(ValidationError, person.validate) + + person.height = 0.01 + self.assertRaises(ValidationError, person.validate) + + person.height = 4.0 + self.assertRaises(ValidationError, person.validate) + + person_2 = Person(height='something invalid') + self.assertRaises(ValidationError, person_2.validate) + + big_person = BigPerson() + + for value, value_type in enumerate(six.integer_types): + big_person.height = value_type(value) + big_person.validate() + + big_person.height = 2 ** 500 + big_person.validate() + + big_person.height = 2 ** 100000 # Too big for a float value + self.assertRaises(ValidationError, big_person.validate) diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py new file mode 100644 index 00000000..1b1f7ad9 --- /dev/null +++ b/tests/fields/test_int_field.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestIntField(MongoDBTestCase): + + def test_int_validation(self): + """Ensure that invalid values cannot be assigned to int fields. + """ + class Person(Document): + age = IntField(min_value=0, max_value=110) + + person = Person() + person.age = 0 + person.validate() + + person.age = 50 + person.validate() + + person.age = 110 + person.validate() + + person.age = -1 + self.assertRaises(ValidationError, person.validate) + person.age = 120 + self.assertRaises(ValidationError, person.validate) + person.age = 'ten' + self.assertRaises(ValidationError, person.validate) + + def test_ne_operator(self): + class TestDocument(Document): + int_fld = IntField() + + TestDocument.drop_collection() + + TestDocument(int_fld=None).save() + TestDocument(int_fld=1).save() + + self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) + self.assertEqual(1, TestDocument.objects(int_fld__ne=1).count()) diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py new file mode 100644 index 00000000..a72e8cbe --- /dev/null +++ b/tests/fields/test_lazy_reference_field.py @@ -0,0 +1,524 @@ +# -*- coding: utf-8 -*- +from bson import DBRef, ObjectId + +from mongoengine import * +from mongoengine.base import LazyReference + +from tests.utils import MongoDBTestCase + + +class TestLazyReferenceField(MongoDBTestCase): + def test_lazy_reference_config(self): + # Make sure ReferenceField only accepts a document class or a string + # with a document class name. + self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) + + def test_lazy_reference_simple(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(person="test", animal=animal).save() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + fetched_animal = p.animal.fetch() + self.assertEqual(fetched_animal, animal) + # `fetch` keep cache on referenced document by default... + animal.tag = "not so heavy" + animal.save() + double_fetch = p.animal.fetch() + self.assertIs(fetched_animal, double_fetch) + self.assertEqual(double_fetch.tag, "heavy") + # ...unless specified otherwise + fetch_force = p.animal.fetch(force=True) + self.assertIsNot(fetch_force, fetched_animal) + self.assertEqual(fetch_force.tag, "not so heavy") + + def test_lazy_reference_fetch_invalid_ref(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(person="test", animal=animal).save() + animal.delete() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + with self.assertRaises(DoesNotExist): + p.animal.fetch() + + def test_lazy_reference_set(self): + class Animal(Document): + meta = {'allow_inheritance': True} + + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + class SubAnimal(Animal): + nick = StringField() + + animal = Animal(name="Leopard", tag="heavy").save() + sub_animal = SubAnimal(nick='doggo', name='dog').save() + for ref in ( + animal, + animal.pk, + DBRef(animal._get_collection_name(), animal.pk), + LazyReference(Animal, animal.pk), + + sub_animal, + sub_animal.pk, + DBRef(sub_animal._get_collection_name(), sub_animal.pk), + LazyReference(SubAnimal, sub_animal.pk), + ): + p = Ocurrence(person="test", animal=ref).save() + p.reload() + self.assertIsInstance(p.animal, LazyReference) + p.animal.fetch() + + def test_lazy_reference_bad_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + class BadDoc(Document): + pass + + animal = Animal(name="Leopard", tag="heavy").save() + baddoc = BadDoc().save() + for bad in ( + 42, + 'foo', + baddoc, + DBRef(baddoc._get_collection_name(), animal.pk), + LazyReference(BadDoc, animal.pk) + ): + with self.assertRaises(ValidationError): + p = Ocurrence(person="test", animal=bad).save() + + def test_lazy_reference_query_conversion(self): + """Ensure that LazyReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = LazyReferenceField(Member, dbref=False) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + # Same thing by passing a LazyReference instance + post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() + self.assertEqual(post.id, post2.id) + + def test_lazy_reference_query_conversion_dbref(self): + """Ensure that LazyReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = LazyReferenceField(Member, dbref=True) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + # Same thing by passing a LazyReference instance + post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() + self.assertEqual(post.id, post2.id) + + def test_lazy_reference_passthrough(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + animal = LazyReferenceField(Animal, passthrough=False) + animal_passthrough = LazyReferenceField(Animal, passthrough=True) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(animal=animal, animal_passthrough=animal).save() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + with self.assertRaises(KeyError): + p.animal['name'] + with self.assertRaises(AttributeError): + p.animal.name + self.assertEqual(p.animal.pk, animal.pk) + + self.assertEqual(p.animal_passthrough.name, "Leopard") + self.assertEqual(p.animal_passthrough['name'], "Leopard") + + # Should not be able to access referenced document's methods + with self.assertRaises(AttributeError): + p.animal.save + with self.assertRaises(KeyError): + p.animal['save'] + + def test_lazy_reference_not_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + Ocurrence(person='foo').save() + p = Ocurrence.objects.get() + self.assertIs(p.animal, None) + + def test_lazy_reference_equality(self): + class Animal(Document): + name = StringField() + tag = StringField() + + Animal.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + animalref = LazyReference(Animal, animal.pk) + self.assertEqual(animal, animalref) + self.assertEqual(animalref, animal) + + other_animalref = LazyReference(Animal, ObjectId("54495ad94c934721ede76f90")) + self.assertNotEqual(animal, other_animalref) + self.assertNotEqual(other_animalref, animal) + + def test_lazy_reference_embedded(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class EmbeddedOcurrence(EmbeddedDocument): + in_list = ListField(LazyReferenceField(Animal)) + direct = LazyReferenceField(Animal) + + class Ocurrence(Document): + in_list = ListField(LazyReferenceField(Animal)) + in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) + direct = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal1 = Animal('doggo').save() + animal2 = Animal('cheeta').save() + + def check_fields_type(occ): + self.assertIsInstance(occ.direct, LazyReference) + for elem in occ.in_list: + self.assertIsInstance(elem, LazyReference) + self.assertIsInstance(occ.in_embedded.direct, LazyReference) + for elem in occ.in_embedded.in_list: + self.assertIsInstance(elem, LazyReference) + + occ = Ocurrence( + in_list=[animal1, animal2], + in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, + direct=animal1 + ).save() + check_fields_type(occ) + occ.reload() + check_fields_type(occ) + occ.direct = animal1.id + occ.in_list = [animal1.id, animal2.id] + occ.in_embedded.direct = animal1.id + occ.in_embedded.in_list = [animal1.id, animal2.id] + check_fields_type(occ) + + +class TestGenericLazyReferenceField(MongoDBTestCase): + def test_generic_lazy_reference_simple(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard", tag="heavy").save() + Ocurrence(person="test", animal=animal).save() + p = Ocurrence.objects.get() + self.assertIsInstance(p.animal, LazyReference) + fetched_animal = p.animal.fetch() + self.assertEqual(fetched_animal, animal) + # `fetch` keep cache on referenced document by default... + animal.tag = "not so heavy" + animal.save() + double_fetch = p.animal.fetch() + self.assertIs(fetched_animal, double_fetch) + self.assertEqual(double_fetch.tag, "heavy") + # ...unless specified otherwise + fetch_force = p.animal.fetch(force=True) + self.assertIsNot(fetch_force, fetched_animal) + self.assertEqual(fetch_force.tag, "not so heavy") + + def test_generic_lazy_reference_choices(self): + class Animal(Document): + name = StringField() + + class Vegetal(Document): + name = StringField() + + class Mineral(Document): + name = StringField() + + class Ocurrence(Document): + living_thing = GenericLazyReferenceField(choices=[Animal, Vegetal]) + thing = GenericLazyReferenceField() + + Animal.drop_collection() + Vegetal.drop_collection() + Mineral.drop_collection() + Ocurrence.drop_collection() + + animal = Animal(name="Leopard").save() + vegetal = Vegetal(name="Oak").save() + mineral = Mineral(name="Granite").save() + + occ_animal = Ocurrence(living_thing=animal, thing=animal).save() + occ_vegetal = Ocurrence(living_thing=vegetal, thing=vegetal).save() + with self.assertRaises(ValidationError): + Ocurrence(living_thing=mineral).save() + + occ = Ocurrence.objects.get(living_thing=animal) + self.assertEqual(occ, occ_animal) + self.assertIsInstance(occ.thing, LazyReference) + self.assertIsInstance(occ.living_thing, LazyReference) + + occ.thing = vegetal + occ.living_thing = vegetal + occ.save() + + occ.thing = mineral + occ.living_thing = mineral + with self.assertRaises(ValidationError): + occ.save() + + def test_generic_lazy_reference_set(self): + class Animal(Document): + meta = {'allow_inheritance': True} + + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + class SubAnimal(Animal): + nick = StringField() + + animal = Animal(name="Leopard", tag="heavy").save() + sub_animal = SubAnimal(nick='doggo', name='dog').save() + for ref in ( + animal, + LazyReference(Animal, animal.pk), + {'_cls': 'Animal', '_ref': DBRef(animal._get_collection_name(), animal.pk)}, + + sub_animal, + LazyReference(SubAnimal, sub_animal.pk), + {'_cls': 'SubAnimal', '_ref': DBRef(sub_animal._get_collection_name(), sub_animal.pk)}, + ): + p = Ocurrence(person="test", animal=ref).save() + p.reload() + self.assertIsInstance(p.animal, (LazyReference, Document)) + p.animal.fetch() + + def test_generic_lazy_reference_bad_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField(choices=['Animal']) + + Animal.drop_collection() + Ocurrence.drop_collection() + + class BadDoc(Document): + pass + + animal = Animal(name="Leopard", tag="heavy").save() + baddoc = BadDoc().save() + for bad in ( + 42, + 'foo', + baddoc, + LazyReference(BadDoc, animal.pk) + ): + with self.assertRaises(ValidationError): + p = Ocurrence(person="test", animal=bad).save() + + def test_generic_lazy_reference_query_conversion(self): + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = GenericLazyReferenceField() + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + # Same thing by passing a LazyReference instance + post = BlogPost.objects(author=LazyReference(Member, m2.pk)).first() + self.assertEqual(post.id, post2.id) + + def test_generic_lazy_reference_not_set(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + Ocurrence(person='foo').save() + p = Ocurrence.objects.get() + self.assertIs(p.animal, None) + + def test_generic_lazy_reference_embedded(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class EmbeddedOcurrence(EmbeddedDocument): + in_list = ListField(GenericLazyReferenceField()) + direct = GenericLazyReferenceField() + + class Ocurrence(Document): + in_list = ListField(GenericLazyReferenceField()) + in_embedded = EmbeddedDocumentField(EmbeddedOcurrence) + direct = GenericLazyReferenceField() + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal1 = Animal('doggo').save() + animal2 = Animal('cheeta').save() + + def check_fields_type(occ): + self.assertIsInstance(occ.direct, LazyReference) + for elem in occ.in_list: + self.assertIsInstance(elem, LazyReference) + self.assertIsInstance(occ.in_embedded.direct, LazyReference) + for elem in occ.in_embedded.in_list: + self.assertIsInstance(elem, LazyReference) + + occ = Ocurrence( + in_list=[animal1, animal2], + in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, + direct=animal1 + ).save() + check_fields_type(occ) + occ.reload() + check_fields_type(occ) + animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} + animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} + occ.direct = animal1_ref + occ.in_list = [animal1_ref, animal2_ref] + occ.in_embedded.direct = animal1_ref + occ.in_embedded.in_list = [animal1_ref, animal2_ref] + check_fields_type(occ) diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py new file mode 100644 index 00000000..4ab7403d --- /dev/null +++ b/tests/fields/test_long_field.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +import six + +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long + +from mongoengine import * +from mongoengine.connection import get_db + +from tests.utils import MongoDBTestCase + + +class TestLongField(MongoDBTestCase): + + def test_long_field_is_considered_as_int64(self): + """ + Tests that long fields are stored as long in mongo, even if long + value is small enough to be an int. + """ + class TestLongFieldConsideredAsInt64(Document): + some_long = LongField() + + doc = TestLongFieldConsideredAsInt64(some_long=42).save() + db = get_db() + self.assertIsInstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64) + self.assertIsInstance(doc.some_long, six.integer_types) + + def test_long_validation(self): + """Ensure that invalid values cannot be assigned to long fields. + """ + class TestDocument(Document): + value = LongField(min_value=0, max_value=110) + + doc = TestDocument() + doc.value = 50 + doc.validate() + + doc.value = -1 + self.assertRaises(ValidationError, doc.validate) + doc.age = 120 + self.assertRaises(ValidationError, doc.validate) + doc.age = 'ten' + self.assertRaises(ValidationError, doc.validate) + + def test_long_ne_operator(self): + class TestDocument(Document): + long_fld = LongField() + + TestDocument.drop_collection() + + TestDocument(long_fld=None).save() + TestDocument(long_fld=1).save() + + self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py new file mode 100644 index 00000000..0447799e --- /dev/null +++ b/tests/fields/test_url_field.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestFloatField(MongoDBTestCase): + + def test_validation(self): + """Ensure that URLFields validate urls properly.""" + class Link(Document): + url = URLField() + + link = Link() + link.url = 'google' + self.assertRaises(ValidationError, link.validate) + + link.url = 'http://www.google.com:8080' + link.validate() + + def test_unicode_url_validation(self): + """Ensure unicode URLs are validated properly.""" + class Link(Document): + url = URLField() + + link = Link() + link.url = u'http://привет.com' + + # TODO fix URL validation - this *IS* a valid URL + # For now we just want to make sure that the error message is correct + with self.assertRaises(ValidationError) as ctx_err: + link.validate() + self.assertEqual(unicode(ctx_err.exception), + u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])") + + def test_url_scheme_validation(self): + """Ensure that URLFields validate urls with specific schemes properly. + """ + class Link(Document): + url = URLField() + + class SchemeLink(Document): + url = URLField(schemes=['ws', 'irc']) + + link = Link() + link.url = 'ws://google.com' + self.assertRaises(ValidationError, link.validate) + + scheme_link = SchemeLink() + scheme_link.url = 'ws://google.com' + scheme_link.validate() + + def test_underscore_allowed_in_domains_names(self): + class Link(Document): + url = URLField() + + link = Link() + link.url = 'https://san_leandro-ca.geebo.com' + link.validate() From b9cc8a4ca9d2874f7866efa2c3e792f757c767b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 21 Feb 2019 23:12:19 +0100 Subject: [PATCH 07/71] refactor more field tests into submodules #1983 --- tests/fields/fields.py | 1598 +------------------ tests/fields/test_boolean_field.py | 49 + tests/fields/test_cached_reference_field.py | 443 +++++ tests/fields/test_complex_datetime_field.py | 5 - tests/fields/test_date_field.py | 19 - tests/fields/test_datetime_field.py | 5 - tests/fields/test_decimal_field.py | 91 ++ tests/fields/test_dict_field.py | 303 ++++ tests/fields/test_email_field.py | 120 ++ tests/fields/test_map_field.py | 144 ++ tests/fields/test_reference_field.py | 219 +++ tests/fields/test_sequence_field.py | 271 ++++ tests/fields/test_url_field.py | 2 +- tests/fields/test_uuid_field.py | 65 + tests/utils.py | 7 +- 15 files changed, 1722 insertions(+), 1619 deletions(-) create mode 100644 tests/fields/test_boolean_field.py create mode 100644 tests/fields/test_cached_reference_field.py create mode 100644 tests/fields/test_decimal_field.py create mode 100644 tests/fields/test_dict_field.py create mode 100644 tests/fields/test_email_field.py create mode 100644 tests/fields/test_map_field.py create mode 100644 tests/fields/test_reference_field.py create mode 100644 tests/fields/test_sequence_field.py create mode 100644 tests/fields/test_uuid_field.py diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 194d07d7..c772b472 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- import datetime import unittest -import uuid -import sys from nose.plugins.skip import SkipTest -import six - -from decimal import Decimal from bson import DBRef, ObjectId, SON -from mongoengine import * -from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, +from mongoengine import Document, StringField, IntField, DateTimeField, DateField, ValidationError, \ + ComplexDateTimeField, FloatField, ListField, ReferenceField, DictField, EmbeddedDocument, EmbeddedDocumentField, \ + GenericReferenceField, DoesNotExist, NotRegistered, GenericEmbeddedDocumentField, OperationError, DynamicField, \ + FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField, ObjectIdField, \ + SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument +from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry) from tests.utils import MongoDBTestCase @@ -276,7 +275,7 @@ class FieldTest(MongoDBTestCase): # attempted. self.assertRaises(ValidationError, ret.validate) - def test_object_id_validation(self): + def test_default_id_validation_as_objectid(self): """Ensure that invalid values cannot be assigned to an ObjectIdField. """ @@ -292,7 +291,7 @@ class FieldTest(MongoDBTestCase): person.id = 'abc' self.assertRaises(ValidationError, person.validate) - person.id = '497ce96f395f2f052a494fd4' + person.id = str(ObjectId()) person.validate() def test_string_validation(self): @@ -319,33 +318,6 @@ class FieldTest(MongoDBTestCase): person.name = 'Shorter name' person.validate() - def test_decimal_validation(self): - """Ensure that invalid values cannot be assigned to decimal fields. - """ - class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), - max_value=Decimal('3.5')) - - Person.drop_collection() - - Person(height=Decimal('1.89')).save() - person = Person.objects.first() - self.assertEqual(person.height, Decimal('1.89')) - - person.height = '2.0' - person.save() - person.height = 0.01 - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('0.01') - self.assertRaises(ValidationError, person.validate) - person.height = Decimal('4.0') - self.assertRaises(ValidationError, person.validate) - person.height = 'something invalid' - self.assertRaises(ValidationError, person.validate) - - person_2 = Person(height='something invalid') - self.assertRaises(ValidationError, person_2.validate) - def test_db_field_validation(self): """Ensure that db_field doesn't accept invalid values.""" @@ -364,128 +336,9 @@ class FieldTest(MongoDBTestCase): class User(Document): name = StringField(db_field='name\0') - def test_decimal_comparison(self): - class Person(Document): - money = DecimalField() - - Person.drop_collection() - - Person(money=6).save() - Person(money=8).save() - Person(money=10).save() - - self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) - self.assertEqual(2, Person.objects(money__gt=7).count()) - self.assertEqual(2, Person.objects(money__gt="7").count()) - - def test_decimal_storage(self): - class Person(Document): - float_value = DecimalField(precision=4) - string_value = DecimalField(precision=4, force_string=True) - - Person.drop_collection() - values_to_store = [10, 10.1, 10.11, "10.111", Decimal("10.1111"), Decimal("10.11111")] - for store_at_creation in [True, False]: - for value in values_to_store: - # to_python is called explicitly if values were sent in the kwargs of __init__ - if store_at_creation: - Person(float_value=value, string_value=value).save() - else: - person = Person.objects.create() - person.float_value = value - person.string_value = value - person.save() - - # How its stored - expected = [ - {'float_value': 10.0, 'string_value': '10.0000'}, - {'float_value': 10.1, 'string_value': '10.1000'}, - {'float_value': 10.11, 'string_value': '10.1100'}, - {'float_value': 10.111, 'string_value': '10.1110'}, - {'float_value': 10.1111, 'string_value': '10.1111'}, - {'float_value': 10.1111, 'string_value': '10.1111'}] - expected.extend(expected) - actual = list(Person.objects.exclude('id').as_pymongo()) - self.assertEqual(expected, actual) - - # How it comes out locally - expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), - Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] - expected.extend(expected) - for field_name in ['float_value', 'string_value']: - actual = list(Person.objects().scalar(field_name)) - self.assertEqual(expected, actual) - - def test_boolean_validation(self): - """Ensure that invalid values cannot be assigned to boolean - fields. - """ - class Person(Document): - admin = BooleanField() - - person = Person() - person.admin = True - person.validate() - - person.admin = 2 - self.assertRaises(ValidationError, person.validate) - person.admin = 'Yes' - self.assertRaises(ValidationError, person.validate) - person.admin = 'False' - self.assertRaises(ValidationError, person.validate) - - def test_uuid_field_string(self): - """Test UUID fields storing as String - """ - class Person(Document): - api_key = UUIDField(binary=False) - - Person.drop_collection() - - uu = uuid.uuid4() - Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) - - person = Person() - valid = (uuid.uuid4(), uuid.uuid1()) - for api_key in valid: - person.api_key = api_key - person.validate() - - invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', - '9d159858-549b-4975-9f98-dd2f987c113') - for api_key in invalid: - person.api_key = api_key - self.assertRaises(ValidationError, person.validate) - - def test_uuid_field_binary(self): - """Test UUID fields storing as Binary object.""" - class Person(Document): - api_key = UUIDField(binary=True) - - Person.drop_collection() - - uu = uuid.uuid4() - Person(api_key=uu).save() - self.assertEqual(1, Person.objects(api_key=uu).count()) - self.assertEqual(uu, Person.objects.first().api_key) - - person = Person() - valid = (uuid.uuid4(), uuid.uuid1()) - for api_key in valid: - person.api_key = api_key - person.validate() - - invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', - '9d159858-549b-4975-9f98-dd2f987c113') - for api_key in invalid: - person.api_key = api_key - self.assertRaises(ValidationError, person.validate) - def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements.""" - AccessLevelChoices = ( + access_level_choices = ( ('a', u'Administration'), ('b', u'Manager'), ('c', u'Staff'), @@ -505,7 +358,7 @@ class FieldTest(MongoDBTestCase): authors_as_lazy = ListField(LazyReferenceField(User)) generic = ListField(GenericReferenceField()) generic_as_lazy = ListField(GenericLazyReferenceField()) - access_list = ListField(choices=AccessLevelChoices, display_sep=', ') + access_list = ListField(choices=access_level_choices, display_sep=', ') User.drop_collection() BlogPost.drop_collection() @@ -1187,374 +1040,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual( Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) - def test_dict_field(self): - """Ensure that dict types work as expected.""" - class BlogPost(Document): - info = DictField() - - BlogPost.drop_collection() - - post = BlogPost() - post.info = 'my post' - self.assertRaises(ValidationError, post.validate) - - post.info = ['test', 'test'] - self.assertRaises(ValidationError, post.validate) - - post.info = {'$title': 'test'} - self.assertRaises(ValidationError, post.validate) - - post.info = {'nested': {'$title': 'test'}} - self.assertRaises(ValidationError, post.validate) - - post.info = {'the.title': 'test'} - self.assertRaises(ValidationError, post.validate) - - post.info = {'nested': {'the.title': 'test'}} - self.assertRaises(ValidationError, post.validate) - - post.info = {1: 'test'} - self.assertRaises(ValidationError, post.validate) - - post.info = {'title': 'test'} - post.save() - - post = BlogPost() - post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}} - post.save() - - post = BlogPost() - post.info = {'details': {'test': 'test'}} - post.save() - - post = BlogPost() - post.info = {'details': {'test': 3}} - post.save() - - self.assertEqual(BlogPost.objects.count(), 4) - self.assertEqual( - BlogPost.objects.filter(info__title__exact='test').count(), 1) - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact='test').count(), 1) - - post = BlogPost.objects.filter(info__title__exact='dollar_sign').first() - self.assertIn('te$t', post['info']['details']) - - # Confirm handles non strings or non existing keys - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact=5).count(), 0) - self.assertEqual( - BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) - - post = BlogPost.objects.create(info={'title': 'original'}) - post.info.update({'title': 'updated'}) - post.save() - post.reload() - self.assertEqual('updated', post.info['title']) - - post.info.setdefault('authors', []) - post.save() - post.reload() - self.assertEqual([], post.info['authors']) - - def test_dictfield_dump_document(self): - """Ensure a DictField can handle another document's dump.""" - class Doc(Document): - field = DictField() - - class ToEmbed(Document): - id = IntField(primary_key=True, default=1) - recursive = DictField() - - class ToEmbedParent(Document): - id = IntField(primary_key=True, default=1) - recursive = DictField() - - meta = {'allow_inheritance': True} - - class ToEmbedChild(ToEmbedParent): - pass - - to_embed_recursive = ToEmbed(id=1).save() - to_embed = ToEmbed( - id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() - doc = Doc(field=to_embed.to_mongo().to_dict()) - doc.save() - assert isinstance(doc.field, dict) - assert doc.field == {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}} - # Same thing with a Document with a _cls field - to_embed_recursive = ToEmbedChild(id=1).save() - to_embed_child = ToEmbedChild( - id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() - doc = Doc(field=to_embed_child.to_mongo().to_dict()) - doc.save() - assert isinstance(doc.field, dict) - assert doc.field == { - '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild', - 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}} - } - - def test_dictfield_strict(self): - """Ensure that dict field handles validation if provided a strict field type.""" - class Simple(Document): - mapping = DictField(field=IntField()) - - Simple.drop_collection() - - e = Simple() - e.mapping['someint'] = 1 - e.save() - - # try creating an invalid mapping - with self.assertRaises(ValidationError): - e.mapping['somestring'] = "abc" - e.save() - - def test_dictfield_complex(self): - """Ensure that the dict field can handle the complex types.""" - class SettingBase(EmbeddedDocument): - meta = {'allow_inheritance': True} - - class StringSetting(SettingBase): - value = StringField() - - class IntegerSetting(SettingBase): - value = IntField() - - class Simple(Document): - mapping = DictField() - - Simple.drop_collection() - - e = Simple() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', - 'float': 1.001, - 'complex': IntegerSetting(value=42), - 'list': [IntegerSetting(value=42), - StringSetting(value='foo')]} - e.save() - - e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping['somestring'], StringSetting) - self.assertIsInstance(e2.mapping['someint'], IntegerSetting) - - # Test querying - self.assertEqual( - Simple.objects.filter(mapping__someint__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) - - # Confirm can update - Simple.objects().update( - set__mapping={"someint": IntegerSetting(value=10)}) - Simple.objects().update( - set__mapping__nested_dict__list__1=StringSetting(value='Boo')) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) - - def test_atomic_update_dict_field(self): - """Ensure that the entire DictField can be atomically updated.""" - class Simple(Document): - mapping = DictField(field=ListField(IntField(required=True))) - - Simple.drop_collection() - - e = Simple() - e.mapping['someints'] = [1, 2] - e.save() - e.update(set__mapping={"ints": [3, 4]}) - e.reload() - self.assertEqual(BaseDict, type(e.mapping)) - self.assertEqual({"ints": [3, 4]}, e.mapping) - - # try creating an invalid mapping - with self.assertRaises(ValueError): - e.update(set__mapping={"somestrings": ["foo", "bar", ]}) - - def test_dictfield_with_referencefield_complex_nesting_cases(self): - """Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)""" - # Relates to Issue #1453 - class Doc(Document): - s = StringField() - - class Simple(Document): - mapping0 = DictField(ReferenceField(Doc, dbref=True)) - mapping1 = DictField(ReferenceField(Doc, dbref=False)) - mapping2 = DictField(ListField(ReferenceField(Doc, dbref=True))) - mapping3 = DictField(ListField(ReferenceField(Doc, dbref=False))) - mapping4 = DictField(DictField(field=ReferenceField(Doc, dbref=True))) - mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False))) - mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True)))) - mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False)))) - mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))) - mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))) - - Doc.drop_collection() - Simple.drop_collection() - - d = Doc(s='aa').save() - e = Simple() - e.mapping0['someint'] = e.mapping1['someint'] = d - e.mapping2['someint'] = e.mapping3['someint'] = [d] - e.mapping4['someint'] = e.mapping5['someint'] = {'d': d} - e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}] - e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}] - e.save() - - s = Simple.objects.first() - self.assertIsInstance(s.mapping0['someint'], Doc) - self.assertIsInstance(s.mapping1['someint'], Doc) - self.assertIsInstance(s.mapping2['someint'][0], Doc) - self.assertIsInstance(s.mapping3['someint'][0], Doc) - self.assertIsInstance(s.mapping4['someint']['d'], Doc) - self.assertIsInstance(s.mapping5['someint']['d'], Doc) - self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc) - self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc) - self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc) - self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc) - - def test_mapfield(self): - """Ensure that the MapField handles the declared type.""" - class Simple(Document): - mapping = MapField(IntField()) - - Simple.drop_collection() - - e = Simple() - e.mapping['someint'] = 1 - e.save() - - with self.assertRaises(ValidationError): - e.mapping['somestring'] = "abc" - e.save() - - with self.assertRaises(ValidationError): - class NoDeclaredType(Document): - mapping = MapField() - - def test_complex_mapfield(self): - """Ensure that the MapField can handle complex declared types.""" - class SettingBase(EmbeddedDocument): - meta = {"allow_inheritance": True} - - class StringSetting(SettingBase): - value = StringField() - - class IntegerSetting(SettingBase): - value = IntField() - - class Extensible(Document): - mapping = MapField(EmbeddedDocumentField(SettingBase)) - - Extensible.drop_collection() - - e = Extensible() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.save() - - e2 = Extensible.objects.get(id=e.id) - self.assertIsInstance(e2.mapping['somestring'], StringSetting) - self.assertIsInstance(e2.mapping['someint'], IntegerSetting) - - with self.assertRaises(ValidationError): - e.mapping['someint'] = 123 - e.save() - - def test_embedded_mapfield_db_field(self): - class Embedded(EmbeddedDocument): - number = IntField(default=0, db_field='i') - - class Test(Document): - my_map = MapField(field=EmbeddedDocumentField(Embedded), - db_field='x') - - Test.drop_collection() - - test = Test() - test.my_map['DICTIONARY_KEY'] = Embedded(number=1) - test.save() - - Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) - - test = Test.objects.get() - self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2) - doc = self.db.test.find_one() - self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) - - def test_mapfield_numerical_index(self): - """Ensure that MapField accept numeric strings as indexes.""" - class Embedded(EmbeddedDocument): - name = StringField() - - class Test(Document): - my_map = MapField(EmbeddedDocumentField(Embedded)) - - Test.drop_collection() - - test = Test() - test.my_map['1'] = Embedded(name='test') - test.save() - test.my_map['1'].name = 'test updated' - test.save() - - def test_map_field_lookup(self): - """Ensure MapField lookups succeed on Fields without a lookup - method. - """ - class Action(EmbeddedDocument): - operation = StringField() - object = StringField() - - class Log(Document): - name = StringField() - visited = MapField(DateTimeField()) - actions = MapField(EmbeddedDocumentField(Action)) - - Log.drop_collection() - Log(name="wilson", visited={'friends': datetime.datetime.now()}, - actions={'friends': Action(operation='drink', object='beer')}).save() - - self.assertEqual(1, Log.objects( - visited__friends__exists=True).count()) - - self.assertEqual(1, Log.objects( - actions__friends__operation='drink', - actions__friends__object='beer').count()) - - def test_map_field_unicode(self): - class Info(EmbeddedDocument): - description = StringField() - value_list = ListField(field=StringField()) - - class BlogPost(Document): - info_dict = MapField(field=EmbeddedDocumentField(Info)) - - BlogPost.drop_collection() - - tree = BlogPost(info_dict={ - u"éééé": { - 'description': u"VALUE: éééé" - } - }) - - tree.save() - - self.assertEqual( - BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, - u"VALUE: éééé" - ) - def test_embedded_db_field(self): class Embedded(EmbeddedDocument): number = IntField(default=0, db_field='i') @@ -1741,121 +1226,6 @@ class FieldTest(MongoDBTestCase): bar._fields['generic_ref']._auto_dereference = False self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'}) - def test_reference_validation(self): - """Ensure that invalid document objects cannot be assigned to - reference fields. - """ - class User(Document): - name = StringField() - - class BlogPost(Document): - content = StringField() - author = ReferenceField(User) - - User.drop_collection() - BlogPost.drop_collection() - - # Make sure ReferenceField only accepts a document class or a string - # with a document class name. - self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) - - user = User(name='Test User') - - # Ensure that the referenced object must have been saved - post1 = BlogPost(content='Chips and gravy taste good.') - post1.author = user - self.assertRaises(ValidationError, post1.save) - - # Check that an invalid object type cannot be used - post2 = BlogPost(content='Chips and chilli taste good.') - post1.author = post2 - self.assertRaises(ValidationError, post1.validate) - - # Ensure ObjectID's are accepted as references - user_object_id = user.pk - post3 = BlogPost(content="Chips and curry sauce taste good.") - post3.author = user_object_id - post3.save() - - # Make sure referencing a saved document of the right type works - user.save() - post1.author = user - post1.save() - - # Make sure referencing a saved document of the *wrong* type fails - post2.save() - post1.author = post2 - self.assertRaises(ValidationError, post1.validate) - - def test_objectid_reference_fields(self): - """Make sure storing Object ID references works.""" - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1.pk).save() - - p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) - - def test_dbref_reference_fields(self): - """Make sure storing references as bson.dbref.DBRef works.""" - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=True) - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1).save() - - self.assertEqual( - Person._get_collection().find_one({'name': 'Ross'})['parent'], - DBRef('person', p1.pk) - ) - - p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) - - def test_dbref_to_mongo(self): - """Make sure that calling to_mongo on a ReferenceField which - has dbref=False, but actually actually contains a DBRef returns - an ID of that DBRef. - """ - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=False) - - p = Person( - name='Steve', - parent=DBRef('person', 'abcdefghijklmnop') - ) - self.assertEqual(p.to_mongo(), SON([ - ('name', u'Steve'), - ('parent', 'abcdefghijklmnop') - ])) - - def test_objectid_reference_fields(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=False) - - Person.drop_collection() - - p1 = Person(name="John").save() - Person(name="Ross", parent=p1).save() - - col = Person._get_collection() - data = col.find_one({'name': 'Ross'}) - self.assertEqual(data['parent'], p1.pk) - - p = Person.objects.get(name="Ross") - self.assertEqual(p.parent, p1) - def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -1972,99 +1342,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual(tree.children[0].children[0].name, second_child.name) self.assertEqual(tree.children[0].children[1].name, third_child.name) - def test_undefined_reference(self): - """Ensure that ReferenceFields may reference undefined Documents. - """ - class Product(Document): - name = StringField() - company = ReferenceField('Company') - - class Company(Document): - name = StringField() - - Product.drop_collection() - Company.drop_collection() - - ten_gen = Company(name='10gen') - ten_gen.save() - mongodb = Product(name='MongoDB', company=ten_gen) - mongodb.save() - - me = Product(name='MongoEngine') - me.save() - - obj = Product.objects(company=ten_gen).first() - self.assertEqual(obj, mongodb) - self.assertEqual(obj.company, ten_gen) - - obj = Product.objects(company=None).first() - self.assertEqual(obj, me) - - obj = Product.objects.get(company=None) - self.assertEqual(obj, me) - - def test_reference_query_conversion(self): - """Ensure that ReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = ReferenceField(Member, dbref=False) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - - def test_reference_query_conversion_dbref(self): - """Ensure that ReferenceFields can be queried using objects and values - of the type of the primary key of the referenced object. - """ - class Member(Document): - user_num = IntField(primary_key=True) - - class BlogPost(Document): - title = StringField() - author = ReferenceField(Member, dbref=True) - - Member.drop_collection() - BlogPost.drop_collection() - - m1 = Member(user_num=1) - m1.save() - m2 = Member(user_num=2) - m2.save() - - post1 = BlogPost(title='post 1', author=m1) - post1.save() - - post2 = BlogPost(title='post 2', author=m2) - post2.save() - - post = BlogPost.objects(author=m1).first() - self.assertEqual(post.id, post1.id) - - post = BlogPost.objects(author=m2).first() - self.assertEqual(post.id, post2.id) - def test_drop_abstract_document(self): """Ensure that an abstract document cannot be dropped given it has no underlying collection. @@ -2681,283 +1958,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual(error_dict['size'], SIZE_MESSAGE) self.assertEqual(error_dict['color'], COLOR_MESSAGE) - def test_ensure_unique_default_instances(self): - """Ensure that every field has it's own unique default instance.""" - class D(Document): - data = DictField() - data2 = DictField(default=lambda: {}) - - d1 = D() - d1.data['foo'] = 'bar' - d1.data2['foo'] = 'bar' - d2 = D() - self.assertEqual(d2.data, {}) - self.assertEqual(d2.data2, {}) - - def test_sequence_field(self): - class Person(Document): - id = SequenceField(primary_key=True) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 1000) - - def test_sequence_field_get_next_value(self): - class Person(Document): - id = SequenceField(primary_key=True) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - self.assertEqual(Person.id.get_next_value(), 11) - self.db['mongoengine.counters'].drop() - - self.assertEqual(Person.id.get_next_value(), 1) - - class Person(Document): - id = SequenceField(primary_key=True, value_decorator=str) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - self.assertEqual(Person.id.get_next_value(), '11') - self.db['mongoengine.counters'].drop() - - self.assertEqual(Person.id.get_next_value(), '1') - - def test_sequence_field_sequence_name(self): - class Person(Document): - id = SequenceField(primary_key=True, sequence_name='jelly') - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 10) - - Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 1000) - - def test_multiple_sequence_fields(self): - class Person(Document): - id = SequenceField(primary_key=True) - counter = SequenceField() - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - counters = [i.counter for i in Person.objects] - self.assertEqual(counters, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 1000) - - Person.counter.set_next_value(999) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) - self.assertEqual(c['next'], 999) - - def test_sequence_fields_reload(self): - class Animal(Document): - counter = SequenceField() - name = StringField() - - self.db['mongoengine.counters'].drop() - Animal.drop_collection() - - a = Animal(name="Boi").save() - - self.assertEqual(a.counter, 1) - a.reload() - self.assertEqual(a.counter, 1) - - a.counter = None - self.assertEqual(a.counter, 2) - a.save() - - self.assertEqual(a.counter, 2) - - a = Animal.objects.first() - self.assertEqual(a.counter, 2) - a.reload() - self.assertEqual(a.counter, 2) - - def test_multiple_sequence_fields_on_docs(self): - class Animal(Document): - id = SequenceField(primary_key=True) - name = StringField() - - class Person(Document): - id = SequenceField(primary_key=True) - name = StringField() - - self.db['mongoengine.counters'].drop() - Animal.drop_collection() - Person.drop_collection() - - for x in range(10): - Animal(name="Animal %s" % x).save() - Person(name="Person %s" % x).save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, range(1, 11)) - - id = [i.id for i in Animal.objects] - self.assertEqual(id, range(1, 11)) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) - self.assertEqual(c['next'], 10) - - def test_sequence_field_value_decorator(self): - class Person(Document): - id = SequenceField(primary_key=True, value_decorator=str) - name = StringField() - - self.db['mongoengine.counters'].drop() - Person.drop_collection() - - for x in range(10): - p = Person(name="Person %s" % x) - p.save() - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - ids = [i.id for i in Person.objects] - self.assertEqual(ids, map(str, range(1, 11))) - - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) - - def test_embedded_sequence_field(self): - class Comment(EmbeddedDocument): - id = SequenceField() - content = StringField(required=True) - - class Post(Document): - title = StringField(required=True) - comments = ListField(EmbeddedDocumentField(Comment)) - - self.db['mongoengine.counters'].drop() - Post.drop_collection() - - Post(title="MongoEngine", - comments=[Comment(content="NoSQL Rocks"), - Comment(content="MongoEngine Rocks")]).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'}) - self.assertEqual(c['next'], 2) - post = Post.objects.first() - self.assertEqual(1, post.comments[0].id) - self.assertEqual(2, post.comments[1].id) - - def test_inherited_sequencefield(self): - class Base(Document): - name = StringField() - counter = SequenceField() - meta = {'abstract': True} - - class Foo(Base): - pass - - class Bar(Base): - pass - - bar = Bar(name='Bar') - bar.save() - - foo = Foo(name='Foo') - foo.save() - - self.assertTrue('base.counter' in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertFalse(('foo.counter' or 'bar.counter') in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertNotEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields['counter'].owner_document, Base) - self.assertEqual(bar._fields['counter'].owner_document, Base) - - def test_no_inherited_sequencefield(self): - class Base(Document): - name = StringField() - meta = {'abstract': True} - - class Foo(Base): - counter = SequenceField() - - class Bar(Base): - counter = SequenceField() - - bar = Bar(name='Bar') - bar.save() - - foo = Foo(name='Foo') - foo.save() - - self.assertFalse('base.counter' in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertTrue(('foo.counter' and 'bar.counter') in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields['counter'].owner_document, Foo) - self.assertEqual(bar._fields['counter'].owner_document, Bar) - def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField() @@ -3078,117 +2078,6 @@ class FieldTest(MongoDBTestCase): post.comments[1].content = 'here we go' post.validate() - def test_email_field(self): - class User(Document): - email = EmailField() - - user = User(email='ross@example.com') - user.validate() - - user = User(email='ross@example.co.uk') - user.validate() - - user = User(email=('Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S' - 'aJIazqqWkm7.net')) - user.validate() - - user = User(email='new-tld@example.technology') - user.validate() - - user = User(email='ross@example.com.') - self.assertRaises(ValidationError, user.validate) - - # unicode domain - user = User(email=u'user@пример.рф') - user.validate() - - # invalid unicode domain - user = User(email=u'user@пример') - self.assertRaises(ValidationError, user.validate) - - # invalid data type - user = User(email=123) - self.assertRaises(ValidationError, user.validate) - - def test_email_field_unicode_user(self): - # Don't run this test on pypy3, which doesn't support unicode regex: - # https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode - if sys.version_info[:2] == (3, 2): - raise SkipTest('unicode email addresses are not supported on PyPy 3') - - class User(Document): - email = EmailField() - - # unicode user shouldn't validate by default... - user = User(email=u'Dörte@Sörensen.example.com') - self.assertRaises(ValidationError, user.validate) - - # ...but it should be fine with allow_utf8_user set to True - class User(Document): - email = EmailField(allow_utf8_user=True) - - user = User(email=u'Dörte@Sörensen.example.com') - user.validate() - - def test_email_field_domain_whitelist(self): - class User(Document): - email = EmailField() - - # localhost domain shouldn't validate by default... - user = User(email='me@localhost') - self.assertRaises(ValidationError, user.validate) - - # ...but it should be fine if it's whitelisted - class User(Document): - email = EmailField(domain_whitelist=['localhost']) - - user = User(email='me@localhost') - user.validate() - - def test_email_field_ip_domain(self): - class User(Document): - email = EmailField() - - valid_ipv4 = 'email@[127.0.0.1]' - valid_ipv6 = 'email@[2001:dB8::1]' - invalid_ip = 'email@[324.0.0.1]' - - # IP address as a domain shouldn't validate by default... - user = User(email=valid_ipv4) - self.assertRaises(ValidationError, user.validate) - - user = User(email=valid_ipv6) - self.assertRaises(ValidationError, user.validate) - - user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) - - # ...but it should be fine with allow_ip_domain set to True - class User(Document): - email = EmailField(allow_ip_domain=True) - - user = User(email=valid_ipv4) - user.validate() - - user = User(email=valid_ipv6) - user.validate() - - # invalid IP should still fail validation - user = User(email=invalid_ip) - self.assertRaises(ValidationError, user.validate) - - def test_email_field_honors_regex(self): - class User(Document): - email = EmailField(regex=r'\w+@example.com') - - # Fails regex validation - user = User(email='me@foo.com') - self.assertRaises(ValidationError, user.validate) - - # Passes regex validation - user = User(email='me@example.com') - self.assertIsNone(user.validate()) - def test_tuples_as_tuples(self): """Ensure that tuples remain tuples when they are inside a ComplexBaseField. @@ -3289,36 +2178,6 @@ class FieldTest(MongoDBTestCase): assert isinstance(doc.field, ToEmbedChild) assert doc.field == to_embed_child - def test_dict_field_invalid_dict_value(self): - class DictFieldTest(Document): - dictionary = DictField(required=True) - - DictFieldTest.drop_collection() - - test = DictFieldTest(dictionary=None) - test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) - - test = DictFieldTest(dictionary=False) - test.dictionary # Just access to test getter - self.assertRaises(ValidationError, test.validate) - - def test_dict_field_raises_validation_error_if_wrongly_assign_embedded_doc(self): - class DictFieldTest(Document): - dictionary = DictField(required=True) - - DictFieldTest.drop_collection() - - class Embedded(EmbeddedDocument): - name = StringField() - - embed = Embedded(name='garbage') - doc = DictFieldTest(dictionary=embed) - with self.assertRaises(ValidationError) as ctx_err: - doc.validate() - self.assertIn("'dictionary'", str(ctx_err.exception)) - self.assertIn('Only dictionaries may be used in a DictField', str(ctx_err.exception)) - def test_cls_field(self): class Animal(Document): meta = {'allow_inheritance': True} @@ -3882,442 +2741,5 @@ class TestEmbeddedDocumentField(MongoDBTestCase): emb = EmbeddedDocumentField('MyDoc') -class CachedReferenceFieldTest(MongoDBTestCase): - - def test_cached_reference_field_get_and_save(self): - """ - Tests #1047: CachedReferenceField creates DBRefs on to_python, - but can't save them on to_mongo. - """ - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField(Animal) - - Animal.drop_collection() - Ocorrence.drop_collection() - - Ocorrence(person="testte", - animal=Animal(name="Leopard", tag="heavy").save()).save() - p = Ocorrence.objects.get() - p.person = 'new_testte' - p.save() - - def test_cached_reference_fields(self): - class Animal(Document): - name = StringField() - tag = StringField() - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag']) - - Animal.drop_collection() - Ocorrence.drop_collection() - - a = Animal(name="Leopard", tag="heavy") - a.save() - - self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) - o = Ocorrence(person="teste", animal=a) - o.save() - - p = Ocorrence(person="Wilson") - p.save() - - self.assertEqual(Ocorrence.objects(animal=None).count(), 1) - - self.assertEqual( - a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) - - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - - # counts - Ocorrence(person="teste 2").save() - Ocorrence(person="teste 3").save() - - count = Ocorrence.objects(animal__tag='heavy').count() - self.assertEqual(count, 1) - - ocorrence = Ocorrence.objects(animal__tag='heavy').first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) - - def test_cached_reference_field_decimal(self): - class PersonAuto(Document): - name = StringField() - salary = DecimalField() - - class SocialTest(Document): - group = StringField() - person = CachedReferenceField( - PersonAuto, - fields=('salary',)) - - PersonAuto.drop_collection() - SocialTest.drop_collection() - - p = PersonAuto(name="Alberto", salary=Decimal('7000.00')) - p.save() - - s = SocialTest(group="dev", person=p) - s.save() - - self.assertEqual( - SocialTest.objects._collection.find_one({'person.salary': 7000.00}), { - '_id': s.pk, - 'group': s.group, - 'person': { - '_id': p.pk, - 'salary': 7000.00 - } - }) - - def test_cached_reference_field_reference(self): - class Group(Document): - name = StringField() - - class Person(Document): - name = StringField() - group = ReferenceField(Group) - - class SocialData(Document): - obs = StringField() - tags = ListField( - StringField()) - person = CachedReferenceField( - Person, - fields=('group',)) - - Group.drop_collection() - Person.drop_collection() - SocialData.drop_collection() - - g1 = Group(name='dev') - g1.save() - - g2 = Group(name="designers") - g2.save() - - p1 = Person(name="Alberto", group=g1) - p1.save() - - p2 = Person(name="Andre", group=g1) - p2.save() - - p3 = Person(name="Afro design", group=g2) - p3.save() - - s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2']) - s1.save() - - s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4']) - s2.save() - - self.assertEqual(SocialData.objects._collection.find_one( - {'tags': 'tag2'}), { - '_id': s1.pk, - 'obs': 'testing 123', - 'tags': ['tag1', 'tag2'], - 'person': { - '_id': p1.pk, - 'group': g1.pk - } - }) - - self.assertEqual(SocialData.objects(person__group=g2).count(), 1) - self.assertEqual(SocialData.objects(person__group=g2).first(), s2) - - def test_cached_reference_field_push_with_fields(self): - class Product(Document): - name = StringField() - - Product.drop_collection() - - class Basket(Document): - products = ListField(CachedReferenceField(Product, fields=['name'])) - - Basket.drop_collection() - product1 = Product(name='abc').save() - product2 = Product(name='def').save() - basket = Basket(products=[product1]).save() - self.assertEqual( - Basket.objects._collection.find_one(), - { - '_id': basket.pk, - 'products': [ - { - '_id': product1.pk, - 'name': product1.name - } - ] - } - ) - # push to list - basket.update(push__products=product2) - basket.reload() - self.assertEqual( - Basket.objects._collection.find_one(), - { - '_id': basket.pk, - 'products': [ - { - '_id': product1.pk, - 'name': product1.name - }, - { - '_id': product2.pk, - 'name': product2.name - } - ] - } - ) - - def test_cached_reference_field_update_all(self): - class Person(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) - name = StringField() - tp = StringField( - choices=TYPES - ) - - father = CachedReferenceField('self', fields=('tp',)) - - Person.drop_collection() - - a1 = Person(name="Wilson Father", tp="pj") - a1.save() - - a2 = Person(name='Wilson Junior', tp='pf', father=a1) - a2.save() - - self.assertEqual(dict(a2.to_mongo()), { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": { - "_id": a1.pk, - "tp": u"pj" - } - }) - - self.assertEqual(Person.objects(father=a1)._query, { - 'father._id': a1.pk - }) - self.assertEqual(Person.objects(father=a1).count(), 1) - - Person.objects.update(set__tp="pf") - Person.father.sync_all() - - a2.reload() - self.assertEqual(dict(a2.to_mongo()), { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": { - "_id": a1.pk, - "tp": u"pf" - } - }) - - def test_cached_reference_fields_on_embedded_documents(self): - with self.assertRaises(InvalidDocumentError): - class Test(Document): - name = StringField() - - type('WrongEmbeddedDocument', ( - EmbeddedDocument,), { - 'test': CachedReferenceField(Test) - }) - - def test_cached_reference_auto_sync(self): - class Person(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) - name = StringField() - tp = StringField( - choices=TYPES - ) - - father = CachedReferenceField('self', fields=('tp',)) - - Person.drop_collection() - - a1 = Person(name="Wilson Father", tp="pj") - a1.save() - - a2 = Person(name='Wilson Junior', tp='pf', father=a1) - a2.save() - - a1.tp = 'pf' - a1.save() - - a2.reload() - self.assertEqual(dict(a2.to_mongo()), { - '_id': a2.pk, - 'name': 'Wilson Junior', - 'tp': 'pf', - 'father': { - '_id': a1.pk, - 'tp': 'pf' - } - }) - - def test_cached_reference_auto_sync_disabled(self): - class Persone(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) - name = StringField() - tp = StringField( - choices=TYPES - ) - - father = CachedReferenceField( - 'self', fields=('tp',), auto_sync=False) - - Persone.drop_collection() - - a1 = Persone(name="Wilson Father", tp="pj") - a1.save() - - a2 = Persone(name='Wilson Junior', tp='pf', father=a1) - a2.save() - - a1.tp = 'pf' - a1.save() - - self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), { - '_id': a2.pk, - 'name': 'Wilson Junior', - 'tp': 'pf', - 'father': { - '_id': a1.pk, - 'tp': 'pj' - } - }) - - def test_cached_reference_embedded_fields(self): - class Owner(EmbeddedDocument): - TPS = ( - ('n', "Normal"), - ('u', "Urgent") - ) - name = StringField() - tp = StringField( - verbose_name="Type", - db_field="t", - choices=TPS) - - class Animal(Document): - name = StringField() - tag = StringField() - - owner = EmbeddedDocumentField(Owner) - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag', 'owner.tp']) - - Animal.drop_collection() - Ocorrence.drop_collection() - - a = Animal(name="Leopard", tag="heavy", - owner=Owner(tp='u', name="Wilson Júnior") - ) - a.save() - - o = Ocorrence(person="teste", animal=a) - o.save() - self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { - '_id': a.pk, - 'tag': 'heavy', - 'owner': { - 't': 'u' - } - }) - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') - - # counts - Ocorrence(person="teste 2").save() - Ocorrence(person="teste 3").save() - - count = Ocorrence.objects( - animal__tag='heavy', animal__owner__tp='u').count() - self.assertEqual(count, 1) - - ocorrence = Ocorrence.objects( - animal__tag='heavy', - animal__owner__tp='u').first() - self.assertEqual(ocorrence.person, "teste") - self.assertIsInstance(ocorrence.animal, Animal) - - def test_cached_reference_embedded_list_fields(self): - class Owner(EmbeddedDocument): - name = StringField() - tags = ListField(StringField()) - - class Animal(Document): - name = StringField() - tag = StringField() - - owner = EmbeddedDocumentField(Owner) - - class Ocorrence(Document): - person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag', 'owner.tags']) - - Animal.drop_collection() - Ocorrence.drop_collection() - - a = Animal(name="Leopard", tag="heavy", - owner=Owner(tags=['cool', 'funny'], - name="Wilson Júnior") - ) - a.save() - - o = Ocorrence(person="teste 2", animal=a) - o.save() - self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { - '_id': a.pk, - 'tag': 'heavy', - 'owner': { - 'tags': ['cool', 'funny'] - } - }) - - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - self.assertEqual(o.to_mongo()['animal']['owner']['tags'], - ['cool', 'funny']) - - # counts - Ocorrence(person="teste 2").save() - Ocorrence(person="teste 3").save() - - query = Ocorrence.objects( - animal__tag='heavy', animal__owner__tags='cool')._query - self.assertEqual( - query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) - - ocorrence = Ocorrence.objects( - animal__tag='heavy', - animal__owner__tags='cool').first() - self.assertEqual(ocorrence.person, "teste 2") - self.assertIsInstance(ocorrence.animal, Animal) - - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py new file mode 100644 index 00000000..7a2a3db6 --- /dev/null +++ b/tests/fields/test_boolean_field.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +from mongoengine import * + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class TestBooleanField(MongoDBTestCase): + def test_storage(self): + class Person(Document): + admin = BooleanField() + + person = Person(admin=True) + person.save() + self.assertEqual( + get_as_pymongo(person), + {'_id': person.id, + 'admin': True}) + + def test_validation(self): + """Ensure that invalid values cannot be assigned to boolean + fields. + """ + class Person(Document): + admin = BooleanField() + + person = Person() + person.admin = True + person.validate() + + person.admin = 2 + self.assertRaises(ValidationError, person.validate) + person.admin = 'Yes' + self.assertRaises(ValidationError, person.validate) + person.admin = 'False' + self.assertRaises(ValidationError, person.validate) + + def test_weirdness_constructor(self): + """When attribute is set in contructor, it gets cast into a bool + which causes some weird behavior. We dont necessarily want to maintain this behavior + but its a known issue + """ + class Person(Document): + admin = BooleanField() + + new_person = Person(admin='False') + self.assertTrue(new_person.admin) + + new_person = Person(admin='0') + self.assertTrue(new_person.admin) diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py new file mode 100644 index 00000000..989cea6d --- /dev/null +++ b/tests/fields/test_cached_reference_field.py @@ -0,0 +1,443 @@ +# -*- coding: utf-8 -*- +from decimal import Decimal + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestCachedReferenceField(MongoDBTestCase): + + def test_get_and_save(self): + """ + Tests #1047: CachedReferenceField creates DBRefs on to_python, + but can't save them on to_mongo. + """ + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField(Animal) + + Animal.drop_collection() + Ocorrence.drop_collection() + + Ocorrence(person="testte", + animal=Animal(name="Leopard", tag="heavy").save()).save() + p = Ocorrence.objects.get() + p.person = 'new_testte' + p.save() + + def test_general_things(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(name="Leopard", tag="heavy") + a.save() + + self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal]) + o = Ocorrence(person="teste", animal=a) + o.save() + + p = Ocorrence(person="Wilson") + p.save() + + self.assertEqual(Ocorrence.objects(animal=None).count(), 1) + + self.assertEqual( + a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects(animal__tag='heavy').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects(animal__tag='heavy').first() + self.assertEqual(ocorrence.person, "teste") + self.assertIsInstance(ocorrence.animal, Animal) + + def test_with_decimal(self): + class PersonAuto(Document): + name = StringField() + salary = DecimalField() + + class SocialTest(Document): + group = StringField() + person = CachedReferenceField( + PersonAuto, + fields=('salary',)) + + PersonAuto.drop_collection() + SocialTest.drop_collection() + + p = PersonAuto(name="Alberto", salary=Decimal('7000.00')) + p.save() + + s = SocialTest(group="dev", person=p) + s.save() + + self.assertEqual( + SocialTest.objects._collection.find_one({'person.salary': 7000.00}), { + '_id': s.pk, + 'group': s.group, + 'person': { + '_id': p.pk, + 'salary': 7000.00 + } + }) + + def test_cached_reference_field_reference(self): + class Group(Document): + name = StringField() + + class Person(Document): + name = StringField() + group = ReferenceField(Group) + + class SocialData(Document): + obs = StringField() + tags = ListField( + StringField()) + person = CachedReferenceField( + Person, + fields=('group',)) + + Group.drop_collection() + Person.drop_collection() + SocialData.drop_collection() + + g1 = Group(name='dev') + g1.save() + + g2 = Group(name="designers") + g2.save() + + p1 = Person(name="Alberto", group=g1) + p1.save() + + p2 = Person(name="Andre", group=g1) + p2.save() + + p3 = Person(name="Afro design", group=g2) + p3.save() + + s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2']) + s1.save() + + s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4']) + s2.save() + + self.assertEqual(SocialData.objects._collection.find_one( + {'tags': 'tag2'}), { + '_id': s1.pk, + 'obs': 'testing 123', + 'tags': ['tag1', 'tag2'], + 'person': { + '_id': p1.pk, + 'group': g1.pk + } + }) + + self.assertEqual(SocialData.objects(person__group=g2).count(), 1) + self.assertEqual(SocialData.objects(person__group=g2).first(), s2) + + def test_cached_reference_field_push_with_fields(self): + class Product(Document): + name = StringField() + + Product.drop_collection() + + class Basket(Document): + products = ListField(CachedReferenceField(Product, fields=['name'])) + + Basket.drop_collection() + product1 = Product(name='abc').save() + product2 = Product(name='def').save() + basket = Basket(products=[product1]).save() + self.assertEqual( + Basket.objects._collection.find_one(), + { + '_id': basket.pk, + 'products': [ + { + '_id': product1.pk, + 'name': product1.name + } + ] + } + ) + # push to list + basket.update(push__products=product2) + basket.reload() + self.assertEqual( + Basket.objects._collection.find_one(), + { + '_id': basket.pk, + 'products': [ + { + '_id': product1.pk, + 'name': product1.name + }, + { + '_id': product2.pk, + 'name': product2.name + } + ] + } + ) + + def test_cached_reference_field_update_all(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pj" + } + }) + + self.assertEqual(Person.objects(father=a1)._query, { + 'father._id': a1.pk + }) + self.assertEqual(Person.objects(father=a1).count(), 1) + + Person.objects.update(set__tp="pf") + Person.father.sync_all() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": { + "_id": a1.pk, + "tp": u"pf" + } + }) + + def test_cached_reference_fields_on_embedded_documents(self): + with self.assertRaises(InvalidDocumentError): + class Test(Document): + name = StringField() + + type('WrongEmbeddedDocument', ( + EmbeddedDocument,), { + 'test': CachedReferenceField(Test) + }) + + def test_cached_reference_auto_sync(self): + class Person(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField('self', fields=('tp',)) + + Person.drop_collection() + + a1 = Person(name="Wilson Father", tp="pj") + a1.save() + + a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a1.tp = 'pf' + a1.save() + + a2.reload() + self.assertEqual(dict(a2.to_mongo()), { + '_id': a2.pk, + 'name': 'Wilson Junior', + 'tp': 'pf', + 'father': { + '_id': a1.pk, + 'tp': 'pf' + } + }) + + def test_cached_reference_auto_sync_disabled(self): + class Persone(Document): + TYPES = ( + ('pf', "PF"), + ('pj', "PJ") + ) + name = StringField() + tp = StringField( + choices=TYPES + ) + + father = CachedReferenceField( + 'self', fields=('tp',), auto_sync=False) + + Persone.drop_collection() + + a1 = Persone(name="Wilson Father", tp="pj") + a1.save() + + a2 = Persone(name='Wilson Junior', tp='pf', father=a1) + a2.save() + + a1.tp = 'pf' + a1.save() + + self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), { + '_id': a2.pk, + 'name': 'Wilson Junior', + 'tp': 'pf', + 'father': { + '_id': a1.pk, + 'tp': 'pj' + } + }) + + def test_cached_reference_embedded_fields(self): + class Owner(EmbeddedDocument): + TPS = ( + ('n', "Normal"), + ('u', "Urgent") + ) + name = StringField() + tp = StringField( + verbose_name="Type", + db_field="t", + choices=TPS) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tp']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(name="Leopard", tag="heavy", + owner=Owner(tp='u', name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 't': 'u' + } + }) + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + count = Ocorrence.objects( + animal__tag='heavy', animal__owner__tp='u').count() + self.assertEqual(count, 1) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tp='u').first() + self.assertEqual(ocorrence.person, "teste") + self.assertIsInstance(ocorrence.animal, Animal) + + def test_cached_reference_embedded_list_fields(self): + class Owner(EmbeddedDocument): + name = StringField() + tags = ListField(StringField()) + + class Animal(Document): + name = StringField() + tag = StringField() + + owner = EmbeddedDocumentField(Owner) + + class Ocorrence(Document): + person = StringField() + animal = CachedReferenceField( + Animal, fields=['tag', 'owner.tags']) + + Animal.drop_collection() + Ocorrence.drop_collection() + + a = Animal(name="Leopard", tag="heavy", + owner=Owner(tags=['cool', 'funny'], + name="Wilson Júnior") + ) + a.save() + + o = Ocorrence(person="teste 2", animal=a) + o.save() + self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { + '_id': a.pk, + 'tag': 'heavy', + 'owner': { + 'tags': ['cool', 'funny'] + } + }) + + self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()['animal']['owner']['tags'], + ['cool', 'funny']) + + # counts + Ocorrence(person="teste 2").save() + Ocorrence(person="teste 3").save() + + query = Ocorrence.objects( + animal__tag='heavy', animal__owner__tags='cool')._query + self.assertEqual( + query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) + + ocorrence = Ocorrence.objects( + animal__tag='heavy', + animal__owner__tags='cool').first() + self.assertEqual(ocorrence.person, "teste 2") + self.assertIsInstance(ocorrence.animal, Animal) diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py index bac534c0..58dc4b43 100644 --- a/tests/fields/test_complex_datetime_field.py +++ b/tests/fields/test_complex_datetime_field.py @@ -4,11 +4,6 @@ import math import itertools import re -try: - from bson.int64 import Int64 -except ImportError: - Int64 = long - from mongoengine import * from tests.utils import MongoDBTestCase diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index b5aed5c1..82adb514 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -1,13 +1,5 @@ # -*- coding: utf-8 -*- import datetime -import unittest -import uuid -import math -import itertools -import re -import sys - -from nose.plugins.skip import SkipTest import six try: @@ -15,18 +7,7 @@ try: except ImportError: dateutil = None -from decimal import Decimal - -from bson import Binary, DBRef, ObjectId, SON -try: - from bson.int64 import Int64 -except ImportError: - Int64 = long - from mongoengine import * -from mongoengine.connection import get_db -from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, - _document_registry, LazyReference) from tests.utils import MongoDBTestCase diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index 24d1c777..c6253043 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -7,11 +7,6 @@ try: except ImportError: dateutil = None -try: - from bson.int64 import Int64 -except ImportError: - Int64 = long - from mongoengine import * from mongoengine import connection diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py new file mode 100644 index 00000000..0213b880 --- /dev/null +++ b/tests/fields/test_decimal_field.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +from decimal import Decimal + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestDecimalField(MongoDBTestCase): + + def test_validation(self): + """Ensure that invalid values cannot be assigned to decimal fields. + """ + class Person(Document): + height = DecimalField(min_value=Decimal('0.1'), + max_value=Decimal('3.5')) + + Person.drop_collection() + + Person(height=Decimal('1.89')).save() + person = Person.objects.first() + self.assertEqual(person.height, Decimal('1.89')) + + person.height = '2.0' + person.save() + person.height = 0.01 + self.assertRaises(ValidationError, person.validate) + person.height = Decimal('0.01') + self.assertRaises(ValidationError, person.validate) + person.height = Decimal('4.0') + self.assertRaises(ValidationError, person.validate) + person.height = 'something invalid' + self.assertRaises(ValidationError, person.validate) + + person_2 = Person(height='something invalid') + self.assertRaises(ValidationError, person_2.validate) + + def test_comparison(self): + class Person(Document): + money = DecimalField() + + Person.drop_collection() + + Person(money=6).save() + Person(money=7).save() + Person(money=8).save() + Person(money=10).save() + + self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) + self.assertEqual(2, Person.objects(money__gt=7).count()) + self.assertEqual(2, Person.objects(money__gt="7").count()) + + self.assertEqual(3, Person.objects(money__gte="7").count()) + + def test_storage(self): + class Person(Document): + float_value = DecimalField(precision=4) + string_value = DecimalField(precision=4, force_string=True) + + Person.drop_collection() + values_to_store = [10, 10.1, 10.11, "10.111", Decimal("10.1111"), Decimal("10.11111")] + for store_at_creation in [True, False]: + for value in values_to_store: + # to_python is called explicitly if values were sent in the kwargs of __init__ + if store_at_creation: + Person(float_value=value, string_value=value).save() + else: + person = Person.objects.create() + person.float_value = value + person.string_value = value + person.save() + + # How its stored + expected = [ + {'float_value': 10.0, 'string_value': '10.0000'}, + {'float_value': 10.1, 'string_value': '10.1000'}, + {'float_value': 10.11, 'string_value': '10.1100'}, + {'float_value': 10.111, 'string_value': '10.1110'}, + {'float_value': 10.1111, 'string_value': '10.1111'}, + {'float_value': 10.1111, 'string_value': '10.1111'}] + expected.extend(expected) + actual = list(Person.objects.exclude('id').as_pymongo()) + self.assertEqual(expected, actual) + + # How it comes out locally + expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), + Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] + expected.extend(expected) + for field_name in ['float_value', 'string_value']: + actual = list(Person.objects().scalar(field_name)) + self.assertEqual(expected, actual) diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py new file mode 100644 index 00000000..2b9cecb7 --- /dev/null +++ b/tests/fields/test_dict_field.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- +from mongoengine import * +from mongoengine.base import BaseDict + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class TestDictField(MongoDBTestCase): + + def test_storage(self): + class BlogPost(Document): + info = DictField() + + BlogPost.drop_collection() + + info = {'testkey': 'testvalue'} + post = BlogPost(info=info).save() + self.assertEqual( + get_as_pymongo(post), + { + '_id': post.id, + 'info': info + } + ) + + def test_general_things(self): + """Ensure that dict types work as expected.""" + class BlogPost(Document): + info = DictField() + + BlogPost.drop_collection() + + post = BlogPost() + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) + + post.info = ['test', 'test'] + self.assertRaises(ValidationError, post.validate) + + post.info = {'$title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = {'nested': {'$title': 'test'}} + self.assertRaises(ValidationError, post.validate) + + post.info = {'the.title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = {'nested': {'the.title': 'test'}} + self.assertRaises(ValidationError, post.validate) + + post.info = {1: 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = {'title': 'test'} + post.save() + + post = BlogPost() + post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}} + post.save() + + post = BlogPost() + post.info = {'details': {'test': 'test'}} + post.save() + + post = BlogPost() + post.info = {'details': {'test': 3}} + post.save() + + self.assertEqual(BlogPost.objects.count(), 4) + self.assertEqual( + BlogPost.objects.filter(info__title__exact='test').count(), 1) + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + + post = BlogPost.objects.filter(info__title__exact='dollar_sign').first() + self.assertIn('te$t', post['info']['details']) + + # Confirm handles non strings or non existing keys + self.assertEqual( + BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + self.assertEqual( + BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + + post = BlogPost.objects.create(info={'title': 'original'}) + post.info.update({'title': 'updated'}) + post.save() + post.reload() + self.assertEqual('updated', post.info['title']) + + post.info.setdefault('authors', []) + post.save() + post.reload() + self.assertEqual([], post.info['authors']) + + def test_dictfield_dump_document(self): + """Ensure a DictField can handle another document's dump.""" + class Doc(Document): + field = DictField() + + class ToEmbed(Document): + id = IntField(primary_key=True, default=1) + recursive = DictField() + + class ToEmbedParent(Document): + id = IntField(primary_key=True, default=1) + recursive = DictField() + + meta = {'allow_inheritance': True} + + class ToEmbedChild(ToEmbedParent): + pass + + to_embed_recursive = ToEmbed(id=1).save() + to_embed = ToEmbed( + id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() + doc = Doc(field=to_embed.to_mongo().to_dict()) + doc.save() + assert isinstance(doc.field, dict) + assert doc.field == {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}} + # Same thing with a Document with a _cls field + to_embed_recursive = ToEmbedChild(id=1).save() + to_embed_child = ToEmbedChild( + id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() + doc = Doc(field=to_embed_child.to_mongo().to_dict()) + doc.save() + assert isinstance(doc.field, dict) + assert doc.field == { + '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild', + 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}} + } + + def test_dictfield_strict(self): + """Ensure that dict field handles validation if provided a strict field type.""" + class Simple(Document): + mapping = DictField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + # try creating an invalid mapping + with self.assertRaises(ValidationError): + e.mapping['somestring'] = "abc" + e.save() + + def test_dictfield_complex(self): + """Ensure that the dict field can handle the complex types.""" + class SettingBase(EmbeddedDocument): + meta = {'allow_inheritance': True} + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = DictField() + + Simple.drop_collection() + + e = Simple() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', + 'float': 1.001, + 'complex': IntegerSetting(value=42), + 'list': [IntegerSetting(value=42), + StringSetting(value='foo')]} + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertIsInstance(e2.mapping['somestring'], StringSetting) + self.assertIsInstance(e2.mapping['someint'], IntegerSetting) + + # Test querying + self.assertEqual( + Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update( + set__mapping={"someint": IntegerSetting(value=10)}) + Simple.objects().update( + set__mapping__nested_dict__list__1=StringSetting(value='Boo')) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEqual( + Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + + def test_ensure_unique_default_instances(self): + """Ensure that every field has it's own unique default instance.""" + class D(Document): + data = DictField() + data2 = DictField(default=lambda: {}) + + d1 = D() + d1.data['foo'] = 'bar' + d1.data2['foo'] = 'bar' + d2 = D() + self.assertEqual(d2.data, {}) + self.assertEqual(d2.data2, {}) + + def test_dict_field_invalid_dict_value(self): + class DictFieldTest(Document): + dictionary = DictField(required=True) + + DictFieldTest.drop_collection() + + test = DictFieldTest(dictionary=None) + test.dictionary # Just access to test getter + self.assertRaises(ValidationError, test.validate) + + test = DictFieldTest(dictionary=False) + test.dictionary # Just access to test getter + self.assertRaises(ValidationError, test.validate) + + def test_dict_field_raises_validation_error_if_wrongly_assign_embedded_doc(self): + class DictFieldTest(Document): + dictionary = DictField(required=True) + + DictFieldTest.drop_collection() + + class Embedded(EmbeddedDocument): + name = StringField() + + embed = Embedded(name='garbage') + doc = DictFieldTest(dictionary=embed) + with self.assertRaises(ValidationError) as ctx_err: + doc.validate() + self.assertIn("'dictionary'", str(ctx_err.exception)) + self.assertIn('Only dictionaries may be used in a DictField', str(ctx_err.exception)) + + def test_atomic_update_dict_field(self): + """Ensure that the entire DictField can be atomically updated.""" + class Simple(Document): + mapping = DictField(field=ListField(IntField(required=True))) + + Simple.drop_collection() + + e = Simple() + e.mapping['someints'] = [1, 2] + e.save() + e.update(set__mapping={"ints": [3, 4]}) + e.reload() + self.assertEqual(BaseDict, type(e.mapping)) + self.assertEqual({"ints": [3, 4]}, e.mapping) + + # try creating an invalid mapping + with self.assertRaises(ValueError): + e.update(set__mapping={"somestrings": ["foo", "bar", ]}) + + def test_dictfield_with_referencefield_complex_nesting_cases(self): + """Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)""" + # Relates to Issue #1453 + class Doc(Document): + s = StringField() + + class Simple(Document): + mapping0 = DictField(ReferenceField(Doc, dbref=True)) + mapping1 = DictField(ReferenceField(Doc, dbref=False)) + mapping2 = DictField(ListField(ReferenceField(Doc, dbref=True))) + mapping3 = DictField(ListField(ReferenceField(Doc, dbref=False))) + mapping4 = DictField(DictField(field=ReferenceField(Doc, dbref=True))) + mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False))) + mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True)))) + mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False)))) + mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))) + mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))) + + Doc.drop_collection() + Simple.drop_collection() + + d = Doc(s='aa').save() + e = Simple() + e.mapping0['someint'] = e.mapping1['someint'] = d + e.mapping2['someint'] = e.mapping3['someint'] = [d] + e.mapping4['someint'] = e.mapping5['someint'] = {'d': d} + e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}] + e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}] + e.save() + + s = Simple.objects.first() + self.assertIsInstance(s.mapping0['someint'], Doc) + self.assertIsInstance(s.mapping1['someint'], Doc) + self.assertIsInstance(s.mapping2['someint'][0], Doc) + self.assertIsInstance(s.mapping3['someint'][0], Doc) + self.assertIsInstance(s.mapping4['someint']['d'], Doc) + self.assertIsInstance(s.mapping5['someint']['d'], Doc) + self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc) + self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc) + self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc) + self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc) diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py new file mode 100644 index 00000000..d8410354 --- /dev/null +++ b/tests/fields/test_email_field.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +import sys +from unittest import SkipTest + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestEmailField(MongoDBTestCase): + def test_generic_behavior(self): + class User(Document): + email = EmailField() + + user = User(email='ross@example.com') + user.validate() + + user = User(email='ross@example.co.uk') + user.validate() + + user = User(email=('Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S' + 'aJIazqqWkm7.net')) + user.validate() + + user = User(email='new-tld@example.technology') + user.validate() + + user = User(email='ross@example.com.') + self.assertRaises(ValidationError, user.validate) + + # unicode domain + user = User(email=u'user@пример.рф') + user.validate() + + # invalid unicode domain + user = User(email=u'user@пример') + self.assertRaises(ValidationError, user.validate) + + # invalid data type + user = User(email=123) + self.assertRaises(ValidationError, user.validate) + + def test_email_field_unicode_user(self): + # Don't run this test on pypy3, which doesn't support unicode regex: + # https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode + if sys.version_info[:2] == (3, 2): + raise SkipTest('unicode email addresses are not supported on PyPy 3') + + class User(Document): + email = EmailField() + + # unicode user shouldn't validate by default... + user = User(email=u'Dörte@Sörensen.example.com') + self.assertRaises(ValidationError, user.validate) + + # ...but it should be fine with allow_utf8_user set to True + class User(Document): + email = EmailField(allow_utf8_user=True) + + user = User(email=u'Dörte@Sörensen.example.com') + user.validate() + + def test_email_field_domain_whitelist(self): + class User(Document): + email = EmailField() + + # localhost domain shouldn't validate by default... + user = User(email='me@localhost') + self.assertRaises(ValidationError, user.validate) + + # ...but it should be fine if it's whitelisted + class User(Document): + email = EmailField(domain_whitelist=['localhost']) + + user = User(email='me@localhost') + user.validate() + + def test_email_field_ip_domain(self): + class User(Document): + email = EmailField() + + valid_ipv4 = 'email@[127.0.0.1]' + valid_ipv6 = 'email@[2001:dB8::1]' + invalid_ip = 'email@[324.0.0.1]' + + # IP address as a domain shouldn't validate by default... + user = User(email=valid_ipv4) + self.assertRaises(ValidationError, user.validate) + + user = User(email=valid_ipv6) + self.assertRaises(ValidationError, user.validate) + + user = User(email=invalid_ip) + self.assertRaises(ValidationError, user.validate) + + # ...but it should be fine with allow_ip_domain set to True + class User(Document): + email = EmailField(allow_ip_domain=True) + + user = User(email=valid_ipv4) + user.validate() + + user = User(email=valid_ipv6) + user.validate() + + # invalid IP should still fail validation + user = User(email=invalid_ip) + self.assertRaises(ValidationError, user.validate) + + def test_email_field_honors_regex(self): + class User(Document): + email = EmailField(regex=r'\w+@example.com') + + # Fails regex validation + user = User(email='me@foo.com') + self.assertRaises(ValidationError, user.validate) + + # Passes regex validation + user = User(email='me@example.com') + self.assertIsNone(user.validate()) diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py new file mode 100644 index 00000000..cb27cfff --- /dev/null +++ b/tests/fields/test_map_field.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +import datetime + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestMapField(MongoDBTestCase): + + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + with self.assertRaises(ValidationError): + e.mapping['somestring'] = "abc" + e.save() + + with self.assertRaises(ValidationError): + class NoDeclaredType(Document): + mapping = MapField() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + meta = {"allow_inheritance": True} + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertIsInstance(e2.mapping['somestring'], StringSetting) + self.assertIsInstance(e2.mapping['someint'], IntegerSetting) + + with self.assertRaises(ValidationError): + e.mapping['someint'] = 123 + e.save() + + def test_embedded_mapfield_db_field(self): + class Embedded(EmbeddedDocument): + number = IntField(default=0, db_field='i') + + class Test(Document): + my_map = MapField(field=EmbeddedDocumentField(Embedded), + db_field='x') + + Test.drop_collection() + + test = Test() + test.my_map['DICTIONARY_KEY'] = Embedded(number=1) + test.save() + + Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) + + test = Test.objects.get() + self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2) + doc = self.db.test.find_one() + self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) + + def test_mapfield_numerical_index(self): + """Ensure that MapField accept numeric strings as indexes.""" + + class Embedded(EmbeddedDocument): + name = StringField() + + class Test(Document): + my_map = MapField(EmbeddedDocumentField(Embedded)) + + Test.drop_collection() + + test = Test() + test.my_map['1'] = Embedded(name='test') + test.save() + test.my_map['1'].name = 'test updated' + test.save() + + def test_map_field_lookup(self): + """Ensure MapField lookups succeed on Fields without a lookup + method. + """ + + class Action(EmbeddedDocument): + operation = StringField() + object = StringField() + + class Log(Document): + name = StringField() + visited = MapField(DateTimeField()) + actions = MapField(EmbeddedDocumentField(Action)) + + Log.drop_collection() + Log(name="wilson", visited={'friends': datetime.datetime.now()}, + actions={'friends': Action(operation='drink', object='beer')}).save() + + self.assertEqual(1, Log.objects( + visited__friends__exists=True).count()) + + self.assertEqual(1, Log.objects( + actions__friends__operation='drink', + actions__friends__object='beer').count()) + + def test_map_field_unicode(self): + class Info(EmbeddedDocument): + description = StringField() + value_list = ListField(field=StringField()) + + class BlogPost(Document): + info_dict = MapField(field=EmbeddedDocumentField(Info)) + + BlogPost.drop_collection() + + tree = BlogPost(info_dict={ + u"éééé": { + 'description': u"VALUE: éééé" + } + }) + + tree.save() + + self.assertEqual( + BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, + u"VALUE: éééé" + ) diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py new file mode 100644 index 00000000..5e1fc605 --- /dev/null +++ b/tests/fields/test_reference_field.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +from bson import SON, DBRef + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestReferenceField(MongoDBTestCase): + def test_reference_validation(self): + """Ensure that invalid document objects cannot be assigned to + reference fields. + """ + + class User(Document): + name = StringField() + + class BlogPost(Document): + content = StringField() + author = ReferenceField(User) + + User.drop_collection() + BlogPost.drop_collection() + + # Make sure ReferenceField only accepts a document class or a string + # with a document class name. + self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) + + user = User(name='Test User') + + # Ensure that the referenced object must have been saved + post1 = BlogPost(content='Chips and gravy taste good.') + post1.author = user + self.assertRaises(ValidationError, post1.save) + + # Check that an invalid object type cannot be used + post2 = BlogPost(content='Chips and chilli taste good.') + post1.author = post2 + self.assertRaises(ValidationError, post1.validate) + + # Ensure ObjectID's are accepted as references + user_object_id = user.pk + post3 = BlogPost(content="Chips and curry sauce taste good.") + post3.author = user_object_id + post3.save() + + # Make sure referencing a saved document of the right type works + user.save() + post1.author = user + post1.save() + + # Make sure referencing a saved document of the *wrong* type fails + post2.save() + post1.author = post2 + self.assertRaises(ValidationError, post1.validate) + + def test_objectid_reference_fields(self): + """Make sure storing Object ID references works.""" + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="John").save() + Person(name="Ross", parent=p1.pk).save() + + p = Person.objects.get(name="Ross") + self.assertEqual(p.parent, p1) + + def test_dbref_reference_fields(self): + """Make sure storing references as bson.dbref.DBRef works.""" + + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=True) + + Person.drop_collection() + + p1 = Person(name="John").save() + Person(name="Ross", parent=p1).save() + + self.assertEqual( + Person._get_collection().find_one({'name': 'Ross'})['parent'], + DBRef('person', p1.pk) + ) + + p = Person.objects.get(name="Ross") + self.assertEqual(p.parent, p1) + + def test_dbref_to_mongo(self): + """Make sure that calling to_mongo on a ReferenceField which + has dbref=False, but actually actually contains a DBRef returns + an ID of that DBRef. + """ + + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=False) + + p = Person( + name='Steve', + parent=DBRef('person', 'abcdefghijklmnop') + ) + self.assertEqual(p.to_mongo(), SON([ + ('name', u'Steve'), + ('parent', 'abcdefghijklmnop') + ])) + + def test_objectid_reference_fields(self): + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=False) + + Person.drop_collection() + + p1 = Person(name="John").save() + Person(name="Ross", parent=p1).save() + + col = Person._get_collection() + data = col.find_one({'name': 'Ross'}) + self.assertEqual(data['parent'], p1.pk) + + p = Person.objects.get(name="Ross") + self.assertEqual(p.parent, p1) + + def test_undefined_reference(self): + """Ensure that ReferenceFields may reference undefined Documents. + """ + class Product(Document): + name = StringField() + company = ReferenceField('Company') + + class Company(Document): + name = StringField() + + Product.drop_collection() + Company.drop_collection() + + ten_gen = Company(name='10gen') + ten_gen.save() + mongodb = Product(name='MongoDB', company=ten_gen) + mongodb.save() + + me = Product(name='MongoEngine') + me.save() + + obj = Product.objects(company=ten_gen).first() + self.assertEqual(obj, mongodb) + self.assertEqual(obj.company, ten_gen) + + obj = Product.objects(company=None).first() + self.assertEqual(obj, me) + + obj = Product.objects.get(company=None) + self.assertEqual(obj, me) + + def test_reference_query_conversion(self): + """Ensure that ReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = ReferenceField(Member, dbref=False) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) + + def test_reference_query_conversion_dbref(self): + """Ensure that ReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = ReferenceField(Member, dbref=True) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2).first() + self.assertEqual(post.id, post2.id) diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py new file mode 100644 index 00000000..6124c65e --- /dev/null +++ b/tests/fields/test_sequence_field.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- + +from mongoengine import * + +from tests.utils import MongoDBTestCase + + +class TestSequenceField(MongoDBTestCase): + def test_sequence_field(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + def test_sequence_field_get_next_value(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + self.assertEqual(Person.id.get_next_value(), 11) + self.db['mongoengine.counters'].drop() + + self.assertEqual(Person.id.get_next_value(), 1) + + class Person(Document): + id = SequenceField(primary_key=True, value_decorator=str) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + self.assertEqual(Person.id.get_next_value(), '11') + self.db['mongoengine.counters'].drop() + + self.assertEqual(Person.id.get_next_value(), '1') + + def test_sequence_field_sequence_name(self): + class Person(Document): + id = SequenceField(primary_key=True, sequence_name='jelly') + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 10) + + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 1000) + + def test_multiple_sequence_fields(self): + class Person(Document): + id = SequenceField(primary_key=True) + counter = SequenceField() + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + counters = [i.counter for i in Person.objects] + self.assertEqual(counters, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + Person.counter.set_next_value(999) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) + self.assertEqual(c['next'], 999) + + def test_sequence_fields_reload(self): + class Animal(Document): + counter = SequenceField() + name = StringField() + + self.db['mongoengine.counters'].drop() + Animal.drop_collection() + + a = Animal(name="Boi").save() + + self.assertEqual(a.counter, 1) + a.reload() + self.assertEqual(a.counter, 1) + + a.counter = None + self.assertEqual(a.counter, 2) + a.save() + + self.assertEqual(a.counter, 2) + + a = Animal.objects.first() + self.assertEqual(a.counter, 2) + a.reload() + self.assertEqual(a.counter, 2) + + def test_multiple_sequence_fields_on_docs(self): + class Animal(Document): + id = SequenceField(primary_key=True) + name = StringField() + + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Animal.drop_collection() + Person.drop_collection() + + for x in range(10): + Animal(name="Animal %s" % x).save() + Person(name="Person %s" % x).save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + id = [i.id for i in Animal.objects] + self.assertEqual(id, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) + self.assertEqual(c['next'], 10) + + def test_sequence_field_value_decorator(self): + class Person(Document): + id = SequenceField(primary_key=True, value_decorator=str) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in range(10): + p = Person(name="Person %s" % x) + p.save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, map(str, range(1, 11))) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + def test_embedded_sequence_field(self): + class Comment(EmbeddedDocument): + id = SequenceField() + content = StringField(required=True) + + class Post(Document): + title = StringField(required=True) + comments = ListField(EmbeddedDocumentField(Comment)) + + self.db['mongoengine.counters'].drop() + Post.drop_collection() + + Post(title="MongoEngine", + comments=[Comment(content="NoSQL Rocks"), + Comment(content="MongoEngine Rocks")]).save() + c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'}) + self.assertEqual(c['next'], 2) + post = Post.objects.first() + self.assertEqual(1, post.comments[0].id) + self.assertEqual(2, post.comments[1].id) + + def test_inherited_sequencefield(self): + class Base(Document): + name = StringField() + counter = SequenceField() + meta = {'abstract': True} + + class Foo(Base): + pass + + class Bar(Base): + pass + + bar = Bar(name='Bar') + bar.save() + + foo = Foo(name='Foo') + foo.save() + + self.assertTrue('base.counter' in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertFalse(('foo.counter' or 'bar.counter') in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertNotEqual(foo.counter, bar.counter) + self.assertEqual(foo._fields['counter'].owner_document, Base) + self.assertEqual(bar._fields['counter'].owner_document, Base) + + def test_no_inherited_sequencefield(self): + class Base(Document): + name = StringField() + meta = {'abstract': True} + + class Foo(Base): + counter = SequenceField() + + class Bar(Base): + counter = SequenceField() + + bar = Bar(name='Bar') + bar.save() + + foo = Foo(name='Foo') + foo.save() + + self.assertFalse('base.counter' in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertTrue(('foo.counter' and 'bar.counter') in + self.db['mongoengine.counters'].find().distinct('_id')) + self.assertEqual(foo.counter, bar.counter) + self.assertEqual(foo._fields['counter'].owner_document, Foo) + self.assertEqual(bar._fields['counter'].owner_document, Bar) diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index 0447799e..ddbf707e 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -4,7 +4,7 @@ from mongoengine import * from tests.utils import MongoDBTestCase -class TestFloatField(MongoDBTestCase): +class TestURLField(MongoDBTestCase): def test_validation(self): """Ensure that URLFields validate urls properly.""" diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py new file mode 100644 index 00000000..7b7faaf2 --- /dev/null +++ b/tests/fields/test_uuid_field.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +import uuid + +from mongoengine import * + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +class Person(Document): + api_key = UUIDField(binary=False) + + +class TestUUIDField(MongoDBTestCase): + def test_storage(self): + uid = uuid.uuid4() + person = Person(api_key=uid).save() + self.assertEqual( + get_as_pymongo(person), + {'_id': person.id, + 'api_key': str(uid) + } + ) + + def test_field_string(self): + """Test UUID fields storing as String + """ + Person.drop_collection() + + uu = uuid.uuid4() + Person(api_key=uu).save() + self.assertEqual(1, Person.objects(api_key=uu).count()) + self.assertEqual(uu, Person.objects.first().api_key) + + person = Person() + valid = (uuid.uuid4(), uuid.uuid1()) + for api_key in valid: + person.api_key = api_key + person.validate() + + invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', + '9d159858-549b-4975-9f98-dd2f987c113') + for api_key in invalid: + person.api_key = api_key + self.assertRaises(ValidationError, person.validate) + + def test_field_binary(self): + """Test UUID fields storing as Binary object.""" + Person.drop_collection() + + uu = uuid.uuid4() + Person(api_key=uu).save() + self.assertEqual(1, Person.objects(api_key=uu).count()) + self.assertEqual(uu, Person.objects.first().api_key) + + person = Person() + valid = (uuid.uuid4(), uuid.uuid1()) + for api_key in valid: + person.api_key = api_key + person.validate() + + invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', + '9d159858-549b-4975-9f98-dd2f987c113') + for api_key in invalid: + person.api_key = api_key + self.assertRaises(ValidationError, person.validate) diff --git a/tests/utils.py b/tests/utils.py index 5345f75e..19936a54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,7 @@ MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database # Constant that can be used to compare the version retrieved with # get_mongodb_version() MONGODB_26 = (2, 6) -MONGODB_3 = (3,0) +MONGODB_3 = (3, 0) MONGODB_32 = (3, 2) @@ -33,6 +33,11 @@ class MongoDBTestCase(unittest.TestCase): cls._connection.drop_database(MONGO_TEST_DB) +def get_as_pymongo(doc): + """Fetch the pymongo version of a certain Document""" + return doc.__class__.objects.as_pymongo().get(id=doc.id) + + def get_mongodb_version(): """Return the version of the connected mongoDB (first 2 digits) From 3cdb5b5db2f163c8a9b2404b020dbc83e573263f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 25 Feb 2019 22:29:44 +0100 Subject: [PATCH 08/71] fix poor assert's in tests --- tests/document/inheritance.py | 8 ++++---- tests/document/instance.py | 13 ++++++++----- tests/fields/fields.py | 8 ++++---- tests/fields/test_dict_field.py | 9 +++++---- tests/test_datastructures.py | 12 ++++++------ 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 9cc20c89..83c2a80a 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -38,12 +38,12 @@ class InheritanceTest(unittest.TestCase): meta = {'allow_inheritance': True} test_doc = DataDoc(name='test', embed=EmbedData(data='data')) - assert test_doc._cls == 'DataDoc' - assert test_doc.embed._cls == 'EmbedData' + self.assertEqual(test_doc._cls, 'DataDoc') + self.assertEqual(test_doc.embed._cls, 'EmbedData') test_doc.save() saved_doc = DataDoc.objects.with_id(test_doc.id) - assert test_doc._cls == saved_doc._cls - assert test_doc.embed._cls == saved_doc.embed._cls + self.assertEqual(test_doc._cls, saved_doc._cls) + self.assertEqual(test_doc.embed._cls, saved_doc.embed._cls) test_doc.delete() def test_superclasses(self): diff --git a/tests/document/instance.py b/tests/document/instance.py index cde18c9f..051eda68 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -806,7 +806,8 @@ class InstanceTest(MongoDBTestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - assert not doc1.modify({'name': doc2.name}, set__age=100) + n_modified = doc1.modify({'name': doc2.name}, set__age=100) + self.assertEqual(n_modified, 0) self.assertDbEqual(docs) @@ -815,7 +816,8 @@ class InstanceTest(MongoDBTestCase): doc2 = self.Person(id=ObjectId(), name="jim", age=20) docs = [dict(doc1.to_mongo())] - assert not doc2.modify({'name': doc2.name}, set__age=100) + n_modified = doc2.modify({'name': doc2.name}, set__age=100) + self.assertEqual(n_modified, 0) self.assertDbEqual(docs) @@ -831,14 +833,15 @@ class InstanceTest(MongoDBTestCase): doc.job.name = "Google" doc.job.years = 3 - assert doc.modify( + n_modified = doc.modify( set__age=21, set__job__name="MongoDB", unset__job__years=True) + self.assertEqual(n_modified, 1) doc_copy.age = 21 doc_copy.job.name = "MongoDB" del doc_copy.job.years - assert doc.to_json() == doc_copy.to_json() - assert doc._get_changed_fields() == [] + self.assertEqual(doc.to_json(), doc_copy.to_json()) + self.assertEqual(doc._get_changed_fields(), []) self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index c772b472..128936bf 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2168,15 +2168,15 @@ class FieldTest(MongoDBTestCase): to_embed = ToEmbed(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed) doc.save() - assert isinstance(doc.field, ToEmbed) - assert doc.field == to_embed + self.assertIsInstance(doc.field, ToEmbed) + self.assertEqual(doc.field, to_embed) # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild(id=2, recursive=to_embed_recursive).save() doc = Doc(field=to_embed_child) doc.save() - assert isinstance(doc.field, ToEmbedChild) - assert doc.field == to_embed_child + self.assertIsInstance(doc.field, ToEmbedChild) + self.assertEqual(doc.field, to_embed_child) def test_cls_field(self): class Animal(Document): diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 2b9cecb7..a3b8ec6c 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -116,19 +116,20 @@ class TestDictField(MongoDBTestCase): id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() doc = Doc(field=to_embed.to_mongo().to_dict()) doc.save() - assert isinstance(doc.field, dict) - assert doc.field == {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}} + self.assertIsInstance(doc.field, dict) + self.assertEqual(doc.field, {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}}) # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild( id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() doc = Doc(field=to_embed_child.to_mongo().to_dict()) doc.save() - assert isinstance(doc.field, dict) - assert doc.field == { + self.assertIsInstance(doc.field, dict) + expected = { '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}} } + self.assertEqual(doc.field, expected) def test_dictfield_strict(self): """Ensure that dict field handles validation if provided a strict field type.""" diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 2f1277e6..4fb21d21 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -203,7 +203,7 @@ class TestBaseList(unittest.TestCase): def test___getitem__using_slice(self): base_list = self._get_baselist([0, 1, 2]) - self.assertEqual(base_list[1:3], [1,2]) + self.assertEqual(base_list[1:3], [1, 2]) self.assertEqual(base_list[0:3:2], [0, 2]) def test___getitem___using_slice_returns_list(self): @@ -218,7 +218,7 @@ class TestBaseList(unittest.TestCase): def test___getitem__sublist_returns_BaseList_bound_to_instance(self): base_list = self._get_baselist( [ - [1,2], + [1, 2], [3, 4] ] ) @@ -305,10 +305,10 @@ class TestBaseList(unittest.TestCase): self.assertEqual(base_list, [-1, 1, -2]) def test___setitem___with_slice(self): - base_list = self._get_baselist([0,1,2,3,4,5]) + base_list = self._get_baselist([0, 1, 2, 3, 4, 5]) base_list[0:6:2] = [None, None, None] self.assertEqual(base_list._instance._changed_fields, ['my_name']) - self.assertEqual(base_list, [None,1,None,3,None,5]) + self.assertEqual(base_list, [None, 1, None, 3, None, 5]) def test___setitem___item_0_calls_mark_as_changed(self): base_list = self._get_baselist([True]) @@ -426,8 +426,8 @@ class TestStrictDict(unittest.TestCase): def test_mappings_protocol(self): d = self.dtype(a=1, b=2) - assert dict(d) == {'a': 1, 'b': 2} - assert dict(**d) == {'a': 1, 'b': 2} + self.assertEqual(dict(d), {'a': 1, 'b': 2}) + self.assertEqual(dict(**d), {'a': 1, 'b': 2}) if __name__ == '__main__': From c60c2ee8d0c90c37cadfa2fbcfdf054eca2090af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 25 Feb 2019 22:33:36 +0100 Subject: [PATCH 09/71] fix minor styling issue in tests --- tests/document/class_methods.py | 24 ++++++++++++------------ tests/document/delta.py | 1 + tests/document/json_serialisation.py | 6 +++--- tests/fields/file_tests.py | 1 + tests/fixtures.py | 1 + tests/queryset/field_list.py | 4 ++-- tests/queryset/geo.py | 4 ++-- tests/queryset/pickable.py | 11 ++++++----- tests/queryset/visitor.py | 1 - tests/test_connection.py | 8 ++++---- tests/test_context_managers.py | 1 + tests/test_replicaset_connection.py | 1 + tests/test_signals.py | 4 ++-- 13 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 2632d38f..88937ec8 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -66,7 +66,7 @@ class ClassMethodsTest(unittest.TestCase): """ collection_name = 'person' self.Person(name='Test').save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, self.db.collection_names()) self.Person.drop_collection() self.assertNotIn(collection_name, self.db.collection_names()) @@ -102,16 +102,16 @@ class ClassMethodsTest(unittest.TestCase): BlogPost.drop_collection() BlogPost.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPost.ensure_index(['author', 'description']) - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('author', 1), ('description', 1)]] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('author', 1), ('description', 1)]]}) BlogPost._get_collection().drop_index('author_1_description_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPost._get_collection().drop_index('author_1_title_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('author', 1), ('title', 1)]], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('author', 1), ('title', 1)]], 'extra': []}) def test_compare_indexes_inheritance(self): """ Ensure that the indexes are properly created and that @@ -140,16 +140,16 @@ class ClassMethodsTest(unittest.TestCase): BlogPost.ensure_indexes() BlogPostWithTags.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPostWithTags.ensure_index(['author', 'tag_list']) - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]]}) BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tag_list_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1') - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': []}) def test_compare_indexes_multiple_subclasses(self): """ Ensure that compare_indexes behaves correctly if called from a @@ -184,9 +184,9 @@ class ClassMethodsTest(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithCustomField.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) - self.assertEqual(BlogPostWithTags.compare_indexes(), { 'missing': [], 'extra': [] }) - self.assertEqual(BlogPostWithCustomField.compare_indexes(), { 'missing': [], 'extra': [] }) + self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPostWithTags.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPostWithCustomField.compare_indexes(), {'missing': [], 'extra': []}) @requires_mongodb_gte_26 def test_compare_indexes_for_text_indexes(self): diff --git a/tests/document/delta.py b/tests/document/delta.py index 30296956..942e3a0a 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -863,5 +863,6 @@ class DeltaTest(unittest.TestCase): self.assertEqual('oops', delta[0]["users.007.rolist"][0]["type"]) self.assertEqual(uinfo.id, delta[0]["users.007.info"]) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 110f1e14..7c785ab2 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -32,12 +32,12 @@ class TestJson(unittest.TestCase): string = StringField(db_field='s') embedded = EmbeddedDocumentField(Embedded, db_field='e') - doc = Doc( string="Hello", embedded=Embedded(string="Inner Hello")) - doc_json = doc.to_json(sort_keys=True, use_db_field=False,separators=(',', ':')) + doc = Doc(string="Hello", embedded=Embedded(string="Inner Hello")) + doc_json = doc.to_json(sort_keys=True, use_db_field=False, separators=(',', ':')) expected_json = """{"embedded":{"string":"Inner Hello"},"string":"Hello"}""" - self.assertEqual( doc_json, expected_json) + self.assertEqual(doc_json, expected_json) def test_json_simple(self): diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 213e889c..76e20bb9 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -578,5 +578,6 @@ class FileTest(MongoDBTestCase): self.assertEqual(marmot.photos[0].foo, 'bar') self.assertEqual(marmot.photos[0].get().length, 8313) + if __name__ == '__main__': unittest.main() diff --git a/tests/fixtures.py b/tests/fixtures.py index d8eb8487..b8303b99 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -48,6 +48,7 @@ class PickleSignalsTest(Document): def post_delete(self, sender, document, **kwargs): pickled = pickle.dumps(document) + signals.post_save.connect(PickleSignalsTest.post_save, sender=PickleSignalsTest) signals.post_delete.connect(PickleSignalsTest.post_delete, sender=PickleSignalsTest) diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index b111238a..250e2601 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -208,7 +208,7 @@ class OnlyExcludeAllTest(unittest.TestCase): BlogPost.drop_collection() - post = BlogPost(content='Had a good coffee today...', various={'test_dynamic':{'some': True}}) + post = BlogPost(content='Had a good coffee today...', various={'test_dynamic': {'some': True}}) post.author = User(name='Test User') post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] post.save() @@ -413,7 +413,6 @@ class OnlyExcludeAllTest(unittest.TestCase): numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) - def test_exclude_from_subclasses_docs(self): class Base(Document): @@ -436,5 +435,6 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index fd8c9b0f..240a94ab 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -534,11 +534,11 @@ class GeoQueriesTest(MongoDBTestCase): Location.drop_collection() - Location(loc=[1,2]).save() + Location(loc=[1, 2]).save() loc = Location.objects.as_pymongo()[0] self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [1, 2]}) - Location.objects.update(set__loc=[2,1]) + Location.objects.update(set__loc=[2, 1]) loc = Location.objects.as_pymongo()[0] self.assertEqual(loc["loc"], {"type": "Point", "coordinates": [2, 1]}) diff --git a/tests/queryset/pickable.py b/tests/queryset/pickable.py index d96e7dc6..bf7bb31c 100644 --- a/tests/queryset/pickable.py +++ b/tests/queryset/pickable.py @@ -6,10 +6,12 @@ from mongoengine.connection import connect __author__ = 'stas' + class Person(Document): name = StringField() age = IntField() + class TestQuerysetPickable(unittest.TestCase): """ Test for adding pickling support for QuerySet instances @@ -18,7 +20,7 @@ class TestQuerysetPickable(unittest.TestCase): def setUp(self): super(TestQuerysetPickable, self).setUp() - connection = connect(db="test") #type: pymongo.mongo_client.MongoClient + connection = connect(db="test") # type: pymongo.mongo_client.MongoClient connection.drop_database("test") @@ -27,7 +29,6 @@ class TestQuerysetPickable(unittest.TestCase): age=21 ) - def test_picke_simple_qs(self): qs = Person.objects.all() @@ -46,10 +47,10 @@ class TestQuerysetPickable(unittest.TestCase): self.assertEqual(qs.count(), loadedQs.count()) - #can update loadedQs + # can update loadedQs loadedQs.update(age=23) - #check + # check self.assertEqual(Person.objects.first().age, 23) def test_pickle_support_filtration(self): @@ -70,7 +71,7 @@ class TestQuerysetPickable(unittest.TestCase): self.assertEqual(loaded.count(), 2) self.assertEqual(loaded.filter(name="Bob").first().age, 23) - + diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index 8261faae..22d274a8 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -275,7 +275,6 @@ class QTest(unittest.TestCase): with self.assertRaises(InvalidQueryError): self.Person.objects.filter('user1') - def test_q_regex(self): """Ensure that Q objects can be queried using regexes. """ diff --git a/tests/test_connection.py b/tests/test_connection.py index 88d63cdb..7c4fc4cf 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -99,11 +99,11 @@ class ConnectionTest(unittest.TestCase): conn = get_connection() self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2') + connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2') conn = get_connection('testdb2') self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['localhost'], is_mock=True, alias='testdb3') + connect(host=['localhost'], is_mock=True, alias='testdb3') conn = get_connection('testdb3') self.assertIsInstance(conn, mongomock.MongoClient) @@ -111,11 +111,11 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb4') self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5') + connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5') conn = get_connection('testdb5') self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6') + connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6') conn = get_connection('testdb6') self.assertIsInstance(conn, mongomock.MongoClient) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 8fb7bc78..8207cd89 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -302,5 +302,6 @@ class ContextManagersTest(unittest.TestCase): _ = db.system.indexes.find_one() # queries on db.system.indexes are ignored as well self.assertEqual(q, 1) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index a53f5903..4aa647d6 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -47,5 +47,6 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(conn.read_preference, READ_PREF) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_signals.py b/tests/test_signals.py index df687d0e..f3b6e33c 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -39,7 +39,6 @@ class SignalTests(unittest.TestCase): def post_init(cls, sender, document, **kwargs): signal_output.append('post_init signal, %s, document._created = %s' % (document, document._created)) - @classmethod def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) @@ -247,7 +246,7 @@ class SignalTests(unittest.TestCase): def load_existing_author(): a = self.Author(name='Bill Shakespeare') a.save() - self.get_signal_output(lambda: None) # eliminate signal output + self.get_signal_output(lambda: None) # eliminate signal output a1 = self.Author.objects(name='Bill Shakespeare')[0] self.assertEqual(self.get_signal_output(create_author), [ @@ -431,5 +430,6 @@ class SignalTests(unittest.TestCase): {} ]) + if __name__ == '__main__': unittest.main() From dca837b843272385e7be80c7d6c1c037dc38393d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 2 Oct 2018 21:26:14 +0200 Subject: [PATCH 10/71] Add suport for Mongo 3.4 (travis, fix tests) --- .install_mongodb_on_travis.sh | 8 +++++- .travis.yml | 9 +++---- README.rst | 4 +-- docs/changelog.rst | 1 + mongoengine/mongodb_support.py | 21 ++++++++++++++++ tests/document/indexes.py | 26 +++++++++++++++---- tests/fields/fields.py | 4 +-- tests/queryset/queryset.py | 25 +++++++++--------- tests/utils.py | 46 +++++++++++++++++----------------- 9 files changed, 94 insertions(+), 50 deletions(-) create mode 100644 mongoengine/mongodb_support.py diff --git a/.install_mongodb_on_travis.sh b/.install_mongodb_on_travis.sh index 057ccf74..6ac2e364 100644 --- a/.install_mongodb_on_travis.sh +++ b/.install_mongodb_on_travis.sh @@ -19,8 +19,14 @@ elif [ "$MONGODB" = "3.2" ]; then sudo apt-get update sudo apt-get install mongodb-org-server=3.2.20 # service should be started automatically +elif [ "$MONGODB" = "3.4" ]; then + sudo apt-key adv --keyserver keyserver.ubuntu.com:80 --recv 0C49F3730359A14518585931BC711F9BA15703C6 + echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list + sudo apt-get update + sudo apt-get install mongodb-org-server=3.4.17 + # service should be started automatically else - echo "Invalid MongoDB version, expected 2.6, 3.0, or 3.2" + echo "Invalid MongoDB version, expected 2.6, 3.0, 3.2 or 3.4." exit 1 fi; diff --git a/.travis.yml b/.travis.yml index 4f77f4e0..a73391aa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,9 @@ # combinations: # * MongoDB v2.6 is currently the "main" version tested against Python v2.7, # v3.5, v3.6, PyPy, and PyMongo v3.x. -# * MongoDB v3.0 & v3.2 are tested against Python v2.7, v3.5 & v3.6 +# * MongoDB v3.0, v3.2 are tested against Python v2.7, v3.5 & v3.6 # and Pymongo v3.5 & v3.x +# * MongoDB v3.4 is tested against v3.6 and Pymongo v3.x # Reminder: Update README.rst if you change MongoDB versions we test. language: python @@ -26,16 +27,14 @@ matrix: include: - python: 2.7 env: MONGODB=3.0 PYMONGO=3.5 - - python: 2.7 - env: MONGODB=3.2 PYMONGO=3.x - - python: 3.5 - env: MONGODB=3.0 PYMONGO=3.5 - python: 3.5 env: MONGODB=3.2 PYMONGO=3.x - python: 3.6 env: MONGODB=3.0 PYMONGO=3.5 - python: 3.6 env: MONGODB=3.2 PYMONGO=3.x + - python: 3.6 + env: MONGODB=3.4 PYMONGO=3.x before_install: - bash .install_mongodb_on_travis.sh diff --git a/README.rst b/README.rst index 4e186a85..f0309170 100644 --- a/README.rst +++ b/README.rst @@ -26,7 +26,7 @@ an `API reference `_. Supported MongoDB Versions ========================== -MongoEngine is currently tested against MongoDB v2.6, v3.0 and v3.2. Future +MongoEngine is currently tested against MongoDB v2.6, v3.0, v3.2 and v3.4. Future versions should be supported as well, but aren't actively tested at the moment. Make sure to open an issue or submit a pull request if you experience any problems with MongoDB v3.4+. @@ -36,7 +36,7 @@ Installation We recommend the use of `virtualenv `_ and of `pip `_. You can then use ``pip install -U mongoengine``. You may also have `setuptools `_ -and thus you can use ``easy_install -U mongoengine``. Another option is +and thus you can use ``easy_install -U mongoengine``. Another option is `pipenv `_. You can then use ``pipenv install mongoengine`` to both create the virtual environment and install the package. Otherwise, you can download the source from `GitHub `_ and diff --git a/docs/changelog.rst b/docs/changelog.rst index ae734448..c87fd1e9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,7 @@ Development - Document a BREAKING CHANGE introduced in 0.15.3 and not reported at that time (#1995) - Fix InvalidStringData error when using modify on a BinaryField #1127 - DEPRECATION: `EmbeddedDocument.save` & `.reload` are marked as deprecated and will be removed in a next version of mongoengine #1552 +- Fix test suite and CI to support MongoDB 3.4 #1445 ================= Changes in 0.16.3 diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py new file mode 100644 index 00000000..b5f3bdc8 --- /dev/null +++ b/mongoengine/mongodb_support.py @@ -0,0 +1,21 @@ +""" +Helper functions, constants, and types to aid with MongoDB v3.x support +""" +from mongoengine.connection import get_connection + + +# Constant that can be used to compare the version retrieved with +# get_mongodb_version() +MONGODB_34 = (3, 4) +MONGODB_32 = (3, 2) +MONGODB_3 = (3, 0) +MONGODB_26 = (2, 6) + + +def get_mongodb_version(): + """Return the version of the connected mongoDB (first 2 digits) + + :return: tuple(int, int) + """ + version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2) + return tuple(version_list) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 57f48587..a21b45bc 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -9,7 +9,8 @@ from six import iteritems from mongoengine import * from mongoengine.connection import get_db -from tests.utils import get_mongodb_version, requires_mongodb_gte_26, MONGODB_32, MONGODB_3 +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32, MONGODB_3 +from tests.utils import requires_mongodb_gte_26, requires_mongodb_lte_32, requires_mongodb_gte_34 __all__ = ("IndexesTest", ) @@ -477,6 +478,7 @@ class IndexesTest(unittest.TestCase): def test_covered_index(self): """Ensure that covered indexes can be used """ + IS_MONGODB_3 = get_mongodb_version() >= MONGODB_3 class Test(Document): a = IntField() @@ -492,8 +494,6 @@ class IndexesTest(unittest.TestCase): obj = Test(a=1) obj.save() - IS_MONGODB_3 = get_mongodb_version() >= MONGODB_3 - # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude('a').explain() @@ -569,7 +569,7 @@ class IndexesTest(unittest.TestCase): if pymongo.version != '3.0': self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) - if MONGO_VER == MONGODB_32: + if MONGO_VER >= MONGODB_32: # Mongo32 throws an error if an index exists (i.e `tags` in our case) # and you use hint on an index name that does not exist with self.assertRaises(OperationFailure): @@ -601,6 +601,22 @@ class IndexesTest(unittest.TestCase): # Ensure backwards compatibilty for errors self.assertRaises(OperationError, post2.save) + @requires_mongodb_gte_34 + def test_primary_key_unique_not_working_under_mongo_34(self): + class Blog(Document): + id = StringField(primary_key=True, unique=True) + + with self.assertRaises(OperationFailure) as ctx_err: + Blog(id='garbage').save() + self.assertIn("The field 'unique' is not valid for an _id index specification", str(ctx_err.exception)) + + @requires_mongodb_lte_32 + def test_primary_key_unique_working_under_mongo_32(self): + class Blog(Document): + id = StringField(primary_key=True, unique=True) + + Blog(id='garbage').save() + def test_unique_with(self): """Ensure that unique_with constraints are applied to fields. """ @@ -760,7 +776,7 @@ class IndexesTest(unittest.TestCase): You won't create a duplicate but you will update an existing document. """ class User(Document): - name = StringField(primary_key=True, unique=True) + name = StringField(primary_key=True) password = StringField() User.drop_collection() diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 128936bf..2c4ac3ac 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2113,7 +2113,7 @@ class FieldTest(MongoDBTestCase): field_1 = StringField(db_field='f') class Doc(Document): - my_id = IntField(required=True, unique=True, primary_key=True) + my_id = IntField(primary_key=True) embed_me = DynamicField(db_field='e') field_x = StringField(db_field='x') @@ -2135,7 +2135,7 @@ class FieldTest(MongoDBTestCase): field_1 = StringField(db_field='f') class Doc(Document): - my_id = IntField(required=True, unique=True, primary_key=True) + my_id = IntField(primary_key=True) embed_me = DynamicField(db_field='e') field_x = StringField(db_field='x') diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 3d8e8960..05a7ca75 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -18,11 +18,12 @@ from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32 from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) -from tests.utils import requires_mongodb_gte_26, skip_pymongo3, get_mongodb_version, MONGODB_32 +from tests.utils import requires_mongodb_gte_26, skip_pymongo3 __all__ = ("QuerySetTest",) @@ -852,8 +853,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) Blog.objects.insert(blogs, load_bulk=False) - if MONGO_VER == MONGODB_32: - self.assertEqual(q, 1) # 1 entry containing the list of inserts + if MONGO_VER >= MONGODB_32: + self.assertEqual(q, 1) # 1 entry containing the list of inserts else: self.assertEqual(q, len(blogs)) # 1 entry per doc inserted @@ -869,8 +870,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) Blog.objects.insert(blogs) - if MONGO_VER == MONGODB_32: - self.assertEqual(q, 2) # 1 for insert 1 for fetch + if MONGO_VER >= MONGODB_32: + self.assertEqual(q, 2) # 1 for insert 1 for fetch else: self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs @@ -1204,7 +1205,7 @@ class QuerySetTest(unittest.TestCase): """Ensure filters can be chained together. """ class Blog(Document): - id = StringField(unique=True, primary_key=True) + id = StringField(primary_key=True) class BlogPost(Document): blog = ReferenceField(Blog) @@ -1316,7 +1317,7 @@ class QuerySetTest(unittest.TestCase): order_by() w/o any arguments. """ MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby' + ORDER_BY_KEY = 'sort' if MONGO_VER >= MONGODB_32 else '$orderby' class BlogPost(Document): title = StringField() @@ -2524,8 +2525,8 @@ class QuerySetTest(unittest.TestCase): def test_comment(self): """Make sure adding a comment to the query gets added to the query""" MONGO_VER = self.mongodb_version - QUERY_KEY = 'filter' if MONGO_VER == MONGODB_32 else '$query' - COMMENT_KEY = 'comment' if MONGO_VER == MONGODB_32 else '$comment' + QUERY_KEY = 'filter' if MONGO_VER >= MONGODB_32 else '$query' + COMMENT_KEY = 'comment' if MONGO_VER >= MONGODB_32 else '$comment' class User(Document): age = IntField() @@ -3349,7 +3350,7 @@ class QuerySetTest(unittest.TestCase): meta = {'indexes': [ {'fields': ['$title', "$content"], 'default_language': 'portuguese', - 'weight': {'title': 10, 'content': 2} + 'weights': {'title': 10, 'content': 2} } ]} @@ -5131,7 +5132,7 @@ class QuerySetTest(unittest.TestCase): def test_query_reference_to_custom_pk_doc(self): class A(Document): - id = StringField(unique=True, primary_key=True) + id = StringField(primary_key=True) class B(Document): a = ReferenceField(A) @@ -5236,7 +5237,7 @@ class QuerySetTest(unittest.TestCase): def test_bool_with_ordering(self): MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby' + ORDER_BY_KEY = 'sort' if MONGO_VER >= MONGODB_32 else '$orderby' class Person(Document): name = StringField() diff --git a/tests/utils.py b/tests/utils.py index 19936a54..e0380dbc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,22 +1,17 @@ +import operator import unittest from nose.plugins.skip import SkipTest from mongoengine import connect -from mongoengine.connection import get_db, get_connection +from mongoengine.connection import get_db +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 from mongoengine.python_support import IS_PYMONGO_3 MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database -# Constant that can be used to compare the version retrieved with -# get_mongodb_version() -MONGODB_26 = (2, 6) -MONGODB_3 = (3, 0) -MONGODB_32 = (3, 2) - - class MongoDBTestCase(unittest.TestCase): """Base class for tests that need a mongodb connection It ensures that the db is clean at the beginning and dropped at the end automatically @@ -38,34 +33,39 @@ def get_as_pymongo(doc): return doc.__class__.objects.as_pymongo().get(id=doc.id) -def get_mongodb_version(): - """Return the version of the connected mongoDB (first 2 digits) - - :return: tuple(int, int) - """ - version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2) - return tuple(version_list) - - -def _decorated_with_ver_requirement(func, version): +def _decorated_with_ver_requirement(func, mongo_version_req, oper=operator.ge): """Return a given function decorated with the version requirement for a particular MongoDB version tuple. - :param version: The version required (tuple(int, int)) + :param mongo_version_req: The mongodb version requirement (tuple(int, int)) + :param oper: The operator to apply """ def _inner(*args, **kwargs): - MONGODB_V = get_mongodb_version() - if MONGODB_V >= version: + mongodb_v = get_mongodb_version() + if oper(mongodb_v, mongo_version_req): return func(*args, **kwargs) - raise SkipTest('Needs MongoDB v{}+'.format('.'.join(str(n) for n in version))) + raise SkipTest('Needs MongoDB v{}+'.format('.'.join(str(n) for n in mongo_version_req))) _inner.__name__ = func.__name__ _inner.__doc__ = func.__doc__ - return _inner +def requires_mongodb_gte_34(func): + """Raise a SkipTest exception if we're working with MongoDB version + lower than v3.4 + """ + return _decorated_with_ver_requirement(func, MONGODB_34) + + +def requires_mongodb_lte_32(func): + """Raise a SkipTest exception if we're working with MongoDB version + greater than v3.2. + """ + return _decorated_with_ver_requirement(func, MONGODB_32, oper=operator.le) + + def requires_mongodb_gte_26(func): """Raise a SkipTest exception if we're working with MongoDB version lower than v2.6. From 7247b9b68ec546c4f0f43b01455dc1c6545492f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 24 Feb 2019 22:34:17 +0100 Subject: [PATCH 11/71] additional fixes to support Mongo3.4 --- .travis.yml | 2 +- mongoengine/document.py | 4 ++-- mongoengine/mongodb_support.py | 2 +- tests/document/indexes.py | 6 ++++++ tests/utils.py | 10 +++++----- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/.travis.yml b/.travis.yml index a73391aa..64086357 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ # combinations: # * MongoDB v2.6 is currently the "main" version tested against Python v2.7, # v3.5, v3.6, PyPy, and PyMongo v3.x. -# * MongoDB v3.0, v3.2 are tested against Python v2.7, v3.5 & v3.6 +# * MongoDB v3.0 & v3.2 are tested against Python v2.7, v3.5 & v3.6 # and Pymongo v3.5 & v3.x # * MongoDB v3.4 is tested against v3.6 and Pymongo v3.x # Reminder: Update README.rst if you change MongoDB versions we test. diff --git a/mongoengine/document.py b/mongoengine/document.py index 84c1d699..b671ac6b 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -808,7 +808,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): db.drop_collection(col_name) @classmethod - def create_index(cls, keys, background=False, **kwargs): + def _create_index(cls, keys, background=False, **kwargs): """Creates the given indexes if required. :param keys: a single index key or a list of index keys (to @@ -851,7 +851,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): warnings.warn(msg, DeprecationWarning) elif not IS_PYMONGO_3: kwargs.update({'drop_dups': drop_dups}) - return cls.create_index(key_or_list, background=background, **kwargs) + return cls._create_index(key_or_list, background=background, **kwargs) @classmethod def ensure_indexes(cls): diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index b5f3bdc8..717a3d81 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -1,5 +1,5 @@ """ -Helper functions, constants, and types to aid with MongoDB v3.x support +Helper functions, constants, and types to aid with MongoDB version support """ from mongoengine.connection import get_connection diff --git a/tests/document/indexes.py b/tests/document/indexes.py index a21b45bc..36fbae46 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -603,18 +603,24 @@ class IndexesTest(unittest.TestCase): @requires_mongodb_gte_34 def test_primary_key_unique_not_working_under_mongo_34(self): + """Relates to #1445""" class Blog(Document): id = StringField(primary_key=True, unique=True) + Blog.drop_collection() + with self.assertRaises(OperationFailure) as ctx_err: Blog(id='garbage').save() self.assertIn("The field 'unique' is not valid for an _id index specification", str(ctx_err.exception)) @requires_mongodb_lte_32 def test_primary_key_unique_working_under_mongo_32(self): + """Relates to #1445""" class Blog(Document): id = StringField(primary_key=True, unique=True) + Blog.drop_collection() + Blog(id='garbage').save() def test_unique_with(self): diff --git a/tests/utils.py b/tests/utils.py index e0380dbc..e94e4a80 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,12 +33,12 @@ def get_as_pymongo(doc): return doc.__class__.objects.as_pymongo().get(id=doc.id) -def _decorated_with_ver_requirement(func, mongo_version_req, oper=operator.ge): +def _decorated_with_ver_requirement(func, mongo_version_req, oper): """Return a given function decorated with the version requirement for a particular MongoDB version tuple. :param mongo_version_req: The mongodb version requirement (tuple(int, int)) - :param oper: The operator to apply + :param oper: The operator to apply (e.g: operator.ge) """ def _inner(*args, **kwargs): mongodb_v = get_mongodb_version() @@ -56,7 +56,7 @@ def requires_mongodb_gte_34(func): """Raise a SkipTest exception if we're working with MongoDB version lower than v3.4 """ - return _decorated_with_ver_requirement(func, MONGODB_34) + return _decorated_with_ver_requirement(func, MONGODB_34, oper=operator.ge) def requires_mongodb_lte_32(func): @@ -70,14 +70,14 @@ def requires_mongodb_gte_26(func): """Raise a SkipTest exception if we're working with MongoDB version lower than v2.6. """ - return _decorated_with_ver_requirement(func, MONGODB_26) + return _decorated_with_ver_requirement(func, MONGODB_26, oper=operator.ge) def requires_mongodb_gte_3(func): """Raise a SkipTest exception if we're working with MongoDB version lower than v3.0. """ - return _decorated_with_ver_requirement(func, MONGODB_3) + return _decorated_with_ver_requirement(func, MONGODB_3, oper=operator.ge) def skip_pymongo3(f): From 7cea2a768f16eafaabcec55df8e222055a8d3ba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 26 Feb 2019 22:39:49 +0100 Subject: [PATCH 12/71] Fix recent flaky test for python 3.6 --- mongoengine/document.py | 4 ++-- tests/document/indexes.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index b671ac6b..84c1d699 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -808,7 +808,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): db.drop_collection(col_name) @classmethod - def _create_index(cls, keys, background=False, **kwargs): + def create_index(cls, keys, background=False, **kwargs): """Creates the given indexes if required. :param keys: a single index key or a list of index keys (to @@ -851,7 +851,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): warnings.warn(msg, DeprecationWarning) elif not IS_PYMONGO_3: kwargs.update({'drop_dups': drop_dups}) - return cls._create_index(key_or_list, background=background, **kwargs) + return cls.create_index(key_or_list, background=background, **kwargs) @classmethod def ensure_indexes(cls): diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 36fbae46..abd349f3 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -611,7 +611,11 @@ class IndexesTest(unittest.TestCase): with self.assertRaises(OperationFailure) as ctx_err: Blog(id='garbage').save() - self.assertIn("The field 'unique' is not valid for an _id index specification", str(ctx_err.exception)) + try: + self.assertIn("The field 'unique' is not valid for an _id index specification", str(ctx_err.exception)) + except AssertionError: + # error is slightly different on python 3.6 + self.assertIn("The field 'background' is not valid for an _id index specification", str(ctx_err.exception)) @requires_mongodb_lte_32 def test_primary_key_unique_working_under_mongo_32(self): From 35b7efe3f40c3ebea39d960fb100526d2fb968b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 24 Feb 2019 11:08:46 +0100 Subject: [PATCH 13/71] refactored deprecated pymongo methods in tests - remove/count/add_user/insert - added pymongo_support --- mongoengine/connection.py | 4 +-- mongoengine/context_managers.py | 3 +- mongoengine/document.py | 4 +-- mongoengine/pymongo_support.py | 33 ++++++++++++++++++++++ mongoengine/python_support.py | 7 +---- mongoengine/queryset/base.py | 2 +- mongoengine/queryset/transform.py | 2 +- tests/document/class_methods.py | 13 ++++----- tests/document/delta.py | 14 ++++----- tests/document/inheritance.py | 15 ++++------ tests/document/instance.py | 10 +++---- tests/fields/file_tests.py | 44 +++++++++++++++++------------ tests/queryset/queryset.py | 6 ++-- tests/test_connection.py | 21 +++++++------- tests/test_context_managers.py | 3 +- tests/test_replicaset_connection.py | 2 +- tests/utils.py | 2 +- 17 files changed, 106 insertions(+), 79 deletions(-) create mode 100644 mongoengine/pymongo_support.py diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 38ebb243..c0cfde31 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,10 +1,10 @@ from pymongo import MongoClient, ReadPreference, uri_parser import six -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 __all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME'] + 'DEFAULT_CONNECTION_NAME', 'get_db'] DEFAULT_CONNECTION_NAME = 'default' diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index d1e5d9ef..98bd897b 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -5,6 +5,7 @@ from six import iteritems from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.pymongo_support import count_documents __all__ = ('switch_db', 'switch_collection', 'no_dereference', 'no_sub_classes', 'query_counter', 'set_write_concern') @@ -237,7 +238,7 @@ class query_counter(object): and substracting the queries issued by this context. In fact everytime this is called, 1 query is issued so we need to balance that """ - count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter + count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter self._ctx_query_counter += 1 # Account for the query we just issued to gather the information return count diff --git a/mongoengine/document.py b/mongoengine/document.py index 84c1d699..5981d8d1 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -18,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern, switch_db) from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, SaveConditionError) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3, list_collection_names from mongoengine.queryset import (NotUniqueError, OperationError, QuerySet, transform) @@ -228,7 +228,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # If the collection already exists and has different options # (i.e. isn't capped or has different max/size), raise an error. - if collection_name in db.collection_names(): + if collection_name in list_collection_names(db, include_system_collections=True): collection = db[collection_name] options = collection.options() if ( diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py new file mode 100644 index 00000000..0d607162 --- /dev/null +++ b/mongoengine/pymongo_support.py @@ -0,0 +1,33 @@ +""" +Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support. +""" +import pymongo + +_PYMONGO_37 = (3, 7) + +PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) + +IS_PYMONGO_3 = PYMONGO_VERSION[0] >= 3 +IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 + + +def count_documents(collection, filter): + """Pymongo>3.7 deprecates count in favour of count_documents""" + if IS_PYMONGO_GTE_37: + return collection.count_documents(filter) + else: + count = collection.find(filter).count() + return count + + +def list_collection_names(db, include_system_collections=False): + """Pymongo>3.7 deprecates collection_names in favour of list_collection_names""" + if IS_PYMONGO_GTE_37: + collections = db.list_collection_names() + else: + collections = db.collection_names() + + if not include_system_collections: + collections = [c for c in collections if not c.startswith('system.')] + + return collections diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 7e8e108f..57e467db 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -1,13 +1,8 @@ """ -Helper functions, constants, and types to aid with Python v2.7 - v3.x and -PyMongo v2.7 - v3.x support. +Helper functions, constants, and types to aid with Python v2.7 - v3.x support """ -import pymongo import six - -IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3 - # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. StringIO = six.BytesIO diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 8c22c5b9..9ddfeab2 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -21,7 +21,7 @@ from mongoengine.connection import get_db from mongoengine.context_managers import set_write_concern, switch_db from mongoengine.errors import (InvalidQueryError, LookUpError, NotUniqueError, OperationError) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index c00271f3..3de10a69 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -10,7 +10,7 @@ from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class from mongoengine.connection import get_connection from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 __all__ = ('query', 'update') diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 88937ec8..421618e4 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -2,6 +2,7 @@ import unittest from mongoengine import * +from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db @@ -27,9 +28,7 @@ class ClassMethodsTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_definition(self): @@ -66,10 +65,10 @@ class ClassMethodsTest(unittest.TestCase): """ collection_name = 'person' self.Person(name='Test').save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, list_collection_names(self.db)) self.Person.drop_collection() - self.assertNotIn(collection_name, self.db.collection_names()) + self.assertNotIn(collection_name, list_collection_names(self.db)) def test_register_delete_rule(self): """Ensure that register delete rule adds a delete rule to the document @@ -340,7 +339,7 @@ class ClassMethodsTest(unittest.TestCase): meta = {'collection': collection_name} Person(name="Test User").save() - self.assertIn(collection_name, self.db.collection_names()) + self.assertIn(collection_name, list_collection_names(self.db)) user_obj = self.db[collection_name].find_one() self.assertEqual(user_obj['name'], "Test User") @@ -349,7 +348,7 @@ class ClassMethodsTest(unittest.TestCase): self.assertEqual(user_obj.name, "Test User") Person.drop_collection() - self.assertNotIn(collection_name, self.db.collection_names()) + self.assertNotIn(collection_name, list_collection_names(self.db)) def test_collection_name_and_primary(self): """Ensure that a collection with a specified name may be used. diff --git a/tests/document/delta.py b/tests/document/delta.py index 942e3a0a..504c1707 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -3,16 +3,14 @@ import unittest from bson import SON from mongoengine import * -from mongoengine.connection import get_db - -__all__ = ("DeltaTest",) +from mongoengine.pymongo_support import list_collection_names +from tests.utils import MongoDBTestCase -class DeltaTest(unittest.TestCase): +class DeltaTest(MongoDBTestCase): def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() + super(DeltaTest, self).setUp() class Person(Document): name = StringField() @@ -25,9 +23,7 @@ class DeltaTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_delta(self): diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 83c2a80a..d81039f4 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -6,23 +6,18 @@ from six import iteritems from mongoengine import (BooleanField, Document, EmbeddedDocument, EmbeddedDocumentField, GenericReferenceField, - IntField, ReferenceField, StringField, connect) -from mongoengine.connection import get_db + IntField, ReferenceField, StringField) +from mongoengine.pymongo_support import list_collection_names +from tests.utils import MongoDBTestCase from tests.fixtures import Base __all__ = ('InheritanceTest', ) -class InheritanceTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() +class InheritanceTest(MongoDBTestCase): def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_constructor_cls(self): diff --git a/tests/document/instance.py b/tests/document/instance.py index 051eda68..9b28f1b4 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -12,6 +12,7 @@ from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError from six import iteritems +from mongoengine.pymongo_support import list_collection_names from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, PickleDynamicEmbedded, PickleDynamicTest) @@ -55,9 +56,7 @@ class InstanceTest(MongoDBTestCase): self.Job = Job def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def assertDbEqual(self, docs): @@ -572,7 +571,7 @@ class InstanceTest(MongoDBTestCase): Post.drop_collection() - Post._get_collection().insert({ + Post._get_collection().insert_one({ "title": "Items eclipse", "items": ["more lorem", "even more ipsum"] }) @@ -3217,8 +3216,7 @@ class InstanceTest(MongoDBTestCase): coll = Person._get_collection() for person in Person.objects.as_pymongo(): if 'height' not in person: - person['height'] = 189 - coll.save(person) + coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}}) self.assertEquals(Person.objects(height=189).count(), 1) diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 76e20bb9..4ff6865b 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -24,6 +24,16 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') +def get_file(path): + """Use a BytesIO instead of a file to allow + to have a one-liner and avoid that the file remains opened""" + bytes_io = StringIO() + with open(path, 'rb') as f: + bytes_io.write(f.read()) + bytes_io.seek(0) + return bytes_io + + class FileTest(MongoDBTestCase): def tearDown(self): @@ -247,8 +257,8 @@ class FileTest(MongoDBTestCase): Animal.drop_collection() marmot = Animal(genus='Marmota', family='Sciuridae') - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - marmot.photo.put(marmot_photo, content_type='image/jpeg', foo='bar') + marmot_photo_content = get_file(TEST_IMAGE_PATH) # Retrieve a photo from disk + marmot.photo.put(marmot_photo_content, content_type='image/jpeg', foo='bar') marmot.photo.close() marmot.save() @@ -261,11 +271,11 @@ class FileTest(MongoDBTestCase): the_file = FileField() TestFile.drop_collection() - test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() self.assertEqual(test_file.the_file.get().length, 8313) test_file = TestFile.objects.first() - test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() self.assertEqual(test_file.the_file.get().length, 4971) @@ -379,7 +389,7 @@ class FileTest(MongoDBTestCase): self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -400,11 +410,11 @@ class FileTest(MongoDBTestCase): the_file = ImageField() TestFile.drop_collection() - test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() self.assertEqual(test_file.the_file.size, (371, 76)) test_file = TestFile.objects.first() - test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.the_file = get_file(TEST_IMAGE2_PATH) test_file.save() self.assertEqual(test_file.the_file.size, (45, 101)) @@ -418,7 +428,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -441,7 +451,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -464,7 +474,7 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.image.put(get_file(TEST_IMAGE_PATH)) t.save() t = TestImage.objects.first() @@ -542,8 +552,8 @@ class FileTest(MongoDBTestCase): TestImage.drop_collection() t = TestImage() - t.image1.put(open(TEST_IMAGE_PATH, 'rb')) - t.image2.put(open(TEST_IMAGE2_PATH, 'rb')) + t.image1.put(get_file(TEST_IMAGE_PATH)) + t.image2.put(get_file(TEST_IMAGE2_PATH)) t.save() test = TestImage.objects.first() @@ -563,12 +573,10 @@ class FileTest(MongoDBTestCase): Animal.drop_collection() marmot = Animal(genus='Marmota', family='Sciuridae') - marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk - - photos_field = marmot._fields['photos'].field - new_proxy = photos_field.get_proxy_obj('photos', marmot) - new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') - marmot_photo.close() + with open(TEST_IMAGE_PATH, 'rb') as marmot_photo: # Retrieve a photo from disk + photos_field = marmot._fields['photos'].field + new_proxy = photos_field.get_proxy_obj('photos', marmot) + new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') marmot.photos.append(new_proxy) marmot.save() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 3d8e8960..b45edb33 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -18,7 +18,7 @@ from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) @@ -4051,7 +4051,7 @@ class QuerySetTest(unittest.TestCase): fielda = IntField() fieldb = IntField() - IntPair.objects._collection.remove() + IntPair.drop_collection() a = IntPair(fielda=1, fieldb=1) b = IntPair(fielda=1, fieldb=2) @@ -5386,7 +5386,7 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() - Person._get_collection().insert({'name': 'a', 'id': ''}) + Person._get_collection().insert_one({'name': 'a', 'id': ''}) for p in Person.objects(): self.assertEqual(p.name, 'a') diff --git a/tests/test_connection.py b/tests/test_connection.py index 7c4fc4cf..fafef9d4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -14,7 +14,7 @@ from mongoengine import ( connect, register_connection, Document, DateTimeField ) -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 import mongoengine.connection from mongoengine.connection import (MongoEngineConnectionError, get_db, get_connection) @@ -147,12 +147,12 @@ class ConnectionTest(unittest.TestCase): def test_connect_uri(self): """Ensure that the connect() method works properly with URIs.""" c = connect(db='mongoenginetest', alias='admin') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) - c.admin.add_user("admin", "password") + c.admin.command("createUser", "admin", pwd="password", roles=["root"]) c.admin.authenticate("admin", "password") - c.mongoenginetest.add_user("username", "password") + c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"]) if not IS_PYMONGO_3: self.assertRaises( @@ -169,8 +169,8 @@ class ConnectionTest(unittest.TestCase): self.assertIsInstance(db, pymongo.database.Database) self.assertEqual(db.name, 'mongoenginetest') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) def test_connect_uri_without_db(self): """Ensure connect() method works properly if the URI doesn't @@ -217,8 +217,9 @@ class ConnectionTest(unittest.TestCase): """ # Create users c = connect('mongoenginetest') - c.admin.system.users.remove({}) - c.admin.add_user('username2', 'password') + + c.admin.system.users.delete_many({}) + c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"]) # Authentication fails without "authSource" if IS_PYMONGO_3: @@ -246,7 +247,7 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(db.name, 'mongoenginetest') # Clear all users - authd_conn.admin.system.users.remove({}) + authd_conn.admin.system.users.delete_many({}) def test_register_connection(self): """Ensure that connections with different aliases may be registered. diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 8207cd89..227031e0 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -5,6 +5,7 @@ from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, no_sub_classes, no_dereference, query_counter) +from mongoengine.pymongo_support import count_documents class ContextManagersTest(unittest.TestCase): @@ -240,7 +241,7 @@ class ContextManagersTest(unittest.TestCase): collection.drop() def issue_1_count_query(): - collection.find({}).count() + count_documents(collection, {}) def issue_1_insert_query(): collection.insert_one({'test': 'garbage'}) diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index 4aa647d6..81fdfb64 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -2,7 +2,7 @@ import unittest from pymongo import ReadPreference -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 if IS_PYMONGO_3: from pymongo import MongoClient diff --git a/tests/utils.py b/tests/utils.py index 19936a54..bda3a878 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ from nose.plugins.skip import SkipTest from mongoengine import connect from mongoengine.connection import get_db, get_connection -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.pymongo_support import IS_PYMONGO_3 MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database From 7ef688b256907748981c92291a768451288f1d87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 2 Mar 2019 22:05:23 +0100 Subject: [PATCH 14/71] Added a test for push in DictField (relates to #1679) --- tests/fields/test_dict_field.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index a3b8ec6c..ade02ccf 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -199,6 +199,26 @@ class TestDictField(MongoDBTestCase): self.assertEqual( Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + def test_push_dict(self): + class MyModel(Document): + events = ListField(DictField()) + + doc = MyModel(events=[{'a': 1}]).save() + raw_doc = get_as_pymongo(doc) + expected_raw_doc = { + '_id': doc.id, + 'events': [{'a': 1}] + } + self.assertEqual(raw_doc, expected_raw_doc) + + MyModel.objects(id=doc.id).update(push__events={}) + raw_doc = get_as_pymongo(doc) + expected_raw_doc = { + '_id': doc.id, + 'events': [{'a': 1}, {}] + } + self.assertEqual(raw_doc, expected_raw_doc) + def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" class D(Document): From b640c766dbb84c7a77efae8645f2d9f480ca96d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 4 Mar 2019 23:01:12 +0100 Subject: [PATCH 15/71] Fix queryset batch_size that wasn't copied to cloned queryset --- mongoengine/queryset/base.py | 2 +- tests/queryset/queryset.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 9ddfeab2..24e12623 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -757,7 +757,7 @@ class BaseQuerySet(object): '_read_preference', '_iter', '_scalar', '_as_pymongo', '_limit', '_skip', '_hint', '_auto_dereference', '_search_text', 'only_fields', '_max_time_ms', - '_comment') + '_comment', '_batch_size') for prop in copy_props: val = getattr(self, prop) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 0d8d6285..662ffe61 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -394,6 +394,16 @@ class QuerySetTest(unittest.TestCase): with self.assertRaises(ValueError): list(qs) + def test_batch_size_cloned(self): + class A(Document): + s = StringField() + + # test that batch size gets cloned + qs = A.objects.batch_size(5) + self.assertEqual(qs._batch_size, 5) + qs_clone = qs.clone() + self.assertEqual(qs_clone._batch_size, 5) + def test_update_write_concern(self): """Test that passing write_concern works""" self.Person.drop_collection() From 9bd0d6b99d960d3a1b14c627c4148d6bea510df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 4 Mar 2019 23:05:22 +0100 Subject: [PATCH 16/71] update changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index c87fd1e9..da1026d4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Development =========== - (Fill this out as you fix issues and develop your features). - Fix .only() working improperly after using .count() of the same instance of QuerySet +- Fix batch_size that was not copied when cloning a queryset object #2011 - POTENTIAL BREAKING CHANGE: All result fields are now passed, including internal fields (_cls, _id) when using `QuerySet.as_pymongo` #1976 - Document a BREAKING CHANGE introduced in 0.15.3 and not reported at that time (#1995) - Fix InvalidStringData error when using modify on a BinaryField #1127 From 7b4245c91c6bd55b70e37cf58272e08c60b0b678 Mon Sep 17 00:00:00 2001 From: Erdenezul Batmunkh Date: Sun, 10 Mar 2019 21:16:58 +0800 Subject: [PATCH 17/71] Bump version 0.17.0 --- docs/changelog.rst | 3 +++ mongoengine/__init__.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c87fd1e9..97da8d10 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,9 @@ Changelog Development =========== - (Fill this out as you fix issues and develop your features). + +Changes in 0.17.0 +================= - Fix .only() working improperly after using .count() of the same instance of QuerySet - POTENTIAL BREAKING CHANGE: All result fields are now passed, including internal fields (_cls, _id) when using `QuerySet.as_pymongo` #1976 - Document a BREAKING CHANGE introduced in 0.15.3 and not reported at that time (#1995) diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 2b78d4e6..b94efab9 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) + list(signals.__all__) + list(errors.__all__)) -VERSION = (0, 16, 3) +VERSION = (0, 17, 0) def get_version(): From 48b849c0319e5f35b81eeb926e7540d3d168aa78 Mon Sep 17 00:00:00 2001 From: lalala223 Date: Wed, 13 Mar 2019 17:50:54 +0800 Subject: [PATCH 18/71] Update querying.rst Fix the 'not' operator error example. --- docs/guide/querying.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 08987835..151855a6 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -64,7 +64,7 @@ Available operators are as follows: * ``gt`` -- greater than * ``gte`` -- greater than or equal to * ``not`` -- negate a standard check, may be used before other operators (e.g. - ``Q(age__not__mod=5)``) + ``Q(age__not__mod=(5, 0))``) * ``in`` -- value is in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values From 68497542b3eaf58fe33ce6548de8f5c76fd751b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 17 Mar 2019 22:04:19 +0100 Subject: [PATCH 19/71] Bump the required version of pymongo to >=3.5 --- README.rst | 2 +- mongoengine/document.py | 10 ---------- requirements.txt | 2 +- setup.py | 2 +- tests/document/indexes.py | 4 ---- tests/document/json_serialisation.py | 4 ---- tests/queryset/queryset.py | 3 --- tests/test_connection.py | 12 +----------- 8 files changed, 4 insertions(+), 35 deletions(-) diff --git a/README.rst b/README.rst index f0309170..12d9df0e 100644 --- a/README.rst +++ b/README.rst @@ -47,7 +47,7 @@ Dependencies All of the dependencies can easily be installed via `pip `_. At the very least, you'll need these two packages to use MongoEngine: -- pymongo>=2.7.1 +- pymongo>=3.5 - six>=1.10.0 If you utilize a ``DateTimeField``, you might also use a more flexible date parser: diff --git a/mongoengine/document.py b/mongoengine/document.py index 5981d8d1..328ac299 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -451,16 +451,6 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): object_id = wc_collection.insert_one(doc).inserted_id - # In PyMongo 3.0, the save() call calls internally the _update() call - # but they forget to return the _id value passed back, therefore getting it back here - # Correct behaviour in 2.X and in 3.0.1+ versions - if not object_id and pymongo.version_tuple == (3, 0): - pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) - object_id = ( - self._qs.filter(pk=pk_as_mongo_obj).first() and - self._qs.filter(pk=pk_as_mongo_obj).first().pk - ) # TODO doesn't this make 2 queries? - return object_id def _get_update_doc(self): diff --git a/requirements.txt b/requirements.txt index 4e3ea940..38e0b20f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ nose -pymongo>=2.7.1 +pymongo>=3.5 six==1.10.0 flake8 flake8-import-order diff --git a/setup.py b/setup.py index c7632ce3..c8e9c038 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ setup( long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=2.7.1', 'six'], + install_requires=['pymongo>=3.5', 'six'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index abd349f3..dd443857 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -565,10 +565,6 @@ class IndexesTest(unittest.TestCase): self.assertEqual(BlogPost.objects.count(), 10) self.assertEqual(BlogPost.objects.hint().count(), 10) - # PyMongo 3.0 bug only, works correctly with 2.X and 3.0.1+ versions - if pymongo.version != '3.0': - self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) - if MONGO_VER >= MONGODB_32: # Mongo32 throws an error if an index exists (i.e `tags` in our case) # and you use hint on an index name that does not exist diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 7c785ab2..251b65a2 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -61,10 +61,6 @@ class TestJson(unittest.TestCase): self.assertEqual(doc, Doc.from_json(doc.to_json())) def test_json_complex(self): - - if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: - raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") - class EmbeddedDoc(EmbeddedDocument): pass diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 662ffe61..46d82203 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -4602,9 +4602,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) def test_json_complex(self): - if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: - raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") - class EmbeddedDoc(EmbeddedDocument): pass diff --git a/tests/test_connection.py b/tests/test_connection.py index fafef9d4..0a7271df 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -285,14 +285,7 @@ class ConnectionTest(unittest.TestCase): """Ensure we can specify a max connection pool size using a connection kwarg. """ - # Use "max_pool_size" or "maxpoolsize" depending on PyMongo version - # (former was changed to the latter as described in - # https://jira.mongodb.org/browse/PYTHON-854). - # TODO remove once PyMongo < 3.0 support is dropped - if pymongo.version_tuple[0] >= 3: - pool_size_kwargs = {'maxpoolsize': 100} - else: - pool_size_kwargs = {'max_pool_size': 100} + pool_size_kwargs = {'maxpoolsize': 100} conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs) self.assertEqual(conn.max_pool_size, 100) @@ -301,9 +294,6 @@ class ConnectionTest(unittest.TestCase): """Ensure we can specify a max connection pool size using an option in a connection URI. """ - if pymongo.version_tuple[0] == 2 and pymongo.version_tuple[1] < 9: - raise SkipTest('maxpoolsize as a URI option is only supported in PyMongo v2.9+') - conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri') self.assertEqual(conn.max_pool_size, 100) From 6f8be8c8ac2023354aa1227a2eb9f1c92cf49266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 17 Mar 2019 22:11:01 +0100 Subject: [PATCH 20/71] document change in changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9a606812..741d3f8d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Development =========== +- mongoengine now requires pymongo>=3.5 #2017 - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 From ba6a37f315392a48092486ade238db97d1ac67e7 Mon Sep 17 00:00:00 2001 From: Paulo Amaral Date: Fri, 15 Mar 2019 13:20:40 +0000 Subject: [PATCH 21/71] Generate Unique Indices for SortedListField and EmbeddedDocumentListFields --- AUTHORS | 3 +- docs/changelog.rst | 1 + mongoengine/base/document.py | 3 +- tests/document/indexes.py | 71 ++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 880dfad1..e8a43dac 100644 --- a/AUTHORS +++ b/AUTHORS @@ -248,4 +248,5 @@ that much better: * Andy Yankovsky (https://github.com/werat) * Bastien Gérard (https://github.com/bagerard) * Trevor Hall (https://github.com/tjhall13) - * Gleb Voropaev (https://github.com/buggyspace) \ No newline at end of file + * Gleb Voropaev (https://github.com/buggyspace) + * Paulo Amaral (https://github.com/pauloAmaral) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9a606812..42a0ab14 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Development =========== +- Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 8587f17f..4cf34b4f 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -883,7 +883,8 @@ class BaseDocument(object): index = {'fields': fields, 'unique': True, 'sparse': sparse} unique_indexes.append(index) - if field.__class__.__name__ == 'ListField': + if field.__class__.__name__ in {'EmbeddedDocumentListField', + 'ListField', 'SortedListField'}: field = field.field # Grab any embedded document field unique indexes diff --git a/tests/document/indexes.py b/tests/document/indexes.py index abd349f3..b63faa9d 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -708,6 +708,77 @@ class IndexesTest(unittest.TestCase): self.assertRaises(NotUniqueError, post2.save) + def test_unique_embedded_document_in_sorted_list(self): + """ + Ensure that the uniqueness constraints are applied to fields in + embedded documents, even when the embedded documents in a sorted list + field. + """ + class SubDocument(EmbeddedDocument): + year = IntField() + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + subs = SortedListField(EmbeddedDocumentField(SubDocument), + ordering='year') + + BlogPost.drop_collection() + + post1 = BlogPost( + title='test1', subs=[ + SubDocument(year=2009, slug='conflict'), + SubDocument(year=2009, slug='conflict') + ] + ) + post1.save() + + # confirm that the unique index is created + indexes = BlogPost._get_collection().index_information() + self.assertIn('subs.slug_1', indexes) + self.assertTrue(indexes['subs.slug_1']['unique']) + + post2 = BlogPost( + title='test2', subs=[SubDocument(year=2014, slug='conflict')] + ) + + self.assertRaises(NotUniqueError, post2.save) + + def test_unique_embedded_document_in_embedded_document_list(self): + """ + Ensure that the uniqueness constraints are applied to fields in + embedded documents, even when the embedded documents in an embedded + list field. + """ + class SubDocument(EmbeddedDocument): + year = IntField() + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + subs = EmbeddedDocumentListField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost( + title='test1', subs=[ + SubDocument(year=2009, slug='conflict'), + SubDocument(year=2009, slug='conflict') + ] + ) + post1.save() + + # confirm that the unique index is created + indexes = BlogPost._get_collection().index_information() + self.assertIn('subs.slug_1', indexes) + self.assertTrue(indexes['subs.slug_1']['unique']) + + post2 = BlogPost( + title='test2', subs=[SubDocument(year=2014, slug='conflict')] + ) + + self.assertRaises(NotUniqueError, post2.save) + def test_unique_with_embedded_document_and_embedded_unique(self): """Ensure that uniqueness constraints are applied to fields on embedded documents. And work with unique_with as well. From fdcaca42ae6145ac102ec3475bdc8c550a7a604a Mon Sep 17 00:00:00 2001 From: Gaurav Dadhania Date: Wed, 20 Feb 2019 12:42:07 +0530 Subject: [PATCH 22/71] Do not keep calling _dereference on values if it has already been dereferenced. --- AUTHORS | 1 + docs/changelog.rst | 1 + mongoengine/base/fields.py | 7 ++- tests/test_dereference.py | 103 +++++++++++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index e8a43dac..21b0ec64 100644 --- a/AUTHORS +++ b/AUTHORS @@ -250,3 +250,4 @@ that much better: * Trevor Hall (https://github.com/tjhall13) * Gleb Voropaev (https://github.com/buggyspace) * Paulo Amaral (https://github.com/pauloAmaral) + * Gaurav Dadhania (https://github.com/GVRV) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9d87d889..53373302 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,7 @@ Changes in 0.17.0 - Fix InvalidStringData error when using modify on a BinaryField #1127 - DEPRECATION: `EmbeddedDocument.save` & `.reload` are marked as deprecated and will be removed in a next version of mongoengine #1552 - Fix test suite and CI to support MongoDB 3.4 #1445 +- Fix reference fields querying the database on each access if value contains orphan DBRefs ================= Changes in 0.16.3 diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 5586c5b7..598eb606 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -276,11 +276,16 @@ class ComplexBaseField(BaseField): _dereference = _import_class('DeReference')() - if instance._initialised and dereference and instance._data.get(self.name): + if (instance._initialised and + dereference and + instance._data.get(self.name) and + not getattr(instance._data[self.name], '_dereferenced', False)): instance._data[self.name] = _dereference( instance._data.get(self.name), max_depth=1, instance=instance, name=self.name ) + if hasattr(instance._data[self.name], '_dereferenced'): + instance._data[self.name]._dereferenced = True value = super(ComplexBaseField, self).__get__(instance, owner) diff --git a/tests/test_dereference.py b/tests/test_dereference.py index cf1194f4..9c565810 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -105,6 +105,14 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + + # verifies that no additional queries gets executed + # if we re-iterate over the ListField once it is + # dereferenced + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) # Document select_related with query_counter() as q: @@ -125,6 +133,46 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + def test_list_item_dereference_orphan_dbref(self): + """Ensure that orphan DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in range(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + group.reload() # Confirm reload works + + # Delete one User so one of the references in the + # Group.members list is an orphan DBRef + User.objects[0].delete() + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + + # verifies that no additional queries gets executed + # if we re-iterate over the ListField once it is + # dereferenced + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + User.drop_collection() Group.drop_collection() @@ -505,6 +553,61 @@ class FieldTest(unittest.TestCase): for m in group_obj.members: self.assertIn('User', m.__class__.__name__) + + def test_generic_reference_orphan_dbref(self): + """Ensure that generic orphan DBRef items in ListFields are dereferenced. + """ + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in range(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + # Delete one UserA instance so that there is + # an orphan DBRef in the GenericReference ListField + UserA.objects[0].delete() + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + self.assertTrue(group_obj._data['members']._dereferenced) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + self.assertTrue(group_obj._data['members']._dereferenced) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() From 9f2a9d9cda4f1bfb9ca7467e1e2769b51850f188 Mon Sep 17 00:00:00 2001 From: Neeraj Date: Wed, 3 Apr 2019 19:09:45 +0530 Subject: [PATCH 23/71] Fix limit usage in aggregate As per https://stackoverflow.com/a/24161461 --- mongoengine/queryset/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 24e12623..c2fb34df 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1200,7 +1200,7 @@ class BaseQuerySet(object): initial_pipeline.append({'$sort': dict(self._ordering)}) if self._limit is not None: - initial_pipeline.append({'$limit': self._limit}) + initial_pipeline.append({'$limit': self._limit + self._skip}) if self._skip is not None: initial_pipeline.append({'$skip': self._skip}) From 4ccfdf051d6d7a94406ab7004baaea82278ceb92 Mon Sep 17 00:00:00 2001 From: Neeraj Suthar Date: Sat, 6 Apr 2019 17:23:02 +0530 Subject: [PATCH 24/71] remove fix; add testcases --- mongoengine/queryset/base.py | 2 +- tests/queryset/queryset.py | 179 +++++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index c2fb34df..24e12623 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1200,7 +1200,7 @@ class BaseQuerySet(object): initial_pipeline.append({'$sort': dict(self._ordering)}) if self._limit is not None: - initial_pipeline.append({'$limit': self._limit + self._skip}) + initial_pipeline.append({'$limit': self._limit}) if self._skip is not None: initial_pipeline.append({'$skip': self._skip}) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 46d82203..1a5aa1f2 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -5349,6 +5349,185 @@ class QuerySetTest(unittest.TestCase): {'_id': None, 'avg': 29, 'total': 2} ]) + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_skip(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.skip(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p2.pk, 'name': "WILSON JUNIOR"}, + {'_id': p3.pk, 'name': "SANDRA MARA"} + ]) + + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.limit(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"} + ]) + + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_sort(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.order_by('name').aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, + {'_id': p3.pk, 'name': "SANDRA MARA"}, + {'_id': p2.pk, 'name': "WILSON JUNIOR"} + ]) + + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_skip_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.skip(1).limit(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p2.pk, 'name': "WILSON JUNIOR"}, + ]) + + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_sort_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.order_by('name').limit(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, + {'_id': p3.pk, 'name': "SANDRA MARA"} + ]) + + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_sort_with_skip(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.order_by('name').skip(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p2.pk, 'name': "WILSON JUNIOR"} + ]) + + @requires_mongodb_gte_26 + def test_queryset_aggregation_with_sort_with_skip_with_limit(self): + class Person(Document): + name = StringField() + age = IntField() + + Person.drop_collection() + + p1 = Person(name="Isabella Luanna", age=16) + p1.save() + + p2 = Person(name="Wilson Junior", age=21) + p2.save() + + p3 = Person(name="Sandra Mara", age=37) + p3.save() + + data = Person.objects.order_by('name').skip(1).limit(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(list(data), [ + {'_id': p3.pk, 'name': "SANDRA MARA"} + ]) + def test_delete_count(self): [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count From 61081651e4f1662e205037b3daa4897e402b924a Mon Sep 17 00:00:00 2001 From: Neeraj Suthar Date: Sat, 6 Apr 2019 17:42:14 +0530 Subject: [PATCH 25/71] reinsert fix; add comments, reference --- mongoengine/queryset/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 24e12623..66e43514 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1200,7 +1200,11 @@ class BaseQuerySet(object): initial_pipeline.append({'$sort': dict(self._ordering)}) if self._limit is not None: - initial_pipeline.append({'$limit': self._limit}) + # As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/), + # keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents + # for a skip stage that might succeed these. So we need to maintain more documents in memory in such a + # case (https://stackoverflow.com/a/24161461). + initial_pipeline.append({'$limit': self._limit + (self._skip or 0)}) if self._skip is not None: initial_pipeline.append({'$skip': self._skip}) From b5213097e887997a247218d9229c0ad3bec074a9 Mon Sep 17 00:00:00 2001 From: Yurii Andrieiev Date: Sun, 7 Apr 2019 02:02:26 +0300 Subject: [PATCH 26/71] Fail fast when db name is invalid Without this commit save operation on first document would fail instead of immediate failure upon connection attempt. Such later failure is much less obvious. --- AUTHORS | 1 + docs/changelog.rst | 1 + mongoengine/connection.py | 12 ++++++++++++ tests/test_connection.py | 32 +++++++++++++++++++++++++++++++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 21b0ec64..45a754cc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -251,3 +251,4 @@ that much better: * Gleb Voropaev (https://github.com/buggyspace) * Paulo Amaral (https://github.com/pauloAmaral) * Gaurav Dadhania (https://github.com/GVRV) + * Yurii Andrieiev (https://github.com/yandrieiev) diff --git a/docs/changelog.rst b/docs/changelog.rst index 53373302..707182f1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Development =========== - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 +- connect() fails immediately when db name contains invalid characters (e. g. when user mistakenly puts 'mongodb://127.0.0.1:27017' as db name, happened in #1718) or is if db name is of an invalid type - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 diff --git a/mongoengine/connection.py b/mongoengine/connection.py index c0cfde31..dda9bbb7 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,4 +1,5 @@ from pymongo import MongoClient, ReadPreference, uri_parser +from pymongo.database import _check_name import six from mongoengine.pymongo_support import IS_PYMONGO_3 @@ -28,6 +29,16 @@ _connections = {} _dbs = {} +def check_db_name(name): + """Check if a database name is valid. + This functionality is copied from pymongo Database class constructor. + """ + if not isinstance(name, six.string_types): + raise TypeError('name must be an instance of %s' % six.string_types) + elif name != '$external': + _check_name(name) + + def register_connection(alias, db=None, name=None, host=None, port=None, read_preference=READ_PREFERENCE, username=None, password=None, @@ -69,6 +80,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, 'authentication_mechanism': authentication_mechanism } + check_db_name(conn_settings['name']) conn_host = conn_settings['host'] # Host can be a list or a string, so if string, force to a list. diff --git a/tests/test_connection.py b/tests/test_connection.py index 0a7271df..fb2a20d7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,5 +1,5 @@ import datetime -from pymongo.errors import OperationFailure +from pymongo.errors import OperationFailure, InvalidName try: import unittest2 as unittest @@ -49,6 +49,36 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + def test_connect_with_invalid_db_name(self): + """Ensure that connect() method fails fast if db name is invalid + """ + with self.assertRaises(InvalidName): + connect('mongomock://localhost') + + def test_connect_with_db_name_external(self): + """Ensure that connect() works if db name is $external + """ + """Ensure that the connect() method works properly.""" + connect('$external') + + conn = get_connection() + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, '$external') + + connect('$external', alias='testdb') + conn = get_connection('testdb') + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + def test_connect_with_invalid_db_name_type(self): + """Ensure that connect() method fails fast if db name has invalid type + """ + with self.assertRaises(TypeError): + non_string_db_name = ['e. g. list instead of a string'] + connect(non_string_db_name) + def test_connect_in_mocking(self): """Ensure that the connect() method works properly in mocking. """ From 9bb3dfd6392f554dea5716df5cc5b49df7e53cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 7 Apr 2019 23:05:55 +0200 Subject: [PATCH 27/71] updated changelog for recent commits + improve tests --- docs/changelog.rst | 1 + tests/queryset/queryset.py | 79 ++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 53373302..5a472eb5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Development =========== +- POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 - (Fill this out as you fix issues and develop your features). diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 1a5aa1f2..31b1641e 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -5312,13 +5312,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects(age__lte=22).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5358,13 +5354,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects.skip(1).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5384,13 +5376,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects.limit(1).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5409,13 +5397,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects.order_by('name').aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5436,22 +5420,27 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) - data = Person.objects.skip(1).limit(1).aggregate( + data = list( + Person.objects.skip(1).limit(1).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} + ) ) self.assertEqual(list(data), [ {'_id': p2.pk, 'name': "WILSON JUNIOR"}, ]) + # Make sure limit/skip chaining order has no impact + data2 = Person.objects.limit(1).skip(1).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}} + ) + + self.assertEqual(data, list(data2)) + @requires_mongodb_gte_26 def test_queryset_aggregation_with_sort_with_limit(self): class Person(Document): @@ -5461,13 +5450,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects.order_by('name').limit(2).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5478,6 +5463,26 @@ class QuerySetTest(unittest.TestCase): {'_id': p3.pk, 'name': "SANDRA MARA"} ]) + # Verify adding limit/skip steps works as expected + data = Person.objects.order_by('name').limit(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}}, + {'$limit': 1}, + ) + + self.assertEqual(list(data), [ + {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, + ]) + + data = Person.objects.order_by('name').limit(2).aggregate( + {'$project': {'name': {'$toUpper': '$name'}}}, + {'$skip': 1}, + {'$limit': 1}, + ) + + self.assertEqual(list(data), [ + {'_id': p3.pk, 'name': "SANDRA MARA"}, + ]) + @requires_mongodb_gte_26 def test_queryset_aggregation_with_sort_with_skip(self): class Person(Document): @@ -5487,13 +5492,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects.order_by('name').skip(2).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} @@ -5512,13 +5513,9 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() p1 = Person(name="Isabella Luanna", age=16) - p1.save() - p2 = Person(name="Wilson Junior", age=21) - p2.save() - p3 = Person(name="Sandra Mara", age=37) - p3.save() + Person.objects.insert([p1, p2, p3]) data = Person.objects.order_by('name').skip(1).limit(1).aggregate( {'$project': {'name': {'$toUpper': '$name'}}} From d1467c2f73dc29a16058a06cb10185f04bd2d5cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 15 Apr 2019 21:24:07 +0200 Subject: [PATCH 28/71] Fix connect/disconnect functions - expose disconnect - disconnect cleans _connection_settings - disconnect cleans cached collection in Document._collection - re-connecting with the same alias raise an error (must call disconnect in between) --- docs/changelog.rst | 5 + mongoengine/base/common.py | 11 +- mongoengine/connection.py | 93 +++++++++++++--- mongoengine/document.py | 12 ++- tests/test_connection.py | 213 +++++++++++++++++++++++++++++++++++-- tests/utils.py | 4 +- 6 files changed, 312 insertions(+), 26 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5a472eb5..6bd19b12 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,11 @@ Changelog Development =========== +- expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` +- POTENTIAL BREAKING CHANGE: Fixes in connect/disconnect methods + - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) + - disconnect now clears `mongoengine.connection._connection_settings` + - disconnect now clears the cached attribute `Document._collection` - POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index d747c8cc..999fd23a 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -13,7 +13,7 @@ _document_registry = {} def get_document(name): - """Get a document class by name.""" + """Get a registered Document class by name.""" doc = _document_registry.get(name, None) if not doc: # Possible old style name @@ -30,3 +30,12 @@ def get_document(name): been imported? """.strip() % name) return doc + + +def _get_documents_by_db(connection_alias, default_connection_alias): + """Get all registered Documents class attached to a given database""" + def get_doc_alias(doc_cls): + return doc_cls._meta.get('db_alias', default_connection_alias) + + return [doc_cls for doc_cls in _document_registry.values() + if get_doc_alias(doc_cls) == connection_alias] diff --git a/mongoengine/connection.py b/mongoengine/connection.py index c0cfde31..13541bd4 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -3,11 +3,13 @@ import six from mongoengine.pymongo_support import IS_PYMONGO_3 -__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME', 'get_db'] +__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all', + 'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME', + 'get_db', 'get_connection'] DEFAULT_CONNECTION_NAME = 'default' +DEFAULT_DATABASE_NAME = 'test' if IS_PYMONGO_3: READ_PREFERENCE = ReadPreference.PRIMARY @@ -28,18 +30,17 @@ _connections = {} _dbs = {} -def register_connection(alias, db=None, name=None, host=None, port=None, - read_preference=READ_PREFERENCE, - username=None, password=None, - authentication_source=None, - authentication_mechanism=None, - **kwargs): - """Add a connection. +def _get_connection_settings( + db=None, name=None, host=None, port=None, + read_preference=READ_PREFERENCE, + username=None, password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs): + """Get the connection settings as a dict - :param alias: the name that will be used to refer to this connection - throughout MongoEngine - :param name: the name of the specific database to use :param db: the name of the database to use, for compatibility with connect + :param name: the name of the specific database to use :param host: the host name of the :program:`mongod` instance to connect to :param port: the port that the :program:`mongod` instance is running on :param read_preference: The read preference for the collection @@ -59,7 +60,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, .. versionchanged:: 0.10.6 - added mongomock support """ conn_settings = { - 'name': name or db or 'test', + 'name': name or db or DEFAULT_DATABASE_NAME, 'host': host or 'localhost', 'port': port or 27017, 'read_preference': read_preference, @@ -125,17 +126,74 @@ def register_connection(alias, db=None, name=None, host=None, port=None, kwargs.pop('is_slave', None) conn_settings.update(kwargs) + return conn_settings + + +def register_connection(alias, db=None, name=None, host=None, port=None, + read_preference=READ_PREFERENCE, + username=None, password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs): + """Register the connection settings. + + :param alias: the name that will be used to refer to this connection + throughout MongoEngine + :param name: the name of the specific database to use + :param db: the name of the database to use, for compatibility with connect + :param host: the host name of the :program:`mongod` instance to connect to + :param port: the port that the :program:`mongod` instance is running on + :param read_preference: The read preference for the collection + ** Added pymongo 2.1 + :param username: username to authenticate with + :param password: password to authenticate with + :param authentication_source: database to authenticate against + :param authentication_mechanism: database authentication mechanisms. + By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, + MONGODB-CR (MongoDB Challenge Response protocol) for older servers. + :param is_mock: explicitly use mongomock for this connection + (can also be done by using `mongomock://` as db host prefix) + :param kwargs: ad-hoc parameters to be passed into the pymongo driver, + for example maxpoolsize, tz_aware, etc. See the documentation + for pymongo's `MongoClient` for a full list. + + .. versionchanged:: 0.10.6 - added mongomock support + """ + conn_settings = _get_connection_settings( + db=db, name=name, host=host, port=port, + read_preference=read_preference, + username=username, password=password, + authentication_source=authentication_source, + authentication_mechanism=authentication_mechanism, + **kwargs) _connection_settings[alias] = conn_settings def disconnect(alias=DEFAULT_CONNECTION_NAME): """Close the connection with a given alias.""" + from mongoengine.base.common import _get_documents_by_db + if alias in _connections: get_connection(alias=alias).close() del _connections[alias] + if alias in _dbs: + # Detach all cached collections in Documents + for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): + if hasattr(doc_cls, '_disconnect'): + doc_cls._disconnect() + del _dbs[alias] + if alias in _connection_settings: + del _connection_settings[alias] + + +def disconnect_all(): + """Close all registered database.""" + for alias in list(_connections.keys()): + disconnect(alias) + def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): """Return a connection with a given alias.""" @@ -265,7 +323,14 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): .. versionchanged:: 0.6 - added multiple database support. """ - if alias not in _connections: + if alias in _connections: + prev_conn_setting = _connection_settings[alias] + new_conn_settings = _get_connection_settings(db, **kwargs) + + if new_conn_settings != prev_conn_setting: + raise MongoEngineConnectionError( + 'A different connection with alias `%s` was already registered. Use disconnect() first' % alias) + else: register_connection(alias, db, **kwargs) return get_connection(alias) diff --git a/mongoengine/document.py b/mongoengine/document.py index 328ac299..fd953340 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -188,10 +188,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) @classmethod - def _get_collection(cls): - """Return a PyMongo collection for the document.""" - if not hasattr(cls, '_collection') or cls._collection is None: + def _disconnect(cls): + """Detach the Document class from the (cached) database collection""" + cls._collection = None + @classmethod + def _get_collection(cls): + """Return the corresponding PyMongo collection of this document. + Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached + """ + if not hasattr(cls, '_collection') or cls._collection is None: # Get the collection, either capped or regular. if cls._meta.get('max_size') or cls._meta.get('max_documents'): cls._collection = cls._get_capped_collection() diff --git a/tests/test_connection.py b/tests/test_connection.py index 0a7271df..4bab23c6 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,12 +12,12 @@ from bson.tz_util import utc from mongoengine import ( connect, register_connection, - Document, DateTimeField -) + Document, DateTimeField, + disconnect_all, StringField) from mongoengine.pymongo_support import IS_PYMONGO_3 import mongoengine.connection from mongoengine.connection import (MongoEngineConnectionError, get_db, - get_connection) + get_connection, disconnect, DEFAULT_DATABASE_NAME) def get_tz_awareness(connection): @@ -29,6 +29,14 @@ def get_tz_awareness(connection): class ConnectionTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + disconnect_all() + + @classmethod + def tearDownClass(cls): + disconnect_all() + def tearDown(self): mongoengine.connection._connection_settings = {} mongoengine.connection._connections = {} @@ -49,6 +57,117 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + def test_connect_disconnect_works_properly(self): + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + class History2(Document): + name = StringField() + meta = {'db_alias': 'db2'} + + connect('db1', alias='db1') + connect('db2', alias='db2') + + History1.drop_collection() + History2.drop_collection() + + h = History1(name='default').save() + h1 = History2(name='db1').save() + + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + + disconnect('db1') + disconnect('db2') + + with self.assertRaises(MongoEngineConnectionError): + list(History1.objects().as_pymongo()) + + with self.assertRaises(MongoEngineConnectionError): + list(History2.objects().as_pymongo()) + + connect('db1', alias='db1') + connect('db2', alias='db2') + + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + + def test_connect_different_documents_to_different_database(self): + class History(Document): + name = StringField() + + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + class History2(Document): + name = StringField() + meta = {'db_alias': 'db2'} + + connect() + connect('db1', alias='db1') + connect('db2', alias='db2') + + History.drop_collection() + History1.drop_collection() + History2.drop_collection() + + h = History(name='default').save() + h1 = History1(name='db1').save() + h2 = History2(name='db2').save() + + self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) + self.assertEqual(History1._collection.database.name, 'db1') + self.assertEqual(History2._collection.database.name, 'db2') + + self.assertEqual(list(History.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h2.id, 'name': 'db2'}]) + + def test_connect_fails_if_connect_2_times_with_default_alias(self): + connect('mongoenginetest') + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + connect('mongoenginetest2') + self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception)) + + def test_connect_fails_if_connect_2_times_with_custom_alias(self): + connect('mongoenginetest', alias='alias1') + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + connect('mongoenginetest2', alias='alias1') + + self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception)) + + def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self): + """Intended to keep the detecton function simple but robust""" + db_name = 'mongoenginetest' + db_alias = 'alias1' + connect(db=db_name, alias=db_alias, host='localhost', port=27017) + + with self.assertRaises(MongoEngineConnectionError): + connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias) + + def test_connect_passes_silently_connect_multiple_times_with_same_config(self): + # test default connection to `test` + connect() + connect() + self.assertEqual(len(mongoengine.connection._connections), 1) + connect('test01', alias='test01') + connect('test01', alias='test01') + self.assertEqual(len(mongoengine.connection._connections), 2) + connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + self.assertEqual(len(mongoengine.connection._connections), 3) + def test_connect_in_mocking(self): """Ensure that the connect() method works properly in mocking. """ @@ -120,13 +239,93 @@ class ConnectionTest(unittest.TestCase): self.assertIsInstance(conn, mongomock.MongoClient) def test_disconnect(self): - """Ensure that the disconnect() method works properly - """ + """Ensure that the disconnect() method works properly""" + connections = mongoengine.connection._connections + dbs = mongoengine.connection._dbs + connection_settings = mongoengine.connection._connection_settings + conn1 = connect('mongoenginetest') - mongoengine.connection.disconnect() + + class History(Document): + pass + + self.assertIsNone(History._collection) + + History.drop_collection() + History.objects.first() # will trigger the caching of _collection attribute + + self.assertIsNotNone(History._collection) + + self.assertEqual(len(connections), 1) + self.assertEqual(len(dbs), 1) + self.assertEqual(len(connection_settings), 1) + + disconnect() + + self.assertIsNone(History._collection) + + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + History.objects.first() + self.assertEqual("You have not defined a default connection", str(ctx_err.exception)) + conn2 = connect('mongoenginetest') + History.objects.first() # Make sure its back on track self.assertTrue(conn1 is not conn2) + def test_disconnect_silently_pass_if_alias_does_not_exist(self): + connections = mongoengine.connection._connections + self.assertEqual(len(connections), 0) + disconnect(alias='not_exist') + + def test_disconnect_all(self): + connections = mongoengine.connection._connections + dbs = mongoengine.connection._dbs + connection_settings = mongoengine.connection._connection_settings + + connect('mongoenginetest') + connect('mongoenginetest2', alias='db1') + + class History(Document): + pass + + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + History.drop_collection() # will trigger the caching of _collection attribute + History.objects.first() + History1.drop_collection() + History1.objects.first() + + self.assertIsNotNone(History._collection) + self.assertIsNotNone(History1._collection) + + self.assertEqual(len(connections), 2) + self.assertEqual(len(dbs), 2) + self.assertEqual(len(connection_settings), 2) + + disconnect_all() + + self.assertIsNone(History._collection) + self.assertIsNone(History1._collection) + + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + with self.assertRaises(MongoEngineConnectionError): + History.objects.first() + + with self.assertRaises(MongoEngineConnectionError): + History1.objects.first() + + def test_disconnect_all_silently_pass_if_no_connection_exist(self): + disconnect_all() + def test_sharing_connections(self): """Ensure that connections are shared when the connection settings are exactly the same """ @@ -342,7 +541,7 @@ class ConnectionTest(unittest.TestCase): with self.assertRaises(MongoEngineConnectionError): c = connect(replicaset='local-rs') - def test_datetime(self): + def test_connect_tz_aware(self): connect('mongoenginetest', tz_aware=True) d = datetime.datetime(2010, 5, 5, tzinfo=utc) diff --git a/tests/utils.py b/tests/utils.py index 3c41f07d..910601b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import unittest from nose.plugins.skip import SkipTest from mongoengine import connect -from mongoengine.connection import get_db +from mongoengine.connection import get_db, disconnect_all from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 from mongoengine.pymongo_support import IS_PYMONGO_3 @@ -19,6 +19,7 @@ class MongoDBTestCase(unittest.TestCase): @classmethod def setUpClass(cls): + disconnect_all() cls._connection = connect(db=MONGO_TEST_DB) cls._connection.drop_database(MONGO_TEST_DB) cls.db = get_db() @@ -26,6 +27,7 @@ class MongoDBTestCase(unittest.TestCase): @classmethod def tearDownClass(cls): cls._connection.drop_database(MONGO_TEST_DB) + disconnect_all() def get_as_pymongo(doc): From b1e28d02f7f16a3aa0eb4b718377ddff8e9fc6f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 24 Apr 2019 22:44:07 +0200 Subject: [PATCH 29/71] Improve connect/disconnect - document disconnect + sample of usage - add more test cases to prevent github issues regressions --- docs/changelog.rst | 2 +- docs/guide/connecting.rst | 42 ++++++++++++++++++++++--- mongoengine/connection.py | 3 ++ mongoengine/document.py | 6 ++-- tests/test_connection.py | 64 ++++++++++++++++++++++++++++++--------- 5 files changed, 95 insertions(+), 22 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6bd19b12..bf3bba24 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,7 +5,7 @@ Changelog Development =========== - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` -- POTENTIAL BREAKING CHANGE: Fixes in connect/disconnect methods +- POTENTIAL BREAKING CHANGE: Fixes in connect/disconnect methods #565 #566 - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) - disconnect now clears `mongoengine.connection._connection_settings` - disconnect now clears the cached attribute `Document._collection` diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 5dac6ae9..1107ee3a 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -4,9 +4,11 @@ Connecting to MongoDB ===================== -To connect to a running instance of :program:`mongod`, use the -:func:`~mongoengine.connect` function. The first argument is the name of the -database to connect to:: +Connections in MongoEngine are registered globally and are identified with aliases. +If no `alias` is provided during the connection, it will use "default" as alias. + +To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect` +function. The first argument is the name of the database to connect to:: from mongoengine import connect connect('project1') @@ -42,6 +44,9 @@ the :attr:`host` to will establish connection to ``production`` database using ``admin`` username and ``qwerty`` password. +.. note:: Calling :func:`~mongoengine.connect` without argument will establish + a connection to the "test" database by default + Replica Sets ============ @@ -71,6 +76,8 @@ is used. In the background this uses :func:`~mongoengine.register_connection` to store the data and you can register all aliases up front if required. +Documents defined in different database +--------------------------------------- Individual documents can also support multiple databases by providing a `db_alias` in their meta data. This allows :class:`~pymongo.dbref.DBRef` objects to point across databases and collections. Below is an example schema, @@ -93,6 +100,33 @@ using 3 different databases to store data:: meta = {'db_alias': 'users-books-db'} +Disconnecting an existing connection +------------------------------------ +The function :func:`~mongoengine.disconnect` can be used to +disconnect a particular connection. This can be used to change a +connection globally:: + + from mongoengine import connect, disconnect + connect('a_db', alias='db1') + + class User(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + disconnect(alias='db1') + + connect('another_db', alias='db1') + +.. note:: Calling :func:`~mongoengine.disconnect` without argument + will disconnect the "default" connection + +.. note:: Since connections gets registered globally, it is important + to use the `disconnect` function from MongoEngine and not the + `disconnect()` method of an existing connection (pymongo.MongoClient) + +.. note:: :class:`~mongoengine.Document` are caching the pymongo collection. + using `disconnect` ensures that it gets cleaned as well + Context Managers ================ Sometimes you may want to switch the database or collection to query against. @@ -119,7 +153,7 @@ access to the same User document across databases:: Switch Collection ----------------- -The :class:`~mongoengine.context_managers.switch_collection` context manager +The :func:`~mongoengine.context_managers.switch_collection` context manager allows you to change the collection for a given class allowing quick and easy access to the same Group document across collection:: diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 13541bd4..0597199b 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -318,6 +318,9 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): Multiple databases are supported by using aliases. Provide a separate `alias` to connect to a different instance of :program:`mongod`. + In order to replace a connection identified by a given alias, you'll + need to call ``disconnect`` first + See the docstring for `register_connection` for more details about all supported kwargs. diff --git a/mongoengine/document.py b/mongoengine/document.py index fd953340..753520c7 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -795,13 +795,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. versionchanged:: 0.10.7 :class:`OperationError` exception raised if no collection available """ - col_name = cls._get_collection_name() - if not col_name: + coll_name = cls._get_collection_name() + if not coll_name: raise OperationError('Document %s has no collection defined ' '(is it abstract ?)' % cls) cls._collection = None db = cls._get_db() - db.drop_collection(col_name) + db.drop_collection(coll_name) @classmethod def create_index(cls, keys, background=False, **kwargs): diff --git a/tests/test_connection.py b/tests/test_connection.py index 4bab23c6..827829b5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,6 @@ import datetime + +from pymongo import MongoClient from pymongo.errors import OperationFailure try: @@ -238,12 +240,25 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb6') self.assertIsInstance(conn, mongomock.MongoClient) - def test_disconnect(self): - """Ensure that the disconnect() method works properly""" + def test_disconnect_cleans_globals(self): + """Ensure that the disconnect() method cleans the globals objects""" connections = mongoengine.connection._connections dbs = mongoengine.connection._dbs connection_settings = mongoengine.connection._connection_settings + connect('mongoenginetest') + + self.assertEqual(len(connections), 1) + self.assertEqual(len(dbs), 1) + self.assertEqual(len(connection_settings), 1) + + disconnect() + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + def test_disconnect_cleans_cached_collection_attribute_in_document(self): + """Ensure that the disconnect() method works properly""" conn1 = connect('mongoenginetest') class History(Document): @@ -252,29 +267,50 @@ class ConnectionTest(unittest.TestCase): self.assertIsNone(History._collection) History.drop_collection() + History.objects.first() # will trigger the caching of _collection attribute - self.assertIsNotNone(History._collection) - self.assertEqual(len(connections), 1) - self.assertEqual(len(dbs), 1) - self.assertEqual(len(connection_settings), 1) - disconnect() self.assertIsNone(History._collection) - self.assertEqual(len(connections), 0) - self.assertEqual(len(dbs), 0) - self.assertEqual(len(connection_settings), 0) - with self.assertRaises(MongoEngineConnectionError) as ctx_err: History.objects.first() self.assertEqual("You have not defined a default connection", str(ctx_err.exception)) - conn2 = connect('mongoenginetest') - History.objects.first() # Make sure its back on track - self.assertTrue(conn1 is not conn2) + def test_connect_disconnect_works_on_same_document(self): + """Ensure that the connect/disconnect works properly with a single Document""" + db1 = 'db1' + db2 = 'db2' + + # Ensure freshness of the 2 databases through pymongo + client = MongoClient('localhost', 27017) + client.drop_database(db1) + client.drop_database(db2) + + # Save in db1 + connect(db1) + + class User(Document): + name = StringField(required=True) + + user1 = User(name='John is in db1').save() + disconnect() + + # Make sure save doesnt work at this stage + with self.assertRaises(MongoEngineConnectionError): + User(name='Wont work').save() + + # Save in db2 + connect(db2) + user2 = User(name='Bob is in db2').save() + disconnect() + + db1_users = list(client[db1].user.find()) + self.assertEqual(db1_users, [{'_id': user1.id, 'name': 'John is in db1'}]) + db2_users = list(client[db2].user.find()) + self.assertEqual(db2_users, [{'_id': user2.id, 'name': 'Bob is in db2'}]) def test_disconnect_silently_pass_if_alias_does_not_exist(self): connections = mongoengine.connection._connections From 565e1dc0ed193b2b9e9ff55276235a8530b160a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 25 Apr 2019 22:11:43 +0200 Subject: [PATCH 30/71] minor improvements --- docs/changelog.rst | 2 +- mongoengine/connection.py | 9 ++++++--- tests/test_connection.py | 8 +++++++- tests/test_context_managers.py | 7 ++++--- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index bf3bba24..dfed5f59 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,7 +5,7 @@ Changelog Development =========== - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` -- POTENTIAL BREAKING CHANGE: Fixes in connect/disconnect methods #565 #566 +- POTENTIAL BREAKING CHANGES: Fixes in connect/disconnect methods #565 #566 #605 #607 #1213 #1599 - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) - disconnect now clears `mongoengine.connection._connection_settings` - disconnect now clears the cached attribute `Document._collection` diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 0597199b..8902bbf6 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -10,6 +10,8 @@ __all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_al DEFAULT_CONNECTION_NAME = 'default' DEFAULT_DATABASE_NAME = 'test' +DEFAULT_HOST = 'localhost' +DEFAULT_PORT = 27017 if IS_PYMONGO_3: READ_PREFERENCE = ReadPreference.PRIMARY @@ -61,8 +63,8 @@ def _get_connection_settings( """ conn_settings = { 'name': name or db or DEFAULT_DATABASE_NAME, - 'host': host or 'localhost', - 'port': port or 27017, + 'host': host or DEFAULT_HOST, + 'port': port or DEFAULT_PORT, 'read_preference': read_preference, 'username': username, 'password': password, @@ -172,6 +174,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, def disconnect(alias=DEFAULT_CONNECTION_NAME): """Close the connection with a given alias.""" from mongoengine.base.common import _get_documents_by_db + from mongoengine import Document if alias in _connections: get_connection(alias=alias).close() @@ -180,7 +183,7 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME): if alias in _dbs: # Detach all cached collections in Documents for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): - if hasattr(doc_cls, '_disconnect'): + if issubclass(doc_cls, Document): # Skip EmbeddedDocument doc_cls._disconnect() del _dbs[alias] diff --git a/tests/test_connection.py b/tests/test_connection.py index 827829b5..5ff22e06 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -249,9 +249,15 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest') self.assertEqual(len(connections), 1) - self.assertEqual(len(dbs), 1) + self.assertEqual(len(dbs), 0) self.assertEqual(len(connection_settings), 1) + class TestDoc(Document): + pass + + TestDoc.drop_collection() # triggers the db + self.assertEqual(len(dbs), 1) + disconnect() self.assertEqual(len(connections), 0) self.assertEqual(len(dbs), 0) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 227031e0..22c33b01 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -37,14 +37,15 @@ class ContextManagersTest(unittest.TestCase): def test_switch_collection_context_manager(self): connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') + register_connection(alias='testdb-1', db='mongoenginetest2') class Group(Document): name = StringField() - Group.drop_collection() + Group.drop_collection() # drops in default + with switch_collection(Group, 'group1') as Group: - Group.drop_collection() + Group.drop_collection() # drops in group1 Group(name="hello - group").save() self.assertEqual(1, Group.objects.count()) From e44f71eeb12ae95b972f7bc2bab5f896db80745e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 25 Apr 2019 22:31:05 +0200 Subject: [PATCH 31/71] updated changelog --- docs/changelog.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index b88c2ce6..356e2b65 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,14 +5,16 @@ Changelog Development =========== - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` -- POTENTIAL BREAKING CHANGES: Fixes in connect/disconnect methods #565 #566 #605 #607 #1213 #1599 +- Fix disconnect function #566 #1599 #605 #607 #1213 #565 +- Improve connect/disconnect documentations +- POTENTIAL BREAKING CHANGES: (associated with connect/disconnect fixes) - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) - disconnect now clears `mongoengine.connection._connection_settings` - disconnect now clears the cached attribute `Document._collection` - POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 -- connect() fails immediately when db name contains invalid characters (e. g. when user mistakenly puts 'mongodb://127.0.0.1:27017' as db name, happened in #1718) or is if db name is of an invalid type +- connect() fails immediately when db name contains invalid characters #2031 #1718 - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 From abfabc30c96a9da1bd7d9a043cbe80d216f4ed29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 1 May 2019 23:23:19 +0200 Subject: [PATCH 32/71] Fix querying on (Generic)EmbeddedDocument subclasses fields --- docs/changelog.rst | 1 + mongoengine/fields.py | 15 +- tests/fields/fields.py | 200 +----------- tests/fields/test_embedded_document_field.py | 316 +++++++++++++++++++ 4 files changed, 331 insertions(+), 201 deletions(-) create mode 100644 tests/fields/test_embedded_document_field.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 356e2b65..34e1c495 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Development =========== +- Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` - Fix disconnect function #566 #1599 #605 #607 #1213 #565 - Improve connect/disconnect documentations diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 52ed4bc9..7e119721 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -700,7 +700,11 @@ class EmbeddedDocumentField(BaseField): self.document_type.validate(value, clean) def lookup_member(self, member_name): - return self.document_type._fields.get(member_name) + doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__() + for doc_type in doc_and_subclasses: + field = doc_type._fields.get(member_name) + if field: + return field def prepare_query_value(self, op, value): if value is not None and not isinstance(value, self.document_type): @@ -747,12 +751,13 @@ class GenericEmbeddedDocumentField(BaseField): value.validate(clean=clean) def lookup_member(self, member_name): - if self.choices: - for choice in self.choices: - field = choice._fields.get(member_name) + document_choices = self.choices or [] + for document_choice in document_choices: + doc_and_subclasses = [document_choice] + document_choice.__subclasses__() + for doc_type in doc_and_subclasses: + field = doc_type._fields.get(member_name) if field: return field - return None def to_mongo(self, document, use_db_field=True, fields=None): if document is None: diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 2c4ac3ac..3b66f2de 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -8,11 +8,10 @@ from bson import DBRef, ObjectId, SON from mongoengine import Document, StringField, IntField, DateTimeField, DateField, ValidationError, \ ComplexDateTimeField, FloatField, ListField, ReferenceField, DictField, EmbeddedDocument, EmbeddedDocumentField, \ - GenericReferenceField, DoesNotExist, NotRegistered, GenericEmbeddedDocumentField, OperationError, DynamicField, \ - FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField, ObjectIdField, \ - SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument -from mongoengine.base import (BaseField, EmbeddedDocumentList, - _document_registry) + GenericReferenceField, DoesNotExist, NotRegistered, OperationError, DynamicField, \ + FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField,\ + ObjectIdField, SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument +from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry) from tests.utils import MongoDBTestCase @@ -1769,79 +1768,6 @@ class FieldTest(MongoDBTestCase): with self.assertRaises(ValidationError): shirt.validate() - def test_choices_validation_documents(self): - """ - Ensure fields with document choices validate given a valid choice. - """ - class UserComments(EmbeddedDocument): - author = StringField() - message = StringField() - - class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(UserComments,)) - ) - - # Ensure Validation Passes - BlogPost(comments=[ - UserComments(author='user2', message='message2'), - ]).save() - - def test_choices_validation_documents_invalid(self): - """ - Ensure fields with document choices validate given an invalid choice. - This should throw a ValidationError exception. - """ - class UserComments(EmbeddedDocument): - author = StringField() - message = StringField() - - class ModeratorComments(EmbeddedDocument): - author = StringField() - message = StringField() - - class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(UserComments,)) - ) - - # Single Entry Failure - post = BlogPost(comments=[ - ModeratorComments(author='mod1', message='message1'), - ]) - self.assertRaises(ValidationError, post.save) - - # Mixed Entry Failure - post = BlogPost(comments=[ - ModeratorComments(author='mod1', message='message1'), - UserComments(author='user2', message='message2'), - ]) - self.assertRaises(ValidationError, post.save) - - def test_choices_validation_documents_inheritance(self): - """ - Ensure fields with document choices validate given subclass of choice. - """ - class Comments(EmbeddedDocument): - meta = { - 'abstract': True - } - author = StringField() - message = StringField() - - class UserComments(Comments): - pass - - class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(Comments,)) - ) - - # Save Valid EmbeddedDocument Type - BlogPost(comments=[ - UserComments(author='user2', message='message2'), - ]).save() - def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field. @@ -1958,85 +1884,6 @@ class FieldTest(MongoDBTestCase): self.assertEqual(error_dict['size'], SIZE_MESSAGE) self.assertEqual(error_dict['color'], COLOR_MESSAGE) - def test_generic_embedded_document(self): - class Car(EmbeddedDocument): - name = StringField() - - class Dish(EmbeddedDocument): - food = StringField(required=True) - number = IntField() - - class Person(Document): - name = StringField() - like = GenericEmbeddedDocumentField() - - Person.drop_collection() - - person = Person(name='Test User') - person.like = Car(name='Fiat') - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.like, Car) - - person.like = Dish(food="arroz", number=15) - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.like, Dish) - - def test_generic_embedded_document_choices(self): - """Ensure you can limit GenericEmbeddedDocument choices.""" - class Car(EmbeddedDocument): - name = StringField() - - class Dish(EmbeddedDocument): - food = StringField(required=True) - number = IntField() - - class Person(Document): - name = StringField() - like = GenericEmbeddedDocumentField(choices=(Dish,)) - - Person.drop_collection() - - person = Person(name='Test User') - person.like = Car(name='Fiat') - self.assertRaises(ValidationError, person.validate) - - person.like = Dish(food="arroz", number=15) - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.like, Dish) - - def test_generic_list_embedded_document_choices(self): - """Ensure you can limit GenericEmbeddedDocument choices inside - a list field. - """ - class Car(EmbeddedDocument): - name = StringField() - - class Dish(EmbeddedDocument): - food = StringField(required=True) - number = IntField() - - class Person(Document): - name = StringField() - likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,))) - - Person.drop_collection() - - person = Person(name='Test User') - person.likes = [Car(name='Fiat')] - self.assertRaises(ValidationError, person.validate) - - person.likes = [Dish(food="arroz", number=15)] - person.save() - - person = Person.objects.first() - self.assertIsInstance(person.likes[0], Dish) - def test_recursive_validation(self): """Ensure that a validation result to_dict is available.""" class Author(EmbeddedDocument): @@ -2702,44 +2549,5 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) -class TestEmbeddedDocumentField(MongoDBTestCase): - def test___init___(self): - class MyDoc(EmbeddedDocument): - name = StringField() - - field = EmbeddedDocumentField(MyDoc) - self.assertEqual(field.document_type_obj, MyDoc) - - field2 = EmbeddedDocumentField('MyDoc') - self.assertEqual(field2.document_type_obj, 'MyDoc') - - def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): - with self.assertRaises(ValidationError): - EmbeddedDocumentField(dict) - - def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): - - class MyDoc(Document): - name = StringField() - - emb = EmbeddedDocumentField('MyDoc') - with self.assertRaises(ValidationError) as ctx: - emb.document_type - self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) - - def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): - # Relates to #1661 - class MyDoc(Document): - name = StringField() - - with self.assertRaises(ValidationError): - class MyFailingDoc(Document): - emb = EmbeddedDocumentField(MyDoc) - - with self.assertRaises(ValidationError): - class MyFailingdoc2(Document): - emb = EmbeddedDocumentField('MyDoc') - - if __name__ == '__main__': unittest.main() diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py new file mode 100644 index 00000000..e9fc79c8 --- /dev/null +++ b/tests/fields/test_embedded_document_field.py @@ -0,0 +1,316 @@ +# -*- coding: utf-8 -*- +from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \ + InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField + +from tests.utils import MongoDBTestCase + + +class TestEmbeddedDocumentField(MongoDBTestCase): + def test___init___(self): + class MyDoc(EmbeddedDocument): + name = StringField() + + field = EmbeddedDocumentField(MyDoc) + self.assertEqual(field.document_type_obj, MyDoc) + + field2 = EmbeddedDocumentField('MyDoc') + self.assertEqual(field2.document_type_obj, 'MyDoc') + + def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): + with self.assertRaises(ValidationError): + EmbeddedDocumentField(dict) + + def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): + + class MyDoc(Document): + name = StringField() + + emb = EmbeddedDocumentField('MyDoc') + with self.assertRaises(ValidationError) as ctx: + emb.document_type + self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) + + def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): + # Relates to #1661 + class MyDoc(Document): + name = StringField() + + with self.assertRaises(ValidationError): + class MyFailingDoc(Document): + emb = EmbeddedDocumentField(MyDoc) + + with self.assertRaises(ValidationError): + class MyFailingdoc2(Document): + emb = EmbeddedDocumentField('MyDoc') + + def test_query_embedded_document_attribute(self): + class AdminSettings(EmbeddedDocument): + foo1 = StringField() + foo2 = StringField() + + class Person(Document): + settings = EmbeddedDocumentField(AdminSettings) + name = StringField() + + Person.drop_collection() + + p = Person( + settings=AdminSettings(foo1='bar1', foo2='bar2'), + name='John', + ).save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + Person.objects(settings__notexist='bar').first() + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + with self.assertRaises(LookUpError): + Person.objects.only('settings.notexist') + + # Test existing attribute + self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p.id) + only_p = Person.objects.only('settings.foo1').first() + self.assertEqual(only_p.settings.foo1, p.settings.foo1) + self.assertIsNone(only_p.settings.foo2) + self.assertIsNone(only_p.name) + + exclude_p = Person.objects.exclude('settings.foo1').first() + self.assertIsNone(exclude_p.settings.foo1) + self.assertEqual(exclude_p.settings.foo2, p.settings.foo2) + self.assertEqual(exclude_p.name, p.name) + + def test_query_embedded_document_attribute_with_inheritance(self): + class BaseSettings(EmbeddedDocument): + meta = {'allow_inheritance': True} + base_foo = StringField() + + class AdminSettings(BaseSettings): + sub_foo = StringField() + + class Person(Document): + settings = EmbeddedDocumentField(BaseSettings) + + Person.drop_collection() + + p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo')) + p.save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id) + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + # Test existing attribute + self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id) + self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id) + + only_p = Person.objects.only('settings.base_foo', 'settings._cls').first() + self.assertEqual(only_p.settings.base_foo, 'basefoo') + self.assertIsNone(only_p.settings.sub_foo) + + +class TestGenericEmbeddedDocumentField(MongoDBTestCase): + + def test_generic_embedded_document(self): + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + like = GenericEmbeddedDocumentField() + + Person.drop_collection() + + person = Person(name='Test User') + person.like = Car(name='Fiat') + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.like, Car) + + person.like = Dish(food="arroz", number=15) + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.like, Dish) + + def test_generic_embedded_document_choices(self): + """Ensure you can limit GenericEmbeddedDocument choices.""" + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + like = GenericEmbeddedDocumentField(choices=(Dish,)) + + Person.drop_collection() + + person = Person(name='Test User') + person.like = Car(name='Fiat') + self.assertRaises(ValidationError, person.validate) + + person.like = Dish(food="arroz", number=15) + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.like, Dish) + + def test_generic_list_embedded_document_choices(self): + """Ensure you can limit GenericEmbeddedDocument choices inside + a list field. + """ + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,))) + + Person.drop_collection() + + person = Person(name='Test User') + person.likes = [Car(name='Fiat')] + self.assertRaises(ValidationError, person.validate) + + person.likes = [Dish(food="arroz", number=15)] + person.save() + + person = Person.objects.first() + self.assertIsInstance(person.likes[0], Dish) + + def test_choices_validation_documents(self): + """ + Ensure fields with document choices validate given a valid choice. + """ + class UserComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(UserComments,)) + ) + + # Ensure Validation Passes + BlogPost(comments=[ + UserComments(author='user2', message='message2'), + ]).save() + + def test_choices_validation_documents_invalid(self): + """ + Ensure fields with document choices validate given an invalid choice. + This should throw a ValidationError exception. + """ + class UserComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class ModeratorComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(UserComments,)) + ) + + # Single Entry Failure + post = BlogPost(comments=[ + ModeratorComments(author='mod1', message='message1'), + ]) + self.assertRaises(ValidationError, post.save) + + # Mixed Entry Failure + post = BlogPost(comments=[ + ModeratorComments(author='mod1', message='message1'), + UserComments(author='user2', message='message2'), + ]) + self.assertRaises(ValidationError, post.save) + + def test_choices_validation_documents_inheritance(self): + """ + Ensure fields with document choices validate given subclass of choice. + """ + class Comments(EmbeddedDocument): + meta = { + 'abstract': True + } + author = StringField() + message = StringField() + + class UserComments(Comments): + pass + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(Comments,)) + ) + + # Save Valid EmbeddedDocument Type + BlogPost(comments=[ + UserComments(author='user2', message='message2'), + ]).save() + + def test_query_generic_embedded_document_attribute(self): + class AdminSettings(EmbeddedDocument): + foo1 = StringField() + + class NonAdminSettings(EmbeddedDocument): + foo2 = StringField() + + class Person(Document): + settings = GenericEmbeddedDocumentField(choices=(AdminSettings, NonAdminSettings)) + + Person.drop_collection() + + p1 = Person(settings=AdminSettings(foo1='bar1')).save() + p2 = Person(settings=NonAdminSettings(foo2='bar2')).save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + Person.objects(settings__notexist='bar').first() + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + with self.assertRaises(LookUpError): + Person.objects.only('settings.notexist') + + # Test existing attribute + self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p1.id) + self.assertEqual(Person.objects(settings__foo2='bar2').first().id, p2.id) + + def test_query_generic_embedded_document_attribute_with_inheritance(self): + class BaseSettings(EmbeddedDocument): + meta = {'allow_inheritance': True} + base_foo = StringField() + + class AdminSettings(BaseSettings): + sub_foo = StringField() + + class Person(Document): + settings = GenericEmbeddedDocumentField(choices=[BaseSettings]) + + Person.drop_collection() + + p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo')) + p.save() + + # Test non exiting attribute + with self.assertRaises(InvalidQueryError) as ctx_err: + self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id) + self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') + + # Test existing attribute + self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id) + self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id) From d98f36ceff73dec4fc419054ceb28a4ea6315fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 2 May 2019 00:08:16 +0200 Subject: [PATCH 33/71] Add test for querying on fields of list(EmbeddedDocument) (with inheritance on the EmbededDoc) --- docs/changelog.rst | 1 + tests/fields/test_embedded_document_field.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 34e1c495..543cfce7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Development =========== +- Fix querying on List(EmbeddedDocument) subclasses fields #1961 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` - Fix disconnect function #566 #1599 #605 #607 #1213 #565 diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index e9fc79c8..b870f9f9 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \ - InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField + InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField from tests.utils import MongoDBTestCase @@ -108,6 +108,22 @@ class TestEmbeddedDocumentField(MongoDBTestCase): self.assertEqual(only_p.settings.base_foo, 'basefoo') self.assertIsNone(only_p.settings.sub_foo) + def test_query_list_embedded_document_with_inheritance(self): + class BaseEmbeddedDoc(EmbeddedDocument): + s = StringField() + meta = {'allow_inheritance': True} + + class EmbeddedDoc(BaseEmbeddedDoc): + s2 = StringField() + + class MyDoc(Document): + embeds = EmbeddedDocumentListField(BaseEmbeddedDoc) + + doc = MyDoc(embeds=[EmbeddedDoc(s='foo', s2='bar')]).save() + + self.assertEqual(MyDoc.objects(embeds__s='foo').first(), doc) + self.assertEqual(MyDoc.objects(embeds__s2='bar').first(), doc) + class TestGenericEmbeddedDocumentField(MongoDBTestCase): From f7b7d0f79e3af2594721db2a59c5d18ca8f347f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 3 May 2019 21:59:48 +0200 Subject: [PATCH 34/71] Improve tests for querying list(embedded) when using inheritance --- docs/changelog.rst | 2 +- tests/fields/test_embedded_document_field.py | 32 ++++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 543cfce7..961a8c94 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,7 +4,7 @@ Changelog Development =========== -- Fix querying on List(EmbeddedDocument) subclasses fields #1961 +- Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` - Fix disconnect function #566 #1599 #605 #607 #1213 #565 diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index b870f9f9..a262d054 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \ - InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField + InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField, \ + ReferenceField from tests.utils import MongoDBTestCase @@ -109,20 +110,31 @@ class TestEmbeddedDocumentField(MongoDBTestCase): self.assertIsNone(only_p.settings.sub_foo) def test_query_list_embedded_document_with_inheritance(self): - class BaseEmbeddedDoc(EmbeddedDocument): - s = StringField() + class Post(EmbeddedDocument): + title = StringField(max_length=120, required=True) meta = {'allow_inheritance': True} - class EmbeddedDoc(BaseEmbeddedDoc): - s2 = StringField() + class TextPost(Post): + content = StringField() - class MyDoc(Document): - embeds = EmbeddedDocumentListField(BaseEmbeddedDoc) + class MoviePost(Post): + author = StringField() - doc = MyDoc(embeds=[EmbeddedDoc(s='foo', s2='bar')]).save() + class Record(Document): + posts = ListField(EmbeddedDocumentField(Post)) - self.assertEqual(MyDoc.objects(embeds__s='foo').first(), doc) - self.assertEqual(MyDoc.objects(embeds__s2='bar').first(), doc) + record_movie = Record(posts=[MoviePost(author='John', title='foo')]).save() + record_text = Record(posts=[TextPost(content='a', title='foo')]).save() + + records = list(Record.objects(posts__author=record_movie.posts[0].author)) + self.assertEqual(len(records), 1) + self.assertEqual(records[0].id, record_movie.id) + + records = list(Record.objects(posts__content=record_text.posts[0].content)) + self.assertEqual(len(records), 1) + self.assertEqual(records[0].id, record_text.id) + + self.assertEqual(Record.objects(posts__title='foo').count(), 2) class TestGenericEmbeddedDocumentField(MongoDBTestCase): From 9cdc3ebee64f7c46be06d9264328c522889cdaf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 5 May 2019 23:37:12 +0200 Subject: [PATCH 35/71] Fix default write concern on save call that was overwriting connection wc --- docs/changelog.rst | 1 + mongoengine/document.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 356e2b65..80b92b81 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,6 +12,7 @@ Development - disconnect now clears `mongoengine.connection._connection_settings` - disconnect now clears the cached attribute `Document._collection` - POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 +- Fix the default write concern of .save that was overwriting the connection write concern #568 - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 - connect() fails immediately when db name contains invalid characters #2031 #1718 diff --git a/mongoengine/document.py b/mongoengine/document.py index 753520c7..5ccedbfa 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -375,7 +375,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): self.validate(clean=clean) if write_concern is None: - write_concern = {'w': 1} + write_concern = {} doc = self.to_mongo() From ac64ade10f27b1cf102965ec1950caaf5e190c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 15 May 2019 21:54:26 +0200 Subject: [PATCH 36/71] remove dead code (related to pymongo2) + minor cleaning --- mongoengine/document.py | 8 ++------ mongoengine/queryset/base.py | 34 ++++++++++++---------------------- tests/queryset/queryset.py | 11 ++--------- tests/test_connection.py | 30 +++++++++--------------------- 4 files changed, 25 insertions(+), 58 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index 5ccedbfa..a1139789 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -816,11 +816,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): index_spec = index_spec.copy() fields = index_spec.pop('fields') drop_dups = kwargs.get('drop_dups', False) - if IS_PYMONGO_3 and drop_dups: + if drop_dups: msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) - elif not IS_PYMONGO_3: - index_spec['drop_dups'] = drop_dups index_spec['background'] = background index_spec.update(kwargs) @@ -842,11 +840,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): :param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value will be removed if PyMongo3+ is used """ - if IS_PYMONGO_3 and drop_dups: + if drop_dups: msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) - elif not IS_PYMONGO_3: - kwargs.update({'drop_dups': drop_dups}) return cls.create_index(key_or_list, background=background, **kwargs) @classmethod diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 66e43514..16a2512d 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -626,7 +626,7 @@ class BaseQuerySet(object): queryset = self.clone() query = queryset._query - if not IS_PYMONGO_3 or not remove: + if not remove: update = transform.update(queryset._document, **update) sort = queryset._ordering @@ -1090,7 +1090,7 @@ class BaseQuerySet(object): return queryset def timeout(self, enabled): - """Enable or disable the default mongod timeout when querying. + """Enable or disable the default mongod timeout when querying. (no_cursor_timeout option) :param enabled: whether or not the timeout is used @@ -1531,26 +1531,16 @@ class BaseQuerySet(object): @property def _cursor_args(self): - if not IS_PYMONGO_3: - fields_name = 'fields' - cursor_args = { - 'timeout': self._timeout, - 'snapshot': self._snapshot - } - if self._read_preference is not None: - cursor_args['read_preference'] = self._read_preference - else: - cursor_args['slave_okay'] = self._slave_okay - else: - fields_name = 'projection' - # snapshot is not handled at all by PyMongo 3+ - # TODO: evaluate similar possibilities using modifiers - if self._snapshot: - msg = 'The snapshot option is not anymore available with PyMongo 3+' - warnings.warn(msg, DeprecationWarning) - cursor_args = { - 'no_cursor_timeout': not self._timeout - } + fields_name = 'projection' + # snapshot is not handled at all by PyMongo 3+ + # TODO: evaluate similar possibilities using modifiers + if self._snapshot: + msg = 'The snapshot option is not anymore available with PyMongo 3+' + warnings.warn(msg, DeprecationWarning) + cursor_args = { + 'no_cursor_timeout': not self._timeout + } + if self._loaded_fields: cursor_args[fields_name] = self._loaded_fields.as_dict() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 31b1641e..c403e68f 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -3415,10 +3415,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(query.count(), 3) self.assertEqual(query._query, {'$text': {'$search': 'brasil'}}) cursor_args = query._cursor_args - if not IS_PYMONGO_3: - cursor_args_fields = cursor_args['fields'] - else: - cursor_args_fields = cursor_args['projection'] + cursor_args_fields = cursor_args['projection'] self.assertEqual( cursor_args_fields, {'_text_score': {'$meta': 'textScore'}}) @@ -4511,11 +4508,7 @@ class QuerySetTest(unittest.TestCase): bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) self.assertEqual([], bars) - if not IS_PYMONGO_3: - error_class = ConfigurationError - else: - error_class = TypeError - self.assertRaises(error_class, Bar.objects, read_preference='Primary') + self.assertRaises(TypeError, Bar.objects, read_preference='Primary') # read_preference as a kwarg bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) diff --git a/tests/test_connection.py b/tests/test_connection.py index e5e10479..65f01717 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -23,10 +23,7 @@ from mongoengine.connection import (MongoEngineConnectionError, get_db, def get_tz_awareness(connection): - if not IS_PYMONGO_3: - return connection.tz_aware - else: - return connection.codec_options.tz_aware + return connection.codec_options.tz_aware class ConnectionTest(unittest.TestCase): @@ -425,12 +422,6 @@ class ConnectionTest(unittest.TestCase): c.admin.authenticate("admin", "password") c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"]) - if not IS_PYMONGO_3: - self.assertRaises( - MongoEngineConnectionError, connect, 'testdb_uri_bad', - host='mongodb://test:password@localhost' - ) - connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') conn = get_connection() @@ -641,17 +632,14 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(len(mongo_connections.items()), 2) self.assertIn('t1', mongo_connections.keys()) self.assertIn('t2', mongo_connections.keys()) - if not IS_PYMONGO_3: - self.assertEqual(mongo_connections['t1'].host, 'localhost') - self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') - else: - # Handle PyMongo 3+ Async Connection - # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. - # Purposely not catching exception to fail test if thrown. - mongo_connections['t1'].server_info() - mongo_connections['t2'].server_info() - self.assertEqual(mongo_connections['t1'].address[0], 'localhost') - self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') + + # Handle PyMongo 3+ Async Connection + # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. + # Purposely not catching exception to fail test if thrown. + mongo_connections['t1'].server_info() + mongo_connections['t2'].server_info() + self.assertEqual(mongo_connections['t1'].address[0], 'localhost') + self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') if __name__ == '__main__': From cf38ef70cb506d20e5a506a8f07311c0d3821fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 15 May 2019 22:23:35 +0200 Subject: [PATCH 37/71] Remove more code related to supporting pymongo2 --- mongoengine/connection.py | 41 ++++------------ mongoengine/document.py | 23 +++------ mongoengine/pymongo_support.py | 1 - mongoengine/queryset/base.py | 63 +++++++++---------------- mongoengine/queryset/transform.py | 12 ++--- tests/queryset/queryset.py | 53 ++------------------- tests/test_connection.py | 72 +++++++++-------------------- tests/test_replicaset_connection.py | 19 +++----- tests/utils.py | 17 ------- 9 files changed, 70 insertions(+), 231 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 67374d01..e12980e6 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -2,8 +2,6 @@ from pymongo import MongoClient, ReadPreference, uri_parser from pymongo.database import _check_name import six -from mongoengine.pymongo_support import IS_PYMONGO_3 - __all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all', 'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME', 'get_db', 'get_connection'] @@ -14,11 +12,11 @@ DEFAULT_DATABASE_NAME = 'test' DEFAULT_HOST = 'localhost' DEFAULT_PORT = 27017 -if IS_PYMONGO_3: - READ_PREFERENCE = ReadPreference.PRIMARY -else: - from pymongo import MongoReplicaSetClient - READ_PREFERENCE = False +_connection_settings = {} +_connections = {} +_dbs = {} + +READ_PREFERENCE = ReadPreference.PRIMARY class MongoEngineConnectionError(Exception): @@ -28,12 +26,7 @@ class MongoEngineConnectionError(Exception): pass -_connection_settings = {} -_connections = {} -_dbs = {} - - -def check_db_name(name): +def _check_db_name(name): """Check if a database name is valid. This functionality is copied from pymongo Database class constructor. """ @@ -57,7 +50,6 @@ def _get_connection_settings( : param host: the host name of the: program: `mongod` instance to connect to : param port: the port that the: program: `mongod` instance is running on : param read_preference: The read preference for the collection - ** Added pymongo 2.1 : param username: username to authenticate with : param password: password to authenticate with : param authentication_source: database to authenticate against @@ -83,7 +75,7 @@ def _get_connection_settings( 'authentication_mechanism': authentication_mechanism } - check_db_name(conn_settings['name']) + _check_db_name(conn_settings['name']) conn_host = conn_settings['host'] # Host can be a list or a string, so if string, force to a list. @@ -119,7 +111,7 @@ def _get_connection_settings( conn_settings['authentication_source'] = uri_options['authsource'] if 'authmechanism' in uri_options: conn_settings['authentication_mechanism'] = uri_options['authmechanism'] - if IS_PYMONGO_3 and 'readpreference' in uri_options: + if 'readpreference' in uri_options: read_preferences = ( ReadPreference.NEAREST, ReadPreference.PRIMARY, @@ -158,7 +150,6 @@ def register_connection(alias, db=None, name=None, host=None, port=None, : param host: the host name of the: program: `mongod` instance to connect to : param port: the port that the: program: `mongod` instance is running on : param read_preference: The read preference for the collection - ** Added pymongo 2.1 : param username: username to authenticate with : param password: password to authenticate with : param authentication_source: database to authenticate against @@ -259,22 +250,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): else: connection_class = MongoClient - # For replica set connections with PyMongo 2.x, use - # MongoReplicaSetClient. - # TODO remove this once we stop supporting PyMongo 2.x. - if 'replicaSet' in conn_settings and not IS_PYMONGO_3: - connection_class = MongoReplicaSetClient - conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) - - # hosts_or_uri has to be a string, so if 'host' was provided - # as a list, join its parts and separate them by ',' - if isinstance(conn_settings['hosts_or_uri'], list): - conn_settings['hosts_or_uri'] = ','.join( - conn_settings['hosts_or_uri']) - - # Discard port since it can't be used on MongoReplicaSetClient - conn_settings.pop('port', None) - # Iterate over all of the connection settings and if a connection with # the same parameters is already established, use it instead of creating # a new one. diff --git a/mongoengine/document.py b/mongoengine/document.py index a1139789..2ebd31fd 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -18,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern, switch_db) from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, SaveConditionError) -from mongoengine.pymongo_support import IS_PYMONGO_3, list_collection_names +from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import (NotUniqueError, OperationError, QuerySet, transform) @@ -822,10 +822,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): index_spec['background'] = background index_spec.update(kwargs) - if IS_PYMONGO_3: - return cls._get_collection().create_index(fields, **index_spec) - else: - return cls._get_collection().ensure_index(fields, **index_spec) + return cls._get_collection().create_index(fields, **index_spec) @classmethod def ensure_index(cls, key_or_list, drop_dups=False, background=False, @@ -858,7 +855,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): drop_dups = cls._meta.get('index_drop_dups', False) index_opts = cls._meta.get('index_opts') or {} index_cls = cls._meta.get('index_cls', True) - if IS_PYMONGO_3 and drop_dups: + if drop_dups: msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) @@ -889,11 +886,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if 'cls' in opts: del opts['cls'] - if IS_PYMONGO_3: - collection.create_index(fields, background=background, **opts) - else: - collection.ensure_index(fields, background=background, - drop_dups=drop_dups, **opts) + collection.create_index(fields, background=background, **opts) # If _cls is being used (for polymorphism), it needs an index, # only if another index doesn't begin with _cls @@ -904,12 +897,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if 'cls' in index_opts: del index_opts['cls'] - if IS_PYMONGO_3: - collection.create_index('_cls', background=background, - **index_opts) - else: - collection.ensure_index('_cls', background=background, - **index_opts) + collection.create_index('_cls', background=background, + **index_opts) @classmethod def list_indexes(cls): diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py index 0d607162..f66c038e 100644 --- a/mongoengine/pymongo_support.py +++ b/mongoengine/pymongo_support.py @@ -7,7 +7,6 @@ _PYMONGO_37 = (3, 7) PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) -IS_PYMONGO_3 = PYMONGO_VERSION[0] >= 3 IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 16a2512d..bfbfbbe0 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -10,6 +10,7 @@ from bson import SON, json_util from bson.code import Code import pymongo import pymongo.errors +from pymongo.collection import ReturnDocument from pymongo.common import validate_read_preference import six from six import iteritems @@ -21,14 +22,10 @@ from mongoengine.connection import get_db from mongoengine.context_managers import set_write_concern, switch_db from mongoengine.errors import (InvalidQueryError, LookUpError, NotUniqueError, OperationError) -from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode -if IS_PYMONGO_3: - from pymongo.collection import ReturnDocument - __all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') @@ -631,26 +628,20 @@ class BaseQuerySet(object): sort = queryset._ordering try: - if IS_PYMONGO_3: - if full_response: - msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' - warnings.warn(msg, DeprecationWarning) - if remove: - result = queryset._collection.find_one_and_delete( - query, sort=sort, **self._cursor_args) - else: - if new: - return_doc = ReturnDocument.AFTER - else: - return_doc = ReturnDocument.BEFORE - result = queryset._collection.find_one_and_update( - query, update, upsert=upsert, sort=sort, return_document=return_doc, - **self._cursor_args) - + if full_response: + msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' + warnings.warn(msg, DeprecationWarning) + if remove: + result = queryset._collection.find_one_and_delete( + query, sort=sort, **self._cursor_args) else: - result = queryset._collection.find_and_modify( - query, update, upsert=upsert, sort=sort, remove=remove, new=new, - full_response=full_response, **self._cursor_args) + if new: + return_doc = ReturnDocument.AFTER + else: + return_doc = ReturnDocument.BEFORE + result = queryset._collection.find_one_and_update( + query, update, upsert=upsert, sort=sort, return_document=return_doc, + **self._cursor_args) except pymongo.errors.DuplicateKeyError as err: raise NotUniqueError(u'Update failed (%s)' % err) except pymongo.errors.OperationFailure as err: @@ -1082,9 +1073,8 @@ class BaseQuerySet(object): ..versionchanged:: 0.5 - made chainable .. deprecated:: Ignored with PyMongo 3+ """ - if IS_PYMONGO_3: - msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' - warnings.warn(msg, DeprecationWarning) + msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' + warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._snapshot = enabled return queryset @@ -1108,9 +1098,8 @@ class BaseQuerySet(object): .. deprecated:: Ignored with PyMongo 3+ """ - if IS_PYMONGO_3: - msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' - warnings.warn(msg, DeprecationWarning) + msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' + warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._slave_okay = enabled return queryset @@ -1211,7 +1200,7 @@ class BaseQuerySet(object): pipeline = initial_pipeline + list(pipeline) - if IS_PYMONGO_3 and self._read_preference is not None: + if self._read_preference is not None: return self._collection.with_options(read_preference=self._read_preference) \ .aggregate(pipeline, cursor={}, **kwargs) @@ -1421,11 +1410,7 @@ class BaseQuerySet(object): if isinstance(field_instances[-1], ListField): pipeline.insert(1, {'$unwind': '$' + field}) - result = self._document._get_collection().aggregate(pipeline) - if IS_PYMONGO_3: - result = tuple(result) - else: - result = result.get('result') + result = tuple(self._document._get_collection().aggregate(pipeline)) if result: return result[0]['total'] @@ -1452,11 +1437,7 @@ class BaseQuerySet(object): if isinstance(field_instances[-1], ListField): pipeline.insert(1, {'$unwind': '$' + field}) - result = self._document._get_collection().aggregate(pipeline) - if IS_PYMONGO_3: - result = tuple(result) - else: - result = result.get('result') + result = tuple(self._document._get_collection().aggregate(pipeline)) if result: return result[0]['total'] return 0 @@ -1564,7 +1545,7 @@ class BaseQuerySet(object): # XXX In PyMongo 3+, we define the read preference on a collection # level, not a cursor level. Thus, we need to get a cloned collection # object using `with_options` first. - if IS_PYMONGO_3 and self._read_preference is not None: + if self._read_preference is not None: self._cursor_obj = self._collection\ .with_options(read_preference=self._read_preference)\ .find(self._query, **self._cursor_args) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 3de10a69..48c5f682 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -8,9 +8,7 @@ from six import iteritems from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class -from mongoengine.connection import get_connection from mongoengine.errors import InvalidQueryError -from mongoengine.pymongo_support import IS_PYMONGO_3 __all__ = ('query', 'update') @@ -163,16 +161,14 @@ def query(_doc_cls=None, **kwargs): # PyMongo 3+ and MongoDB < 2.6 near_embedded = False for near_op in ('$near', '$nearSphere'): - if isinstance(value_dict.get(near_op), dict) and ( - IS_PYMONGO_3 or get_connection().max_wire_version > 1): + if isinstance(value_dict.get(near_op), dict): value_son[near_op] = SON(value_son[near_op]) if '$maxDistance' in value_dict: - value_son[near_op][ - '$maxDistance'] = value_dict['$maxDistance'] + value_son[near_op]['$maxDistance'] = value_dict['$maxDistance'] if '$minDistance' in value_dict: - value_son[near_op][ - '$minDistance'] = value_dict['$minDistance'] + value_son[near_op]['$minDistance'] = value_dict['$minDistance'] near_embedded = True + if not near_embedded: if '$maxDistance' in value_dict: value_son['$maxDistance'] = value_dict['$maxDistance'] diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index c403e68f..039834f7 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -19,10 +19,9 @@ from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32 -from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) -from tests.utils import requires_mongodb_gte_26, skip_pymongo3 +from tests.utils import requires_mongodb_gte_26 class db_ops_tracker(query_counter): @@ -1047,48 +1046,6 @@ class QuerySetTest(unittest.TestCase): org.save() # saves the org self.assertEqual(q, 2) - @skip_pymongo3 - def test_slave_okay(self): - """Ensures that a query can take slave_okay syntax. - Useless with PyMongo 3+ as well as with MongoDB 3+. - """ - person1 = self.Person(name="User A", age=20) - person1.save() - person2 = self.Person(name="User B", age=30) - person2.save() - - # Retrieve the first person from the database - person = self.Person.objects.slave_okay(True).first() - self.assertIsInstance(person, self.Person) - self.assertEqual(person.name, "User A") - self.assertEqual(person.age, 20) - - @requires_mongodb_gte_26 - @skip_pymongo3 - def test_cursor_args(self): - """Ensures the cursor args can be set as expected - """ - p = self.Person.objects - # Check default - self.assertEqual(p._cursor_args, - {'snapshot': False, 'slave_okay': False, 'timeout': True}) - - p = p.snapshot(False).slave_okay(False).timeout(False) - self.assertEqual(p._cursor_args, - {'snapshot': False, 'slave_okay': False, 'timeout': False}) - - p = p.snapshot(True).slave_okay(False).timeout(False) - self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': False, 'timeout': False}) - - p = p.snapshot(True).slave_okay(True).timeout(False) - self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': True, 'timeout': False}) - - p = p.snapshot(True).slave_okay(True).timeout(True) - self.assertEqual(p._cursor_args, - {'snapshot': True, 'slave_okay': True, 'timeout': True}) - def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """ @@ -4568,12 +4525,8 @@ class QuerySetTest(unittest.TestCase): bars = Bar.objects \ .read_preference(ReadPreference.SECONDARY_PREFERRED) \ .aggregate() - if IS_PYMONGO_3: - self.assertEqual(bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED) - else: - self.assertNotEqual(bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._CommandCursor__collection.read_preference, + ReadPreference.SECONDARY_PREFERRED) def test_json_simple(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index 65f01717..5473b8a0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,6 +2,7 @@ import datetime from pymongo import MongoClient from pymongo.errors import OperationFailure, InvalidName +from pymongo import ReadPreference try: import unittest2 as unittest @@ -16,7 +17,6 @@ from mongoengine import ( connect, register_connection, Document, DateTimeField, disconnect_all, StringField) -from mongoengine.pymongo_support import IS_PYMONGO_3 import mongoengine.connection from mongoengine.connection import (MongoEngineConnectionError, get_db, get_connection, disconnect, DEFAULT_DATABASE_NAME) @@ -404,11 +404,7 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetests', alias='testdb2') actual_connection = get_connection('testdb2') - # Handle PyMongo 3+ Async Connection - if IS_PYMONGO_3: - # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. - # Purposely not catching exception to fail test if thrown. - expected_connection.server_info() + expected_connection.server_info() self.assertEqual(expected_connection, actual_connection) @@ -484,19 +480,11 @@ class ConnectionTest(unittest.TestCase): c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"]) # Authentication fails without "authSource" - if IS_PYMONGO_3: - test_conn = connect( - 'mongoenginetest', alias='test1', - host='mongodb://username2:password@localhost/mongoenginetest' - ) - self.assertRaises(OperationFailure, test_conn.server_info) - else: - self.assertRaises( - MongoEngineConnectionError, - connect, 'mongoenginetest', alias='test1', - host='mongodb://username2:password@localhost/mongoenginetest' - ) - self.assertRaises(MongoEngineConnectionError, get_db, 'test1') + test_conn = connect( + 'mongoenginetest', alias='test1', + host='mongodb://username2:password@localhost/mongoenginetest' + ) + self.assertRaises(OperationFailure, test_conn.server_info) # Authentication succeeds with "authSource" authd_conn = connect( @@ -565,44 +553,28 @@ class ConnectionTest(unittest.TestCase): """ conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true') conn2 = connect('testing', alias='conn2', w=1, j=True) - if IS_PYMONGO_3: - self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) - self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) - else: - self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True}) - self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True}) + self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) + self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) def test_connect_with_replicaset_via_uri(self): """Ensure connect() works when specifying a replicaSet via the MongoDB URI. """ - if IS_PYMONGO_3: - c = connect(host='mongodb://localhost/test?replicaSet=local-rs') - db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') - else: - # PyMongo < v3.x raises an exception: - # "localhost:27017 is not a member of replica set local-rs" - with self.assertRaises(MongoEngineConnectionError): - c = connect(host='mongodb://localhost/test?replicaSet=local-rs') + c = connect(host='mongodb://localhost/test?replicaSet=local-rs') + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, 'test') def test_connect_with_replicaset_via_kwargs(self): """Ensure connect() works when specifying a replicaSet via the connection kwargs """ - if IS_PYMONGO_3: - c = connect(replicaset='local-rs') - self.assertEqual(c._MongoClient__options.replica_set_name, - 'local-rs') - db = get_db() - self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') - else: - # PyMongo < v3.x raises an exception: - # "localhost:27017 is not a member of replica set local-rs" - with self.assertRaises(MongoEngineConnectionError): - c = connect(replicaset='local-rs') + c = connect(replicaset='local-rs') + self.assertEqual(c._MongoClient__options.replica_set_name, + 'local-rs') + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, 'test') def test_connect_tz_aware(self): connect('mongoenginetest', tz_aware=True) @@ -618,10 +590,8 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(d, date_doc.the_date) def test_read_preference_from_parse(self): - if IS_PYMONGO_3: - from pymongo import ReadPreference - conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") - self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) + conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") + self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) def test_multiple_connection_settings(self): connect('mongoenginetest', alias='t1', host="localhost") diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index 81fdfb64..cacdce8b 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,23 +1,16 @@ import unittest from pymongo import ReadPreference - -from mongoengine.pymongo_support import IS_PYMONGO_3 - -if IS_PYMONGO_3: - from pymongo import MongoClient - CONN_CLASS = MongoClient - READ_PREF = ReadPreference.SECONDARY -else: - from pymongo import ReplicaSetConnection - CONN_CLASS = ReplicaSetConnection - READ_PREF = ReadPreference.SECONDARY_ONLY +from pymongo import MongoClient import mongoengine -from mongoengine import * from mongoengine.connection import MongoEngineConnectionError +CONN_CLASS = MongoClient +READ_PREF = ReadPreference.SECONDARY + + class ConnectionTest(unittest.TestCase): def setUp(self): @@ -35,7 +28,7 @@ class ConnectionTest(unittest.TestCase): """ try: - conn = connect(db='mongoenginetest', + conn = mongoengine.connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=READ_PREF) except MongoEngineConnectionError as e: diff --git a/tests/utils.py b/tests/utils.py index 910601b1..0ebb44a4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,6 @@ from nose.plugins.skip import SkipTest from mongoengine import connect from mongoengine.connection import get_db, disconnect_all from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 -from mongoengine.pymongo_support import IS_PYMONGO_3 MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database @@ -80,19 +79,3 @@ def requires_mongodb_gte_3(func): lower than v3.0. """ return _decorated_with_ver_requirement(func, MONGODB_3, oper=operator.ge) - - -def skip_pymongo3(f): - """Raise a SkipTest exception if we're running a test against - PyMongo v3.x. - """ - def _inner(*args, **kwargs): - if IS_PYMONGO_3: - raise SkipTest("Useless with PyMongo 3+") - return f(*args, **kwargs) - - _inner.__name__ = f.__name__ - _inner.__doc__ = f.__doc__ - - return _inner - From 6a9ef319d0138ee9dce3d5ae73d667e9e14e0c5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 16 May 2019 22:46:42 +0200 Subject: [PATCH 38/71] Fix Incompatibility btw recent tox version and virtualenv version --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 64086357..b943024a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,8 +48,8 @@ install: - travis_retry pip install --upgrade pip - travis_retry pip install coveralls - travis_retry pip install flake8 flake8-import-order -- travis_retry pip install tox>=1.9 -- travis_retry pip install "virtualenv<14.0.0" # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) +- travis_retry pip install "tox" # tox 3.11.0 has requirement virtualenv>=14.0.0 +- travis_retry pip install "virtualenv" # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test # Cache dependencies installed via pip From b392e3102e94657fd663ba4e2b188ec648fe1c1b Mon Sep 17 00:00:00 2001 From: Agustin Barto Date: Fri, 17 May 2019 13:41:02 -0300 Subject: [PATCH 39/71] Add support to transform. Add pull tests for and . --- mongoengine/queryset/transform.py | 2 +- tests/queryset/transform.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 3de10a69..7241efbd 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -281,7 +281,7 @@ def update(_doc_cls=None, **update): if op == 'pull': if field.required or value is not None: - if match == 'in' and not isinstance(value, dict): + if match in ('in', 'nin') and not isinstance(value, dict): value = _prepare_query_for_iterable(field, op, value) else: value = field.prepare_query_value(op, value) diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 8064f09c..3c7c945f 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -276,13 +276,18 @@ class TransformTest(unittest.TestCase): title = StringField() content = EmbeddedDocumentField(SubDoc) - word = Word(word='abc', index=1) - update = transform.update(MainDoc, pull__content__text=word) - self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}}) + # word = Word(word='abc', index=1) + # update = transform.update(MainDoc, pull__content__text=word) + # self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}}) - update = transform.update(MainDoc, pull__content__heading='xyz') - self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) + # update = transform.update(MainDoc, pull__content__heading='xyz') + # self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) + # update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar']) + # self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}}) + + update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar']) + self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}}) if __name__ == '__main__': unittest.main() From 2b17985a11b5ff9baa206d513f2d05e502122039 Mon Sep 17 00:00:00 2001 From: Agustin Barto Date: Fri, 17 May 2019 13:55:00 -0300 Subject: [PATCH 40/71] Uncomment tests. --- tests/queryset/transform.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 3c7c945f..b2bc1d6c 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -276,15 +276,15 @@ class TransformTest(unittest.TestCase): title = StringField() content = EmbeddedDocumentField(SubDoc) - # word = Word(word='abc', index=1) - # update = transform.update(MainDoc, pull__content__text=word) - # self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}}) + word = Word(word='abc', index=1) + update = transform.update(MainDoc, pull__content__text=word) + self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}}) - # update = transform.update(MainDoc, pull__content__heading='xyz') - # self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) + update = transform.update(MainDoc, pull__content__heading='xyz') + self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) - # update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar']) - # self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}}) + update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar']) + self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}}) update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar']) self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}}) From f28e1b8c90eeb977a1ee8be625bdc3ae5143ad51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 11 May 2019 22:31:32 +0200 Subject: [PATCH 41/71] improve coverage of lazy ref field --- tests/fields/test_lazy_reference_field.py | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index a72e8cbe..d8031409 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -13,6 +13,35 @@ class TestLazyReferenceField(MongoDBTestCase): # with a document class name. self.assertRaises(ValidationError, LazyReferenceField, EmbeddedDocument) + def test___repr__(self): + class Animal(Document): + pass + + class Ocurrence(Document): + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal() + oc = Ocurrence(animal=animal) + self.assertIn('LazyReference', repr(oc.animal)) + + def test___getattr___unknown_attr_raises_attribute_error(self): + class Animal(Document): + pass + + class Ocurrence(Document): + animal = LazyReferenceField(Animal) + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal().save() + oc = Ocurrence(animal=animal) + with self.assertRaises(AttributeError): + oc.animal.not_exist + def test_lazy_reference_simple(self): class Animal(Document): name = StringField() From 00d2fd685af10f3cebb2c61a3d04c2ac341c7092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 12 May 2019 22:58:17 +0200 Subject: [PATCH 42/71] more test cov --- mongoengine/base/metaclasses.py | 3 --- mongoengine/common.py | 6 +----- mongoengine/fields.py | 6 ++---- mongoengine/queryset/base.py | 2 +- mongoengine/queryset/transform.py | 18 +++--------------- tests/document/indexes.py | 29 ++++++++++++++++++++++++----- tests/document/instance.py | 6 ++++++ tests/fields/geo.py | 5 +++++ tests/test_common.py | 0 9 files changed, 42 insertions(+), 33 deletions(-) create mode 100644 tests/test_common.py diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index a1970825..6f507eaa 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -184,9 +184,6 @@ class DocumentMetaclass(type): if issubclass(new_class, EmbeddedDocument): raise InvalidDocumentError('CachedReferenceFields is not ' 'allowed in EmbeddedDocuments') - if not f.document_type: - raise InvalidDocumentError( - 'Document is not available to sync') if f.auto_sync: f.start_listener() diff --git a/mongoengine/common.py b/mongoengine/common.py index bde7e78c..bcdea194 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -31,7 +31,6 @@ def _import_class(cls_name): field_classes = _field_list_cache - queryset_classes = ('OperationError',) deref_classes = ('DeReference',) if cls_name == 'BaseDocument': @@ -43,14 +42,11 @@ def _import_class(cls_name): elif cls_name in field_classes: from mongoengine import fields as module import_classes = field_classes - elif cls_name in queryset_classes: - from mongoengine import queryset as module - import_classes = queryset_classes elif cls_name in deref_classes: from mongoengine import dereference as module import_classes = deref_classes else: - raise ValueError('No import set for: ' % cls_name) + raise ValueError('No import set for: %s' % cls_name) for cls in import_classes: _class_registry_cache[cls] = getattr(module, cls) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 7e119721..1cd6be11 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -152,12 +152,10 @@ class URLField(StringField): scheme = value.split('://')[0].lower() if scheme not in self.schemes: self.error(u'Invalid scheme {} in URL: {}'.format(scheme, value)) - return # Then check full URL if not self.url_regex.match(value): self.error(u'Invalid URL: {}'.format(value)) - return class EmailField(StringField): @@ -259,10 +257,10 @@ class EmailField(StringField): try: domain_part = domain_part.encode('idna').decode('ascii') except UnicodeError: - self.error(self.error_msg % value) + self.error("%s %s" % (self.error_msg % value, "(domain failed IDN encoding)")) else: if not self.validate_domain_part(domain_part): - self.error(self.error_msg % value) + self.error("%s %s" % (self.error_msg % value, "(domain validation failed)")) class IntField(BaseField): diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 66e43514..eec73f91 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -197,7 +197,7 @@ class BaseQuerySet(object): only_fields=self.only_fields ) - raise AttributeError('Provide a slice or an integer index') + raise TypeError('Provide a slice or an integer index') def __iter__(self): raise NotImplementedError diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 3de10a69..ef5b1ea3 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -88,18 +88,10 @@ def query(_doc_cls=None, **kwargs): singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += STRING_OPERATORS if op in singular_ops: - if isinstance(field, six.string_types): - if (op in STRING_OPERATORS and - isinstance(value, six.string_types)): - StringField = _import_class('StringField') - value = StringField.prepare_query_value(op, value) - else: - value = field - else: - value = field.prepare_query_value(op, value) + value = field.prepare_query_value(op, value) - if isinstance(field, CachedReferenceField) and value: - value = value['_id'] + if isinstance(field, CachedReferenceField) and value: + value = value['_id'] elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): # Raise an error if the in/nin/all/near param is not iterable. @@ -308,10 +300,6 @@ def update(_doc_cls=None, **update): key = '.'.join(parts) - if not op: - raise InvalidQueryError('Updates must supply an operation ' - 'eg: set__FIELD=value') - if 'pull' in op and '.' in key: # Dot operators don't work on pull operations # unless they point to a list field diff --git a/tests/document/indexes.py b/tests/document/indexes.py index be344d32..bbe3dc5a 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -593,8 +593,9 @@ class IndexesTest(unittest.TestCase): # Two posts with the same slug is not allowed post2 = BlogPost(title='test2', slug='test') self.assertRaises(NotUniqueError, post2.save) + self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2) - # Ensure backwards compatibilty for errors + # Ensure backwards compatibility for errors self.assertRaises(OperationError, post2.save) @requires_mongodb_gte_34 @@ -826,6 +827,19 @@ class IndexesTest(unittest.TestCase): self.assertEqual(3600, info['created_1']['expireAfterSeconds']) + def test_index_drop_dups_silently_ignored(self): + class Customer(Document): + cust_id = IntField(unique=True, required=True) + meta = { + 'indexes': ['cust_id'], + 'index_drop_dups': True, + 'allow_inheritance': False, + } + + Customer.drop_collection() + Customer.objects.first() + + def test_unique_and_indexes(self): """Ensure that 'unique' constraints aren't overridden by meta.indexes. @@ -842,11 +856,16 @@ class IndexesTest(unittest.TestCase): cust.save() cust_dupe = Customer(cust_id=1) - try: + with self.assertRaises(NotUniqueError): cust_dupe.save() - raise AssertionError("We saved a dupe!") - except NotUniqueError: - pass + + cust = Customer(cust_id=2) + cust.save() + + # duplicate key on update + with self.assertRaises(NotUniqueError): + cust.cust_id = 1 + cust.save() def test_primary_save_duplicate_update_existing_object(self): """If you set a field as primary, then unexpected behaviour can occur. diff --git a/tests/document/instance.py b/tests/document/instance.py index 9b28f1b4..e1379a5d 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -420,6 +420,12 @@ class InstanceTest(MongoDBTestCase): person.save() person.to_dbref() + def test_key_like_attribute_access(self): + person = self.Person(age=30) + self.assertEqual(person['age'], 30) + with self.assertRaises(KeyError): + person['unknown_attr'] + def test_save_abstract_document(self): """Saving an abstract document should fail.""" class Doc(Document): diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 754f4203..37ed97f5 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -40,6 +40,11 @@ class GeoFieldTest(unittest.TestCase): expected = "Both values (%s) in point must be float or int" % repr(coord) self._test_for_expected_error(Location, coord, expected) + invalid_coords = [21, 4, 'a'] + for coord in invalid_coords: + expected = "GeoPointField can only accept tuples or lists of (x, y)" + self._test_for_expected_error(Location, coord, expected) + def test_point_validation(self): class Location(Document): loc = PointField() diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..e69de29b From c82f0c937d50b2f4c27af7d59ba0069f93686baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 14 May 2019 21:46:57 +0200 Subject: [PATCH 43/71] more work on coverage --- mongoengine/errors.py | 3 -- tests/document/indexes.py | 1 - tests/fields/test_email_field.py | 10 +++++ tests/queryset/queryset.py | 67 ++++++++++++++++++++++++++++++++ tests/queryset/transform.py | 8 ++++ tests/test_common.py | 15 +++++++ tests/test_context_managers.py | 8 ++++ tests/test_datastructures.py | 15 +++++++ 8 files changed, 123 insertions(+), 4 deletions(-) diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 0e92a8c4..b0009cbc 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -110,9 +110,6 @@ class ValidationError(AssertionError): def build_dict(source): errors_dict = {} - if not source: - return errors_dict - if isinstance(source, dict): for field_name, error in iteritems(source): errors_dict[field_name] = build_dict(error) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index bbe3dc5a..b0b78923 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -839,7 +839,6 @@ class IndexesTest(unittest.TestCase): Customer.drop_collection() Customer.objects.first() - def test_unique_and_indexes(self): """Ensure that 'unique' constraints aren't overridden by meta.indexes. diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index d8410354..3ce49d62 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -75,6 +75,16 @@ class TestEmailField(MongoDBTestCase): user = User(email='me@localhost') user.validate() + def test_email_domain_validation_fails_if_invalid_idn(self): + class User(Document): + email = EmailField() + + invalid_idn = '.google.com' + user = User(email='me@%s' % invalid_idn) + with self.assertRaises(ValidationError) as ctx_err: + user.validate() + self.assertIn("domain failed IDN encoding", str(ctx_err.exception)) + def test_email_field_ip_domain(self): class User(Document): email = EmailField() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 31b1641e..83b19ef8 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -158,6 +158,11 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(person.name, 'User B') self.assertEqual(person.age, None) + def test___getitem___invalid_index(self): + """Ensure slicing a queryset works as expected.""" + with self.assertRaises(TypeError): + self.Person.objects()['a'] + def test_slice(self): """Ensure slicing a queryset works as expected.""" user_a = self.Person.objects.create(name='User A', age=20) @@ -986,6 +991,29 @@ class QuerySetTest(unittest.TestCase): inserted_comment_id = Comment.objects.insert(comment, load_bulk=False) self.assertEqual(comment.id, inserted_comment_id) + def test_bulk_insert_accepts_doc_with_ids(self): + class Comment(Document): + id = IntField(primary_key=True) + + Comment.drop_collection() + + com1 = Comment(id=0) + com2 = Comment(id=1) + Comment.objects.insert([com1, com2]) + + def test_insert_raise_if_duplicate_in_constraint(self): + class Comment(Document): + id = IntField(primary_key=True) + + Comment.drop_collection() + + com1 = Comment(id=0) + + Comment.objects.insert(com1) + + with self.assertRaises(NotUniqueError): + Comment.objects.insert(com1) + def test_get_changed_fields_query_count(self): """Make sure we don't perform unnecessary db operations when none of document's fields were updated. @@ -3570,6 +3598,11 @@ class QuerySetTest(unittest.TestCase): opts = {"deleted": False} return qryset(**opts) + @queryset_manager + def objects_1_arg(qryset): + opts = {"deleted": False} + return qryset(**opts) + @queryset_manager def music_posts(doc_cls, queryset, deleted=False): return queryset(tags='music', @@ -3584,6 +3617,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([p.id for p in BlogPost.objects()], [post1.id, post2.id, post3.id]) + self.assertEqual([p.id for p in BlogPost.objects_1_arg()], + [post1.id, post2.id, post3.id]) self.assertEqual([p.id for p in BlogPost.music_posts()], [post1.id, post2.id]) @@ -4968,6 +5003,38 @@ class QuerySetTest(unittest.TestCase): people.count() self.assertEqual(q, 3) + def test_no_cached_queryset__repr__(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + qs = Person.objects.no_cache() + self.assertEqual(repr(qs), '[]') + + def test_no_cached_on_a_cached_queryset_raise_error(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + Person(name='a').save() + qs = Person.objects() + _ = list(qs) + with self.assertRaises(OperationError) as ctx_err: + qs.no_cache() + self.assertEqual("QuerySet already cached", str(ctx_err.exception)) + + def test_no_cached_queryset_no_cache_back_to_cache(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + qs = Person.objects() + self.assertIsInstance(qs, QuerySet) + qs = qs.no_cache() + self.assertIsInstance(qs, QuerySetNoCache) + qs = qs.cache() + self.assertIsInstance(qs, QuerySet) + def test_cache_not_cloned(self): class User(Document): diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 8064f09c..c0da1a52 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -71,6 +71,14 @@ class TransformTest(unittest.TestCase): update = transform.update(BlogPost, push_all__tags=['mongo', 'db']) self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}}) + def test_transform_update_no_operator_default_to_set(self): + """Ensure the differences in behvaior between 'push' and 'push_all'""" + class BlogPost(Document): + tags = ListField(StringField()) + + update = transform.update(BlogPost, tags=['mongo', 'db']) + self.assertEqual(update, {'$set': {'tags': ['mongo', 'db']}}) + def test_query_field_name(self): """Ensure that the correct field name is used when querying. """ diff --git a/tests/test_common.py b/tests/test_common.py index e69de29b..04ad5b34 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -0,0 +1,15 @@ +import unittest + +from mongoengine.common import _import_class +from mongoengine import Document + + +class TestCommon(unittest.TestCase): + + def test__import_class(self): + doc_cls = _import_class("Document") + self.assertIs(doc_cls, Document) + + def test__import_class_raise_if_not_known(self): + with self.assertRaises(ValueError): + _import_class("UnknownClass") diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 22c33b01..529032fe 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -270,6 +270,14 @@ class ContextManagersTest(unittest.TestCase): counter += 1 self.assertEqual(q, counter) + self.assertEqual(int(q), counter) # test __int__ + self.assertEqual(repr(q), str(int(q))) # test __repr__ + self.assertGreater(q, -1) # test __gt__ + self.assertGreaterEqual(q, int(q)) # test __gte__ + self.assertNotEqual(q, -1) + self.assertLess(q, 1000) + self.assertLessEqual(q, int(q)) + def test_query_counter_counts_getmore_queries(self): connect('mongoenginetest') db = get_db() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 4fb21d21..a9ef98e7 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,5 @@ import unittest +from six import iterkeys from mongoengine import Document from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict @@ -368,6 +369,20 @@ class TestStrictDict(unittest.TestCase): d = self.dtype(a=1, b=1, c=1) self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + def test_iterkeys(self): + d = self.dtype(a=1) + self.assertEqual(list(iterkeys(d)), ['a']) + + def test_len(self): + d = self.dtype(a=1) + self.assertEqual(len(d), 1) + + def test_pop(self): + d = self.dtype(a=1) + self.assertIn('a', d) + d.pop('a') + self.assertNotIn('a', d) + def test_repr(self): d = self.dtype(a=1, b=2, c=3) self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') From bb1089e03d3d2a4e4e7929a377e917d056f28517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 16 May 2019 22:31:24 +0200 Subject: [PATCH 44/71] Improve coverage in fields test --- tests/fields/test_cached_reference_field.py | 11 +++++++---- tests/fields/test_datetime_field.py | 9 +++++++++ tests/fields/test_lazy_reference_field.py | 17 +++++++++++++++++ tests/fields/test_long_field.py | 4 ++-- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index 989cea6d..470ecc5d 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -208,10 +208,7 @@ class TestCachedReferenceField(MongoDBTestCase): ('pj', "PJ") ) name = StringField() - tp = StringField( - choices=TYPES - ) - + tp = StringField(choices=TYPES) father = CachedReferenceField('self', fields=('tp',)) Person.drop_collection() @@ -222,6 +219,9 @@ class TestCachedReferenceField(MongoDBTestCase): a2 = Person(name='Wilson Junior', tp='pf', father=a1) a2.save() + a2 = Person.objects.with_id(a2.id) + self.assertEqual(a2.father.tp, a1.tp) + self.assertEqual(dict(a2.to_mongo()), { "_id": a2.pk, "name": u"Wilson Junior", @@ -374,6 +374,9 @@ class TestCachedReferenceField(MongoDBTestCase): self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') + # Check to_mongo with fields + self.assertNotIn('animal', o.to_mongo(fields=['person'])) + # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index c6253043..5af6a011 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -172,6 +172,9 @@ class TestDateTimeField(MongoDBTestCase): log.time = datetime.datetime.now().isoformat(' ') log.validate() + log.time = '2019-05-16 21:42:57.897847' + log.validate() + if dateutil: log.time = datetime.datetime.now().isoformat('T') log.validate() @@ -180,6 +183,12 @@ class TestDateTimeField(MongoDBTestCase): self.assertRaises(ValidationError, log.validate) log.time = 'ABC' self.assertRaises(ValidationError, log.validate) + log.time = '2019-05-16 21:GARBAGE:12' + self.assertRaises(ValidationError, log.validate) + log.time = '2019-05-16 21:42:57.GARBAGE' + self.assertRaises(ValidationError, log.validate) + log.time = '2019-05-16 21:42:57.123.456' + self.assertRaises(ValidationError, log.validate) class TestDateTimeTzAware(MongoDBTestCase): diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index d8031409..b10506e7 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -508,6 +508,23 @@ class TestGenericLazyReferenceField(MongoDBTestCase): p = Ocurrence.objects.get() self.assertIs(p.animal, None) + def test_generic_lazy_reference_accepts_string_instead_of_class(self): + class Animal(Document): + name = StringField() + tag = StringField() + + class Ocurrence(Document): + person = StringField() + animal = GenericLazyReferenceField('Animal') + + Animal.drop_collection() + Ocurrence.drop_collection() + + animal = Animal().save() + Ocurrence(animal=animal).save() + p = Ocurrence.objects.get() + self.assertEqual(p.animal, animal) + def test_generic_lazy_reference_embedded(self): class Animal(Document): name = StringField() diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py index 4ab7403d..3f307809 100644 --- a/tests/fields/test_long_field.py +++ b/tests/fields/test_long_field.py @@ -39,9 +39,9 @@ class TestLongField(MongoDBTestCase): doc.value = -1 self.assertRaises(ValidationError, doc.validate) - doc.age = 120 + doc.value = 120 self.assertRaises(ValidationError, doc.validate) - doc.age = 'ten' + doc.value = 'ten' self.assertRaises(ValidationError, doc.validate) def test_long_ne_operator(self): From 6b9d71554e397645071eacfe0b4f1fce7f834e46 Mon Sep 17 00:00:00 2001 From: Agustin Barto Date: Fri, 17 May 2019 17:23:52 -0300 Subject: [PATCH 45/71] Add integration tests --- tests/queryset/queryset.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 31b1641e..0b88193e 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2193,6 +2193,40 @@ class QuerySetTest(unittest.TestCase): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=['Ross']) + def test_pull_from_nested_embedded_using_in_nin(self): + """Ensure that the 'pull' update operation works on embedded documents using 'in' and 'nin' operators. + """ + + class User(EmbeddedDocument): + name = StringField() + + def __unicode__(self): + return '%s' % self.name + + class Collaborator(EmbeddedDocument): + helpful = ListField(EmbeddedDocumentField(User)) + unhelpful = ListField(EmbeddedDocumentField(User)) + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = EmbeddedDocumentField(Collaborator) + + Site.drop_collection() + + a = User(name='Esteban') + b = User(name='Frank') + x = User(name='Harry') + y = User(name='John') + + s = Site(name="test", collaborators=Collaborator( + helpful=[a, b], unhelpful=[x, y])).save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful__name__in=['Esteban']) # Pull a + self.assertEqual(Site.objects.first().collaborators['helpful'], [b]) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful__name__nin=['John']) # Pull x + self.assertEqual(Site.objects.first().collaborators['unhelpful'], [y]) + def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): From 2e01eb87db14e0d6b291f91b3aec7a8c477a4754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 17 May 2019 23:07:30 +0200 Subject: [PATCH 46/71] Add support for MongoDB 3.6 and Python3.7 in travis --- .install_mongodb_on_travis.sh | 8 +++++- .travis.yml | 2 ++ README.rst | 4 +-- docs/changelog.rst | 1 + mongoengine/mongodb_support.py | 1 + tests/queryset/queryset.py | 46 ++++++++++++++++++---------------- 6 files changed, 38 insertions(+), 24 deletions(-) diff --git a/.install_mongodb_on_travis.sh b/.install_mongodb_on_travis.sh index 6ac2e364..f1073333 100644 --- a/.install_mongodb_on_travis.sh +++ b/.install_mongodb_on_travis.sh @@ -25,8 +25,14 @@ elif [ "$MONGODB" = "3.4" ]; then sudo apt-get update sudo apt-get install mongodb-org-server=3.4.17 # service should be started automatically +elif [ "$MONGODB" = "3.6" ]; then + sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2930ADAE8CAF5059EE73BB4B58712A2291FA4AD5 + echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.6 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.6.list + sudo apt-get update + sudo apt-get install mongodb-org-server=3.6.12 + # service should be started automatically else - echo "Invalid MongoDB version, expected 2.6, 3.0, 3.2 or 3.4." + echo "Invalid MongoDB version, expected 2.6, 3.0, 3.2, 3.4 or 3.6." exit 1 fi; diff --git a/.travis.yml b/.travis.yml index b943024a..909183c1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,6 +35,8 @@ matrix: env: MONGODB=3.2 PYMONGO=3.x - python: 3.6 env: MONGODB=3.4 PYMONGO=3.x + - python: 3.6 + env: MONGODB=3.6 PYMONGO=3.x before_install: - bash .install_mongodb_on_travis.sh diff --git a/README.rst b/README.rst index 12d9df0e..cb279d2c 100644 --- a/README.rst +++ b/README.rst @@ -26,10 +26,10 @@ an `API reference `_. Supported MongoDB Versions ========================== -MongoEngine is currently tested against MongoDB v2.6, v3.0, v3.2 and v3.4. Future +MongoEngine is currently tested against MongoDB v2.6, v3.0, v3.2, v3.4 and v3.6. Future versions should be supported as well, but aren't actively tested at the moment. Make sure to open an issue or submit a pull request if you experience any -problems with MongoDB v3.4+. +problems with MongoDB v3.6+. Installation ============ diff --git a/docs/changelog.rst b/docs/changelog.rst index b875ffcd..6ac8da93 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Development =========== +- Add support for MongoDB 3.6 and Python3.7 in travis - Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index 717a3d81..8e414075 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -6,6 +6,7 @@ from mongoengine.connection import get_connection # Constant that can be used to compare the version retrieved with # get_mongodb_version() +MONGODB_36 = (3, 6) MONGODB_34 = (3, 4) MONGODB_32 = (3, 2) MONGODB_3 = (3, 0) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 31b1641e..5005f260 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -6,7 +6,6 @@ import uuid from decimal import Decimal from bson import DBRef, ObjectId -from nose.plugins.skip import SkipTest import pymongo from pymongo.errors import ConfigurationError from pymongo.read_preferences import ReadPreference @@ -18,7 +17,7 @@ from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError -from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32 +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32, MONGODB_36 from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) @@ -33,6 +32,12 @@ class db_ops_tracker(query_counter): return list(self.db.system.profile.find(ignore_query)) +def get_key_compat(mongo_ver): + ORDER_BY_KEY = 'sort' if mongo_ver >= MONGODB_32 else '$orderby' + CMD_QUERY_KEY = 'command' if mongo_ver >= MONGODB_36 else 'query' + return ORDER_BY_KEY, CMD_QUERY_KEY + + class QuerySetTest(unittest.TestCase): def setUp(self): @@ -1323,8 +1328,7 @@ class QuerySetTest(unittest.TestCase): """Ensure that the default ordering can be cleared by calling order_by() w/o any arguments. """ - MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER >= MONGODB_32 else '$orderby' + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class BlogPost(Document): title = StringField() @@ -1341,7 +1345,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.filter(title='whatever').first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0]['query'][ORDER_BY_KEY], + q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {'published_date': -1} ) @@ -1349,14 +1353,14 @@ class QuerySetTest(unittest.TestCase): with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by().first() self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) # calling an explicit order_by should use a specified sort with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by('published_date').first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0]['query'][ORDER_BY_KEY], + q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {'published_date': 1} ) @@ -1365,13 +1369,12 @@ class QuerySetTest(unittest.TestCase): qs = BlogPost.objects.filter(title='whatever').order_by('published_date') qs.order_by().first() self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) def test_no_ordering_for_get(self): """ Ensure that Doc.objects.get doesn't use any ordering. """ - MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER == MONGODB_32 else '$orderby' + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class BlogPost(Document): title = StringField() @@ -1387,13 +1390,13 @@ class QuerySetTest(unittest.TestCase): with db_ops_tracker() as q: BlogPost.objects.get(title='whatever') self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) # Ordering should be ignored for .get even if we set it explicitly with db_ops_tracker() as q: BlogPost.objects.order_by('-title').get(title='whatever') self.assertEqual(len(q.get_ops()), 1) - self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0]['query']) + self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) def test_find_embedded(self): """Ensure that an embedded document is properly returned from @@ -2532,6 +2535,7 @@ class QuerySetTest(unittest.TestCase): def test_comment(self): """Make sure adding a comment to the query gets added to the query""" MONGO_VER = self.mongodb_version + _, CMD_QUERY_KEY = get_key_compat(MONGO_VER) QUERY_KEY = 'filter' if MONGO_VER >= MONGODB_32 else '$query' COMMENT_KEY = 'comment' if MONGO_VER >= MONGODB_32 else '$comment' @@ -2550,8 +2554,8 @@ class QuerySetTest(unittest.TestCase): ops = q.get_ops() self.assertEqual(len(ops), 2) for op in ops: - self.assertEqual(op['query'][QUERY_KEY], {'age': {'$gte': 18}}) - self.assertEqual(op['query'][COMMENT_KEY], 'looking for an adult') + self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {'age': {'$gte': 18}}) + self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], 'looking for an adult') def test_map_reduce(self): """Ensure map/reduce is both mapping and reducing. @@ -5240,8 +5244,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(op['nreturned'], 1) def test_bool_with_ordering(self): - MONGO_VER = self.mongodb_version - ORDER_BY_KEY = 'sort' if MONGO_VER >= MONGODB_32 else '$orderby' + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class Person(Document): name = StringField() @@ -5260,21 +5263,22 @@ class QuerySetTest(unittest.TestCase): op = q.db.system.profile.find({"ns": {"$ne": "%s.system.indexes" % q.db.name}})[0] - self.assertNotIn(ORDER_BY_KEY, op['query']) + self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) # Check that normal query uses orderby qs2 = Person.objects.order_by('name') - with query_counter() as p: + with query_counter() as q: for x in qs2: pass - op = p.db.system.profile.find({"ns": + op = q.db.system.profile.find({"ns": {"$ne": "%s.system.indexes" % q.db.name}})[0] - self.assertIn(ORDER_BY_KEY, op['query']) + self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) def test_bool_with_ordering_from_meta_dict(self): + ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) class Person(Document): name = StringField() @@ -5296,7 +5300,7 @@ class QuerySetTest(unittest.TestCase): op = q.db.system.profile.find({"ns": {"$ne": "%s.system.indexes" % q.db.name}})[0] - self.assertNotIn('$orderby', op['query'], + self.assertNotIn('$orderby', op[CMD_QUERY_KEY], 'BaseQuerySet must remove orderby from meta in boolen test') self.assertEqual(Person.objects.first().name, 'A') From 64b63e9d52924f60af2a938b769fb12631494d3f Mon Sep 17 00:00:00 2001 From: George Pearson Date: Fri, 24 May 2019 12:49:04 +0100 Subject: [PATCH 47/71] Use update_one instead of deprecated update #1899 --- mongoengine/document.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index 5ccedbfa..03b659e3 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -502,8 +502,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): update_doc = self._get_update_doc() if update_doc: upsert = save_condition is None - last_error = collection.update(select_dict, update_doc, - upsert=upsert, **write_concern) + with set_write_concern(collection, write_concern) as wc_collection: + last_error = wc_collection.update_one( + select_dict, + update_doc, + upsert=upsert + ).raw_result if not upsert and last_error['n'] == 0: raise SaveConditionError('Race condition preventing' ' document update detected') From 6e1c132ee82bb2e134d18ca0e34f4b104624570e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 26 May 2019 22:17:58 +0200 Subject: [PATCH 48/71] Improve minor things in the tests --- tests/document/indexes.py | 2 +- tests/document/instance.py | 16 ++++++++-------- tests/fields/fields.py | 6 +++--- tests/fields/file_tests.py | 24 ++++++++++++------------ tests/queryset/queryset.py | 6 +++--- 5 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index b0b78923..34771e8a 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -409,7 +409,7 @@ class IndexesTest(unittest.TestCase): self.assertEqual(2, User.objects.count()) info = User.objects._collection.index_information() - self.assertEqual(info.keys(), ['_id_']) + self.assertEqual(list(info.keys()), ['_id_']) User.ensure_indexes() info = User.objects._collection.index_information() diff --git a/tests/document/instance.py b/tests/document/instance.py index e1379a5d..5746f1fc 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -3204,7 +3204,7 @@ class InstanceTest(MongoDBTestCase): p2.name = 'alon2' p2.save() p3 = Person.objects().only('created_on')[0] - self.assertEquals(orig_created_on, p3.created_on) + self.assertEqual(orig_created_on, p3.created_on) class Person(Document): created_on = DateTimeField(default=lambda: datetime.utcnow()) @@ -3213,10 +3213,10 @@ class InstanceTest(MongoDBTestCase): p4 = Person.objects()[0] p4.save() - self.assertEquals(p4.height, 189) + self.assertEqual(p4.height, 189) # However the default will not be fixed in DB - self.assertEquals(Person.objects(height=189).count(), 0) + self.assertEqual(Person.objects(height=189).count(), 0) # alter DB for the new default coll = Person._get_collection() @@ -3224,17 +3224,17 @@ class InstanceTest(MongoDBTestCase): if 'height' not in person: coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}}) - self.assertEquals(Person.objects(height=189).count(), 1) + self.assertEqual(Person.objects(height=189).count(), 1) def test_from_son(self): # 771 class MyPerson(self.Person): meta = dict(shard_key=["id"]) p = MyPerson.from_json('{"name": "name", "age": 27}', created=True) - self.assertEquals(p.id, None) + self.assertEqual(p.id, None) p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here p = MyPerson._from_son({"name": "name", "age": 27}, created=True) - self.assertEquals(p.id, None) + self.assertEqual(p.id, None) p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here def test_from_son_created_False_without_id(self): @@ -3312,7 +3312,7 @@ class InstanceTest(MongoDBTestCase): u_from_db = User.objects.get(name='user') u_from_db.height = None u_from_db.save() - self.assertEquals(u_from_db.height, None) + self.assertEqual(u_from_db.height, None) # 864 self.assertEqual(u_from_db.str_fld, None) self.assertEqual(u_from_db.int_fld, None) @@ -3326,7 +3326,7 @@ class InstanceTest(MongoDBTestCase): u.save() User.objects(name='user').update_one(set__height=None, upsert=True) u_from_db = User.objects.get(name='user') - self.assertEquals(u_from_db.height, None) + self.assertEqual(u_from_db.height, None) def test_not_saved_eq(self): """Ensure we can compare documents not saved. diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 3b66f2de..5eaee4be 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1372,7 +1372,7 @@ class FieldTest(MongoDBTestCase): brother = Brother(name="Bob", sibling=sister) brother.save() - self.assertEquals(Brother.objects[0].sibling.name, sister.name) + self.assertEqual(Brother.objects[0].sibling.name, sister.name) def test_reference_abstract_class(self): """Ensure that an abstract class instance cannot be used in the @@ -2045,8 +2045,8 @@ class FieldTest(MongoDBTestCase): Dog().save() Fish().save() Human().save() - self.assertEquals(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2) - self.assertEquals(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) + self.assertEqual(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2) + self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) def test_sparse_field(self): class Doc(Document): diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 4ff6865b..a7722458 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -320,16 +320,16 @@ class FileTest(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 1) - self.assertEquals(len(list(chunks)), 1) + self.assertEqual(len(list(files)), 1) + self.assertEqual(len(list(chunks)), 1) # Deleting the docoument should delete the files testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) # Test case where we don't store a file in the first place testfile = TestFile() @@ -337,15 +337,15 @@ class FileTest(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) # Test case where we overwrite the file testfile = TestFile() @@ -358,15 +358,15 @@ class FileTest(MongoDBTestCase): files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 1) - self.assertEquals(len(list(chunks)), 1) + self.assertEqual(len(list(files)), 1) + self.assertEqual(len(list(chunks)), 1) testfile.delete() files = db.fs.files.find() chunks = db.fs.chunks.find() - self.assertEquals(len(list(files)), 0) - self.assertEquals(len(list(chunks)), 0) + self.assertEqual(len(list(files)), 0) + self.assertEqual(len(list(chunks)), 0) def test_image_field(self): if not HAS_PIL: diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 04042350..51da663a 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -91,7 +91,7 @@ class QuerySetTest(unittest.TestCase): results = list(people) self.assertIsInstance(results[0], self.Person) - self.assertIsInstance(results[0].id, (ObjectId, str, unicode)) + self.assertIsInstance(results[0].id, ObjectId) self.assertEqual(results[0], user_a) self.assertEqual(results[0].name, 'User A') @@ -5609,8 +5609,8 @@ class QuerySetTest(unittest.TestCase): Animal(is_mamal=False).save() Cat(is_mamal=True, whiskers_length=5.1).save() ScottishCat(is_mamal=True, folded_ears=True).save() - self.assertEquals(Animal.objects(folded_ears=True).count(), 1) - self.assertEquals(Animal.objects(whiskers_length=5.1).count(), 1) + self.assertEqual(Animal.objects(folded_ears=True).count(), 1) + self.assertEqual(Animal.objects(whiskers_length=5.1).count(), 1) def test_loop_over_invalid_id_does_not_crash(self): class Person(Document): From 7d0687ec73c6d94ccdd9a77ee21ac82bca9a065b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 11 May 2019 22:34:25 +0200 Subject: [PATCH 49/71] custom field validator is now expected to raise a ValidationError (drop support for returning True/False) --- docs/changelog.rst | 2 ++ mongoengine/base/fields.py | 21 ++++++++++++------- mongoengine/errors.py | 5 +++++ tests/fields/fields.py | 43 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6ac8da93..6864627b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,8 @@ Changelog Development =========== - Add support for MongoDB 3.6 and Python3.7 in travis +- BREAKING CHANGE: Changed the custom field validator (i.e `validation` parameter of Field) so that it now requires: + the callable to raise a ValidationError (i.o return True/False). - Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 598eb606..5962df14 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -11,8 +11,7 @@ from mongoengine.base.common import UPDATE_OPERATORS from mongoengine.base.datastructures import (BaseDict, BaseList, EmbeddedDocumentList) from mongoengine.common import _import_class -from mongoengine.errors import ValidationError - +from mongoengine.errors import ValidationError, DeprecatedError __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField') @@ -53,8 +52,8 @@ class BaseField(object): unique with. :param primary_key: Mark this field as the primary key. Defaults to False. :param validation: (optional) A callable to validate the value of the - field. Generally this is deprecated in favour of the - `FIELD.validate` method + field. The callable takes the value as parameter and should raise + a ValidationError if validation fails :param choices: (optional) The valid choices :param null: (optional) If the field value can be null. If no and there is a default value then the default value is set @@ -226,10 +225,18 @@ class BaseField(object): # check validation argument if self.validation is not None: if callable(self.validation): - if not self.validation(value): - self.error('Value does not match custom validation method') + try: + # breaking change of 0.18 + # Get rid of True/False-type return for the validation method + # in favor of having validation raising a ValidationError + ret = self.validation(value) + if ret is not None: + raise DeprecatedError('validation argument for `%s` must not return anything, ' + 'it should raise a ValidationError if validation fails' % self.name) + except ValidationError as ex: + self.error(str(ex)) else: - raise ValueError('validation argument for "%s" must be a ' + raise ValueError('validation argument for `"%s"` must be a ' 'callable.' % self.name) self.validate(value, **kwargs) diff --git a/mongoengine/errors.py b/mongoengine/errors.py index b0009cbc..4aecef5e 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -142,3 +142,8 @@ class ValidationError(AssertionError): for k, v in iteritems(self.to_dict()): error_dict[generate_key(v)].append(k) return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)]) + + +class DeprecatedError(Exception): + """Raise when a user uses a feature that has been Deprecated""" + pass diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 5eaee4be..68baab46 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -12,6 +12,7 @@ from mongoengine import Document, StringField, IntField, DateTimeField, DateFiel FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField,\ ObjectIdField, SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry) +from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase @@ -56,6 +57,48 @@ class FieldTest(MongoDBTestCase): self.assertEqual( data_to_be_saved, ['age', 'created', 'day', 'name', 'userid']) + def test_custom_field_validation_raise_deprecated_error_when_validation_return_something(self): + # Covers introduction of a breaking change in the validation parameter (0.18) + def _not_empty(z): + return bool(z) + + class Person(Document): + name = StringField(validation=_not_empty) + + Person.drop_collection() + + error = ("validation argument for `name` must not return anything, " + "it should raise a ValidationError if validation fails") + + with self.assertRaises(DeprecatedError) as ctx_err: + Person(name="").validate() + self.assertEqual(str(ctx_err.exception), error) + + with self.assertRaises(DeprecatedError) as ctx_err: + Person(name="").save() + self.assertEqual(str(ctx_err.exception), error) + + def test_custom_field_validation_raise_validation_error(self): + def _not_empty(z): + if not z: + raise ValidationError('cantbeempty') + + class Person(Document): + name = StringField(validation=_not_empty) + + Person.drop_collection() + + with self.assertRaises(ValidationError) as ctx_err: + Person(name="").validate() + self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception)) + + with self.assertRaises(ValidationError): + Person(name="").save() + self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception)) + + Person(name="garbage").validate() + Person(name="garbage").save() + def test_default_values_set_to_None(self): """Ensure that default field values are used even when we explcitly initialize the doc with None values. From f00c9dc4d654de0510d6df566a8bc9b1781bb471 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 16 May 2019 22:53:31 +0200 Subject: [PATCH 50/71] Fix flake8 import error --- mongoengine/base/fields.py | 2 +- mongoengine/errors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 5962df14..fe96f15b 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -11,7 +11,7 @@ from mongoengine.base.common import UPDATE_OPERATORS from mongoengine.base.datastructures import (BaseDict, BaseList, EmbeddedDocumentList) from mongoengine.common import _import_class -from mongoengine.errors import ValidationError, DeprecatedError +from mongoengine.errors import DeprecatedError, ValidationError __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField') diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 4aecef5e..bea1d3dc 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -6,7 +6,7 @@ from six import iteritems __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', 'OperationError', 'NotUniqueError', 'FieldDoesNotExist', - 'ValidationError', 'SaveConditionError') + 'ValidationError', 'SaveConditionError', 'DeprecatedError') class NotRegistered(Exception): From 4334955e39781e273c26293833478f4e400a0170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20W=C3=B3jcik?= Date: Fri, 31 May 2019 11:01:15 +0200 Subject: [PATCH 51/71] Update the test matrix to reflect what's supported in 2019 (#2066) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, we were running the test suite for several combinations of MongoDB, Python, and PyMongo: - PyPy, MongoDB v2.6, PyMongo v3.x (which really means v3.6.1 at the moment) - Python v2.7, MongoDB v2.6, PyMongo v3.x - Python v3.5, MongoDB v2.6, PyMongo v3.x - Python v3.6, MongoDB v2.6, PyMongo v3.x - Python v2.7, MongoDB v3.0, PyMongo v3.5.0 - Python v3.6, MongoDB v3.0, PyMongo v3.5.0 - Python v3.5, MongoDB v3.2, PyMongo v3.x - Python v3.6, MongoDB v3.2, PyMongo v3.x - Python v3.6, MongoDB v3.4, PyMongo v3.x - Python v3.6, MongoDB v3.6, PyMongo v3.x There were a couple issues with this setup: 1. MongoDB v2.6 – v3.2 have reached their End of Life already (v2.6 almost 3 years ago!). See the "MongoDB Server" section on https://www.mongodb.com/support-policy. 2. We were only testing two recent-ish PyMongo versions (v3.5.0 & v3.6.1). We were not testing the oldest actively supported MongoDB/PyMongo/Python setup. This PR updates the test matrix so that these problems are solved. For the sake of simplicity, it does not yet attempt to cover MongoDB v4.0: - PyPy, MongoDB v3.4, PyMongo v3.x (aka v3.6.1 at the moment) - Python v2.7, MongoDB v3.4, PyMongo v3.x - Python v3.5, MongoDB v3.4, PyMongo v3.x - Python v3.6, MongoDB v3.4, PyMongo v3.x - Python v2.7, MongoDB v3.4, PyMongo v3.4 - Python v3.6, MongoDB v3.6, PyMongo v3.x --- .travis.yml | 41 ++++++++------- README.rst | 8 +-- mongoengine/connection.py | 18 +++++-- mongoengine/mongodb_support.py | 4 -- requirements.txt | 2 +- setup.py | 2 +- tests/document/class_methods.py | 2 - tests/document/indexes.py | 89 +++++++++++++++------------------ tests/document/instance.py | 4 -- tests/queryset/geo.py | 10 +--- tests/queryset/modify.py | 3 -- tests/queryset/queryset.py | 35 +++---------- tests/utils.py | 46 ++++++----------- tox.ini | 2 +- 14 files changed, 105 insertions(+), 161 deletions(-) diff --git a/.travis.yml b/.travis.yml index 909183c1..3186ea1c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,11 +2,18 @@ # PyMongo combinations. However, that would result in an overly long build # with a very large number of jobs, hence we only test a subset of all the # combinations: -# * MongoDB v2.6 is currently the "main" version tested against Python v2.7, -# v3.5, v3.6, PyPy, and PyMongo v3.x. -# * MongoDB v3.0 & v3.2 are tested against Python v2.7, v3.5 & v3.6 -# and Pymongo v3.5 & v3.x -# * MongoDB v3.4 is tested against v3.6 and Pymongo v3.x +# * MongoDB v3.4 & the latest PyMongo v3.x is currently the "main" setup, +# tested against Python v2.7, v3.5, v3.6, and PyPy. +# * Besides that, we test the lowest actively supported Python/MongoDB/PyMongo +# combination: MongoDB v3.4, PyMongo v3.4, Python v2.7. +# * MongoDB v3.6 is tested against Python v3.6, and PyMongo v3.6, v3.7, v3.8. +# +# We should periodically check MongoDB Server versions supported by MongoDB +# Inc., add newly released versions to the test matrix, and remove versions +# which have reached their End of Life. See: +# 1. https://www.mongodb.com/support-policy. +# 2. https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility +# # Reminder: Update README.rst if you change MongoDB versions we test. language: python @@ -18,7 +25,7 @@ python: - pypy env: -- MONGODB=2.6 PYMONGO=3.x +- MONGODB=3.4 PYMONGO=3.x matrix: # Finish the build as soon as one job fails @@ -26,15 +33,7 @@ matrix: include: - python: 2.7 - env: MONGODB=3.0 PYMONGO=3.5 - - python: 3.5 - env: MONGODB=3.2 PYMONGO=3.x - - python: 3.6 - env: MONGODB=3.0 PYMONGO=3.5 - - python: 3.6 - env: MONGODB=3.2 PYMONGO=3.x - - python: 3.6 - env: MONGODB=3.4 PYMONGO=3.x + env: MONGODB=3.4 PYMONGO=3.4.x - python: 3.6 env: MONGODB=3.6 PYMONGO=3.x @@ -86,15 +85,15 @@ deploy: password: secure: QMyatmWBnC6ZN3XLW2+fTBDU4LQcp1m/LjR2/0uamyeUzWKdlOoh/Wx5elOgLwt/8N9ppdPeG83ose1jOz69l5G0MUMjv8n/RIcMFSpCT59tGYqn3kh55b0cIZXFT9ar+5cxlif6a5rS72IHm5li7QQyxexJIII6Uxp0kpvUmek= - # create a source distribution and a pure python wheel for faster installs + # Create a source distribution and a pure python wheel for faster installs. distributions: "sdist bdist_wheel" - # only deploy on tagged commits (aka GitHub releases) and only for the - # parent repo's builds running Python 2.7 along with PyMongo v3.x (we run - # Travis against many different Python and PyMongo versions and we don't - # want the deploy to occur multiple times). + # Only deploy on tagged commits (aka GitHub releases) and only for the parent + # repo's builds running Python v2.7 along with PyMongo v3.x and MongoDB v3.4. + # We run Travis against many different Python, PyMongo, and MongoDB versions + # and we don't want the deploy to occur multiple times). on: tags: true repo: MongoEngine/mongoengine - condition: "$PYMONGO = 3.x" + condition: ($PYMONGO = 3.x) AND ($MONGODB = 3.4) python: 2.7 diff --git a/README.rst b/README.rst index cb279d2c..fe5f5f22 100644 --- a/README.rst +++ b/README.rst @@ -26,10 +26,10 @@ an `API reference `_. Supported MongoDB Versions ========================== -MongoEngine is currently tested against MongoDB v2.6, v3.0, v3.2, v3.4 and v3.6. Future -versions should be supported as well, but aren't actively tested at the moment. -Make sure to open an issue or submit a pull request if you experience any -problems with MongoDB v3.6+. +MongoEngine is currently tested against MongoDB v3.4 and v3.6. Future versions +should be supported as well, but aren't actively tested at the moment. Make +sure to open an issue or submit a pull request if you experience any problems +with MongoDB version > 3.6. Installation ============ diff --git a/mongoengine/connection.py b/mongoengine/connection.py index e12980e6..e0399fde 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -117,10 +117,22 @@ def _get_connection_settings( ReadPreference.PRIMARY, ReadPreference.PRIMARY_PREFERRED, ReadPreference.SECONDARY, - ReadPreference.SECONDARY_PREFERRED) - read_pf_mode = uri_options['readpreference'].lower() + ReadPreference.SECONDARY_PREFERRED, + ) + + # Starting with PyMongo v3.5, the "readpreference" option is + # returned as a string (e.g. "secondaryPreferred") and not an + # int (e.g. 3). + # TODO simplify the code below once we drop support for + # PyMongo v3.4. + read_pf_mode = uri_options['readpreference'] + if isinstance(read_pf_mode, six.string_types): + read_pf_mode = read_pf_mode.lower() for preference in read_preferences: - if preference.name.lower() == read_pf_mode: + if ( + preference.name.lower() == read_pf_mode or + preference.mode == read_pf_mode + ): conn_settings['read_preference'] = preference break else: diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index 8e414075..8234a616 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -7,10 +7,6 @@ from mongoengine.connection import get_connection # Constant that can be used to compare the version retrieved with # get_mongodb_version() MONGODB_36 = (3, 6) -MONGODB_34 = (3, 4) -MONGODB_32 = (3, 2) -MONGODB_3 = (3, 0) -MONGODB_26 = (2, 6) def get_mongodb_version(): diff --git a/requirements.txt b/requirements.txt index 38e0b20f..9bb319a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ nose -pymongo>=3.5 +pymongo>=3.4 six==1.10.0 flake8 flake8-import-order diff --git a/setup.py b/setup.py index c8e9c038..f1f5dea7 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ setup( long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=3.5', 'six'], + install_requires=['pymongo>=3.4', 'six'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 421618e4..4fc648b7 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -6,7 +6,6 @@ from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db -from tests.utils import requires_mongodb_gte_26 __all__ = ("ClassMethodsTest", ) @@ -187,7 +186,6 @@ class ClassMethodsTest(unittest.TestCase): self.assertEqual(BlogPostWithTags.compare_indexes(), {'missing': [], 'extra': []}) self.assertEqual(BlogPostWithCustomField.compare_indexes(), {'missing': [], 'extra': []}) - @requires_mongodb_gte_26 def test_compare_indexes_for_text_indexes(self): """ Ensure that compare_indexes behaves correctly for text indexes """ diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 34771e8a..6f486e9f 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -9,8 +9,7 @@ from six import iteritems from mongoengine import * from mongoengine.connection import get_db -from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32, MONGODB_3 -from tests.utils import requires_mongodb_gte_26, requires_mongodb_lte_32, requires_mongodb_gte_34 +from mongoengine.mongodb_support import get_mongodb_version __all__ = ("IndexesTest", ) @@ -478,8 +477,6 @@ class IndexesTest(unittest.TestCase): def test_covered_index(self): """Ensure that covered indexes can be used """ - IS_MONGODB_3 = get_mongodb_version() >= MONGODB_3 - class Test(Document): a = IntField() b = IntField() @@ -497,33 +494,38 @@ class IndexesTest(unittest.TestCase): # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude('a').explain() - if not IS_MONGODB_3: - self.assertFalse(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IDHACK' + ) query_plan = Test.objects(id=obj.id).only('id').explain() - if not IS_MONGODB_3: - self.assertTrue(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IDHACK' + ) query_plan = Test.objects(a=1).only('a').exclude('id').explain() - if not IS_MONGODB_3: - self.assertTrue(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN') - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'PROJECTION') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IXSCAN' + ) + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('stage'), + 'PROJECTION' + ) query_plan = Test.objects(a=1).explain() - if not IS_MONGODB_3: - self.assertFalse(query_plan['indexOnly']) - else: - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN') - self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'FETCH') + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), + 'IXSCAN' + ) + self.assertEqual( + query_plan.get('queryPlanner').get('winningPlan').get('stage'), + 'FETCH' + ) def test_index_on_id(self): - class BlogPost(Document): meta = { 'indexes': [ @@ -565,13 +567,10 @@ class IndexesTest(unittest.TestCase): self.assertEqual(BlogPost.objects.count(), 10) self.assertEqual(BlogPost.objects.hint().count(), 10) - if MONGO_VER >= MONGODB_32: - # Mongo32 throws an error if an index exists (i.e `tags` in our case) - # and you use hint on an index name that does not exist - with self.assertRaises(OperationFailure): - BlogPost.objects.hint([('ZZ', 1)]).count() - else: - self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) + # MongoDB v3.2+ throws an error if an index exists (i.e `tags` in our + # case) and you use hint on an index name that does not exist. + with self.assertRaises(OperationFailure): + BlogPost.objects.hint([('ZZ', 1)]).count() self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) @@ -598,9 +597,8 @@ class IndexesTest(unittest.TestCase): # Ensure backwards compatibility for errors self.assertRaises(OperationError, post2.save) - @requires_mongodb_gte_34 - def test_primary_key_unique_not_working_under_mongo_34(self): - """Relates to #1445""" + def test_primary_key_unique_not_working(self): + """Relates to #1445""" class Blog(Document): id = StringField(primary_key=True, unique=True) @@ -608,21 +606,17 @@ class IndexesTest(unittest.TestCase): with self.assertRaises(OperationFailure) as ctx_err: Blog(id='garbage').save() - try: - self.assertIn("The field 'unique' is not valid for an _id index specification", str(ctx_err.exception)) - except AssertionError: - # error is slightly different on python 3.6 - self.assertIn("The field 'background' is not valid for an _id index specification", str(ctx_err.exception)) - @requires_mongodb_lte_32 - def test_primary_key_unique_working_under_mongo_32(self): - """Relates to #1445""" - class Blog(Document): - id = StringField(primary_key=True, unique=True) - - Blog.drop_collection() - - Blog(id='garbage').save() + # One of the errors below should happen. Which one depends on the + # PyMongo version and dict order. + err_msg = str(ctx_err.exception) + self.assertTrue( + any([ + "The field 'unique' is not valid for an _id index specification" in err_msg, + "The field 'background' is not valid for an _id index specification" in err_msg, + "The field 'sparse' is not valid for an _id index specification" in err_msg, + ]) + ) def test_unique_with(self): """Ensure that unique_with constraints are applied to fields. @@ -984,7 +978,6 @@ class IndexesTest(unittest.TestCase): info['provider_ids.foo_1_provider_ids.bar_1']['key']) self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse']) - @requires_mongodb_gte_26 def test_text_indexes(self): class Book(Document): title = DictField() diff --git a/tests/document/instance.py b/tests/document/instance.py index 5746f1fc..cec019a9 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -28,8 +28,6 @@ from mongoengine.queryset import NULLIFY, Q from mongoengine.context_managers import switch_db, query_counter from mongoengine import signals -from tests.utils import requires_mongodb_gte_26 - TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), '../fields/mongoengine.png') @@ -850,7 +848,6 @@ class InstanceTest(MongoDBTestCase): self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) - @requires_mongodb_gte_26 def test_modify_with_positional_push(self): class Content(EmbeddedDocument): keywords = ListField(StringField()) @@ -3368,7 +3365,6 @@ class InstanceTest(MongoDBTestCase): person.update(set__height=2.0) - @requires_mongodb_gte_26 def test_push_with_position(self): """Ensure that push with position works properly for an instance.""" class BlogPost(Document): diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index 240a94ab..45e6a089 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -3,7 +3,7 @@ import unittest from mongoengine import * -from tests.utils import MongoDBTestCase, requires_mongodb_gte_3 +from tests.utils import MongoDBTestCase __all__ = ("GeoQueriesTest",) @@ -70,9 +70,6 @@ class GeoQueriesTest(MongoDBTestCase): self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) - # $minDistance was added in MongoDB v2.6, but continued being buggy - # until v3.0; skip for older versions - @requires_mongodb_gte_3 def test_near_and_min_distance(self): """Ensure the "min_distance" operator works alongside the "near" operator. @@ -243,9 +240,6 @@ class GeoQueriesTest(MongoDBTestCase): events = self.Event.objects(location__geo_within_polygon=polygon2) self.assertEqual(events.count(), 0) - # $minDistance was added in MongoDB v2.6, but continued being buggy - # until v3.0; skip for older versions - @requires_mongodb_gte_3 def test_2dsphere_near_and_min_max_distance(self): """Ensure "min_distace" and "max_distance" operators work well together with the "near" operator in a 2dsphere index. @@ -328,8 +322,6 @@ class GeoQueriesTest(MongoDBTestCase): """Make sure PointField works properly in an embedded document.""" self._test_embedded(point_field_class=PointField) - # Needs MongoDB > 2.6.4 https://jira.mongodb.org/browse/SERVER-14039 - @requires_mongodb_gte_3 def test_spherical_geospatial_operators(self): """Ensure that spherical geospatial queries are working.""" class Point(Document): diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py index 4b7c3da2..3c5879ba 100644 --- a/tests/queryset/modify.py +++ b/tests/queryset/modify.py @@ -2,8 +2,6 @@ import unittest from mongoengine import connect, Document, IntField, StringField, ListField -from tests.utils import requires_mongodb_gte_26 - __all__ = ("FindAndModifyTest",) @@ -96,7 +94,6 @@ class FindAndModifyTest(unittest.TestCase): self.assertEqual(old_doc.to_mongo(), {"_id": 1}) self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) - @requires_mongodb_gte_26 def test_modify_with_push(self): class BlogPost(Document): tags = ListField(StringField()) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 51da663a..04cfb061 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -17,10 +17,9 @@ from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError -from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32, MONGODB_36 +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) -from tests.utils import requires_mongodb_gte_26 class db_ops_tracker(query_counter): @@ -32,7 +31,7 @@ class db_ops_tracker(query_counter): def get_key_compat(mongo_ver): - ORDER_BY_KEY = 'sort' if mongo_ver >= MONGODB_32 else '$orderby' + ORDER_BY_KEY = 'sort' CMD_QUERY_KEY = 'command' if mongo_ver >= MONGODB_36 else 'query' return ORDER_BY_KEY, CMD_QUERY_KEY @@ -598,7 +597,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(post.comments[0].by, 'joe') self.assertEqual(post.comments[0].votes.score, 4) - @requires_mongodb_gte_26 def test_update_min_max(self): class Scores(Document): high_score = IntField() @@ -616,7 +614,6 @@ class QuerySetTest(unittest.TestCase): Scores.objects(id=scores.id).update(max__high_score=500) self.assertEqual(Scores.objects.get(id=scores.id).high_score, 1000) - @requires_mongodb_gte_26 def test_update_multiple(self): class Product(Document): item = StringField() @@ -868,11 +865,7 @@ class QuerySetTest(unittest.TestCase): with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs, load_bulk=False) - - if MONGO_VER >= MONGODB_32: - self.assertEqual(q, 1) # 1 entry containing the list of inserts - else: - self.assertEqual(q, len(blogs)) # 1 entry per doc inserted + self.assertEqual(q, 1) # 1 entry containing the list of inserts self.assertEqual(Blog.objects.count(), len(blogs)) @@ -885,11 +878,7 @@ class QuerySetTest(unittest.TestCase): with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs) - - if MONGO_VER >= MONGODB_32: - self.assertEqual(q, 2) # 1 for insert 1 for fetch - else: - self.assertEqual(q, len(blogs)+1) # + 1 to fetch all docs + self.assertEqual(q, 2) # 1 for insert 1 for fetch Blog.drop_collection() @@ -2030,7 +2019,6 @@ class QuerySetTest(unittest.TestCase): pymongo_doc = BlogPost.objects.as_pymongo().first() self.assertNotIn('title', pymongo_doc) - @requires_mongodb_gte_26 def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly. """ @@ -2555,8 +2543,8 @@ class QuerySetTest(unittest.TestCase): """Make sure adding a comment to the query gets added to the query""" MONGO_VER = self.mongodb_version _, CMD_QUERY_KEY = get_key_compat(MONGO_VER) - QUERY_KEY = 'filter' if MONGO_VER >= MONGODB_32 else '$query' - COMMENT_KEY = 'comment' if MONGO_VER >= MONGODB_32 else '$comment' + QUERY_KEY = 'filter' + COMMENT_KEY = 'comment' class User(Document): age = IntField() @@ -3370,7 +3358,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Foo.objects.distinct("bar"), [bar]) - @requires_mongodb_gte_26 def test_text_indexes(self): class News(Document): title = StringField() @@ -3454,7 +3441,6 @@ class QuerySetTest(unittest.TestCase): 'brasil').order_by('$text_score').first() self.assertEqual(item.get_text_score(), max_text_score) - @requires_mongodb_gte_26 def test_distinct_handles_references_to_alias(self): register_connection('testdb', 'mongoenginetest2') @@ -4586,7 +4572,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED) - @requires_mongodb_gte_26 def test_read_preference_aggregation_framework(self): class Bar(Document): txt = StringField() @@ -5354,7 +5339,6 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(Person.objects._has_data(), 'Cursor has data and returned False') - @requires_mongodb_gte_26 def test_queryset_aggregation_framework(self): class Person(Document): name = StringField() @@ -5396,7 +5380,6 @@ class QuerySetTest(unittest.TestCase): {'_id': None, 'avg': 29, 'total': 2} ]) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_skip(self): class Person(Document): name = StringField() @@ -5418,7 +5401,6 @@ class QuerySetTest(unittest.TestCase): {'_id': p3.pk, 'name': "SANDRA MARA"} ]) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_limit(self): class Person(Document): name = StringField() @@ -5439,7 +5421,6 @@ class QuerySetTest(unittest.TestCase): {'_id': p1.pk, 'name': "ISABELLA LUANNA"} ]) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_sort(self): class Person(Document): name = StringField() @@ -5462,7 +5443,6 @@ class QuerySetTest(unittest.TestCase): {'_id': p2.pk, 'name': "WILSON JUNIOR"} ]) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_skip_with_limit(self): class Person(Document): name = StringField() @@ -5492,7 +5472,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(data, list(data2)) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_sort_with_limit(self): class Person(Document): name = StringField() @@ -5534,7 +5513,6 @@ class QuerySetTest(unittest.TestCase): {'_id': p3.pk, 'name': "SANDRA MARA"}, ]) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_sort_with_skip(self): class Person(Document): name = StringField() @@ -5555,7 +5533,6 @@ class QuerySetTest(unittest.TestCase): {'_id': p2.pk, 'name': "WILSON JUNIOR"} ]) - @requires_mongodb_gte_26 def test_queryset_aggregation_with_sort_with_skip_with_limit(self): class Person(Document): name = StringField() diff --git a/tests/utils.py b/tests/utils.py index 0ebb44a4..27d5ada7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ from nose.plugins.skip import SkipTest from mongoengine import connect from mongoengine.connection import get_db, disconnect_all -from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 +from mongoengine.mongodb_support import get_mongodb_version MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database @@ -35,8 +35,20 @@ def get_as_pymongo(doc): def _decorated_with_ver_requirement(func, mongo_version_req, oper): - """Return a given function decorated with the version requirement - for a particular MongoDB version tuple. + """Return a MongoDB version requirement decorator. + + The resulting decorator will raise a SkipTest exception if the current + MongoDB version doesn't match the provided version/operator. + + For example, if you define a decorator like so: + + def requires_mongodb_gte_36(func): + return _decorated_with_ver_requirement( + func, (3.6), oper=operator.ge + ) + + Then tests decorated with @requires_mongodb_gte_36 will be skipped if + ran against MongoDB < v3.6. :param mongo_version_req: The mongodb version requirement (tuple(int, int)) :param oper: The operator to apply (e.g: operator.ge) @@ -51,31 +63,3 @@ def _decorated_with_ver_requirement(func, mongo_version_req, oper): _inner.__name__ = func.__name__ _inner.__doc__ = func.__doc__ return _inner - - -def requires_mongodb_gte_34(func): - """Raise a SkipTest exception if we're working with MongoDB version - lower than v3.4 - """ - return _decorated_with_ver_requirement(func, MONGODB_34, oper=operator.ge) - - -def requires_mongodb_lte_32(func): - """Raise a SkipTest exception if we're working with MongoDB version - greater than v3.2. - """ - return _decorated_with_ver_requirement(func, MONGODB_32, oper=operator.le) - - -def requires_mongodb_gte_26(func): - """Raise a SkipTest exception if we're working with MongoDB version - lower than v2.6. - """ - return _decorated_with_ver_requirement(func, MONGODB_26, oper=operator.ge) - - -def requires_mongodb_gte_3(func): - """Raise a SkipTest exception if we're working with MongoDB version - lower than v3.0. - """ - return _decorated_with_ver_requirement(func, MONGODB_3, oper=operator.ge) diff --git a/tox.ini b/tox.ini index 815d2acc..40bcea8a 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ commands = python setup.py nosetests {posargs} deps = nose - mg35: PyMongo==3.5 + mg34x: PyMongo>=3.4,<3.5 mg3x: PyMongo>=3.0,<3.7 setenv = PYTHON_EGG_CACHE = {envdir}/python-eggs From 9ae8fe7c2d674a8dfd0e061368520fb534b990f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 8 May 2019 23:45:35 +0200 Subject: [PATCH 52/71] Improve perf of Doc.save by preventing a full to_mongo() call just to get the `created` variable --- docs/changelog.rst | 1 + mongoengine/base/document.py | 3 +- mongoengine/document.py | 10 ++-- tests/document/instance.py | 109 ++++++++++++++++++++++++++--------- 4 files changed, 89 insertions(+), 34 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6864627b..45de682c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,7 @@ Development - Add support for MongoDB 3.6 and Python3.7 in travis - BREAKING CHANGE: Changed the custom field validator (i.e `validation` parameter of Field) so that it now requires: the callable to raise a ValidationError (i.o return True/False). +- Prevent an expensive call to to_mongo in Document.save() to improve performance #? - Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 4cf34b4f..2e8dd9f1 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -293,8 +293,7 @@ class BaseDocument(object): """ Return as SON data ready for use with MongoDB. """ - if not fields: - fields = [] + fields = fields or [] data = SON() data['_id'] = None diff --git a/mongoengine/document.py b/mongoengine/document.py index 55ff538a..bf194c70 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -259,7 +259,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): data = super(Document, self).to_mongo(*args, **kwargs) # If '_id' is None, try and set it from self._data. If that - # doesn't exist either, remote '_id' from the SON completely. + # doesn't exist either, remove '_id' from the SON completely. if data['_id'] is None: if self._data.get('id') is None: del data['_id'] @@ -365,10 +365,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. versionchanged:: 0.10.7 Add signal_kwargs argument """ + signal_kwargs = signal_kwargs or {} + if self._meta.get('abstract'): raise InvalidDocumentError('Cannot save an abstract document.') - signal_kwargs = signal_kwargs or {} signals.pre_save.send(self.__class__, document=self, **signal_kwargs) if validate: @@ -377,9 +378,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if write_concern is None: write_concern = {} - doc = self.to_mongo() - - created = ('_id' not in doc or self._created or force_insert) + doc_id = self.to_mongo(fields=['id']) + created = ('_id' not in doc_id or self._created or force_insert) signals.pre_save_post_validation.send(self.__class__, document=self, created=created, **signal_kwargs) diff --git a/tests/document/instance.py b/tests/document/instance.py index cec019a9..5d4cd1da 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -16,7 +16,7 @@ from mongoengine.pymongo_support import list_collection_names from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, PickleDynamicEmbedded, PickleDynamicTest) -from tests.utils import MongoDBTestCase +from tests.utils import MongoDBTestCase, get_as_pymongo from mongoengine import * from mongoengine.base import get_document, _document_registry @@ -715,39 +715,74 @@ class InstanceTest(MongoDBTestCase): acc1 = Account.objects.first() self.assertHasInstance(acc1._data["emails"][0], acc1) + def test_save_checks_that_clean_is_called(self): + class CustomError(Exception): + pass + + class TestDocument(Document): + def clean(self): + raise CustomError() + + with self.assertRaises(CustomError): + TestDocument().save() + + TestDocument().save(clean=False) + + def test_save_signal_pre_save_post_validation_makes_change_to_doc(self): + class BlogPost(Document): + content = StringField() + + @classmethod + def pre_save_post_validation(cls, sender, document, **kwargs): + document.content = 'checked' + + signals.pre_save_post_validation.connect(BlogPost.pre_save_post_validation, sender=BlogPost) + + BlogPost.drop_collection() + + post = BlogPost(content='unchecked').save() + self.assertEqual(post.content, 'checked') + # Make sure pre_save_post_validation changes makes it to the db + raw_doc = get_as_pymongo(post) + self.assertEqual( + raw_doc, + { + 'content': 'checked', + '_id': post.id + }) + def test_document_clean(self): class TestDocument(Document): status = StringField() - pub_date = DateTimeField() + cleaned = BooleanField(default=False) def clean(self): - if self.status == 'draft' and self.pub_date is not None: - msg = 'Draft entries may not have a publication date.' - raise ValidationError(msg) - # Set the pub_date for published items if not set. - if self.status == 'published' and self.pub_date is None: - self.pub_date = datetime.now() + self.cleaned = True TestDocument.drop_collection() - t = TestDocument(status="draft", pub_date=datetime.now()) - - with self.assertRaises(ValidationError) as cm: - t.save() - - expected_msg = "Draft entries may not have a publication date." - self.assertIn(expected_msg, cm.exception.message) - self.assertEqual(cm.exception.to_dict(), {'__all__': expected_msg}) + t = TestDocument(status="draft") + # Ensure clean=False prevent call to clean t = TestDocument(status="published") t.save(clean=False) - - self.assertEqual(t.pub_date, None) + self.assertEqual(t.status, "published") + self.assertEqual(t.cleaned, False) t = TestDocument(status="published") + self.assertEqual(t.cleaned, False) t.save(clean=True) - - self.assertEqual(type(t.pub_date), datetime) + self.assertEqual(t.status, "published") + self.assertEqual(t.cleaned, True) + raw_doc = get_as_pymongo(t) + # Make sure clean changes makes it to the db + self.assertEqual( + raw_doc, + { + 'status': 'published', + 'cleaned': True, + '_id': t.id + }) def test_document_embedded_clean(self): class TestEmbeddedDocument(EmbeddedDocument): @@ -887,19 +922,39 @@ class InstanceTest(MongoDBTestCase): person.save() # Ensure that the object is in the database - collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(person_obj['name'], 'Test User') - self.assertEqual(person_obj['age'], 30) - self.assertEqual(person_obj['_id'], person.id) + raw_doc = get_as_pymongo(person) + self.assertEqual( + raw_doc, + { + '_cls': 'Person', + 'name': 'Test User', + 'age': 30, + '_id': person.id + }) - # Test skipping validation on save + def test_save_skip_validation(self): class Recipient(Document): email = EmailField(required=True) recipient = Recipient(email='not-an-email') - self.assertRaises(ValidationError, recipient.save) + with self.assertRaises(ValidationError): + recipient.save() + recipient.save(validate=False) + raw_doc = get_as_pymongo(recipient) + self.assertEqual( + raw_doc, + { + 'email': 'not-an-email', + '_id': recipient.id + }) + + def test_save_with_bad_id(self): + class Clown(Document): + id = IntField(primary_key=True) + + with self.assertRaises(ValidationError): + Clown(id="not_an_int").save() def test_save_to_a_value_that_equates_to_false(self): class Thing(EmbeddedDocument): From daca0ebc1498c74a45317cd887c93e3c5910f52e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 17 May 2019 22:21:22 +0200 Subject: [PATCH 53/71] update changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 45de682c..dbc64e7d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,7 @@ Development - BREAKING CHANGE: Changed the custom field validator (i.e `validation` parameter of Field) so that it now requires: the callable to raise a ValidationError (i.o return True/False). - Prevent an expensive call to to_mongo in Document.save() to improve performance #? +- Improve perf of .save by avoiding a call to to_mongo in Document.save() #2049 - Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` From 962997ed160eefdb309d73add39ffb3d9431f560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sun, 26 May 2019 22:33:04 +0200 Subject: [PATCH 54/71] fix flaky test due to signal receiver garbage collection --- tests/document/instance.py | 4 ++++ tests/test_signals.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/tests/document/instance.py b/tests/document/instance.py index 5d4cd1da..f4527ad8 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -751,6 +751,10 @@ class InstanceTest(MongoDBTestCase): '_id': post.id }) + # Important to disconnect as it could cause some assertions in test_signals + # to fail (due to the garbage collection timing of this signal) + signals.pre_save_post_validation.disconnect(BlogPost.pre_save_post_validation) + def test_document_clean(self): class TestDocument(Document): status = StringField() diff --git a/tests/test_signals.py b/tests/test_signals.py index f3b6e33c..34cb43c3 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -227,6 +227,9 @@ class SignalTests(unittest.TestCase): self.ExplicitId.objects.delete() + # Note that there is a chance that the following assert fails in case + # some receivers (eventually created in other tests) + # gets garbage collected (https://pythonhosted.org/blinker/#blinker.base.Signal.connect) self.assertEqual(self.pre_signals, post_signals) def test_model_signals(self): From 5fb0f46e3f110a0725a73bae979d74d76c24e17a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 1 Jun 2019 11:14:01 +0200 Subject: [PATCH 55/71] fix changelog (py37 not yet in travis) --- docs/changelog.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index dbc64e7d..87729df3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,10 +4,9 @@ Changelog Development =========== -- Add support for MongoDB 3.6 and Python3.7 in travis +- Add support for MongoDB 3.6 in travis - BREAKING CHANGE: Changed the custom field validator (i.e `validation` parameter of Field) so that it now requires: the callable to raise a ValidationError (i.o return True/False). -- Prevent an expensive call to to_mongo in Document.save() to improve performance #? - Improve perf of .save by avoiding a call to to_mongo in Document.save() #2049 - Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 - Fix querying on (Generic)EmbeddedDocument subclasses fields #475 From 048a0459662c8b40eb01a32c1b93e116aaaf489b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 4 Jun 2019 21:47:28 +0200 Subject: [PATCH 56/71] Update connection/multiple databases docs I observed that many people were confused by this so I thought I'd make the multiple databases example more explicit --- docs/guide/connecting.rst | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 1107ee3a..aac13902 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -78,26 +78,30 @@ store the data and you can register all aliases up front if required. Documents defined in different database --------------------------------------- -Individual documents can also support multiple databases by providing a +Individual documents can be attached to different databases by providing a `db_alias` in their meta data. This allows :class:`~pymongo.dbref.DBRef` objects to point across databases and collections. Below is an example schema, using 3 different databases to store data:: + connect(alias='user-db-alias', db='user-db') + connect(alias='book-db-alias', db='book-db') + connect(alias='users-books-db-alias', db='users-books-db') + class User(Document): name = StringField() - meta = {'db_alias': 'user-db'} + meta = {'db_alias': 'user-db-alias'} class Book(Document): name = StringField() - meta = {'db_alias': 'book-db'} + meta = {'db_alias': 'book-db-alias'} class AuthorBooks(Document): author = ReferenceField(User) book = ReferenceField(Book) - meta = {'db_alias': 'users-books-db'} + meta = {'db_alias': 'users-books-db-alias'} Disconnecting an existing connection From 9634e44343b31e0880aeb628ac5e89dba3a4d97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 15 May 2019 21:43:29 +0200 Subject: [PATCH 57/71] Fix the issue that the same MongoClient gets re-used in case we connect to 2 databases on the same host (problematic when different users authenticate) --- mongoengine/connection.py | 57 ++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index e0399fde..5fae9507 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -235,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): raise MongoEngineConnectionError(msg) def _clean_settings(settings_dict): - # set literal more efficient than calling set function irrelevant_fields_set = { 'name', 'username', 'password', 'authentication_source', 'authentication_mechanism' @@ -245,10 +244,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if k not in irrelevant_fields_set } + raw_conn_settings = _connection_settings[alias].copy() # Retrieve a copy of the connection settings associated with the requested # alias and remove the database name and authentication info (we don't # care about them at this point). - conn_settings = _clean_settings(_connection_settings[alias].copy()) + conn_settings = _clean_settings(raw_conn_settings) # Determine if we should use PyMongo's or mongomock's MongoClient. is_mock = conn_settings.pop('is_mock', False) @@ -262,19 +262,8 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): else: connection_class = MongoClient - # Iterate over all of the connection settings and if a connection with - # the same parameters is already established, use it instead of creating - # a new one. - existing_connection = None - connection_settings_iterator = ( - (db_alias, settings.copy()) - for db_alias, settings in _connection_settings.items() - ) - for db_alias, connection_settings in connection_settings_iterator: - connection_settings = _clean_settings(connection_settings) - if conn_settings == connection_settings and _connections.get(db_alias): - existing_connection = _connections[db_alias] - break + # Re-use existing connection if one is suitable + existing_connection = _find_existing_connection(raw_conn_settings) # If an existing connection was found, assign it to the new alias if existing_connection: @@ -291,6 +280,44 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): return _connections[alias] +def _create_connection(connection_class, **connection_settings): + # Otherwise, create the new connection for this alias. Raise + # MongoEngineConnectionError if it can't be established. + try: + _connections[alias] = connection_class(**conn_settings) + except Exception as e: + raise MongoEngineConnectionError( + 'Cannot connect to database %s :\n%s' % (alias, e)) + + +def _find_existing_connection(connection_settings): + """ + Check if an existing connection could be reused + + Iterate over all of the connection settings and if an existing connection + with the same parameters is suitable, return it + + :param connection_settings: the settings of the new connection + :return: An existing connection or None + """ + connection_settings_iterator = ( + (db_alias, settings.copy()) + for db_alias, settings in _connection_settings.items() + ) + + def _clean_settings(settings_dict): + # Only remove the name but it's important to + # keep the username/password/authentication_source/authentication_mechanism + # to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047) + return {k: v for k, v in settings_dict.items() if k != 'name'} + + cleaned_conn_settings = _clean_settings(connection_settings) + for db_alias, connection_settings in connection_settings_iterator: + db_conn_settings = _clean_settings(connection_settings) + if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): + return _connections[db_alias] + + def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if reconnect: disconnect(alias) From 84c42ed58c6c7d26ec6a392048653d1993acb319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 4 Jun 2019 22:35:42 +0200 Subject: [PATCH 58/71] Add tests --- mongoengine/connection.py | 25 ++++++++++++------------- tests/test_connection.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 5fae9507..6249225c 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -245,6 +245,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): } raw_conn_settings = _connection_settings[alias].copy() + # Retrieve a copy of the connection settings associated with the requested # alias and remove the database name and authentication info (we don't # care about them at this point). @@ -269,22 +270,20 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if existing_connection: _connections[alias] = existing_connection else: - # Otherwise, create the new connection for this alias. Raise - # MongoEngineConnectionError if it can't be established. - try: - _connections[alias] = connection_class(**conn_settings) - except Exception as e: - raise MongoEngineConnectionError( - 'Cannot connect to database %s :\n%s' % (alias, e)) + _create_connection(alias=alias, + connection_class=connection_class, + **conn_settings) return _connections[alias] -def _create_connection(connection_class, **connection_settings): - # Otherwise, create the new connection for this alias. Raise - # MongoEngineConnectionError if it can't be established. +def _create_connection(alias, connection_class, **connection_settings): + """ + Create the new connection for this alias. Raise + MongoEngineConnectionError if it can't be established. + """ try: - _connections[alias] = connection_class(**conn_settings) + _connections[alias] = connection_class(**connection_settings) except Exception as e: raise MongoEngineConnectionError( 'Cannot connect to database %s :\n%s' % (alias, e)) @@ -300,7 +299,7 @@ def _find_existing_connection(connection_settings): :param connection_settings: the settings of the new connection :return: An existing connection or None """ - connection_settings_iterator = ( + connection_settings_bis = ( (db_alias, settings.copy()) for db_alias, settings in _connection_settings.items() ) @@ -312,7 +311,7 @@ def _find_existing_connection(connection_settings): return {k: v for k, v in settings_dict.items() if k != 'name'} cleaned_conn_settings = _clean_settings(connection_settings) - for db_alias, connection_settings in connection_settings_iterator: + for db_alias, connection_settings in connection_settings_bis: db_conn_settings = _clean_settings(connection_settings) if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): return _connections[db_alias] diff --git a/tests/test_connection.py b/tests/test_connection.py index 5473b8a0..d3fcc395 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -611,6 +611,16 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(mongo_connections['t1'].address[0], 'localhost') self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') + def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): + c1 = connect(alias='testdb1', db='testdb1') + c2 = connect(alias='testdb2', db='testdb2') + self.assertIs(c1, c2) + + def test_connect_2_databases_uses_different_client_if_different_parameters(self): + c1 = connect(alias='testdb1', db='testdb1', username='u1') + c2 = connect(alias='testdb2', db='testdb2', username='u2') + self.assertIsNot(c1, c2) + if __name__ == '__main__': unittest.main() From 36aebffcc0f18fd6823d734cdd4001cef7474f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 4 Jun 2019 22:39:44 +0200 Subject: [PATCH 59/71] update changelog --- docs/changelog.rst | 1 + mongoengine/connection.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 87729df3..7e8fd3d2 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,7 @@ Development - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` - Fix disconnect function #566 #1599 #605 #607 #1213 #565 - Improve connect/disconnect documentations +- Fix issue when using multiple connections to the same mongo with different credentials #2047 - POTENTIAL BREAKING CHANGES: (associated with connect/disconnect fixes) - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) - disconnect now clears `mongoengine.connection._connection_settings` diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 6249225c..9d4f25fc 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -270,9 +270,9 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if existing_connection: _connections[alias] = existing_connection else: - _create_connection(alias=alias, - connection_class=connection_class, - **conn_settings) + _connections[alias] = _create_connection(alias=alias, + connection_class=connection_class, + **conn_settings) return _connections[alias] @@ -283,7 +283,7 @@ def _create_connection(alias, connection_class, **connection_settings): MongoEngineConnectionError if it can't be established. """ try: - _connections[alias] = connection_class(**connection_settings) + return connection_class(**connection_settings) except Exception as e: raise MongoEngineConnectionError( 'Cannot connect to database %s :\n%s' % (alias, e)) From 5bf1dd55b143f6dfb504d774bf345e1e7b7f0458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 4 Jun 2019 22:56:52 +0200 Subject: [PATCH 60/71] Update mongomock example Improved the mongomock example as reported in #2067 Fixes #2067 --- docs/guide/mongomock.rst | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/guide/mongomock.rst b/docs/guide/mongomock.rst index 1d5227ec..d70ee6a6 100644 --- a/docs/guide/mongomock.rst +++ b/docs/guide/mongomock.rst @@ -19,3 +19,30 @@ or with an alias: connect('mongoenginetest', host='mongomock://localhost', alias='testdb') conn = get_connection('testdb') + +Example of test file: +-------- +.. code-block:: python + + import unittest + from mongoengine import connect, disconnect + + class Person(Document): + name = StringField() + + class TestPerson(unittest.TestCase): + + @classmethod + def setUpClass(cls): + connect('mongoenginetest', host='mongomock://localhost') + + @classmethod + def tearDownClass(cls): + disconnect() + + def test_thing(self): + pers = Person(name='John') + pers.save() + + fresh_pers = Person.objects().first() + self.assertEqual(fresh_pers.name, 'John') From 7ed5829b2ce49fb4ef597b0535abc9addda74032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 2 Mar 2019 23:17:24 +0100 Subject: [PATCH 61/71] Add test on datetime field - parse datetime as str --- tests/fields/test_datetime_field.py | 68 +++++++++++++++++++---------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index 5af6a011..83be207f 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -import datetime +import datetime as dt import six try: @@ -41,13 +41,13 @@ class TestDateTimeField(MongoDBTestCase): a document. """ class Person(Document): - created = DateTimeField(default=datetime.datetime.utcnow) + created = DateTimeField(default=dt.datetime.utcnow) - utcnow = datetime.datetime.utcnow() + utcnow = dt.datetime.utcnow() person = Person() person.validate() person_created_t0 = person.created - self.assertLess(person.created - utcnow, datetime.timedelta(seconds=1)) + self.assertLess(person.created - utcnow, dt.timedelta(seconds=1)) self.assertEqual(person_created_t0, person.created) # make sure it does not change self.assertEqual(person._data['created'], person.created) @@ -65,15 +65,15 @@ class TestDateTimeField(MongoDBTestCase): # Test can save dates log = LogEntry() - log.date = datetime.date.today() + log.date = dt.date.today() log.save() log.reload() - self.assertEqual(log.date.date(), datetime.date.today()) + self.assertEqual(log.date.date(), dt.date.today()) # Post UTC - microseconds are rounded (down) nearest millisecond and # dropped - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) + d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 999) + d2 = dt.datetime(1970, 1, 1, 0, 0, 1) log = LogEntry() log.date = d1 log.save() @@ -82,8 +82,8 @@ class TestDateTimeField(MongoDBTestCase): self.assertEqual(log.date, d2) # Post UTC - microseconds are rounded (down) nearest millisecond - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) + d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 9999) + d2 = dt.datetime(1970, 1, 1, 0, 0, 1, 9000) log.date = d1 log.save() log.reload() @@ -93,8 +93,8 @@ class TestDateTimeField(MongoDBTestCase): if not six.PY3: # Pre UTC dates microseconds below 1000 are dropped # This does not seem to be true in PY3 - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + d1 = dt.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = dt.datetime(1969, 12, 31, 23, 59, 59) log.date = d1 log.save() log.reload() @@ -108,7 +108,7 @@ class TestDateTimeField(MongoDBTestCase): LogEntry.drop_collection() - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) + d1 = dt.datetime(1970, 1, 1, 0, 0, 1) log = LogEntry() log.date = d1 log.validate() @@ -124,7 +124,7 @@ class TestDateTimeField(MongoDBTestCase): # create additional 19 log entries for a total of 20 for i in range(1971, 1990): - d = datetime.datetime(i, 1, 1, 0, 0, 1) + d = dt.datetime(i, 1, 1, 0, 0, 1) LogEntry(date=d).save() self.assertEqual(LogEntry.objects.count(), 20) @@ -143,15 +143,15 @@ class TestDateTimeField(MongoDBTestCase): i += 1 # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + logs = LogEntry.objects.filter(date__gte=dt.datetime(1980, 1, 1)) self.assertEqual(logs.count(), 10) - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) + logs = LogEntry.objects.filter(date__lte=dt.datetime(1980, 1, 1)) self.assertEqual(logs.count(), 10) logs = LogEntry.objects.filter( - date__lte=datetime.datetime(1980, 1, 1), - date__gte=datetime.datetime(1975, 1, 1), + date__lte=dt.datetime(1980, 1, 1), + date__gte=dt.datetime(1975, 1, 1), ) self.assertEqual(logs.count(), 5) @@ -163,20 +163,20 @@ class TestDateTimeField(MongoDBTestCase): time = DateTimeField() log = LogEntry() - log.time = datetime.datetime.now() + log.time = dt.datetime.now() log.validate() - log.time = datetime.date.today() + log.time = dt.date.today() log.validate() - log.time = datetime.datetime.now().isoformat(' ') + log.time = dt.datetime.now().isoformat(' ') log.validate() log.time = '2019-05-16 21:42:57.897847' log.validate() if dateutil: - log.time = datetime.datetime.now().isoformat('T') + log.time = dt.datetime.now().isoformat('T') log.validate() log.time = -1 @@ -190,6 +190,26 @@ class TestDateTimeField(MongoDBTestCase): log.time = '2019-05-16 21:42:57.123.456' self.assertRaises(ValidationError, log.validate) + def test_parse_valid_datetime_str(self): + class DTDoc(Document): + date = DateTimeField() + + # make sure that passing a parsable datetime works + dtd = DTDoc() + now = dt.datetime.utcnow() + dtd.date = str(now) + self.assertIsInstance(dtd.date, six.string_types) + dtd.save() + dtd.reload() + + self.assertIsInstance(dtd.date, dt.datetime) + + self.assertNotEqual(dtd.date, now) # microseconds differ as its not stored in mongo + self.assertEqual( + dtd.date.strftime('%Y-%m-%d %H:%M:%S'), + now.strftime('%Y-%m-%d %H:%M:%S') + ) + class TestDateTimeTzAware(MongoDBTestCase): def test_datetime_tz_aware_mark_as_changed(self): @@ -205,8 +225,8 @@ class TestDateTimeTzAware(MongoDBTestCase): LogEntry.drop_collection() - LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save() + LogEntry(time=dt.datetime(2013, 1, 1, 0, 0, 0)).save() log = LogEntry.objects.first() - log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) + log.time = dt.datetime(2013, 1, 1, 0, 0, 0) self.assertEqual(['time'], log._changed_fields) From 27ea01ee05de4b5da8336fe8b8983ff591502670 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 2 Mar 2019 23:44:33 +0100 Subject: [PATCH 62/71] refactored datetime to_mongo, separating parsing from str + added test --- mongoengine/fields.py | 7 +++++-- tests/fields/test_datetime_field.py | 15 +++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1cd6be11..9650403c 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -497,15 +497,18 @@ class DateTimeField(BaseField): if not isinstance(value, six.string_types): return None + return self._parse_datetime(value) + + def _parse_datetime(self, value): + # Attempt to parse a datetime from a string value = value.strip() if not value: return None - # Attempt to parse a datetime: if dateutil: try: return dateutil.parser.parse(value) - except (TypeError, ValueError): + except (TypeError, ValueError, OverflowError): return None # split usecs, because they are not recognized by strptime. diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index 83be207f..92f0668a 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -190,25 +190,24 @@ class TestDateTimeField(MongoDBTestCase): log.time = '2019-05-16 21:42:57.123.456' self.assertRaises(ValidationError, log.validate) - def test_parse_valid_datetime_str(self): + def test_parse_datetime_as_str(self): class DTDoc(Document): date = DateTimeField() + date_str = '2019-03-02 22:26:01' + # make sure that passing a parsable datetime works dtd = DTDoc() - now = dt.datetime.utcnow() - dtd.date = str(now) + dtd.date = date_str self.assertIsInstance(dtd.date, six.string_types) dtd.save() dtd.reload() self.assertIsInstance(dtd.date, dt.datetime) + self.assertEqual(str(dtd.date), date_str) - self.assertNotEqual(dtd.date, now) # microseconds differ as its not stored in mongo - self.assertEqual( - dtd.date.strftime('%Y-%m-%d %H:%M:%S'), - now.strftime('%Y-%m-%d %H:%M:%S') - ) + dtd.date = 'January 1st, 9999999999' + self.assertRaises(ValidationError, dtd.validate) class TestDateTimeTzAware(MongoDBTestCase): From b407c0e6c6a812b80ba3882bea0d9dd1ecac4ad4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 4 Mar 2019 23:42:45 +0100 Subject: [PATCH 63/71] add test for shard key routing (ported from https://github.com/closeio/mongoengine/commit/43f35f5) --- tests/document/instance.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/document/instance.py b/tests/document/instance.py index f4527ad8..04bddea1 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -466,7 +466,13 @@ class InstanceTest(MongoDBTestCase): Animal.drop_collection() doc = Animal(superphylum='Deuterostomia') doc.save() - doc.reload() + + with query_counter() as q: + doc.reload() + query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] + self.assertEqual(set(query_op['query']['filter'].keys()), set(['_id', 'superphylum'])) + + Animal.drop_collection() def test_reload_sharded_nested(self): class SuperPhylum(EmbeddedDocument): @@ -480,6 +486,29 @@ class InstanceTest(MongoDBTestCase): doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) doc.save() doc.reload() + Animal.drop_collection() + + def test_update_shard_key_routing(self): + """Ensures updating a doc with a specified shard_key includes it in + the query. + """ + class Animal(Document): + is_mammal = BooleanField() + name = StringField() + meta = {'shard_key': ('is_mammal', 'id')} + + Animal.drop_collection() + doc = Animal(is_mammal=True, name='Dog') + doc.save() + + with query_counter() as q: + doc.name = 'Cat' + doc.save() + query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] + self.assertEqual(query_op['op'], 'update') + self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal'])) + + Animal.drop_collection() def test_reload_with_changed_fields(self): """Ensures reloading will not affect changed fields""" From 82e28dec43e64e6c7995a3ca4b9eb2f91da675d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 8 Mar 2019 17:09:39 +0100 Subject: [PATCH 64/71] improved string operation code --- mongoengine/fields.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9650403c..ba508a70 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -13,6 +13,8 @@ import pymongo import six from six import iteritems +from mongoengine.queryset.transform import STRING_OPERATORS + try: import dateutil except ImportError: @@ -106,11 +108,11 @@ class StringField(BaseField): if not isinstance(op, six.string_types): return value - if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): - flags = 0 - if op.startswith('i'): - flags = re.IGNORECASE - op = op.lstrip('i') + if op in STRING_OPERATORS: + case_insensitive = op.startswith('i') + op = op.lstrip('i') + + flags = re.IGNORECASE if case_insensitive else 0 regex = r'%s' if op == 'startswith': From 15f4d4fee6c2239eabb30ad2c10d44f4975f9f39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 4 Jun 2019 23:35:38 +0200 Subject: [PATCH 65/71] fix tests for diff mongo vers --- mongoengine/fields.py | 3 +-- mongoengine/mongodb_support.py | 1 + tests/document/instance.py | 13 +++++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index ba508a70..aa5aa805 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -13,8 +13,6 @@ import pymongo import six from six import iteritems -from mongoengine.queryset.transform import STRING_OPERATORS - try: import dateutil except ImportError: @@ -39,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError from mongoengine.python_support import StringIO from mongoengine.queryset import DO_NOTHING from mongoengine.queryset.base import BaseQuerySet +from mongoengine.queryset.transform import STRING_OPERATORS try: from PIL import Image, ImageOps diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index 8234a616..b20ebc1e 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -6,6 +6,7 @@ from mongoengine.connection import get_connection # Constant that can be used to compare the version retrieved with # get_mongodb_version() +MONGODB_34 = (3, 4) MONGODB_36 = (3, 6) diff --git a/tests/document/instance.py b/tests/document/instance.py index 04bddea1..0f2f0c0f 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -12,6 +12,7 @@ from bson import DBRef, ObjectId from pymongo.errors import DuplicateKeyError from six import iteritems +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36, MONGODB_34 from mongoengine.pymongo_support import list_collection_names from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, @@ -467,10 +468,13 @@ class InstanceTest(MongoDBTestCase): doc = Animal(superphylum='Deuterostomia') doc.save() + mongo_db = get_mongodb_version() + CMD_QUERY_KEY = 'command' if mongo_db >= MONGODB_36 else 'query' + with query_counter() as q: doc.reload() query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] - self.assertEqual(set(query_op['query']['filter'].keys()), set(['_id', 'superphylum'])) + self.assertEqual(set(query_op[CMD_QUERY_KEY]['filter'].keys()), set(['_id', 'superphylum'])) Animal.drop_collection() @@ -501,12 +505,17 @@ class InstanceTest(MongoDBTestCase): doc = Animal(is_mammal=True, name='Dog') doc.save() + mongo_db = get_mongodb_version() + with query_counter() as q: doc.name = 'Cat' doc.save() query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] self.assertEqual(query_op['op'], 'update') - self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal'])) + if mongo_db == MONGODB_34: + self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal'])) + else: + self.assertEqual(set(query_op['command']['q'].keys()), set(['_id', 'is_mammal'])) Animal.drop_collection() From 70d6e763b00ea2240cb9e5d4917c38af2bc78c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Wed, 5 Jun 2019 22:06:37 +0200 Subject: [PATCH 66/71] Document the custom field validation feature --- docs/guide/defining-documents.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 911de36d..ae9d3b36 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -176,6 +176,21 @@ arguments can be set on all fields: class Shirt(Document): size = StringField(max_length=3, choices=SIZE) +:attr:`validation` (Optional) + A callable to validate the value of the field. + The callable takes the value as parameter and should raise a ValidationError + if validation fails + + e.g :: + + def _not_empty(val): + if not val: + raise ValidationError('value can not be empty') + + class Person(Document): + name = StringField(validation=_not_empty) + + :attr:`**kwargs` (Optional) You can supply additional metadata as arbitrary additional keyword arguments. You can not override existing attributes, however. Common From 9499c97e1865355184f9bfb26c19c867040f1edd Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Fri, 7 Jun 2019 12:15:35 +0200 Subject: [PATCH 67/71] Clean up the .install_mongodb_on_travis.sh script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a leftover from #2066. Since we no longer install MongoDB versions v2.6 – v3.2, we no longer need this code. --- .install_mongodb_on_travis.sh | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/.install_mongodb_on_travis.sh b/.install_mongodb_on_travis.sh index f1073333..0be02655 100644 --- a/.install_mongodb_on_travis.sh +++ b/.install_mongodb_on_travis.sh @@ -3,23 +3,7 @@ sudo apt-get remove mongodb-org-server sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 7F0CEB10 -if [ "$MONGODB" = "2.6" ]; then - echo "deb http://downloads-distro.mongodb.org/repo/ubuntu-upstart dist 10gen" | sudo tee /etc/apt/sources.list.d/mongodb.list - sudo apt-get update - sudo apt-get install mongodb-org-server=2.6.12 - # service should be started automatically -elif [ "$MONGODB" = "3.0" ]; then - echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb.list - sudo apt-get update - sudo apt-get install mongodb-org-server=3.0.14 - # service should be started automatically -elif [ "$MONGODB" = "3.2" ]; then - sudo apt-key adv --keyserver keyserver.ubuntu.com --recv EA312927 - echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.2 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.2.list - sudo apt-get update - sudo apt-get install mongodb-org-server=3.2.20 - # service should be started automatically -elif [ "$MONGODB" = "3.4" ]; then +if [ "$MONGODB" = "3.4" ]; then sudo apt-key adv --keyserver keyserver.ubuntu.com:80 --recv 0C49F3730359A14518585931BC711F9BA15703C6 echo "deb http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list sudo apt-get update From f996f3df74e0f8ca0d522e725b03eef588389518 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Fri, 7 Jun 2019 12:34:32 +0200 Subject: [PATCH 68/71] Cleaner test_hint --- tests/document/indexes.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 6f486e9f..3bdd1a66 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -544,9 +544,8 @@ class IndexesTest(unittest.TestCase): [('categories', 1), ('_id', 1)]) def test_hint(self): - MONGO_VER = self.mongodb_version - TAGS_INDEX_NAME = 'tags_1' + class BlogPost(Document): tags = ListField(StringField()) meta = { @@ -564,18 +563,27 @@ class IndexesTest(unittest.TestCase): tags = [("tag %i" % n) for n in range(i % 2)] BlogPost(tags=tags).save() - self.assertEqual(BlogPost.objects.count(), 10) - self.assertEqual(BlogPost.objects.hint().count(), 10) + # Hinting by shape should work. + self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) - # MongoDB v3.2+ throws an error if an index exists (i.e `tags` in our - # case) and you use hint on an index name that does not exist. + # Hinting by index name should work. + self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) + + # Clearing the hint should work fine. + self.assertEqual(BlogPost.objects.hint().count(), 10) + self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).hint().count(), 10) + + # Hinting on a non-existent index shape should fail. with self.assertRaises(OperationFailure): BlogPost.objects.hint([('ZZ', 1)]).count() - self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) + # Hinting on a non-existent index name should fail. + with self.assertRaises(OperationFailure): + BlogPost.objects.hint('Bad Name').count() - with self.assertRaises(Exception): - BlogPost.objects.hint(('tags', 1)).next() + # Invalid shape argument (missing list brackets) should fail. + with self.assertRaises(ValueError): + BlogPost.objects.hint(('tags', 1)).count() def test_unique(self): """Ensure that uniqueness constraints are applied to fields. From 8e8c74c621a31870b521ed0216ae91d1e737b971 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Fri, 7 Jun 2019 12:35:38 +0200 Subject: [PATCH 69/71] Drop the unused mongodb_version attribute in IndexesTest --- tests/document/indexes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 3bdd1a66..764ef0c5 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -9,7 +9,6 @@ from six import iteritems from mongoengine import * from mongoengine.connection import get_db -from mongoengine.mongodb_support import get_mongodb_version __all__ = ("IndexesTest", ) @@ -19,7 +18,6 @@ class IndexesTest(unittest.TestCase): def setUp(self): self.connection = connect(db='mongoenginetest') self.db = get_db() - self.mongodb_version = get_mongodb_version() class Person(Document): name = StringField() From 31d99c0bd2db09c9db8df079e8f41a6febc8445b Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Mon, 10 Jun 2019 11:26:47 +0200 Subject: [PATCH 70/71] Cleaner wording in the dev changelog --- docs/changelog.rst | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 7e8fd3d2..9311b8c1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,25 +4,25 @@ Changelog Development =========== -- Add support for MongoDB 3.6 in travis -- BREAKING CHANGE: Changed the custom field validator (i.e `validation` parameter of Field) so that it now requires: - the callable to raise a ValidationError (i.o return True/False). -- Improve perf of .save by avoiding a call to to_mongo in Document.save() #2049 -- Fix querying on List(EmbeddedDocument) subclasses fields #1961 #1492 -- Fix querying on (Generic)EmbeddedDocument subclasses fields #475 -- expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` -- Fix disconnect function #566 #1599 #605 #607 #1213 #565 -- Improve connect/disconnect documentations -- Fix issue when using multiple connections to the same mongo with different credentials #2047 -- POTENTIAL BREAKING CHANGES: (associated with connect/disconnect fixes) - - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) - - disconnect now clears `mongoengine.connection._connection_settings` - - disconnect now clears the cached attribute `Document._collection` -- POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 -- Fix the default write concern of .save that was overwriting the connection write concern #568 -- mongoengine now requires pymongo>=3.5 #2017 -- Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 -- connect() fails immediately when db name contains invalid characters #2031 #1718 +- Drop support for EOL'd MongoDB v2.6, v3.0, and v3.2. +- MongoEngine now requires PyMongo >= v3.4. Travis CI now tests against MongoDB v3.4 – v3.6 and PyMongo v3.4 – v3.6 (#2017 #2066). +- Improve performance by avoiding a call to `to_mongo` in `Document.save()` #2049 +- Connection/disconnection improvements: + - Expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` + - Fix disconnecting #566 #1599 #605 #607 #1213 #565 + - Improve documentation of `connect`/`disconnect` + - Fix issue when using multiple connections to the same mongo with different credentials #2047 + - `connect` fails immediately when db name contains invalid characters #2031 #1718 +- Fix the default write concern of `Document.save` that was overwriting the connection write concern #568 +- Fix querying on `List(EmbeddedDocument)` subclasses fields #1961 #1492 +- Fix querying on `(Generic)EmbeddedDocument` subclasses fields #475 +- Generate unique indices for `SortedListField` and `EmbeddedDocumentListFields` #2020 +- BREAKING CHANGE: Changed the behavior of a custom field validator (i.e `validation` parameter of a `Field`). It is now expected to raise a `ValidationError` instead of returning True/False #2050 +- BREAKING CHANGE: `QuerySet.aggreagte` now takes limit and skip value into account #2029 +- BREAKING CHANGES (associated with connect/disconnect fixes): + - Calling `connect` 2 times with the same alias and different parameter will raise an error (should call `disconnect` first). + - `disconnect` now clears `mongoengine.connection._connection_settings`. + - `disconnect` now clears the cached attribute `Document._collection`. - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 From 1fc5b954f2095ddc116be08f392deb3a6801446a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Mon, 10 Jun 2019 22:38:37 +0200 Subject: [PATCH 71/71] fix typo in changelog --- docs/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9311b8c1..9ef90afe 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -18,7 +18,7 @@ Development - Fix querying on `(Generic)EmbeddedDocument` subclasses fields #475 - Generate unique indices for `SortedListField` and `EmbeddedDocumentListFields` #2020 - BREAKING CHANGE: Changed the behavior of a custom field validator (i.e `validation` parameter of a `Field`). It is now expected to raise a `ValidationError` instead of returning True/False #2050 -- BREAKING CHANGE: `QuerySet.aggreagte` now takes limit and skip value into account #2029 +- BREAKING CHANGE: `QuerySet.aggregate` now takes limit and skip value into account #2029 - BREAKING CHANGES (associated with connect/disconnect fixes): - Calling `connect` 2 times with the same alias and different parameter will raise an error (should call `disconnect` first). - `disconnect` now clears `mongoengine.connection._connection_settings`.