Files
mongoengine/tests/fields/fields.py
Nigel McNie 4c9e90732e Apply defaults to fields with None value at 'set' time.
If a field has a default, and you explicitly set it to None, the
behaviour before this patch was very confusing:

    class Person(Document):
        created = DateTimeField(default=datetime.datetime.utcnow)

    >>> p = Person(created=None)
    >>> p.created
    datetime.datetime(2013, 5, 30, 0, 18, 20, 242628)
    >>> p.created
    datetime.datetime(2013, 5, 30, 0, 18, 20, 995248)
    >>> p.created
    datetime.datetime(2013, 5, 30, 0, 18, 21, 370578)

It would be stored as None, and then at 'get' time, the default would be
applied. As you can see, if the default is a generator, this leads to some
crazy behaviour.

There's an argument that if I asked it to be set to None, why not respect that?
But I don't think that's how the rest of mongoengine seems to work (for
example, setting a field to None seems to mean it doesn't even get set in mongo
- as opposed to being set but with a 'null' value). Besides, as the code shows
above, you'd expect p.created to return None. So clearly, mongoengine is
already expecting None to mean 'default' where a default is available.

This bug also interacts nastily with required=True - if you're forcibly setting
the field to None, then at validation time, the None will fail validation
despite a perfectly valid default being available.

With this patch, when the field is set, the default is immediately applied.
This means any generation happens once, the getter always returns the same
value, and 'required' validation always respects the default.

Note: this breakage seems to be new since mongoengine 0.8.
2013-05-30 16:37:40 +12:00

2257 lines
71 KiB
Python

# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import datetime
import unittest
import uuid
from decimal import Decimal
from bson import Binary, DBRef, ObjectId
from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.base import _document_registry
from mongoengine.errors import NotRegistered
from mongoengine.python_support import PY3, b, bin_type
__all__ = ("FieldTest", )
class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def tearDown(self):
self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks')
def test_default_values(self):
"""Ensure that default field values are used when creating a document.
"""
class Person(Document):
name = StringField()
age = IntField(default=30, help_text="Your real age")
userid = StringField(default=lambda: 'test', verbose_name="User Identity")
person = Person(name='Test Person')
self.assertEqual(person._data['age'], 30)
self.assertEqual(person._data['userid'], 'test')
self.assertEqual(person._fields['name'].help_text, None)
self.assertEqual(person._fields['age'].help_text, "Your real age")
self.assertEqual(person._fields['userid'].verbose_name, "User Identity")
class Person2(Document):
created = DateTimeField(default=datetime.datetime.utcnow)
person = Person2()
date1 = person.created
date2 = person.created
self.assertEqual(date1, date2)
person = Person2(created=None)
date1 = person.created
date2 = person.created
self.assertEqual(date1, date2)
def test_required_values(self):
"""Ensure that required field constraints are enforced.
"""
class Person(Document):
name = StringField(required=True)
age = IntField(required=True)
userid = StringField()
person = Person(name="Test User")
self.assertRaises(ValidationError, person.validate)
person = Person(age=30)
self.assertRaises(ValidationError, person.validate)
def test_not_required_handles_none_in_update(self):
"""Ensure that every fields should accept None if required is False.
"""
class HandleNoneFields(Document):
str_fld = StringField()
int_fld = IntField()
flt_fld = FloatField()
comp_dt_fld = ComplexDateTimeField()
HandleNoneFields.drop_collection()
doc = HandleNoneFields()
doc.str_fld = u'spam ham egg'
doc.int_fld = 42
doc.flt_fld = 4.2
doc.com_dt_fld = datetime.datetime.utcnow()
doc.save()
res = HandleNoneFields.objects(id=doc.id).update(
set__str_fld=None,
set__int_fld=None,
set__flt_fld=None,
set__comp_dt_fld=None,
)
self.assertEqual(res, 1)
# Retrive data from db and verify it.
ret = HandleNoneFields.objects.all()[0]
self.assertEqual(ret.str_fld, None)
self.assertEqual(ret.int_fld, None)
self.assertEqual(ret.flt_fld, None)
# Return current time if retrived value is None.
self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime))
def test_not_required_handles_none_from_database(self):
"""Ensure that every fields can handle null values from the database.
"""
class HandleNoneFields(Document):
str_fld = StringField(required=True)
int_fld = IntField(required=True)
flt_fld = FloatField(required=True)
comp_dt_fld = ComplexDateTimeField(required=True)
HandleNoneFields.drop_collection()
doc = HandleNoneFields()
doc.str_fld = u'spam ham egg'
doc.int_fld = 42
doc.flt_fld = 4.2
doc.com_dt_fld = datetime.datetime.utcnow()
doc.save()
collection = self.db[HandleNoneFields._get_collection_name()]
obj = collection.update({"_id": doc.id}, {"$unset": {
"str_fld": 1,
"int_fld": 1,
"flt_fld": 1,
"comp_dt_fld": 1}
})
# Retrive data from db and verify it.
ret = HandleNoneFields.objects.all()[0]
self.assertEqual(ret.str_fld, None)
self.assertEqual(ret.int_fld, None)
self.assertEqual(ret.flt_fld, None)
# Return current time if retrived value is None.
self.assertTrue(isinstance(ret.comp_dt_fld, datetime.datetime))
self.assertRaises(ValidationError, ret.validate)
def test_int_and_float_ne_operator(self):
class TestDocument(Document):
int_fld = IntField()
float_fld = FloatField()
TestDocument.drop_collection()
TestDocument(int_fld=None, float_fld=None).save()
TestDocument(int_fld=1, float_fld=1).save()
self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count())
self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count())
def test_long_ne_operator(self):
class TestDocument(Document):
long_fld = LongField()
TestDocument.drop_collection()
TestDocument(long_fld=None).save()
TestDocument(long_fld=1).save()
self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count())
def test_object_id_validation(self):
"""Ensure that invalid values cannot be assigned to string fields.
"""
class Person(Document):
name = StringField()
person = Person(name='Test User')
self.assertEqual(person.id, None)
person.id = 47
self.assertRaises(ValidationError, person.validate)
person.id = 'abc'
self.assertRaises(ValidationError, person.validate)
person.id = '497ce96f395f2f052a494fd4'
person.validate()
def test_string_validation(self):
"""Ensure that invalid values cannot be assigned to string fields.
"""
class Person(Document):
name = StringField(max_length=20)
userid = StringField(r'[0-9a-z_]+$')
person = Person(name=34)
self.assertRaises(ValidationError, person.validate)
# Test regex validation on userid
person = Person(userid='test.User')
self.assertRaises(ValidationError, person.validate)
person.userid = 'test_user'
self.assertEqual(person.userid, 'test_user')
person.validate()
# Test max length validation on name
person = Person(name='Name that is more than twenty characters')
self.assertRaises(ValidationError, person.validate)
person.name = 'Shorter name'
person.validate()
def test_url_validation(self):
"""Ensure that URLFields validate urls properly.
"""
class Link(Document):
url = URLField()
link = Link()
link.url = 'google'
self.assertRaises(ValidationError, link.validate)
link.url = 'http://www.google.com:8080'
link.validate()
def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields.
"""
class Person(Document):
age = IntField(min_value=0, max_value=110)
person = Person()
person.age = 50
person.validate()
person.age = -1
self.assertRaises(ValidationError, person.validate)
person.age = 120
self.assertRaises(ValidationError, person.validate)
person.age = 'ten'
self.assertRaises(ValidationError, person.validate)
def test_long_validation(self):
"""Ensure that invalid values cannot be assigned to long fields.
"""
class TestDocument(Document):
value = LongField(min_value=0, max_value=110)
doc = TestDocument()
doc.value = 50
doc.validate()
doc.value = -1
self.assertRaises(ValidationError, doc.validate)
doc.age = 120
self.assertRaises(ValidationError, doc.validate)
doc.age = 'ten'
self.assertRaises(ValidationError, doc.validate)
def test_float_validation(self):
"""Ensure that invalid values cannot be assigned to float fields.
"""
class Person(Document):
height = FloatField(min_value=0.1, max_value=3.5)
person = Person()
person.height = 1.89
person.validate()
person.height = '2.0'
self.assertRaises(ValidationError, person.validate)
person.height = 0.01
self.assertRaises(ValidationError, person.validate)
person.height = 4.0
self.assertRaises(ValidationError, person.validate)
def test_decimal_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields.
"""
class Person(Document):
height = DecimalField(min_value=Decimal('0.1'),
max_value=Decimal('3.5'))
Person.drop_collection()
Person(height=Decimal('1.89')).save()
person = Person.objects.first()
self.assertEqual(person.height, Decimal('1.89'))
person.height = '2.0'
person.save()
person.height = 0.01
self.assertRaises(ValidationError, person.validate)
person.height = Decimal('0.01')
self.assertRaises(ValidationError, person.validate)
person.height = Decimal('4.0')
self.assertRaises(ValidationError, person.validate)
Person.drop_collection()
def test_decimal_comparison(self):
class Person(Document):
money = DecimalField()
Person.drop_collection()
Person(money=6).save()
Person(money=8).save()
Person(money=10).save()
self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count())
self.assertEqual(2, Person.objects(money__gt=7).count())
self.assertEqual(2, Person.objects(money__gt="7").count())
def test_decimal_storage(self):
class Person(Document):
btc = DecimalField(precision=4)
Person.drop_collection()
Person(btc=10).save()
Person(btc=10.1).save()
Person(btc=10.11).save()
Person(btc="10.111").save()
Person(btc=Decimal("10.1111")).save()
Person(btc=Decimal("10.11111")).save()
# How its stored
expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11},
{'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}]
actual = list(Person.objects.exclude('id').as_pymongo())
self.assertEqual(expected, actual)
# How it comes out locally
expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'),
Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')]
actual = list(Person.objects().scalar('btc'))
self.assertEqual(expected, actual)
def test_boolean_validation(self):
"""Ensure that invalid values cannot be assigned to boolean fields.
"""
class Person(Document):
admin = BooleanField()
person = Person()
person.admin = True
person.validate()
person.admin = 2
self.assertRaises(ValidationError, person.validate)
person.admin = 'Yes'
self.assertRaises(ValidationError, person.validate)
def test_uuid_field_string(self):
"""Test UUID fields storing as String
"""
class Person(Document):
api_key = UUIDField(binary=False)
Person.drop_collection()
uu = uuid.uuid4()
Person(api_key=uu).save()
self.assertEqual(1, Person.objects(api_key=uu).count())
self.assertEqual(uu, Person.objects.first().api_key)
person = Person()
valid = (uuid.uuid4(), uuid.uuid1())
for api_key in valid:
person.api_key = api_key
person.validate()
invalid = ('9d159858-549b-4975-9f98-dd2f987c113g',
'9d159858-549b-4975-9f98-dd2f987c113')
for api_key in invalid:
person.api_key = api_key
self.assertRaises(ValidationError, person.validate)
def test_uuid_field_binary(self):
"""Test UUID fields storing as Binary object
"""
class Person(Document):
api_key = UUIDField(binary=True)
Person.drop_collection()
uu = uuid.uuid4()
Person(api_key=uu).save()
self.assertEqual(1, Person.objects(api_key=uu).count())
self.assertEqual(uu, Person.objects.first().api_key)
person = Person()
valid = (uuid.uuid4(), uuid.uuid1())
for api_key in valid:
person.api_key = api_key
person.validate()
invalid = ('9d159858-549b-4975-9f98-dd2f987c113g',
'9d159858-549b-4975-9f98-dd2f987c113')
for api_key in invalid:
person.api_key = api_key
self.assertRaises(ValidationError, person.validate)
def test_datetime_validation(self):
"""Ensure that invalid values cannot be assigned to datetime fields.
"""
class LogEntry(Document):
time = DateTimeField()
log = LogEntry()
log.time = datetime.datetime.now()
log.validate()
log.time = datetime.date.today()
log.validate()
log.time = -1
self.assertRaises(ValidationError, log.validate)
log.time = '1pm'
self.assertRaises(ValidationError, log.validate)
def test_datetime_tz_aware_mark_as_changed(self):
from mongoengine import connection
# Reset the connections
connection._connection_settings = {}
connection._connections = {}
connection._dbs = {}
connect(db='mongoenginetest', tz_aware=True)
class LogEntry(Document):
time = DateTimeField()
LogEntry.drop_collection()
LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save()
log = LogEntry.objects.first()
log.time = datetime.datetime(2013, 1, 1, 0, 0, 0)
self.assertEqual(['time'], log._changed_fields)
def test_datetime(self):
"""Tests showing pymongo datetime fields handling of microseconds.
Microseconds are rounded to the nearest millisecond and pre UTC
handling is wonky.
See: http://api.mongodb.org/python/current/api/bson/son.html#dt
"""
class LogEntry(Document):
date = DateTimeField()
LogEntry.drop_collection()
# Test can save dates
log = LogEntry()
log.date = datetime.date.today()
log.save()
log.reload()
self.assertEqual(log.date.date(), datetime.date.today())
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
log = LogEntry()
log.date = d1
log.save()
log.reload()
self.assertNotEqual(log.date, d1)
self.assertEqual(log.date, d2)
# Post UTC - microseconds are rounded (down) nearest millisecond
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000)
log.date = d1
log.save()
log.reload()
self.assertNotEqual(log.date, d1)
self.assertEqual(log.date, d2)
if not PY3:
# Pre UTC dates microseconds below 1000 are dropped
# This does not seem to be true in PY3
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
d2 = datetime.datetime(1969, 12, 31, 23, 59, 59)
log.date = d1
log.save()
log.reload()
self.assertNotEqual(log.date, d1)
self.assertEqual(log.date, d2)
LogEntry.drop_collection()
def test_complexdatetime_storage(self):
"""Tests for complex datetime fields - which can handle microseconds
without rounding.
"""
class LogEntry(Document):
date = ComplexDateTimeField()
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
log.save()
log.reload()
self.assertEqual(log.date, d1)
# Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
log.date = d1
log.save()
log.reload()
self.assertEqual(log.date, d1)
# Pre UTC dates microseconds below 1000 are dropped - with default datetimefields
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
log.date = d1
log.save()
log.reload()
self.assertEqual(log.date, d1)
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
# log.date has an invalid microsecond value so I can't construct
# a date to compare.
for i in xrange(1001, 3113, 33):
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
log.date = d1
log.save()
log.reload()
self.assertEqual(log.date, d1)
log1 = LogEntry.objects.get(date=d1)
self.assertEqual(log, log1)
LogEntry.drop_collection()
def test_complexdatetime_usage(self):
"""Tests for complex datetime fields - which can handle microseconds
without rounding.
"""
class LogEntry(Document):
date = ComplexDateTimeField()
LogEntry.drop_collection()
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
log.save()
log1 = LogEntry.objects.get(date=d1)
self.assertEqual(log, log1)
LogEntry.drop_collection()
# create 60 log entries
for i in xrange(1950, 2010):
d = datetime.datetime(i, 01, 01, 00, 00, 01, 999)
LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 60)
# Test ordering
logs = LogEntry.objects.order_by("date")
count = logs.count()
i = 0
while i == count - 1:
self.assertTrue(logs[i].date <= logs[i + 1].date)
i += 1
logs = LogEntry.objects.order_by("-date")
count = logs.count()
i = 0
while i == count - 1:
self.assertTrue(logs[i].date >= logs[i + 1].date)
i += 1
# Test searching
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(
date__lte=datetime.datetime(2011, 1, 1),
date__gte=datetime.datetime(2000, 1, 1),
)
self.assertEqual(logs.count(), 10)
LogEntry.drop_collection()
def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements.
"""
class User(Document):
pass
class Comment(EmbeddedDocument):
content = StringField()
class BlogPost(Document):
content = StringField()
comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField())
authors = ListField(ReferenceField(User))
generic = ListField(GenericReferenceField())
post = BlogPost(content='Went for a walk today...')
post.validate()
post.tags = 'fun'
self.assertRaises(ValidationError, post.validate)
post.tags = [1, 2]
self.assertRaises(ValidationError, post.validate)
post.tags = ['fun', 'leisure']
post.validate()
post.tags = ('fun', 'leisure')
post.validate()
post.comments = ['a']
self.assertRaises(ValidationError, post.validate)
post.comments = 'yay'
self.assertRaises(ValidationError, post.validate)
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
post.comments = comments
post.validate()
post.authors = [Comment()]
self.assertRaises(ValidationError, post.validate)
post.authors = [User()]
self.assertRaises(ValidationError, post.validate)
user = User()
user.save()
post.authors = [user]
post.validate()
post.generic = [1, 2]
self.assertRaises(ValidationError, post.validate)
post.generic = [User(), Comment()]
self.assertRaises(ValidationError, post.validate)
post.generic = [Comment()]
self.assertRaises(ValidationError, post.validate)
post.generic = [user]
post.validate()
User.drop_collection()
BlogPost.drop_collection()
def test_sorted_list_sorting(self):
"""Ensure that a sorted list field properly sorts values.
"""
class Comment(EmbeddedDocument):
order = IntField()
content = StringField()
class BlogPost(Document):
content = StringField()
comments = SortedListField(EmbeddedDocumentField(Comment),
ordering='order')
tags = SortedListField(StringField())
post = BlogPost(content='Went for a walk today...')
post.save()
post.tags = ['leisure', 'fun']
post.save()
post.reload()
self.assertEqual(post.tags, ['fun', 'leisure'])
comment1 = Comment(content='Good for you', order=1)
comment2 = Comment(content='Yay.', order=0)
comments = [comment1, comment2]
post.comments = comments
post.save()
post.reload()
self.assertEqual(post.comments[0].content, comment2.content)
self.assertEqual(post.comments[1].content, comment1.content)
BlogPost.drop_collection()
def test_reverse_list_sorting(self):
'''Ensure that a reverse sorted list field properly sorts values'''
class Category(EmbeddedDocument):
count = IntField()
name = StringField()
class CategoryList(Document):
categories = SortedListField(EmbeddedDocumentField(Category),
ordering='count', reverse=True)
name = StringField()
catlist = CategoryList(name="Top categories")
cat1 = Category(name='posts', count=10)
cat2 = Category(name='food', count=100)
cat3 = Category(name='drink', count=40)
catlist.categories = [cat1, cat2, cat3]
catlist.save()
catlist.reload()
self.assertEqual(catlist.categories[0].name, cat2.name)
self.assertEqual(catlist.categories[1].name, cat3.name)
self.assertEqual(catlist.categories[2].name, cat1.name)
CategoryList.drop_collection()
def test_list_field(self):
"""Ensure that list types work as expected.
"""
class BlogPost(Document):
info = ListField()
BlogPost.drop_collection()
post = BlogPost()
post.info = 'my post'
self.assertRaises(ValidationError, post.validate)
post.info = {'title': 'test'}
self.assertRaises(ValidationError, post.validate)
post.info = ['test']
post.save()
post = BlogPost()
post.info = [{'test': 'test'}]
post.save()
post = BlogPost()
post.info = [{'test': 3}]
post.save()
self.assertEqual(BlogPost.objects.count(), 3)
self.assertEqual(BlogPost.objects.filter(info__exact='test').count(), 1)
self.assertEqual(BlogPost.objects.filter(info__0__test='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEqual(BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
self.assertEqual(BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_list_field_passed_in_value(self):
class Foo(Document):
bars = ListField(ReferenceField("Bar"))
class Bar(Document):
text = StringField()
bar = Bar(text="hi")
bar.save()
foo = Foo(bars=[])
foo.bars.append(bar)
self.assertEqual(repr(foo.bars), '[<Bar: Bar object>]')
def test_list_field_strict(self):
"""Ensure that list field handles validation if provided a strict field type."""
class Simple(Document):
mapping = ListField(field=IntField())
Simple.drop_collection()
e = Simple()
e.mapping = [1]
e.save()
def create_invalid_mapping():
e.mapping = ["abc"]
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_list_field_rejects_strings(self):
"""Strings aren't valid list field data types"""
class Simple(Document):
mapping = ListField()
Simple.drop_collection()
e = Simple()
e.mapping = 'hello world'
self.assertRaises(ValidationError, e.save)
def test_complex_field_required(self):
"""Ensure required cant be None / Empty"""
class Simple(Document):
mapping = ListField(required=True)
Simple.drop_collection()
e = Simple()
e.mapping = []
self.assertRaises(ValidationError, e.save)
class Simple(Document):
mapping = DictField(required=True)
Simple.drop_collection()
e = Simple()
e.mapping = {}
self.assertRaises(ValidationError, e.save)
def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types."""
class SettingBase(EmbeddedDocument):
meta = {'allow_inheritance': True}
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Simple(Document):
mapping = ListField()
Simple.drop_collection()
e = Simple()
e.mapping.append(StringSetting(value='foo'))
e.mapping.append(IntegerSetting(value=42))
e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42),
'list': [IntegerSetting(value=42),
StringSetting(value='foo')]})
e.save()
e2 = Simple.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping[0], StringSetting))
self.assertTrue(isinstance(e2.mapping[1], IntegerSetting))
# Test querying
self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1)
Simple.objects().update(
set__mapping__2__list__1=StringSetting(value='Boo'))
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
Simple.drop_collection()
def test_dict_field(self):
"""Ensure that dict types work as expected.
"""
class BlogPost(Document):
info = DictField()
BlogPost.drop_collection()
post = BlogPost()
post.info = 'my post'
self.assertRaises(ValidationError, post.validate)
post.info = ['test', 'test']
self.assertRaises(ValidationError, post.validate)
post.info = {'$title': 'test'}
self.assertRaises(ValidationError, post.validate)
post.info = {'the.title': 'test'}
self.assertRaises(ValidationError, post.validate)
post.info = {1: 'test'}
self.assertRaises(ValidationError, post.validate)
post.info = {'title': 'test'}
post.save()
post = BlogPost()
post.info = {'details': {'test': 'test'}}
post.save()
post = BlogPost()
post.info = {'details': {'test': 3}}
post.save()
self.assertEqual(BlogPost.objects.count(), 3)
self.assertEqual(BlogPost.objects.filter(info__title__exact='test').count(), 1)
self.assertEqual(BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEqual(BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
self.assertEqual(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
post = BlogPost.objects.create(info={'title': 'original'})
post.info.update({'title': 'updated'})
post.save()
post.reload()
self.assertEqual('updated', post.info['title'])
BlogPost.drop_collection()
def test_dictfield_strict(self):
"""Ensure that dict field handles validation if provided a strict field type."""
class Simple(Document):
mapping = DictField(field=IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_dictfield_complex(self):
"""Ensure that the dict field can handle the complex types."""
class SettingBase(EmbeddedDocument):
meta = {'allow_inheritance': True}
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Simple(Document):
mapping = DictField()
Simple.drop_collection()
e = Simple()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!',
'float': 1.001,
'complex': IntegerSetting(value=42),
'list': [IntegerSetting(value=42),
StringSetting(value='foo')]}
e.save()
e2 = Simple.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
# Test querying
self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(
set__mapping={"someint": IntegerSetting(value=10)})
Simple.objects().update(
set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
Simple.drop_collection()
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
meta = {"allow_inheritance": True}
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
def test_embedded_mapfield_db_field(self):
class Embedded(EmbeddedDocument):
number = IntField(default=0, db_field='i')
class Test(Document):
my_map = MapField(field=EmbeddedDocumentField(Embedded),
db_field='x')
Test.drop_collection()
test = Test()
test.my_map['DICTIONARY_KEY'] = Embedded(number=1)
test.save()
Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1)
test = Test.objects.get()
self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2)
doc = self.db.test.find_one()
self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2)
def test_mapfield_numerical_index(self):
"""Ensure that MapField accept numeric strings as indexes."""
class Embedded(EmbeddedDocument):
name = StringField()
class Test(Document):
my_map = MapField(EmbeddedDocumentField(Embedded))
Test.drop_collection()
test = Test()
test.my_map['1'] = Embedded(name='test')
test.save()
test.my_map['1'].name = 'test updated'
test.save()
Test.drop_collection()
def test_map_field_lookup(self):
"""Ensure MapField lookups succeed on Fields without a lookup method"""
class Log(Document):
name = StringField()
visited = MapField(DateTimeField())
Log.drop_collection()
Log(name="wilson", visited={'friends': datetime.datetime.now()}).save()
self.assertEqual(1, Log.objects(
visited__friends__exists=True).count())
def test_embedded_db_field(self):
class Embedded(EmbeddedDocument):
number = IntField(default=0, db_field='i')
class Test(Document):
embedded = EmbeddedDocumentField(Embedded, db_field='x')
Test.drop_collection()
test = Test()
test.embedded = Embedded(number=1)
test.save()
Test.objects.update_one(inc__embedded__number=1)
test = Test.objects.get()
self.assertEqual(test.embedded.number, 2)
doc = self.db.test.find_one()
self.assertEqual(doc['x']['i'], 2)
def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to
embedded document fields.
"""
class Comment(EmbeddedDocument):
content = StringField()
class PersonPreferences(EmbeddedDocument):
food = StringField(required=True)
number = IntField()
class Person(Document):
name = StringField()
preferences = EmbeddedDocumentField(PersonPreferences)
person = Person(name='Test User')
person.preferences = 'My Preferences'
self.assertRaises(ValidationError, person.validate)
# Check that only the right embedded doc works
person.preferences = Comment(content='Nice blog post...')
self.assertRaises(ValidationError, person.validate)
# Check that the embedded doc is valid
person.preferences = PersonPreferences()
self.assertRaises(ValidationError, person.validate)
person.preferences = PersonPreferences(food='Cheese', number=47)
self.assertEqual(person.preferences.food, 'Cheese')
person.validate()
def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided to
EmbeddedDocumentFields of the superclass' type.
"""
class User(EmbeddedDocument):
name = StringField()
meta = {'allow_inheritance': True}
class PowerUser(User):
power = IntField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...')
post.author = PowerUser(name='Test User', power=47)
post.save()
self.assertEqual(47, BlogPost.objects.first().author.power)
def test_reference_validation(self):
"""Ensure that invalid docment objects cannot be assigned to reference
fields.
"""
class User(Document):
name = StringField()
class BlogPost(Document):
content = StringField()
author = ReferenceField(User)
User.drop_collection()
BlogPost.drop_collection()
self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument)
user = User(name='Test User')
# Ensure that the referenced object must have been saved
post1 = BlogPost(content='Chips and gravy taste good.')
post1.author = user
self.assertRaises(ValidationError, post1.save)
# Check that an invalid object type cannot be used
post2 = BlogPost(content='Chips and chilli taste good.')
post1.author = post2
self.assertRaises(ValidationError, post1.validate)
user.save()
post1.author = user
post1.save()
post2.save()
post1.author = post2
self.assertRaises(ValidationError, post1.validate)
User.drop_collection()
BlogPost.drop_collection()
def test_dbref_reference_fields(self):
class Person(Document):
name = StringField()
parent = ReferenceField('self', dbref=True)
Person.drop_collection()
p1 = Person(name="John").save()
Person(name="Ross", parent=p1).save()
col = Person._get_collection()
data = col.find_one({'name': 'Ross'})
self.assertEqual(data['parent'], DBRef('person', p1.pk))
p = Person.objects.get(name="Ross")
self.assertEqual(p.parent, p1)
def test_dbref_to_mongo(self):
class Person(Document):
name = StringField()
parent = ReferenceField('self', dbref=False)
p1 = Person._from_son({'name': "Yakxxx",
'parent': "50a234ea469ac1eda42d347d"})
mongoed = p1.to_mongo()
self.assertTrue(isinstance(mongoed['parent'], ObjectId))
def test_objectid_reference_fields(self):
class Person(Document):
name = StringField()
parent = ReferenceField('self', dbref=False)
Person.drop_collection()
p1 = Person(name="John").save()
Person(name="Ross", parent=p1).save()
col = Person._get_collection()
data = col.find_one({'name': 'Ross'})
self.assertEqual(data['parent'], p1.pk)
p = Person.objects.get(name="Ross")
self.assertEqual(p.parent, p1)
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
members = ListField(ReferenceField(User))
User.drop_collection()
Group.drop_collection()
user1 = User(name='user1')
user1.save()
user2 = User(name='user2')
user2.save()
group = Group(members=[user1, user2])
group.save()
group_obj = Group.objects.first()
self.assertEqual(group_obj.members[0].name, user1.name)
self.assertEqual(group_obj.members[1].name, user2.name)
User.drop_collection()
Group.drop_collection()
def test_recursive_reference(self):
"""Ensure that ReferenceFields can reference their own documents.
"""
class Employee(Document):
name = StringField()
boss = ReferenceField('self')
friends = ListField(ReferenceField('self'))
Employee.drop_collection()
bill = Employee(name='Bill Lumbergh')
bill.save()
michael = Employee(name='Michael Bolton')
michael.save()
samir = Employee(name='Samir Nagheenanajar')
samir.save()
friends = [michael, samir]
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
peter.save()
peter = Employee.objects.with_id(peter.id)
self.assertEqual(peter.boss, bill)
self.assertEqual(peter.friends, friends)
def test_recursive_embedding(self):
"""Ensure that EmbeddedDocumentFields can contain their own documents.
"""
class Tree(Document):
name = StringField()
children = ListField(EmbeddedDocumentField('TreeNode'))
class TreeNode(EmbeddedDocument):
name = StringField()
children = ListField(EmbeddedDocumentField('self'))
Tree.drop_collection()
tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1")
tree.children.append(first_child)
second_child = TreeNode(name="Child 2")
first_child.children.append(second_child)
tree.save()
tree = Tree.objects.first()
self.assertEqual(len(tree.children), 1)
self.assertEqual(len(tree.children[0].children), 1)
third_child = TreeNode(name="Child 3")
tree.children[0].children.append(third_child)
tree.save()
self.assertEqual(len(tree.children), 1)
self.assertEqual(tree.children[0].name, first_child.name)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
# Test updating
tree.children[0].name = 'I am Child 1'
tree.children[0].children[0].name = 'I am Child 2'
tree.children[0].children[1].name = 'I am Child 3'
tree.save()
self.assertEqual(tree.children[0].name, 'I am Child 1')
self.assertEqual(tree.children[0].children[0].name, 'I am Child 2')
self.assertEqual(tree.children[0].children[1].name, 'I am Child 3')
# Test removal
self.assertEqual(len(tree.children[0].children), 2)
del(tree.children[0].children[1])
tree.save()
self.assertEqual(len(tree.children[0].children), 1)
tree.children[0].children.pop(0)
tree.save()
self.assertEqual(len(tree.children[0].children), 0)
self.assertEqual(tree.children[0].children, [])
tree.children[0].children.insert(0, third_child)
tree.children[0].children.insert(0, second_child)
tree.save()
self.assertEqual(len(tree.children[0].children), 2)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents.
"""
class Product(Document):
name = StringField()
company = ReferenceField('Company')
class Company(Document):
name = StringField()
Product.drop_collection()
Company.drop_collection()
ten_gen = Company(name='10gen')
ten_gen.save()
mongodb = Product(name='MongoDB', company=ten_gen)
mongodb.save()
me = Product(name='MongoEngine')
me.save()
obj = Product.objects(company=ten_gen).first()
self.assertEqual(obj, mongodb)
self.assertEqual(obj.company, ten_gen)
obj = Product.objects(company=None).first()
self.assertEqual(obj, me)
obj, created = Product.objects.get_or_create(company=None)
self.assertEqual(created, False)
self.assertEqual(obj, me)
def test_reference_query_conversion(self):
"""Ensure that ReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object.
"""
class Member(Document):
user_num = IntField(primary_key=True)
class BlogPost(Document):
title = StringField()
author = ReferenceField(Member, dbref=False)
Member.drop_collection()
BlogPost.drop_collection()
m1 = Member(user_num=1)
m1.save()
m2 = Member(user_num=2)
m2.save()
post1 = BlogPost(title='post 1', author=m1)
post1.save()
post2 = BlogPost(title='post 2', author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
self.assertEqual(post.id, post1.id)
post = BlogPost.objects(author=m2).first()
self.assertEqual(post.id, post2.id)
Member.drop_collection()
BlogPost.drop_collection()
def test_reference_query_conversion_dbref(self):
"""Ensure that ReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object.
"""
class Member(Document):
user_num = IntField(primary_key=True)
class BlogPost(Document):
title = StringField()
author = ReferenceField(Member, dbref=True)
Member.drop_collection()
BlogPost.drop_collection()
m1 = Member(user_num=1)
m1.save()
m2 = Member(user_num=2)
m2.save()
post1 = BlogPost(title='post 1', author=m1)
post1.save()
post2 = BlogPost(title='post 2', author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
self.assertEqual(post.id, post1.id)
post = BlogPost.objects(author=m2).first()
self.assertEqual(post.id, post2.id)
Member.drop_collection()
BlogPost.drop_collection()
def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences items.
"""
class Link(Document):
title = StringField()
meta = {'allow_inheritance': False}
class Post(Document):
title = StringField()
class Bookmark(Document):
bookmark_object = GenericReferenceField()
Link.drop_collection()
Post.drop_collection()
Bookmark.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
post_1.save()
bm = Bookmark(bookmark_object=post_1)
bm.save()
bm = Bookmark.objects(bookmark_object=post_1).first()
self.assertEqual(bm.bookmark_object, post_1)
self.assertTrue(isinstance(bm.bookmark_object, Post))
bm.bookmark_object = link_1
bm.save()
bm = Bookmark.objects(bookmark_object=link_1).first()
self.assertEqual(bm.bookmark_object, link_1)
self.assertTrue(isinstance(bm.bookmark_object, Link))
Link.drop_collection()
Post.drop_collection()
Bookmark.drop_collection()
def test_generic_reference_list(self):
"""Ensure that a ListField properly dereferences generic references.
"""
class Link(Document):
title = StringField()
class Post(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
Post.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
post_1.save()
user = User(bookmarks=[post_1, link_1])
user.save()
user = User.objects(bookmarks__all=[post_1, link_1]).first()
self.assertEqual(user.bookmarks[0], post_1)
self.assertEqual(user.bookmarks[1], link_1)
Link.drop_collection()
Post.drop_collection()
User.drop_collection()
def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
class Link(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
user = User(bookmarks=[link_1])
user.save()
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del(_document_registry["Link"])
user = User.objects.first()
try:
user.bookmarks
raise AssertionError("Link was removed from the registry")
except NotRegistered:
pass
Link.drop_collection()
User.drop_collection()
def test_generic_reference_is_none(self):
class Person(Document):
name = StringField()
city = GenericReferenceField()
Person.drop_collection()
Person(name="Wilson Jr").save()
self.assertEqual(repr(Person.objects(city=None)),
"[<Person: Person object>]")
def test_generic_reference_choices(self):
"""Ensure that a GenericReferenceField can handle choices
"""
class Link(Document):
title = StringField()
class Post(Document):
title = StringField()
class Bookmark(Document):
bookmark_object = GenericReferenceField(choices=(Post,))
Link.drop_collection()
Post.drop_collection()
Bookmark.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
post_1.save()
bm = Bookmark(bookmark_object=link_1)
self.assertRaises(ValidationError, bm.validate)
bm = Bookmark(bookmark_object=post_1)
bm.save()
bm = Bookmark.objects.first()
self.assertEqual(bm.bookmark_object, post_1)
def test_generic_reference_list_choices(self):
"""Ensure that a ListField properly dereferences generic references and
respects choices.
"""
class Link(Document):
title = StringField()
class Post(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField(choices=(Post,)))
Link.drop_collection()
Post.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
post_1.save()
user = User(bookmarks=[link_1])
self.assertRaises(ValidationError, user.validate)
user = User(bookmarks=[post_1])
user.save()
user = User.objects.first()
self.assertEqual(user.bookmarks, [post_1])
Link.drop_collection()
Post.drop_collection()
User.drop_collection()
def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved.
"""
class Attachment(Document):
content_type = StringField()
blob = BinaryField()
BLOB = b('\xe6\x00\xc4\xff\x07')
MIME_TYPE = 'application/octet-stream'
Attachment.drop_collection()
attachment = Attachment(content_type=MIME_TYPE, blob=BLOB)
attachment.save()
attachment_1 = Attachment.objects().first()
self.assertEqual(MIME_TYPE, attachment_1.content_type)
self.assertEqual(BLOB, bin_type(attachment_1.blob))
Attachment.drop_collection()
def test_binary_validation(self):
"""Ensure that invalid values cannot be assigned to binary fields.
"""
class Attachment(Document):
blob = BinaryField()
class AttachmentRequired(Document):
blob = BinaryField(required=True)
class AttachmentSizeLimit(Document):
blob = BinaryField(max_bytes=4)
Attachment.drop_collection()
AttachmentRequired.drop_collection()
AttachmentSizeLimit.drop_collection()
attachment = Attachment()
attachment.validate()
attachment.blob = 2
self.assertRaises(ValidationError, attachment.validate)
attachment_required = AttachmentRequired()
self.assertRaises(ValidationError, attachment_required.validate)
attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07'))
attachment_required.validate()
attachment_size_limit = AttachmentSizeLimit(blob=b('\xe6\x00\xc4\xff\x07'))
self.assertRaises(ValidationError, attachment_size_limit.validate)
attachment_size_limit.blob = b('\xe6\x00\xc4\xff')
attachment_size_limit.validate()
Attachment.drop_collection()
AttachmentRequired.drop_collection()
AttachmentSizeLimit.drop_collection()
def test_binary_field_primary(self):
class Attachment(Document):
id = BinaryField(primary_key=True)
Attachment.drop_collection()
att = Attachment(id=uuid.uuid4().bytes).save()
att.delete()
self.assertEqual(0, Attachment.objects.count())
def test_choices_validation(self):
"""Ensure that value is in a container of allowed values.
"""
class Shirt(Document):
size = StringField(max_length=3, choices=(
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
Shirt.drop_collection()
shirt = Shirt()
shirt.validate()
shirt.size = "S"
shirt.validate()
shirt.size = "XS"
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_choices_get_field_display(self):
"""Test dynamic helper for returning the display value of a choices
field.
"""
class Shirt(Document):
size = StringField(max_length=3, choices=(
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
Shirt.drop_collection()
shirt = Shirt()
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values.
"""
class Shirt(Document):
size = StringField(max_length=3,
choices=('S', 'M', 'L', 'XL', 'XXL'))
Shirt.drop_collection()
shirt = Shirt()
shirt.validate()
shirt.size = "S"
shirt.validate()
shirt.size = "XS"
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_simple_choices_get_field_display(self):
"""Test dynamic helper for returning the display value of a choices
field.
"""
class Shirt(Document):
size = StringField(max_length=3,
choices=('S', 'M', 'L', 'XL', 'XXL'))
style = StringField(max_length=3,
choices=('Small', 'Baggy', 'wide'),
default='Small')
Shirt.drop_collection()
shirt = Shirt()
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
shirt.size = "XXL"
shirt.style = "Baggy"
self.assertEqual(shirt.get_size_display(), 'XXL')
self.assertEqual(shirt.get_style_display(), 'Baggy')
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_simple_choices_validation_invalid_value(self):
"""Ensure that error messages are correct.
"""
SIZES = ('S', 'M', 'L', 'XL', 'XXL')
COLORS = (('R', 'Red'), ('B', 'Blue'))
SIZE_MESSAGE = u"Value must be one of ('S', 'M', 'L', 'XL', 'XXL')"
COLOR_MESSAGE = u"Value must be one of ['R', 'B']"
class Shirt(Document):
size = StringField(max_length=3, choices=SIZES)
color = StringField(max_length=1, choices=COLORS)
Shirt.drop_collection()
shirt = Shirt()
shirt.validate()
shirt.size = "S"
shirt.color = "R"
shirt.validate()
shirt.size = "XS"
shirt.color = "G"
try:
shirt.validate()
except ValidationError, error:
# get the validation rules
error_dict = error.to_dict()
self.assertEqual(error_dict['size'], SIZE_MESSAGE)
self.assertEqual(error_dict['color'], COLOR_MESSAGE)
Shirt.drop_collection()
def test_ensure_unique_default_instances(self):
"""Ensure that every field has it's own unique default instance."""
class D(Document):
data = DictField()
data2 = DictField(default=lambda: {})
d1 = D()
d1.data['foo'] = 'bar'
d1.data2['foo'] = 'bar'
d2 = D()
self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {})
def test_sequence_field(self):
class Person(Document):
id = SequenceField(primary_key=True)
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
Person.id.set_next_value(1000)
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 1000)
def test_sequence_field_get_next_value(self):
class Person(Document):
id = SequenceField(primary_key=True)
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), 11)
self.db['mongoengine.counters'].drop()
self.assertEqual(Person.id.get_next_value(), 1)
class Person(Document):
id = SequenceField(primary_key=True, value_decorator=str)
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), '11')
self.db['mongoengine.counters'].drop()
self.assertEqual(Person.id.get_next_value(), '1')
def test_sequence_field_sequence_name(self):
class Person(Document):
id = SequenceField(primary_key=True, sequence_name='jelly')
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
self.assertEqual(c['next'], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
self.assertEqual(c['next'], 10)
Person.id.set_next_value(1000)
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
self.assertEqual(c['next'], 1000)
def test_multiple_sequence_fields(self):
class Person(Document):
id = SequenceField(primary_key=True)
counter = SequenceField()
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
counters = [i.counter for i in Person.objects]
self.assertEqual(counters, range(1, 11))
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
Person.id.set_next_value(1000)
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 1000)
Person.counter.set_next_value(999)
c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'})
self.assertEqual(c['next'], 999)
def test_sequence_fields_reload(self):
class Animal(Document):
counter = SequenceField()
name = StringField()
self.db['mongoengine.counters'].drop()
Animal.drop_collection()
a = Animal(name="Boi").save()
self.assertEqual(a.counter, 1)
a.reload()
self.assertEqual(a.counter, 1)
a.counter = None
self.assertEqual(a.counter, 2)
a.save()
self.assertEqual(a.counter, 2)
a = Animal.objects.first()
self.assertEqual(a.counter, 2)
a.reload()
self.assertEqual(a.counter, 2)
def test_multiple_sequence_fields_on_docs(self):
class Animal(Document):
id = SequenceField(primary_key=True)
class Person(Document):
id = SequenceField(primary_key=True)
self.db['mongoengine.counters'].drop()
Animal.drop_collection()
Person.drop_collection()
for x in xrange(10):
Animal(name="Animal %s" % x).save()
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
self.assertEqual(c['next'], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
id = [i.id for i in Animal.objects]
self.assertEqual(id, range(1, 11))
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
self.assertEqual(c['next'], 10)
def test_sequence_field_value_decorator(self):
class Person(Document):
id = SequenceField(primary_key=True, value_decorator=str)
name = StringField()
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
p = Person(name="Person %s" % x)
p.save()
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, map(str, range(1, 11)))
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
def test_embedded_sequence_field(self):
class Comment(EmbeddedDocument):
id = SequenceField()
content = StringField(required=True)
class Post(Document):
title = StringField(required=True)
comments = ListField(EmbeddedDocumentField(Comment))
self.db['mongoengine.counters'].drop()
Post.drop_collection()
Post(title="MongoEngine",
comments=[Comment(content="NoSQL Rocks"),
Comment(content="MongoEngine Rocks")]).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'})
self.assertEqual(c['next'], 2)
post = Post.objects.first()
self.assertEqual(1, post.comments[0].id)
self.assertEqual(2, post.comments[1].id)
def test_generic_embedded_document(self):
class Car(EmbeddedDocument):
name = StringField()
class Dish(EmbeddedDocument):
food = StringField(required=True)
number = IntField()
class Person(Document):
name = StringField()
like = GenericEmbeddedDocumentField()
Person.drop_collection()
person = Person(name='Test User')
person.like = Car(name='Fiat')
person.save()
person = Person.objects.first()
self.assertTrue(isinstance(person.like, Car))
person.like = Dish(food="arroz", number=15)
person.save()
person = Person.objects.first()
self.assertTrue(isinstance(person.like, Dish))
def test_generic_embedded_document_choices(self):
"""Ensure you can limit GenericEmbeddedDocument choices
"""
class Car(EmbeddedDocument):
name = StringField()
class Dish(EmbeddedDocument):
food = StringField(required=True)
number = IntField()
class Person(Document):
name = StringField()
like = GenericEmbeddedDocumentField(choices=(Dish,))
Person.drop_collection()
person = Person(name='Test User')
person.like = Car(name='Fiat')
self.assertRaises(ValidationError, person.validate)
person.like = Dish(food="arroz", number=15)
person.save()
person = Person.objects.first()
self.assertTrue(isinstance(person.like, Dish))
def test_generic_list_embedded_document_choices(self):
"""Ensure you can limit GenericEmbeddedDocument choices inside a list
field
"""
class Car(EmbeddedDocument):
name = StringField()
class Dish(EmbeddedDocument):
food = StringField(required=True)
number = IntField()
class Person(Document):
name = StringField()
likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,)))
Person.drop_collection()
person = Person(name='Test User')
person.likes = [Car(name='Fiat')]
self.assertRaises(ValidationError, person.validate)
person.likes = [Dish(food="arroz", number=15)]
person.save()
person = Person.objects.first()
self.assertTrue(isinstance(person.likes[0], Dish))
def test_recursive_validation(self):
"""Ensure that a validation result to_dict is available.
"""
class Author(EmbeddedDocument):
name = StringField(required=True)
class Comment(EmbeddedDocument):
author = EmbeddedDocumentField(Author, required=True)
content = StringField(required=True)
class Post(Document):
title = StringField(required=True)
comments = ListField(EmbeddedDocumentField(Comment))
bob = Author(name='Bob')
post = Post(title='hello world')
post.comments.append(Comment(content='hello', author=bob))
post.comments.append(Comment(author=bob))
self.assertRaises(ValidationError, post.validate)
try:
post.validate()
except ValidationError, error:
# ValidationError.errors property
self.assertTrue(hasattr(error, 'errors'))
self.assertTrue(isinstance(error.errors, dict))
self.assertTrue('comments' in error.errors)
self.assertTrue(1 in error.errors['comments'])
self.assertTrue(isinstance(error.errors['comments'][1]['content'],
ValidationError))
# ValidationError.schema property
error_dict = error.to_dict()
self.assertTrue(isinstance(error_dict, dict))
self.assertTrue('comments' in error_dict)
self.assertTrue(1 in error_dict['comments'])
self.assertTrue('content' in error_dict['comments'][1])
self.assertEqual(error_dict['comments'][1]['content'],
u'Field is required')
post.comments[1].content = 'here we go'
post.validate()
def test_email_field(self):
class User(Document):
email = EmailField()
user = User(email="ross@example.com")
self.assertTrue(user.validate() is None)
user = User(email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S"
"ucictfqpdkK9iS1zeFw8sg7s7cwAF7suIfUfeyueLpfosjn3"
"aJIazqqWkm7.net"))
self.assertTrue(user.validate() is None)
user = User(email='me@localhost')
self.assertRaises(ValidationError, user.validate)
def test_email_field_honors_regex(self):
class User(Document):
email = EmailField(regex=r'\w+@example.com')
# Fails regex validation
user = User(email='me@foo.com')
self.assertRaises(ValidationError, user.validate)
# Passes regex validation
user = User(email='me@example.com')
self.assertTrue(user.validate() is None)
if __name__ == '__main__':
unittest.main()