Merge branch 'refs/heads/master_gh'

This commit is contained in:
Georg K
2024-10-13 03:14:00 +03:00
9 changed files with 823 additions and 961 deletions

View File

@@ -169,7 +169,22 @@ class Casing(builtin_enum.Enum):
SNAKE = snake_case #: A snake_case sterilization function.
PLACEHOLDER: Any = object()
class Placeholder:
__slots__ = ()
def __repr__(self) -> str:
return "<PLACEHOLDER>"
def __copy__(self) -> Self:
return self
def __deepcopy__(self, _) -> Self:
return self
# We can't simply use object() here because pydantic automatically performs deep-copy of mutable default values
# See #606
PLACEHOLDER: Any = Placeholder()
@dataclasses.dataclass(frozen=True)
@@ -206,7 +221,7 @@ def dataclass_field(
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=None if optional else PLACEHOLDER,
default=None if optional else PLACEHOLDER, # type: ignore
metadata={
"betterproto": FieldMetadata(
number, proto_type, map_types, group, wraps, optional
@@ -1864,9 +1879,7 @@ class Message(ABC):
if getattr(values, field.name, None) is not None
]
if not set_fields:
raise ValueError(f"Group {group} has no value; all fields are None")
elif len(set_fields) > 1:
if len(set_fields) > 1:
set_fields_str = ", ".join(set_fields)
raise ValueError(
f"Group {group} has more than one value; fields {set_fields_str} are not None"

View File

@@ -500,35 +500,6 @@ class FieldCompiler(MessageCompiler):
.replace("type_", "")
)
@property
def default_value_string(self) -> str:
"""Python representation of the default proto value."""
if self.repeated:
return "[]"
if self.optional:
return "None"
if self.py_type == "int":
return "0"
if self.py_type == "float":
return "0.0"
elif self.py_type == "bool":
return "False"
elif self.py_type == "str":
return '""'
elif self.py_type == "bytes":
return 'b""'
elif self.field_type == "enum":
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
enum = next(
e
for e in self.output_file.enums
if e.proto_obj.name == enum_proto_obj_name
)
return enum.default_value_string
else:
# Message type
return "None"
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
@@ -687,14 +658,6 @@ class EnumDefinitionCompiler(MessageCompiler):
]
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> str:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
"""
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
@dataclass
class ServiceCompiler(ProtoContentBase):
@@ -755,30 +718,6 @@ class ServiceMethodCompiler(ProtoContentBase):
)
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object.
Returns
-------
Optional[MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
"""
package, name = parse_source_type_name(self.proto_obj.input_type)
# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is
# comparable with method.input_type
for msg in self.request.all_messages:
if (
msg.py_name == pythonize_class_name(name.replace(".", ""))
and msg.output_file.package == package
):
return msg
return None
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type corresponding to the

View File

@@ -5,16 +5,9 @@
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% set type_checking_imported = False %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
{% set type_checking_imported = True %}
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import rebuild_dataclass
{%- else -%}
from dataclasses import dataclass
@@ -46,7 +39,7 @@ import grpclib
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only and not type_checking_imported %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:

View File

@@ -23,7 +23,11 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endfor %}
{% endif %}
{% for message in output_file.messages %}
{% if output_file.pydantic_dataclasses %}
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
{% else %}
@dataclass(eq=False, repr=False)
{% endif %}
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
@@ -70,7 +74,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
@@ -149,7 +153,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}