Fixed set/unset issue with ListFields

This commit is contained in:
Harry Marr 2010-02-12 11:21:51 +00:00
parent a93509c9b3
commit ea1fe6a538
4 changed files with 45 additions and 18 deletions

View File

@ -55,7 +55,7 @@ class BaseField(object):
""" """
return self.to_python(value) return self.to_python(value)
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
"""Prepare a value that is being used in a query for PyMongo. """Prepare a value that is being used in a query for PyMongo.
""" """
return value return value
@ -81,7 +81,7 @@ class ObjectIdField(BaseField):
raise ValidationError(e.message) raise ValidationError(e.message)
return value return value
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):

View File

@ -136,7 +136,7 @@ class EmbeddedDocumentField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document._fields.get(member_name) return self.document._fields.get(member_name)
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
@ -166,6 +166,7 @@ class ListField(BaseField):
referenced_type = self.field.document_type referenced_type = self.field.document_type
# Get value from document instance if available # Get value from document instance if available
value_list = instance._data.get(self.name) value_list = instance._data.get(self.name)
if value_list:
deref_list = [] deref_list = []
for value in value_list: for value in value_list:
# Dereference DBRefs # Dereference DBRefs
@ -197,7 +198,9 @@ class ListField(BaseField):
raise ValidationError('All items in a list field must be of the ' raise ValidationError('All items in a list field must be of the '
'specified type') 'specified type')
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
if op in ('set', 'unset'):
return [self.field.to_mongo(v) for v in value]
return self.field.to_mongo(value) return self.field.to_mongo(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
@ -273,7 +276,7 @@ class ReferenceField(BaseField):
collection = self.document_type._meta['collection'] collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_) return pymongo.dbref.DBRef(collection, id_)
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):

View File

@ -284,10 +284,10 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'): if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'):
value = field.prepare_query_value(value) value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all'): elif op in ('in', 'nin', 'all'):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
if op: if op:
value = {'$' + op: value} value = {'$' + op: value}
@ -487,9 +487,9 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'set', 'unset', 'push', 'pull'): if op in (None, 'set', 'unset', 'push', 'pull'):
value = field.prepare_query_value(value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
key = '.'.join(parts) key = '.'.join(parts)

View File

@ -594,6 +594,30 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_value_conversion(self):
"""Ensure that values used in updates are converted before use.
"""
class Group(Document):
members = ListField(ReferenceField(self.Person))
Group.drop_collection()
user1 = self.Person(name='user1')
user1.save()
user2 = self.Person(name='user2')
user2.save()
group = Group()
group.save()
Group.objects(id=group.id).update(set__members=[user1, user2])
group.reload()
self.assertTrue(len(group.members) == 2)
self.assertEqual(group.members[0].name, user1.name)
self.assertEqual(group.members[1].name, user2.name)
Group.drop_collection()
def test_types_index(self): def test_types_index(self):
"""Ensure that and index is used when '_types' is being used in a """Ensure that and index is used when '_types' is being used in a