diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 7ab2276d..9b9fef6e 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -916,8 +916,9 @@ class ListField(ComplexBaseField): Required means it cannot be empty - as the default for ListFields is [] """ - def __init__(self, field=None, **kwargs): + def __init__(self, field=None, max_length=None, **kwargs): self.field = field + self.max_length = max_length kwargs.setdefault("default", lambda: []) super(ListField, self).__init__(**kwargs) @@ -939,9 +940,21 @@ class ListField(ComplexBaseField): """Make sure that a list of valid fields is being used.""" if not isinstance(value, (list, tuple, BaseQuerySet)): self.error("Only lists and tuples may be used in a list field") + + # Validate that max_length is not exceeded. + # NOTE It's still possible to bypass this enforcement by using $push. + # However, if the document is reloaded after $push and then re-saved, + # the validation error will be raised. + if self.max_length is not None and len(value) > self.max_length: + self.error("List is too long") + super(ListField, self).validate(value) def prepare_query_value(self, op, value): + # Validate that the `set` operator doesn't contain more items than `max_length`. + if op == "set" and self.max_length is not None and len(value) > self.max_length: + self.error("List is too long") + if self.field: # If the value is iterable and it's not a string nor a diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 87acf27f..b77ba753 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1010,6 +1010,38 @@ class FieldTest(MongoDBTestCase): e.mapping = ["abc"] e.save() + def test_list_field_max_length(self): + """Ensure ListField's max_length is respected.""" + + class Foo(Document): + items = ListField(IntField(), max_length=5) + + foo = Foo() + for i in range(1, 7): + foo.items.append(i) + if i < 6: + foo.save() + else: + with self.assertRaises(ValidationError) as cm: + foo.save() + self.assertIn("List is too long", str(cm.exception)) + + def test_list_field_max_length(self): + """Ensure ListField's max_length is respected.""" + + class Foo(Document): + items = ListField(IntField(), max_length=5) + + foo = Foo() + for i in range(1, 7): + foo.items.append(i) + if i < 6: + foo.save() + else: + with self.assertRaises(ValidationError) as cm: + foo.save() + self.assertIn("List is too long", str(cm.exception)) + def test_list_field_rejects_strings(self): """Strings aren't valid list field data types."""