From bc3cfc5562307f1ec84dffade143b4a0cdc2fb96 Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Sat, 4 Dec 2021 00:26:48 +0300 Subject: [PATCH] Fix default values for enum service args #298 (#299) --- src/betterproto/plugin/models.py | 13 ++++++++++--- tests/inputs/service/service.proto | 7 +++++++ tests/test_features.py | 10 +++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 680b15c..e58092f 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -59,8 +59,7 @@ from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest import re import textwrap from dataclasses import dataclass, field -from typing import Dict, Iterable, Iterator, List, Optional, Set, Text, Type, Union -import sys +from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union from ..casing import sanitize_name from ..compile.importing import get_type_reference, parse_source_type_name @@ -460,7 +459,7 @@ class FieldCompiler(MessageCompiler): ) @property - def default_value_string(self) -> Union[Text, None, float, int]: + def default_value_string(self) -> str: """Python representation of the default proto value.""" if self.repeated: return "[]" @@ -474,6 +473,14 @@ class FieldCompiler(MessageCompiler): return '""' elif self.py_type == "bytes": return 'b""' + elif self.field_type == "enum": + enum_proto_obj_name = self.proto_obj.type_name.split(".").pop() + enum = next( + e + for e in self.output_file.enums + if e.proto_obj.name == enum_proto_obj_name + ) + return enum.default_value_string else: # Message type return "None" diff --git a/tests/inputs/service/service.proto b/tests/inputs/service/service.proto index 9ca0d25..53d84fb 100644 --- a/tests/inputs/service/service.proto +++ b/tests/inputs/service/service.proto @@ -2,9 +2,16 @@ syntax = "proto3"; package service; +enum ThingType { + UNKNOWN = 0; + LIVING = 1; + DEAD = 2; +} + message DoThingRequest { string name = 1; repeated string comments = 2; + ThingType type = 3; } message DoThingResponse { diff --git a/tests/test_features.py b/tests/test_features.py index 3f44f17..b82528e 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,7 +1,8 @@ import betterproto from dataclasses import dataclass from typing import Optional, List, Dict -from datetime import datetime, timedelta +from datetime import datetime +from inspect import signature def test_has_field(): @@ -476,3 +477,10 @@ def test_iso_datetime_list(): msg.from_dict({"timestamps": iso_candidates}) assert all([isinstance(item, datetime) for item in msg.timestamps]) + + +def test_enum_service_argument__expected_default_value(): + from tests.output_betterproto.service.service import ThingType, TestStub + + sig = signature(TestStub.do_thing) + assert sig.parameters["type"].default == ThingType.UNKNOWN