diff --git a/mongoengine/fields.py b/mongoengine/fields.py index abea212c..a1638bf0 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -666,6 +666,9 @@ class ReferenceField(BaseField): return pymongo.dbref.DBRef(collection, id_) def prepare_query_value(self, op, value): + if value is None: + return None + return self.to_mongo(value) def validate(self, value): @@ -743,6 +746,9 @@ class GenericReferenceField(BaseField): return {'_cls': document._class_name, '_ref': ref} def prepare_query_value(self, op, value): + if value is None: + return None + return self.to_mongo(value) diff --git a/tests/fields.py b/tests/fields.py index c95b544e..80a343e3 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -992,15 +992,29 @@ class FieldTest(unittest.TestCase): class Company(Document): name = StringField() + Product.drop_collection() + Company.drop_collection() + ten_gen = Company(name='10gen') ten_gen.save() mongodb = Product(name='MongoDB', company=ten_gen) mongodb.save() + me = Product(name='MongoEngine') + me.save() + obj = Product.objects(company=ten_gen).first() self.assertEqual(obj, mongodb) self.assertEqual(obj.company, ten_gen) + obj = Product.objects(company=None).first() + self.assertEqual(obj, me) + + obj, created = Product.objects.get_or_create(company=None) + + self.assertEqual(created, False) + self.assertEqual(obj, me) + def test_reference_query_conversion(self): """Ensure that ReferenceFields can be queried using objects and values of the type of the primary key of the referenced object.