diff --git a/docs/changelog.rst b/docs/changelog.rst index e3f3ecd1..d1714e78 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,7 @@ Changes in 0.9.X - DEV - OperationError: Shard Keys are immutable. Tried to update id even though the document is not yet saved #771 - with_limit_and_skip for count should default like in pymongo #759 - Fix storing value of precision attribute in DecimalField #787 +- Set attribute to None does not work (at least for fields with default values) #734 - Querying by a field defined in a subclass raises InvalidQueryError #744 - Add Support For MongoDB 2.6.X's maxTimeMS #778 - abstract shouldn't be inherited in EmbeddedDocument # 789 diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 5bb9c7ac..2747dde1 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -37,7 +37,7 @@ class BaseField(object): def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, validation=None, choices=None, verbose_name=None, - help_text=None): + help_text=None, null=False): """ :param db_field: The database field to store this field in (defaults to the name of the field) @@ -60,6 +60,8 @@ class BaseField(object): model forms from the document model. :param help_text: (optional) The help text for this field and is often used when generating model forms from the document model. + :param null: (optional) Is the field value can be null. If no and there is a default value + then the default value is set """ self.db_field = (db_field or name) if not primary_key else '_id' @@ -75,6 +77,7 @@ class BaseField(object): self.choices = choices self.verbose_name = verbose_name self.help_text = help_text + self.null = null # Adjust the appropriate creation counter, and save our local copy. if self.db_field == '_id': @@ -100,10 +103,13 @@ class BaseField(object): # If setting to None and theres a default # Then set the value to the default value - if value is None and self.default is not None: - value = self.default - if callable(value): - value = value() + if value is None: + if self.null: + value = None + elif self.default is not None: + value = self.default + if callable(value): + value = value() if instance._initialised: try: diff --git a/tests/document/instance.py b/tests/document/instance.py index b7733548..36118512 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -57,7 +57,9 @@ class InstanceTest(unittest.TestCase): self.db.drop_collection(collection) def assertDbEqual(self, docs): - self.assertEqual(list(self.Person._get_collection().find().sort("id")), sorted(docs, key=lambda doc: doc["_id"])) + self.assertEqual( + list(self.Person._get_collection().find().sort("id")), + sorted(docs, key=lambda doc: doc["_id"])) def test_capped_collection(self): """Ensure that capped collections work properly. @@ -144,10 +146,18 @@ class InstanceTest(unittest.TestCase): """ class Animal(Document): meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Dog(Mammal): pass - class Human(Mammal): pass + + class Fish(Animal): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass class Zoo(Document): animals = ListField(ReferenceField(Animal)) @@ -459,7 +469,7 @@ class InstanceTest(unittest.TestCase): f.reload() except Foo.DoesNotExist: pass - except Exception as ex: + except Exception: self.assertFalse("Threw wrong exception") f.save() @@ -468,7 +478,7 @@ class InstanceTest(unittest.TestCase): f.reload() except Foo.DoesNotExist: pass - except Exception as ex: + except Exception: self.assertFalse("Threw wrong exception") def test_dictionary_access(self): @@ -503,8 +513,9 @@ class InstanceTest(unittest.TestCase): self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), ['_cls', 'name', 'age']) - self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) + self.assertEqual( + Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) def test_embedded_document_to_mongo_id(self): class SubDoc(EmbeddedDocument): @@ -641,7 +652,8 @@ class InstanceTest(unittest.TestCase): def test_modify_empty(self): doc = self.Person(name="bob", age=10).save() - self.assertRaises(InvalidDocumentError, lambda: self.Person().modify(set__age=10)) + self.assertRaises( + InvalidDocumentError, lambda: self.Person().modify(set__age=10)) self.assertDbEqual([dict(doc.to_mongo())]) def test_modify_invalid_query(self): @@ -649,8 +661,9 @@ class InstanceTest(unittest.TestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - self.assertRaises(InvalidQueryError, lambda: - doc1.modify(dict(id=doc2.id), set__value=20)) + self.assertRaises( + InvalidQueryError, + lambda: doc1.modify(dict(id=doc2.id), set__value=20)) self.assertDbEqual(docs) @@ -674,7 +687,8 @@ class InstanceTest(unittest.TestCase): def test_modify_update(self): other_doc = self.Person(name="bob", age=10).save() - doc = self.Person(name="jim", age=20, job=self.Job(name="10gen", years=3)).save() + doc = self.Person( + name="jim", age=20, job=self.Job(name="10gen", years=3)).save() doc_copy = doc._from_son(doc.to_mongo()) @@ -683,7 +697,8 @@ class InstanceTest(unittest.TestCase): doc.job.name = "Google" doc.job.years = 3 - assert doc.modify(set__age=21, set__job__name="MongoDB", unset__job__years=True) + assert doc.modify( + set__age=21, set__job__name="MongoDB", unset__job__years=True) doc_copy.age = 21 doc_copy.job.name = "MongoDB" del doc_copy.job.years @@ -931,7 +946,7 @@ class InstanceTest(unittest.TestCase): w1 = Widget(toggle=False, save_id=UUID(1)) # ignore save_condition on new record creation - w1.save(save_condition={'save_id':UUID(42)}) + w1.save(save_condition={'save_id': UUID(42)}) w1.reload() self.assertFalse(w1.toggle) self.assertEqual(w1.save_id, UUID(1)) @@ -941,7 +956,7 @@ class InstanceTest(unittest.TestCase): flip(w1) self.assertTrue(w1.toggle) self.assertEqual(w1.count, 1) - w1.save(save_condition={'save_id':UUID(42)}) + w1.save(save_condition={'save_id': UUID(42)}) w1.reload() self.assertFalse(w1.toggle) self.assertEqual(w1.count, 0) @@ -950,7 +965,7 @@ class InstanceTest(unittest.TestCase): flip(w1) self.assertTrue(w1.toggle) self.assertEqual(w1.count, 1) - w1.save(save_condition={'save_id':UUID(1)}) + w1.save(save_condition={'save_id': UUID(1)}) w1.reload() self.assertTrue(w1.toggle) self.assertEqual(w1.count, 1) @@ -963,25 +978,25 @@ class InstanceTest(unittest.TestCase): flip(w1) w1.save_id = UUID(2) - w1.save(save_condition={'save_id':old_id}) + w1.save(save_condition={'save_id': old_id}) w1.reload() self.assertFalse(w1.toggle) self.assertEqual(w1.count, 2) flip(w2) flip(w2) - w2.save(save_condition={'save_id':old_id}) + w2.save(save_condition={'save_id': old_id}) w2.reload() self.assertFalse(w2.toggle) self.assertEqual(w2.count, 2) # save_condition uses mongoengine-style operator syntax flip(w1) - w1.save(save_condition={'count__lt':w1.count}) + w1.save(save_condition={'count__lt': w1.count}) w1.reload() self.assertTrue(w1.toggle) self.assertEqual(w1.count, 3) flip(w1) - w1.save(save_condition={'count__gte':w1.count}) + w1.save(save_condition={'count__gte': w1.count}) w1.reload() self.assertTrue(w1.toggle) self.assertEqual(w1.count, 3) @@ -1427,7 +1442,8 @@ class InstanceTest(unittest.TestCase): self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') def test_save_custom_pk(self): - """Ensure that a document may be saved with a custom _id using pk alias. + """ + Ensure that a document may be saved with a custom _id using pk alias. """ # Create person object and save it to the database person = self.Person(name='Test User', age=30, @@ -1513,9 +1529,15 @@ class InstanceTest(unittest.TestCase): p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")]) p4.save() - self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) - self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) - self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) + self.assertEqual( + [p1, p2], + list(Page.objects.filter(comments__user=u1))) + self.assertEqual( + [p1, p2, p4], + list(Page.objects.filter(comments__user=u2))) + self.assertEqual( + [p1, p3], + list(Page.objects.filter(comments__user=u3))) def test_save_embedded_document(self): """Ensure that a document with an embedded document field may be @@ -1590,7 +1612,8 @@ class InstanceTest(unittest.TestCase): self.assertEqual(promoted_employee.age, 50) # Ensure that the 'details' embedded object saved correctly - self.assertEqual(promoted_employee.details.position, 'Senior Developer') + self.assertEqual( + promoted_employee.details.position, 'Senior Developer') # Test removal promoted_employee.details = None @@ -1726,7 +1749,8 @@ class InstanceTest(unittest.TestCase): post.save() reviewer.delete() - self.assertEqual(BlogPost.objects.count(), 1) # No effect on the BlogPost + # No effect on the BlogPost + self.assertEqual(BlogPost.objects.count(), 1) self.assertEqual(BlogPost.objects.get().reviewer, None) # Delete the Person, which should lead to deletion of the BlogPost, too @@ -1775,8 +1799,10 @@ class InstanceTest(unittest.TestCase): class BlogPost(Document): content = StringField() - authors = ListField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) - reviewers = ListField(ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) + authors = ListField(ReferenceField( + self.Person, reverse_delete_rule=CASCADE)) + reviewers = ListField(ReferenceField( + self.Person, reverse_delete_rule=NULLIFY)) self.Person.drop_collection() @@ -1876,8 +1902,12 @@ class InstanceTest(unittest.TestCase): def throw_invalid_document_error(): class Blog(Document): content = StringField() - authors = MapField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) - reviewers = DictField(field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) + authors = MapField(ReferenceField( + self.Person, reverse_delete_rule=CASCADE)) + reviewers = DictField( + field=ReferenceField( + self.Person, + reverse_delete_rule=NULLIFY)) self.assertRaises(InvalidDocumentError, throw_invalid_document_error) @@ -1886,7 +1916,8 @@ class InstanceTest(unittest.TestCase): father = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY) - self.assertRaises(InvalidDocumentError, throw_invalid_document_error_embedded) + self.assertRaises( + InvalidDocumentError, throw_invalid_document_error_embedded) def test_reverse_delete_rule_cascade_recurs(self): """Ensure that a chain of documents is also deleted upon cascaded @@ -1908,16 +1939,16 @@ class InstanceTest(unittest.TestCase): author = self.Person(name='Test User') author.save() - post = BlogPost(content = 'Watched some TV') + post = BlogPost(content='Watched some TV') post.author = author post.save() - comment = Comment(text = 'Kudos.') + comment = Comment(text='Kudos.') comment.post = post comment.save() - # Delete the Person, which should lead to deletion of the BlogPost, and, - # recursively to the Comment, too + # Delete the Person, which should lead to deletion of the BlogPost, + # and, recursively to the Comment, too author.delete() self.assertEqual(Comment.objects.count(), 0) @@ -1940,7 +1971,7 @@ class InstanceTest(unittest.TestCase): author = self.Person(name='Test User') author.save() - post = BlogPost(content = 'Watched some TV') + post = BlogPost(content='Watched some TV') post.author = author post.save() @@ -2056,7 +2087,8 @@ class InstanceTest(unittest.TestCase): def test_dynamic_document_pickle(self): - pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2']) + pickle_doc = PickleDynamicTest( + name="test", number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved @@ -2078,7 +2110,8 @@ class InstanceTest(unittest.TestCase): pickle_doc.embedded._dynamic_fields.keys()) def test_picklable_on_signals(self): - pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) + pickle_doc = PickleSignalsTest( + number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() pickle_doc.save() pickle_doc.delete() @@ -2233,9 +2266,15 @@ class InstanceTest(unittest.TestCase): self.assertEqual(AuthorBooks._get_db(), get_db("testdb-3")) # Collections - self.assertEqual(User._get_collection(), get_db("testdb-1")[User._get_collection_name()]) - self.assertEqual(Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()]) - self.assertEqual(AuthorBooks._get_collection(), get_db("testdb-3")[AuthorBooks._get_collection_name()]) + self.assertEqual( + User._get_collection(), + get_db("testdb-1")[User._get_collection_name()]) + self.assertEqual( + Book._get_collection(), + get_db("testdb-2")[Book._get_collection_name()]) + self.assertEqual( + AuthorBooks._get_collection(), + get_db("testdb-3")[AuthorBooks._get_collection_name()]) def test_db_alias_overrides(self): """db_alias can be overriden @@ -2613,7 +2652,9 @@ class InstanceTest(unittest.TestCase): system.save() system = NodesSystem.objects.first() - self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) + self.assertEqual( + "UNDEFINED", + system.nodes["node"].parameters["param"].macros["test"].value) def test_embedded_document_equality(self): @@ -2730,5 +2771,26 @@ class InstanceTest(unittest.TestCase): self.assertEquals(p.id, None) p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here + def test_null_field(self): + # 734 + class User(Document): + name = StringField() + height = IntField(default=184, null=True) + User.objects.delete() + u = User(name='user') + u.save() + u_from_db = User.objects.get(name='user') + u_from_db.height = None + u_from_db.save() + self.assertEquals(u_from_db.height, None) + + # 735 + User.objects.delete() + u = User(name='user') + u.save() + User.objects(name='user').update_one(set__height=None, upsert=True) + u_from_db = User.objects.get(name='user') + self.assertEquals(u_from_db.height, None) + if __name__ == '__main__': unittest.main()