merged master, fixed 1 merge conflict

This commit is contained in:
blackbrrr 2010-02-14 17:23:38 -06:00
commit 348f7b5dfc
6 changed files with 88 additions and 9 deletions

View File

@ -55,7 +55,7 @@ class BaseField(object):
""" """
return self.to_python(value) return self.to_python(value)
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
"""Prepare a value that is being used in a query for PyMongo. """Prepare a value that is being used in a query for PyMongo.
""" """
return value return value
@ -82,7 +82,7 @@ class ObjectIdField(BaseField):
raise ValidationError(e.message) raise ValidationError(e.message)
return value return value
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):

View File

@ -59,3 +59,4 @@ def connect(db, username=None, password=None, **kwargs):
_db_name = db _db_name = db
_db_username = username _db_username = username
_db_password = password _db_password = password
return _get_db()

View File

@ -188,7 +188,7 @@ class EmbeddedDocumentField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document._fields.get(member_name) return self.document._fields.get(member_name)
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
@ -207,6 +207,30 @@ class ListField(BaseField):
self.field = field self.field = field
super(ListField, self).__init__(**kwargs) super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_list = instance._data.get(self.name)
if value_list:
deref_list = []
for value in value_list:
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_list.append(referenced_type._from_son(value))
else:
deref_list.append(value)
instance._data[self.name] = deref_list
return super(ListField, self).__get__(instance, owner)
def to_python(self, value): def to_python(self, value):
return [self.field.to_python(item) for item in value] return [self.field.to_python(item) for item in value]
@ -226,7 +250,9 @@ class ListField(BaseField):
raise ValidationError('All items in a list field must be of the ' raise ValidationError('All items in a list field must be of the '
'specified type') 'specified type')
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
if op in ('set', 'unset'):
return [self.field.to_mongo(v) for v in value]
return self.field.to_mongo(value) return self.field.to_mongo(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
@ -302,7 +328,7 @@ class ReferenceField(BaseField):
collection = self.document_type._meta['collection'] collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_) return pymongo.dbref.DBRef(collection, id_)
def prepare_query_value(self, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):

View File

@ -284,10 +284,10 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'): if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'):
value = field.prepare_query_value(value) value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all'): elif op in ('in', 'nin', 'all'):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
if op: if op:
value = {'$' + op: value} value = {'$' + op: value}
@ -487,9 +487,9 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'set', 'unset', 'push', 'pull'): if op in (None, 'set', 'unset', 'push', 'pull'):
value = field.prepare_query_value(value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
key = '.'.join(parts) key = '.'.join(parts)

View File

@ -287,6 +287,34 @@ class FieldTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
members = ListField(ReferenceField(User))
User.drop_collection()
Group.drop_collection()
user1 = User(name='user1')
user1.save()
user2 = User(name='user2')
user2.save()
group = Group(members=[user1, user2])
group.save()
group_obj = Group.objects.first()
self.assertEqual(group_obj.members[0].name, user1.name)
self.assertEqual(group_obj.members[1].name, user2.name)
User.drop_collection()
Group.drop_collection()
def test_reference_query_conversion(self): def test_reference_query_conversion(self):
"""Ensure that ReferenceFields can be queried using objects and values """Ensure that ReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object. of the type of the primary key of the referenced object.

View File

@ -594,6 +594,30 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_value_conversion(self):
"""Ensure that values used in updates are converted before use.
"""
class Group(Document):
members = ListField(ReferenceField(self.Person))
Group.drop_collection()
user1 = self.Person(name='user1')
user1.save()
user2 = self.Person(name='user2')
user2.save()
group = Group()
group.save()
Group.objects(id=group.id).update(set__members=[user1, user2])
group.reload()
self.assertTrue(len(group.members) == 2)
self.assertEqual(group.members[0].name, user1.name)
self.assertEqual(group.members[1].name, user2.name)
Group.drop_collection()
def test_types_index(self): def test_types_index(self):
"""Ensure that and index is used when '_types' is being used in a """Ensure that and index is used when '_types' is being used in a