Merge branch 'refs/heads/master_gh'

This commit is contained in:
Georg K 2024-10-13 03:14:00 +03:00
commit f8ecc42478
9 changed files with 823 additions and 961 deletions

1635
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,13 +12,12 @@ packages = [
]
[tool.poetry.dependencies]
python = "^3.7"
python = "^3.8"
black = { version = ">=23.1.0", optional = true }
grpclib = "^0.4.1"
importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true}
isort = { version = "^5.11.5", optional = true }
typing-extensions = "^4.7.1"
betterproto-rust-codec = { version = "0.1.1", optional = true }
@ -26,7 +25,7 @@ betterproto-rust-codec = { version = "0.1.1", optional = true }
asv = "^0.4.2"
bpython = "^0.19"
jinja2 = ">=3.0.3"
mypy = "^0.930"
mypy = "^1.11.2"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
pre-commit = "^2.17.0"

View File

@ -169,7 +169,22 @@ class Casing(builtin_enum.Enum):
SNAKE = snake_case #: A snake_case sterilization function.
PLACEHOLDER: Any = object()
class Placeholder:
__slots__ = ()
def __repr__(self) -> str:
return "<PLACEHOLDER>"
def __copy__(self) -> Self:
return self
def __deepcopy__(self, _) -> Self:
return self
# We can't simply use object() here because pydantic automatically performs deep-copy of mutable default values
# See #606
PLACEHOLDER: Any = Placeholder()
@dataclasses.dataclass(frozen=True)
@ -206,7 +221,7 @@ def dataclass_field(
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=None if optional else PLACEHOLDER,
default=None if optional else PLACEHOLDER, # type: ignore
metadata={
"betterproto": FieldMetadata(
number, proto_type, map_types, group, wraps, optional
@ -1864,9 +1879,7 @@ class Message(ABC):
if getattr(values, field.name, None) is not None
]
if not set_fields:
raise ValueError(f"Group {group} has no value; all fields are None")
elif len(set_fields) > 1:
if len(set_fields) > 1:
set_fields_str = ", ".join(set_fields)
raise ValueError(
f"Group {group} has more than one value; fields {set_fields_str} are not None"

View File

@ -500,35 +500,6 @@ class FieldCompiler(MessageCompiler):
.replace("type_", "")
)
@property
def default_value_string(self) -> str:
"""Python representation of the default proto value."""
if self.repeated:
return "[]"
if self.optional:
return "None"
if self.py_type == "int":
return "0"
if self.py_type == "float":
return "0.0"
elif self.py_type == "bool":
return "False"
elif self.py_type == "str":
return '""'
elif self.py_type == "bytes":
return 'b""'
elif self.field_type == "enum":
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
enum = next(
e
for e in self.output_file.enums
if e.proto_obj.name == enum_proto_obj_name
)
return enum.default_value_string
else:
# Message type
return "None"
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
@ -687,14 +658,6 @@ class EnumDefinitionCompiler(MessageCompiler):
]
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> str:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
"""
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
@dataclass
class ServiceCompiler(ProtoContentBase):
@ -755,30 +718,6 @@ class ServiceMethodCompiler(ProtoContentBase):
)
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object.
Returns
-------
Optional[MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
"""
package, name = parse_source_type_name(self.proto_obj.input_type)
# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is
# comparable with method.input_type
for msg in self.request.all_messages:
if (
msg.py_name == pythonize_class_name(name.replace(".", ""))
and msg.output_file.package == package
):
return msg
return None
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type corresponding to the

View File

@ -5,16 +5,9 @@
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% set type_checking_imported = False %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
{% set type_checking_imported = True %}
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import rebuild_dataclass
{%- else -%}
from dataclasses import dataclass
@ -46,7 +39,7 @@ import grpclib
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only and not type_checking_imported %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:

View File

@ -23,7 +23,11 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endfor %}
{% endif %}
{% for message in output_file.messages %}
{% if output_file.pydantic_dataclasses %}
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
{% else %}
@dataclass(eq=False, repr=False)
{% endif %}
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
@ -70,7 +74,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
@ -149,7 +153,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}

View File

@ -0,0 +1,7 @@
syntax = "proto3";
package invalid_field;
message Test {
int32 x = 1;
}

View File

@ -0,0 +1,17 @@
import pytest
def test_invalid_field():
from tests.output_betterproto.invalid_field import Test
with pytest.raises(TypeError):
Test(unknown_field=12)
def test_invalid_field_pydantic():
from pydantic import ValidationError
from tests.output_betterproto_pydantic.invalid_field import Test
with pytest.raises(ValidationError):
Test(unknown_field=12)

View File

@ -35,18 +35,16 @@ def test_message_with_deprecated_field(message):
def test_message_with_deprecated_field_not_set(message):
with pytest.warns(None) as record:
with warnings.catch_warnings():
warnings.simplefilter("error")
Test(value=10)
assert not record
def test_message_with_deprecated_field_not_set_default(message):
with pytest.warns(None) as record:
with warnings.catch_warnings():
warnings.simplefilter("error")
_ = Test(value=10).message
assert not record
@pytest.mark.asyncio
async def test_service_with_deprecated_method():
@ -58,7 +56,6 @@ async def test_service_with_deprecated_method():
assert len(record) == 1
assert str(record[0].message) == f"TestService.deprecated_func is deprecated"
with pytest.warns(None) as record:
with warnings.catch_warnings():
warnings.simplefilter("error")
await stub.func(Empty())
assert not record