Query values may be processed before being used
This commit is contained in:
parent
196f4471be
commit
557fb19d13
@ -49,6 +49,11 @@ class BaseField(object):
|
|||||||
"""
|
"""
|
||||||
return self.to_python(value)
|
return self.to_python(value)
|
||||||
|
|
||||||
|
def prepare_query_value(self, value):
|
||||||
|
"""Prepare a value that is being used in a query for PyMongo.
|
||||||
|
"""
|
||||||
|
return value
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
"""Perform validation on a value.
|
"""Perform validation on a value.
|
||||||
"""
|
"""
|
||||||
@ -67,6 +72,9 @@ class ObjectIdField(BaseField):
|
|||||||
return pymongo.objectid.ObjectId(value)
|
return pymongo.objectid.ObjectId(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def prepare_query_value(self, value):
|
||||||
|
return self.to_mongo(value)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
try:
|
try:
|
||||||
pymongo.objectid.ObjectId(str(value))
|
pymongo.objectid.ObjectId(str(value))
|
||||||
|
@ -67,8 +67,9 @@ class Document(BaseDocument):
|
|||||||
def reload(self):
|
def reload(self):
|
||||||
"""Reloads all attributes from the database.
|
"""Reloads all attributes from the database.
|
||||||
"""
|
"""
|
||||||
object_id = self._fields['id'].to_mongo(self.id)
|
#object_id = self._fields['id'].to_mongo(self.id)
|
||||||
obj = self.__class__.objects(id=object_id).first()
|
#obj = self.__class__.objects(id=object_id).first()
|
||||||
|
obj = self.__class__.objects(id=self.id).first()
|
||||||
for field in self._fields:
|
for field in self._fields:
|
||||||
setattr(self, field, getattr(obj, field))
|
setattr(self, field, getattr(obj, field))
|
||||||
|
|
||||||
|
@ -199,18 +199,24 @@ class ReferenceField(BaseField):
|
|||||||
|
|
||||||
def to_mongo(self, document):
|
def to_mongo(self, document):
|
||||||
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
|
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
|
||||||
|
# document may already be an object id
|
||||||
id_ = document
|
id_ = document
|
||||||
else:
|
else:
|
||||||
|
# We need the id from the saved object to create the DBRef
|
||||||
id_ = document.id
|
id_ = document.id
|
||||||
if id_ is None:
|
if id_ is None:
|
||||||
raise ValidationError('You can only reference documents once '
|
raise ValidationError('You can only reference documents once '
|
||||||
'they have been saved to the database')
|
'they have been saved to the database')
|
||||||
|
|
||||||
|
# id may be a string rather than an ObjectID object
|
||||||
if not isinstance(id_, pymongo.objectid.ObjectId):
|
if not isinstance(id_, pymongo.objectid.ObjectId):
|
||||||
id_ = pymongo.objectid.ObjectId(id_)
|
id_ = pymongo.objectid.ObjectId(id_)
|
||||||
|
|
||||||
collection = self.document_type._meta['collection']
|
collection = self.document_type._meta['collection']
|
||||||
return pymongo.dbref.DBRef(collection, id_)
|
return pymongo.dbref.DBRef(collection, id_)
|
||||||
|
|
||||||
|
def prepare_query_value(self, value):
|
||||||
|
return self.to_mongo(value)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
|
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
|
||||||
|
@ -53,12 +53,13 @@ class QuerySet(object):
|
|||||||
return self._cursor_obj
|
return self._cursor_obj
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _translate_field_name(cls, document, parts):
|
def _lookup_field(cls, document, parts):
|
||||||
"""Translate a field attribute name to a database field name.
|
"""Lookup a field based on its attribute and return a list containing
|
||||||
|
the field's parents and the field.
|
||||||
"""
|
"""
|
||||||
if not isinstance(parts, (list, tuple)):
|
if not isinstance(parts, (list, tuple)):
|
||||||
parts = [parts]
|
parts = [parts]
|
||||||
field_names = []
|
fields = []
|
||||||
field = None
|
field = None
|
||||||
for field_name in parts:
|
for field_name in parts:
|
||||||
if field is None:
|
if field is None:
|
||||||
@ -70,9 +71,15 @@ class QuerySet(object):
|
|||||||
if field is None:
|
if field is None:
|
||||||
raise InvalidQueryError('Cannot resolve field "%s"'
|
raise InvalidQueryError('Cannot resolve field "%s"'
|
||||||
% field_name)
|
% field_name)
|
||||||
field_names.append(field.name)
|
fields.append(field)
|
||||||
return field_names
|
return fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _translate_field_name(cls, doc_cls, parts):
|
||||||
|
"""Translate a field attribute name to a database field name.
|
||||||
|
"""
|
||||||
|
return [field.name for field in QuerySet._lookup_field(doc_cls, parts)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _transform_query(cls, _doc_cls=None, **query):
|
def _transform_query(cls, _doc_cls=None, **query):
|
||||||
"""Transform a query from Django-style format to Mongo format.
|
"""Transform a query from Django-style format to Mongo format.
|
||||||
@ -87,11 +94,22 @@ class QuerySet(object):
|
|||||||
op = None
|
op = None
|
||||||
if parts[-1] in operators:
|
if parts[-1] in operators:
|
||||||
op = parts.pop()
|
op = parts.pop()
|
||||||
value = {'$' + op: value}
|
|
||||||
|
|
||||||
# Switch field names to proper names [set in Field(name='foo')]
|
|
||||||
if _doc_cls:
|
if _doc_cls:
|
||||||
parts = QuerySet._translate_field_name(_doc_cls, parts)
|
# Switch field names to proper names [set in Field(name='foo')]
|
||||||
|
fields = QuerySet._lookup_field(_doc_cls, parts)
|
||||||
|
parts = [field.name for field in fields]
|
||||||
|
|
||||||
|
# Convert value to proper value
|
||||||
|
field = fields[-1]
|
||||||
|
if op in (None, 'neq', 'gt', 'gte', 'lt', 'lte'):
|
||||||
|
value = field.prepare_query_value(value)
|
||||||
|
elif op in ('in', 'nin', 'all'):
|
||||||
|
# 'in', 'nin' and 'all' require a list of values
|
||||||
|
value = [field.prepare_query_value(v) for v in value]
|
||||||
|
|
||||||
|
if op:
|
||||||
|
value = {'$' + op: value}
|
||||||
|
|
||||||
key = '.'.join(parts)
|
key = '.'.join(parts)
|
||||||
if op is None or key not in mongo_query:
|
if op is None or key not in mongo_query:
|
||||||
|
@ -321,6 +321,8 @@ class DocumentTest(unittest.TestCase):
|
|||||||
comments = ListField(EmbeddedDocumentField(Comment))
|
comments = ListField(EmbeddedDocumentField(Comment))
|
||||||
tags = ListField(StringField())
|
tags = ListField(StringField())
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
post = BlogPost(content='Went for a walk today...')
|
post = BlogPost(content='Went for a walk today...')
|
||||||
post.tags = tags = ['fun', 'leisure']
|
post.tags = tags = ['fun', 'leisure']
|
||||||
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
||||||
|
@ -300,6 +300,32 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
def test_query_value_conversion(self):
|
||||||
|
"""Ensure that query values are properly converted when necessary.
|
||||||
|
"""
|
||||||
|
class BlogPost(Document):
|
||||||
|
author = ReferenceField(self.Person)
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
person = self.Person(name='test', age=30)
|
||||||
|
person.save()
|
||||||
|
|
||||||
|
post = BlogPost(author=person)
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
# Test that query may be performed by providing a document as a value
|
||||||
|
# while using a ReferenceField's name - the document should be
|
||||||
|
# converted to an DBRef, which is legal, unlike a Document object
|
||||||
|
post_obj = BlogPost.objects(author=person).first()
|
||||||
|
self.assertEqual(post.id, post_obj.id)
|
||||||
|
|
||||||
|
# Test that lists of values work when using the 'in', 'nin' and 'all'
|
||||||
|
post_obj = BlogPost.objects(author__in=[person]).first()
|
||||||
|
self.assertEqual(post.id, post_obj.id)
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.Person.drop_collection()
|
self.Person.drop_collection()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user