diff --git a/docs/changelog.rst b/docs/changelog.rst index cd99b73e..0da97e90 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,7 +4,8 @@ Changelog Development =========== -- (Fill this out as you fix issues and develop you features). +- (Fill this out as you fix issues and develop your features). +- Fixed using sets in field choices #1481 - POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476 - POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476 - Fixed connecting to a replica set with PyMongo 2.x #1436 diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index e656dee0..d41ae7e6 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -150,7 +150,7 @@ arguments can be set on all fields: .. note:: If set, this field is also accessible through the `pk` field. :attr:`choices` (Default: None) - An iterable (e.g. a list or tuple) of choices to which the value of this + An iterable (e.g. list, tuple or set) of choices to which the value of this field should be limited. Can be either be a nested tuples of value (stored in mongo) and a @@ -214,8 +214,8 @@ document class as the first argument:: Dictionary Fields ----------------- -Often, an embedded document may be used instead of a dictionary – generally -embedded documents are recommended as dictionaries don’t support validation +Often, an embedded document may be used instead of a dictionary – generally +embedded documents are recommended as dictionaries don’t support validation or custom field types. However, sometimes you will not know the structure of what you want to store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate:: diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 9ba9dc9a..5658b185 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -193,7 +193,8 @@ class BaseField(object): EmbeddedDocument = _import_class('EmbeddedDocument') choice_list = self.choices - if isinstance(choice_list[0], (list, tuple)): + if isinstance(next(iter(choice_list)), (list, tuple)): + # next(iter) is useful for sets choice_list = [k for k, _ in choice_list] # Choices which are other types of Documents diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 5c83d58c..f24bcae4 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- -import six -from nose.plugins.skip import SkipTest - import datetime import unittest import uuid import math import itertools import re + +from nose.plugins.skip import SkipTest import six try: @@ -27,21 +26,13 @@ from mongoengine import * from mongoengine.connection import get_db from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, _document_registry) -from mongoengine.errors import NotRegistered, DoesNotExist + +from tests.utils import MongoDBTestCase __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") -class FieldTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def tearDown(self): - self.db.drop_collection('fs.files') - self.db.drop_collection('fs.chunks') - self.db.drop_collection('mongoengine.counters') +class FieldTest(MongoDBTestCase): def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. @@ -3186,26 +3177,42 @@ class FieldTest(unittest.TestCase): att.delete() self.assertEqual(0, Attachment.objects.count()) - def test_choices_validation(self): - """Ensure that value is in a container of allowed values. + def test_choices_allow_using_sets_as_choices(self): + """Ensure that sets can be used when setting choices """ class Shirt(Document): - size = StringField(max_length=3, choices=( - ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), - ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + size = StringField(choices={'M', 'L'}) - Shirt.drop_collection() + Shirt(size='M').validate() + + def test_choices_validation_allow_no_value(self): + """Ensure that .validate passes and no value was provided + for a field setup with choices + """ + class Shirt(Document): + size = StringField(choices=('S', 'M')) shirt = Shirt() shirt.validate() - shirt.size = "S" + def test_choices_validation_accept_possible_value(self): + """Ensure that value is in a container of allowed values. + """ + class Shirt(Document): + size = StringField(choices=('S', 'M')) + + shirt = Shirt(size='S') shirt.validate() - shirt.size = "XS" - self.assertRaises(ValidationError, shirt.validate) + def test_choices_validation_reject_unknown_value(self): + """Ensure that unallowed value are rejected upon validation + """ + class Shirt(Document): + size = StringField(choices=('S', 'M')) - Shirt.drop_collection() + shirt = Shirt(size="XS") + with self.assertRaises(ValidationError): + shirt.validate() def test_choices_validation_documents(self): """ @@ -4420,7 +4427,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): my_list = ListField(EmbeddedDocumentField(EmbeddedWithUnique)) A(my_list=[]).save() - self.assertRaises(NotUniqueError, lambda: A(my_list=[]).save()) + with self.assertRaises(NotUniqueError): + A(my_list=[]).save() class EmbeddedWithSparseUnique(EmbeddedDocument): number = IntField(unique=True, sparse=True) @@ -4431,6 +4439,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): B(my_list=[]).save() B(my_list=[]).save() + A.drop_collection() + B.drop_collection() + def test_filtered_delete(self): """ Tests the delete method of a List of Embedded Documents diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index b266a5e5..8364d5ef 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -18,15 +18,13 @@ try: except ImportError: HAS_PIL = False +from tests.utils import MongoDBTestCase + TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') -class FileTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() +class FileTest(MongoDBTestCase): def tearDown(self): self.db.drop_collection('fs.files') diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..128bbff0 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,22 @@ +import unittest + +from mongoengine import connect +from mongoengine.connection import get_db + +MONGO_TEST_DB = 'mongoenginetest' + + +class MongoDBTestCase(unittest.TestCase): + """Base class for tests that need a mongodb connection + db is being dropped automatically + """ + + @classmethod + def setUpClass(cls): + cls._connection = connect(db=MONGO_TEST_DB) + cls._connection.drop_database(MONGO_TEST_DB) + cls.db = get_db() + + @classmethod + def tearDownClass(cls): + cls._connection.drop_database(MONGO_TEST_DB)