Fixed ListField bug, added ReferenceField + tests
This commit is contained in:
parent
90e27cc87d
commit
5fa01d89a5
@ -1,12 +1,13 @@
|
|||||||
from base import BaseField, ObjectIdField, ValidationError
|
from base import BaseField, ObjectIdField, ValidationError
|
||||||
from document import EmbeddedDocument
|
from document import Document, EmbeddedDocument
|
||||||
|
from connection import _get_db
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import pymongo
|
import pymongo
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', 'ListField',
|
__all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', 'ListField',
|
||||||
'ObjectIdField', 'ValidationError']
|
'ObjectIdField', 'ReferenceField', 'ValidationError']
|
||||||
|
|
||||||
|
|
||||||
class StringField(BaseField):
|
class StringField(BaseField):
|
||||||
@ -100,7 +101,7 @@ class ListField(BaseField):
|
|||||||
|
|
||||||
def _to_python(self, value):
|
def _to_python(self, value):
|
||||||
assert(isinstance(value, (list, tuple)))
|
assert(isinstance(value, (list, tuple)))
|
||||||
return list(value)
|
return [self.field._to_python(item) for item in value]
|
||||||
|
|
||||||
def _to_mongo(self, value):
|
def _to_mongo(self, value):
|
||||||
return [self.field._to_mongo(item) for item in value]
|
return [self.field._to_mongo(item) for item in value]
|
||||||
@ -117,3 +118,53 @@ class ListField(BaseField):
|
|||||||
except:
|
except:
|
||||||
raise ValidationError('All items in a list field must be of the '
|
raise ValidationError('All items in a list field must be of the '
|
||||||
'specified type')
|
'specified type')
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceField(BaseField):
|
||||||
|
"""A reference to a document that will be automatically dereferenced on
|
||||||
|
access (lazily).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, document_type, **kwargs):
|
||||||
|
if not issubclass(document_type, Document):
|
||||||
|
raise ValidationError('Argument to ReferenceField constructor '
|
||||||
|
'must be a top level document class')
|
||||||
|
self.document_type = document_type
|
||||||
|
self.document_obj = None
|
||||||
|
super(ReferenceField, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
"""Descriptor to allow lazy dereferencing.
|
||||||
|
"""
|
||||||
|
if instance is None:
|
||||||
|
# Document class being used rather than a document object
|
||||||
|
return self
|
||||||
|
|
||||||
|
# Get value from document instance if available
|
||||||
|
value = instance._data.get(self.name)
|
||||||
|
# Dereference DBRefs
|
||||||
|
if isinstance(value, (pymongo.dbref.DBRef)):
|
||||||
|
value = _get_db().dereference(value)
|
||||||
|
instance._data[self.name] = self.document_type._from_son(value)
|
||||||
|
|
||||||
|
return super(ReferenceField, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
def _to_python(self, document):
|
||||||
|
assert(isinstance(document, (self.document_type, pymongo.dbref.DBRef)))
|
||||||
|
return document
|
||||||
|
|
||||||
|
def _to_mongo(self, document):
|
||||||
|
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
|
||||||
|
_id = document
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
_id = document._id
|
||||||
|
except:
|
||||||
|
raise ValidationError('You can only reference documents once '
|
||||||
|
'they have been saved to the database')
|
||||||
|
|
||||||
|
if not isinstance(_id, pymongo.objectid.ObjectId):
|
||||||
|
_id = pymongo.objectid.ObjectId(_id)
|
||||||
|
|
||||||
|
collection = self.document_type._meta['collection']
|
||||||
|
return pymongo.dbref.DBRef(collection, _id)
|
||||||
|
@ -9,7 +9,7 @@ from mongomap import *
|
|||||||
class CollectionManagerTest(unittest.TestCase):
|
class CollectionManagerTest(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
connect(db='mongotest')
|
connect(db='mongomaptest')
|
||||||
|
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
@ -54,6 +54,8 @@ class CollectionManagerTest(unittest.TestCase):
|
|||||||
self.assertEqual(people.count(), 2)
|
self.assertEqual(people.count(), 2)
|
||||||
results = list(people)
|
results = list(people)
|
||||||
self.assertTrue(isinstance(results[0], self.Person))
|
self.assertTrue(isinstance(results[0], self.Person))
|
||||||
|
self.assertTrue(isinstance(results[0]._id, (pymongo.objectid.ObjectId,
|
||||||
|
str, unicode)))
|
||||||
self.assertEqual(results[0].name, "User A")
|
self.assertEqual(results[0].name, "User A")
|
||||||
self.assertEqual(results[0].age, 20)
|
self.assertEqual(results[0].age, 20)
|
||||||
self.assertEqual(results[1].name, "User B")
|
self.assertEqual(results[1].name, "User B")
|
||||||
|
@ -231,6 +231,36 @@ class DocumentTest(unittest.TestCase):
|
|||||||
# Ensure that the 'details' embedded object saved correctly
|
# Ensure that the 'details' embedded object saved correctly
|
||||||
self.assertEqual(employee_obj['details']['position'], 'Developer')
|
self.assertEqual(employee_obj['details']['position'], 'Developer')
|
||||||
|
|
||||||
|
def test_save_reference(self):
|
||||||
|
"""Ensure that a document reference field may be saved in the database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BlogPost(Document):
|
||||||
|
meta = {'collection': 'blogpost_1'}
|
||||||
|
content = StringField()
|
||||||
|
author = ReferenceField(self.Person)
|
||||||
|
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
|
author = self.Person(name='Test User')
|
||||||
|
author.save()
|
||||||
|
|
||||||
|
post = BlogPost(content='Watched some TV today... how exciting.')
|
||||||
|
post.author = author
|
||||||
|
post.save()
|
||||||
|
|
||||||
|
post_obj = BlogPost.objects.find_one()
|
||||||
|
self.assertTrue(isinstance(post_obj.author, self.Person))
|
||||||
|
self.assertEqual(post_obj.author.name, 'Test User')
|
||||||
|
|
||||||
|
post_obj.author.age = 25
|
||||||
|
post_obj.author.save()
|
||||||
|
|
||||||
|
author = self.Person.objects.find_one(name='Test User')
|
||||||
|
self.assertEqual(author.age, 25)
|
||||||
|
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.db.drop_collection(self.Person._meta['collection'])
|
self.db.drop_collection(self.Person._meta['collection'])
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mongomap import *
|
from mongomap import *
|
||||||
|
from mongomap.connection import _get_db
|
||||||
|
|
||||||
|
|
||||||
class FieldTest(unittest.TestCase):
|
class FieldTest(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
connect(db='mongomaptest')
|
connect(db='mongomaptest')
|
||||||
|
self.db = _get_db()
|
||||||
|
|
||||||
def test_default_values(self):
|
def test_default_values(self):
|
||||||
"""Ensure that default field values are used when creating a document.
|
"""Ensure that default field values are used when creating a document.
|
||||||
@ -146,6 +148,38 @@ class FieldTest(unittest.TestCase):
|
|||||||
post.author = User(name='Test User')
|
post.author = User(name='Test User')
|
||||||
post.author = PowerUser(name='Test User', power=47)
|
post.author = PowerUser(name='Test User', power=47)
|
||||||
|
|
||||||
|
def test_reference_validation(self):
|
||||||
|
"""Ensure that invalid embedded documents cannot be assigned to
|
||||||
|
embedded document fields.
|
||||||
|
"""
|
||||||
|
class User(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class BlogPost(Document):
|
||||||
|
content = StringField()
|
||||||
|
author = ReferenceField(User)
|
||||||
|
|
||||||
|
self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument)
|
||||||
|
|
||||||
|
user = User(name='Test User')
|
||||||
|
|
||||||
|
post1 = BlogPost(content='Chips and gravy taste good.')
|
||||||
|
post1.author = user
|
||||||
|
self.assertRaises(ValidationError, post1.save)
|
||||||
|
|
||||||
|
post2 = BlogPost(content='Chips and chilli taste good.')
|
||||||
|
self.assertRaises(ValidationError, post1.__setattr__, 'author', post2)
|
||||||
|
|
||||||
|
user.save()
|
||||||
|
post1.author = user
|
||||||
|
post1.save()
|
||||||
|
|
||||||
|
post2.save()
|
||||||
|
self.assertRaises(ValidationError, post1.__setattr__, 'author', post2)
|
||||||
|
|
||||||
|
self.db.drop_collection(User._meta['collection'])
|
||||||
|
self.db.drop_collection(BlogPost._meta['collection'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user