From 557fb19d131736f858e54c991dceabbe8d68ccbf Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Wed, 6 Jan 2010 03:14:21 +0000 Subject: [PATCH] Query values may be processed before being used --- mongoengine/base.py | 8 ++++++++ mongoengine/document.py | 5 +++-- mongoengine/fields.py | 6 ++++++ mongoengine/queryset.py | 36 +++++++++++++++++++++++++++--------- tests/document.py | 2 ++ tests/queryset.py | 26 ++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 11 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 7c166dff..d4412e2a 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -49,6 +49,11 @@ class BaseField(object): """ return self.to_python(value) + def prepare_query_value(self, value): + """Prepare a value that is being used in a query for PyMongo. + """ + return value + def validate(self, value): """Perform validation on a value. """ @@ -67,6 +72,9 @@ class ObjectIdField(BaseField): return pymongo.objectid.ObjectId(value) return value + def prepare_query_value(self, value): + return self.to_mongo(value) + def validate(self, value): try: pymongo.objectid.ObjectId(str(value)) diff --git a/mongoengine/document.py b/mongoengine/document.py index 822a3ea5..e26c5ed0 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -67,8 +67,9 @@ class Document(BaseDocument): def reload(self): """Reloads all attributes from the database. """ - object_id = self._fields['id'].to_mongo(self.id) - obj = self.__class__.objects(id=object_id).first() + #object_id = self._fields['id'].to_mongo(self.id) + #obj = self.__class__.objects(id=object_id).first() + obj = self.__class__.objects(id=self.id).first() for field in self._fields: setattr(self, field, getattr(obj, field)) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 6612d444..61fb385a 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -199,18 +199,24 @@ class ReferenceField(BaseField): def to_mongo(self, document): if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): + # document may already be an object id id_ = document else: + # We need the id from the saved object to create the DBRef id_ = document.id if id_ is None: raise ValidationError('You can only reference documents once ' 'they have been saved to the database') + # id may be a string rather than an ObjectID object if not isinstance(id_, pymongo.objectid.ObjectId): id_ = pymongo.objectid.ObjectId(id_) collection = self.document_type._meta['collection'] return pymongo.dbref.DBRef(collection, id_) + + def prepare_query_value(self, value): + return self.to_mongo(value) def validate(self, value): assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index ff2d8a3e..7b3f7c6b 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -53,12 +53,13 @@ class QuerySet(object): return self._cursor_obj @classmethod - def _translate_field_name(cls, document, parts): - """Translate a field attribute name to a database field name. + def _lookup_field(cls, document, parts): + """Lookup a field based on its attribute and return a list containing + the field's parents and the field. """ if not isinstance(parts, (list, tuple)): parts = [parts] - field_names = [] + fields = [] field = None for field_name in parts: if field is None: @@ -70,9 +71,15 @@ class QuerySet(object): if field is None: raise InvalidQueryError('Cannot resolve field "%s"' % field_name) - field_names.append(field.name) - return field_names - + fields.append(field) + return fields + + @classmethod + def _translate_field_name(cls, doc_cls, parts): + """Translate a field attribute name to a database field name. + """ + return [field.name for field in QuerySet._lookup_field(doc_cls, parts)] + @classmethod def _transform_query(cls, _doc_cls=None, **query): """Transform a query from Django-style format to Mongo format. @@ -87,11 +94,22 @@ class QuerySet(object): op = None if parts[-1] in operators: op = parts.pop() - value = {'$' + op: value} - # Switch field names to proper names [set in Field(name='foo')] if _doc_cls: - parts = QuerySet._translate_field_name(_doc_cls, parts) + # Switch field names to proper names [set in Field(name='foo')] + fields = QuerySet._lookup_field(_doc_cls, parts) + parts = [field.name for field in fields] + + # Convert value to proper value + field = fields[-1] + if op in (None, 'neq', 'gt', 'gte', 'lt', 'lte'): + value = field.prepare_query_value(value) + elif op in ('in', 'nin', 'all'): + # 'in', 'nin' and 'all' require a list of values + value = [field.prepare_query_value(v) for v in value] + + if op: + value = {'$' + op: value} key = '.'.join(parts) if op is None or key not in mongo_query: diff --git a/tests/document.py b/tests/document.py index 31ae0999..47d9c5c4 100644 --- a/tests/document.py +++ b/tests/document.py @@ -321,6 +321,8 @@ class DocumentTest(unittest.TestCase): comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) + BlogPost.drop_collection() + post = BlogPost(content='Went for a walk today...') post.tags = tags = ['fun', 'leisure'] comments = [Comment(content='Good for you'), Comment(content='Yay.')] diff --git a/tests/queryset.py b/tests/queryset.py index 13e38e4d..e7e79ccc 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -300,6 +300,32 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_query_value_conversion(self): + """Ensure that query values are properly converted when necessary. + """ + class BlogPost(Document): + author = ReferenceField(self.Person) + + BlogPost.drop_collection() + + person = self.Person(name='test', age=30) + person.save() + + post = BlogPost(author=person) + post.save() + + # Test that query may be performed by providing a document as a value + # while using a ReferenceField's name - the document should be + # converted to an DBRef, which is legal, unlike a Document object + post_obj = BlogPost.objects(author=person).first() + self.assertEqual(post.id, post_obj.id) + + # Test that lists of values work when using the 'in', 'nin' and 'all' + post_obj = BlogPost.objects(author__in=[person]).first() + self.assertEqual(post.id, post_obj.id) + + BlogPost.drop_collection() + def tearDown(self): self.Person.drop_collection()