Fixes tiny documentation error. Adds possibility to add custom validation methods to fields, e. g.:

class Customer(Document):
        country = StringField(validation=lambda value: value in ['DE', 'AT', 'CH'])

Replaced some str() with unicode() for i18n reasons.
This commit is contained in:
Florian Schlachter 2010-04-16 16:59:34 +02:00
parent 0a074e52e0
commit 48facec524
4 changed files with 43 additions and 17 deletions

View File

@ -22,7 +22,7 @@ objects** as class attributes to the document class::
class Page(Document): class Page(Document):
title = StringField(max_length=200, required=True) title = StringField(max_length=200, required=True)
date_modified = DateTimeField(default=datetime.now) date_modified = DateTimeField(default=datetime.datetime.now)
Fields Fields
====== ======

View File

@ -24,7 +24,7 @@ class BaseField(object):
_index_with_types = True _index_with_types = True
def __init__(self, db_field=None, name=None, required=False, default=None, def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False): unique=False, unique_with=None, primary_key=False, validation=None):
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
if name: if name:
import warnings import warnings
@ -36,6 +36,7 @@ class BaseField(object):
self.unique = bool(unique or unique_with) self.unique = bool(unique or unique_with)
self.unique_with = unique_with self.unique_with = unique_with
self.primary_key = primary_key self.primary_key = primary_key
self.validation = validation
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document. Do """Descriptor for retrieving a value from a field in a document. Do
@ -77,8 +78,8 @@ class BaseField(object):
def validate(self, value): def validate(self, value):
"""Perform validation on a value. """Perform validation on a value.
""" """
pass if self.validation is not None and not self.validation(value):
raise ValidationError('Value does not match custom validation method.')
class ObjectIdField(BaseField): class ObjectIdField(BaseField):
"""An field wrapper around MongoDB's ObjectIds. """An field wrapper around MongoDB's ObjectIds.
@ -91,10 +92,10 @@ class ObjectIdField(BaseField):
def to_mongo(self, value): def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId): if not isinstance(value, pymongo.objectid.ObjectId):
try: try:
return pymongo.objectid.ObjectId(str(value)) return pymongo.objectid.ObjectId(unicode(value))
except Exception, e: except Exception, e:
#e.message attribute has been deprecated since Python 2.6 #e.message attribute has been deprecated since Python 2.6
raise ValidationError(str(e)) raise ValidationError(unicode(e))
return value return value
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -102,7 +103,7 @@ class ObjectIdField(BaseField):
def validate(self, value): def validate(self, value):
try: try:
pymongo.objectid.ObjectId(str(value)) pymongo.objectid.ObjectId(unicode(value))
except: except:
raise ValidationError('Invalid Object ID') raise ValidationError('Invalid Object ID')
@ -402,7 +403,7 @@ class BaseDocument(object):
# class if unavailable # class if unavailable
class_name = son.get(u'_cls', cls._class_name) class_name = son.get(u'_cls', cls._class_name)
data = dict((str(key), value) for key, value in son.items()) data = dict((unicode(key), value) for key, value in son.items())
if '_types' in data: if '_types' in data:
del data['_types'] del data['_types']

View File

@ -39,6 +39,8 @@ class StringField(BaseField):
if self.regex is not None and self.regex.match(value) is None: if self.regex is not None and self.regex.match(value) is None:
message = 'String value did not match validation regex' message = 'String value did not match validation regex'
raise ValidationError(message) raise ValidationError(message)
super(StringField, self).validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return None return None
@ -93,6 +95,8 @@ class URLField(StringField):
except Exception, e: except Exception, e:
message = 'This URL appears to be a broken link: %s' % e message = 'This URL appears to be a broken link: %s' % e
raise ValidationError(message) raise ValidationError(message)
super(URLField, self).validate(value)
class EmailField(StringField): class EmailField(StringField):
"""A field that validates input as an E-Mail-Address. """A field that validates input as an E-Mail-Address.
@ -130,7 +134,8 @@ class IntField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Integer value is too large') raise ValidationError('Integer value is too large')
super(IntField, self).validate(value)
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """An floating point number field.
@ -153,6 +158,8 @@ class FloatField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Float value is too large') raise ValidationError('Float value is too large')
super(FloatField, self).validate(value)
class DecimalField(BaseField): class DecimalField(BaseField):
@ -187,6 +194,8 @@ class DecimalField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Decimal value is too large') raise ValidationError('Decimal value is too large')
super(DecimalField, self).validate(value)
class BooleanField(BaseField): class BooleanField(BaseField):
@ -200,6 +209,8 @@ class BooleanField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, bool) assert isinstance(value, bool)
super(BooleanField, self).validate(value)
class DateTimeField(BaseField): class DateTimeField(BaseField):
@ -208,6 +219,8 @@ class DateTimeField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, datetime.datetime) assert isinstance(value, datetime.datetime)
super(DateTimeField, self).validate(value)
class EmbeddedDocumentField(BaseField): class EmbeddedDocumentField(BaseField):
@ -239,6 +252,8 @@ class EmbeddedDocumentField(BaseField):
raise ValidationError('Invalid embedded document instance ' raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField') 'provided to an EmbeddedDocumentField')
self.document.validate(value) self.document.validate(value)
super(EmbeddedDocumentField, self).validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document._fields.get(member_name) return self.document._fields.get(member_name)
@ -315,6 +330,8 @@ class ListField(BaseField):
[self.field.validate(item) for item in value] [self.field.validate(item) for item in value]
except Exception, err: except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(err)) raise ValidationError('Invalid ListField item (%s)' % str(err))
super(ListField, self).validate(value)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if op in ('set', 'unset'): if op in ('set', 'unset'):
@ -359,6 +376,8 @@ class DictField(BaseField):
if any(('.' in k or '$' in k) for k in value): if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not ' raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters') 'contain "." or "$" characters')
super(DictField, self).validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return BaseField(db_field=member_name) return BaseField(db_field=member_name)
@ -374,6 +393,8 @@ class GeoLocationField(DictField):
if len(value) <> 2: if len(value) <> 2:
raise ValidationError('GeoLocationField must have exactly two elements (x, y)') raise ValidationError('GeoLocationField must have exactly two elements (x, y)')
super(GeoLocationField, self).validate(value)
def to_mongo(self, value): def to_mongo(self, value):
return {'x': value[0], 'y': value[1]} return {'x': value[0], 'y': value[1]}
@ -443,6 +464,8 @@ class ReferenceField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
super(ReferenceField, self).validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
@ -506,10 +529,12 @@ class BinaryField(BaseField):
return pymongo.binary.Binary(value) return pymongo.binary.Binary(value)
def to_python(self, value): def to_python(self, value):
return str(value) return unicode(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, str) assert isinstance(value, str)
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
raise ValidationError('Binary value is too long') raise ValidationError('Binary value is too long')
super(BinaryField, self).validate(value)

View File

@ -123,7 +123,7 @@ class Q(object):
# Comparing two ObjectIds in Javascript doesn't work.. # Comparing two ObjectIds in Javascript doesn't work..
if isinstance(value, pymongo.objectid.ObjectId): if isinstance(value, pymongo.objectid.ObjectId):
value = str(value) value = unicode(value)
# Perform the substitution # Perform the substitution
operation_js = op_js % { operation_js = op_js % {
@ -497,13 +497,13 @@ class QuerySet(object):
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, pymongo.code.Code): if isinstance(map_f, pymongo.code.Code):
map_f_scope = map_f.scope map_f_scope = map_f.scope
map_f = str(map_f) map_f = unicode(map_f)
map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope)
reduce_f_scope = {} reduce_f_scope = {}
if isinstance(reduce_f, pymongo.code.Code): if isinstance(reduce_f, pymongo.code.Code):
reduce_f_scope = reduce_f.scope reduce_f_scope = reduce_f.scope
reduce_f = str(reduce_f) reduce_f = unicode(reduce_f)
reduce_f_code = self._sub_js_fields(reduce_f) reduce_f_code = self._sub_js_fields(reduce_f)
reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope)
@ -513,7 +513,7 @@ class QuerySet(object):
finalize_f_scope = {} finalize_f_scope = {}
if isinstance(finalize_f, pymongo.code.Code): if isinstance(finalize_f, pymongo.code.Code):
finalize_f_scope = finalize_f.scope finalize_f_scope = finalize_f.scope
finalize_f = str(finalize_f) finalize_f = unicode(finalize_f)
finalize_f_code = self._sub_js_fields(finalize_f) finalize_f_code = self._sub_js_fields(finalize_f)
finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope)
mr_args['finalize'] = finalize_f mr_args['finalize'] = finalize_f
@ -736,7 +736,7 @@ class QuerySet(object):
# Older versions of PyMongo don't support 'multi' # Older versions of PyMongo don't support 'multi'
self._collection.update(self._query, update, safe=safe_update) self._collection.update(self._query, update, safe=safe_update)
except pymongo.errors.OperationFailure, e: except pymongo.errors.OperationFailure, e:
raise OperationError('Update failed [%s]' % str(e)) raise OperationError(u'Update failed [%s]' % unicode(e))
def __iter__(self): def __iter__(self):
return self return self
@ -752,9 +752,9 @@ class QuerySet(object):
field_name = match.group(1).split('.') field_name = match.group(1).split('.')
fields = QuerySet._lookup_field(self._document, field_name) fields = QuerySet._lookup_field(self._document, field_name)
# Substitute the correct name for the field into the javascript # Substitute the correct name for the field into the javascript
return '["%s"]' % fields[-1].db_field return u'["%s"]' % fields[-1].db_field
return re.sub('\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) return re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
def exec_js(self, code, *fields, **options): def exec_js(self, code, *fields, **options):
"""Execute a Javascript function on the server. A list of fields may be """Execute a Javascript function on the server. A list of fields may be