From 282b83ac08a3b4cf680b6f81627187b7a08a2889 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Tue, 4 Sep 2018 23:48:07 +0200 Subject: [PATCH] Fix default value of ComplexDateTime + fixed descriptor .__get__ for class attribute --- mongoengine/base/fields.py | 3 +- mongoengine/fields.py | 20 ++- tests/fields/fields.py | 327 +++++++++++++++++++++---------------- 3 files changed, 199 insertions(+), 151 deletions(-) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 69034d5d..d25d4305 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -55,7 +55,7 @@ class BaseField(object): field. Generally this is deprecated in favour of the `FIELD.validate` method :param choices: (optional) The valid choices - :param null: (optional) Is the field value can be null. If no and there is a default value + :param null: (optional) If the field value can be null. If no and there is a default value then the default value is set :param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False` means that uniqueness won't be enforced for `None` values @@ -130,7 +130,6 @@ class BaseField(object): def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - # If setting to None and there is a default # Then set the value to the default value if value is None: diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 6994b315..bc632ac9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -562,11 +562,15 @@ class ComplexDateTimeField(StringField): The `,` as the separator can be easily modified by passing the `separator` keyword when initializing the field. + Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow) + .. versionadded:: 0.5 """ def __init__(self, separator=',', **kwargs): - self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond'] + """ + :param separator: Allows to customize the separator used for storage (default ``,``) + """ self.separator = separator self.format = separator.join(['%Y', '%m', '%d', '%H', '%M', '%S', '%f']) super(ComplexDateTimeField, self).__init__(**kwargs) @@ -597,16 +601,20 @@ class ComplexDateTimeField(StringField): return datetime.datetime(*values) def __get__(self, instance, owner): + if instance is None: + return self + data = super(ComplexDateTimeField, self).__get__(instance, owner) - if data is None: - return None if self.null else datetime.datetime.now() - if isinstance(data, datetime.datetime): + + if isinstance(data, datetime.datetime) or data is None: return data return self._convert_from_string(data) def __set__(self, instance, value): - value = self._convert_from_datetime(value) if value else value - return super(ComplexDateTimeField, self).__set__(instance, value) + super(ComplexDateTimeField, self).__set__(instance, value) + value = instance._data[self.name] + if value is not None: + instance._data[self.name] = self._convert_from_datetime(value) def validate(self, value): value = self.to_python(value) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 362acec4..4bfd7c74 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -264,12 +264,11 @@ class FieldTest(MongoDBTestCase): # Retrive data from db and verify it. ret = HandleNoneFields.objects.all()[0] - self.assertEqual(ret.str_fld, None) - self.assertEqual(ret.int_fld, None) - self.assertEqual(ret.flt_fld, None) + self.assertIsNone(ret.str_fld) + self.assertIsNone(ret.int_fld) + self.assertIsNone(ret.flt_fld) - # Return current time if retrived value is None. - self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime)) + self.assertIsNone(ret.comp_dt_fld) def test_not_required_handles_none_from_database(self): """Ensure that every field can handle null values from the @@ -287,7 +286,7 @@ class FieldTest(MongoDBTestCase): doc.str_fld = u'spam ham egg' doc.int_fld = 42 doc.flt_fld = 4.2 - doc.com_dt_fld = datetime.datetime.utcnow() + doc.comp_dt_fld = datetime.datetime.utcnow() doc.save() # Unset all the fields @@ -302,12 +301,10 @@ class FieldTest(MongoDBTestCase): # Retrive data from db and verify it. ret = HandleNoneFields.objects.first() - self.assertEqual(ret.str_fld, None) - self.assertEqual(ret.int_fld, None) - self.assertEqual(ret.flt_fld, None) - - # ComplexDateTimeField returns current time if retrived value is None. - self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime)) + self.assertIsNone(ret.str_fld) + self.assertIsNone(ret.int_fld) + self.assertIsNone(ret.flt_fld) + self.assertIsNone(ret.comp_dt_fld) # Retrieved object shouldn't pass validation when a re-save is # attempted. @@ -928,137 +925,6 @@ class FieldTest(MongoDBTestCase): logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) self.assertEqual(logs.count(), 10) - def test_complexdatetime_storage(self): - """Tests for complex datetime fields - which can handle - microseconds without rounding. - """ - class LogEntry(Document): - date = ComplexDateTimeField() - date_with_dots = ComplexDateTimeField(separator='.') - - LogEntry.drop_collection() - - # Post UTC - microseconds are rounded (down) nearest millisecond and - # dropped - with default datetimefields - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) - log = LogEntry() - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Post UTC - microseconds are rounded (down) nearest millisecond - with - # default datetimefields - d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Pre UTC dates microseconds below 1000 are dropped - with default - # datetimefields - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - - # Pre UTC microseconds above 1000 is wonky - with default datetimefields - # log.date has an invalid microsecond value so I can't construct - # a date to compare. - for i in range(1001, 3113, 33): - d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) - log.date = d1 - log.save() - log.reload() - self.assertEqual(log.date, d1) - log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) - - # Test string padding - microsecond = map(int, [math.pow(10, x) for x in range(6)]) - mm = dd = hh = ii = ss = [1, 10] - - for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): - stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] - self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) - - # Test separator - stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] - self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) - - def test_complexdatetime_usage(self): - """Tests for complex datetime fields - which can handle - microseconds without rounding. - """ - class LogEntry(Document): - date = ComplexDateTimeField() - - LogEntry.drop_collection() - - d1 = datetime.datetime(1950, 1, 1, 0, 0, 1, 999) - log = LogEntry() - log.date = d1 - log.save() - - log1 = LogEntry.objects.get(date=d1) - self.assertEqual(log, log1) - - # create extra 59 log entries for a total of 60 - for i in range(1951, 2010): - d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) - LogEntry(date=d).save() - - self.assertEqual(LogEntry.objects.count(), 60) - - # Test ordering - logs = LogEntry.objects.order_by("date") - i = 0 - while i < 59: - self.assertTrue(logs[i].date <= logs[i + 1].date) - i += 1 - - logs = LogEntry.objects.order_by("-date") - i = 0 - while i < 59: - self.assertTrue(logs[i].date >= logs[i + 1].date) - i += 1 - - # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) - - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) - self.assertEqual(logs.count(), 30) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2011, 1, 1), - date__gte=datetime.datetime(2000, 1, 1), - ) - self.assertEqual(logs.count(), 10) - - LogEntry.drop_collection() - - # Test microsecond-level ordering/filtering - for microsecond in (99, 999, 9999, 10000): - LogEntry( - date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) - ).save() - - logs = list(LogEntry.objects.order_by('date')) - for next_idx, log in enumerate(logs[:-1], start=1): - next_log = logs[next_idx] - self.assertTrue(log.date < next_log.date) - - logs = list(LogEntry.objects.order_by('-date')) - for next_idx, log in enumerate(logs[:-1], start=1): - next_log = logs[next_idx] - self.assertTrue(log.date > next_log.date) - - logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) - self.assertEqual(logs.count(), 4) - def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements.""" AccessLevelChoices = ( @@ -5389,5 +5255,180 @@ class GenericLazyReferenceFieldTest(MongoDBTestCase): check_fields_type(occ) +class ComplexDateTimeFieldTest(MongoDBTestCase): + def test_complexdatetime_storage(self): + """Tests for complex datetime fields - which can handle + microseconds without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + date_with_dots = ComplexDateTimeField(separator='.') + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and + # dropped - with default datetimefields + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Post UTC - microseconds are rounded (down) nearest millisecond - with + # default datetimefields + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Pre UTC dates microseconds below 1000 are dropped - with default + # datetimefields + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + + # Pre UTC microseconds above 1000 is wonky - with default datetimefields + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + for i in range(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertEqual(log.date, d1) + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + # Test string padding + microsecond = map(int, [math.pow(10, x) for x in range(6)]) + mm = dd = hh = ii = ss = [1, 10] + + for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): + stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] + self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) + + # Test separator + stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] + self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) + + def test_complexdatetime_usage(self): + """Tests for complex datetime fields - which can handle + microseconds without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1950, 1, 1, 0, 0, 1, 999) + log = LogEntry() + log.date = d1 + log.save() + + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + # create extra 59 log entries for a total of 60 + for i in range(1951, 2010): + d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 60) + + # Test ordering + logs = LogEntry.objects.order_by("date") + i = 0 + while i < 59: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 + + logs = LogEntry.objects.order_by("-date") + i = 0 + while i < 59: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2011, 1, 1), + date__gte=datetime.datetime(2000, 1, 1), + ) + self.assertEqual(logs.count(), 10) + + LogEntry.drop_collection() + + # Test microsecond-level ordering/filtering + for microsecond in (99, 999, 9999, 10000): + LogEntry( + date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) + ).save() + + logs = list(LogEntry.objects.order_by('date')) + for next_idx, log in enumerate(logs[:-1], start=1): + next_log = logs[next_idx] + self.assertTrue(log.date < next_log.date) + + logs = list(LogEntry.objects.order_by('-date')) + for next_idx, log in enumerate(logs[:-1], start=1): + next_log = logs[next_idx] + self.assertTrue(log.date > next_log.date) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) + self.assertEqual(logs.count(), 4) + + def test_no_default_value(self): + class Log(Document): + timestamp = ComplexDateTimeField() + + Log.drop_collection() + + log = Log() + self.assertIsNone(log.timestamp) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertIsNone(fetched_log.timestamp) + + def test_default_static_value(self): + NOW = datetime.datetime.utcnow() + class Log(Document): + timestamp = ComplexDateTimeField(default=NOW) + + Log.drop_collection() + + log = Log() + self.assertEqual(log.timestamp, NOW) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertEqual(fetched_log.timestamp, NOW) + + def test_default_callable(self): + NOW = datetime.datetime.utcnow() + + class Log(Document): + timestamp = ComplexDateTimeField(default=datetime.datetime.utcnow) + + Log.drop_collection() + + log = Log() + self.assertGreaterEqual(log.timestamp, NOW) + log.save() + + fetched_log = Log.objects.with_id(log.id) + self.assertGreaterEqual(fetched_log.timestamp, NOW) + + if __name__ == '__main__': unittest.main()