Fixed set/unset issue with ListFields
This commit is contained in:
parent
a93509c9b3
commit
ea1fe6a538
@ -55,7 +55,7 @@ class BaseField(object):
|
|||||||
"""
|
"""
|
||||||
return self.to_python(value)
|
return self.to_python(value)
|
||||||
|
|
||||||
def prepare_query_value(self, value):
|
def prepare_query_value(self, op, value):
|
||||||
"""Prepare a value that is being used in a query for PyMongo.
|
"""Prepare a value that is being used in a query for PyMongo.
|
||||||
"""
|
"""
|
||||||
return value
|
return value
|
||||||
@ -81,7 +81,7 @@ class ObjectIdField(BaseField):
|
|||||||
raise ValidationError(e.message)
|
raise ValidationError(e.message)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def prepare_query_value(self, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self.to_mongo(value)
|
return self.to_mongo(value)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
@ -136,7 +136,7 @@ class EmbeddedDocumentField(BaseField):
|
|||||||
def lookup_member(self, member_name):
|
def lookup_member(self, member_name):
|
||||||
return self.document._fields.get(member_name)
|
return self.document._fields.get(member_name)
|
||||||
|
|
||||||
def prepare_query_value(self, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self.to_mongo(value)
|
return self.to_mongo(value)
|
||||||
|
|
||||||
|
|
||||||
@ -166,15 +166,16 @@ class ListField(BaseField):
|
|||||||
referenced_type = self.field.document_type
|
referenced_type = self.field.document_type
|
||||||
# Get value from document instance if available
|
# Get value from document instance if available
|
||||||
value_list = instance._data.get(self.name)
|
value_list = instance._data.get(self.name)
|
||||||
deref_list = []
|
if value_list:
|
||||||
for value in value_list:
|
deref_list = []
|
||||||
# Dereference DBRefs
|
for value in value_list:
|
||||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
# Dereference DBRefs
|
||||||
value = _get_db().dereference(value)
|
if isinstance(value, (pymongo.dbref.DBRef)):
|
||||||
deref_list.append(referenced_type._from_son(value))
|
value = _get_db().dereference(value)
|
||||||
else:
|
deref_list.append(referenced_type._from_son(value))
|
||||||
deref_list.append(value)
|
else:
|
||||||
instance._data[self.name] = deref_list
|
deref_list.append(value)
|
||||||
|
instance._data[self.name] = deref_list
|
||||||
|
|
||||||
return super(ListField, self).__get__(instance, owner)
|
return super(ListField, self).__get__(instance, owner)
|
||||||
|
|
||||||
@ -197,7 +198,9 @@ class ListField(BaseField):
|
|||||||
raise ValidationError('All items in a list field must be of the '
|
raise ValidationError('All items in a list field must be of the '
|
||||||
'specified type')
|
'specified type')
|
||||||
|
|
||||||
def prepare_query_value(self, value):
|
def prepare_query_value(self, op, value):
|
||||||
|
if op in ('set', 'unset'):
|
||||||
|
return [self.field.to_mongo(v) for v in value]
|
||||||
return self.field.to_mongo(value)
|
return self.field.to_mongo(value)
|
||||||
|
|
||||||
def lookup_member(self, member_name):
|
def lookup_member(self, member_name):
|
||||||
@ -273,7 +276,7 @@ class ReferenceField(BaseField):
|
|||||||
collection = self.document_type._meta['collection']
|
collection = self.document_type._meta['collection']
|
||||||
return pymongo.dbref.DBRef(collection, id_)
|
return pymongo.dbref.DBRef(collection, id_)
|
||||||
|
|
||||||
def prepare_query_value(self, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self.to_mongo(value)
|
return self.to_mongo(value)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
@ -284,10 +284,10 @@ class QuerySet(object):
|
|||||||
# Convert value to proper value
|
# Convert value to proper value
|
||||||
field = fields[-1]
|
field = fields[-1]
|
||||||
if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'):
|
if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'):
|
||||||
value = field.prepare_query_value(value)
|
value = field.prepare_query_value(op, value)
|
||||||
elif op in ('in', 'nin', 'all'):
|
elif op in ('in', 'nin', 'all'):
|
||||||
# 'in', 'nin' and 'all' require a list of values
|
# 'in', 'nin' and 'all' require a list of values
|
||||||
value = [field.prepare_query_value(v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
|
|
||||||
if op:
|
if op:
|
||||||
value = {'$' + op: value}
|
value = {'$' + op: value}
|
||||||
@ -487,9 +487,9 @@ class QuerySet(object):
|
|||||||
# Convert value to proper value
|
# Convert value to proper value
|
||||||
field = fields[-1]
|
field = fields[-1]
|
||||||
if op in (None, 'set', 'unset', 'push', 'pull'):
|
if op in (None, 'set', 'unset', 'push', 'pull'):
|
||||||
value = field.prepare_query_value(value)
|
value = field.prepare_query_value(op, value)
|
||||||
elif op in ('pushAll', 'pullAll'):
|
elif op in ('pushAll', 'pullAll'):
|
||||||
value = [field.prepare_query_value(v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
|
|
||||||
key = '.'.join(parts)
|
key = '.'.join(parts)
|
||||||
|
|
||||||
|
@ -594,6 +594,30 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
def test_update_value_conversion(self):
|
||||||
|
"""Ensure that values used in updates are converted before use.
|
||||||
|
"""
|
||||||
|
class Group(Document):
|
||||||
|
members = ListField(ReferenceField(self.Person))
|
||||||
|
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
user1 = self.Person(name='user1')
|
||||||
|
user1.save()
|
||||||
|
user2 = self.Person(name='user2')
|
||||||
|
user2.save()
|
||||||
|
|
||||||
|
group = Group()
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
Group.objects(id=group.id).update(set__members=[user1, user2])
|
||||||
|
group.reload()
|
||||||
|
|
||||||
|
self.assertTrue(len(group.members) == 2)
|
||||||
|
self.assertEqual(group.members[0].name, user1.name)
|
||||||
|
self.assertEqual(group.members[1].name, user2.name)
|
||||||
|
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
def test_types_index(self):
|
def test_types_index(self):
|
||||||
"""Ensure that and index is used when '_types' is being used in a
|
"""Ensure that and index is used when '_types' is being used in a
|
||||||
|
Loading…
x
Reference in New Issue
Block a user