Fixed ListField bug, added ReferenceField + tests
This commit is contained in:
		| @@ -1,12 +1,13 @@ | |||||||
| from base import BaseField, ObjectIdField, ValidationError | from base import BaseField, ObjectIdField, ValidationError | ||||||
| from document import EmbeddedDocument | from document import Document, EmbeddedDocument | ||||||
|  | from connection import _get_db | ||||||
|   |   | ||||||
| import re | import re | ||||||
| import pymongo | import pymongo | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', 'ListField', | __all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', 'ListField', | ||||||
|            'ObjectIdField', 'ValidationError'] |            'ObjectIdField', 'ReferenceField', 'ValidationError'] | ||||||
|  |  | ||||||
|  |  | ||||||
| class StringField(BaseField): | class StringField(BaseField): | ||||||
| @@ -100,7 +101,7 @@ class ListField(BaseField): | |||||||
|  |  | ||||||
|     def _to_python(self, value): |     def _to_python(self, value): | ||||||
|         assert(isinstance(value, (list, tuple))) |         assert(isinstance(value, (list, tuple))) | ||||||
|         return list(value) |         return [self.field._to_python(item) for item in value] | ||||||
|  |  | ||||||
|     def _to_mongo(self, value): |     def _to_mongo(self, value): | ||||||
|         return [self.field._to_mongo(item) for item in value] |         return [self.field._to_mongo(item) for item in value] | ||||||
| @@ -117,3 +118,53 @@ class ListField(BaseField): | |||||||
|         except: |         except: | ||||||
|             raise ValidationError('All items in a list field must be of the ' |             raise ValidationError('All items in a list field must be of the ' | ||||||
|                                   'specified type') |                                   'specified type') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ReferenceField(BaseField): | ||||||
|  |     """A reference to a document that will be automatically dereferenced on | ||||||
|  |     access (lazily). | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, document_type, **kwargs): | ||||||
|  |         if not issubclass(document_type, Document): | ||||||
|  |             raise ValidationError('Argument to ReferenceField constructor ' | ||||||
|  |                                   'must be a top level document class') | ||||||
|  |         self.document_type = document_type | ||||||
|  |         self.document_obj = None | ||||||
|  |         super(ReferenceField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|  |     def __get__(self, instance, owner): | ||||||
|  |         """Descriptor to allow lazy dereferencing. | ||||||
|  |         """ | ||||||
|  |         if instance is None: | ||||||
|  |             # Document class being used rather than a document object | ||||||
|  |             return self | ||||||
|  |  | ||||||
|  |         # Get value from document instance if available | ||||||
|  |         value = instance._data.get(self.name) | ||||||
|  |         # Dereference DBRefs | ||||||
|  |         if isinstance(value, (pymongo.dbref.DBRef)): | ||||||
|  |             value = _get_db().dereference(value) | ||||||
|  |             instance._data[self.name] = self.document_type._from_son(value) | ||||||
|  |          | ||||||
|  |         return super(ReferenceField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
|  |     def _to_python(self, document): | ||||||
|  |         assert(isinstance(document, (self.document_type, pymongo.dbref.DBRef))) | ||||||
|  |         return document | ||||||
|  |  | ||||||
|  |     def _to_mongo(self, document): | ||||||
|  |         if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): | ||||||
|  |             _id = document | ||||||
|  |         else: | ||||||
|  |             try: | ||||||
|  |                 _id = document._id | ||||||
|  |             except: | ||||||
|  |                 raise ValidationError('You can only reference documents once ' | ||||||
|  |                                       'they have been saved to the database') | ||||||
|  |  | ||||||
|  |         if not isinstance(_id, pymongo.objectid.ObjectId): | ||||||
|  |             _id = pymongo.objectid.ObjectId(_id) | ||||||
|  |  | ||||||
|  |         collection = self.document_type._meta['collection'] | ||||||
|  |         return pymongo.dbref.DBRef(collection, _id) | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ from mongomap import * | |||||||
| class CollectionManagerTest(unittest.TestCase): | class CollectionManagerTest(unittest.TestCase): | ||||||
|      |      | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         connect(db='mongotest') |         connect(db='mongomaptest') | ||||||
|  |  | ||||||
|         class Person(Document): |         class Person(Document): | ||||||
|             name = StringField() |             name = StringField() | ||||||
| @@ -54,6 +54,8 @@ class CollectionManagerTest(unittest.TestCase): | |||||||
|         self.assertEqual(people.count(), 2) |         self.assertEqual(people.count(), 2) | ||||||
|         results = list(people) |         results = list(people) | ||||||
|         self.assertTrue(isinstance(results[0], self.Person)) |         self.assertTrue(isinstance(results[0], self.Person)) | ||||||
|  |         self.assertTrue(isinstance(results[0]._id, (pymongo.objectid.ObjectId, | ||||||
|  |                                                     str, unicode))) | ||||||
|         self.assertEqual(results[0].name, "User A") |         self.assertEqual(results[0].name, "User A") | ||||||
|         self.assertEqual(results[0].age, 20) |         self.assertEqual(results[0].age, 20) | ||||||
|         self.assertEqual(results[1].name, "User B") |         self.assertEqual(results[1].name, "User B") | ||||||
|   | |||||||
| @@ -231,6 +231,36 @@ class DocumentTest(unittest.TestCase): | |||||||
|         # Ensure that the 'details' embedded object saved correctly |         # Ensure that the 'details' embedded object saved correctly | ||||||
|         self.assertEqual(employee_obj['details']['position'], 'Developer') |         self.assertEqual(employee_obj['details']['position'], 'Developer') | ||||||
|  |  | ||||||
|  |     def test_save_reference(self): | ||||||
|  |         """Ensure that a document reference field may be saved in the database. | ||||||
|  |         """ | ||||||
|  |          | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             meta = {'collection': 'blogpost_1'} | ||||||
|  |             content = StringField() | ||||||
|  |             author = ReferenceField(self.Person) | ||||||
|  |  | ||||||
|  |         self.db.drop_collection(BlogPost._meta['collection']) | ||||||
|  |  | ||||||
|  |         author = self.Person(name='Test User') | ||||||
|  |         author.save() | ||||||
|  |  | ||||||
|  |         post = BlogPost(content='Watched some TV today... how exciting.') | ||||||
|  |         post.author = author | ||||||
|  |         post.save() | ||||||
|  |  | ||||||
|  |         post_obj = BlogPost.objects.find_one() | ||||||
|  |         self.assertTrue(isinstance(post_obj.author, self.Person)) | ||||||
|  |         self.assertEqual(post_obj.author.name, 'Test User') | ||||||
|  |  | ||||||
|  |         post_obj.author.age = 25 | ||||||
|  |         post_obj.author.save() | ||||||
|  |  | ||||||
|  |         author = self.Person.objects.find_one(name='Test User') | ||||||
|  |         self.assertEqual(author.age, 25) | ||||||
|  |  | ||||||
|  |         self.db.drop_collection(BlogPost._meta['collection']) | ||||||
|  |  | ||||||
|     def tearDown(self): |     def tearDown(self): | ||||||
|         self.db.drop_collection(self.Person._meta['collection']) |         self.db.drop_collection(self.Person._meta['collection']) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,12 +1,14 @@ | |||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from mongomap import * | from mongomap import * | ||||||
|  | from mongomap.connection import _get_db | ||||||
|  |  | ||||||
|  |  | ||||||
| class FieldTest(unittest.TestCase): | class FieldTest(unittest.TestCase): | ||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         connect(db='mongomaptest') |         connect(db='mongomaptest') | ||||||
|  |         self.db = _get_db() | ||||||
|  |  | ||||||
|     def test_default_values(self): |     def test_default_values(self): | ||||||
|         """Ensure that default field values are used when creating a document. |         """Ensure that default field values are used when creating a document. | ||||||
| @@ -146,6 +148,38 @@ class FieldTest(unittest.TestCase): | |||||||
|         post.author = User(name='Test User') |         post.author = User(name='Test User') | ||||||
|         post.author = PowerUser(name='Test User', power=47) |         post.author = PowerUser(name='Test User', power=47) | ||||||
|  |  | ||||||
|  |     def test_reference_validation(self): | ||||||
|  |         """Ensure that invalid embedded documents cannot be assigned to | ||||||
|  |         embedded document fields. | ||||||
|  |         """ | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             content = StringField() | ||||||
|  |             author = ReferenceField(User) | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) | ||||||
|  |  | ||||||
|  |         user = User(name='Test User') | ||||||
|  |  | ||||||
|  |         post1 = BlogPost(content='Chips and gravy taste good.') | ||||||
|  |         post1.author = user | ||||||
|  |         self.assertRaises(ValidationError, post1.save) | ||||||
|  |  | ||||||
|  |         post2 = BlogPost(content='Chips and chilli taste good.') | ||||||
|  |         self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) | ||||||
|  |  | ||||||
|  |         user.save() | ||||||
|  |         post1.author = user | ||||||
|  |         post1.save() | ||||||
|  |  | ||||||
|  |         post2.save() | ||||||
|  |         self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) | ||||||
|  |  | ||||||
|  |         self.db.drop_collection(User._meta['collection']) | ||||||
|  |         self.db.drop_collection(BlogPost._meta['collection']) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user