From 0ba0692dec31883bbcf3bd9302e58205e3195479 Mon Sep 17 00:00:00 2001 From: Arun Babu Neelicattu Date: Tue, 7 Jul 2020 18:13:43 +0200 Subject: [PATCH] Handle mutable default arguments cleanly When generating code, ensure that default list/dict arguments are initialised in local scope if unspecified or `None`. --- poetry.lock | 20 +++++++++++++++++++- pyproject.toml | 1 + src/betterproto/plugin.py | 20 ++++++++++++++++++++ src/betterproto/templates/template.py.j2 | 4 ++++ tests/grpc/test_grpclib_client.py | 19 ++++++++++++++++++- tests/inputs/service/service.proto | 1 + 6 files changed, 63 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1113322..9c926b0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -475,6 +475,20 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"] +[[package]] +category = "dev" +description = "Thin-wrapper around the mock package for easier use with pytest" +name = "pytest-mock" +optional = false +python-versions = ">=3.5" +version = "3.1.1" + +[package.dependencies] +pytest = ">=2.7" + +[package.extras] +dev = ["pre-commit", "tox", "pytest-asyncio"] + [[package]] category = "main" description = "Alternative regular expression module, to replace re." @@ -623,7 +637,7 @@ testing = ["jaraco.itertools", "func-timeout"] compiler = ["black", "jinja2", "protobuf"] [metadata] -content-hash = "375411698ec644810d09af809ac8a004abc9821f0b718178505424625f407c14" +content-hash = "7a2aa57a3a2b58e70aa05be75bff08272c2f070ecc1872a1c825f228299b82c2" python-versions = "^3.6" [metadata.files] @@ -972,6 +986,10 @@ pytest-cov = [ {file = "pytest-cov-2.10.0.tar.gz", hash = "sha256:1a629dc9f48e53512fcbfda6b07de490c374b0c83c55ff7a1720b3fccff0ac87"}, {file = "pytest_cov-2.10.0-py2.py3-none-any.whl", hash = "sha256:6e6d18092dce6fad667cd7020deed816f858ad3b49d5b5e2b1cc1c97a4dba65c"}, ] +pytest-mock = [ + {file = "pytest-mock-3.1.1.tar.gz", hash = "sha256:636e792f7dd9e2c80657e174c04bf7aa92672350090736d82e97e92ce8f68737"}, + {file = "pytest_mock-3.1.1-py3-none-any.whl", hash = "sha256:a9fedba70e37acf016238bb2293f2652ce19985ceb245bbd3d7f3e4032667402"}, +] regex = [ {file = "regex-2020.6.8-cp27-cp27m-win32.whl", hash = "sha256:fbff901c54c22425a5b809b914a3bfaf4b9570eee0e5ce8186ac71eb2025191c"}, {file = "regex-2020.6.8-cp27-cp27m-win_amd64.whl", hash = "sha256:112e34adf95e45158c597feea65d06a8124898bdeac975c9087fe71b572bd938"}, diff --git a/pyproject.toml b/pyproject.toml index 400a797..c48a2b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ protobuf = "^3.12.2" pytest = "^5.4.2" pytest-asyncio = "^0.12.0" pytest-cov = "^2.9.0" +pytest-mock = "^3.1.1" tox = "^3.15.1" [tool.poetry.scripts] diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index 0d88d47..4f01c29 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -369,6 +369,10 @@ def lookup_method_input_type(method, types): return known_type +def is_mutable_field_type(field_type: str) -> bool: + return field_type.startswith("List[") or field_type.startswith("Dict[") + + def read_protobuf_service( service: ServiceDescriptorProto, index, proto_file, content, output_types ): @@ -384,8 +388,23 @@ def read_protobuf_service( for j, method in enumerate(service.method): method_input_message = lookup_method_input_type(method, output_types) + # This section ensures that method arguments having a default + # value that is initialised as a List/Dict (mutable) is replaced + # with None and initialisation is deferred to the beginning of the + # method definition. This is done so to avoid any side-effects. + # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments + mutable_default_args = [] + if method_input_message: for field in method_input_message["properties"]: + if ( + not method.client_streaming + and field["zero"] != "None" + and is_mutable_field_type(field["type"]) + ): + mutable_default_args.append((field["py_name"], field["zero"])) + field["zero"] = "None" + if field["zero"] == "None": template_data["typing_imports"].add("Optional") @@ -407,6 +426,7 @@ def read_protobuf_service( ), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, + "mutable_default_args": mutable_default_args, } ) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index b2d9112..b7ca89c 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {{ method.comment }} {% endif %} + {%- for py_name, zero in method.mutable_default_args %} + {{ py_name }} = {{ py_name }} or {{ zero }} + {% endfor %} + {% if not method.client_streaming %} request = {{ method.input }}() {% for field in method.input_message.properties %} diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py index d1ca47f..ed35c77 100644 --- a/tests/grpc/test_grpclib_client.py +++ b/tests/grpc/test_grpclib_client.py @@ -1,11 +1,14 @@ import asyncio +import sys + from tests.output_betterproto.service.service import ( - DoThingResponse, DoThingRequest, + DoThingResponse, GetThingRequest, TestStub as ThingServiceClient, ) import grpclib +import grpclib.metadata from grpclib.testing import ChannelFor import pytest from betterproto.grpc.util.async_channel import AsyncChannel @@ -35,6 +38,20 @@ async def test_simple_service_call(): await _test_client(ThingServiceClient(channel)) +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="async mock spy does works for python3.8+" +) +async def test_service_call_mutable_defaults(mocker): + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + spy = mocker.spy(client, "_unary_unary") + await _test_client(client) + comments = spy.call_args_list[-1].args[1].comments + await _test_client(client) + assert spy.call_args_list[-1].args[1].comments is not comments + + @pytest.mark.asyncio async def test_service_call_with_upfront_request_params(): # Setting deadline diff --git a/tests/inputs/service/service.proto b/tests/inputs/service/service.proto index acfbcdd..9ca0d25 100644 --- a/tests/inputs/service/service.proto +++ b/tests/inputs/service/service.proto @@ -4,6 +4,7 @@ package service; message DoThingRequest { string name = 1; + repeated string comments = 2; } message DoThingResponse {