diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 13fa6689..beb8ae00 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -209,20 +209,19 @@ class ReferenceField(BaseField): return super(ReferenceField, self).__get__(instance, owner) def to_mongo(self, document): - if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): - # document may already be an object id - id_ = document - else: + id_field_name = self.document_type._meta['id_field'] + id_field = self.document_type._fields[id_field_name] + + if isinstance(document, Document): # We need the id from the saved object to create the DBRef id_ = document.id if id_ is None: raise ValidationError('You can only reference documents once ' 'they have been saved to the database') + else: + id_ = document - # id may be a string rather than an ObjectID object - if not isinstance(id_, pymongo.objectid.ObjectId): - id_ = pymongo.objectid.ObjectId(id_) - + id_ = id_field.to_mongo(id_) collection = self.document_type._meta['collection'] return pymongo.dbref.DBRef(collection, id_) diff --git a/tests/fields.py b/tests/fields.py index affb0c9b..0ebb143d 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -259,6 +259,40 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() + def test_reference_query_conversion(self): + """Ensure that ReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = ReferenceField(Member) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1.id).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2.id).first() + self.assertEqual(post.id, post2.id) + + Member.drop_collection() + BlogPost.drop_collection() + if __name__ == '__main__': unittest.main()