From 061bf86a9cce960d0bdff78baabb7269c7facd3f Mon Sep 17 00:00:00 2001 From: Danny Weinberg Date: Thu, 4 Jun 2020 11:04:36 -0700 Subject: [PATCH 1/5] Set serialized_on_wire when message contains only lists This fixes a bug where serialized_on_wire was not set when a message contained only repeated values (eg in a list or map). The fix here is to just set it to true in the `parse` method as soon as we receive any valid data. This also adds a test to expose the behavior. --- betterproto/__init__.py | 3 ++ .../tests/inputs/repeated/repeated.proto | 6 +++ .../tests/inputs/repeated/test_repeated.py | 39 +++++++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 betterproto/tests/inputs/repeated/test_repeated.py diff --git a/betterproto/__init__.py b/betterproto/__init__.py index f394b41..c728a7c 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -747,6 +747,9 @@ class Message(ABC): self._unknown_fields += parsed.raw continue + # Got some data over the wire + self._serialized_on_wire = True + meta = self._betterproto.meta_by_field_name[field_name] value: Any diff --git a/betterproto/tests/inputs/repeated/repeated.proto b/betterproto/tests/inputs/repeated/repeated.proto index 42c1132..816bb26 100644 --- a/betterproto/tests/inputs/repeated/repeated.proto +++ b/betterproto/tests/inputs/repeated/repeated.proto @@ -1,5 +1,11 @@ syntax = "proto3"; +package repeated; + message Test { repeated string names = 1; } + +service ExampleService { + rpc DoThing (Test) returns (Test); +} diff --git a/betterproto/tests/inputs/repeated/test_repeated.py b/betterproto/tests/inputs/repeated/test_repeated.py new file mode 100644 index 0000000..7182a63 --- /dev/null +++ b/betterproto/tests/inputs/repeated/test_repeated.py @@ -0,0 +1,39 @@ +from typing import Dict + +import grpclib.const +import grpclib.server +import pytest +from grpclib.testing import ChannelFor + +import betterproto +from betterproto.tests.output_betterproto.repeated.repeated import ( + ExampleServiceStub, + Test, +) + + +class ExampleService: + async def DoThing( + self, stream: "grpclib.server.Stream[Test, Test]" + ): + request = await stream.recv_message() + await stream.send_message(request) + + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { + "/repeated.ExampleService/DoThing": grpclib.const.Handler( + self.DoThing, + grpclib.const.Cardinality.UNARY_UNARY, + Test, + Test, + ), + } + + +@pytest.mark.asyncio +async def test_sets_serialized_on_wire() -> None: + async with ChannelFor([ExampleService()]) as channel: + stub = ExampleServiceStub(channel) + response = await stub.do_thing(names=['a', 'b', 'c']) + assert betterproto.serialized_on_wire(response) + assert response.names == ['a', 'b', 'c'] From 67422db6b928079f0248e22afe7a9229a567e95f Mon Sep 17 00:00:00 2001 From: Danny Weinberg Date: Thu, 4 Jun 2020 11:34:20 -0700 Subject: [PATCH 2/5] Fix formatting --- betterproto/tests/inputs/repeated/test_repeated.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/betterproto/tests/inputs/repeated/test_repeated.py b/betterproto/tests/inputs/repeated/test_repeated.py index 7182a63..ee4e079 100644 --- a/betterproto/tests/inputs/repeated/test_repeated.py +++ b/betterproto/tests/inputs/repeated/test_repeated.py @@ -13,19 +13,14 @@ from betterproto.tests.output_betterproto.repeated.repeated import ( class ExampleService: - async def DoThing( - self, stream: "grpclib.server.Stream[Test, Test]" - ): + async def DoThing(self, stream: "grpclib.server.Stream[Test, Test]"): request = await stream.recv_message() await stream.send_message(request) def __mapping__(self) -> Dict[str, grpclib.const.Handler]: return { "/repeated.ExampleService/DoThing": grpclib.const.Handler( - self.DoThing, - grpclib.const.Cardinality.UNARY_UNARY, - Test, - Test, + self.DoThing, grpclib.const.Cardinality.UNARY_UNARY, Test, Test, ), } @@ -34,6 +29,6 @@ class ExampleService: async def test_sets_serialized_on_wire() -> None: async with ChannelFor([ExampleService()]) as channel: stub = ExampleServiceStub(channel) - response = await stub.do_thing(names=['a', 'b', 'c']) + response = await stub.do_thing(names=["a", "b", "c"]) assert betterproto.serialized_on_wire(response) - assert response.names == ['a', 'b', 'c'] + assert response.names == ["a", "b", "c"] From a914306f33b47e56f4f0a51d45c9da1f06b48e88 Mon Sep 17 00:00:00 2001 From: Danny Weinberg Date: Thu, 4 Jun 2020 13:42:07 -0700 Subject: [PATCH 3/5] Put test into `test_features`, simplify to call `parse` directly --- .../tests/inputs/repeated/repeated.proto | 6 ---- .../tests/inputs/repeated/test_repeated.py | 34 ------------------- betterproto/tests/test_features.py | 17 +++++++++- 3 files changed, 16 insertions(+), 41 deletions(-) delete mode 100644 betterproto/tests/inputs/repeated/test_repeated.py diff --git a/betterproto/tests/inputs/repeated/repeated.proto b/betterproto/tests/inputs/repeated/repeated.proto index 816bb26..42c1132 100644 --- a/betterproto/tests/inputs/repeated/repeated.proto +++ b/betterproto/tests/inputs/repeated/repeated.proto @@ -1,11 +1,5 @@ syntax = "proto3"; -package repeated; - message Test { repeated string names = 1; } - -service ExampleService { - rpc DoThing (Test) returns (Test); -} diff --git a/betterproto/tests/inputs/repeated/test_repeated.py b/betterproto/tests/inputs/repeated/test_repeated.py deleted file mode 100644 index ee4e079..0000000 --- a/betterproto/tests/inputs/repeated/test_repeated.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Dict - -import grpclib.const -import grpclib.server -import pytest -from grpclib.testing import ChannelFor - -import betterproto -from betterproto.tests.output_betterproto.repeated.repeated import ( - ExampleServiceStub, - Test, -) - - -class ExampleService: - async def DoThing(self, stream: "grpclib.server.Stream[Test, Test]"): - request = await stream.recv_message() - await stream.send_message(request) - - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: - return { - "/repeated.ExampleService/DoThing": grpclib.const.Handler( - self.DoThing, grpclib.const.Cardinality.UNARY_UNARY, Test, Test, - ), - } - - -@pytest.mark.asyncio -async def test_sets_serialized_on_wire() -> None: - async with ChannelFor([ExampleService()]) as channel: - stub = ExampleServiceStub(channel) - response = await stub.do_thing(names=["a", "b", "c"]) - assert betterproto.serialized_on_wire(response) - assert response.names == ["a", "b", "c"] diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 47019e1..7c3247a 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -1,6 +1,6 @@ import betterproto from dataclasses import dataclass -from typing import Optional +from typing import Optional, List, Dict def test_has_field(): @@ -32,6 +32,21 @@ def test_has_field(): foo.bar = Bar() assert betterproto.serialized_on_wire(foo.bar) == False + @dataclass + class WithCollections(betterproto.Message): + test_list: List[str] = betterproto.string_field(1) + test_map: Dict[str, str] = betterproto.map_field(2, betterproto.TYPE_STRING, betterproto.TYPE_STRING) + + # Unset with empty collections + with_collections_empty = WithCollections().parse(bytes(WithCollections())) + assert betterproto.serialized_on_wire(with_collections_empty) == False + + # Set with non-empty collections + with_collections_list = WithCollections().parse(bytes(WithCollections(test_list=['a', 'b', 'c']))) + assert betterproto.serialized_on_wire(with_collections_list) == True + with_collections_map = WithCollections().parse(bytes(WithCollections(test_map={'a': 'b', 'c': 'd'}))) + assert betterproto.serialized_on_wire(with_collections_map) == True + def test_class_init(): @dataclass From 5c700618fde51c05b4d1ab4b8caf6ff7ab766164 Mon Sep 17 00:00:00 2001 From: Danny Weinberg Date: Thu, 4 Jun 2020 13:42:43 -0700 Subject: [PATCH 4/5] Black again lol --- betterproto/tests/test_features.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 7c3247a..d714caf 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -35,16 +35,22 @@ def test_has_field(): @dataclass class WithCollections(betterproto.Message): test_list: List[str] = betterproto.string_field(1) - test_map: Dict[str, str] = betterproto.map_field(2, betterproto.TYPE_STRING, betterproto.TYPE_STRING) + test_map: Dict[str, str] = betterproto.map_field( + 2, betterproto.TYPE_STRING, betterproto.TYPE_STRING + ) # Unset with empty collections with_collections_empty = WithCollections().parse(bytes(WithCollections())) assert betterproto.serialized_on_wire(with_collections_empty) == False # Set with non-empty collections - with_collections_list = WithCollections().parse(bytes(WithCollections(test_list=['a', 'b', 'c']))) + with_collections_list = WithCollections().parse( + bytes(WithCollections(test_list=["a", "b", "c"])) + ) assert betterproto.serialized_on_wire(with_collections_list) == True - with_collections_map = WithCollections().parse(bytes(WithCollections(test_map={'a': 'b', 'c': 'd'}))) + with_collections_map = WithCollections().parse( + bytes(WithCollections(test_map={"a": "b", "c": "d"})) + ) assert betterproto.serialized_on_wire(with_collections_map) == True From 28a288924f4f714d5543f20b6d2e8f5e6c8dcf7f Mon Sep 17 00:00:00 2001 From: Danny Weinberg Date: Thu, 4 Jun 2020 16:20:32 -0700 Subject: [PATCH 5/5] Change to have `parse` *always* set serialized_on_wire --- betterproto/__init__.py | 6 +++--- betterproto/tests/test_features.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index c728a7c..e9f2190 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -741,15 +741,15 @@ class Message(ABC): Parse the binary encoded Protobuf into this message instance. This returns the instance itself and is therefore assignable and chainable. """ + # Got some data over the wire + self._serialized_on_wire = True + for parsed in parse_fields(data): field_name = self._betterproto.field_name_by_number.get(parsed.number) if not field_name: self._unknown_fields += parsed.raw continue - # Got some data over the wire - self._serialized_on_wire = True - meta = self._betterproto.meta_by_field_name[field_name] value: Any diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index d714caf..024ab08 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -39,11 +39,9 @@ def test_has_field(): 2, betterproto.TYPE_STRING, betterproto.TYPE_STRING ) - # Unset with empty collections + # Is always set from parse, even if all collections are empty with_collections_empty = WithCollections().parse(bytes(WithCollections())) - assert betterproto.serialized_on_wire(with_collections_empty) == False - - # Set with non-empty collections + assert betterproto.serialized_on_wire(with_collections_empty) == True with_collections_list = WithCollections().parse( bytes(WithCollections(test_list=["a", "b", "c"])) )