diff --git a/mongoengine/base.py b/mongoengine/base.py index e5d8d002..61deb409 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -42,26 +42,25 @@ class BaseField(object): """ if value is not None: try: - value = self._to_python(value) - self._validate(value) - except (ValueError, AttributeError, AssertionError): + self.validate(value) + except (ValueError, AttributeError, AssertionError), e: raise ValidationError('Invalid value for field of type "' + self.__class__.__name__ + '"') elif self.required: raise ValidationError('Field "%s" is required' % self.name) instance._data[self.name] = value - def _to_python(self, value): + def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ - return unicode(value) + return value - def _to_mongo(self, value): + def to_mongo(self, value): """Convert a Python type to a MongoDB-compatible type. """ - return self._to_python(value) + return self.to_python(value) - def _validate(self, value): + def validate(self, value): """Perform validation on a value. """ pass @@ -71,15 +70,15 @@ class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ - def _to_python(self, value): + def to_python(self, value): return str(value) - def _to_mongo(self, value): + def to_mongo(self, value): if not isinstance(value, pymongo.objectid.ObjectId): return pymongo.objectid.ObjectId(value) return value - def _validate(self, value): + def validate(self, value): try: pymongo.objectid.ObjectId(str(value)) except: @@ -218,14 +217,14 @@ class BaseDocument(object): def __len__(self): return len(self._data) - def _to_mongo(self): + def to_mongo(self): """Return data dictionary ready for use with MongoDB. """ data = {} for field_name, field in self._fields.items(): value = getattr(self, field_name, None) if value is not None: - data[field_name] = field._to_mongo(value) + data[field_name] = field.to_mongo(value) data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] return data @@ -246,5 +245,9 @@ class BaseDocument(object): # that has been queried to return this SON return None cls = subclasses[class_name] - return cls(**data) + for field_name, field in cls._fields.items(): + if field_name in data: + data[field_name] = field.to_python(data[field_name]) + + return cls(**data) diff --git a/mongoengine/document.py b/mongoengine/document.py index ab6b74a8..359a1789 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -18,7 +18,7 @@ class Document(BaseDocument): """Save the document to the database. If the document already exists, it will be updated, otherwise it will be created. """ - _id = self.objects._collection.save(self._to_mongo()) + _id = self.objects._collection.save(self.to_mongo()) self._id = _id @classmethod diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 70e58be8..71f54539 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -19,11 +19,10 @@ class StringField(BaseField): self.max_length = max_length super(StringField, self).__init__(**kwargs) - def _to_python(self, value): - assert(isinstance(value, (str, unicode))) + def to_python(self, value): return unicode(value) - def _validate(self, value): + def validate(self, value): assert(isinstance(value, (str, unicode))) if self.max_length is not None and len(value) > self.max_length: @@ -42,12 +41,11 @@ class IntField(BaseField): self.min_value, self.max_value = min_value, max_value super(IntField, self).__init__(**kwargs) - def _to_python(self, value): - assert(isinstance(value, int)) + def to_python(self, value): return int(value) - def _validate(self, value): - assert(isinstance(value, int)) + def validate(self, value): + assert(isinstance(value, (int, long))) if self.min_value is not None and value < self.min_value: raise ValidationError('Integer value is too small') @@ -68,16 +66,15 @@ class EmbeddedDocumentField(BaseField): self.document = document super(EmbeddedDocumentField, self).__init__(**kwargs) - def _to_python(self, value): + def to_python(self, value): if not isinstance(value, self.document): - assert(isinstance(value, (dict, pymongo.son.SON))) return self.document._from_son(value) return value - def _to_mongo(self, value): - return self.document._to_mongo(value) + def to_mongo(self, value): + return self.document.to_mongo(value) - def _validate(self, value): + def validate(self, value): """Make sure that the document instance is an instance of the EmbeddedDocument subclass provided when the document was defined. """ @@ -99,14 +96,13 @@ class ListField(BaseField): self.field = field super(ListField, self).__init__(**kwargs) - def _to_python(self, value): - assert(isinstance(value, (list, tuple))) - return [self.field._to_python(item) for item in value] + def to_python(self, value): + return [self.field.to_python(item) for item in value] - def _to_mongo(self, value): - return [self.field._to_mongo(item) for item in value] + def to_mongo(self, value): + return [self.field.to_mongo(item) for item in value] - def _validate(self, value): + def validate(self, value): """Make sure that a list of valid fields is being used. """ if not isinstance(value, (list, tuple)): @@ -114,7 +110,7 @@ class ListField(BaseField): 'list field') try: - [self.field._validate(item) for item in value] + [self.field.validate(item) for item in value] except: raise ValidationError('All items in a list field must be of the ' 'specified type') @@ -150,11 +146,7 @@ class ReferenceField(BaseField): return super(ReferenceField, self).__get__(instance, owner) - def _to_python(self, document): - assert(isinstance(document, (self.document_type, pymongo.dbref.DBRef))) - return document - - def _to_mongo(self, document): + def to_mongo(self, document): if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): _id = document else: @@ -169,3 +161,6 @@ class ReferenceField(BaseField): collection = self.document_type._meta['collection'] return pymongo.dbref.DBRef(collection, _id) + + def validate(self, value): + assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef))) diff --git a/tests/document.py b/tests/document.py index 4842934c..3dc91f36 100644 --- a/tests/document.py +++ b/tests/document.py @@ -174,7 +174,7 @@ class DocumentTest(unittest.TestCase): person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['age'], 30) - self.assertEqual(str(person_obj['_id']), person._id) + self.assertEqual(person_obj['_id'], person._id) def test_save_custom_id(self): """Ensure that a document may be saved with a custom _id.