diff --git a/mongomap/__init__.py b/mongomap/__init__.py index 214f0fa4..1b7da2b3 100644 --- a/mongomap/__init__.py +++ b/mongomap/__init__.py @@ -1,8 +1,9 @@ -from document import Document +import document +from document import * import fields from fields import * -__all__ = fields.__all__ + ['Document'] +__all__ = document.__all__ + fields.__all__ __author__ = 'Harry Marr' __version__ = '0.1' diff --git a/mongomap/base.py b/mongomap/base.py index f6c97faa..f4be1fb7 100644 --- a/mongomap/base.py +++ b/mongomap/base.py @@ -161,3 +161,8 @@ class BaseDocument(object): if name not in self._fields: raise KeyError(name) return setattr(self, name, value) + + def _to_mongo(self): + """Return data dictionary ready for use with MongoDB. + """ + return dict((k, v) for k, v in self._data.items() if v is not None) diff --git a/mongomap/document.py b/mongomap/document.py index 3fe2bb31..b1b3ffe3 100644 --- a/mongomap/document.py +++ b/mongomap/document.py @@ -2,6 +2,10 @@ from base import DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument #import pymongo + +__all__ = ['Document', 'EmbeddedDocument'] + + class EmbeddedDocument(BaseDocument): __metaclass__ = DocumentMetaclass diff --git a/mongomap/fields.py b/mongomap/fields.py index b8fe9693..6235cb6d 100644 --- a/mongomap/fields.py +++ b/mongomap/fields.py @@ -4,7 +4,8 @@ from document import EmbeddedDocument import re -__all__ = ['StringField', 'IntField', 'ValidationError'] +__all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', + 'ValidationError'] class StringField(BaseField): @@ -52,7 +53,15 @@ class EmbeddedDocumentField(BaseField): def __init__(self, document, **kwargs): if not issubclass(document, EmbeddedDocument): - raise ValidationError('Invalid embedded document provided to an ' - 'EmbeddedDocumentField') + raise ValidationError('Invalid embedded document class provided ' + 'to an EmbeddedDocumentField') self.document = document super(EmbeddedDocumentField, self).__init__(**kwargs) + + def _to_python(self, value): + return value + + def _validate(self, value): + if not isinstance(value, self.document): + raise ValidationError('Invalid embedded document instance ' + 'provided to an EmbeddedDocumentField') diff --git a/tests/fields.py b/tests/fields.py index 973ee8db..58342453 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -65,6 +65,28 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.__setattr__, 'age', 120) self.assertRaises(ValidationError, person.__setattr__, 'age', 'ten') + def test_embedded_document_validation(self): + """Ensure that invalid embedded documents cannot be assigned to + embedded document fields. + """ + class Comment(EmbeddedDocument): + content = StringField() + + class PersonPreferences(EmbeddedDocument): + food = StringField() + number = IntField() + + class Person(Document): + name = StringField() + preferences = EmbeddedDocumentField(PersonPreferences) + + person = Person(name='Test User') + self.assertRaises(ValidationError, person.__setattr__, 'preferences', + 'My preferences') + self.assertRaises(ValidationError, person.__setattr__, 'preferences', + Comment(content='Nice blog post...')) + person.preferences = PersonPreferences(food='Cheese', number=47) + if __name__ == '__main__': unittest.main()