Merge branch 'master_gh'

# Conflicts:
#	src/betterproto/__init__.py
This commit is contained in:
Georg K 2023-11-15 17:33:19 +03:00
commit 1d296f1a88
36 changed files with 2233 additions and 784 deletions

63
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: Bug Report
description: Report broken or incorrect behaviour
labels: ["bug", "investigation needed"]
body:
- type: markdown
attributes:
value: >
Thanks for taking the time to fill out a bug report!
If you're not sure it's a bug and you just have a question, the [community Slack channel](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ) is a better place for general questions than a GitHub issue.
- type: input
attributes:
label: Summary
description: A simple summary of your bug report
validations:
required: true
- type: textarea
attributes:
label: Reproduction Steps
description: >
What you did to make it happen.
Ideally there should be a short code snippet in this section to help reproduce the bug.
validations:
required: true
- type: textarea
attributes:
label: Expected Results
description: >
What did you expect to happen?
validations:
required: true
- type: textarea
attributes:
label: Actual Results
description: >
What actually happened?
validations:
required: true
- type: textarea
attributes:
label: System Information
description: >
Paste the result of `protoc --version; python --version; pip show betterproto` below.
validations:
required: true
- type: checkboxes
attributes:
label: Checklist
options:
- label: I have searched the issues for duplicates.
required: true
- label: I have shown the entire traceback, if possible.
required: true
- label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible.
required: true

6
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,6 @@
name:
description:
contact_links:
- name: For questions about the library
about: Support questions are better answered in our Slack group.
url: https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ

View File

@ -0,0 +1,49 @@
name: Feature Request
description: Suggest a feature for this library
labels: ["enhancement"]
body:
- type: input
attributes:
label: Summary
description: >
What problem is your feature trying to solve? What would become easier or possible if feature was implemented?
validations:
required: true
- type: dropdown
attributes:
multiple: false
label: What is the feature request for?
options:
- The core library
- RPC handling
- The documentation
validations:
required: true
- type: textarea
attributes:
label: The Problem
description: >
What problem is your feature trying to solve?
What would become easier or possible if feature was implemented?
validations:
required: true
- type: textarea
attributes:
label: The Ideal Solution
description: >
What is your ideal solution to the problem?
What would you like this feature to do?
validations:
required: true
- type: textarea
attributes:
label: The Current Solution
description: >
What is the current solution to the problem, if any?
validations:
required: false

16
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,16 @@
## Summary
<!-- What is this pull request for? Does it fix any issues? -->
## Checklist
<!-- Put an x inside [ ] to check it, like so: [x] -->
- [ ] If code changes were made then they have been tested.
- [ ] I have updated the documentation to reflect the changes.
- [ ] This PR fixes an issue.
- [ ] This PR adds something new (e.g. new method or parameters).
- [ ] This change has an associated test.
- [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed)
- [ ] This PR is **not** a code change (e.g. documentation, README, ...)

46
.github/workflows/codeql-analysis.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: "CodeQL"
on:
push:
branches: [ "master" ]
pull_request:
branches:
- '**'
schedule:
- cron: '19 1 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'python' ]
steps:
- name: Checkout repository
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
- name: Autobuild
uses: github/codeql-action/autobuild@v2
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2

View File

@ -16,6 +16,13 @@ repos:
- repo: https://github.com/PyCQA/doc8
rev: 0.10.1
hooks:
- id: doc8
- id: doc8
additional_dependencies:
- toml
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.10.0
hooks:
- id: pretty-format-java
args: [--autofix, --aosp]
files: ^.*\.java$

1200
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -19,26 +19,29 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true}
typing-extensions = "^4.7.1"
[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
asv = "^0.4.2"
bpython = "^0.19"
grpcio-tools = "^1.54.2"
jinja2 = ">=3.0.3"
mypy = "^0.930"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
pre-commit = "^2.17.0"
grpcio-tools = "^1.54.2"
tox = "^4.0.0"
[tool.poetry.group.test.dependencies]
poethepoet = ">=0.9.0"
protobuf = "^4.21.6"
pytest = "^6.2.5"
pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0"
tox = "^3.15.1"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"
pydantic = ">=1.8.0,<2"
protobuf = "^4"
cachelib = "^0.10.2"
tomlkit = ">=0.7.0"
[tool.poetry.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main"
@ -61,9 +64,13 @@ help = "Run tests"
cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy"
[tool.poe.tasks]
_black = "black . --exclude tests/output_ --target-version py310"
_isort = "isort . --extend-skip-glob 'tests/output_*/**/*'"
[tool.poe.tasks.format]
cmd = "black . --exclude tests/output_ --target-version py310"
help = "Apply black formatting to source code"
sequence = ["_black", "_isort"]
help = "Apply black and isort formatting to source code"
[tool.poe.tasks.docs]
cmd = "sphinx-build docs docs/build"
@ -130,14 +137,21 @@ omit = ["betterproto/tests/*"]
[tool.tox]
legacy_tox_ini = """
[tox]
isolated_build = true
envlist = py37, py38, py310
requires =
tox>=4.2
tox-poetry-installer[poetry]==1.0.0b1
env_list =
py311
py38
py37
[testenv]
whitelist_externals = poetry
commands =
poetry install -v --extras compiler
poetry run pytest --cov betterproto
pytest {posargs: --cov betterproto}
poetry_dep_groups =
test
require_locked_deps = true
require_poetry = true
"""
[build-system]

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import dataclasses
import enum
import enum as builtin_enum
import json
import math
import struct
@ -22,8 +24,8 @@ from itertools import count
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
ClassVar,
Dict,
Generator,
Iterable,
@ -37,6 +39,7 @@ from typing import (
)
from dateutil.parser import isoparse
from typing_extensions import Self
from ._types import T
from ._version import __version__
@ -45,11 +48,19 @@ from .casing import (
safe_snake_case,
snake_case,
)
from .grpc.grpclib_client import ServiceStub
from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import (
classproperty,
hybridmethod,
)
if TYPE_CHECKING:
from _typeshed import ReadableBuffer
from _typeshed import (
SupportsRead,
SupportsWrite,
)
# Proto 3 data types
@ -126,6 +137,9 @@ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Indicator of message delimitation in streams
SIZE_DELIMITED = -1
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen() -> datetime:
@ -140,7 +154,7 @@ NEG_INFINITY = "-Infinity"
NAN = "NaN"
class Casing(enum.Enum):
class Casing(builtin_enum.Enum):
"""Casing constants for serialization."""
CAMEL = camel_case #: A camelCase sterilization function.
@ -309,32 +323,6 @@ def map_field(
)
class Enum(enum.IntEnum):
"""
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""
@classmethod
def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary."""
return {
@ -347,7 +335,7 @@ def _pack_fmt(proto_type: str) -> str:
}[proto_type]
def dump_varint(value: int, stream: BinaryIO) -> None:
def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
"""Encodes a single varint and dumps it into the provided stream."""
if value < -(1 << 63):
raise ValueError(
@ -556,7 +544,7 @@ def _dump_float(value: float) -> Union[float, str]:
return value
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
"""
Load a single varint value from a stream. Returns the value and the raw bytes read.
"""
@ -594,7 +582,7 @@ class ParsedField:
raw: bytes
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
while True:
try:
num_wire, raw = load_varint(stream)
@ -748,6 +736,7 @@ class Message(ABC):
_serialized_on_wire: bool
_unknown_fields: bytes
_group_current: Dict[str, str]
_betterproto_meta: ClassVar[ProtoClassMetadata]
def __post_init__(self) -> None:
# Keep track of whether every field was default
@ -815,6 +804,10 @@ class Message(ABC):
]
return f"{self.__class__.__name__}({', '.join(parts)})"
def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
for field_name in self._betterproto.sorted_field_names:
yield field_name, self.__raw_get(field_name), PLACEHOLDER
if not TYPE_CHECKING:
def __getattribute__(self, name: str) -> Any:
@ -889,20 +882,28 @@ class Message(ABC):
kwargs[name] = deepcopy(value)
return self.__class__(**kwargs) # type: ignore
@property
def _betterproto(self) -> ProtoClassMetadata:
def __copy__(self: T, _: Any = {}) -> T:
kwargs = {}
for name in self._betterproto.sorted_field_names:
value = self.__raw_get(name)
if value is not PLACEHOLDER:
kwargs[name] = value
return self.__class__(**kwargs) # type: ignore
@classproperty
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness.
"""
meta = getattr(self.__class__, "_betterproto_meta", None)
if not meta:
meta = ProtoClassMetadata(self.__class__)
self.__class__._betterproto_meta = meta # type: ignore
return meta
try:
return cls._betterproto_meta
except AttributeError:
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
return meta
def dump(self, stream: BinaryIO) -> None:
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
"""
Dumps the binary encoded Protobuf message to the stream.
@ -910,7 +911,11 @@ class Message(ABC):
-----------
stream: :class:`BinaryIO`
The stream to dump the message to.
delimit:
Whether to prefix the message with a varint declaring its size.
"""
if delimit == SIZE_DELIMITED:
dump_varint(len(self), stream)
for field_name, meta in self._betterproto.meta_by_field_name.items():
try:
@ -930,7 +935,7 @@ class Message(ABC):
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = bool(meta.group)
selected_in_group = bool(meta.group) or meta.optional
# Empty messages can still be sent on the wire if they were
# set (or received empty).
@ -1124,6 +1129,15 @@ class Message(ABC):
"""
return bytes(self)
def __getstate__(self) -> bytes:
return bytes(self)
def __setstate__(self: T, pickled_bytes: bytes) -> T:
return self.parse(pickled_bytes)
def __reduce__(self) -> Tuple[Any, ...]:
return (self.__class__.FromString, (bytes(self),))
@classmethod
def _type_hint(cls, field_name: str) -> Type:
return cls._type_hints()[field_name]
@ -1168,7 +1182,7 @@ class Message(ABC):
return t
elif issubclass(t, Enum):
# Enums always default to zero.
return int
return t.try_value
elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen
@ -1193,6 +1207,9 @@ class Message(ABC):
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
elif meta.proto_type == TYPE_ENUM:
# Convert enum ints to python enum instances
value = self._betterproto.cls_by_field[field_name].try_value(value)
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
@ -1225,7 +1242,11 @@ class Message(ABC):
meta.group is not None and self._group_current.get(meta.group) == field_name
)
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
def load(
self: T,
stream: "SupportsRead[bytes]",
size: Optional[int] = None,
) -> T:
"""
Load the binary encoded Protobuf from a stream into this message instance. This
returns the instance itself and is therefore assignable and chainable.
@ -1237,12 +1258,17 @@ class Message(ABC):
size: :class:`Optional[int]`
The size of the message in the stream.
Reads stream until EOF if ``None`` is given.
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
Returns
--------
:class:`Message`
The initialized message.
"""
# If the message is delimited, parse the message delimiter
if size == SIZE_DELIMITED:
size, _ = load_varint(stream)
# Got some data over the wire
self._serialized_on_wire = True
proto_meta = self._betterproto
@ -1315,7 +1341,7 @@ class Message(ABC):
return self
def parse(self: T, data: "ReadableBuffer") -> T:
def parse(self: T, data: bytes) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
@ -1494,7 +1520,91 @@ class Message(ABC):
output[cased_name] = value
return output
def from_dict(self: T, value: Mapping[str, Any]) -> T:
@classmethod
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
init_kwargs: Dict[str, Any] = {}
for key, value in mapping.items():
field_name = safe_snake_case(key)
try:
meta = cls._betterproto.meta_by_field_name[field_name]
except KeyError:
continue
if value is None:
continue
if meta.proto_type == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[field_name]
if sub_cls == datetime:
value = (
[isoparse(item) for item in value]
if isinstance(value, list)
else isoparse(value)
)
elif sub_cls == timedelta:
value = (
[timedelta(seconds=float(item[:-1])) for item in value]
if isinstance(value, list)
else timedelta(seconds=float(value[:-1]))
)
elif not meta.wraps:
value = (
[sub_cls.from_dict(item) for item in value]
if isinstance(value, list)
else sub_cls.from_dict(value)
)
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
else:
if meta.proto_type in INT_64_TYPES:
value = (
[int(n) for n in value]
if isinstance(value, list)
else int(value)
)
elif meta.proto_type == TYPE_BYTES:
value = (
[b64decode(n) for n in value]
if isinstance(value, list)
else b64decode(value)
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]
if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
elif isinstance(value, str):
value = enum_cls.from_string(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
if isinstance(value, list)
else _parse_float(value)
)
init_kwargs[field_name] = value
return init_kwargs
@hybridmethod
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
"""
Parse the key/value pairs into the a new message instance.
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
"""
self = cls(**cls._from_dict_init(value))
self._serialized_on_wire = True
return self
@from_dict.instancemethod
def from_dict(self, value: Mapping[str, Any]) -> Self:
"""
Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable.
@ -1510,71 +1620,8 @@ class Message(ABC):
The initialized message.
"""
self._serialized_on_wire = True
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif cls == datetime:
v = isoparse(value[key])
setattr(self, field_name, v)
elif cls == timedelta:
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
if v is not None:
setattr(self, field_name, v)
for field, value in self._from_dict_init(value).items():
setattr(self, field, value)
return self
def to_json(
@ -1791,8 +1838,8 @@ class Message(ABC):
@classmethod
def _validate_field_groups(cls, values):
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
group_to_one_ofs = cls._betterproto.oneof_field_by_group
field_name_to_meta = cls._betterproto.meta_by_field_name
for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1:
@ -1819,6 +1866,9 @@ class Message(ABC):
return values
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
def serialized_on_wire(message: Message) -> bool:
"""
If this message was or should be serialized on the wire. This can be used to detect
@ -1890,17 +1940,24 @@ class _Duration(Duration):
class _Timestamp(Timestamp):
@classmethod
def from_datetime(cls, dt: datetime) -> "_Timestamp":
seconds = int(dt.timestamp())
nanos = int(dt.microsecond * 1e3)
return cls(seconds, nanos)
# manual epoch offset calulation to avoid rounding errors,
# to support negative timestamps (before 1970) and skirt
# around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
offset = dt - DATETIME_ZERO
# below is the same as timedelta.total_seconds() but without dividing by 1e6
# so we end up with microseconds as integers instead of seconds as float
offset_us = (
offset.days * 24 * 60 * 60 + offset.seconds
) * 10**6 + offset.microseconds
seconds, us = divmod(offset_us, 10**6)
return cls(seconds, us * 1000)
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
if ts < 0:
return datetime(1970, 1, 1) + timedelta(seconds=ts)
else:
return datetime.fromtimestamp(ts, tz=timezone.utc)
# datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
# if we pass it as a floating point number, we will run into rounding errors
# see also #407
offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
return DATETIME_ZERO + offset
@staticmethod
def timestamp_to_json(dt: datetime) -> str:

View File

@ -136,4 +136,8 @@ def lowercase_first(value: str) -> str:
def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
return f"{value}_" if keyword.iskeyword(value) else value
if keyword.iskeyword(value):
return f"{value}_"
if not value.isidentifier():
return f"_{value}"
return value

View File

@ -11,3 +11,11 @@ def pythonize_field_name(name: str) -> str:
def pythonize_method_name(name: str) -> str:
return casing.safe_snake_case(name)
def pythonize_enum_member_name(name: str, enum_name: str) -> str:
enum_name = casing.snake_case(enum_name).upper()
find = name.find(enum_name)
if find != -1:
name = name[find + len(enum_name) :].strip("_")
return casing.sanitize_name(name)

196
src/betterproto/enum.py Normal file
View File

@ -0,0 +1,196 @@
from __future__ import annotations
import sys
from enum import (
EnumMeta,
IntEnum,
)
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Tuple,
)
if TYPE_CHECKING:
from collections.abc import (
Generator,
Mapping,
)
from typing_extensions import (
Never,
Self,
)
def _is_descriptor(obj: object) -> bool:
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)
class EnumType(EnumMeta if TYPE_CHECKING else type):
_value_map_: Mapping[int, Enum]
_member_map_: Mapping[str, Enum]
def __new__(
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
) -> Self:
value_map = {}
member_map = {}
new_mcs = type(
f"{name}Type",
tuple(
dict.fromkeys(
[base.__class__ for base in bases if base.__class__ is not type]
+ [EnumType, type]
)
), # reorder the bases so EnumType and type are last to avoid conflicts
{"_value_map_": value_map, "_member_map_": member_map},
)
members = {
name: value
for name, value in namespace.items()
if not _is_descriptor(value) and not name.startswith("__")
}
cls = type.__new__(
new_mcs,
name,
bases,
{key: value for key, value in namespace.items() if key not in members},
)
# this allows us to disallow member access from other members as
# members become proper class variables
for name, value in members.items():
member = value_map.get(value)
if member is None:
member = cls.__new__(cls, name=name, value=value) # type: ignore
value_map[value] = member
member_map[name] = member
type.__setattr__(new_mcs, name, member)
return cls
if not TYPE_CHECKING:
def __call__(cls, value: int) -> Enum:
try:
return cls._value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
def __iter__(cls) -> Generator[Enum, None, None]:
yield from cls._member_map_.values()
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(cls._member_map_.values())
else:
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(tuple(cls._member_map_.values()))
def __getitem__(cls, key: str) -> Enum:
return cls._member_map_[key]
@property
def __members__(cls) -> MappingProxyType[str, Enum]:
return MappingProxyType(cls._member_map_)
def __repr__(cls) -> str:
return f"<enum {cls.__name__!r}>"
def __len__(cls) -> int:
return len(cls._member_map_)
def __setattr__(cls, name: str, value: Any) -> Never:
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
def __delattr__(cls, name: str) -> Never:
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
def __contains__(cls, member: object) -> bool:
return isinstance(member, cls) and member.name in cls._member_map_
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
"""
The base class for protobuf enumerations, all generated enumerations will
inherit from this. Emulates `enum.IntEnum`.
"""
name: Optional[str]
value: int
if not TYPE_CHECKING:
def __new__(cls, *, name: Optional[str], value: int) -> Self:
self = super().__new__(cls, value)
super().__setattr__(self, "name", name)
super().__setattr__(self, "value", value)
return self
def __str__(self) -> str:
return self.name or "None"
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __setattr__(self, key: str, value: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot reassign a member's attributes."
)
def __delattr__(self, item: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot delete a member's attributes."
)
@classmethod
def try_value(cls, value: int = 0) -> Self:
"""Return the value which corresponds to the value.
Parameters
-----------
value: :class:`int`
The value of the enum member to get.
Returns
-------
:class:`Enum`
The corresponding member or a new instance of the enum if
``value`` isn't actually a member.
"""
try:
return cls._value_map_[value]
except (KeyError, TypeError):
return cls.__new__(cls, name=None, value=value)
@classmethod
def from_string(cls, name: str) -> Self:
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get.
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

View File

@ -127,6 +127,7 @@ class ServiceStub(ABC):
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_request()
await self._send_messages(stream, request_iterator)
response = await stream.recv_message()
assert response is not None

View File

@ -72,13 +72,13 @@ from betterproto.lib.google.protobuf import (
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
from ..casing import sanitize_name
from ..compile.importing import (
get_type_reference,
parse_source_type_name,
)
from ..compile.naming import (
pythonize_class_name,
pythonize_enum_member_name,
pythonize_field_name,
pythonize_method_name,
)
@ -385,7 +385,10 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
us to tell whether it was set, via the which_one_of interface.
"""
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
return (
not proto_field_obj.proto3_optional
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
)
@dataclass
@ -670,7 +673,9 @@ class EnumDefinitionCompiler(MessageCompiler):
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
name=sanitize_name(entry_proto_value.name),
name=pythonize_enum_member_name(
entry_proto_value.name, self.proto_obj.name
),
value=entry_proto_value.number,
comment=get_comment(
proto_file=self.source_file, path=self.path + [2, entry_number]

56
src/betterproto/utils.py Normal file
View File

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import (
Any,
Callable,
Generic,
Optional,
Type,
TypeVar,
)
from typing_extensions import (
Concatenate,
ParamSpec,
Self,
)
SelfT = TypeVar("SelfT")
P = ParamSpec("P")
HybridT = TypeVar("HybridT", covariant=True)
class hybridmethod(Generic[SelfT, P, HybridT]):
def __init__(
self,
func: Callable[
Concatenate[type[SelfT], P], HybridT
], # Must be the classmethod version
):
self.cls_func = func
self.__doc__ = func.__doc__
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
self.instance_func = func
return self
def __get__(
self, instance: Optional[SelfT], owner: Type[SelfT]
) -> Callable[P, HybridT]:
if instance is None or self.instance_func is None:
# either bound to the class, or no instance method available
return self.cls_func.__get__(owner, None)
return self.instance_func.__get__(instance, owner)
T_co = TypeVar("T_co")
TT_co = TypeVar("TT_co", bound="type[Any]")
class classproperty(Generic[TT_co, T_co]):
def __init__(self, func: Callable[[TT_co], T_co]):
self.__func__ = func
def __get__(self, instance: Any, type: TT_co) -> T_co:
return self.__func__(type)

View File

@ -272,3 +272,27 @@ async def test_async_gen_for_stream_stream_request():
assert response_index == len(
expected_things
), "Didn't receive all expected responses"
@pytest.mark.asyncio
async def test_stream_unary_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [DoThingRequest(name) for name in things]
response = await client.do_many_things(requests)
assert len(response.names) == 0
@pytest.mark.asyncio
async def test_stream_stream_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [GetThingRequest(name) for name in things]
responses = [
response async for response in client.get_different_things(requests)
]
assert len(responses) == 0

View File

@ -27,7 +27,7 @@ class ThingService:
async def do_many_things(
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
):
thing_names = [request.name for request in stream]
thing_names = [request.name async for request in stream]
if self.test_hook is not None:
self.test_hook(stream)
await stream.send_message(DoThingResponse(thing_names))

View File

@ -15,3 +15,11 @@ enum Choice {
FOUR = 4;
THREE = 3;
}
// A "C" like enum with the enum name prefixed onto members, these should be stripped
enum ArithmeticOperator {
ARITHMETIC_OPERATOR_NONE = 0;
ARITHMETIC_OPERATOR_PLUS = 1;
ARITHMETIC_OPERATOR_MINUS = 2;
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
}

View File

@ -1,4 +1,5 @@
from tests.output_betterproto.enum import (
ArithmeticOperator,
Choice,
Test,
)
@ -82,3 +83,32 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
yield Choice.THREE
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
def test_enum_mapped_on_parse():
# test default value
b = Test().parse(bytes(Test()))
assert b.choice.name == Choice.ZERO.name
assert b.choices == []
# test non default value
a = Test().parse(bytes(Test(choice=Choice.ONE)))
assert a.choice.name == Choice.ONE.name
assert b.choices == []
# test repeated
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
assert c.choices[0].name == Choice.THREE.name
assert c.choices[1].name == Choice.FOUR.name
# bonus: defaults after empty init are also mapped
assert Test().choice.name == Choice.ZERO.name
def test_renamed_enum_members():
assert set(ArithmeticOperator.__members__) == {
"NONE",
"PLUS",
"MINUS",
"_0_PREFIXED",
}

View File

@ -1,5 +1,6 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
package google_impl_behavior_equivalence;
message Foo { int64 bar = 1; }
@ -12,6 +13,10 @@ message Test {
}
}
message Spam {
google.protobuf.Timestamp ts = 1;
}
message Request { Empty foo = 1; }
message Empty {}
message Empty {}

View File

@ -1,17 +1,25 @@
from datetime import (
datetime,
timezone,
)
import pytest
from google.protobuf import json_format
from google.protobuf.timestamp_pb2 import Timestamp
import betterproto
from tests.output_betterproto.google_impl_behavior_equivalence import (
Empty,
Foo,
Request,
Spam,
Test,
)
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
Empty as ReferenceEmpty,
Foo as ReferenceFoo,
Request as ReferenceRequest,
Spam as ReferenceSpam,
Test as ReferenceTest,
)
@ -59,6 +67,19 @@ def test_bytes_are_the_same_for_oneof():
assert isinstance(message_reference2.foo, ReferenceFoo)
@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
def test_datetime_clamping(dt): # see #407
ts = Timestamp()
ts.FromDatetime(dt)
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
message_bytes = bytes(Spam(dt))
assert (
Spam().parse(message_bytes).ts.timestamp()
== ReferenceSpam.FromString(message_bytes).ts.seconds
)
def test_empty_message_field():
message = Request()
reference_message = ReferenceRequest()

View File

@ -2,6 +2,10 @@ syntax = "proto3";
package oneof;
message MixedDrink {
int32 shots = 1;
}
message Test {
oneof foo {
int32 pitied = 1;
@ -13,6 +17,7 @@ message Test {
oneof bar {
int32 drinks = 11;
string bar_name = 12;
MixedDrink mixed_drink = 13;
}
}

View File

@ -1,5 +1,10 @@
import pytest
import betterproto
from tests.output_betterproto.oneof import Test
from tests.output_betterproto.oneof import (
MixedDrink,
Test,
)
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
from tests.util import get_test_case_json_data
@ -19,3 +24,20 @@ def test_which_name():
def test_which_count_pyd():
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
def test_oneof_constructor_assign():
message = Test(mixed_drink=MixedDrink(shots=42))
field, value = betterproto.which_one_of(message, "bar")
assert field == "mixed_drink"
assert value.shots == 42
# Issue #305:
@pytest.mark.xfail
def test_oneof_nested_assign():
message = Test()
message.mixed_drink.shots = 42
field, value = betterproto.which_one_of(message, "bar")
assert field == "mixed_drink"
assert value.shots == 42

View File

@ -41,3 +41,8 @@ def test_null_fields_json():
"test8": None,
"test9": None,
}
def test_unset_access(): # see #523
assert Test().test1 is None
assert Test(test1=None).test1 is None

View File

@ -0,0 +1,2 @@
•šï:bTesting•šï:bTesting
 

38
tests/streams/java/.gitignore vendored Normal file
View File

@ -0,0 +1,38 @@
### Output ###
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
dependency-reduced-pom.xml
MANIFEST.MF
### IntelliJ IDEA ###
.idea/
*.iws
*.iml
*.ipr
### Eclipse ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
### VS Code ###
.vscode/
### Mac OS ###
.DS_Store

View File

@ -0,0 +1,94 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>betterproto</groupId>
<artifactId>compatibility-test</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<protobuf.version>3.23.4</protobuf.version>
</properties>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
</dependencies>
<build>
<extensions>
<extension>
<groupId>kr.motd.maven</groupId>
<artifactId>os-maven-plugin</artifactId>
<version>1.7.1</version>
</extension>
</extensions>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.5.0</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>betterproto.CompatibilityTest</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<configuration>
<archive>
<manifest>
<addClasspath>true</addClasspath>
<mainClass>betterproto.CompatibilityTest</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.6.1</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
</goals>
</execution>
</executions>
<configuration>
<protocArtifact>
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
</protocArtifact>
</configuration>
</plugin>
</plugins>
<finalName>${project.artifactId}</finalName>
</build>
</project>

View File

@ -0,0 +1,41 @@
package betterproto;
import java.io.IOException;
public class CompatibilityTest {
public static void main(String[] args) throws IOException {
if (args.length < 2)
throw new RuntimeException("Attempted to run without the required arguments.");
else if (args.length > 2)
throw new RuntimeException(
"Attempted to run with more than the expected number of arguments (>1).");
Tests tests = new Tests(args[1]);
switch (args[0]) {
case "single_varint":
tests.testSingleVarint();
break;
case "multiple_varints":
tests.testMultipleVarints();
break;
case "single_message":
tests.testSingleMessage();
break;
case "multiple_messages":
tests.testMultipleMessages();
break;
case "infinite_messages":
tests.testInfiniteMessages();
break;
default:
throw new RuntimeException(
"Attempted to run with unknown argument '" + args[0] + "'.");
}
}
}

View File

@ -0,0 +1,115 @@
package betterproto;
import betterproto.nested.NestedOuterClass;
import betterproto.oneof.Oneof;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
public class Tests {
String path;
public Tests(String path) {
this.path = path;
}
public void testSingleVarint() throws IOException {
// Read in the Python-generated single varint file
FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
int value = codedInput.readUInt32();
inputStream.close();
// Write the value back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
codedOutput.writeUInt32NoTag(value);
codedOutput.flush();
outputStream.close();
}
public void testMultipleVarints() throws IOException {
// Read in the Python-generated multiple varints file
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
int value1 = codedInput.readUInt32();
int value2 = codedInput.readUInt32();
long value3 = codedInput.readUInt64();
inputStream.close();
// Write the values back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
codedOutput.writeUInt32NoTag(value1);
codedOutput.writeUInt64NoTag(value2);
codedOutput.writeUInt64NoTag(value3);
codedOutput.flush();
outputStream.close();
}
public void testSingleMessage() throws IOException {
// Read in the Python-generated single message file
FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
Oneof.Test message = Oneof.Test.parseFrom(codedInput);
inputStream.close();
// Write the message back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
message.writeTo(codedOutput);
codedOutput.flush();
outputStream.close();
}
public void testMultipleMessages() throws IOException {
// Read in the Python-generated multi-message file
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
inputStream.close();
// Write the messages back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
oneof.writeDelimitedTo(outputStream);
nested.writeDelimitedTo(outputStream);
outputStream.flush();
outputStream.close();
}
public void testInfiniteMessages() throws IOException {
// Read in as many messages as are present in the Python-generated file and write them back
FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
while (current != null) {
current.writeDelimitedTo(outputStream);
current = Oneof.Test.parseDelimitedFrom(inputStream);
}
inputStream.close();
outputStream.flush();
outputStream.close();
}
}

View File

@ -0,0 +1,27 @@
syntax = "proto3";
package nested;
option java_package = "betterproto.nested";
// A test message with a nested message inside of it.
message Test {
// This is the nested type.
message Nested {
// Stores a simple counter.
int32 count = 1;
}
// This is the nested enum.
enum Msg {
NONE = 0;
THIS = 1;
}
Nested nested = 1;
Sibling sibling = 2;
Sibling sibling2 = 3;
Msg msg = 4;
}
message Sibling {
int32 foo = 1;
}

View File

@ -0,0 +1,19 @@
syntax = "proto3";
package oneof;
option java_package = "betterproto.oneof";
message Test {
oneof foo {
int32 pitied = 1;
string pitier = 2;
}
int32 just_a_regular_field = 3;
oneof bar {
int32 drinks = 11;
string bar_name = 12;
}
}

79
tests/test_enum.py Normal file
View File

@ -0,0 +1,79 @@
from typing import (
Optional,
Tuple,
)
import pytest
import betterproto
class Colour(betterproto.Enum):
RED = 1
GREEN = 2
BLUE = 3
PURPLE = Colour.__new__(Colour, name=None, value=4)
@pytest.mark.parametrize(
"member, str_value",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_str(member: Colour, str_value: str) -> None:
assert str(member) == str_value
@pytest.mark.parametrize(
"member, repr_value",
[
(Colour.RED, "Colour.RED"),
(Colour.GREEN, "Colour.GREEN"),
(Colour.BLUE, "Colour.BLUE"),
],
)
def test_repr(member: Colour, repr_value: str) -> None:
assert repr(member) == repr_value
@pytest.mark.parametrize(
"member, values",
[
(Colour.RED, ("RED", 1)),
(Colour.GREEN, ("GREEN", 2)),
(Colour.BLUE, ("BLUE", 3)),
(PURPLE, (None, 4)),
],
)
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
assert (member.name, member.value) == values
@pytest.mark.parametrize(
"member, input_str",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_from_string(member: Colour, input_str: str) -> None:
assert Colour.from_string(input_str) == member
@pytest.mark.parametrize(
"member, input_int",
[
(Colour.RED, 1),
(Colour.GREEN, 2),
(Colour.BLUE, 3),
(PURPLE, 4),
],
)
def test_try_value(member: Colour, input_int: int) -> None:
assert Colour.try_value(input_int) == member

View File

@ -545,47 +545,6 @@ def test_oneof_default_value_set_causes_writes_wire():
)
def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
def test_message_repr():
from tests.output_betterproto.recursivemessage import Test
@ -699,25 +658,6 @@ def test_service_argument__expected_parameter():
assert do_thing_request_parameter.annotation == "DoThingRequest"
def test_copyability():
@dataclass
class Spam(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
spam = Spam(bar=12, baz=["hello"])
copied = copy(spam)
assert spam == copied
assert spam is not copied
assert spam.baz is copied.baz
deepcopied = deepcopy(spam)
assert spam == deepcopied
assert spam is not deepcopied
assert spam.baz is not deepcopied.baz
def test_is_set():
@dataclass
class Spam(betterproto.Message):

203
tests/test_pickling.py Normal file
View File

@ -0,0 +1,203 @@
import pickle
from copy import (
copy,
deepcopy,
)
from dataclasses import dataclass
from typing import (
Dict,
List,
)
from unittest.mock import ANY
import cachelib
import betterproto
from betterproto.lib.google import protobuf as google
def unpickled(message):
return pickle.loads(pickle.dumps(message))
@dataclass(eq=False, repr=False)
class Fe(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fi(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fo(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class NestedData(betterproto.Message):
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
@dataclass(eq=False, repr=False)
class Complex(betterproto.Message):
foo_str: str = betterproto.string_field(1)
fe: "Fe" = betterproto.message_field(3, group="grp")
fi: "Fi" = betterproto.message_field(4, group="grp")
fo: "Fo" = betterproto.message_field(5, group="grp")
nested_data: "NestedData" = betterproto.message_field(6)
mapping: Dict[str, "google.Any"] = betterproto.map_field(
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
def complex_msg():
return Complex(
foo_str="yep",
fe=Fe(abc="1"),
nested_data=NestedData(
struct_foo={
"foo": google.Struct(
fields={
"hello": google.Value(
list_value=google.ListValue(
values=[google.Value(string_value="world")]
)
)
}
),
},
map_str_any_bar={
"key": google.Any(value=b"value"),
},
),
mapping={
"message": google.Any(value=bytes(Fi(abc="hi"))),
"string": google.Any(value=b"howdy"),
},
)
def test_pickling_complex_message():
msg = complex_msg()
deser = unpickled(msg)
assert msg == deser
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)
def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
msg = unpickled(msg)
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
msg = unpickled(msg)
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
@dataclass
class PickledMessage(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
def test_copyability():
msg = PickledMessage(bar=12, baz=["hello"])
msg = unpickled(msg)
copied = copy(msg)
assert msg == copied
assert msg is not copied
assert msg.baz is copied.baz
deepcopied = deepcopy(msg)
assert msg == deepcopied
assert msg is not deepcopied
assert msg.baz is not deepcopied.baz
def test_message_can_be_cached():
"""Cachelib uses pickling to cache values"""
cache = cachelib.SimpleCache()
def use_cache():
calls = getattr(use_cache, "calls", 0)
result = cache.get("message")
if result is not None:
return result
else:
setattr(use_cache, "calls", calls + 1)
result = complex_msg()
cache.set("message", result)
return result
for n in range(10):
if n == 0:
assert not cache.has("message")
else:
assert cache.has("message")
msg = use_cache()
assert use_cache.calls == 1 # The message is only ever built once
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from shutil import which
from subprocess import run
from typing import Optional
import pytest
@ -40,6 +42,8 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
streams_path = Path("tests/streams/")
java = which("java")
def test_load_varint_too_long():
with BytesIO(
@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path):
assert test_stream.read() == exp_stream.read()
def test_message_dump_delimited(tmp_path):
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
streams_path / "delimited_messages.in", "rb"
) as exp_stream:
assert test_stream.read() == exp_stream.read()
def test_message_len():
assert len_oneof == len(bytes(oneof_example))
assert len(nested_example) == len(bytes(nested_example))
@ -155,7 +171,15 @@ def test_message_load_too_small():
oneof.Test().load(stream, len_oneof - 1)
def test_message_too_large():
def test_message_load_delimited():
with open(streams_path / "delimited_messages.in", "rb") as stream:
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
assert stream.read(1) == b""
def test_message_load_too_large():
with open(
streams_path / "message_dump_file_single.expected", "rb"
) as stream, pytest.raises(ValueError):
@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path):
streams_path / "dump_varint_positive.expected", "rb"
) as exp_stream:
assert test_stream.read() == exp_stream.read()
# Java compatibility tests
@pytest.fixture(scope="module")
def compile_jar():
# Skip if not all required tools are present
if java is None:
pytest.skip("`java` command is absent and is required")
mvn = which("mvn")
if mvn is None:
pytest.skip("Maven is absent and is required")
# Compile the JAR
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
if proc_maven.returncode != 0:
pytest.skip(
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
)
jar = "tests/streams/java/target/compatibility-test.jar"
def run_jar(command: str, tmp_path):
return run([java, "-jar", jar, command, tmp_path], check=True)
def run_java_single_varint(value: int, tmp_path) -> int:
# Write single varint to file
with open(tmp_path / "py_single_varint.out", "wb") as stream:
betterproto.dump_varint(value, stream)
# Have Java read this varint and write it back
run_jar("single_varint", tmp_path)
# Read single varint from Java output file
with open(tmp_path / "java_single_varint.out", "rb") as stream:
returned = betterproto.load_varint(stream)
with pytest.raises(EOFError):
betterproto.load_varint(stream)
return returned
def test_single_varint(compile_jar, tmp_path):
single_byte = (1, b"\x01")
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
# Write a single-byte varint to a file and have Java read it back
returned = run_java_single_varint(single_byte[0], tmp_path)
assert returned == single_byte
# Same for a multi-byte varint
returned = run_java_single_varint(multi_byte[0], tmp_path)
assert returned == multi_byte
def test_multiple_varints(compile_jar, tmp_path):
single_byte = (1, b"\x01")
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")
# Write two varints to the same file
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
betterproto.dump_varint(single_byte[0], stream)
betterproto.dump_varint(multi_byte[0], stream)
betterproto.dump_varint(over32[0], stream)
# Have Java read these varints and write them back
run_jar("multiple_varints", tmp_path)
# Read varints from Java output file
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
returned_single = betterproto.load_varint(stream)
returned_multi = betterproto.load_varint(stream)
returned_over32 = betterproto.load_varint(stream)
with pytest.raises(EOFError):
betterproto.load_varint(stream)
assert returned_single == single_byte
assert returned_multi == multi_byte
assert returned_over32 == over32
def test_single_message(compile_jar, tmp_path):
# Write message to file
with open(tmp_path / "py_single_message.out", "wb") as stream:
oneof_example.dump(stream)
# Have Java read and return the message
run_jar("single_message", tmp_path)
# Read and check the returned message
with open(tmp_path / "java_single_message.out", "rb") as stream:
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
assert stream.read() == b""
assert returned == oneof_example
def test_multiple_messages(compile_jar, tmp_path):
# Write delimited messages to file
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
# Have Java read and return the messages
run_jar("multiple_messages", tmp_path)
# Read and check the returned messages
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
assert stream.read() == b""
assert returned_oneof == oneof_example
assert returned_nested == nested_example
def test_infinite_messages(compile_jar, tmp_path):
num_messages = 5
# Write delimited messages to file
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
for x in range(num_messages):
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
# Have Java read and return the messages
run_jar("infinite_messages", tmp_path)
# Read and check the returned messages
messages = []
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
while True:
try:
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
except EOFError:
break
assert len(messages) == num_messages

27
tests/test_timestamp.py Normal file
View File

@ -0,0 +1,27 @@
from datetime import (
datetime,
timezone,
)
import pytest
from betterproto import _Timestamp
@pytest.mark.parametrize(
"dt",
[
datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
datetime.now(timezone.utc),
# potential issue with floating point precision:
datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
# potential issue with negative timestamps:
datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
],
)
def test_timestamp_to_datetime_and_back(dt: datetime):
"""
Make sure converting a datetime to a protobuf timestamp message
and then back again ends up with the same datetime.
"""
assert _Timestamp.from_datetime(dt).to_datetime() == dt