diff --git a/AUTHORS b/AUTHORS index 28b57bb4..bcf6129e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -80,3 +80,4 @@ that much better: * Stephen Young * tkloc * aid + * yamaneko1212 diff --git a/docs/changelog.rst b/docs/changelog.rst index d510531f..602294a1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Fixed Handle None values for non-required fields. - Removed Document._get_subclasses() - no longer required - Fixed bug requiring subclasses when not actually needed - Fixed deletion of dynamic data diff --git a/mongoengine/base.py b/mongoengine/base.py index 402f191e..ad2ebca7 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -751,6 +751,13 @@ class BaseDocument(object): self._data[name] = value if hasattr(self, '_changed_fields'): self._mark_as_changed(name) + + # Handle None values for required fields + if value is None and name in getattr(self, '_fields', {}): + self._data[name] = value + if hasattr(self, '_changed_fields'): + self._mark_as_changed(name) + return super(BaseDocument, self).__setattr__(name, value) def __expand_dynamic_values(self, name, value): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 6025dd99..308bdf7f 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -694,8 +694,7 @@ class QuerySet(object): elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] - - + # if op and op not in match_operators: if op: if op in geo_operators: @@ -1301,7 +1300,8 @@ class QuerySet(object): field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull', 'addToSet'): - value = field.prepare_query_value(op, value) + if field.required or value is not None: + value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] diff --git a/tests/fields.py b/tests/fields.py index 0bc15a3c..37972fd0 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -46,6 +46,81 @@ class FieldTest(unittest.TestCase): person = Person(age=30) self.assertRaises(ValidationError, person.validate) + def test_not_required_handles_none_in_update(self): + """Ensure that every fields should accept None if required is False. + """ + + class HandleNoneFields(Document): + str_fld = StringField() + int_fld = IntField() + flt_fld = FloatField() + comp_dt_fld = ComplexDateTimeField() + + HandleNoneFields.drop_collection() + + doc = HandleNoneFields() + doc.str_fld = u'spam ham egg' + doc.int_fld = 42 + doc.flt_fld = 4.2 + doc.com_dt_fld = datetime.datetime.utcnow() + doc.save() + + res = HandleNoneFields.objects(id=doc.id).update( + set__str_fld=None, + set__int_fld=None, + set__flt_fld=None, + set__comp_dt_fld=None, + ) + self.assertEqual(res, 1) + + # Retrive data from db and verify it. + ret = HandleNoneFields.objects.all()[0] + + self.assertEqual(ret.str_fld, None) + self.assertEqual(ret.int_fld, None) + self.assertEqual(ret.flt_fld, None) + + # Return current time if retrived value is None. + self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime)) + + def test_not_required_handles_none_from_database(self): + """Ensure that every fields can handle null values from the database. + """ + + class HandleNoneFields(Document): + str_fld = StringField(required=True) + int_fld = IntField(required=True) + flt_fld = FloatField(required=True) + comp_dt_fld = ComplexDateTimeField(required=True) + + HandleNoneFields.drop_collection() + + doc = HandleNoneFields() + doc.str_fld = u'spam ham egg' + doc.int_fld = 42 + doc.flt_fld = 4.2 + doc.com_dt_fld = datetime.datetime.utcnow() + doc.save() + + collection = self.db[HandleNoneFields._get_collection_name()] + obj = collection.update({"_id": doc.id}, {"$unset": { + "str_fld": 1, + "int_fld": 1, + "flt_fld": 1, + "comp_dt_fld": 1} + }) + + # Retrive data from db and verify it. + ret = HandleNoneFields.objects.all()[0] + + self.assertEqual(ret.str_fld, None) + self.assertEqual(ret.int_fld, None) + self.assertEqual(ret.flt_fld, None) + # Return current time if retrived value is None. + self.assert_(isinstance(ret.comp_dt_fld, datetime.datetime)) + + self.assertRaises(ValidationError, ret.validate) + def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. """