From 9dfee83e687a9aef625fbc38bf5dd10b16e463dc Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 7 May 2013 11:54:47 +0000 Subject: [PATCH] Fixed querying string versions of ObjectIds issue with ReferenceField (#307) --- docs/changelog.rst | 1 + mongoengine/fields.py | 2 +- mongoengine/queryset/queryset.py | 5 ++-- tests/queryset/queryset.py | 45 +++++++++++++++++++++++++++++++- 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index ad5f6157..842bc7d0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.0 ================ +- Fixed querying string versions of ObjectIds issue with ReferenceField (#307) - Added $setOnInsert support for upserts (#308) - Upserts now possible with just query parameters (#309) - Upserting is the only way to ensure docs are saved correctly (#306) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 49959983..573d9a03 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -854,7 +854,7 @@ class ReferenceField(BaseField): return document.id return document elif not self.dbref and isinstance(document, basestring): - return document + return ObjectId(document) id_field_name = self.document_type._meta['id_field'] id_field = self.document_type._fields[id_field_name] diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 85b683db..191afdd2 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -544,8 +544,9 @@ class QuerySet(object): return c def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to - a maximum depth in order to cut down the number queries to mongodb. + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or + :class:`~bson.object_id.ObjectId` a maximum depth in order to cut down + the number queries to mongodb. .. versionadded:: 0.5 """ diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index ffb53786..b9db297b 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -1263,7 +1263,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): content = StringField() authors = ListField(ReferenceField(self.Person, - reverse_delete_rule=PULL)) + reverse_delete_rule=PULL)) BlogPost.drop_collection() self.Person.drop_collection() @@ -1321,6 +1321,49 @@ class QuerySetTest(unittest.TestCase): self.Person.objects()[:1].delete() self.assertEqual(1, BlogPost.objects.count()) + + def test_reference_field_find(self): + """Ensure cascading deletion of referring documents from the database. + """ + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person) + + BlogPost.drop_collection() + self.Person.drop_collection() + + me = self.Person(name='Test User').save() + BlogPost(content="test 123", author=me).save() + + self.assertEqual(1, BlogPost.objects(author=me).count()) + self.assertEqual(1, BlogPost.objects(author=me.pk).count()) + self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + + self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) + self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) + self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + + def test_reference_field_find_dbref(self): + """Ensure cascading deletion of referring documents from the database. + """ + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, dbref=True) + + BlogPost.drop_collection() + self.Person.drop_collection() + + me = self.Person(name='Test User').save() + BlogPost(content="test 123", author=me).save() + + self.assertEqual(1, BlogPost.objects(author=me).count()) + self.assertEqual(1, BlogPost.objects(author=me.pk).count()) + self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count()) + + self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) + self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) + self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + def test_update(self): """Ensure that atomic updates work properly. """