Queries now translate keys to correct field names

This commit is contained in:
Harry Marr 2009-12-27 23:08:31 +00:00
parent 17aef253cb
commit 53544c5b0f
3 changed files with 65 additions and 2 deletions

View File

@ -34,6 +34,9 @@ class StringField(BaseField):
message = 'String value did not match validation regex' message = 'String value did not match validation regex'
raise ValidationError(message) raise ValidationError(message)
def lookup_member(self, member_name):
return None
class IntField(BaseField): class IntField(BaseField):
"""An integer field. """An integer field.
@ -114,6 +117,9 @@ class EmbeddedDocumentField(BaseField):
raise ValidationError('Invalid embedded document instance ' raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField') 'provided to an EmbeddedDocumentField')
def lookup_member(self, member_name):
return self.document._fields.get(member_name)
class ListField(BaseField): class ListField(BaseField):
"""A list field that wraps a standard field, allowing multiple instances """A list field that wraps a standard field, allowing multiple instances
@ -146,6 +152,9 @@ class ListField(BaseField):
raise ValidationError('All items in a list field must be of the ' raise ValidationError('All items in a list field must be of the '
'specified type') 'specified type')
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
class ReferenceField(BaseField): class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on """A reference to a document that will be automatically dereferenced on
@ -194,3 +203,6 @@ class ReferenceField(BaseField):
def validate(self, value): def validate(self, value):
assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef))) assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef)))
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)

View File

@ -6,6 +6,10 @@ import pymongo
__all__ = ['queryset_manager'] __all__ = ['queryset_manager']
class InvalidQueryError(Exception):
pass
class QuerySet(object): class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor, """A set of results returned from a query. Wraps a MongoDB cursor,
providing :class:`~mongoengine.Document` objects as the results. providing :class:`~mongoengine.Document` objects as the results.
@ -38,7 +42,8 @@ class QuerySet(object):
"""Filter the selected documents by calling the """Filter the selected documents by calling the
:class:`~mongoengine.QuerySet` with a query. :class:`~mongoengine.QuerySet` with a query.
""" """
self._query.update(QuerySet._transform_query(**query)) query = QuerySet._transform_query(_doc_cls=self._document, **query)
self._query.update(query)
return self return self
@property @property
@ -48,7 +53,7 @@ class QuerySet(object):
return self._cursor_obj return self._cursor_obj
@classmethod @classmethod
def _transform_query(cls, **query): def _transform_query(cls, _doc_cls=None, **query):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
@ -63,6 +68,23 @@ class QuerySet(object):
op = parts.pop() op = parts.pop()
value = {'$' + op: value} value = {'$' + op: value}
# Switch field names to proper names [set in Field(name='foo')]
if _doc_cls:
field_names = []
field = None
for field_name in parts:
if field is None:
# Look up first field from the document
field = _doc_cls._fields[field_name]
else:
# Look up subfield on the previous field
field = field.lookup_member(field_name)
if field is None:
raise InvalidQueryError('Cannot resolve field "%s"'
% field_name)
field_names.append(field.name)
parts = field_names
key = '.'.join(parts) key = '.'.join(parts)
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
mongo_query[key] = value mongo_query[key] = value

View File

@ -240,6 +240,35 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_query_field_name(self):
"""Ensure that the correct field name is used when querying.
"""
class Comment(EmbeddedDocument):
content = StringField(name='commentContent')
class BlogPost(Document):
title = StringField(name='postTitle')
comments = ListField(EmbeddedDocumentField(Comment),
name='postComments')
BlogPost.drop_collection()
data = {'title': 'Post 1', 'comments': [Comment(content='test')]}
BlogPost(**data).save()
self.assertTrue('postTitle' in
BlogPost.objects(title=data['title'])._query)
self.assertFalse('title' in
BlogPost.objects(title=data['title'])._query)
self.assertEqual(len(BlogPost.objects(title=data['title'])), 1)
self.assertTrue('postComments.commentContent' in
BlogPost.objects(comments__content='test')._query)
self.assertEqual(len(BlogPost.objects(comments__content='test')), 1)
BlogPost.drop_collection()
def tearDown(self): def tearDown(self):
self.Person.drop_collection() self.Person.drop_collection()