This commit is contained in:
Nat Noordanus 2020-05-27 10:56:12 +02:00
parent a5effb219a
commit be2a24d15c
4 changed files with 31 additions and 17 deletions

View File

@ -652,7 +652,9 @@ class Message(ABC):
return self._betterproto.default_gen[field.name]() return self._betterproto.default_gen[field.name]()
@classmethod @classmethod
def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any: def _get_field_default_gen(
cls, field: dataclasses.Field, meta: FieldMetadata
) -> Any:
t = cls._type_hint(field.name) t = cls._type_hint(field.name)
if hasattr(t, "__origin__"): if hasattr(t, "__origin__"):
@ -831,7 +833,9 @@ class Message(ABC):
else: else:
output[cased_name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore enum_values = list(
self._betterproto.cls_by_field[field.name]
) # type: ignore
if isinstance(v, list): if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v] output[cased_name] = [enum_values[e].name for e in v]
else: else:

View File

@ -29,20 +29,25 @@ from betterproto.casing import safe_snake_case
import google.protobuf.wrappers_pb2 as google_wrappers import google.protobuf.wrappers_pb2 as google_wrappers
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, { WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(
'google.protobuf.DoubleValue': google_wrappers.DoubleValue, lambda: None,
'google.protobuf.FloatValue': google_wrappers.FloatValue, {
'google.protobuf.Int64Value': google_wrappers.Int64Value, "google.protobuf.DoubleValue": google_wrappers.DoubleValue,
'google.protobuf.UInt64Value': google_wrappers.UInt64Value, "google.protobuf.FloatValue": google_wrappers.FloatValue,
'google.protobuf.Int32Value': google_wrappers.Int32Value, "google.protobuf.Int64Value": google_wrappers.Int64Value,
'google.protobuf.UInt32Value': google_wrappers.UInt32Value, "google.protobuf.UInt64Value": google_wrappers.UInt64Value,
'google.protobuf.BoolValue': google_wrappers.BoolValue, "google.protobuf.Int32Value": google_wrappers.Int32Value,
'google.protobuf.StringValue': google_wrappers.StringValue, "google.protobuf.UInt32Value": google_wrappers.UInt32Value,
'google.protobuf.BytesValue': google_wrappers.BytesValue, "google.protobuf.BoolValue": google_wrappers.BoolValue,
}) "google.protobuf.StringValue": google_wrappers.StringValue,
"google.protobuf.BytesValue": google_wrappers.BytesValue,
},
)
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
) -> str:
""" """
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required. necessary. Unwraps well known type if required.
@ -385,7 +390,10 @@ def generate_code(request, response):
).strip('"'), ).strip('"'),
"input_message": input_message, "input_message": input_message,
"output": get_ref_type( "output": get_ref_type(
package, output["imports"], method.output_type, unwrap=False package,
output["imports"],
method.output_type,
unwrap=False,
).strip('"'), ).strip('"'),
"client_streaming": method.client_streaming, "client_streaming": method.client_streaming,
"server_streaming": method.server_streaming, "server_streaming": method.server_streaming,

View File

@ -34,7 +34,7 @@ class ExampleService:
grpclib.const.Cardinality.UNARY_UNARY, grpclib.const.Cardinality.UNARY_UNARY,
DoThingRequest, DoThingRequest,
DoThingResponse, DoThingResponse,
), )
} }

View File

@ -2,7 +2,9 @@ import pytest
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption("--repeat", type=int, default=1, help="repeat the operation multiple times") parser.addoption(
"--repeat", type=int, default=1, help="repeat the operation multiple times"
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")