From 53544c5b0ffc99a90f74f881a2c426d2bc9f8c13 Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Sun, 27 Dec 2009 23:08:31 +0000 Subject: [PATCH] Queries now translate keys to correct field names --- mongoengine/fields.py | 12 ++++++++++++ mongoengine/queryset.py | 26 ++++++++++++++++++++++++-- tests/queryset.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1163d51a..badc7363 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -34,6 +34,9 @@ class StringField(BaseField): message = 'String value did not match validation regex' raise ValidationError(message) + def lookup_member(self, member_name): + return None + class IntField(BaseField): """An integer field. @@ -114,6 +117,9 @@ class EmbeddedDocumentField(BaseField): raise ValidationError('Invalid embedded document instance ' 'provided to an EmbeddedDocumentField') + def lookup_member(self, member_name): + return self.document._fields.get(member_name) + class ListField(BaseField): """A list field that wraps a standard field, allowing multiple instances @@ -146,6 +152,9 @@ class ListField(BaseField): raise ValidationError('All items in a list field must be of the ' 'specified type') + def lookup_member(self, member_name): + return self.field.lookup_member(member_name) + class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on @@ -194,3 +203,6 @@ class ReferenceField(BaseField): def validate(self, value): assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef))) + + def lookup_member(self, member_name): + return self.document_type._fields.get(member_name) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index c97aca1e..af78e5b2 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -6,6 +6,10 @@ import pymongo __all__ = ['queryset_manager'] +class InvalidQueryError(Exception): + pass + + class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. @@ -38,7 +42,8 @@ class QuerySet(object): """Filter the selected documents by calling the :class:`~mongoengine.QuerySet` with a query. """ - self._query.update(QuerySet._transform_query(**query)) + query = QuerySet._transform_query(_doc_cls=self._document, **query) + self._query.update(query) return self @property @@ -48,7 +53,7 @@ class QuerySet(object): return self._cursor_obj @classmethod - def _transform_query(cls, **query): + def _transform_query(cls, _doc_cls=None, **query): """Transform a query from Django-style format to Mongo format. """ operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', @@ -63,6 +68,23 @@ class QuerySet(object): op = parts.pop() value = {'$' + op: value} + # Switch field names to proper names [set in Field(name='foo')] + if _doc_cls: + field_names = [] + field = None + for field_name in parts: + if field is None: + # Look up first field from the document + field = _doc_cls._fields[field_name] + else: + # Look up subfield on the previous field + field = field.lookup_member(field_name) + if field is None: + raise InvalidQueryError('Cannot resolve field "%s"' + % field_name) + field_names.append(field.name) + parts = field_names + key = '.'.join(parts) if op is None or key not in mongo_query: mongo_query[key] = value diff --git a/tests/queryset.py b/tests/queryset.py index 48217132..29c62370 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -240,6 +240,35 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_query_field_name(self): + """Ensure that the correct field name is used when querying. + """ + class Comment(EmbeddedDocument): + content = StringField(name='commentContent') + + class BlogPost(Document): + title = StringField(name='postTitle') + comments = ListField(EmbeddedDocumentField(Comment), + name='postComments') + + + BlogPost.drop_collection() + + data = {'title': 'Post 1', 'comments': [Comment(content='test')]} + BlogPost(**data).save() + + self.assertTrue('postTitle' in + BlogPost.objects(title=data['title'])._query) + self.assertFalse('title' in + BlogPost.objects(title=data['title'])._query) + self.assertEqual(len(BlogPost.objects(title=data['title'])), 1) + + self.assertTrue('postComments.commentContent' in + BlogPost.objects(comments__content='test')._query) + self.assertEqual(len(BlogPost.objects(comments__content='test')), 1) + + BlogPost.drop_collection() + def tearDown(self): self.Person.drop_collection()