Merge pull request #78 from boukeversteegh/pr/google

Basic general support for Google Protobuf
This commit is contained in:
nat 2020-06-11 10:50:12 +02:00 committed by GitHub
commit 9a45ea9f16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1601 additions and 161 deletions

View File

@ -256,6 +256,7 @@ Google provides several well-known message types like a timestamp, duration, and
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` | | `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` | | `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
| `google.protobuf.*Value` | `Optional[...]` | `None` | | `google.protobuf.*Value` | `Optional[...]` | `None` |
| `google.protobuf.*` | `betterproto.lib.google.protobuf.*` | `None` |
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects [td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime [dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
@ -354,6 +355,25 @@ $ pipenv run generate
$ pipenv run test $ pipenv run test
``` ```
### (Re)compiling Google Well-known Types
Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google).
Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests.
Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`.
Assuming your `google.protobuf` source files (included with all releases of `protoc`) are located in `/usr/local/include`, you can regenerate them as follows:
```sh
protoc \
--plugin=protoc-gen-custom=betterproto/plugin.py \
--custom_opt=INCLUDE_GOOGLE \
--custom_out=betterproto/lib \
-I /usr/local/include/ \
/usr/local/include/google/protobuf/*.proto
```
### TODO ### TODO
- [x] Fixed length fields - [x] Fixed length fields

View File

@ -941,19 +941,23 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
return (field_name, getattr(message, field_name)) return (field_name, getattr(message, field_name))
@dataclasses.dataclass # Circular import workaround: google.protobuf depends on base classes defined above.
class _Duration(Message): from .lib.google.protobuf import (
# Signed seconds of the span of time. Must be from -315,576,000,000 to Duration,
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60 Timestamp,
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years BoolValue,
seconds: int = int64_field(1) BytesValue,
# Signed fractions of a second at nanosecond resolution of the span of time. DoubleValue,
# Durations less than one second are represented with a 0 `seconds` field and FloatValue,
# a positive or negative `nanos` field. For durations of one second or more, Int32Value,
# a non-zero value for the `nanos` field must be of the same sign as the Int64Value,
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive. StringValue,
nanos: int = int32_field(2) UInt32Value,
UInt64Value,
)
class _Duration(Duration):
def to_timedelta(self) -> timedelta: def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@ -966,16 +970,7 @@ class _Duration(Message):
return ".".join(parts) + "s" return ".".join(parts) + "s"
@dataclasses.dataclass class _Timestamp(Timestamp):
class _Timestamp(Message):
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
seconds: int = int64_field(1)
# Non-negative fractions of a second at nanosecond resolution. Negative
# second values with fractions must still have non-negative nanos values that
# count forward in time. Must be from 0 to 999,999,999 inclusive.
nanos: int = int32_field(2)
def to_datetime(self) -> datetime: def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9) ts = self.seconds + (self.nanos / 1e9)
return datetime.fromtimestamp(ts, tz=timezone.utc) return datetime.fromtimestamp(ts, tz=timezone.utc)
@ -1016,63 +1011,18 @@ class _WrappedMessage(Message):
return self return self
@dataclasses.dataclass
class _BoolValue(_WrappedMessage):
value: bool = bool_field(1)
@dataclasses.dataclass
class _Int32Value(_WrappedMessage):
value: int = int32_field(1)
@dataclasses.dataclass
class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1)
@dataclasses.dataclass
class _Int64Value(_WrappedMessage):
value: int = int64_field(1)
@dataclasses.dataclass
class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1)
@dataclasses.dataclass
class _FloatValue(_WrappedMessage):
value: float = float_field(1)
@dataclasses.dataclass
class _DoubleValue(_WrappedMessage):
value: float = double_field(1)
@dataclasses.dataclass
class _StringValue(_WrappedMessage):
value: str = string_field(1)
@dataclasses.dataclass
class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> Type: def _get_wrapper(proto_type: str) -> Type:
"""Get the wrapper message class for a wrapped type.""" """Get the wrapper message class for a wrapped type."""
return { return {
TYPE_BOOL: _BoolValue, TYPE_BOOL: BoolValue,
TYPE_INT32: _Int32Value, TYPE_INT32: Int32Value,
TYPE_UINT32: _UInt32Value, TYPE_UINT32: UInt32Value,
TYPE_INT64: _Int64Value, TYPE_INT64: Int64Value,
TYPE_UINT64: _UInt64Value, TYPE_UINT64: UInt64Value,
TYPE_FLOAT: _FloatValue, TYPE_FLOAT: FloatValue,
TYPE_DOUBLE: _DoubleValue, TYPE_DOUBLE: DoubleValue,
TYPE_STRING: _StringValue, TYPE_STRING: StringValue,
TYPE_BYTES: _BytesValue, TYPE_BYTES: BytesValue,
}[proto_type] }[proto_type]

View File

View File

@ -0,0 +1,70 @@
from typing import Dict, Type
import stringcase
from betterproto import safe_snake_case
from betterproto.lib.google import protobuf as google_protobuf
WRAPPER_TYPES: Dict[str, Type] = {
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
"google.protobuf.FloatValue": google_protobuf.FloatValue,
"google.protobuf.Int32Value": google_protobuf.Int32Value,
"google.protobuf.Int64Value": google_protobuf.Int64Value,
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
"google.protobuf.BoolValue": google_protobuf.BoolValue,
"google.protobuf.StringValue": google_protobuf.StringValue,
"google.protobuf.BytesValue": google_protobuf.BytesValue,
}
def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
is_wrapper = type_name in WRAPPER_TYPES
if unwrap:
if is_wrapper:
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
# Use precompiled classes for google.protobuf.* objects
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
type_name = type_name.rsplit(".", maxsplit=1)[1]
import_package = "betterproto.lib.google.protobuf"
import_alias = safe_snake_case(import_package)
imports.add(f"import {import_package} as {import_alias}")
return f"{import_alias}.{type_name}"
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
parts[-1] = stringcase.pascalcase(parts[-1])
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
return type_name

View File

View File

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python #!/usr/bin/env python
from collections import defaultdict
import itertools import itertools
import os.path import os.path
import re
import stringcase import stringcase
import sys import sys
import textwrap import textwrap
from typing import Dict, List, Optional, Type from typing import List
from betterproto.casing import safe_snake_case from betterproto.casing import safe_snake_case
from betterproto.compile.importing import get_ref_type
import betterproto
try: try:
# betterproto[compiler] specific dependencies # betterproto[compiler] specific dependencies
@ -33,70 +35,6 @@ except ImportError as err:
raise SystemExit(1) raise SystemExit(1)
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(
lambda: None,
{
"google.protobuf.DoubleValue": google_wrappers.DoubleValue,
"google.protobuf.FloatValue": google_wrappers.FloatValue,
"google.protobuf.Int64Value": google_wrappers.Int64Value,
"google.protobuf.UInt64Value": google_wrappers.UInt64Value,
"google.protobuf.Int32Value": google_wrappers.Int32Value,
"google.protobuf.UInt32Value": google_wrappers.UInt32Value,
"google.protobuf.BoolValue": google_wrappers.BoolValue,
"google.protobuf.StringValue": google_wrappers.StringValue,
"google.protobuf.BytesValue": google_wrappers.BytesValue,
},
)
def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
# Check if type is wrapper.
wrapper_class = WRAPPER_TYPES[type_name]
if unwrap:
if wrapper_class:
wrapped_type = type(wrapper_class().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
elif wrapper_class:
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
return f"{wrapper_class.__name__}"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
parts[-1] = stringcase.pascalcase(parts[-1])
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
return type_name
def py_type( def py_type(
package: str, package: str,
imports: set, imports: set,
@ -182,6 +120,8 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
def generate_code(request, response): def generate_code(request, response):
plugin_options = request.parameter.split(",") if request.parameter else []
env = jinja2.Environment( env = jinja2.Environment(
trim_blocks=True, trim_blocks=True,
lstrip_blocks=True, lstrip_blocks=True,
@ -192,7 +132,8 @@ def generate_code(request, response):
output_map = {} output_map = {}
for proto_file in request.proto_file: for proto_file in request.proto_file:
out = proto_file.package out = proto_file.package
if out == "google.protobuf":
if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options:
continue continue
if not out: if not out:
@ -255,11 +196,13 @@ def generate_code(request, response):
field_type = f.Type.Name(f.type).lower()[5:] field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = "" field_wraps = ""
if f.type_name.startswith( match_wrapper = re.match(
".google.protobuf" r"\.google\.protobuf\.(.+)Value", f.type_name
) and f.type_name.endswith("Value"): )
w = f.type_name.split(".").pop()[:-5].upper() if match_wrapper:
field_wraps = f"betterproto.TYPE_{w}" wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
field_wraps = f"betterproto.{wrapped_type}"
map_types = None map_types = None
if f.type == 11: if f.type == 11:

View File

View File

@ -8,10 +8,11 @@ tests = {
"import_circular_dependency", # failing because of other bugs now "import_circular_dependency", # failing because of other bugs now
"import_packages_same_name", # 25 "import_packages_same_name", # 25
"oneof_enum", # 63 "oneof_enum", # 63
"googletypes_service_returns_empty", # 9
"casing_message_field_uppercase", # 11 "casing_message_field_uppercase", # 11
"namespace_keywords", # 70 "namespace_keywords", # 70
"namespace_builtin_types", # 53 "namespace_builtin_types", # 53
"googletypes_struct", # 9
"googletypes_value", # 9
} }
services = { services = {
@ -20,4 +21,5 @@ services = {
"service", "service",
"import_service_input_message", "import_service_input_message",
"googletypes_service_returns_empty", "googletypes_service_returns_empty",
"googletypes_service_returns_googletype",
} }

View File

@ -1,5 +1,7 @@
{ {
"maybe": false, "maybe": false,
"ts": "1972-01-01T10:00:20.021Z", "ts": "1972-01-01T10:00:20.021Z",
"duration": "1.200s" "duration": "1.200s",
"important": 10,
"empty": {}
} }

View File

@ -3,10 +3,12 @@ syntax = "proto3";
import "google/protobuf/duration.proto"; import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto"; import "google/protobuf/wrappers.proto";
import "google/protobuf/empty.proto";
message Test { message Test {
google.protobuf.BoolValue maybe = 1; google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2; google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3; google.protobuf.Duration duration = 3;
google.protobuf.Int32Value important = 4; google.protobuf.Int32Value important = 4;
google.protobuf.Empty empty = 5;
} }

View File

@ -1,6 +1,6 @@
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import google.protobuf.wrappers_pb2 as wrappers import betterproto.lib.google.protobuf as protobuf
import pytest import pytest
from betterproto.tests.mocks import MockChannel from betterproto.tests.mocks import MockChannel
@ -9,15 +9,15 @@ from betterproto.tests.output_betterproto.googletypes_response.googletypes_respo
) )
test_cases = [ test_cases = [
(TestStub.get_double, wrappers.DoubleValue, 2.5), (TestStub.get_double, protobuf.DoubleValue, 2.5),
(TestStub.get_float, wrappers.FloatValue, 2.5), (TestStub.get_float, protobuf.FloatValue, 2.5),
(TestStub.get_int64, wrappers.Int64Value, -64), (TestStub.get_int64, protobuf.Int64Value, -64),
(TestStub.get_u_int64, wrappers.UInt64Value, 64), (TestStub.get_u_int64, protobuf.UInt64Value, 64),
(TestStub.get_int32, wrappers.Int32Value, -32), (TestStub.get_int32, protobuf.Int32Value, -32),
(TestStub.get_u_int32, wrappers.UInt32Value, 32), (TestStub.get_u_int32, protobuf.UInt32Value, 32),
(TestStub.get_bool, wrappers.BoolValue, True), (TestStub.get_bool, protobuf.BoolValue, True),
(TestStub.get_string, wrappers.StringValue, "string"), (TestStub.get_string, protobuf.StringValue, "string"),
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]), (TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
] ]

View File

@ -0,0 +1,16 @@
syntax = "proto3";
import "google/protobuf/empty.proto";
import "google/protobuf/struct.proto";
// Tests that imports are generated correctly when returning Google well-known types
service Test {
rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty);
rpc GetStruct (RequestMessage) returns (google.protobuf.Struct);
rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue);
rpc GetValue (RequestMessage) returns (google.protobuf.Value);
}
message RequestMessage {
}

View File

@ -0,0 +1,5 @@
{
"struct": {
"key": true
}
}

View File

@ -0,0 +1,7 @@
syntax = "proto3";
import "google/protobuf/struct.proto";
message Test {
google.protobuf.Struct struct = 1;
}

View File

@ -0,0 +1,11 @@
{
"value1": "hello world",
"value2": true,
"value3": 1,
"value4": null,
"value5": [
1,
2,
3
]
}

View File

@ -0,0 +1,13 @@
syntax = "proto3";
import "google/protobuf/struct.proto";
// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values.
message Test {
google.protobuf.Value value1 = 1;
google.protobuf.Value value2 = 2;
google.protobuf.Value value3 = 3;
google.protobuf.Value value4 = 4;
google.protobuf.Value value5 = 5;
}

View File

@ -8,6 +8,7 @@ class MockChannel(Channel):
def __init__(self, responses=None) -> None: def __init__(self, responses=None) -> None:
self.responses = responses if responses else [] self.responses = responses if responses else []
self.requests = [] self.requests = []
self._loop = None
def request(self, route, cardinality, request, response_type, **kwargs): def request(self, route, cardinality, request, response_type, **kwargs):
self.requests.append( self.requests.append(

View File

@ -0,0 +1,82 @@
import pytest
from ..compile.importing import get_ref_type
@pytest.mark.parametrize(
["google_type", "expected_name", "expected_import"],
[
(
".google.protobuf.Empty",
"betterproto_lib_google_protobuf.Empty",
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
),
(
".google.protobuf.Struct",
"betterproto_lib_google_protobuf.Struct",
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
),
(
".google.protobuf.ListValue",
"betterproto_lib_google_protobuf.ListValue",
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
),
(
".google.protobuf.Value",
"betterproto_lib_google_protobuf.Value",
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
),
],
)
def test_import_google_wellknown_types_non_wrappers(
google_type: str, expected_name: str, expected_import: str
):
imports = set()
name = get_ref_type(package="", imports=imports, type_name=google_type)
assert name == expected_name
assert imports.__contains__(expected_import)
@pytest.mark.parametrize(
["google_type", "expected_name"],
[
(".google.protobuf.DoubleValue", "Optional[float]"),
(".google.protobuf.FloatValue", "Optional[float]"),
(".google.protobuf.Int32Value", "Optional[int]"),
(".google.protobuf.Int64Value", "Optional[int]"),
(".google.protobuf.UInt32Value", "Optional[int]"),
(".google.protobuf.UInt64Value", "Optional[int]"),
(".google.protobuf.BoolValue", "Optional[bool]"),
(".google.protobuf.StringValue", "Optional[str]"),
(".google.protobuf.BytesValue", "Optional[bytes]"),
],
)
def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
imports = set()
name = get_ref_type(package="", imports=imports, type_name=google_type)
assert name == expected_name
assert imports == set()
@pytest.mark.parametrize(
["google_type", "expected_name"],
[
(".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"),
(".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"),
(".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"),
(".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"),
(".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"),
(".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"),
(".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"),
(".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"),
(".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
],
)
def test_importing_google_wrappers_without_unwrapping(
google_type: str, expected_name: str
):
name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
assert name == expected_name

View File

@ -30,6 +30,10 @@ class TestCases:
test for test in _messages if get_test_case_json_data(test) test for test in _messages if get_test_case_json_data(test)
} }
unknown_xfail_tests = xfail - _all
if unknown_xfail_tests:
raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}")
self.all = self.apply_xfail_marks(_all, xfail) self.all = self.apply_xfail_marks(_all, xfail)
self.services = self.apply_xfail_marks(_services, xfail) self.services = self.apply_xfail_marks(_services, xfail)
self.messages = self.apply_xfail_marks(_messages, xfail) self.messages = self.apply_xfail_marks(_messages, xfail)
@ -110,7 +114,7 @@ def test_message_json(repeat, test_data: TestData) -> None:
message.from_json(json_data) message.from_json(json_data)
message_json = message.to_json(0) message_json = message.to_json(0)
assert json.loads(json_data) == json.loads(message_json) assert json.loads(message_json) == json.loads(json_data)
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True)