Merge pull request #78 from boukeversteegh/pr/google
Basic general support for Google Protobuf
This commit is contained in:
		
							
								
								
									
										20
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								README.md
									
									
									
									
									
								
							@@ -256,6 +256,7 @@ Google provides several well-known message types like a timestamp, duration, and
 | 
				
			|||||||
| `google.protobuf.duration`  | [`datetime.timedelta`][td]               | `0`                    |
 | 
					| `google.protobuf.duration`  | [`datetime.timedelta`][td]               | `0`                    |
 | 
				
			||||||
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
 | 
					| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
 | 
				
			||||||
| `google.protobuf.*Value`    | `Optional[...]`                          | `None`                 |
 | 
					| `google.protobuf.*Value`    | `Optional[...]`                          | `None`                 |
 | 
				
			||||||
 | 
					| `google.protobuf.*`         | `betterproto.lib.google.protobuf.*`      | `None`                 |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
 | 
					[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
 | 
				
			||||||
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
 | 
					[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
 | 
				
			||||||
@@ -354,6 +355,25 @@ $ pipenv run generate
 | 
				
			|||||||
$ pipenv run test
 | 
					$ pipenv run test
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### (Re)compiling Google Well-known Types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google).
 | 
				
			||||||
 | 
					Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`. 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Assuming your `google.protobuf` source files (included with all releases of `protoc`) are located in `/usr/local/include`, you can regenerate them as follows:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```sh
 | 
				
			||||||
 | 
					protoc \
 | 
				
			||||||
 | 
					    --plugin=protoc-gen-custom=betterproto/plugin.py \
 | 
				
			||||||
 | 
					    --custom_opt=INCLUDE_GOOGLE \
 | 
				
			||||||
 | 
					    --custom_out=betterproto/lib \
 | 
				
			||||||
 | 
					    -I /usr/local/include/ \
 | 
				
			||||||
 | 
					    /usr/local/include/google/protobuf/*.proto
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### TODO
 | 
					### TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- [x] Fixed length fields
 | 
					- [x] Fixed length fields
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -941,19 +941,23 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
 | 
				
			|||||||
    return (field_name, getattr(message, field_name))
 | 
					    return (field_name, getattr(message, field_name))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					# Circular import workaround: google.protobuf depends on base classes defined above.
 | 
				
			||||||
class _Duration(Message):
 | 
					from .lib.google.protobuf import (
 | 
				
			||||||
    # Signed seconds of the span of time. Must be from -315,576,000,000 to
 | 
					    Duration,
 | 
				
			||||||
    # +315,576,000,000 inclusive. Note: these bounds are computed from: 60
 | 
					    Timestamp,
 | 
				
			||||||
    # sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
 | 
					    BoolValue,
 | 
				
			||||||
    seconds: int = int64_field(1)
 | 
					    BytesValue,
 | 
				
			||||||
    # Signed fractions of a second at nanosecond resolution of the span of time.
 | 
					    DoubleValue,
 | 
				
			||||||
    # Durations less than one second are represented with a 0 `seconds` field and
 | 
					    FloatValue,
 | 
				
			||||||
    # a positive or negative `nanos` field. For durations of one second or more,
 | 
					    Int32Value,
 | 
				
			||||||
    # a non-zero value for the `nanos` field must be of the same sign as the
 | 
					    Int64Value,
 | 
				
			||||||
    # `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
 | 
					    StringValue,
 | 
				
			||||||
    nanos: int = int32_field(2)
 | 
					    UInt32Value,
 | 
				
			||||||
 | 
					    UInt64Value,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _Duration(Duration):
 | 
				
			||||||
    def to_timedelta(self) -> timedelta:
 | 
					    def to_timedelta(self) -> timedelta:
 | 
				
			||||||
        return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
 | 
					        return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -966,16 +970,7 @@ class _Duration(Message):
 | 
				
			|||||||
        return ".".join(parts) + "s"
 | 
					        return ".".join(parts) + "s"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					class _Timestamp(Timestamp):
 | 
				
			||||||
class _Timestamp(Message):
 | 
					 | 
				
			||||||
    # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
 | 
					 | 
				
			||||||
    # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
 | 
					 | 
				
			||||||
    seconds: int = int64_field(1)
 | 
					 | 
				
			||||||
    # Non-negative fractions of a second at nanosecond resolution. Negative
 | 
					 | 
				
			||||||
    # second values with fractions must still have non-negative nanos values that
 | 
					 | 
				
			||||||
    # count forward in time. Must be from 0 to 999,999,999 inclusive.
 | 
					 | 
				
			||||||
    nanos: int = int32_field(2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def to_datetime(self) -> datetime:
 | 
					    def to_datetime(self) -> datetime:
 | 
				
			||||||
        ts = self.seconds + (self.nanos / 1e9)
 | 
					        ts = self.seconds + (self.nanos / 1e9)
 | 
				
			||||||
        return datetime.fromtimestamp(ts, tz=timezone.utc)
 | 
					        return datetime.fromtimestamp(ts, tz=timezone.utc)
 | 
				
			||||||
@@ -1016,63 +1011,18 @@ class _WrappedMessage(Message):
 | 
				
			|||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _BoolValue(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: bool = bool_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _Int32Value(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: int = int32_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _UInt32Value(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: int = uint32_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _Int64Value(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: int = int64_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _UInt64Value(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: int = uint64_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _FloatValue(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: float = float_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _DoubleValue(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: float = double_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _StringValue(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: str = string_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					 | 
				
			||||||
class _BytesValue(_WrappedMessage):
 | 
					 | 
				
			||||||
    value: bytes = bytes_field(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_wrapper(proto_type: str) -> Type:
 | 
					def _get_wrapper(proto_type: str) -> Type:
 | 
				
			||||||
    """Get the wrapper message class for a wrapped type."""
 | 
					    """Get the wrapper message class for a wrapped type."""
 | 
				
			||||||
    return {
 | 
					    return {
 | 
				
			||||||
        TYPE_BOOL: _BoolValue,
 | 
					        TYPE_BOOL: BoolValue,
 | 
				
			||||||
        TYPE_INT32: _Int32Value,
 | 
					        TYPE_INT32: Int32Value,
 | 
				
			||||||
        TYPE_UINT32: _UInt32Value,
 | 
					        TYPE_UINT32: UInt32Value,
 | 
				
			||||||
        TYPE_INT64: _Int64Value,
 | 
					        TYPE_INT64: Int64Value,
 | 
				
			||||||
        TYPE_UINT64: _UInt64Value,
 | 
					        TYPE_UINT64: UInt64Value,
 | 
				
			||||||
        TYPE_FLOAT: _FloatValue,
 | 
					        TYPE_FLOAT: FloatValue,
 | 
				
			||||||
        TYPE_DOUBLE: _DoubleValue,
 | 
					        TYPE_DOUBLE: DoubleValue,
 | 
				
			||||||
        TYPE_STRING: _StringValue,
 | 
					        TYPE_STRING: StringValue,
 | 
				
			||||||
        TYPE_BYTES: _BytesValue,
 | 
					        TYPE_BYTES: BytesValue,
 | 
				
			||||||
    }[proto_type]
 | 
					    }[proto_type]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										0
									
								
								betterproto/compile/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/compile/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										70
									
								
								betterproto/compile/importing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								betterproto/compile/importing.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,70 @@
 | 
				
			|||||||
 | 
					from typing import Dict, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import stringcase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from betterproto import safe_snake_case
 | 
				
			||||||
 | 
					from betterproto.lib.google import protobuf as google_protobuf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					WRAPPER_TYPES: Dict[str, Type] = {
 | 
				
			||||||
 | 
					    "google.protobuf.DoubleValue": google_protobuf.DoubleValue,
 | 
				
			||||||
 | 
					    "google.protobuf.FloatValue": google_protobuf.FloatValue,
 | 
				
			||||||
 | 
					    "google.protobuf.Int32Value": google_protobuf.Int32Value,
 | 
				
			||||||
 | 
					    "google.protobuf.Int64Value": google_protobuf.Int64Value,
 | 
				
			||||||
 | 
					    "google.protobuf.UInt32Value": google_protobuf.UInt32Value,
 | 
				
			||||||
 | 
					    "google.protobuf.UInt64Value": google_protobuf.UInt64Value,
 | 
				
			||||||
 | 
					    "google.protobuf.BoolValue": google_protobuf.BoolValue,
 | 
				
			||||||
 | 
					    "google.protobuf.StringValue": google_protobuf.StringValue,
 | 
				
			||||||
 | 
					    "google.protobuf.BytesValue": google_protobuf.BytesValue,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_ref_type(
 | 
				
			||||||
 | 
					    package: str, imports: set, type_name: str, unwrap: bool = True
 | 
				
			||||||
 | 
					) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Return a Python type name for a proto type reference. Adds the import if
 | 
				
			||||||
 | 
					    necessary. Unwraps well known type if required.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # If the package name is a blank string, then this should still work
 | 
				
			||||||
 | 
					    # because by convention packages are lowercase and message/enum types are
 | 
				
			||||||
 | 
					    # pascal-cased. May require refactoring in the future.
 | 
				
			||||||
 | 
					    type_name = type_name.lstrip(".")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    is_wrapper = type_name in WRAPPER_TYPES
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if unwrap:
 | 
				
			||||||
 | 
					        if is_wrapper:
 | 
				
			||||||
 | 
					            wrapped_type = type(WRAPPER_TYPES[type_name]().value)
 | 
				
			||||||
 | 
					            return f"Optional[{wrapped_type.__name__}]"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if type_name == "google.protobuf.Duration":
 | 
				
			||||||
 | 
					            return "timedelta"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if type_name == "google.protobuf.Timestamp":
 | 
				
			||||||
 | 
					            return "datetime"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if type_name.startswith(package):
 | 
				
			||||||
 | 
					        parts = type_name.lstrip(package).lstrip(".").split(".")
 | 
				
			||||||
 | 
					        if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
 | 
				
			||||||
 | 
					            # This is the current package, which has nested types flattened.
 | 
				
			||||||
 | 
					            # foo.bar_thing => FooBarThing
 | 
				
			||||||
 | 
					            cased = [stringcase.pascalcase(part) for part in parts]
 | 
				
			||||||
 | 
					            type_name = f'"{"".join(cased)}"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Use precompiled classes for google.protobuf.* objects
 | 
				
			||||||
 | 
					    if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
 | 
				
			||||||
 | 
					        type_name = type_name.rsplit(".", maxsplit=1)[1]
 | 
				
			||||||
 | 
					        import_package = "betterproto.lib.google.protobuf"
 | 
				
			||||||
 | 
					        import_alias = safe_snake_case(import_package)
 | 
				
			||||||
 | 
					        imports.add(f"import {import_package} as {import_alias}")
 | 
				
			||||||
 | 
					        return f"{import_alias}.{type_name}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "." in type_name:
 | 
				
			||||||
 | 
					        # This is imported from another package. No need
 | 
				
			||||||
 | 
					        # to use a forward ref and we need to add the import.
 | 
				
			||||||
 | 
					        parts = type_name.split(".")
 | 
				
			||||||
 | 
					        parts[-1] = stringcase.pascalcase(parts[-1])
 | 
				
			||||||
 | 
					        imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
 | 
				
			||||||
 | 
					        type_name = f"{parts[-2]}.{parts[-1]}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return type_name
 | 
				
			||||||
							
								
								
									
										0
									
								
								betterproto/lib/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/lib/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								betterproto/lib/google/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/lib/google/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										1312
									
								
								betterproto/lib/google/protobuf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1312
									
								
								betterproto/lib/google/protobuf.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -1,13 +1,15 @@
 | 
				
			|||||||
#!/usr/bin/env python
 | 
					#!/usr/bin/env python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections import defaultdict
 | 
					 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import os.path
 | 
					import os.path
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
import stringcase
 | 
					import stringcase
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import textwrap
 | 
					import textwrap
 | 
				
			||||||
from typing import Dict, List, Optional, Type
 | 
					from typing import List
 | 
				
			||||||
from betterproto.casing import safe_snake_case
 | 
					from betterproto.casing import safe_snake_case
 | 
				
			||||||
 | 
					from betterproto.compile.importing import get_ref_type
 | 
				
			||||||
 | 
					import betterproto
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    # betterproto[compiler] specific dependencies
 | 
					    # betterproto[compiler] specific dependencies
 | 
				
			||||||
@@ -33,70 +35,6 @@ except ImportError as err:
 | 
				
			|||||||
    raise SystemExit(1)
 | 
					    raise SystemExit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(
 | 
					 | 
				
			||||||
    lambda: None,
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        "google.protobuf.DoubleValue": google_wrappers.DoubleValue,
 | 
					 | 
				
			||||||
        "google.protobuf.FloatValue": google_wrappers.FloatValue,
 | 
					 | 
				
			||||||
        "google.protobuf.Int64Value": google_wrappers.Int64Value,
 | 
					 | 
				
			||||||
        "google.protobuf.UInt64Value": google_wrappers.UInt64Value,
 | 
					 | 
				
			||||||
        "google.protobuf.Int32Value": google_wrappers.Int32Value,
 | 
					 | 
				
			||||||
        "google.protobuf.UInt32Value": google_wrappers.UInt32Value,
 | 
					 | 
				
			||||||
        "google.protobuf.BoolValue": google_wrappers.BoolValue,
 | 
					 | 
				
			||||||
        "google.protobuf.StringValue": google_wrappers.StringValue,
 | 
					 | 
				
			||||||
        "google.protobuf.BytesValue": google_wrappers.BytesValue,
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_ref_type(
 | 
					 | 
				
			||||||
    package: str, imports: set, type_name: str, unwrap: bool = True
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Return a Python type name for a proto type reference. Adds the import if
 | 
					 | 
				
			||||||
    necessary. Unwraps well known type if required.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # If the package name is a blank string, then this should still work
 | 
					 | 
				
			||||||
    # because by convention packages are lowercase and message/enum types are
 | 
					 | 
				
			||||||
    # pascal-cased. May require refactoring in the future.
 | 
					 | 
				
			||||||
    type_name = type_name.lstrip(".")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Check if type is wrapper.
 | 
					 | 
				
			||||||
    wrapper_class = WRAPPER_TYPES[type_name]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if unwrap:
 | 
					 | 
				
			||||||
        if wrapper_class:
 | 
					 | 
				
			||||||
            wrapped_type = type(wrapper_class().value)
 | 
					 | 
				
			||||||
            return f"Optional[{wrapped_type.__name__}]"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if type_name == "google.protobuf.Duration":
 | 
					 | 
				
			||||||
            return "timedelta"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if type_name == "google.protobuf.Timestamp":
 | 
					 | 
				
			||||||
            return "datetime"
 | 
					 | 
				
			||||||
    elif wrapper_class:
 | 
					 | 
				
			||||||
        imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
 | 
					 | 
				
			||||||
        return f"{wrapper_class.__name__}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if type_name.startswith(package):
 | 
					 | 
				
			||||||
        parts = type_name.lstrip(package).lstrip(".").split(".")
 | 
					 | 
				
			||||||
        if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
 | 
					 | 
				
			||||||
            # This is the current package, which has nested types flattened.
 | 
					 | 
				
			||||||
            # foo.bar_thing => FooBarThing
 | 
					 | 
				
			||||||
            cased = [stringcase.pascalcase(part) for part in parts]
 | 
					 | 
				
			||||||
            type_name = f'"{"".join(cased)}"'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if "." in type_name:
 | 
					 | 
				
			||||||
        # This is imported from another package. No need
 | 
					 | 
				
			||||||
        # to use a forward ref and we need to add the import.
 | 
					 | 
				
			||||||
        parts = type_name.split(".")
 | 
					 | 
				
			||||||
        parts[-1] = stringcase.pascalcase(parts[-1])
 | 
					 | 
				
			||||||
        imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
 | 
					 | 
				
			||||||
        type_name = f"{parts[-2]}.{parts[-1]}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return type_name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def py_type(
 | 
					def py_type(
 | 
				
			||||||
    package: str,
 | 
					    package: str,
 | 
				
			||||||
    imports: set,
 | 
					    imports: set,
 | 
				
			||||||
@@ -182,6 +120,8 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_code(request, response):
 | 
					def generate_code(request, response):
 | 
				
			||||||
 | 
					    plugin_options = request.parameter.split(",") if request.parameter else []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    env = jinja2.Environment(
 | 
					    env = jinja2.Environment(
 | 
				
			||||||
        trim_blocks=True,
 | 
					        trim_blocks=True,
 | 
				
			||||||
        lstrip_blocks=True,
 | 
					        lstrip_blocks=True,
 | 
				
			||||||
@@ -192,7 +132,8 @@ def generate_code(request, response):
 | 
				
			|||||||
    output_map = {}
 | 
					    output_map = {}
 | 
				
			||||||
    for proto_file in request.proto_file:
 | 
					    for proto_file in request.proto_file:
 | 
				
			||||||
        out = proto_file.package
 | 
					        out = proto_file.package
 | 
				
			||||||
        if out == "google.protobuf":
 | 
					
 | 
				
			||||||
 | 
					        if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options:
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not out:
 | 
					        if not out:
 | 
				
			||||||
@@ -255,11 +196,13 @@ def generate_code(request, response):
 | 
				
			|||||||
                        field_type = f.Type.Name(f.type).lower()[5:]
 | 
					                        field_type = f.Type.Name(f.type).lower()[5:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        field_wraps = ""
 | 
					                        field_wraps = ""
 | 
				
			||||||
                        if f.type_name.startswith(
 | 
					                        match_wrapper = re.match(
 | 
				
			||||||
                            ".google.protobuf"
 | 
					                            r"\.google\.protobuf\.(.+)Value", f.type_name
 | 
				
			||||||
                        ) and f.type_name.endswith("Value"):
 | 
					                        )
 | 
				
			||||||
                            w = f.type_name.split(".").pop()[:-5].upper()
 | 
					                        if match_wrapper:
 | 
				
			||||||
                            field_wraps = f"betterproto.TYPE_{w}"
 | 
					                            wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
 | 
				
			||||||
 | 
					                            if hasattr(betterproto, wrapped_type):
 | 
				
			||||||
 | 
					                                field_wraps = f"betterproto.{wrapped_type}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        map_types = None
 | 
					                        map_types = None
 | 
				
			||||||
                        if f.type == 11:
 | 
					                        if f.type == 11:
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										0
									
								
								betterproto/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -8,10 +8,11 @@ tests = {
 | 
				
			|||||||
    "import_circular_dependency",  # failing because of other bugs now
 | 
					    "import_circular_dependency",  # failing because of other bugs now
 | 
				
			||||||
    "import_packages_same_name",  # 25
 | 
					    "import_packages_same_name",  # 25
 | 
				
			||||||
    "oneof_enum",  # 63
 | 
					    "oneof_enum",  # 63
 | 
				
			||||||
    "googletypes_service_returns_empty",  # 9
 | 
					 | 
				
			||||||
    "casing_message_field_uppercase",  # 11
 | 
					    "casing_message_field_uppercase",  # 11
 | 
				
			||||||
    "namespace_keywords",  # 70
 | 
					    "namespace_keywords",  # 70
 | 
				
			||||||
    "namespace_builtin_types",  # 53
 | 
					    "namespace_builtin_types",  # 53
 | 
				
			||||||
 | 
					    "googletypes_struct",  # 9
 | 
				
			||||||
 | 
					    "googletypes_value",  # 9
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
services = {
 | 
					services = {
 | 
				
			||||||
@@ -20,4 +21,5 @@ services = {
 | 
				
			|||||||
    "service",
 | 
					    "service",
 | 
				
			||||||
    "import_service_input_message",
 | 
					    "import_service_input_message",
 | 
				
			||||||
    "googletypes_service_returns_empty",
 | 
					    "googletypes_service_returns_empty",
 | 
				
			||||||
 | 
					    "googletypes_service_returns_googletype",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,7 @@
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
  "maybe": false,
 | 
					  "maybe": false,
 | 
				
			||||||
  "ts": "1972-01-01T10:00:20.021Z",
 | 
					  "ts": "1972-01-01T10:00:20.021Z",
 | 
				
			||||||
  "duration": "1.200s"
 | 
					  "duration": "1.200s",
 | 
				
			||||||
 | 
					  "important": 10,
 | 
				
			||||||
 | 
					  "empty": {}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,10 +3,12 @@ syntax = "proto3";
 | 
				
			|||||||
import "google/protobuf/duration.proto";
 | 
					import "google/protobuf/duration.proto";
 | 
				
			||||||
import "google/protobuf/timestamp.proto";
 | 
					import "google/protobuf/timestamp.proto";
 | 
				
			||||||
import "google/protobuf/wrappers.proto";
 | 
					import "google/protobuf/wrappers.proto";
 | 
				
			||||||
 | 
					import "google/protobuf/empty.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message Test {
 | 
					message Test {
 | 
				
			||||||
  google.protobuf.BoolValue maybe = 1;
 | 
					  google.protobuf.BoolValue maybe = 1;
 | 
				
			||||||
  google.protobuf.Timestamp ts = 2;
 | 
					  google.protobuf.Timestamp ts = 2;
 | 
				
			||||||
  google.protobuf.Duration duration = 3;
 | 
					  google.protobuf.Duration duration = 3;
 | 
				
			||||||
  google.protobuf.Int32Value important = 4;
 | 
					  google.protobuf.Int32Value important = 4;
 | 
				
			||||||
 | 
					  google.protobuf.Empty empty = 5;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,6 @@
 | 
				
			|||||||
from typing import Any, Callable, Optional
 | 
					from typing import Any, Callable, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import google.protobuf.wrappers_pb2 as wrappers
 | 
					import betterproto.lib.google.protobuf as protobuf
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from betterproto.tests.mocks import MockChannel
 | 
					from betterproto.tests.mocks import MockChannel
 | 
				
			||||||
@@ -9,15 +9,15 @@ from betterproto.tests.output_betterproto.googletypes_response.googletypes_respo
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
test_cases = [
 | 
					test_cases = [
 | 
				
			||||||
    (TestStub.get_double, wrappers.DoubleValue, 2.5),
 | 
					    (TestStub.get_double, protobuf.DoubleValue, 2.5),
 | 
				
			||||||
    (TestStub.get_float, wrappers.FloatValue, 2.5),
 | 
					    (TestStub.get_float, protobuf.FloatValue, 2.5),
 | 
				
			||||||
    (TestStub.get_int64, wrappers.Int64Value, -64),
 | 
					    (TestStub.get_int64, protobuf.Int64Value, -64),
 | 
				
			||||||
    (TestStub.get_u_int64, wrappers.UInt64Value, 64),
 | 
					    (TestStub.get_u_int64, protobuf.UInt64Value, 64),
 | 
				
			||||||
    (TestStub.get_int32, wrappers.Int32Value, -32),
 | 
					    (TestStub.get_int32, protobuf.Int32Value, -32),
 | 
				
			||||||
    (TestStub.get_u_int32, wrappers.UInt32Value, 32),
 | 
					    (TestStub.get_u_int32, protobuf.UInt32Value, 32),
 | 
				
			||||||
    (TestStub.get_bool, wrappers.BoolValue, True),
 | 
					    (TestStub.get_bool, protobuf.BoolValue, True),
 | 
				
			||||||
    (TestStub.get_string, wrappers.StringValue, "string"),
 | 
					    (TestStub.get_string, protobuf.StringValue, "string"),
 | 
				
			||||||
    (TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
 | 
					    (TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "google/protobuf/empty.proto";
 | 
				
			||||||
 | 
					import "google/protobuf/struct.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tests that imports are generated correctly when returning Google well-known types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					service Test {
 | 
				
			||||||
 | 
					    rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty);
 | 
				
			||||||
 | 
					    rpc GetStruct (RequestMessage) returns (google.protobuf.Struct);
 | 
				
			||||||
 | 
					    rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue);
 | 
				
			||||||
 | 
					    rpc GetValue (RequestMessage) returns (google.protobuf.Value);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message RequestMessage {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "struct": {
 | 
				
			||||||
 | 
					    "key": true
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "google/protobuf/struct.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					  google.protobuf.Struct struct = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "value1": "hello world",
 | 
				
			||||||
 | 
					  "value2": true,
 | 
				
			||||||
 | 
					  "value3": 1,
 | 
				
			||||||
 | 
					  "value4": null,
 | 
				
			||||||
 | 
					  "value5": [
 | 
				
			||||||
 | 
					    1,
 | 
				
			||||||
 | 
					    2,
 | 
				
			||||||
 | 
					    3
 | 
				
			||||||
 | 
					  ]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "google/protobuf/struct.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					  google.protobuf.Value value1 = 1;
 | 
				
			||||||
 | 
					  google.protobuf.Value value2 = 2;
 | 
				
			||||||
 | 
					  google.protobuf.Value value3 = 3;
 | 
				
			||||||
 | 
					  google.protobuf.Value value4 = 4;
 | 
				
			||||||
 | 
					  google.protobuf.Value value5 = 5;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -8,6 +8,7 @@ class MockChannel(Channel):
 | 
				
			|||||||
    def __init__(self, responses=None) -> None:
 | 
					    def __init__(self, responses=None) -> None:
 | 
				
			||||||
        self.responses = responses if responses else []
 | 
					        self.responses = responses if responses else []
 | 
				
			||||||
        self.requests = []
 | 
					        self.requests = []
 | 
				
			||||||
 | 
					        self._loop = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def request(self, route, cardinality, request, response_type, **kwargs):
 | 
					    def request(self, route, cardinality, request, response_type, **kwargs):
 | 
				
			||||||
        self.requests.append(
 | 
					        self.requests.append(
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										82
									
								
								betterproto/tests/test_get_ref_type.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								betterproto/tests/test_get_ref_type.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..compile.importing import get_ref_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["google_type", "expected_name", "expected_import"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            ".google.protobuf.Empty",
 | 
				
			||||||
 | 
					            "betterproto_lib_google_protobuf.Empty",
 | 
				
			||||||
 | 
					            "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            ".google.protobuf.Struct",
 | 
				
			||||||
 | 
					            "betterproto_lib_google_protobuf.Struct",
 | 
				
			||||||
 | 
					            "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            ".google.protobuf.ListValue",
 | 
				
			||||||
 | 
					            "betterproto_lib_google_protobuf.ListValue",
 | 
				
			||||||
 | 
					            "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            ".google.protobuf.Value",
 | 
				
			||||||
 | 
					            "betterproto_lib_google_protobuf.Value",
 | 
				
			||||||
 | 
					            "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_import_google_wellknown_types_non_wrappers(
 | 
				
			||||||
 | 
					    google_type: str, expected_name: str, expected_import: str
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_ref_type(package="", imports=imports, type_name=google_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert name == expected_name
 | 
				
			||||||
 | 
					    assert imports.__contains__(expected_import)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["google_type", "expected_name"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        (".google.protobuf.DoubleValue", "Optional[float]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.FloatValue", "Optional[float]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.Int32Value", "Optional[int]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.Int64Value", "Optional[int]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.UInt32Value", "Optional[int]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.UInt64Value", "Optional[int]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.BoolValue", "Optional[bool]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.StringValue", "Optional[str]"),
 | 
				
			||||||
 | 
					        (".google.protobuf.BytesValue", "Optional[bytes]"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_ref_type(package="", imports=imports, type_name=google_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert name == expected_name
 | 
				
			||||||
 | 
					    assert imports == set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["google_type", "expected_name"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        (".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"),
 | 
				
			||||||
 | 
					        (".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"),
 | 
				
			||||||
 | 
					        (".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"),
 | 
				
			||||||
 | 
					        (".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"),
 | 
				
			||||||
 | 
					        (".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"),
 | 
				
			||||||
 | 
					        (".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"),
 | 
				
			||||||
 | 
					        (".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"),
 | 
				
			||||||
 | 
					        (".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"),
 | 
				
			||||||
 | 
					        (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_importing_google_wrappers_without_unwrapping(
 | 
				
			||||||
 | 
					    google_type: str, expected_name: str
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert name == expected_name
 | 
				
			||||||
@@ -30,6 +30,10 @@ class TestCases:
 | 
				
			|||||||
            test for test in _messages if get_test_case_json_data(test)
 | 
					            test for test in _messages if get_test_case_json_data(test)
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        unknown_xfail_tests = xfail - _all
 | 
				
			||||||
 | 
					        if unknown_xfail_tests:
 | 
				
			||||||
 | 
					            raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.all = self.apply_xfail_marks(_all, xfail)
 | 
					        self.all = self.apply_xfail_marks(_all, xfail)
 | 
				
			||||||
        self.services = self.apply_xfail_marks(_services, xfail)
 | 
					        self.services = self.apply_xfail_marks(_services, xfail)
 | 
				
			||||||
        self.messages = self.apply_xfail_marks(_messages, xfail)
 | 
					        self.messages = self.apply_xfail_marks(_messages, xfail)
 | 
				
			||||||
@@ -110,7 +114,7 @@ def test_message_json(repeat, test_data: TestData) -> None:
 | 
				
			|||||||
        message.from_json(json_data)
 | 
					        message.from_json(json_data)
 | 
				
			||||||
        message_json = message.to_json(0)
 | 
					        message_json = message.to_json(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert json.loads(json_data) == json.loads(message_json)
 | 
					        assert json.loads(message_json) == json.loads(json_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
 | 
					@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user