Handle mutable default arguments cleanly
When generating code, ensure that default list/dict arguments are initialised in local scope if unspecified or `None`.
This commit is contained in:
parent
42e197f985
commit
0ba0692dec
20
poetry.lock
generated
20
poetry.lock
generated
@ -475,6 +475,20 @@ pytest = ">=4.6"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"]
|
testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
category = "dev"
|
||||||
|
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||||
|
name = "pytest-mock"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.5"
|
||||||
|
version = "3.1.1"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pytest = ">=2.7"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["pre-commit", "tox", "pytest-asyncio"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
category = "main"
|
category = "main"
|
||||||
description = "Alternative regular expression module, to replace re."
|
description = "Alternative regular expression module, to replace re."
|
||||||
@ -623,7 +637,7 @@ testing = ["jaraco.itertools", "func-timeout"]
|
|||||||
compiler = ["black", "jinja2", "protobuf"]
|
compiler = ["black", "jinja2", "protobuf"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
content-hash = "375411698ec644810d09af809ac8a004abc9821f0b718178505424625f407c14"
|
content-hash = "7a2aa57a3a2b58e70aa05be75bff08272c2f070ecc1872a1c825f228299b82c2"
|
||||||
python-versions = "^3.6"
|
python-versions = "^3.6"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
@ -972,6 +986,10 @@ pytest-cov = [
|
|||||||
{file = "pytest-cov-2.10.0.tar.gz", hash = "sha256:1a629dc9f48e53512fcbfda6b07de490c374b0c83c55ff7a1720b3fccff0ac87"},
|
{file = "pytest-cov-2.10.0.tar.gz", hash = "sha256:1a629dc9f48e53512fcbfda6b07de490c374b0c83c55ff7a1720b3fccff0ac87"},
|
||||||
{file = "pytest_cov-2.10.0-py2.py3-none-any.whl", hash = "sha256:6e6d18092dce6fad667cd7020deed816f858ad3b49d5b5e2b1cc1c97a4dba65c"},
|
{file = "pytest_cov-2.10.0-py2.py3-none-any.whl", hash = "sha256:6e6d18092dce6fad667cd7020deed816f858ad3b49d5b5e2b1cc1c97a4dba65c"},
|
||||||
]
|
]
|
||||||
|
pytest-mock = [
|
||||||
|
{file = "pytest-mock-3.1.1.tar.gz", hash = "sha256:636e792f7dd9e2c80657e174c04bf7aa92672350090736d82e97e92ce8f68737"},
|
||||||
|
{file = "pytest_mock-3.1.1-py3-none-any.whl", hash = "sha256:a9fedba70e37acf016238bb2293f2652ce19985ceb245bbd3d7f3e4032667402"},
|
||||||
|
]
|
||||||
regex = [
|
regex = [
|
||||||
{file = "regex-2020.6.8-cp27-cp27m-win32.whl", hash = "sha256:fbff901c54c22425a5b809b914a3bfaf4b9570eee0e5ce8186ac71eb2025191c"},
|
{file = "regex-2020.6.8-cp27-cp27m-win32.whl", hash = "sha256:fbff901c54c22425a5b809b914a3bfaf4b9570eee0e5ce8186ac71eb2025191c"},
|
||||||
{file = "regex-2020.6.8-cp27-cp27m-win_amd64.whl", hash = "sha256:112e34adf95e45158c597feea65d06a8124898bdeac975c9087fe71b572bd938"},
|
{file = "regex-2020.6.8-cp27-cp27m-win_amd64.whl", hash = "sha256:112e34adf95e45158c597feea65d06a8124898bdeac975c9087fe71b572bd938"},
|
||||||
|
@ -30,6 +30,7 @@ protobuf = "^3.12.2"
|
|||||||
pytest = "^5.4.2"
|
pytest = "^5.4.2"
|
||||||
pytest-asyncio = "^0.12.0"
|
pytest-asyncio = "^0.12.0"
|
||||||
pytest-cov = "^2.9.0"
|
pytest-cov = "^2.9.0"
|
||||||
|
pytest-mock = "^3.1.1"
|
||||||
tox = "^3.15.1"
|
tox = "^3.15.1"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
@ -369,6 +369,10 @@ def lookup_method_input_type(method, types):
|
|||||||
return known_type
|
return known_type
|
||||||
|
|
||||||
|
|
||||||
|
def is_mutable_field_type(field_type: str) -> bool:
|
||||||
|
return field_type.startswith("List[") or field_type.startswith("Dict[")
|
||||||
|
|
||||||
|
|
||||||
def read_protobuf_service(
|
def read_protobuf_service(
|
||||||
service: ServiceDescriptorProto, index, proto_file, content, output_types
|
service: ServiceDescriptorProto, index, proto_file, content, output_types
|
||||||
):
|
):
|
||||||
@ -384,8 +388,23 @@ def read_protobuf_service(
|
|||||||
for j, method in enumerate(service.method):
|
for j, method in enumerate(service.method):
|
||||||
method_input_message = lookup_method_input_type(method, output_types)
|
method_input_message = lookup_method_input_type(method, output_types)
|
||||||
|
|
||||||
|
# This section ensures that method arguments having a default
|
||||||
|
# value that is initialised as a List/Dict (mutable) is replaced
|
||||||
|
# with None and initialisation is deferred to the beginning of the
|
||||||
|
# method definition. This is done so to avoid any side-effects.
|
||||||
|
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
|
||||||
|
mutable_default_args = []
|
||||||
|
|
||||||
if method_input_message:
|
if method_input_message:
|
||||||
for field in method_input_message["properties"]:
|
for field in method_input_message["properties"]:
|
||||||
|
if (
|
||||||
|
not method.client_streaming
|
||||||
|
and field["zero"] != "None"
|
||||||
|
and is_mutable_field_type(field["type"])
|
||||||
|
):
|
||||||
|
mutable_default_args.append((field["py_name"], field["zero"]))
|
||||||
|
field["zero"] = "None"
|
||||||
|
|
||||||
if field["zero"] == "None":
|
if field["zero"] == "None":
|
||||||
template_data["typing_imports"].add("Optional")
|
template_data["typing_imports"].add("Optional")
|
||||||
|
|
||||||
@ -407,6 +426,7 @@ def read_protobuf_service(
|
|||||||
),
|
),
|
||||||
"client_streaming": method.client_streaming,
|
"client_streaming": method.client_streaming,
|
||||||
"server_streaming": method.server_streaming,
|
"server_streaming": method.server_streaming,
|
||||||
|
"mutable_default_args": mutable_default_args,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{%- for py_name, zero in method.mutable_default_args %}
|
||||||
|
{{ py_name }} = {{ py_name }} or {{ zero }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
{% if not method.client_streaming %}
|
{% if not method.client_streaming %}
|
||||||
request = {{ method.input }}()
|
request = {{ method.input }}()
|
||||||
{% for field in method.input_message.properties %}
|
{% for field in method.input_message.properties %}
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
from tests.output_betterproto.service.service import (
|
from tests.output_betterproto.service.service import (
|
||||||
DoThingResponse,
|
|
||||||
DoThingRequest,
|
DoThingRequest,
|
||||||
|
DoThingResponse,
|
||||||
GetThingRequest,
|
GetThingRequest,
|
||||||
TestStub as ThingServiceClient,
|
TestStub as ThingServiceClient,
|
||||||
)
|
)
|
||||||
import grpclib
|
import grpclib
|
||||||
|
import grpclib.metadata
|
||||||
from grpclib.testing import ChannelFor
|
from grpclib.testing import ChannelFor
|
||||||
import pytest
|
import pytest
|
||||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
@ -35,6 +38,20 @@ async def test_simple_service_call():
|
|||||||
await _test_client(ThingServiceClient(channel))
|
await _test_client(ThingServiceClient(channel))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
|
||||||
|
)
|
||||||
|
async def test_service_call_mutable_defaults(mocker):
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
spy = mocker.spy(client, "_unary_unary")
|
||||||
|
await _test_client(client)
|
||||||
|
comments = spy.call_args_list[-1].args[1].comments
|
||||||
|
await _test_client(client)
|
||||||
|
assert spy.call_args_list[-1].args[1].comments is not comments
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_service_call_with_upfront_request_params():
|
async def test_service_call_with_upfront_request_params():
|
||||||
# Setting deadline
|
# Setting deadline
|
||||||
|
@ -4,6 +4,7 @@ package service;
|
|||||||
|
|
||||||
message DoThingRequest {
|
message DoThingRequest {
|
||||||
string name = 1;
|
string name = 1;
|
||||||
|
repeated string comments = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DoThingResponse {
|
message DoThingResponse {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user