Handle mutable default arguments cleanly
When generating code, ensure that default list/dict arguments are initialised in local scope if unspecified or `None`.
This commit is contained in:
		
				
					committed by
					
						 Bouke Versteegh
						Bouke Versteegh
					
				
			
			
				
	
			
			
			
						parent
						
							42e197f985
						
					
				
				
					commit
					0ba0692dec
				
			
							
								
								
									
										20
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										20
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							| @@ -475,6 +475,20 @@ pytest = ">=4.6" | ||||
| [package.extras] | ||||
| testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"] | ||||
|  | ||||
| [[package]] | ||||
| category = "dev" | ||||
| description = "Thin-wrapper around the mock package for easier use with pytest" | ||||
| name = "pytest-mock" | ||||
| optional = false | ||||
| python-versions = ">=3.5" | ||||
| version = "3.1.1" | ||||
|  | ||||
| [package.dependencies] | ||||
| pytest = ">=2.7" | ||||
|  | ||||
| [package.extras] | ||||
| dev = ["pre-commit", "tox", "pytest-asyncio"] | ||||
|  | ||||
| [[package]] | ||||
| category = "main" | ||||
| description = "Alternative regular expression module, to replace re." | ||||
| @@ -623,7 +637,7 @@ testing = ["jaraco.itertools", "func-timeout"] | ||||
| compiler = ["black", "jinja2", "protobuf"] | ||||
|  | ||||
| [metadata] | ||||
| content-hash = "375411698ec644810d09af809ac8a004abc9821f0b718178505424625f407c14" | ||||
| content-hash = "7a2aa57a3a2b58e70aa05be75bff08272c2f070ecc1872a1c825f228299b82c2" | ||||
| python-versions = "^3.6" | ||||
|  | ||||
| [metadata.files] | ||||
| @@ -972,6 +986,10 @@ pytest-cov = [ | ||||
|     {file = "pytest-cov-2.10.0.tar.gz", hash = "sha256:1a629dc9f48e53512fcbfda6b07de490c374b0c83c55ff7a1720b3fccff0ac87"}, | ||||
|     {file = "pytest_cov-2.10.0-py2.py3-none-any.whl", hash = "sha256:6e6d18092dce6fad667cd7020deed816f858ad3b49d5b5e2b1cc1c97a4dba65c"}, | ||||
| ] | ||||
| pytest-mock = [ | ||||
|     {file = "pytest-mock-3.1.1.tar.gz", hash = "sha256:636e792f7dd9e2c80657e174c04bf7aa92672350090736d82e97e92ce8f68737"}, | ||||
|     {file = "pytest_mock-3.1.1-py3-none-any.whl", hash = "sha256:a9fedba70e37acf016238bb2293f2652ce19985ceb245bbd3d7f3e4032667402"}, | ||||
| ] | ||||
| regex = [ | ||||
|     {file = "regex-2020.6.8-cp27-cp27m-win32.whl", hash = "sha256:fbff901c54c22425a5b809b914a3bfaf4b9570eee0e5ce8186ac71eb2025191c"}, | ||||
|     {file = "regex-2020.6.8-cp27-cp27m-win_amd64.whl", hash = "sha256:112e34adf95e45158c597feea65d06a8124898bdeac975c9087fe71b572bd938"}, | ||||
|   | ||||
| @@ -30,6 +30,7 @@ protobuf = "^3.12.2" | ||||
| pytest = "^5.4.2" | ||||
| pytest-asyncio = "^0.12.0" | ||||
| pytest-cov = "^2.9.0" | ||||
| pytest-mock = "^3.1.1" | ||||
| tox = "^3.15.1" | ||||
|  | ||||
| [tool.poetry.scripts] | ||||
|   | ||||
| @@ -369,6 +369,10 @@ def lookup_method_input_type(method, types): | ||||
|             return known_type | ||||
|  | ||||
|  | ||||
| def is_mutable_field_type(field_type: str) -> bool: | ||||
|     return field_type.startswith("List[") or field_type.startswith("Dict[") | ||||
|  | ||||
|  | ||||
| def read_protobuf_service( | ||||
|     service: ServiceDescriptorProto, index, proto_file, content, output_types | ||||
| ): | ||||
| @@ -384,8 +388,23 @@ def read_protobuf_service( | ||||
|     for j, method in enumerate(service.method): | ||||
|         method_input_message = lookup_method_input_type(method, output_types) | ||||
|  | ||||
|         # This section ensures that method arguments having a default | ||||
|         # value that is initialised as a List/Dict (mutable) is replaced | ||||
|         # with None and initialisation is deferred to the beginning of the | ||||
|         # method definition. This is done so to avoid any side-effects. | ||||
|         # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments | ||||
|         mutable_default_args = [] | ||||
|  | ||||
|         if method_input_message: | ||||
|             for field in method_input_message["properties"]: | ||||
|                 if ( | ||||
|                     not method.client_streaming | ||||
|                     and field["zero"] != "None" | ||||
|                     and is_mutable_field_type(field["type"]) | ||||
|                 ): | ||||
|                     mutable_default_args.append((field["py_name"], field["zero"])) | ||||
|                     field["zero"] = "None" | ||||
|  | ||||
|                 if field["zero"] == "None": | ||||
|                     template_data["typing_imports"].add("Optional") | ||||
|  | ||||
| @@ -407,6 +426,7 @@ def read_protobuf_service( | ||||
|                 ), | ||||
|                 "client_streaming": method.client_streaming, | ||||
|                 "server_streaming": method.server_streaming, | ||||
|                 "mutable_default_args": mutable_default_args, | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|   | ||||
| @@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | ||||
| {{ method.comment }} | ||||
|  | ||||
|         {% endif %} | ||||
|         {%- for py_name, zero in method.mutable_default_args %} | ||||
|         {{ py_name }} = {{ py_name }} or {{ zero }} | ||||
|         {% endfor %} | ||||
|  | ||||
|         {% if not method.client_streaming %} | ||||
|         request = {{ method.input }}() | ||||
|         {% for field in method.input_message.properties %} | ||||
|   | ||||
| @@ -1,11 +1,14 @@ | ||||
| import asyncio | ||||
| import sys | ||||
|  | ||||
| from tests.output_betterproto.service.service import ( | ||||
|     DoThingResponse, | ||||
|     DoThingRequest, | ||||
|     DoThingResponse, | ||||
|     GetThingRequest, | ||||
|     TestStub as ThingServiceClient, | ||||
| ) | ||||
| import grpclib | ||||
| import grpclib.metadata | ||||
| from grpclib.testing import ChannelFor | ||||
| import pytest | ||||
| from betterproto.grpc.util.async_channel import AsyncChannel | ||||
| @@ -35,6 +38,20 @@ async def test_simple_service_call(): | ||||
|         await _test_client(ThingServiceClient(channel)) | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.skipif( | ||||
|     sys.version_info < (3, 8), reason="async mock spy does works for python3.8+" | ||||
| ) | ||||
| async def test_service_call_mutable_defaults(mocker): | ||||
|     async with ChannelFor([ThingService()]) as channel: | ||||
|         client = ThingServiceClient(channel) | ||||
|         spy = mocker.spy(client, "_unary_unary") | ||||
|         await _test_client(client) | ||||
|         comments = spy.call_args_list[-1].args[1].comments | ||||
|         await _test_client(client) | ||||
|         assert spy.call_args_list[-1].args[1].comments is not comments | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_service_call_with_upfront_request_params(): | ||||
|     # Setting deadline | ||||
|   | ||||
| @@ -4,6 +4,7 @@ package service; | ||||
|  | ||||
| message DoThingRequest { | ||||
|   string name = 1; | ||||
|   repeated string comments = 2; | ||||
| } | ||||
|  | ||||
| message DoThingResponse { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user