Merge branch 'master_gh'

# Conflicts:
#	src/betterproto/__init__.py
This commit is contained in:
Georg K 2023-11-15 17:33:19 +03:00
commit 1d296f1a88
36 changed files with 2233 additions and 784 deletions

63
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: Bug Report
description: Report broken or incorrect behaviour
labels: ["bug", "investigation needed"]
body:
- type: markdown
attributes:
value: >
Thanks for taking the time to fill out a bug report!
If you're not sure it's a bug and you just have a question, the [community Slack channel](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ) is a better place for general questions than a GitHub issue.
- type: input
attributes:
label: Summary
description: A simple summary of your bug report
validations:
required: true
- type: textarea
attributes:
label: Reproduction Steps
description: >
What you did to make it happen.
Ideally there should be a short code snippet in this section to help reproduce the bug.
validations:
required: true
- type: textarea
attributes:
label: Expected Results
description: >
What did you expect to happen?
validations:
required: true
- type: textarea
attributes:
label: Actual Results
description: >
What actually happened?
validations:
required: true
- type: textarea
attributes:
label: System Information
description: >
Paste the result of `protoc --version; python --version; pip show betterproto` below.
validations:
required: true
- type: checkboxes
attributes:
label: Checklist
options:
- label: I have searched the issues for duplicates.
required: true
- label: I have shown the entire traceback, if possible.
required: true
- label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible.
required: true

6
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,6 @@
name:
description:
contact_links:
- name: For questions about the library
about: Support questions are better answered in our Slack group.
url: https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ

View File

@ -0,0 +1,49 @@
name: Feature Request
description: Suggest a feature for this library
labels: ["enhancement"]
body:
- type: input
attributes:
label: Summary
description: >
What problem is your feature trying to solve? What would become easier or possible if feature was implemented?
validations:
required: true
- type: dropdown
attributes:
multiple: false
label: What is the feature request for?
options:
- The core library
- RPC handling
- The documentation
validations:
required: true
- type: textarea
attributes:
label: The Problem
description: >
What problem is your feature trying to solve?
What would become easier or possible if feature was implemented?
validations:
required: true
- type: textarea
attributes:
label: The Ideal Solution
description: >
What is your ideal solution to the problem?
What would you like this feature to do?
validations:
required: true
- type: textarea
attributes:
label: The Current Solution
description: >
What is the current solution to the problem, if any?
validations:
required: false

16
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,16 @@
## Summary
<!-- What is this pull request for? Does it fix any issues? -->
## Checklist
<!-- Put an x inside [ ] to check it, like so: [x] -->
- [ ] If code changes were made then they have been tested.
- [ ] I have updated the documentation to reflect the changes.
- [ ] This PR fixes an issue.
- [ ] This PR adds something new (e.g. new method or parameters).
- [ ] This change has an associated test.
- [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed)
- [ ] This PR is **not** a code change (e.g. documentation, README, ...)

46
.github/workflows/codeql-analysis.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: "CodeQL"
on:
push:
branches: [ "master" ]
pull_request:
branches:
- '**'
schedule:
- cron: '19 1 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'python' ]
steps:
- name: Checkout repository
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
- name: Autobuild
uses: github/codeql-action/autobuild@v2
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2

View File

@ -16,6 +16,13 @@ repos:
- repo: https://github.com/PyCQA/doc8 - repo: https://github.com/PyCQA/doc8
rev: 0.10.1 rev: 0.10.1
hooks: hooks:
- id: doc8 - id: doc8
additional_dependencies: additional_dependencies:
- toml - toml
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.10.0
hooks:
- id: pretty-format-java
args: [--autofix, --aosp]
files: ^.*\.java$

1200
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -19,26 +19,29 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true } jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8" python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true} isort = {version = "^5.11.5", optional = true}
typing-extensions = "^4.7.1"
[tool.poetry.dev-dependencies] [tool.poetry.group.dev.dependencies]
asv = "^0.4.2" asv = "^0.4.2"
bpython = "^0.19" bpython = "^0.19"
grpcio-tools = "^1.54.2"
jinja2 = ">=3.0.3" jinja2 = ">=3.0.3"
mypy = "^0.930" mypy = "^0.930"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
pre-commit = "^2.17.0"
grpcio-tools = "^1.54.2"
tox = "^4.0.0"
[tool.poetry.group.test.dependencies]
poethepoet = ">=0.9.0" poethepoet = ">=0.9.0"
protobuf = "^4.21.6"
pytest = "^6.2.5" pytest = "^6.2.5"
pytest-asyncio = "^0.12.0" pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0" pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1" pytest-mock = "^3.1.1"
sphinx = "3.1.2" pydantic = ">=1.8.0,<2"
sphinx-rtd-theme = "0.5.0" protobuf = "^4"
tomlkit = "^0.7.0" cachelib = "^0.10.2"
tox = "^3.15.1" tomlkit = ">=0.7.0"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"
[tool.poetry.scripts] [tool.poetry.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main" protoc-gen-python_betterproto = "betterproto.plugin:main"
@ -61,9 +64,13 @@ help = "Run tests"
cmd = "mypy src --ignore-missing-imports" cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy" help = "Check types with mypy"
[tool.poe.tasks]
_black = "black . --exclude tests/output_ --target-version py310"
_isort = "isort . --extend-skip-glob 'tests/output_*/**/*'"
[tool.poe.tasks.format] [tool.poe.tasks.format]
cmd = "black . --exclude tests/output_ --target-version py310" sequence = ["_black", "_isort"]
help = "Apply black formatting to source code" help = "Apply black and isort formatting to source code"
[tool.poe.tasks.docs] [tool.poe.tasks.docs]
cmd = "sphinx-build docs docs/build" cmd = "sphinx-build docs docs/build"
@ -130,14 +137,21 @@ omit = ["betterproto/tests/*"]
[tool.tox] [tool.tox]
legacy_tox_ini = """ legacy_tox_ini = """
[tox] [tox]
isolated_build = true requires =
envlist = py37, py38, py310 tox>=4.2
tox-poetry-installer[poetry]==1.0.0b1
env_list =
py311
py38
py37
[testenv] [testenv]
whitelist_externals = poetry
commands = commands =
poetry install -v --extras compiler pytest {posargs: --cov betterproto}
poetry run pytest --cov betterproto poetry_dep_groups =
test
require_locked_deps = true
require_poetry = true
""" """
[build-system] [build-system]

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import dataclasses import dataclasses
import enum import enum as builtin_enum
import json import json
import math import math
import struct import struct
@ -22,8 +24,8 @@ from itertools import count
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
BinaryIO,
Callable, Callable,
ClassVar,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -37,6 +39,7 @@ from typing import (
) )
from dateutil.parser import isoparse from dateutil.parser import isoparse
from typing_extensions import Self
from ._types import T from ._types import T
from ._version import __version__ from ._version import __version__
@ -45,11 +48,19 @@ from .casing import (
safe_snake_case, safe_snake_case,
snake_case, snake_case,
) )
from .grpc.grpclib_client import ServiceStub from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import (
classproperty,
hybridmethod,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from _typeshed import ReadableBuffer from _typeshed import (
SupportsRead,
SupportsWrite,
)
# Proto 3 data types # Proto 3 data types
@ -126,6 +137,9 @@ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Indicator of message delimitation in streams
SIZE_DELIMITED = -1
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. # Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen() -> datetime: def datetime_default_gen() -> datetime:
@ -140,7 +154,7 @@ NEG_INFINITY = "-Infinity"
NAN = "NaN" NAN = "NaN"
class Casing(enum.Enum): class Casing(builtin_enum.Enum):
"""Casing constants for serialization.""" """Casing constants for serialization."""
CAMEL = camel_case #: A camelCase sterilization function. CAMEL = camel_case #: A camelCase sterilization function.
@ -309,32 +323,6 @@ def map_field(
) )
class Enum(enum.IntEnum):
"""
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""
@classmethod
def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
def _pack_fmt(proto_type: str) -> str: def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary.""" """Returns a little-endian format string for reading/writing binary."""
return { return {
@ -347,7 +335,7 @@ def _pack_fmt(proto_type: str) -> str:
}[proto_type] }[proto_type]
def dump_varint(value: int, stream: BinaryIO) -> None: def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
"""Encodes a single varint and dumps it into the provided stream.""" """Encodes a single varint and dumps it into the provided stream."""
if value < -(1 << 63): if value < -(1 << 63):
raise ValueError( raise ValueError(
@ -556,7 +544,7 @@ def _dump_float(value: float) -> Union[float, str]:
return value return value
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]: def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
""" """
Load a single varint value from a stream. Returns the value and the raw bytes read. Load a single varint value from a stream. Returns the value and the raw bytes read.
""" """
@ -594,7 +582,7 @@ class ParsedField:
raw: bytes raw: bytes
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]: def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
while True: while True:
try: try:
num_wire, raw = load_varint(stream) num_wire, raw = load_varint(stream)
@ -748,6 +736,7 @@ class Message(ABC):
_serialized_on_wire: bool _serialized_on_wire: bool
_unknown_fields: bytes _unknown_fields: bytes
_group_current: Dict[str, str] _group_current: Dict[str, str]
_betterproto_meta: ClassVar[ProtoClassMetadata]
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Keep track of whether every field was default # Keep track of whether every field was default
@ -815,6 +804,10 @@ class Message(ABC):
] ]
return f"{self.__class__.__name__}({', '.join(parts)})" return f"{self.__class__.__name__}({', '.join(parts)})"
def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
for field_name in self._betterproto.sorted_field_names:
yield field_name, self.__raw_get(field_name), PLACEHOLDER
if not TYPE_CHECKING: if not TYPE_CHECKING:
def __getattribute__(self, name: str) -> Any: def __getattribute__(self, name: str) -> Any:
@ -889,20 +882,28 @@ class Message(ABC):
kwargs[name] = deepcopy(value) kwargs[name] = deepcopy(value)
return self.__class__(**kwargs) # type: ignore return self.__class__(**kwargs) # type: ignore
@property def __copy__(self: T, _: Any = {}) -> T:
def _betterproto(self) -> ProtoClassMetadata: kwargs = {}
for name in self._betterproto.sorted_field_names:
value = self.__raw_get(name)
if value is not PLACEHOLDER:
kwargs[name] = value
return self.__class__(**kwargs) # type: ignore
@classproperty
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
""" """
Lazy initialize metadata for each protobuf class. Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment, It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness. but that won't affect the correctness.
""" """
meta = getattr(self.__class__, "_betterproto_meta", None) try:
if not meta: return cls._betterproto_meta
meta = ProtoClassMetadata(self.__class__) except AttributeError:
self.__class__._betterproto_meta = meta # type: ignore cls._betterproto_meta = meta = ProtoClassMetadata(cls)
return meta return meta
def dump(self, stream: BinaryIO) -> None: def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
""" """
Dumps the binary encoded Protobuf message to the stream. Dumps the binary encoded Protobuf message to the stream.
@ -910,7 +911,11 @@ class Message(ABC):
----------- -----------
stream: :class:`BinaryIO` stream: :class:`BinaryIO`
The stream to dump the message to. The stream to dump the message to.
delimit:
Whether to prefix the message with a varint declaring its size.
""" """
if delimit == SIZE_DELIMITED:
dump_varint(len(self), stream)
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
try: try:
@ -930,7 +935,7 @@ class Message(ABC):
# Note that proto3 field presence/optional fields are put in a # Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we # synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value. # send the value even if the value is the default zero value.
selected_in_group = bool(meta.group) selected_in_group = bool(meta.group) or meta.optional
# Empty messages can still be sent on the wire if they were # Empty messages can still be sent on the wire if they were
# set (or received empty). # set (or received empty).
@ -1124,6 +1129,15 @@ class Message(ABC):
""" """
return bytes(self) return bytes(self)
def __getstate__(self) -> bytes:
return bytes(self)
def __setstate__(self: T, pickled_bytes: bytes) -> T:
return self.parse(pickled_bytes)
def __reduce__(self) -> Tuple[Any, ...]:
return (self.__class__.FromString, (bytes(self),))
@classmethod @classmethod
def _type_hint(cls, field_name: str) -> Type: def _type_hint(cls, field_name: str) -> Type:
return cls._type_hints()[field_name] return cls._type_hints()[field_name]
@ -1168,7 +1182,7 @@ class Message(ABC):
return t return t
elif issubclass(t, Enum): elif issubclass(t, Enum):
# Enums always default to zero. # Enums always default to zero.
return int return t.try_value
elif t is datetime: elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z # Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen return datetime_default_gen
@ -1193,6 +1207,9 @@ class Message(ABC):
elif meta.proto_type == TYPE_BOOL: elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false. # Booleans use a varint encoding, so convert it to true/false.
value = value > 0 value = value > 0
elif meta.proto_type == TYPE_ENUM:
# Convert enum ints to python enum instances
value = self._betterproto.cls_by_field[field_name].try_value(value)
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64): elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
fmt = _pack_fmt(meta.proto_type) fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0] value = struct.unpack(fmt, value)[0]
@ -1225,7 +1242,11 @@ class Message(ABC):
meta.group is not None and self._group_current.get(meta.group) == field_name meta.group is not None and self._group_current.get(meta.group) == field_name
) )
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: def load(
self: T,
stream: "SupportsRead[bytes]",
size: Optional[int] = None,
) -> T:
""" """
Load the binary encoded Protobuf from a stream into this message instance. This Load the binary encoded Protobuf from a stream into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
@ -1237,12 +1258,17 @@ class Message(ABC):
size: :class:`Optional[int]` size: :class:`Optional[int]`
The size of the message in the stream. The size of the message in the stream.
Reads stream until EOF if ``None`` is given. Reads stream until EOF if ``None`` is given.
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
Returns Returns
-------- --------
:class:`Message` :class:`Message`
The initialized message. The initialized message.
""" """
# If the message is delimited, parse the message delimiter
if size == SIZE_DELIMITED:
size, _ = load_varint(stream)
# Got some data over the wire # Got some data over the wire
self._serialized_on_wire = True self._serialized_on_wire = True
proto_meta = self._betterproto proto_meta = self._betterproto
@ -1315,7 +1341,7 @@ class Message(ABC):
return self return self
def parse(self: T, data: "ReadableBuffer") -> T: def parse(self: T, data: bytes) -> T:
""" """
Parse the binary encoded Protobuf into this message instance. This Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
@ -1494,7 +1520,91 @@ class Message(ABC):
output[cased_name] = value output[cased_name] = value
return output return output
def from_dict(self: T, value: Mapping[str, Any]) -> T: @classmethod
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
init_kwargs: Dict[str, Any] = {}
for key, value in mapping.items():
field_name = safe_snake_case(key)
try:
meta = cls._betterproto.meta_by_field_name[field_name]
except KeyError:
continue
if value is None:
continue
if meta.proto_type == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[field_name]
if sub_cls == datetime:
value = (
[isoparse(item) for item in value]
if isinstance(value, list)
else isoparse(value)
)
elif sub_cls == timedelta:
value = (
[timedelta(seconds=float(item[:-1])) for item in value]
if isinstance(value, list)
else timedelta(seconds=float(value[:-1]))
)
elif not meta.wraps:
value = (
[sub_cls.from_dict(item) for item in value]
if isinstance(value, list)
else sub_cls.from_dict(value)
)
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
else:
if meta.proto_type in INT_64_TYPES:
value = (
[int(n) for n in value]
if isinstance(value, list)
else int(value)
)
elif meta.proto_type == TYPE_BYTES:
value = (
[b64decode(n) for n in value]
if isinstance(value, list)
else b64decode(value)
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]
if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
elif isinstance(value, str):
value = enum_cls.from_string(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
if isinstance(value, list)
else _parse_float(value)
)
init_kwargs[field_name] = value
return init_kwargs
@hybridmethod
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
"""
Parse the key/value pairs into the a new message instance.
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
"""
self = cls(**cls._from_dict_init(value))
self._serialized_on_wire = True
return self
@from_dict.instancemethod
def from_dict(self, value: Mapping[str, Any]) -> Self:
""" """
Parse the key/value pairs into the current message instance. This returns the Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable. instance itself and is therefore assignable and chainable.
@ -1510,71 +1620,8 @@ class Message(ABC):
The initialized message. The initialized message.
""" """
self._serialized_on_wire = True self._serialized_on_wire = True
for key in value: for field, value in self._from_dict_init(value).items():
field_name = safe_snake_case(key) setattr(self, field, value)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif cls == datetime:
v = isoparse(value[key])
setattr(self, field_name, v)
elif cls == timedelta:
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
if v is not None:
setattr(self, field_name, v)
return self return self
def to_json( def to_json(
@ -1791,8 +1838,8 @@ class Message(ABC):
@classmethod @classmethod
def _validate_field_groups(cls, values): def _validate_field_groups(cls, values):
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore group_to_one_ofs = cls._betterproto.oneof_field_by_group
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore field_name_to_meta = cls._betterproto.meta_by_field_name
for group, field_set in group_to_one_ofs.items(): for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1: if len(field_set) == 1:
@ -1819,6 +1866,9 @@ class Message(ABC):
return values return values
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
def serialized_on_wire(message: Message) -> bool: def serialized_on_wire(message: Message) -> bool:
""" """
If this message was or should be serialized on the wire. This can be used to detect If this message was or should be serialized on the wire. This can be used to detect
@ -1890,17 +1940,24 @@ class _Duration(Duration):
class _Timestamp(Timestamp): class _Timestamp(Timestamp):
@classmethod @classmethod
def from_datetime(cls, dt: datetime) -> "_Timestamp": def from_datetime(cls, dt: datetime) -> "_Timestamp":
seconds = int(dt.timestamp()) # manual epoch offset calulation to avoid rounding errors,
nanos = int(dt.microsecond * 1e3) # to support negative timestamps (before 1970) and skirt
return cls(seconds, nanos) # around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
offset = dt - DATETIME_ZERO
# below is the same as timedelta.total_seconds() but without dividing by 1e6
# so we end up with microseconds as integers instead of seconds as float
offset_us = (
offset.days * 24 * 60 * 60 + offset.seconds
) * 10**6 + offset.microseconds
seconds, us = divmod(offset_us, 10**6)
return cls(seconds, us * 1000)
def to_datetime(self) -> datetime: def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9) # datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
# if we pass it as a floating point number, we will run into rounding errors
if ts < 0: # see also #407
return datetime(1970, 1, 1) + timedelta(seconds=ts) offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
else: return DATETIME_ZERO + offset
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod @staticmethod
def timestamp_to_json(dt: datetime) -> str: def timestamp_to_json(dt: datetime) -> str:

View File

@ -136,4 +136,8 @@ def lowercase_first(value: str) -> str:
def sanitize_name(value: str) -> str: def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
return f"{value}_" if keyword.iskeyword(value) else value if keyword.iskeyword(value):
return f"{value}_"
if not value.isidentifier():
return f"_{value}"
return value

View File

@ -11,3 +11,11 @@ def pythonize_field_name(name: str) -> str:
def pythonize_method_name(name: str) -> str: def pythonize_method_name(name: str) -> str:
return casing.safe_snake_case(name) return casing.safe_snake_case(name)
def pythonize_enum_member_name(name: str, enum_name: str) -> str:
enum_name = casing.snake_case(enum_name).upper()
find = name.find(enum_name)
if find != -1:
name = name[find + len(enum_name) :].strip("_")
return casing.sanitize_name(name)

196
src/betterproto/enum.py Normal file
View File

@ -0,0 +1,196 @@
from __future__ import annotations
import sys
from enum import (
EnumMeta,
IntEnum,
)
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Tuple,
)
if TYPE_CHECKING:
from collections.abc import (
Generator,
Mapping,
)
from typing_extensions import (
Never,
Self,
)
def _is_descriptor(obj: object) -> bool:
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)
class EnumType(EnumMeta if TYPE_CHECKING else type):
_value_map_: Mapping[int, Enum]
_member_map_: Mapping[str, Enum]
def __new__(
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
) -> Self:
value_map = {}
member_map = {}
new_mcs = type(
f"{name}Type",
tuple(
dict.fromkeys(
[base.__class__ for base in bases if base.__class__ is not type]
+ [EnumType, type]
)
), # reorder the bases so EnumType and type are last to avoid conflicts
{"_value_map_": value_map, "_member_map_": member_map},
)
members = {
name: value
for name, value in namespace.items()
if not _is_descriptor(value) and not name.startswith("__")
}
cls = type.__new__(
new_mcs,
name,
bases,
{key: value for key, value in namespace.items() if key not in members},
)
# this allows us to disallow member access from other members as
# members become proper class variables
for name, value in members.items():
member = value_map.get(value)
if member is None:
member = cls.__new__(cls, name=name, value=value) # type: ignore
value_map[value] = member
member_map[name] = member
type.__setattr__(new_mcs, name, member)
return cls
if not TYPE_CHECKING:
def __call__(cls, value: int) -> Enum:
try:
return cls._value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
def __iter__(cls) -> Generator[Enum, None, None]:
yield from cls._member_map_.values()
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(cls._member_map_.values())
else:
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(tuple(cls._member_map_.values()))
def __getitem__(cls, key: str) -> Enum:
return cls._member_map_[key]
@property
def __members__(cls) -> MappingProxyType[str, Enum]:
return MappingProxyType(cls._member_map_)
def __repr__(cls) -> str:
return f"<enum {cls.__name__!r}>"
def __len__(cls) -> int:
return len(cls._member_map_)
def __setattr__(cls, name: str, value: Any) -> Never:
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
def __delattr__(cls, name: str) -> Never:
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
def __contains__(cls, member: object) -> bool:
return isinstance(member, cls) and member.name in cls._member_map_
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
"""
The base class for protobuf enumerations, all generated enumerations will
inherit from this. Emulates `enum.IntEnum`.
"""
name: Optional[str]
value: int
if not TYPE_CHECKING:
def __new__(cls, *, name: Optional[str], value: int) -> Self:
self = super().__new__(cls, value)
super().__setattr__(self, "name", name)
super().__setattr__(self, "value", value)
return self
def __str__(self) -> str:
return self.name or "None"
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __setattr__(self, key: str, value: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot reassign a member's attributes."
)
def __delattr__(self, item: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot delete a member's attributes."
)
@classmethod
def try_value(cls, value: int = 0) -> Self:
"""Return the value which corresponds to the value.
Parameters
-----------
value: :class:`int`
The value of the enum member to get.
Returns
-------
:class:`Enum`
The corresponding member or a new instance of the enum if
``value`` isn't actually a member.
"""
try:
return cls._value_map_[value]
except (KeyError, TypeError):
return cls.__new__(cls, name=None, value=value)
@classmethod
def from_string(cls, name: str) -> Self:
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get.
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

View File

@ -127,6 +127,7 @@ class ServiceStub(ABC):
response_type, response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata), **self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream: ) as stream:
await stream.send_request()
await self._send_messages(stream, request_iterator) await self._send_messages(stream, request_iterator)
response = await stream.recv_message() response = await stream.recv_message()
assert response is not None assert response is not None

View File

@ -72,13 +72,13 @@ from betterproto.lib.google.protobuf import (
) )
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
from ..casing import sanitize_name
from ..compile.importing import ( from ..compile.importing import (
get_type_reference, get_type_reference,
parse_source_type_name, parse_source_type_name,
) )
from ..compile.naming import ( from ..compile.naming import (
pythonize_class_name, pythonize_class_name,
pythonize_enum_member_name,
pythonize_field_name, pythonize_field_name,
pythonize_method_name, pythonize_method_name,
) )
@ -385,7 +385,10 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
us to tell whether it was set, via the which_one_of interface. us to tell whether it was set, via the which_one_of interface.
""" """
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index" return (
not proto_field_obj.proto3_optional
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
)
@dataclass @dataclass
@ -670,7 +673,9 @@ class EnumDefinitionCompiler(MessageCompiler):
# Get entries/allowed values for this Enum # Get entries/allowed values for this Enum
self.entries = [ self.entries = [
self.EnumEntry( self.EnumEntry(
name=sanitize_name(entry_proto_value.name), name=pythonize_enum_member_name(
entry_proto_value.name, self.proto_obj.name
),
value=entry_proto_value.number, value=entry_proto_value.number,
comment=get_comment( comment=get_comment(
proto_file=self.source_file, path=self.path + [2, entry_number] proto_file=self.source_file, path=self.path + [2, entry_number]

56
src/betterproto/utils.py Normal file
View File

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import (
Any,
Callable,
Generic,
Optional,
Type,
TypeVar,
)
from typing_extensions import (
Concatenate,
ParamSpec,
Self,
)
SelfT = TypeVar("SelfT")
P = ParamSpec("P")
HybridT = TypeVar("HybridT", covariant=True)
class hybridmethod(Generic[SelfT, P, HybridT]):
def __init__(
self,
func: Callable[
Concatenate[type[SelfT], P], HybridT
], # Must be the classmethod version
):
self.cls_func = func
self.__doc__ = func.__doc__
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
self.instance_func = func
return self
def __get__(
self, instance: Optional[SelfT], owner: Type[SelfT]
) -> Callable[P, HybridT]:
if instance is None or self.instance_func is None:
# either bound to the class, or no instance method available
return self.cls_func.__get__(owner, None)
return self.instance_func.__get__(instance, owner)
T_co = TypeVar("T_co")
TT_co = TypeVar("TT_co", bound="type[Any]")
class classproperty(Generic[TT_co, T_co]):
def __init__(self, func: Callable[[TT_co], T_co]):
self.__func__ = func
def __get__(self, instance: Any, type: TT_co) -> T_co:
return self.__func__(type)

View File

@ -272,3 +272,27 @@ async def test_async_gen_for_stream_stream_request():
assert response_index == len( assert response_index == len(
expected_things expected_things
), "Didn't receive all expected responses" ), "Didn't receive all expected responses"
@pytest.mark.asyncio
async def test_stream_unary_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [DoThingRequest(name) for name in things]
response = await client.do_many_things(requests)
assert len(response.names) == 0
@pytest.mark.asyncio
async def test_stream_stream_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [GetThingRequest(name) for name in things]
responses = [
response async for response in client.get_different_things(requests)
]
assert len(responses) == 0

View File

@ -27,7 +27,7 @@ class ThingService:
async def do_many_things( async def do_many_things(
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
): ):
thing_names = [request.name for request in stream] thing_names = [request.name async for request in stream]
if self.test_hook is not None: if self.test_hook is not None:
self.test_hook(stream) self.test_hook(stream)
await stream.send_message(DoThingResponse(thing_names)) await stream.send_message(DoThingResponse(thing_names))

View File

@ -15,3 +15,11 @@ enum Choice {
FOUR = 4; FOUR = 4;
THREE = 3; THREE = 3;
} }
// A "C" like enum with the enum name prefixed onto members, these should be stripped
enum ArithmeticOperator {
ARITHMETIC_OPERATOR_NONE = 0;
ARITHMETIC_OPERATOR_PLUS = 1;
ARITHMETIC_OPERATOR_MINUS = 2;
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
}

View File

@ -1,4 +1,5 @@
from tests.output_betterproto.enum import ( from tests.output_betterproto.enum import (
ArithmeticOperator,
Choice, Choice,
Test, Test,
) )
@ -82,3 +83,32 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
yield Choice.THREE yield Choice.THREE
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
def test_enum_mapped_on_parse():
# test default value
b = Test().parse(bytes(Test()))
assert b.choice.name == Choice.ZERO.name
assert b.choices == []
# test non default value
a = Test().parse(bytes(Test(choice=Choice.ONE)))
assert a.choice.name == Choice.ONE.name
assert b.choices == []
# test repeated
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
assert c.choices[0].name == Choice.THREE.name
assert c.choices[1].name == Choice.FOUR.name
# bonus: defaults after empty init are also mapped
assert Test().choice.name == Choice.ZERO.name
def test_renamed_enum_members():
assert set(ArithmeticOperator.__members__) == {
"NONE",
"PLUS",
"MINUS",
"_0_PREFIXED",
}

View File

@ -1,5 +1,6 @@
syntax = "proto3"; syntax = "proto3";
import "google/protobuf/timestamp.proto";
package google_impl_behavior_equivalence; package google_impl_behavior_equivalence;
message Foo { int64 bar = 1; } message Foo { int64 bar = 1; }
@ -12,6 +13,10 @@ message Test {
} }
} }
message Spam {
google.protobuf.Timestamp ts = 1;
}
message Request { Empty foo = 1; } message Request { Empty foo = 1; }
message Empty {} message Empty {}

View File

@ -1,17 +1,25 @@
from datetime import (
datetime,
timezone,
)
import pytest import pytest
from google.protobuf import json_format from google.protobuf import json_format
from google.protobuf.timestamp_pb2 import Timestamp
import betterproto import betterproto
from tests.output_betterproto.google_impl_behavior_equivalence import ( from tests.output_betterproto.google_impl_behavior_equivalence import (
Empty, Empty,
Foo, Foo,
Request, Request,
Spam,
Test, Test,
) )
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
Empty as ReferenceEmpty, Empty as ReferenceEmpty,
Foo as ReferenceFoo, Foo as ReferenceFoo,
Request as ReferenceRequest, Request as ReferenceRequest,
Spam as ReferenceSpam,
Test as ReferenceTest, Test as ReferenceTest,
) )
@ -59,6 +67,19 @@ def test_bytes_are_the_same_for_oneof():
assert isinstance(message_reference2.foo, ReferenceFoo) assert isinstance(message_reference2.foo, ReferenceFoo)
@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
def test_datetime_clamping(dt): # see #407
ts = Timestamp()
ts.FromDatetime(dt)
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
message_bytes = bytes(Spam(dt))
assert (
Spam().parse(message_bytes).ts.timestamp()
== ReferenceSpam.FromString(message_bytes).ts.seconds
)
def test_empty_message_field(): def test_empty_message_field():
message = Request() message = Request()
reference_message = ReferenceRequest() reference_message = ReferenceRequest()

View File

@ -2,6 +2,10 @@ syntax = "proto3";
package oneof; package oneof;
message MixedDrink {
int32 shots = 1;
}
message Test { message Test {
oneof foo { oneof foo {
int32 pitied = 1; int32 pitied = 1;
@ -13,6 +17,7 @@ message Test {
oneof bar { oneof bar {
int32 drinks = 11; int32 drinks = 11;
string bar_name = 12; string bar_name = 12;
MixedDrink mixed_drink = 13;
} }
} }

View File

@ -1,5 +1,10 @@
import pytest
import betterproto import betterproto
from tests.output_betterproto.oneof import Test from tests.output_betterproto.oneof import (
MixedDrink,
Test,
)
from tests.output_betterproto_pydantic.oneof import Test as TestPyd from tests.output_betterproto_pydantic.oneof import Test as TestPyd
from tests.util import get_test_case_json_data from tests.util import get_test_case_json_data
@ -19,3 +24,20 @@ def test_which_name():
def test_which_count_pyd(): def test_which_count_pyd():
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar") message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
def test_oneof_constructor_assign():
message = Test(mixed_drink=MixedDrink(shots=42))
field, value = betterproto.which_one_of(message, "bar")
assert field == "mixed_drink"
assert value.shots == 42
# Issue #305:
@pytest.mark.xfail
def test_oneof_nested_assign():
message = Test()
message.mixed_drink.shots = 42
field, value = betterproto.which_one_of(message, "bar")
assert field == "mixed_drink"
assert value.shots == 42

View File

@ -41,3 +41,8 @@ def test_null_fields_json():
"test8": None, "test8": None,
"test9": None, "test9": None,
} }
def test_unset_access(): # see #523
assert Test().test1 is None
assert Test(test1=None).test1 is None

View File

@ -0,0 +1,2 @@
•šï:bTesting•šï:bTesting
 

38
tests/streams/java/.gitignore vendored Normal file
View File

@ -0,0 +1,38 @@
### Output ###
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
dependency-reduced-pom.xml
MANIFEST.MF
### IntelliJ IDEA ###
.idea/
*.iws
*.iml
*.ipr
### Eclipse ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
### VS Code ###
.vscode/
### Mac OS ###
.DS_Store

View File

@ -0,0 +1,94 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>betterproto</groupId>
<artifactId>compatibility-test</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<protobuf.version>3.23.4</protobuf.version>
</properties>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
</dependencies>
<build>
<extensions>
<extension>
<groupId>kr.motd.maven</groupId>
<artifactId>os-maven-plugin</artifactId>
<version>1.7.1</version>
</extension>
</extensions>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.5.0</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>betterproto.CompatibilityTest</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<configuration>
<archive>
<manifest>
<addClasspath>true</addClasspath>
<mainClass>betterproto.CompatibilityTest</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.6.1</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
</goals>
</execution>
</executions>
<configuration>
<protocArtifact>
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
</protocArtifact>
</configuration>
</plugin>
</plugins>
<finalName>${project.artifactId}</finalName>
</build>
</project>

View File

@ -0,0 +1,41 @@
package betterproto;
import java.io.IOException;
public class CompatibilityTest {
public static void main(String[] args) throws IOException {
if (args.length < 2)
throw new RuntimeException("Attempted to run without the required arguments.");
else if (args.length > 2)
throw new RuntimeException(
"Attempted to run with more than the expected number of arguments (>1).");
Tests tests = new Tests(args[1]);
switch (args[0]) {
case "single_varint":
tests.testSingleVarint();
break;
case "multiple_varints":
tests.testMultipleVarints();
break;
case "single_message":
tests.testSingleMessage();
break;
case "multiple_messages":
tests.testMultipleMessages();
break;
case "infinite_messages":
tests.testInfiniteMessages();
break;
default:
throw new RuntimeException(
"Attempted to run with unknown argument '" + args[0] + "'.");
}
}
}

View File

@ -0,0 +1,115 @@
package betterproto;
import betterproto.nested.NestedOuterClass;
import betterproto.oneof.Oneof;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
public class Tests {
String path;
public Tests(String path) {
this.path = path;
}
public void testSingleVarint() throws IOException {
// Read in the Python-generated single varint file
FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
int value = codedInput.readUInt32();
inputStream.close();
// Write the value back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
codedOutput.writeUInt32NoTag(value);
codedOutput.flush();
outputStream.close();
}
public void testMultipleVarints() throws IOException {
// Read in the Python-generated multiple varints file
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
int value1 = codedInput.readUInt32();
int value2 = codedInput.readUInt32();
long value3 = codedInput.readUInt64();
inputStream.close();
// Write the values back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
codedOutput.writeUInt32NoTag(value1);
codedOutput.writeUInt64NoTag(value2);
codedOutput.writeUInt64NoTag(value3);
codedOutput.flush();
outputStream.close();
}
public void testSingleMessage() throws IOException {
// Read in the Python-generated single message file
FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
Oneof.Test message = Oneof.Test.parseFrom(codedInput);
inputStream.close();
// Write the message back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
message.writeTo(codedOutput);
codedOutput.flush();
outputStream.close();
}
public void testMultipleMessages() throws IOException {
// Read in the Python-generated multi-message file
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
inputStream.close();
// Write the messages back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
oneof.writeDelimitedTo(outputStream);
nested.writeDelimitedTo(outputStream);
outputStream.flush();
outputStream.close();
}
public void testInfiniteMessages() throws IOException {
// Read in as many messages as are present in the Python-generated file and write them back
FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
while (current != null) {
current.writeDelimitedTo(outputStream);
current = Oneof.Test.parseDelimitedFrom(inputStream);
}
inputStream.close();
outputStream.flush();
outputStream.close();
}
}

View File

@ -0,0 +1,27 @@
syntax = "proto3";
package nested;
option java_package = "betterproto.nested";
// A test message with a nested message inside of it.
message Test {
// This is the nested type.
message Nested {
// Stores a simple counter.
int32 count = 1;
}
// This is the nested enum.
enum Msg {
NONE = 0;
THIS = 1;
}
Nested nested = 1;
Sibling sibling = 2;
Sibling sibling2 = 3;
Msg msg = 4;
}
message Sibling {
int32 foo = 1;
}

View File

@ -0,0 +1,19 @@
syntax = "proto3";
package oneof;
option java_package = "betterproto.oneof";
message Test {
oneof foo {
int32 pitied = 1;
string pitier = 2;
}
int32 just_a_regular_field = 3;
oneof bar {
int32 drinks = 11;
string bar_name = 12;
}
}

79
tests/test_enum.py Normal file
View File

@ -0,0 +1,79 @@
from typing import (
Optional,
Tuple,
)
import pytest
import betterproto
class Colour(betterproto.Enum):
RED = 1
GREEN = 2
BLUE = 3
PURPLE = Colour.__new__(Colour, name=None, value=4)
@pytest.mark.parametrize(
"member, str_value",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_str(member: Colour, str_value: str) -> None:
assert str(member) == str_value
@pytest.mark.parametrize(
"member, repr_value",
[
(Colour.RED, "Colour.RED"),
(Colour.GREEN, "Colour.GREEN"),
(Colour.BLUE, "Colour.BLUE"),
],
)
def test_repr(member: Colour, repr_value: str) -> None:
assert repr(member) == repr_value
@pytest.mark.parametrize(
"member, values",
[
(Colour.RED, ("RED", 1)),
(Colour.GREEN, ("GREEN", 2)),
(Colour.BLUE, ("BLUE", 3)),
(PURPLE, (None, 4)),
],
)
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
assert (member.name, member.value) == values
@pytest.mark.parametrize(
"member, input_str",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_from_string(member: Colour, input_str: str) -> None:
assert Colour.from_string(input_str) == member
@pytest.mark.parametrize(
"member, input_int",
[
(Colour.RED, 1),
(Colour.GREEN, 2),
(Colour.BLUE, 3),
(PURPLE, 4),
],
)
def test_try_value(member: Colour, input_int: int) -> None:
assert Colour.try_value(input_int) == member

View File

@ -545,47 +545,6 @@ def test_oneof_default_value_set_causes_writes_wire():
) )
def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
def test_message_repr(): def test_message_repr():
from tests.output_betterproto.recursivemessage import Test from tests.output_betterproto.recursivemessage import Test
@ -699,25 +658,6 @@ def test_service_argument__expected_parameter():
assert do_thing_request_parameter.annotation == "DoThingRequest" assert do_thing_request_parameter.annotation == "DoThingRequest"
def test_copyability():
@dataclass
class Spam(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
spam = Spam(bar=12, baz=["hello"])
copied = copy(spam)
assert spam == copied
assert spam is not copied
assert spam.baz is copied.baz
deepcopied = deepcopy(spam)
assert spam == deepcopied
assert spam is not deepcopied
assert spam.baz is not deepcopied.baz
def test_is_set(): def test_is_set():
@dataclass @dataclass
class Spam(betterproto.Message): class Spam(betterproto.Message):

203
tests/test_pickling.py Normal file
View File

@ -0,0 +1,203 @@
import pickle
from copy import (
copy,
deepcopy,
)
from dataclasses import dataclass
from typing import (
Dict,
List,
)
from unittest.mock import ANY
import cachelib
import betterproto
from betterproto.lib.google import protobuf as google
def unpickled(message):
return pickle.loads(pickle.dumps(message))
@dataclass(eq=False, repr=False)
class Fe(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fi(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fo(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class NestedData(betterproto.Message):
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
@dataclass(eq=False, repr=False)
class Complex(betterproto.Message):
foo_str: str = betterproto.string_field(1)
fe: "Fe" = betterproto.message_field(3, group="grp")
fi: "Fi" = betterproto.message_field(4, group="grp")
fo: "Fo" = betterproto.message_field(5, group="grp")
nested_data: "NestedData" = betterproto.message_field(6)
mapping: Dict[str, "google.Any"] = betterproto.map_field(
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
def complex_msg():
return Complex(
foo_str="yep",
fe=Fe(abc="1"),
nested_data=NestedData(
struct_foo={
"foo": google.Struct(
fields={
"hello": google.Value(
list_value=google.ListValue(
values=[google.Value(string_value="world")]
)
)
}
),
},
map_str_any_bar={
"key": google.Any(value=b"value"),
},
),
mapping={
"message": google.Any(value=bytes(Fi(abc="hi"))),
"string": google.Any(value=b"howdy"),
},
)
def test_pickling_complex_message():
msg = complex_msg()
deser = unpickled(msg)
assert msg == deser
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)
def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
msg = unpickled(msg)
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
msg = unpickled(msg)
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
@dataclass
class PickledMessage(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
def test_copyability():
msg = PickledMessage(bar=12, baz=["hello"])
msg = unpickled(msg)
copied = copy(msg)
assert msg == copied
assert msg is not copied
assert msg.baz is copied.baz
deepcopied = deepcopy(msg)
assert msg == deepcopied
assert msg is not deepcopied
assert msg.baz is not deepcopied.baz
def test_message_can_be_cached():
"""Cachelib uses pickling to cache values"""
cache = cachelib.SimpleCache()
def use_cache():
calls = getattr(use_cache, "calls", 0)
result = cache.get("message")
if result is not None:
return result
else:
setattr(use_cache, "calls", calls + 1)
result = complex_msg()
cache.set("message", result)
return result
for n in range(10):
if n == 0:
assert not cache.has("message")
else:
assert cache.has("message")
msg = use_cache()
assert use_cache.calls == 1 # The message is only ever built once
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from shutil import which
from subprocess import run
from typing import Optional from typing import Optional
import pytest import pytest
@ -40,6 +42,8 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
streams_path = Path("tests/streams/") streams_path = Path("tests/streams/")
java = which("java")
def test_load_varint_too_long(): def test_load_varint_too_long():
with BytesIO( with BytesIO(
@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path):
assert test_stream.read() == exp_stream.read() assert test_stream.read() == exp_stream.read()
def test_message_dump_delimited(tmp_path):
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
streams_path / "delimited_messages.in", "rb"
) as exp_stream:
assert test_stream.read() == exp_stream.read()
def test_message_len(): def test_message_len():
assert len_oneof == len(bytes(oneof_example)) assert len_oneof == len(bytes(oneof_example))
assert len(nested_example) == len(bytes(nested_example)) assert len(nested_example) == len(bytes(nested_example))
@ -155,7 +171,15 @@ def test_message_load_too_small():
oneof.Test().load(stream, len_oneof - 1) oneof.Test().load(stream, len_oneof - 1)
def test_message_too_large(): def test_message_load_delimited():
with open(streams_path / "delimited_messages.in", "rb") as stream:
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
assert stream.read(1) == b""
def test_message_load_too_large():
with open( with open(
streams_path / "message_dump_file_single.expected", "rb" streams_path / "message_dump_file_single.expected", "rb"
) as stream, pytest.raises(ValueError): ) as stream, pytest.raises(ValueError):
@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path):
streams_path / "dump_varint_positive.expected", "rb" streams_path / "dump_varint_positive.expected", "rb"
) as exp_stream: ) as exp_stream:
assert test_stream.read() == exp_stream.read() assert test_stream.read() == exp_stream.read()
# Java compatibility tests
@pytest.fixture(scope="module")
def compile_jar():
# Skip if not all required tools are present
if java is None:
pytest.skip("`java` command is absent and is required")
mvn = which("mvn")
if mvn is None:
pytest.skip("Maven is absent and is required")
# Compile the JAR
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
if proc_maven.returncode != 0:
pytest.skip(
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
)
jar = "tests/streams/java/target/compatibility-test.jar"
def run_jar(command: str, tmp_path):
return run([java, "-jar", jar, command, tmp_path], check=True)
def run_java_single_varint(value: int, tmp_path) -> int:
# Write single varint to file
with open(tmp_path / "py_single_varint.out", "wb") as stream:
betterproto.dump_varint(value, stream)
# Have Java read this varint and write it back
run_jar("single_varint", tmp_path)
# Read single varint from Java output file
with open(tmp_path / "java_single_varint.out", "rb") as stream:
returned = betterproto.load_varint(stream)
with pytest.raises(EOFError):
betterproto.load_varint(stream)
return returned
def test_single_varint(compile_jar, tmp_path):
single_byte = (1, b"\x01")
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
# Write a single-byte varint to a file and have Java read it back
returned = run_java_single_varint(single_byte[0], tmp_path)
assert returned == single_byte
# Same for a multi-byte varint
returned = run_java_single_varint(multi_byte[0], tmp_path)
assert returned == multi_byte
def test_multiple_varints(compile_jar, tmp_path):
single_byte = (1, b"\x01")
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")
# Write two varints to the same file
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
betterproto.dump_varint(single_byte[0], stream)
betterproto.dump_varint(multi_byte[0], stream)
betterproto.dump_varint(over32[0], stream)
# Have Java read these varints and write them back
run_jar("multiple_varints", tmp_path)
# Read varints from Java output file
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
returned_single = betterproto.load_varint(stream)
returned_multi = betterproto.load_varint(stream)
returned_over32 = betterproto.load_varint(stream)
with pytest.raises(EOFError):
betterproto.load_varint(stream)
assert returned_single == single_byte
assert returned_multi == multi_byte
assert returned_over32 == over32
def test_single_message(compile_jar, tmp_path):
# Write message to file
with open(tmp_path / "py_single_message.out", "wb") as stream:
oneof_example.dump(stream)
# Have Java read and return the message
run_jar("single_message", tmp_path)
# Read and check the returned message
with open(tmp_path / "java_single_message.out", "rb") as stream:
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
assert stream.read() == b""
assert returned == oneof_example
def test_multiple_messages(compile_jar, tmp_path):
# Write delimited messages to file
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
# Have Java read and return the messages
run_jar("multiple_messages", tmp_path)
# Read and check the returned messages
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
assert stream.read() == b""
assert returned_oneof == oneof_example
assert returned_nested == nested_example
def test_infinite_messages(compile_jar, tmp_path):
num_messages = 5
# Write delimited messages to file
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
for x in range(num_messages):
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
# Have Java read and return the messages
run_jar("infinite_messages", tmp_path)
# Read and check the returned messages
messages = []
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
while True:
try:
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
except EOFError:
break
assert len(messages) == num_messages

27
tests/test_timestamp.py Normal file
View File

@ -0,0 +1,27 @@
from datetime import (
datetime,
timezone,
)
import pytest
from betterproto import _Timestamp
@pytest.mark.parametrize(
"dt",
[
datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
datetime.now(timezone.utc),
# potential issue with floating point precision:
datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
# potential issue with negative timestamps:
datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
],
)
def test_timestamp_to_datetime_and_back(dt: datetime):
"""
Make sure converting a datetime to a protobuf timestamp message
and then back again ends up with the same datetime.
"""
assert _Timestamp.from_datetime(dt).to_datetime() == dt