1267 lines
40 KiB
Python
1267 lines
40 KiB
Python
import unittest
|
|
import datetime
|
|
from decimal import Decimal
|
|
|
|
import pymongo
|
|
import gridfs
|
|
|
|
from mongoengine import *
|
|
from mongoengine.connection import _get_db
|
|
from mongoengine.base import _document_registry, NotRegistered
|
|
|
|
|
|
class FieldTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
connect(db='mongoenginetest')
|
|
self.db = _get_db()
|
|
|
|
def test_default_values(self):
|
|
"""Ensure that default field values are used when creating a document.
|
|
"""
|
|
class Person(Document):
|
|
name = StringField()
|
|
age = IntField(default=30)
|
|
userid = StringField(default=lambda: 'test')
|
|
|
|
person = Person(name='Test Person')
|
|
self.assertEqual(person._data['age'], 30)
|
|
self.assertEqual(person._data['userid'], 'test')
|
|
|
|
def test_required_values(self):
|
|
"""Ensure that required field constraints are enforced.
|
|
"""
|
|
class Person(Document):
|
|
name = StringField(required=True)
|
|
age = IntField(required=True)
|
|
userid = StringField()
|
|
|
|
person = Person(name="Test User")
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person = Person(age=30)
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
def test_object_id_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to string fields.
|
|
"""
|
|
class Person(Document):
|
|
name = StringField()
|
|
|
|
person = Person(name='Test User')
|
|
self.assertEqual(person.id, None)
|
|
|
|
person.id = 47
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
person.id = 'abc'
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
person.id = '497ce96f395f2f052a494fd4'
|
|
person.validate()
|
|
|
|
def test_string_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to string fields.
|
|
"""
|
|
class Person(Document):
|
|
name = StringField(max_length=20)
|
|
userid = StringField(r'[0-9a-z_]+$')
|
|
|
|
person = Person(name=34)
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
# Test regex validation on userid
|
|
person = Person(userid='test.User')
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
person.userid = 'test_user'
|
|
self.assertEqual(person.userid, 'test_user')
|
|
person.validate()
|
|
|
|
# Test max length validation on name
|
|
person = Person(name='Name that is more than twenty characters')
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
person.name = 'Shorter name'
|
|
person.validate()
|
|
|
|
def test_url_validation(self):
|
|
"""Ensure that URLFields validate urls properly.
|
|
"""
|
|
class Link(Document):
|
|
url = URLField()
|
|
|
|
link = Link()
|
|
link.url = 'google'
|
|
self.assertRaises(ValidationError, link.validate)
|
|
|
|
link.url = 'http://www.google.com:8080'
|
|
link.validate()
|
|
|
|
def test_int_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to int fields.
|
|
"""
|
|
class Person(Document):
|
|
age = IntField(min_value=0, max_value=110)
|
|
|
|
person = Person()
|
|
person.age = 50
|
|
person.validate()
|
|
|
|
person.age = -1
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.age = 120
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.age = 'ten'
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
def test_float_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to float fields.
|
|
"""
|
|
class Person(Document):
|
|
height = FloatField(min_value=0.1, max_value=3.5)
|
|
|
|
person = Person()
|
|
person.height = 1.89
|
|
person.validate()
|
|
|
|
person.height = '2.0'
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.height = 0.01
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.height = 4.0
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
def test_decimal_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to decimal fields.
|
|
"""
|
|
class Person(Document):
|
|
height = DecimalField(min_value=Decimal('0.1'),
|
|
max_value=Decimal('3.5'))
|
|
|
|
Person.drop_collection()
|
|
|
|
person = Person()
|
|
person.height = Decimal('1.89')
|
|
person.save()
|
|
person.reload()
|
|
self.assertEqual(person.height, Decimal('1.89'))
|
|
|
|
person.height = '2.0'
|
|
person.save()
|
|
person.height = 0.01
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.height = Decimal('0.01')
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.height = Decimal('4.0')
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
Person.drop_collection()
|
|
|
|
def test_boolean_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to boolean fields.
|
|
"""
|
|
class Person(Document):
|
|
admin = BooleanField()
|
|
|
|
person = Person()
|
|
person.admin = True
|
|
person.validate()
|
|
|
|
person.admin = 2
|
|
self.assertRaises(ValidationError, person.validate)
|
|
person.admin = 'Yes'
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
def test_datetime_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to datetime fields.
|
|
"""
|
|
class LogEntry(Document):
|
|
time = DateTimeField()
|
|
|
|
log = LogEntry()
|
|
log.time = datetime.datetime.now()
|
|
log.validate()
|
|
|
|
log.time = -1
|
|
self.assertRaises(ValidationError, log.validate)
|
|
log.time = '1pm'
|
|
self.assertRaises(ValidationError, log.validate)
|
|
|
|
def test_datetime(self):
|
|
"""Tests showing pymongo datetime fields handling of microseconds.
|
|
Microseconds are rounded to the nearest millisecond and pre UTC
|
|
handling is wonky.
|
|
|
|
See: http://api.mongodb.org/python/current/api/bson/son.html#dt
|
|
"""
|
|
class LogEntry(Document):
|
|
date = DateTimeField()
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped
|
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
|
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
|
|
log = LogEntry()
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertNotEquals(log.date, d1)
|
|
self.assertEquals(log.date, d2)
|
|
|
|
# Post UTC - microseconds are rounded (down) nearest millisecond
|
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
|
|
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000)
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertNotEquals(log.date, d1)
|
|
self.assertEquals(log.date, d2)
|
|
|
|
# Pre UTC dates microseconds below 1000 are dropped
|
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
|
|
d2 = datetime.datetime(1969, 12, 31, 23, 59, 59)
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertNotEquals(log.date, d1)
|
|
self.assertEquals(log.date, d2)
|
|
|
|
# Pre UTC microseconds above 1000 is wonky.
|
|
# log.date has an invalid microsecond value so I can't construct
|
|
# a date to compare.
|
|
#
|
|
# However, the timedelta is predicable with pre UTC timestamps
|
|
# It always adds 16 seconds and [777216-776217] microseconds
|
|
for i in xrange(1001, 3113, 33):
|
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertNotEquals(log.date, d1)
|
|
|
|
delta = log.date - d1
|
|
self.assertEquals(delta.seconds, 16)
|
|
microseconds = 777216 - (i % 1000)
|
|
self.assertEquals(delta.microseconds, microseconds)
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
def test_complexdatetime_storage(self):
|
|
"""Tests for complex datetime fields - which can handle microseconds
|
|
without rounding.
|
|
"""
|
|
class LogEntry(Document):
|
|
date = ComplexDateTimeField()
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields
|
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
|
log = LogEntry()
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertEquals(log.date, d1)
|
|
|
|
# Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields
|
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertEquals(log.date, d1)
|
|
|
|
# Pre UTC dates microseconds below 1000 are dropped - with default datetimefields
|
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertEquals(log.date, d1)
|
|
|
|
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
|
|
# log.date has an invalid microsecond value so I can't construct
|
|
# a date to compare.
|
|
for i in xrange(1001, 3113, 33):
|
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
|
|
log.date = d1
|
|
log.save()
|
|
log.reload()
|
|
self.assertEquals(log.date, d1)
|
|
log1 = LogEntry.objects.get(date=d1)
|
|
self.assertEqual(log, log1)
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
def test_complexdatetime_usage(self):
|
|
"""Tests for complex datetime fields - which can handle microseconds
|
|
without rounding.
|
|
"""
|
|
class LogEntry(Document):
|
|
date = ComplexDateTimeField()
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
|
log = LogEntry()
|
|
log.date = d1
|
|
log.save()
|
|
|
|
log1 = LogEntry.objects.get(date=d1)
|
|
self.assertEquals(log, log1)
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
# create 60 log entries
|
|
for i in xrange(1950, 2010):
|
|
d = datetime.datetime(i, 01, 01, 00, 00, 01, 999)
|
|
LogEntry(date=d).save()
|
|
|
|
self.assertEqual(LogEntry.objects.count(), 60)
|
|
|
|
# Test ordering
|
|
logs = LogEntry.objects.order_by("date")
|
|
count = logs.count()
|
|
i = 0
|
|
while i == count-1:
|
|
self.assertTrue(logs[i].date <= logs[i+1].date)
|
|
i +=1
|
|
|
|
logs = LogEntry.objects.order_by("-date")
|
|
count = logs.count()
|
|
i = 0
|
|
while i == count-1:
|
|
self.assertTrue(logs[i].date >= logs[i+1].date)
|
|
i +=1
|
|
|
|
# Test searching
|
|
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1))
|
|
self.assertEqual(logs.count(), 30)
|
|
|
|
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1))
|
|
self.assertEqual(logs.count(), 30)
|
|
|
|
logs = LogEntry.objects.filter(
|
|
date__lte=datetime.datetime(2011,1,1),
|
|
date__gte=datetime.datetime(2000,1,1),
|
|
)
|
|
self.assertEqual(logs.count(), 10)
|
|
|
|
LogEntry.drop_collection()
|
|
|
|
def test_list_validation(self):
|
|
"""Ensure that a list field only accepts lists with valid elements.
|
|
"""
|
|
class User(Document):
|
|
pass
|
|
|
|
class Comment(EmbeddedDocument):
|
|
content = StringField()
|
|
|
|
class BlogPost(Document):
|
|
content = StringField()
|
|
comments = ListField(EmbeddedDocumentField(Comment))
|
|
tags = ListField(StringField())
|
|
authors = ListField(ReferenceField(User))
|
|
|
|
post = BlogPost(content='Went for a walk today...')
|
|
post.validate()
|
|
|
|
post.tags = 'fun'
|
|
self.assertRaises(ValidationError, post.validate)
|
|
post.tags = [1, 2]
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.tags = ['fun', 'leisure']
|
|
post.validate()
|
|
post.tags = ('fun', 'leisure')
|
|
post.validate()
|
|
|
|
post.comments = ['a']
|
|
self.assertRaises(ValidationError, post.validate)
|
|
post.comments = 'yay'
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
|
post.comments = comments
|
|
post.validate()
|
|
|
|
post.authors = [Comment()]
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.authors = [User()]
|
|
post.validate()
|
|
|
|
def test_sorted_list_sorting(self):
|
|
"""Ensure that a sorted list field properly sorts values.
|
|
"""
|
|
class Comment(EmbeddedDocument):
|
|
order = IntField()
|
|
content = StringField()
|
|
|
|
class BlogPost(Document):
|
|
content = StringField()
|
|
comments = SortedListField(EmbeddedDocumentField(Comment),
|
|
ordering='order')
|
|
tags = SortedListField(StringField())
|
|
|
|
post = BlogPost(content='Went for a walk today...')
|
|
post.save()
|
|
|
|
post.tags = ['leisure', 'fun']
|
|
post.save()
|
|
post.reload()
|
|
self.assertEqual(post.tags, ['fun', 'leisure'])
|
|
|
|
comment1 = Comment(content='Good for you', order=1)
|
|
comment2 = Comment(content='Yay.', order=0)
|
|
comments = [comment1, comment2]
|
|
post.comments = comments
|
|
post.save()
|
|
post.reload()
|
|
self.assertEqual(post.comments[0].content, comment2.content)
|
|
self.assertEqual(post.comments[1].content, comment1.content)
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
def test_list_field(self):
|
|
"""Ensure that list types work as expected.
|
|
"""
|
|
class BlogPost(Document):
|
|
info = ListField()
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
post = BlogPost()
|
|
post.info = 'my post'
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.info = {'title': 'test'}
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.info = ['test']
|
|
post.save()
|
|
|
|
post = BlogPost()
|
|
post.info = [{'test': 'test'}]
|
|
post.save()
|
|
|
|
post = BlogPost()
|
|
post.info = [{'test': 3}]
|
|
post.save()
|
|
|
|
|
|
self.assertEquals(BlogPost.objects.count(), 3)
|
|
self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1)
|
|
self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1)
|
|
|
|
# Confirm handles non strings or non existing keys
|
|
self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
|
|
self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
|
|
BlogPost.drop_collection()
|
|
|
|
def test_list_field_strict(self):
|
|
"""Ensure that list field handles validation if provided a strict field type."""
|
|
|
|
class Simple(Document):
|
|
mapping = ListField(field=IntField())
|
|
|
|
Simple.drop_collection()
|
|
|
|
e = Simple()
|
|
e.mapping = [1]
|
|
e.save()
|
|
|
|
def create_invalid_mapping():
|
|
e.mapping = ["abc"]
|
|
e.save()
|
|
|
|
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
|
|
Simple.drop_collection()
|
|
|
|
def test_list_field_complex(self):
|
|
"""Ensure that the list fields can handle the complex types."""
|
|
|
|
class SettingBase(EmbeddedDocument):
|
|
pass
|
|
|
|
class StringSetting(SettingBase):
|
|
value = StringField()
|
|
|
|
class IntegerSetting(SettingBase):
|
|
value = IntField()
|
|
|
|
class Simple(Document):
|
|
mapping = ListField()
|
|
|
|
Simple.drop_collection()
|
|
e = Simple()
|
|
e.mapping.append(StringSetting(value='foo'))
|
|
e.mapping.append(IntegerSetting(value=42))
|
|
e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
|
|
'complex': IntegerSetting(value=42), 'list':
|
|
[IntegerSetting(value=42), StringSetting(value='foo')]})
|
|
e.save()
|
|
|
|
e2 = Simple.objects.get(id=e.id)
|
|
self.assertTrue(isinstance(e2.mapping[0], StringSetting))
|
|
self.assertTrue(isinstance(e2.mapping[1], IntegerSetting))
|
|
|
|
# Test querying
|
|
self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
|
|
|
|
# Confirm can update
|
|
Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
|
|
self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1)
|
|
|
|
Simple.objects().update(
|
|
set__mapping__2__list__1=StringSetting(value='Boo'))
|
|
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
|
|
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
|
|
|
|
Simple.drop_collection()
|
|
|
|
def test_dict_field(self):
|
|
"""Ensure that dict types work as expected.
|
|
"""
|
|
class BlogPost(Document):
|
|
info = DictField()
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
post = BlogPost()
|
|
post.info = 'my post'
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.info = ['test', 'test']
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.info = {'$title': 'test'}
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.info = {'the.title': 'test'}
|
|
self.assertRaises(ValidationError, post.validate)
|
|
|
|
post.info = {'title': 'test'}
|
|
post.save()
|
|
|
|
post = BlogPost()
|
|
post.info = {'details': {'test': 'test'}}
|
|
post.save()
|
|
|
|
post = BlogPost()
|
|
post.info = {'details': {'test': 3}}
|
|
post.save()
|
|
|
|
self.assertEquals(BlogPost.objects.count(), 3)
|
|
self.assertEquals(BlogPost.objects.filter(info__title__exact='test').count(), 1)
|
|
self.assertEquals(BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
|
|
|
|
# Confirm handles non strings or non existing keys
|
|
self.assertEquals(BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
|
|
self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
|
|
BlogPost.drop_collection()
|
|
|
|
def test_dictfield_strict(self):
|
|
"""Ensure that dict field handles validation if provided a strict field type."""
|
|
|
|
class Simple(Document):
|
|
mapping = DictField(field=IntField())
|
|
|
|
Simple.drop_collection()
|
|
|
|
e = Simple()
|
|
e.mapping['someint'] = 1
|
|
e.save()
|
|
|
|
def create_invalid_mapping():
|
|
e.mapping['somestring'] = "abc"
|
|
e.save()
|
|
|
|
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
|
|
Simple.drop_collection()
|
|
|
|
def test_dictfield_complex(self):
|
|
"""Ensure that the dict field can handle the complex types."""
|
|
|
|
class SettingBase(EmbeddedDocument):
|
|
pass
|
|
|
|
class StringSetting(SettingBase):
|
|
value = StringField()
|
|
|
|
class IntegerSetting(SettingBase):
|
|
value = IntField()
|
|
|
|
class Simple(Document):
|
|
mapping = DictField()
|
|
|
|
Simple.drop_collection()
|
|
e = Simple()
|
|
e.mapping['somestring'] = StringSetting(value='foo')
|
|
e.mapping['someint'] = IntegerSetting(value=42)
|
|
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001,
|
|
'complex': IntegerSetting(value=42), 'list':
|
|
[IntegerSetting(value=42), StringSetting(value='foo')]}
|
|
e.save()
|
|
|
|
e2 = Simple.objects.get(id=e.id)
|
|
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
|
|
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
|
|
|
|
# Test querying
|
|
self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
|
|
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
|
|
|
|
# Confirm can update
|
|
Simple.objects().update(
|
|
set__mapping={"someint": IntegerSetting(value=10)})
|
|
Simple.objects().update(
|
|
set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
|
|
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
|
|
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
|
|
|
|
Simple.drop_collection()
|
|
|
|
def test_mapfield(self):
|
|
"""Ensure that the MapField handles the declared type."""
|
|
|
|
class Simple(Document):
|
|
mapping = MapField(IntField())
|
|
|
|
Simple.drop_collection()
|
|
|
|
e = Simple()
|
|
e.mapping['someint'] = 1
|
|
e.save()
|
|
|
|
def create_invalid_mapping():
|
|
e.mapping['somestring'] = "abc"
|
|
e.save()
|
|
|
|
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
|
|
def create_invalid_class():
|
|
class NoDeclaredType(Document):
|
|
mapping = MapField()
|
|
|
|
self.assertRaises(ValidationError, create_invalid_class)
|
|
|
|
Simple.drop_collection()
|
|
|
|
def test_complex_mapfield(self):
|
|
"""Ensure that the MapField can handle complex declared types."""
|
|
|
|
class SettingBase(EmbeddedDocument):
|
|
pass
|
|
|
|
class StringSetting(SettingBase):
|
|
value = StringField()
|
|
|
|
class IntegerSetting(SettingBase):
|
|
value = IntField()
|
|
|
|
class Extensible(Document):
|
|
mapping = MapField(EmbeddedDocumentField(SettingBase))
|
|
|
|
Extensible.drop_collection()
|
|
|
|
e = Extensible()
|
|
e.mapping['somestring'] = StringSetting(value='foo')
|
|
e.mapping['someint'] = IntegerSetting(value=42)
|
|
e.save()
|
|
|
|
e2 = Extensible.objects.get(id=e.id)
|
|
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
|
|
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
|
|
|
|
def create_invalid_mapping():
|
|
e.mapping['someint'] = 123
|
|
e.save()
|
|
|
|
self.assertRaises(ValidationError, create_invalid_mapping)
|
|
|
|
Extensible.drop_collection()
|
|
|
|
def test_embedded_document_validation(self):
|
|
"""Ensure that invalid embedded documents cannot be assigned to
|
|
embedded document fields.
|
|
"""
|
|
class Comment(EmbeddedDocument):
|
|
content = StringField()
|
|
|
|
class PersonPreferences(EmbeddedDocument):
|
|
food = StringField(required=True)
|
|
number = IntField()
|
|
|
|
class Person(Document):
|
|
name = StringField()
|
|
preferences = EmbeddedDocumentField(PersonPreferences)
|
|
|
|
person = Person(name='Test User')
|
|
person.preferences = 'My Preferences'
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
# Check that only the right embedded doc works
|
|
person.preferences = Comment(content='Nice blog post...')
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
# Check that the embedded doc is valid
|
|
person.preferences = PersonPreferences()
|
|
self.assertRaises(ValidationError, person.validate)
|
|
|
|
person.preferences = PersonPreferences(food='Cheese', number=47)
|
|
self.assertEqual(person.preferences.food, 'Cheese')
|
|
person.validate()
|
|
|
|
def test_embedded_document_inheritance(self):
|
|
"""Ensure that subclasses of embedded documents may be provided to
|
|
EmbeddedDocumentFields of the superclass' type.
|
|
"""
|
|
class User(EmbeddedDocument):
|
|
name = StringField()
|
|
|
|
class PowerUser(User):
|
|
power = IntField()
|
|
|
|
class BlogPost(Document):
|
|
content = StringField()
|
|
author = EmbeddedDocumentField(User)
|
|
|
|
post = BlogPost(content='What I did today...')
|
|
post.author = User(name='Test User')
|
|
post.author = PowerUser(name='Test User', power=47)
|
|
|
|
def test_reference_validation(self):
|
|
"""Ensure that invalid docment objects cannot be assigned to reference
|
|
fields.
|
|
"""
|
|
class User(Document):
|
|
name = StringField()
|
|
|
|
class BlogPost(Document):
|
|
content = StringField()
|
|
author = ReferenceField(User)
|
|
|
|
User.drop_collection()
|
|
BlogPost.drop_collection()
|
|
|
|
self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument)
|
|
|
|
user = User(name='Test User')
|
|
|
|
# Ensure that the referenced object must have been saved
|
|
post1 = BlogPost(content='Chips and gravy taste good.')
|
|
post1.author = user
|
|
self.assertRaises(ValidationError, post1.save)
|
|
|
|
# Check that an invalid object type cannot be used
|
|
post2 = BlogPost(content='Chips and chilli taste good.')
|
|
post1.author = post2
|
|
self.assertRaises(ValidationError, post1.validate)
|
|
|
|
user.save()
|
|
post1.author = user
|
|
post1.save()
|
|
|
|
post2.save()
|
|
post1.author = post2
|
|
self.assertRaises(ValidationError, post1.validate)
|
|
|
|
User.drop_collection()
|
|
BlogPost.drop_collection()
|
|
|
|
def test_list_item_dereference(self):
|
|
"""Ensure that DBRef items in ListFields are dereferenced.
|
|
"""
|
|
class User(Document):
|
|
name = StringField()
|
|
|
|
class Group(Document):
|
|
members = ListField(ReferenceField(User))
|
|
|
|
User.drop_collection()
|
|
Group.drop_collection()
|
|
|
|
user1 = User(name='user1')
|
|
user1.save()
|
|
user2 = User(name='user2')
|
|
user2.save()
|
|
|
|
group = Group(members=[user1, user2])
|
|
group.save()
|
|
|
|
group_obj = Group.objects.first()
|
|
|
|
self.assertEqual(group_obj.members[0].name, user1.name)
|
|
self.assertEqual(group_obj.members[1].name, user2.name)
|
|
|
|
User.drop_collection()
|
|
Group.drop_collection()
|
|
|
|
def test_recursive_reference(self):
|
|
"""Ensure that ReferenceFields can reference their own documents.
|
|
"""
|
|
class Employee(Document):
|
|
name = StringField()
|
|
boss = ReferenceField('self')
|
|
friends = ListField(ReferenceField('self'))
|
|
|
|
bill = Employee(name='Bill Lumbergh')
|
|
bill.save()
|
|
|
|
michael = Employee(name='Michael Bolton')
|
|
michael.save()
|
|
|
|
samir = Employee(name='Samir Nagheenanajar')
|
|
samir.save()
|
|
|
|
friends = [michael, samir]
|
|
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
|
|
peter.save()
|
|
|
|
peter = Employee.objects.with_id(peter.id)
|
|
self.assertEqual(peter.boss, bill)
|
|
self.assertEqual(peter.friends, friends)
|
|
|
|
def test_recursive_embedding(self):
|
|
"""Ensure that EmbeddedDocumentFields can contain their own documents.
|
|
"""
|
|
class Tree(Document):
|
|
name = StringField()
|
|
children = ListField(EmbeddedDocumentField('TreeNode'))
|
|
|
|
class TreeNode(EmbeddedDocument):
|
|
name = StringField()
|
|
children = ListField(EmbeddedDocumentField('self'))
|
|
|
|
tree = Tree(name="Tree")
|
|
|
|
first_child = TreeNode(name="Child 1")
|
|
tree.children.append(first_child)
|
|
|
|
second_child = TreeNode(name="Child 2")
|
|
first_child.children.append(second_child)
|
|
|
|
third_child = TreeNode(name="Child 3")
|
|
first_child.children.append(third_child)
|
|
|
|
tree.save()
|
|
|
|
tree_obj = Tree.objects.first()
|
|
self.assertEqual(len(tree.children), 1)
|
|
self.assertEqual(tree.children[0].name, first_child.name)
|
|
self.assertEqual(tree.children[0].children[0].name, second_child.name)
|
|
self.assertEqual(tree.children[0].children[1].name, third_child.name)
|
|
|
|
def test_undefined_reference(self):
|
|
"""Ensure that ReferenceFields may reference undefined Documents.
|
|
"""
|
|
class Product(Document):
|
|
name = StringField()
|
|
company = ReferenceField('Company')
|
|
|
|
class Company(Document):
|
|
name = StringField()
|
|
|
|
ten_gen = Company(name='10gen')
|
|
ten_gen.save()
|
|
mongodb = Product(name='MongoDB', company=ten_gen)
|
|
mongodb.save()
|
|
|
|
obj = Product.objects(company=ten_gen).first()
|
|
self.assertEqual(obj, mongodb)
|
|
self.assertEqual(obj.company, ten_gen)
|
|
|
|
def test_reference_query_conversion(self):
|
|
"""Ensure that ReferenceFields can be queried using objects and values
|
|
of the type of the primary key of the referenced object.
|
|
"""
|
|
class Member(Document):
|
|
user_num = IntField(primary_key=True)
|
|
|
|
class BlogPost(Document):
|
|
title = StringField()
|
|
author = ReferenceField(Member)
|
|
|
|
Member.drop_collection()
|
|
BlogPost.drop_collection()
|
|
|
|
m1 = Member(user_num=1)
|
|
m1.save()
|
|
m2 = Member(user_num=2)
|
|
m2.save()
|
|
|
|
post1 = BlogPost(title='post 1', author=m1)
|
|
post1.save()
|
|
|
|
post2 = BlogPost(title='post 2', author=m2)
|
|
post2.save()
|
|
|
|
post = BlogPost.objects(author=m1).first()
|
|
self.assertEqual(post.id, post1.id)
|
|
|
|
post = BlogPost.objects(author=m2).first()
|
|
self.assertEqual(post.id, post2.id)
|
|
|
|
Member.drop_collection()
|
|
BlogPost.drop_collection()
|
|
|
|
def test_generic_reference(self):
|
|
"""Ensure that a GenericReferenceField properly dereferences items.
|
|
"""
|
|
class Link(Document):
|
|
title = StringField()
|
|
meta = {'allow_inheritance': False}
|
|
|
|
class Post(Document):
|
|
title = StringField()
|
|
|
|
class Bookmark(Document):
|
|
bookmark_object = GenericReferenceField()
|
|
|
|
Link.drop_collection()
|
|
Post.drop_collection()
|
|
Bookmark.drop_collection()
|
|
|
|
link_1 = Link(title="Pitchfork")
|
|
link_1.save()
|
|
|
|
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
|
|
post_1.save()
|
|
|
|
bm = Bookmark(bookmark_object=post_1)
|
|
bm.save()
|
|
|
|
bm = Bookmark.objects(bookmark_object=post_1).first()
|
|
|
|
self.assertEqual(bm.bookmark_object, post_1)
|
|
self.assertTrue(isinstance(bm.bookmark_object, Post))
|
|
|
|
bm.bookmark_object = link_1
|
|
bm.save()
|
|
|
|
bm = Bookmark.objects(bookmark_object=link_1).first()
|
|
|
|
self.assertEqual(bm.bookmark_object, link_1)
|
|
self.assertTrue(isinstance(bm.bookmark_object, Link))
|
|
|
|
Link.drop_collection()
|
|
Post.drop_collection()
|
|
Bookmark.drop_collection()
|
|
|
|
def test_generic_reference_list(self):
|
|
"""Ensure that a ListField properly dereferences generic references.
|
|
"""
|
|
class Link(Document):
|
|
title = StringField()
|
|
|
|
class Post(Document):
|
|
title = StringField()
|
|
|
|
class User(Document):
|
|
bookmarks = ListField(GenericReferenceField())
|
|
|
|
Link.drop_collection()
|
|
Post.drop_collection()
|
|
User.drop_collection()
|
|
|
|
link_1 = Link(title="Pitchfork")
|
|
link_1.save()
|
|
|
|
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
|
|
post_1.save()
|
|
|
|
user = User(bookmarks=[post_1, link_1])
|
|
user.save()
|
|
|
|
user = User.objects(bookmarks__all=[post_1, link_1]).first()
|
|
|
|
self.assertEqual(user.bookmarks[0], post_1)
|
|
self.assertEqual(user.bookmarks[1], link_1)
|
|
|
|
Link.drop_collection()
|
|
Post.drop_collection()
|
|
User.drop_collection()
|
|
|
|
def test_generic_reference_document_not_registered(self):
|
|
"""Ensure dereferencing out of the document registry throws a
|
|
`NotRegistered` error.
|
|
"""
|
|
class Link(Document):
|
|
title = StringField()
|
|
|
|
class User(Document):
|
|
bookmarks = ListField(GenericReferenceField())
|
|
|
|
Link.drop_collection()
|
|
User.drop_collection()
|
|
|
|
link_1 = Link(title="Pitchfork")
|
|
link_1.save()
|
|
|
|
user = User(bookmarks=[link_1])
|
|
user.save()
|
|
|
|
# Mimic User and Link definitions being in a different file
|
|
# and the Link model not being imported in the User file.
|
|
del(_document_registry["Link"])
|
|
|
|
user = User.objects.first()
|
|
try:
|
|
user.bookmarks
|
|
raise AssertionError, "Link was removed from the registry"
|
|
except NotRegistered:
|
|
pass
|
|
|
|
Link.drop_collection()
|
|
User.drop_collection()
|
|
|
|
def test_binary_fields(self):
|
|
"""Ensure that binary fields can be stored and retrieved.
|
|
"""
|
|
class Attachment(Document):
|
|
content_type = StringField()
|
|
blob = BinaryField()
|
|
|
|
BLOB = '\xe6\x00\xc4\xff\x07'
|
|
MIME_TYPE = 'application/octet-stream'
|
|
|
|
Attachment.drop_collection()
|
|
|
|
attachment = Attachment(content_type=MIME_TYPE, blob=BLOB)
|
|
attachment.save()
|
|
|
|
attachment_1 = Attachment.objects().first()
|
|
self.assertEqual(MIME_TYPE, attachment_1.content_type)
|
|
self.assertEqual(BLOB, attachment_1.blob)
|
|
|
|
Attachment.drop_collection()
|
|
|
|
def test_binary_validation(self):
|
|
"""Ensure that invalid values cannot be assigned to binary fields.
|
|
"""
|
|
class Attachment(Document):
|
|
blob = BinaryField()
|
|
|
|
class AttachmentRequired(Document):
|
|
blob = BinaryField(required=True)
|
|
|
|
class AttachmentSizeLimit(Document):
|
|
blob = BinaryField(max_bytes=4)
|
|
|
|
Attachment.drop_collection()
|
|
AttachmentRequired.drop_collection()
|
|
AttachmentSizeLimit.drop_collection()
|
|
|
|
attachment = Attachment()
|
|
attachment.validate()
|
|
attachment.blob = 2
|
|
self.assertRaises(ValidationError, attachment.validate)
|
|
|
|
attachment_required = AttachmentRequired()
|
|
self.assertRaises(ValidationError, attachment_required.validate)
|
|
attachment_required.blob = '\xe6\x00\xc4\xff\x07'
|
|
attachment_required.validate()
|
|
|
|
attachment_size_limit = AttachmentSizeLimit(blob='\xe6\x00\xc4\xff\x07')
|
|
self.assertRaises(ValidationError, attachment_size_limit.validate)
|
|
attachment_size_limit.blob = '\xe6\x00\xc4\xff'
|
|
attachment_size_limit.validate()
|
|
|
|
Attachment.drop_collection()
|
|
AttachmentRequired.drop_collection()
|
|
AttachmentSizeLimit.drop_collection()
|
|
|
|
def test_choices_validation(self):
|
|
"""Ensure that value is in a container of allowed values.
|
|
"""
|
|
class Shirt(Document):
|
|
size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
|
|
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
|
|
|
|
Shirt.drop_collection()
|
|
|
|
shirt = Shirt()
|
|
shirt.validate()
|
|
|
|
shirt.size = "S"
|
|
shirt.validate()
|
|
|
|
shirt.size = "XS"
|
|
self.assertRaises(ValidationError, shirt.validate)
|
|
|
|
Shirt.drop_collection()
|
|
|
|
def test_choices_get_field_display(self):
|
|
"""Test dynamic helper for returning the display value of a choices field.
|
|
"""
|
|
class Shirt(Document):
|
|
size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
|
|
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
|
|
style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
|
|
|
|
Shirt.drop_collection()
|
|
|
|
shirt = Shirt()
|
|
|
|
self.assertEqual(shirt.get_size_display(), None)
|
|
self.assertEqual(shirt.get_style_display(), 'Small')
|
|
|
|
shirt.size = "XXL"
|
|
shirt.style = "B"
|
|
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
|
|
self.assertEqual(shirt.get_style_display(), 'Baggy')
|
|
|
|
# Set as Z - an invalid choice
|
|
shirt.size = "Z"
|
|
shirt.style = "Z"
|
|
self.assertEqual(shirt.get_size_display(), 'Z')
|
|
self.assertEqual(shirt.get_style_display(), 'Z')
|
|
self.assertRaises(ValidationError, shirt.validate)
|
|
|
|
Shirt.drop_collection()
|
|
|
|
def test_file_fields(self):
|
|
"""Ensure that file fields can be written to and their data retrieved
|
|
"""
|
|
class PutFile(Document):
|
|
file = FileField()
|
|
|
|
class StreamFile(Document):
|
|
file = FileField()
|
|
|
|
class SetFile(Document):
|
|
file = FileField()
|
|
|
|
text = 'Hello, World!'
|
|
more_text = 'Foo Bar'
|
|
content_type = 'text/plain'
|
|
|
|
PutFile.drop_collection()
|
|
StreamFile.drop_collection()
|
|
SetFile.drop_collection()
|
|
|
|
putfile = PutFile()
|
|
putfile.file.put(text, content_type=content_type)
|
|
putfile.save()
|
|
putfile.validate()
|
|
result = PutFile.objects.first()
|
|
self.assertTrue(putfile == result)
|
|
self.assertEquals(result.file.read(), text)
|
|
self.assertEquals(result.file.content_type, content_type)
|
|
result.file.delete() # Remove file from GridFS
|
|
|
|
streamfile = StreamFile()
|
|
streamfile.file.new_file(content_type=content_type)
|
|
streamfile.file.write(text)
|
|
streamfile.file.write(more_text)
|
|
streamfile.file.close()
|
|
streamfile.save()
|
|
streamfile.validate()
|
|
result = StreamFile.objects.first()
|
|
self.assertTrue(streamfile == result)
|
|
self.assertEquals(result.file.read(), text + more_text)
|
|
self.assertEquals(result.file.content_type, content_type)
|
|
result.file.seek(0)
|
|
self.assertEquals(result.file.tell(), 0)
|
|
self.assertEquals(result.file.read(len(text)), text)
|
|
self.assertEquals(result.file.tell(), len(text))
|
|
self.assertEquals(result.file.read(len(more_text)), more_text)
|
|
self.assertEquals(result.file.tell(), len(text + more_text))
|
|
result.file.delete()
|
|
|
|
# Ensure deleted file returns None
|
|
self.assertTrue(result.file.read() == None)
|
|
|
|
setfile = SetFile()
|
|
setfile.file = text
|
|
setfile.save()
|
|
setfile.validate()
|
|
result = SetFile.objects.first()
|
|
self.assertTrue(setfile == result)
|
|
self.assertEquals(result.file.read(), text)
|
|
|
|
# Try replacing file with new one
|
|
result.file.replace(more_text)
|
|
result.save()
|
|
result.validate()
|
|
result = SetFile.objects.first()
|
|
self.assertTrue(setfile == result)
|
|
self.assertEquals(result.file.read(), more_text)
|
|
result.file.delete()
|
|
|
|
PutFile.drop_collection()
|
|
StreamFile.drop_collection()
|
|
SetFile.drop_collection()
|
|
|
|
# Make sure FileField is optional and not required
|
|
class DemoFile(Document):
|
|
file = FileField()
|
|
d = DemoFile.objects.create()
|
|
|
|
def test_file_uniqueness(self):
|
|
"""Ensure that each instance of a FileField is unique
|
|
"""
|
|
class TestFile(Document):
|
|
name = StringField()
|
|
file = FileField()
|
|
|
|
# First instance
|
|
testfile = TestFile()
|
|
testfile.name = "Hello, World!"
|
|
testfile.file.put('Hello, World!')
|
|
testfile.save()
|
|
|
|
# Second instance
|
|
testfiledupe = TestFile()
|
|
data = testfiledupe.file.read() # Should be None
|
|
|
|
self.assertTrue(testfile.name != testfiledupe.name)
|
|
self.assertTrue(testfile.file.read() != data)
|
|
|
|
TestFile.drop_collection()
|
|
|
|
def test_geo_indexes(self):
|
|
"""Ensure that indexes are created automatically for GeoPointFields.
|
|
"""
|
|
class Event(Document):
|
|
title = StringField()
|
|
location = GeoPointField()
|
|
|
|
Event.drop_collection()
|
|
event = Event(title="Coltrane Motion @ Double Door",
|
|
location=[41.909889, -87.677137])
|
|
event.save()
|
|
|
|
info = Event.objects._collection.index_information()
|
|
self.assertTrue(u'location_2d' in info)
|
|
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
|
|
|
|
Event.drop_collection()
|
|
|
|
def test_ensure_unique_default_instances(self):
|
|
"""Ensure that every field has it's own unique default instance."""
|
|
class D(Document):
|
|
data = DictField()
|
|
data2 = DictField(default=lambda: {})
|
|
|
|
d1 = D()
|
|
d1.data['foo'] = 'bar'
|
|
d1.data2['foo'] = 'bar'
|
|
d2 = D()
|
|
self.assertEqual(d2.data, {})
|
|
self.assertEqual(d2.data2, {})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|