diff --git a/mongomap/fields.py b/mongomap/fields.py index f65aa38a..93b21f1e 100644 --- a/mongomap/fields.py +++ b/mongomap/fields.py @@ -1,12 +1,13 @@ from base import BaseField, ObjectIdField, ValidationError -from document import EmbeddedDocument +from document import Document, EmbeddedDocument +from connection import _get_db import re import pymongo __all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', 'ListField', - 'ObjectIdField', 'ValidationError'] + 'ObjectIdField', 'ReferenceField', 'ValidationError'] class StringField(BaseField): @@ -100,7 +101,7 @@ class ListField(BaseField): def _to_python(self, value): assert(isinstance(value, (list, tuple))) - return list(value) + return [self.field._to_python(item) for item in value] def _to_mongo(self, value): return [self.field._to_mongo(item) for item in value] @@ -117,3 +118,53 @@ class ListField(BaseField): except: raise ValidationError('All items in a list field must be of the ' 'specified type') + + +class ReferenceField(BaseField): + """A reference to a document that will be automatically dereferenced on + access (lazily). + """ + + def __init__(self, document_type, **kwargs): + if not issubclass(document_type, Document): + raise ValidationError('Argument to ReferenceField constructor ' + 'must be a top level document class') + self.document_type = document_type + self.document_obj = None + super(ReferenceField, self).__init__(**kwargs) + + def __get__(self, instance, owner): + """Descriptor to allow lazy dereferencing. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available + value = instance._data.get(self.name) + # Dereference DBRefs + if isinstance(value, (pymongo.dbref.DBRef)): + value = _get_db().dereference(value) + instance._data[self.name] = self.document_type._from_son(value) + + return super(ReferenceField, self).__get__(instance, owner) + + def _to_python(self, document): + assert(isinstance(document, (self.document_type, pymongo.dbref.DBRef))) + return document + + def _to_mongo(self, document): + if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): + _id = document + else: + try: + _id = document._id + except: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + + if not isinstance(_id, pymongo.objectid.ObjectId): + _id = pymongo.objectid.ObjectId(_id) + + collection = self.document_type._meta['collection'] + return pymongo.dbref.DBRef(collection, _id) diff --git a/tests/collection.py b/tests/collection.py index 1b7f1026..bffc9065 100644 --- a/tests/collection.py +++ b/tests/collection.py @@ -9,7 +9,7 @@ from mongomap import * class CollectionManagerTest(unittest.TestCase): def setUp(self): - connect(db='mongotest') + connect(db='mongomaptest') class Person(Document): name = StringField() @@ -54,6 +54,8 @@ class CollectionManagerTest(unittest.TestCase): self.assertEqual(people.count(), 2) results = list(people) self.assertTrue(isinstance(results[0], self.Person)) + self.assertTrue(isinstance(results[0]._id, (pymongo.objectid.ObjectId, + str, unicode))) self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].age, 20) self.assertEqual(results[1].name, "User B") diff --git a/tests/document.py b/tests/document.py index e389b322..d2f3d8e9 100644 --- a/tests/document.py +++ b/tests/document.py @@ -231,6 +231,36 @@ class DocumentTest(unittest.TestCase): # Ensure that the 'details' embedded object saved correctly self.assertEqual(employee_obj['details']['position'], 'Developer') + def test_save_reference(self): + """Ensure that a document reference field may be saved in the database. + """ + + class BlogPost(Document): + meta = {'collection': 'blogpost_1'} + content = StringField() + author = ReferenceField(self.Person) + + self.db.drop_collection(BlogPost._meta['collection']) + + author = self.Person(name='Test User') + author.save() + + post = BlogPost(content='Watched some TV today... how exciting.') + post.author = author + post.save() + + post_obj = BlogPost.objects.find_one() + self.assertTrue(isinstance(post_obj.author, self.Person)) + self.assertEqual(post_obj.author.name, 'Test User') + + post_obj.author.age = 25 + post_obj.author.save() + + author = self.Person.objects.find_one(name='Test User') + self.assertEqual(author.age, 25) + + self.db.drop_collection(BlogPost._meta['collection']) + def tearDown(self): self.db.drop_collection(self.Person._meta['collection']) diff --git a/tests/fields.py b/tests/fields.py index 68113ca2..8f33c21c 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1,12 +1,14 @@ import unittest from mongomap import * +from mongomap.connection import _get_db class FieldTest(unittest.TestCase): def setUp(self): connect(db='mongomaptest') + self.db = _get_db() def test_default_values(self): """Ensure that default field values are used when creating a document. @@ -146,6 +148,38 @@ class FieldTest(unittest.TestCase): post.author = User(name='Test User') post.author = PowerUser(name='Test User', power=47) + def test_reference_validation(self): + """Ensure that invalid embedded documents cannot be assigned to + embedded document fields. + """ + class User(Document): + name = StringField() + + class BlogPost(Document): + content = StringField() + author = ReferenceField(User) + + self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) + + user = User(name='Test User') + + post1 = BlogPost(content='Chips and gravy taste good.') + post1.author = user + self.assertRaises(ValidationError, post1.save) + + post2 = BlogPost(content='Chips and chilli taste good.') + self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) + + user.save() + post1.author = user + post1.save() + + post2.save() + self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) + + self.db.drop_collection(User._meta['collection']) + self.db.drop_collection(BlogPost._meta['collection']) + if __name__ == '__main__': unittest.main()