Merge branch 'master_gh'
# Conflicts: # src/betterproto/__init__.py
This commit is contained in:
commit
1d296f1a88
63
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report broken or incorrect behaviour
|
||||||
|
labels: ["bug", "investigation needed"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: >
|
||||||
|
Thanks for taking the time to fill out a bug report!
|
||||||
|
|
||||||
|
If you're not sure it's a bug and you just have a question, the [community Slack channel](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ) is a better place for general questions than a GitHub issue.
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
attributes:
|
||||||
|
label: Summary
|
||||||
|
description: A simple summary of your bug report
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Reproduction Steps
|
||||||
|
description: >
|
||||||
|
What you did to make it happen.
|
||||||
|
Ideally there should be a short code snippet in this section to help reproduce the bug.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Expected Results
|
||||||
|
description: >
|
||||||
|
What did you expect to happen?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Actual Results
|
||||||
|
description: >
|
||||||
|
What actually happened?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: System Information
|
||||||
|
description: >
|
||||||
|
Paste the result of `protoc --version; python --version; pip show betterproto` below.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
options:
|
||||||
|
- label: I have searched the issues for duplicates.
|
||||||
|
required: true
|
||||||
|
- label: I have shown the entire traceback, if possible.
|
||||||
|
required: true
|
||||||
|
- label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible.
|
||||||
|
required: true
|
||||||
|
|
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
name:
|
||||||
|
description:
|
||||||
|
contact_links:
|
||||||
|
- name: For questions about the library
|
||||||
|
about: Support questions are better answered in our Slack group.
|
||||||
|
url: https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ
|
49
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
49
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a feature for this library
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: input
|
||||||
|
attributes:
|
||||||
|
label: Summary
|
||||||
|
description: >
|
||||||
|
What problem is your feature trying to solve? What would become easier or possible if feature was implemented?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
attributes:
|
||||||
|
multiple: false
|
||||||
|
label: What is the feature request for?
|
||||||
|
options:
|
||||||
|
- The core library
|
||||||
|
- RPC handling
|
||||||
|
- The documentation
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: The Problem
|
||||||
|
description: >
|
||||||
|
What problem is your feature trying to solve?
|
||||||
|
What would become easier or possible if feature was implemented?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: The Ideal Solution
|
||||||
|
description: >
|
||||||
|
What is your ideal solution to the problem?
|
||||||
|
What would you like this feature to do?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: The Current Solution
|
||||||
|
description: >
|
||||||
|
What is the current solution to the problem, if any?
|
||||||
|
validations:
|
||||||
|
required: false
|
16
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
16
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
## Summary
|
||||||
|
|
||||||
|
<!-- What is this pull request for? Does it fix any issues? -->
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
<!-- Put an x inside [ ] to check it, like so: [x] -->
|
||||||
|
|
||||||
|
- [ ] If code changes were made then they have been tested.
|
||||||
|
- [ ] I have updated the documentation to reflect the changes.
|
||||||
|
- [ ] This PR fixes an issue.
|
||||||
|
- [ ] This PR adds something new (e.g. new method or parameters).
|
||||||
|
- [ ] This change has an associated test.
|
||||||
|
- [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed)
|
||||||
|
- [ ] This PR is **not** a code change (e.g. documentation, README, ...)
|
||||||
|
|
46
.github/workflows/codeql-analysis.yml
vendored
Normal file
46
.github/workflows/codeql-analysis.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
name: "CodeQL"
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "master" ]
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- '**'
|
||||||
|
schedule:
|
||||||
|
- cron: '19 1 * * 6'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
analyze:
|
||||||
|
name: Analyze
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
actions: read
|
||||||
|
contents: read
|
||||||
|
security-events: write
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
language: [ 'python' ]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
# Initializes the CodeQL tools for scanning.
|
||||||
|
- name: Initialize CodeQL
|
||||||
|
uses: github/codeql-action/init@v2
|
||||||
|
with:
|
||||||
|
languages: ${{ matrix.language }}
|
||||||
|
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||||
|
# By default, queries listed here will override any specified in a config file.
|
||||||
|
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||||
|
|
||||||
|
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||||
|
# queries: security-extended,security-and-quality
|
||||||
|
|
||||||
|
- name: Autobuild
|
||||||
|
uses: github/codeql-action/autobuild@v2
|
||||||
|
|
||||||
|
- name: Perform CodeQL Analysis
|
||||||
|
uses: github/codeql-action/analyze@v2
|
@ -16,6 +16,13 @@ repos:
|
|||||||
- repo: https://github.com/PyCQA/doc8
|
- repo: https://github.com/PyCQA/doc8
|
||||||
rev: 0.10.1
|
rev: 0.10.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: doc8
|
- id: doc8
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- toml
|
- toml
|
||||||
|
|
||||||
|
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
|
||||||
|
rev: v2.10.0
|
||||||
|
hooks:
|
||||||
|
- id: pretty-format-java
|
||||||
|
args: [--autofix, --aosp]
|
||||||
|
files: ^.*\.java$
|
||||||
|
1200
poetry.lock
generated
1200
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -19,26 +19,29 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
|
|||||||
jinja2 = { version = ">=3.0.3", optional = true }
|
jinja2 = { version = ">=3.0.3", optional = true }
|
||||||
python-dateutil = "^2.8"
|
python-dateutil = "^2.8"
|
||||||
isort = {version = "^5.11.5", optional = true}
|
isort = {version = "^5.11.5", optional = true}
|
||||||
|
typing-extensions = "^4.7.1"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
asv = "^0.4.2"
|
asv = "^0.4.2"
|
||||||
bpython = "^0.19"
|
bpython = "^0.19"
|
||||||
grpcio-tools = "^1.54.2"
|
|
||||||
jinja2 = ">=3.0.3"
|
jinja2 = ">=3.0.3"
|
||||||
mypy = "^0.930"
|
mypy = "^0.930"
|
||||||
|
sphinx = "3.1.2"
|
||||||
|
sphinx-rtd-theme = "0.5.0"
|
||||||
|
pre-commit = "^2.17.0"
|
||||||
|
grpcio-tools = "^1.54.2"
|
||||||
|
tox = "^4.0.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.test.dependencies]
|
||||||
poethepoet = ">=0.9.0"
|
poethepoet = ">=0.9.0"
|
||||||
protobuf = "^4.21.6"
|
|
||||||
pytest = "^6.2.5"
|
pytest = "^6.2.5"
|
||||||
pytest-asyncio = "^0.12.0"
|
pytest-asyncio = "^0.12.0"
|
||||||
pytest-cov = "^2.9.0"
|
pytest-cov = "^2.9.0"
|
||||||
pytest-mock = "^3.1.1"
|
pytest-mock = "^3.1.1"
|
||||||
sphinx = "3.1.2"
|
pydantic = ">=1.8.0,<2"
|
||||||
sphinx-rtd-theme = "0.5.0"
|
protobuf = "^4"
|
||||||
tomlkit = "^0.7.0"
|
cachelib = "^0.10.2"
|
||||||
tox = "^3.15.1"
|
tomlkit = ">=0.7.0"
|
||||||
pre-commit = "^2.17.0"
|
|
||||||
pydantic = ">=1.8.0"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
protoc-gen-python_betterproto = "betterproto.plugin:main"
|
protoc-gen-python_betterproto = "betterproto.plugin:main"
|
||||||
@ -61,9 +64,13 @@ help = "Run tests"
|
|||||||
cmd = "mypy src --ignore-missing-imports"
|
cmd = "mypy src --ignore-missing-imports"
|
||||||
help = "Check types with mypy"
|
help = "Check types with mypy"
|
||||||
|
|
||||||
|
[tool.poe.tasks]
|
||||||
|
_black = "black . --exclude tests/output_ --target-version py310"
|
||||||
|
_isort = "isort . --extend-skip-glob 'tests/output_*/**/*'"
|
||||||
|
|
||||||
[tool.poe.tasks.format]
|
[tool.poe.tasks.format]
|
||||||
cmd = "black . --exclude tests/output_ --target-version py310"
|
sequence = ["_black", "_isort"]
|
||||||
help = "Apply black formatting to source code"
|
help = "Apply black and isort formatting to source code"
|
||||||
|
|
||||||
[tool.poe.tasks.docs]
|
[tool.poe.tasks.docs]
|
||||||
cmd = "sphinx-build docs docs/build"
|
cmd = "sphinx-build docs docs/build"
|
||||||
@ -130,14 +137,21 @@ omit = ["betterproto/tests/*"]
|
|||||||
[tool.tox]
|
[tool.tox]
|
||||||
legacy_tox_ini = """
|
legacy_tox_ini = """
|
||||||
[tox]
|
[tox]
|
||||||
isolated_build = true
|
requires =
|
||||||
envlist = py37, py38, py310
|
tox>=4.2
|
||||||
|
tox-poetry-installer[poetry]==1.0.0b1
|
||||||
|
env_list =
|
||||||
|
py311
|
||||||
|
py38
|
||||||
|
py37
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
whitelist_externals = poetry
|
|
||||||
commands =
|
commands =
|
||||||
poetry install -v --extras compiler
|
pytest {posargs: --cov betterproto}
|
||||||
poetry run pytest --cov betterproto
|
poetry_dep_groups =
|
||||||
|
test
|
||||||
|
require_locked_deps = true
|
||||||
|
require_poetry = true
|
||||||
"""
|
"""
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum as builtin_enum
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
@ -22,8 +24,8 @@ from itertools import count
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
BinaryIO,
|
|
||||||
Callable,
|
Callable,
|
||||||
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
@ -37,6 +39,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ._types import T
|
from ._types import T
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
@ -45,11 +48,19 @@ from .casing import (
|
|||||||
safe_snake_case,
|
safe_snake_case,
|
||||||
snake_case,
|
snake_case,
|
||||||
)
|
)
|
||||||
from .grpc.grpclib_client import ServiceStub
|
from .enum import Enum as Enum
|
||||||
|
from .grpc.grpclib_client import ServiceStub as ServiceStub
|
||||||
|
from .utils import (
|
||||||
|
classproperty,
|
||||||
|
hybridmethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from _typeshed import ReadableBuffer
|
from _typeshed import (
|
||||||
|
SupportsRead,
|
||||||
|
SupportsWrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Proto 3 data types
|
# Proto 3 data types
|
||||||
@ -126,6 +137,9 @@ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
|||||||
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||||
|
|
||||||
|
# Indicator of message delimitation in streams
|
||||||
|
SIZE_DELIMITED = -1
|
||||||
|
|
||||||
|
|
||||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||||
def datetime_default_gen() -> datetime:
|
def datetime_default_gen() -> datetime:
|
||||||
@ -140,7 +154,7 @@ NEG_INFINITY = "-Infinity"
|
|||||||
NAN = "NaN"
|
NAN = "NaN"
|
||||||
|
|
||||||
|
|
||||||
class Casing(enum.Enum):
|
class Casing(builtin_enum.Enum):
|
||||||
"""Casing constants for serialization."""
|
"""Casing constants for serialization."""
|
||||||
|
|
||||||
CAMEL = camel_case #: A camelCase sterilization function.
|
CAMEL = camel_case #: A camelCase sterilization function.
|
||||||
@ -309,32 +323,6 @@ def map_field(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Enum(enum.IntEnum):
|
|
||||||
"""
|
|
||||||
The base class for protobuf enumerations, all generated enumerations will inherit
|
|
||||||
from this. Bases :class:`enum.IntEnum`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_string(cls, name: str) -> "Enum":
|
|
||||||
"""Return the value which corresponds to the string name.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
-----------
|
|
||||||
name: :class:`str`
|
|
||||||
The name of the enum member to get
|
|
||||||
|
|
||||||
Raises
|
|
||||||
-------
|
|
||||||
:exc:`ValueError`
|
|
||||||
The member was not found in the Enum.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return cls._member_map_[name] # type: ignore
|
|
||||||
except KeyError as e:
|
|
||||||
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
|
||||||
|
|
||||||
|
|
||||||
def _pack_fmt(proto_type: str) -> str:
|
def _pack_fmt(proto_type: str) -> str:
|
||||||
"""Returns a little-endian format string for reading/writing binary."""
|
"""Returns a little-endian format string for reading/writing binary."""
|
||||||
return {
|
return {
|
||||||
@ -347,7 +335,7 @@ def _pack_fmt(proto_type: str) -> str:
|
|||||||
}[proto_type]
|
}[proto_type]
|
||||||
|
|
||||||
|
|
||||||
def dump_varint(value: int, stream: BinaryIO) -> None:
|
def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
|
||||||
"""Encodes a single varint and dumps it into the provided stream."""
|
"""Encodes a single varint and dumps it into the provided stream."""
|
||||||
if value < -(1 << 63):
|
if value < -(1 << 63):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -556,7 +544,7 @@ def _dump_float(value: float) -> Union[float, str]:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
|
def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
|
||||||
"""
|
"""
|
||||||
Load a single varint value from a stream. Returns the value and the raw bytes read.
|
Load a single varint value from a stream. Returns the value and the raw bytes read.
|
||||||
"""
|
"""
|
||||||
@ -594,7 +582,7 @@ class ParsedField:
|
|||||||
raw: bytes
|
raw: bytes
|
||||||
|
|
||||||
|
|
||||||
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
|
def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
num_wire, raw = load_varint(stream)
|
num_wire, raw = load_varint(stream)
|
||||||
@ -748,6 +736,7 @@ class Message(ABC):
|
|||||||
_serialized_on_wire: bool
|
_serialized_on_wire: bool
|
||||||
_unknown_fields: bytes
|
_unknown_fields: bytes
|
||||||
_group_current: Dict[str, str]
|
_group_current: Dict[str, str]
|
||||||
|
_betterproto_meta: ClassVar[ProtoClassMetadata]
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Keep track of whether every field was default
|
# Keep track of whether every field was default
|
||||||
@ -815,6 +804,10 @@ class Message(ABC):
|
|||||||
]
|
]
|
||||||
return f"{self.__class__.__name__}({', '.join(parts)})"
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||||
|
|
||||||
|
def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
|
||||||
|
for field_name in self._betterproto.sorted_field_names:
|
||||||
|
yield field_name, self.__raw_get(field_name), PLACEHOLDER
|
||||||
|
|
||||||
if not TYPE_CHECKING:
|
if not TYPE_CHECKING:
|
||||||
|
|
||||||
def __getattribute__(self, name: str) -> Any:
|
def __getattribute__(self, name: str) -> Any:
|
||||||
@ -889,20 +882,28 @@ class Message(ABC):
|
|||||||
kwargs[name] = deepcopy(value)
|
kwargs[name] = deepcopy(value)
|
||||||
return self.__class__(**kwargs) # type: ignore
|
return self.__class__(**kwargs) # type: ignore
|
||||||
|
|
||||||
@property
|
def __copy__(self: T, _: Any = {}) -> T:
|
||||||
def _betterproto(self) -> ProtoClassMetadata:
|
kwargs = {}
|
||||||
|
for name in self._betterproto.sorted_field_names:
|
||||||
|
value = self.__raw_get(name)
|
||||||
|
if value is not PLACEHOLDER:
|
||||||
|
kwargs[name] = value
|
||||||
|
return self.__class__(**kwargs) # type: ignore
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
|
||||||
"""
|
"""
|
||||||
Lazy initialize metadata for each protobuf class.
|
Lazy initialize metadata for each protobuf class.
|
||||||
It may be initialized multiple times in a multi-threaded environment,
|
It may be initialized multiple times in a multi-threaded environment,
|
||||||
but that won't affect the correctness.
|
but that won't affect the correctness.
|
||||||
"""
|
"""
|
||||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
try:
|
||||||
if not meta:
|
return cls._betterproto_meta
|
||||||
meta = ProtoClassMetadata(self.__class__)
|
except AttributeError:
|
||||||
self.__class__._betterproto_meta = meta # type: ignore
|
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
def dump(self, stream: BinaryIO) -> None:
|
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Dumps the binary encoded Protobuf message to the stream.
|
Dumps the binary encoded Protobuf message to the stream.
|
||||||
|
|
||||||
@ -910,7 +911,11 @@ class Message(ABC):
|
|||||||
-----------
|
-----------
|
||||||
stream: :class:`BinaryIO`
|
stream: :class:`BinaryIO`
|
||||||
The stream to dump the message to.
|
The stream to dump the message to.
|
||||||
|
delimit:
|
||||||
|
Whether to prefix the message with a varint declaring its size.
|
||||||
"""
|
"""
|
||||||
|
if delimit == SIZE_DELIMITED:
|
||||||
|
dump_varint(len(self), stream)
|
||||||
|
|
||||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
try:
|
try:
|
||||||
@ -930,7 +935,7 @@ class Message(ABC):
|
|||||||
# Note that proto3 field presence/optional fields are put in a
|
# Note that proto3 field presence/optional fields are put in a
|
||||||
# synthetic single-item oneof by protoc, which helps us ensure we
|
# synthetic single-item oneof by protoc, which helps us ensure we
|
||||||
# send the value even if the value is the default zero value.
|
# send the value even if the value is the default zero value.
|
||||||
selected_in_group = bool(meta.group)
|
selected_in_group = bool(meta.group) or meta.optional
|
||||||
|
|
||||||
# Empty messages can still be sent on the wire if they were
|
# Empty messages can still be sent on the wire if they were
|
||||||
# set (or received empty).
|
# set (or received empty).
|
||||||
@ -1124,6 +1129,15 @@ class Message(ABC):
|
|||||||
"""
|
"""
|
||||||
return bytes(self)
|
return bytes(self)
|
||||||
|
|
||||||
|
def __getstate__(self) -> bytes:
|
||||||
|
return bytes(self)
|
||||||
|
|
||||||
|
def __setstate__(self: T, pickled_bytes: bytes) -> T:
|
||||||
|
return self.parse(pickled_bytes)
|
||||||
|
|
||||||
|
def __reduce__(self) -> Tuple[Any, ...]:
|
||||||
|
return (self.__class__.FromString, (bytes(self),))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _type_hint(cls, field_name: str) -> Type:
|
def _type_hint(cls, field_name: str) -> Type:
|
||||||
return cls._type_hints()[field_name]
|
return cls._type_hints()[field_name]
|
||||||
@ -1168,7 +1182,7 @@ class Message(ABC):
|
|||||||
return t
|
return t
|
||||||
elif issubclass(t, Enum):
|
elif issubclass(t, Enum):
|
||||||
# Enums always default to zero.
|
# Enums always default to zero.
|
||||||
return int
|
return t.try_value
|
||||||
elif t is datetime:
|
elif t is datetime:
|
||||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||||
return datetime_default_gen
|
return datetime_default_gen
|
||||||
@ -1193,6 +1207,9 @@ class Message(ABC):
|
|||||||
elif meta.proto_type == TYPE_BOOL:
|
elif meta.proto_type == TYPE_BOOL:
|
||||||
# Booleans use a varint encoding, so convert it to true/false.
|
# Booleans use a varint encoding, so convert it to true/false.
|
||||||
value = value > 0
|
value = value > 0
|
||||||
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
|
# Convert enum ints to python enum instances
|
||||||
|
value = self._betterproto.cls_by_field[field_name].try_value(value)
|
||||||
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
|
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
|
||||||
fmt = _pack_fmt(meta.proto_type)
|
fmt = _pack_fmt(meta.proto_type)
|
||||||
value = struct.unpack(fmt, value)[0]
|
value = struct.unpack(fmt, value)[0]
|
||||||
@ -1225,7 +1242,11 @@ class Message(ABC):
|
|||||||
meta.group is not None and self._group_current.get(meta.group) == field_name
|
meta.group is not None and self._group_current.get(meta.group) == field_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
|
def load(
|
||||||
|
self: T,
|
||||||
|
stream: "SupportsRead[bytes]",
|
||||||
|
size: Optional[int] = None,
|
||||||
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Load the binary encoded Protobuf from a stream into this message instance. This
|
Load the binary encoded Protobuf from a stream into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
@ -1237,12 +1258,17 @@ class Message(ABC):
|
|||||||
size: :class:`Optional[int]`
|
size: :class:`Optional[int]`
|
||||||
The size of the message in the stream.
|
The size of the message in the stream.
|
||||||
Reads stream until EOF if ``None`` is given.
|
Reads stream until EOF if ``None`` is given.
|
||||||
|
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
--------
|
--------
|
||||||
:class:`Message`
|
:class:`Message`
|
||||||
The initialized message.
|
The initialized message.
|
||||||
"""
|
"""
|
||||||
|
# If the message is delimited, parse the message delimiter
|
||||||
|
if size == SIZE_DELIMITED:
|
||||||
|
size, _ = load_varint(stream)
|
||||||
|
|
||||||
# Got some data over the wire
|
# Got some data over the wire
|
||||||
self._serialized_on_wire = True
|
self._serialized_on_wire = True
|
||||||
proto_meta = self._betterproto
|
proto_meta = self._betterproto
|
||||||
@ -1315,7 +1341,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def parse(self: T, data: "ReadableBuffer") -> T:
|
def parse(self: T, data: bytes) -> T:
|
||||||
"""
|
"""
|
||||||
Parse the binary encoded Protobuf into this message instance. This
|
Parse the binary encoded Protobuf into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
@ -1494,7 +1520,91 @@ class Message(ABC):
|
|||||||
output[cased_name] = value
|
output[cased_name] = value
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def from_dict(self: T, value: Mapping[str, Any]) -> T:
|
@classmethod
|
||||||
|
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
init_kwargs: Dict[str, Any] = {}
|
||||||
|
for key, value in mapping.items():
|
||||||
|
field_name = safe_snake_case(key)
|
||||||
|
try:
|
||||||
|
meta = cls._betterproto.meta_by_field_name[field_name]
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if meta.proto_type == TYPE_MESSAGE:
|
||||||
|
sub_cls = cls._betterproto.cls_by_field[field_name]
|
||||||
|
if sub_cls == datetime:
|
||||||
|
value = (
|
||||||
|
[isoparse(item) for item in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else isoparse(value)
|
||||||
|
)
|
||||||
|
elif sub_cls == timedelta:
|
||||||
|
value = (
|
||||||
|
[timedelta(seconds=float(item[:-1])) for item in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else timedelta(seconds=float(value[:-1]))
|
||||||
|
)
|
||||||
|
elif not meta.wraps:
|
||||||
|
value = (
|
||||||
|
[sub_cls.from_dict(item) for item in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else sub_cls.from_dict(value)
|
||||||
|
)
|
||||||
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
|
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
|
||||||
|
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
|
||||||
|
else:
|
||||||
|
if meta.proto_type in INT_64_TYPES:
|
||||||
|
value = (
|
||||||
|
[int(n) for n in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else int(value)
|
||||||
|
)
|
||||||
|
elif meta.proto_type == TYPE_BYTES:
|
||||||
|
value = (
|
||||||
|
[b64decode(n) for n in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else b64decode(value)
|
||||||
|
)
|
||||||
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
|
enum_cls = cls._betterproto.cls_by_field[field_name]
|
||||||
|
if isinstance(value, list):
|
||||||
|
value = [enum_cls.from_string(e) for e in value]
|
||||||
|
elif isinstance(value, str):
|
||||||
|
value = enum_cls.from_string(value)
|
||||||
|
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||||
|
value = (
|
||||||
|
[_parse_float(n) for n in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else _parse_float(value)
|
||||||
|
)
|
||||||
|
|
||||||
|
init_kwargs[field_name] = value
|
||||||
|
return init_kwargs
|
||||||
|
|
||||||
|
@hybridmethod
|
||||||
|
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
|
||||||
|
"""
|
||||||
|
Parse the key/value pairs into the a new message instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
value: Dict[:class:`str`, Any]
|
||||||
|
The dictionary to parse from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
:class:`Message`
|
||||||
|
The initialized message.
|
||||||
|
"""
|
||||||
|
self = cls(**cls._from_dict_init(value))
|
||||||
|
self._serialized_on_wire = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
@from_dict.instancemethod
|
||||||
|
def from_dict(self, value: Mapping[str, Any]) -> Self:
|
||||||
"""
|
"""
|
||||||
Parse the key/value pairs into the current message instance. This returns the
|
Parse the key/value pairs into the current message instance. This returns the
|
||||||
instance itself and is therefore assignable and chainable.
|
instance itself and is therefore assignable and chainable.
|
||||||
@ -1510,71 +1620,8 @@ class Message(ABC):
|
|||||||
The initialized message.
|
The initialized message.
|
||||||
"""
|
"""
|
||||||
self._serialized_on_wire = True
|
self._serialized_on_wire = True
|
||||||
for key in value:
|
for field, value in self._from_dict_init(value).items():
|
||||||
field_name = safe_snake_case(key)
|
setattr(self, field, value)
|
||||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
|
||||||
if not meta:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if value[key] is not None:
|
|
||||||
if meta.proto_type == TYPE_MESSAGE:
|
|
||||||
v = self._get_field_default(field_name)
|
|
||||||
cls = self._betterproto.cls_by_field[field_name]
|
|
||||||
if isinstance(v, list):
|
|
||||||
if cls == datetime:
|
|
||||||
v = [isoparse(item) for item in value[key]]
|
|
||||||
elif cls == timedelta:
|
|
||||||
v = [
|
|
||||||
timedelta(seconds=float(item[:-1]))
|
|
||||||
for item in value[key]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
v = [cls().from_dict(item) for item in value[key]]
|
|
||||||
elif cls == datetime:
|
|
||||||
v = isoparse(value[key])
|
|
||||||
setattr(self, field_name, v)
|
|
||||||
elif cls == timedelta:
|
|
||||||
v = timedelta(seconds=float(value[key][:-1]))
|
|
||||||
setattr(self, field_name, v)
|
|
||||||
elif meta.wraps:
|
|
||||||
setattr(self, field_name, value[key])
|
|
||||||
elif v is None:
|
|
||||||
setattr(self, field_name, cls().from_dict(value[key]))
|
|
||||||
else:
|
|
||||||
# NOTE: `from_dict` mutates the underlying message, so no
|
|
||||||
# assignment here is necessary.
|
|
||||||
v.from_dict(value[key])
|
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
|
||||||
v = getattr(self, field_name)
|
|
||||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
|
||||||
for k in value[key]:
|
|
||||||
v[k] = cls().from_dict(value[key][k])
|
|
||||||
else:
|
|
||||||
v = value[key]
|
|
||||||
if meta.proto_type in INT_64_TYPES:
|
|
||||||
if isinstance(value[key], list):
|
|
||||||
v = [int(n) for n in value[key]]
|
|
||||||
else:
|
|
||||||
v = int(value[key])
|
|
||||||
elif meta.proto_type == TYPE_BYTES:
|
|
||||||
if isinstance(value[key], list):
|
|
||||||
v = [b64decode(n) for n in value[key]]
|
|
||||||
else:
|
|
||||||
v = b64decode(value[key])
|
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
|
||||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = [enum_cls.from_string(e) for e in v]
|
|
||||||
elif isinstance(v, str):
|
|
||||||
v = enum_cls.from_string(v)
|
|
||||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
|
||||||
if isinstance(value[key], list):
|
|
||||||
v = [_parse_float(n) for n in value[key]]
|
|
||||||
else:
|
|
||||||
v = _parse_float(value[key])
|
|
||||||
|
|
||||||
if v is not None:
|
|
||||||
setattr(self, field_name, v)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_json(
|
def to_json(
|
||||||
@ -1791,8 +1838,8 @@ class Message(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_field_groups(cls, values):
|
def _validate_field_groups(cls, values):
|
||||||
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
group_to_one_ofs = cls._betterproto.oneof_field_by_group
|
||||||
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
|
field_name_to_meta = cls._betterproto.meta_by_field_name
|
||||||
|
|
||||||
for group, field_set in group_to_one_ofs.items():
|
for group, field_set in group_to_one_ofs.items():
|
||||||
if len(field_set) == 1:
|
if len(field_set) == 1:
|
||||||
@ -1819,6 +1866,9 @@ class Message(ABC):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
|
||||||
|
|
||||||
|
|
||||||
def serialized_on_wire(message: Message) -> bool:
|
def serialized_on_wire(message: Message) -> bool:
|
||||||
"""
|
"""
|
||||||
If this message was or should be serialized on the wire. This can be used to detect
|
If this message was or should be serialized on the wire. This can be used to detect
|
||||||
@ -1890,17 +1940,24 @@ class _Duration(Duration):
|
|||||||
class _Timestamp(Timestamp):
|
class _Timestamp(Timestamp):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_datetime(cls, dt: datetime) -> "_Timestamp":
|
def from_datetime(cls, dt: datetime) -> "_Timestamp":
|
||||||
seconds = int(dt.timestamp())
|
# manual epoch offset calulation to avoid rounding errors,
|
||||||
nanos = int(dt.microsecond * 1e3)
|
# to support negative timestamps (before 1970) and skirt
|
||||||
return cls(seconds, nanos)
|
# around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
|
||||||
|
offset = dt - DATETIME_ZERO
|
||||||
|
# below is the same as timedelta.total_seconds() but without dividing by 1e6
|
||||||
|
# so we end up with microseconds as integers instead of seconds as float
|
||||||
|
offset_us = (
|
||||||
|
offset.days * 24 * 60 * 60 + offset.seconds
|
||||||
|
) * 10**6 + offset.microseconds
|
||||||
|
seconds, us = divmod(offset_us, 10**6)
|
||||||
|
return cls(seconds, us * 1000)
|
||||||
|
|
||||||
def to_datetime(self) -> datetime:
|
def to_datetime(self) -> datetime:
|
||||||
ts = self.seconds + (self.nanos / 1e9)
|
# datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
|
||||||
|
# if we pass it as a floating point number, we will run into rounding errors
|
||||||
if ts < 0:
|
# see also #407
|
||||||
return datetime(1970, 1, 1) + timedelta(seconds=ts)
|
offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
|
||||||
else:
|
return DATETIME_ZERO + offset
|
||||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def timestamp_to_json(dt: datetime) -> str:
|
def timestamp_to_json(dt: datetime) -> str:
|
||||||
|
@ -136,4 +136,8 @@ def lowercase_first(value: str) -> str:
|
|||||||
|
|
||||||
def sanitize_name(value: str) -> str:
|
def sanitize_name(value: str) -> str:
|
||||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||||
return f"{value}_" if keyword.iskeyword(value) else value
|
if keyword.iskeyword(value):
|
||||||
|
return f"{value}_"
|
||||||
|
if not value.isidentifier():
|
||||||
|
return f"_{value}"
|
||||||
|
return value
|
||||||
|
@ -11,3 +11,11 @@ def pythonize_field_name(name: str) -> str:
|
|||||||
|
|
||||||
def pythonize_method_name(name: str) -> str:
|
def pythonize_method_name(name: str) -> str:
|
||||||
return casing.safe_snake_case(name)
|
return casing.safe_snake_case(name)
|
||||||
|
|
||||||
|
|
||||||
|
def pythonize_enum_member_name(name: str, enum_name: str) -> str:
|
||||||
|
enum_name = casing.snake_case(enum_name).upper()
|
||||||
|
find = name.find(enum_name)
|
||||||
|
if find != -1:
|
||||||
|
name = name[find + len(enum_name) :].strip("_")
|
||||||
|
return casing.sanitize_name(name)
|
||||||
|
196
src/betterproto/enum.py
Normal file
196
src/betterproto/enum.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from enum import (
|
||||||
|
EnumMeta,
|
||||||
|
IntEnum,
|
||||||
|
)
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import (
|
||||||
|
Generator,
|
||||||
|
Mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import (
|
||||||
|
Never,
|
||||||
|
Self,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_descriptor(obj: object) -> bool:
|
||||||
|
return (
|
||||||
|
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnumType(EnumMeta if TYPE_CHECKING else type):
|
||||||
|
_value_map_: Mapping[int, Enum]
|
||||||
|
_member_map_: Mapping[str, Enum]
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
|
||||||
|
) -> Self:
|
||||||
|
value_map = {}
|
||||||
|
member_map = {}
|
||||||
|
|
||||||
|
new_mcs = type(
|
||||||
|
f"{name}Type",
|
||||||
|
tuple(
|
||||||
|
dict.fromkeys(
|
||||||
|
[base.__class__ for base in bases if base.__class__ is not type]
|
||||||
|
+ [EnumType, type]
|
||||||
|
)
|
||||||
|
), # reorder the bases so EnumType and type are last to avoid conflicts
|
||||||
|
{"_value_map_": value_map, "_member_map_": member_map},
|
||||||
|
)
|
||||||
|
|
||||||
|
members = {
|
||||||
|
name: value
|
||||||
|
for name, value in namespace.items()
|
||||||
|
if not _is_descriptor(value) and not name.startswith("__")
|
||||||
|
}
|
||||||
|
|
||||||
|
cls = type.__new__(
|
||||||
|
new_mcs,
|
||||||
|
name,
|
||||||
|
bases,
|
||||||
|
{key: value for key, value in namespace.items() if key not in members},
|
||||||
|
)
|
||||||
|
# this allows us to disallow member access from other members as
|
||||||
|
# members become proper class variables
|
||||||
|
|
||||||
|
for name, value in members.items():
|
||||||
|
member = value_map.get(value)
|
||||||
|
if member is None:
|
||||||
|
member = cls.__new__(cls, name=name, value=value) # type: ignore
|
||||||
|
value_map[value] = member
|
||||||
|
member_map[name] = member
|
||||||
|
type.__setattr__(new_mcs, name, member)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
|
||||||
|
def __call__(cls, value: int) -> Enum:
|
||||||
|
try:
|
||||||
|
return cls._value_map_[value]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
|
||||||
|
|
||||||
|
def __iter__(cls) -> Generator[Enum, None, None]:
|
||||||
|
yield from cls._member_map_.values()
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
|
||||||
|
|
||||||
|
def __reversed__(cls) -> Generator[Enum, None, None]:
|
||||||
|
yield from reversed(cls._member_map_.values())
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def __reversed__(cls) -> Generator[Enum, None, None]:
|
||||||
|
yield from reversed(tuple(cls._member_map_.values()))
|
||||||
|
|
||||||
|
def __getitem__(cls, key: str) -> Enum:
|
||||||
|
return cls._member_map_[key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __members__(cls) -> MappingProxyType[str, Enum]:
|
||||||
|
return MappingProxyType(cls._member_map_)
|
||||||
|
|
||||||
|
def __repr__(cls) -> str:
|
||||||
|
return f"<enum {cls.__name__!r}>"
|
||||||
|
|
||||||
|
def __len__(cls) -> int:
|
||||||
|
return len(cls._member_map_)
|
||||||
|
|
||||||
|
def __setattr__(cls, name: str, value: Any) -> Never:
|
||||||
|
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
|
||||||
|
|
||||||
|
def __delattr__(cls, name: str) -> Never:
|
||||||
|
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
|
||||||
|
|
||||||
|
def __contains__(cls, member: object) -> bool:
|
||||||
|
return isinstance(member, cls) and member.name in cls._member_map_
|
||||||
|
|
||||||
|
|
||||||
|
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
|
||||||
|
"""
|
||||||
|
The base class for protobuf enumerations, all generated enumerations will
|
||||||
|
inherit from this. Emulates `enum.IntEnum`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Optional[str]
|
||||||
|
value: int
|
||||||
|
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
|
||||||
|
def __new__(cls, *, name: Optional[str], value: int) -> Self:
|
||||||
|
self = super().__new__(cls, value)
|
||||||
|
super().__setattr__(self, "name", name)
|
||||||
|
super().__setattr__(self, "value", value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name or "None"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}.{self.name}"
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, value: Any) -> Never:
|
||||||
|
raise AttributeError(
|
||||||
|
f"{self.__class__.__name__} Cannot reassign a member's attributes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __delattr__(self, item: Any) -> Never:
|
||||||
|
raise AttributeError(
|
||||||
|
f"{self.__class__.__name__} Cannot delete a member's attributes."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def try_value(cls, value: int = 0) -> Self:
|
||||||
|
"""Return the value which corresponds to the value.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
value: :class:`int`
|
||||||
|
The value of the enum member to get.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
:class:`Enum`
|
||||||
|
The corresponding member or a new instance of the enum if
|
||||||
|
``value`` isn't actually a member.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cls._value_map_[value]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
return cls.__new__(cls, name=None, value=value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, name: str) -> Self:
|
||||||
|
"""Return the value which corresponds to the string name.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
name: :class:`str`
|
||||||
|
The name of the enum member to get.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
-------
|
||||||
|
:exc:`ValueError`
|
||||||
|
The member was not found in the Enum.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cls._member_map_[name]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
@ -127,6 +127,7 @@ class ServiceStub(ABC):
|
|||||||
response_type,
|
response_type,
|
||||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
) as stream:
|
) as stream:
|
||||||
|
await stream.send_request()
|
||||||
await self._send_messages(stream, request_iterator)
|
await self._send_messages(stream, request_iterator)
|
||||||
response = await stream.recv_message()
|
response = await stream.recv_message()
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
@ -72,13 +72,13 @@ from betterproto.lib.google.protobuf import (
|
|||||||
)
|
)
|
||||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||||
|
|
||||||
from ..casing import sanitize_name
|
|
||||||
from ..compile.importing import (
|
from ..compile.importing import (
|
||||||
get_type_reference,
|
get_type_reference,
|
||||||
parse_source_type_name,
|
parse_source_type_name,
|
||||||
)
|
)
|
||||||
from ..compile.naming import (
|
from ..compile.naming import (
|
||||||
pythonize_class_name,
|
pythonize_class_name,
|
||||||
|
pythonize_enum_member_name,
|
||||||
pythonize_field_name,
|
pythonize_field_name,
|
||||||
pythonize_method_name,
|
pythonize_method_name,
|
||||||
)
|
)
|
||||||
@ -385,7 +385,10 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
|||||||
us to tell whether it was set, via the which_one_of interface.
|
us to tell whether it was set, via the which_one_of interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
return (
|
||||||
|
not proto_field_obj.proto3_optional
|
||||||
|
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -670,7 +673,9 @@ class EnumDefinitionCompiler(MessageCompiler):
|
|||||||
# Get entries/allowed values for this Enum
|
# Get entries/allowed values for this Enum
|
||||||
self.entries = [
|
self.entries = [
|
||||||
self.EnumEntry(
|
self.EnumEntry(
|
||||||
name=sanitize_name(entry_proto_value.name),
|
name=pythonize_enum_member_name(
|
||||||
|
entry_proto_value.name, self.proto_obj.name
|
||||||
|
),
|
||||||
value=entry_proto_value.number,
|
value=entry_proto_value.number,
|
||||||
comment=get_comment(
|
comment=get_comment(
|
||||||
proto_file=self.source_file, path=self.path + [2, entry_number]
|
proto_file=self.source_file, path=self.path + [2, entry_number]
|
||||||
|
56
src/betterproto/utils.py
Normal file
56
src/betterproto/utils.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import (
|
||||||
|
Concatenate,
|
||||||
|
ParamSpec,
|
||||||
|
Self,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SelfT = TypeVar("SelfT")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
HybridT = TypeVar("HybridT", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class hybridmethod(Generic[SelfT, P, HybridT]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func: Callable[
|
||||||
|
Concatenate[type[SelfT], P], HybridT
|
||||||
|
], # Must be the classmethod version
|
||||||
|
):
|
||||||
|
self.cls_func = func
|
||||||
|
self.__doc__ = func.__doc__
|
||||||
|
|
||||||
|
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
|
||||||
|
self.instance_func = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __get__(
|
||||||
|
self, instance: Optional[SelfT], owner: Type[SelfT]
|
||||||
|
) -> Callable[P, HybridT]:
|
||||||
|
if instance is None or self.instance_func is None:
|
||||||
|
# either bound to the class, or no instance method available
|
||||||
|
return self.cls_func.__get__(owner, None)
|
||||||
|
return self.instance_func.__get__(instance, owner)
|
||||||
|
|
||||||
|
|
||||||
|
T_co = TypeVar("T_co")
|
||||||
|
TT_co = TypeVar("TT_co", bound="type[Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class classproperty(Generic[TT_co, T_co]):
|
||||||
|
def __init__(self, func: Callable[[TT_co], T_co]):
|
||||||
|
self.__func__ = func
|
||||||
|
|
||||||
|
def __get__(self, instance: Any, type: TT_co) -> T_co:
|
||||||
|
return self.__func__(type)
|
@ -272,3 +272,27 @@ async def test_async_gen_for_stream_stream_request():
|
|||||||
assert response_index == len(
|
assert response_index == len(
|
||||||
expected_things
|
expected_things
|
||||||
), "Didn't receive all expected responses"
|
), "Didn't receive all expected responses"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_unary_with_empty_iterable():
|
||||||
|
things = [] # empty
|
||||||
|
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
requests = [DoThingRequest(name) for name in things]
|
||||||
|
response = await client.do_many_things(requests)
|
||||||
|
assert len(response.names) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_stream_with_empty_iterable():
|
||||||
|
things = [] # empty
|
||||||
|
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
requests = [GetThingRequest(name) for name in things]
|
||||||
|
responses = [
|
||||||
|
response async for response in client.get_different_things(requests)
|
||||||
|
]
|
||||||
|
assert len(responses) == 0
|
||||||
|
@ -27,7 +27,7 @@ class ThingService:
|
|||||||
async def do_many_things(
|
async def do_many_things(
|
||||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||||
):
|
):
|
||||||
thing_names = [request.name for request in stream]
|
thing_names = [request.name async for request in stream]
|
||||||
if self.test_hook is not None:
|
if self.test_hook is not None:
|
||||||
self.test_hook(stream)
|
self.test_hook(stream)
|
||||||
await stream.send_message(DoThingResponse(thing_names))
|
await stream.send_message(DoThingResponse(thing_names))
|
||||||
|
@ -15,3 +15,11 @@ enum Choice {
|
|||||||
FOUR = 4;
|
FOUR = 4;
|
||||||
THREE = 3;
|
THREE = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A "C" like enum with the enum name prefixed onto members, these should be stripped
|
||||||
|
enum ArithmeticOperator {
|
||||||
|
ARITHMETIC_OPERATOR_NONE = 0;
|
||||||
|
ARITHMETIC_OPERATOR_PLUS = 1;
|
||||||
|
ARITHMETIC_OPERATOR_MINUS = 2;
|
||||||
|
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
|
||||||
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from tests.output_betterproto.enum import (
|
from tests.output_betterproto.enum import (
|
||||||
|
ArithmeticOperator,
|
||||||
Choice,
|
Choice,
|
||||||
Test,
|
Test,
|
||||||
)
|
)
|
||||||
@ -82,3 +83,32 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
|
|||||||
yield Choice.THREE
|
yield Choice.THREE
|
||||||
|
|
||||||
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
|
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_mapped_on_parse():
|
||||||
|
# test default value
|
||||||
|
b = Test().parse(bytes(Test()))
|
||||||
|
assert b.choice.name == Choice.ZERO.name
|
||||||
|
assert b.choices == []
|
||||||
|
|
||||||
|
# test non default value
|
||||||
|
a = Test().parse(bytes(Test(choice=Choice.ONE)))
|
||||||
|
assert a.choice.name == Choice.ONE.name
|
||||||
|
assert b.choices == []
|
||||||
|
|
||||||
|
# test repeated
|
||||||
|
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
|
||||||
|
assert c.choices[0].name == Choice.THREE.name
|
||||||
|
assert c.choices[1].name == Choice.FOUR.name
|
||||||
|
|
||||||
|
# bonus: defaults after empty init are also mapped
|
||||||
|
assert Test().choice.name == Choice.ZERO.name
|
||||||
|
|
||||||
|
|
||||||
|
def test_renamed_enum_members():
|
||||||
|
assert set(ArithmeticOperator.__members__) == {
|
||||||
|
"NONE",
|
||||||
|
"PLUS",
|
||||||
|
"MINUS",
|
||||||
|
"_0_PREFIXED",
|
||||||
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
package google_impl_behavior_equivalence;
|
package google_impl_behavior_equivalence;
|
||||||
|
|
||||||
message Foo { int64 bar = 1; }
|
message Foo { int64 bar = 1; }
|
||||||
@ -12,6 +13,10 @@ message Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Spam {
|
||||||
|
google.protobuf.Timestamp ts = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message Request { Empty foo = 1; }
|
message Request { Empty foo = 1; }
|
||||||
|
|
||||||
message Empty {}
|
message Empty {}
|
@ -1,17 +1,25 @@
|
|||||||
|
from datetime import (
|
||||||
|
datetime,
|
||||||
|
timezone,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from google.protobuf import json_format
|
from google.protobuf import json_format
|
||||||
|
from google.protobuf.timestamp_pb2 import Timestamp
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from tests.output_betterproto.google_impl_behavior_equivalence import (
|
from tests.output_betterproto.google_impl_behavior_equivalence import (
|
||||||
Empty,
|
Empty,
|
||||||
Foo,
|
Foo,
|
||||||
Request,
|
Request,
|
||||||
|
Spam,
|
||||||
Test,
|
Test,
|
||||||
)
|
)
|
||||||
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
|
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
|
||||||
Empty as ReferenceEmpty,
|
Empty as ReferenceEmpty,
|
||||||
Foo as ReferenceFoo,
|
Foo as ReferenceFoo,
|
||||||
Request as ReferenceRequest,
|
Request as ReferenceRequest,
|
||||||
|
Spam as ReferenceSpam,
|
||||||
Test as ReferenceTest,
|
Test as ReferenceTest,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,6 +67,19 @@ def test_bytes_are_the_same_for_oneof():
|
|||||||
assert isinstance(message_reference2.foo, ReferenceFoo)
|
assert isinstance(message_reference2.foo, ReferenceFoo)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
|
||||||
|
def test_datetime_clamping(dt): # see #407
|
||||||
|
ts = Timestamp()
|
||||||
|
ts.FromDatetime(dt)
|
||||||
|
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
|
||||||
|
message_bytes = bytes(Spam(dt))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
Spam().parse(message_bytes).ts.timestamp()
|
||||||
|
== ReferenceSpam.FromString(message_bytes).ts.seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_empty_message_field():
|
def test_empty_message_field():
|
||||||
message = Request()
|
message = Request()
|
||||||
reference_message = ReferenceRequest()
|
reference_message = ReferenceRequest()
|
||||||
|
@ -2,6 +2,10 @@ syntax = "proto3";
|
|||||||
|
|
||||||
package oneof;
|
package oneof;
|
||||||
|
|
||||||
|
message MixedDrink {
|
||||||
|
int32 shots = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message Test {
|
message Test {
|
||||||
oneof foo {
|
oneof foo {
|
||||||
int32 pitied = 1;
|
int32 pitied = 1;
|
||||||
@ -13,6 +17,7 @@ message Test {
|
|||||||
oneof bar {
|
oneof bar {
|
||||||
int32 drinks = 11;
|
int32 drinks = 11;
|
||||||
string bar_name = 12;
|
string bar_name = 12;
|
||||||
|
MixedDrink mixed_drink = 13;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from tests.output_betterproto.oneof import Test
|
from tests.output_betterproto.oneof import (
|
||||||
|
MixedDrink,
|
||||||
|
Test,
|
||||||
|
)
|
||||||
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
|
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
|
||||||
from tests.util import get_test_case_json_data
|
from tests.util import get_test_case_json_data
|
||||||
|
|
||||||
@ -19,3 +24,20 @@ def test_which_name():
|
|||||||
def test_which_count_pyd():
|
def test_which_count_pyd():
|
||||||
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
|
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
|
||||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||||
|
|
||||||
|
|
||||||
|
def test_oneof_constructor_assign():
|
||||||
|
message = Test(mixed_drink=MixedDrink(shots=42))
|
||||||
|
field, value = betterproto.which_one_of(message, "bar")
|
||||||
|
assert field == "mixed_drink"
|
||||||
|
assert value.shots == 42
|
||||||
|
|
||||||
|
|
||||||
|
# Issue #305:
|
||||||
|
@pytest.mark.xfail
|
||||||
|
def test_oneof_nested_assign():
|
||||||
|
message = Test()
|
||||||
|
message.mixed_drink.shots = 42
|
||||||
|
field, value = betterproto.which_one_of(message, "bar")
|
||||||
|
assert field == "mixed_drink"
|
||||||
|
assert value.shots == 42
|
||||||
|
@ -41,3 +41,8 @@ def test_null_fields_json():
|
|||||||
"test8": None,
|
"test8": None,
|
||||||
"test9": None,
|
"test9": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_unset_access(): # see #523
|
||||||
|
assert Test().test1 is None
|
||||||
|
assert Test(test1=None).test1 is None
|
||||||
|
2
tests/streams/delimited_messages.in
Normal file
2
tests/streams/delimited_messages.in
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
•šï:bTesting•šï:bTesting
|
||||||
|
|
38
tests/streams/java/.gitignore
vendored
Normal file
38
tests/streams/java/.gitignore
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
### Output ###
|
||||||
|
target/
|
||||||
|
!.mvn/wrapper/maven-wrapper.jar
|
||||||
|
!**/src/main/**/target/
|
||||||
|
!**/src/test/**/target/
|
||||||
|
dependency-reduced-pom.xml
|
||||||
|
MANIFEST.MF
|
||||||
|
|
||||||
|
### IntelliJ IDEA ###
|
||||||
|
.idea/
|
||||||
|
*.iws
|
||||||
|
*.iml
|
||||||
|
*.ipr
|
||||||
|
|
||||||
|
### Eclipse ###
|
||||||
|
.apt_generated
|
||||||
|
.classpath
|
||||||
|
.factorypath
|
||||||
|
.project
|
||||||
|
.settings
|
||||||
|
.springBeans
|
||||||
|
.sts4-cache
|
||||||
|
|
||||||
|
### NetBeans ###
|
||||||
|
/nbproject/private/
|
||||||
|
/nbbuild/
|
||||||
|
/dist/
|
||||||
|
/nbdist/
|
||||||
|
/.nb-gradle/
|
||||||
|
build/
|
||||||
|
!**/src/main/**/build/
|
||||||
|
!**/src/test/**/build/
|
||||||
|
|
||||||
|
### VS Code ###
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
### Mac OS ###
|
||||||
|
.DS_Store
|
94
tests/streams/java/pom.xml
Normal file
94
tests/streams/java/pom.xml
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<groupId>betterproto</groupId>
|
||||||
|
<artifactId>compatibility-test</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<maven.compiler.source>11</maven.compiler.source>
|
||||||
|
<maven.compiler.target>11</maven.compiler.target>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
<protobuf.version>3.23.4</protobuf.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.protobuf</groupId>
|
||||||
|
<artifactId>protobuf-java</artifactId>
|
||||||
|
<version>${protobuf.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<extensions>
|
||||||
|
<extension>
|
||||||
|
<groupId>kr.motd.maven</groupId>
|
||||||
|
<artifactId>os-maven-plugin</artifactId>
|
||||||
|
<version>1.7.1</version>
|
||||||
|
</extension>
|
||||||
|
</extensions>
|
||||||
|
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<version>3.5.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>shade</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<transformers>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||||
|
<mainClass>betterproto.CompatibilityTest</mainClass>
|
||||||
|
</transformer>
|
||||||
|
</transformers>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
|
<version>3.3.0</version>
|
||||||
|
<configuration>
|
||||||
|
<archive>
|
||||||
|
<manifest>
|
||||||
|
<addClasspath>true</addClasspath>
|
||||||
|
<mainClass>betterproto.CompatibilityTest</mainClass>
|
||||||
|
</manifest>
|
||||||
|
</archive>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.xolstice.maven.plugins</groupId>
|
||||||
|
<artifactId>protobuf-maven-plugin</artifactId>
|
||||||
|
<version>0.6.1</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<goals>
|
||||||
|
<goal>compile</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<protocArtifact>
|
||||||
|
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
|
||||||
|
</protocArtifact>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
|
||||||
|
<finalName>${project.artifactId}</finalName>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
</project>
|
@ -0,0 +1,41 @@
|
|||||||
|
package betterproto;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class CompatibilityTest {
|
||||||
|
public static void main(String[] args) throws IOException {
|
||||||
|
if (args.length < 2)
|
||||||
|
throw new RuntimeException("Attempted to run without the required arguments.");
|
||||||
|
else if (args.length > 2)
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Attempted to run with more than the expected number of arguments (>1).");
|
||||||
|
|
||||||
|
Tests tests = new Tests(args[1]);
|
||||||
|
|
||||||
|
switch (args[0]) {
|
||||||
|
case "single_varint":
|
||||||
|
tests.testSingleVarint();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "multiple_varints":
|
||||||
|
tests.testMultipleVarints();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "single_message":
|
||||||
|
tests.testSingleMessage();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "multiple_messages":
|
||||||
|
tests.testMultipleMessages();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "infinite_messages":
|
||||||
|
tests.testInfiniteMessages();
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Attempted to run with unknown argument '" + args[0] + "'.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
115
tests/streams/java/src/main/java/betterproto/Tests.java
Normal file
115
tests/streams/java/src/main/java/betterproto/Tests.java
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package betterproto;
|
||||||
|
|
||||||
|
import betterproto.nested.NestedOuterClass;
|
||||||
|
import betterproto.oneof.Oneof;
|
||||||
|
|
||||||
|
import com.google.protobuf.CodedInputStream;
|
||||||
|
import com.google.protobuf.CodedOutputStream;
|
||||||
|
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class Tests {
|
||||||
|
String path;
|
||||||
|
|
||||||
|
public Tests(String path) {
|
||||||
|
this.path = path;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSingleVarint() throws IOException {
|
||||||
|
// Read in the Python-generated single varint file
|
||||||
|
FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
|
||||||
|
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||||
|
|
||||||
|
int value = codedInput.readUInt32();
|
||||||
|
|
||||||
|
inputStream.close();
|
||||||
|
|
||||||
|
// Write the value back to a file
|
||||||
|
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
|
||||||
|
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||||
|
|
||||||
|
codedOutput.writeUInt32NoTag(value);
|
||||||
|
|
||||||
|
codedOutput.flush();
|
||||||
|
outputStream.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMultipleVarints() throws IOException {
|
||||||
|
// Read in the Python-generated multiple varints file
|
||||||
|
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
|
||||||
|
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||||
|
|
||||||
|
int value1 = codedInput.readUInt32();
|
||||||
|
int value2 = codedInput.readUInt32();
|
||||||
|
long value3 = codedInput.readUInt64();
|
||||||
|
|
||||||
|
inputStream.close();
|
||||||
|
|
||||||
|
// Write the values back to a file
|
||||||
|
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
|
||||||
|
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||||
|
|
||||||
|
codedOutput.writeUInt32NoTag(value1);
|
||||||
|
codedOutput.writeUInt64NoTag(value2);
|
||||||
|
codedOutput.writeUInt64NoTag(value3);
|
||||||
|
|
||||||
|
codedOutput.flush();
|
||||||
|
outputStream.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSingleMessage() throws IOException {
|
||||||
|
// Read in the Python-generated single message file
|
||||||
|
FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
|
||||||
|
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||||
|
|
||||||
|
Oneof.Test message = Oneof.Test.parseFrom(codedInput);
|
||||||
|
|
||||||
|
inputStream.close();
|
||||||
|
|
||||||
|
// Write the message back to a file
|
||||||
|
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
|
||||||
|
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||||
|
|
||||||
|
message.writeTo(codedOutput);
|
||||||
|
|
||||||
|
codedOutput.flush();
|
||||||
|
outputStream.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMultipleMessages() throws IOException {
|
||||||
|
// Read in the Python-generated multi-message file
|
||||||
|
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
|
||||||
|
|
||||||
|
Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||||
|
NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
|
||||||
|
|
||||||
|
inputStream.close();
|
||||||
|
|
||||||
|
// Write the messages back to a file
|
||||||
|
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
|
||||||
|
|
||||||
|
oneof.writeDelimitedTo(outputStream);
|
||||||
|
nested.writeDelimitedTo(outputStream);
|
||||||
|
|
||||||
|
outputStream.flush();
|
||||||
|
outputStream.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInfiniteMessages() throws IOException {
|
||||||
|
// Read in as many messages as are present in the Python-generated file and write them back
|
||||||
|
FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
|
||||||
|
FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
|
||||||
|
|
||||||
|
Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||||
|
while (current != null) {
|
||||||
|
current.writeDelimitedTo(outputStream);
|
||||||
|
current = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||||
|
}
|
||||||
|
|
||||||
|
inputStream.close();
|
||||||
|
outputStream.flush();
|
||||||
|
outputStream.close();
|
||||||
|
}
|
||||||
|
}
|
27
tests/streams/java/src/main/proto/betterproto/nested.proto
Normal file
27
tests/streams/java/src/main/proto/betterproto/nested.proto
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package nested;
|
||||||
|
option java_package = "betterproto.nested";
|
||||||
|
|
||||||
|
// A test message with a nested message inside of it.
|
||||||
|
message Test {
|
||||||
|
// This is the nested type.
|
||||||
|
message Nested {
|
||||||
|
// Stores a simple counter.
|
||||||
|
int32 count = 1;
|
||||||
|
}
|
||||||
|
// This is the nested enum.
|
||||||
|
enum Msg {
|
||||||
|
NONE = 0;
|
||||||
|
THIS = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nested nested = 1;
|
||||||
|
Sibling sibling = 2;
|
||||||
|
Sibling sibling2 = 3;
|
||||||
|
Msg msg = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Sibling {
|
||||||
|
int32 foo = 1;
|
||||||
|
}
|
19
tests/streams/java/src/main/proto/betterproto/oneof.proto
Normal file
19
tests/streams/java/src/main/proto/betterproto/oneof.proto
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package oneof;
|
||||||
|
option java_package = "betterproto.oneof";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
oneof foo {
|
||||||
|
int32 pitied = 1;
|
||||||
|
string pitier = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32 just_a_regular_field = 3;
|
||||||
|
|
||||||
|
oneof bar {
|
||||||
|
int32 drinks = 11;
|
||||||
|
string bar_name = 12;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
79
tests/test_enum.py
Normal file
79
tests/test_enum.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import betterproto
|
||||||
|
|
||||||
|
|
||||||
|
class Colour(betterproto.Enum):
|
||||||
|
RED = 1
|
||||||
|
GREEN = 2
|
||||||
|
BLUE = 3
|
||||||
|
|
||||||
|
|
||||||
|
PURPLE = Colour.__new__(Colour, name=None, value=4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member, str_value",
|
||||||
|
[
|
||||||
|
(Colour.RED, "RED"),
|
||||||
|
(Colour.GREEN, "GREEN"),
|
||||||
|
(Colour.BLUE, "BLUE"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_str(member: Colour, str_value: str) -> None:
|
||||||
|
assert str(member) == str_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member, repr_value",
|
||||||
|
[
|
||||||
|
(Colour.RED, "Colour.RED"),
|
||||||
|
(Colour.GREEN, "Colour.GREEN"),
|
||||||
|
(Colour.BLUE, "Colour.BLUE"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_repr(member: Colour, repr_value: str) -> None:
|
||||||
|
assert repr(member) == repr_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member, values",
|
||||||
|
[
|
||||||
|
(Colour.RED, ("RED", 1)),
|
||||||
|
(Colour.GREEN, ("GREEN", 2)),
|
||||||
|
(Colour.BLUE, ("BLUE", 3)),
|
||||||
|
(PURPLE, (None, 4)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
|
||||||
|
assert (member.name, member.value) == values
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member, input_str",
|
||||||
|
[
|
||||||
|
(Colour.RED, "RED"),
|
||||||
|
(Colour.GREEN, "GREEN"),
|
||||||
|
(Colour.BLUE, "BLUE"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_from_string(member: Colour, input_str: str) -> None:
|
||||||
|
assert Colour.from_string(input_str) == member
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member, input_int",
|
||||||
|
[
|
||||||
|
(Colour.RED, 1),
|
||||||
|
(Colour.GREEN, 2),
|
||||||
|
(Colour.BLUE, 3),
|
||||||
|
(PURPLE, 4),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_try_value(member: Colour, input_int: int) -> None:
|
||||||
|
assert Colour.try_value(input_int) == member
|
@ -545,47 +545,6 @@ def test_oneof_default_value_set_causes_writes_wire():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_recursive_message():
|
|
||||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
|
||||||
|
|
||||||
msg = RecursiveMessage()
|
|
||||||
|
|
||||||
assert msg.child == RecursiveMessage()
|
|
||||||
|
|
||||||
# Lazily-created zero-value children must not affect equality.
|
|
||||||
assert msg == RecursiveMessage()
|
|
||||||
|
|
||||||
# Lazily-created zero-value children must not affect serialization.
|
|
||||||
assert bytes(msg) == b""
|
|
||||||
|
|
||||||
|
|
||||||
def test_recursive_message_defaults():
|
|
||||||
from tests.output_betterproto.recursivemessage import (
|
|
||||||
Intermediate,
|
|
||||||
Test as RecursiveMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
|
||||||
|
|
||||||
# set values are as expected
|
|
||||||
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
|
||||||
|
|
||||||
# lazy initialized works modifies the message
|
|
||||||
assert msg != RecursiveMessage(
|
|
||||||
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
|
||||||
)
|
|
||||||
msg.child.child.name = "jude"
|
|
||||||
assert msg == RecursiveMessage(
|
|
||||||
name="bob",
|
|
||||||
intermediate=Intermediate(42),
|
|
||||||
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# lazily initialization recurses as needed
|
|
||||||
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
|
||||||
assert msg.intermediate.child.intermediate == Intermediate()
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_repr():
|
def test_message_repr():
|
||||||
from tests.output_betterproto.recursivemessage import Test
|
from tests.output_betterproto.recursivemessage import Test
|
||||||
|
|
||||||
@ -699,25 +658,6 @@ def test_service_argument__expected_parameter():
|
|||||||
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
||||||
|
|
||||||
|
|
||||||
def test_copyability():
|
|
||||||
@dataclass
|
|
||||||
class Spam(betterproto.Message):
|
|
||||||
foo: bool = betterproto.bool_field(1)
|
|
||||||
bar: int = betterproto.int32_field(2)
|
|
||||||
baz: List[str] = betterproto.string_field(3)
|
|
||||||
|
|
||||||
spam = Spam(bar=12, baz=["hello"])
|
|
||||||
copied = copy(spam)
|
|
||||||
assert spam == copied
|
|
||||||
assert spam is not copied
|
|
||||||
assert spam.baz is copied.baz
|
|
||||||
|
|
||||||
deepcopied = deepcopy(spam)
|
|
||||||
assert spam == deepcopied
|
|
||||||
assert spam is not deepcopied
|
|
||||||
assert spam.baz is not deepcopied.baz
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_set():
|
def test_is_set():
|
||||||
@dataclass
|
@dataclass
|
||||||
class Spam(betterproto.Message):
|
class Spam(betterproto.Message):
|
||||||
|
203
tests/test_pickling.py
Normal file
203
tests/test_pickling.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import pickle
|
||||||
|
from copy import (
|
||||||
|
copy,
|
||||||
|
deepcopy,
|
||||||
|
)
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
)
|
||||||
|
from unittest.mock import ANY
|
||||||
|
|
||||||
|
import cachelib
|
||||||
|
|
||||||
|
import betterproto
|
||||||
|
from betterproto.lib.google import protobuf as google
|
||||||
|
|
||||||
|
|
||||||
|
def unpickled(message):
|
||||||
|
return pickle.loads(pickle.dumps(message))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False, repr=False)
|
||||||
|
class Fe(betterproto.Message):
|
||||||
|
abc: str = betterproto.string_field(1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False, repr=False)
|
||||||
|
class Fi(betterproto.Message):
|
||||||
|
abc: str = betterproto.string_field(1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False, repr=False)
|
||||||
|
class Fo(betterproto.Message):
|
||||||
|
abc: str = betterproto.string_field(1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False, repr=False)
|
||||||
|
class NestedData(betterproto.Message):
|
||||||
|
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
|
||||||
|
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||||
|
)
|
||||||
|
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
|
||||||
|
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False, repr=False)
|
||||||
|
class Complex(betterproto.Message):
|
||||||
|
foo_str: str = betterproto.string_field(1)
|
||||||
|
fe: "Fe" = betterproto.message_field(3, group="grp")
|
||||||
|
fi: "Fi" = betterproto.message_field(4, group="grp")
|
||||||
|
fo: "Fo" = betterproto.message_field(5, group="grp")
|
||||||
|
nested_data: "NestedData" = betterproto.message_field(6)
|
||||||
|
mapping: Dict[str, "google.Any"] = betterproto.map_field(
|
||||||
|
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def complex_msg():
|
||||||
|
return Complex(
|
||||||
|
foo_str="yep",
|
||||||
|
fe=Fe(abc="1"),
|
||||||
|
nested_data=NestedData(
|
||||||
|
struct_foo={
|
||||||
|
"foo": google.Struct(
|
||||||
|
fields={
|
||||||
|
"hello": google.Value(
|
||||||
|
list_value=google.ListValue(
|
||||||
|
values=[google.Value(string_value="world")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
map_str_any_bar={
|
||||||
|
"key": google.Any(value=b"value"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
mapping={
|
||||||
|
"message": google.Any(value=bytes(Fi(abc="hi"))),
|
||||||
|
"string": google.Any(value=b"howdy"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickling_complex_message():
|
||||||
|
msg = complex_msg()
|
||||||
|
deser = unpickled(msg)
|
||||||
|
assert msg == deser
|
||||||
|
assert msg.fe.abc == "1"
|
||||||
|
assert msg.is_set("fi") is not True
|
||||||
|
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
|
||||||
|
assert msg.mapping["string"].value.decode() == "howdy"
|
||||||
|
assert (
|
||||||
|
msg.nested_data.struct_foo["foo"]
|
||||||
|
.fields["hello"]
|
||||||
|
.list_value.values[0]
|
||||||
|
.string_value
|
||||||
|
== "world"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_message():
|
||||||
|
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||||
|
|
||||||
|
msg = RecursiveMessage()
|
||||||
|
msg = unpickled(msg)
|
||||||
|
|
||||||
|
assert msg.child == RecursiveMessage()
|
||||||
|
|
||||||
|
# Lazily-created zero-value children must not affect equality.
|
||||||
|
assert msg == RecursiveMessage()
|
||||||
|
|
||||||
|
# Lazily-created zero-value children must not affect serialization.
|
||||||
|
assert bytes(msg) == b""
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_message_defaults():
|
||||||
|
from tests.output_betterproto.recursivemessage import (
|
||||||
|
Intermediate,
|
||||||
|
Test as RecursiveMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||||
|
msg = unpickled(msg)
|
||||||
|
|
||||||
|
# set values are as expected
|
||||||
|
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||||
|
|
||||||
|
# lazy initialized works modifies the message
|
||||||
|
assert msg != RecursiveMessage(
|
||||||
|
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
||||||
|
)
|
||||||
|
msg.child.child.name = "jude"
|
||||||
|
assert msg == RecursiveMessage(
|
||||||
|
name="bob",
|
||||||
|
intermediate=Intermediate(42),
|
||||||
|
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# lazily initialization recurses as needed
|
||||||
|
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
||||||
|
assert msg.intermediate.child.intermediate == Intermediate()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PickledMessage(betterproto.Message):
|
||||||
|
foo: bool = betterproto.bool_field(1)
|
||||||
|
bar: int = betterproto.int32_field(2)
|
||||||
|
baz: List[str] = betterproto.string_field(3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_copyability():
|
||||||
|
msg = PickledMessage(bar=12, baz=["hello"])
|
||||||
|
msg = unpickled(msg)
|
||||||
|
|
||||||
|
copied = copy(msg)
|
||||||
|
assert msg == copied
|
||||||
|
assert msg is not copied
|
||||||
|
assert msg.baz is copied.baz
|
||||||
|
|
||||||
|
deepcopied = deepcopy(msg)
|
||||||
|
assert msg == deepcopied
|
||||||
|
assert msg is not deepcopied
|
||||||
|
assert msg.baz is not deepcopied.baz
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_can_be_cached():
|
||||||
|
"""Cachelib uses pickling to cache values"""
|
||||||
|
|
||||||
|
cache = cachelib.SimpleCache()
|
||||||
|
|
||||||
|
def use_cache():
|
||||||
|
calls = getattr(use_cache, "calls", 0)
|
||||||
|
result = cache.get("message")
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
setattr(use_cache, "calls", calls + 1)
|
||||||
|
result = complex_msg()
|
||||||
|
cache.set("message", result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
for n in range(10):
|
||||||
|
if n == 0:
|
||||||
|
assert not cache.has("message")
|
||||||
|
else:
|
||||||
|
assert cache.has("message")
|
||||||
|
|
||||||
|
msg = use_cache()
|
||||||
|
assert use_cache.calls == 1 # The message is only ever built once
|
||||||
|
assert msg.fe.abc == "1"
|
||||||
|
assert msg.is_set("fi") is not True
|
||||||
|
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
|
||||||
|
assert msg.mapping["string"].value.decode() == "howdy"
|
||||||
|
assert (
|
||||||
|
msg.nested_data.struct_foo["foo"]
|
||||||
|
.fields["hello"]
|
||||||
|
.list_value.values[0]
|
||||||
|
.string_value
|
||||||
|
== "world"
|
||||||
|
)
|
@ -1,6 +1,8 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from shutil import which
|
||||||
|
from subprocess import run
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -40,6 +42,8 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
|
|||||||
|
|
||||||
streams_path = Path("tests/streams/")
|
streams_path = Path("tests/streams/")
|
||||||
|
|
||||||
|
java = which("java")
|
||||||
|
|
||||||
|
|
||||||
def test_load_varint_too_long():
|
def test_load_varint_too_long():
|
||||||
with BytesIO(
|
with BytesIO(
|
||||||
@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path):
|
|||||||
assert test_stream.read() == exp_stream.read()
|
assert test_stream.read() == exp_stream.read()
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dump_delimited(tmp_path):
|
||||||
|
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
|
||||||
|
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
|
||||||
|
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
|
||||||
|
streams_path / "delimited_messages.in", "rb"
|
||||||
|
) as exp_stream:
|
||||||
|
assert test_stream.read() == exp_stream.read()
|
||||||
|
|
||||||
|
|
||||||
def test_message_len():
|
def test_message_len():
|
||||||
assert len_oneof == len(bytes(oneof_example))
|
assert len_oneof == len(bytes(oneof_example))
|
||||||
assert len(nested_example) == len(bytes(nested_example))
|
assert len(nested_example) == len(bytes(nested_example))
|
||||||
@ -155,7 +171,15 @@ def test_message_load_too_small():
|
|||||||
oneof.Test().load(stream, len_oneof - 1)
|
oneof.Test().load(stream, len_oneof - 1)
|
||||||
|
|
||||||
|
|
||||||
def test_message_too_large():
|
def test_message_load_delimited():
|
||||||
|
with open(streams_path / "delimited_messages.in", "rb") as stream:
|
||||||
|
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||||
|
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||||
|
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
|
||||||
|
assert stream.read(1) == b""
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_load_too_large():
|
||||||
with open(
|
with open(
|
||||||
streams_path / "message_dump_file_single.expected", "rb"
|
streams_path / "message_dump_file_single.expected", "rb"
|
||||||
) as stream, pytest.raises(ValueError):
|
) as stream, pytest.raises(ValueError):
|
||||||
@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path):
|
|||||||
streams_path / "dump_varint_positive.expected", "rb"
|
streams_path / "dump_varint_positive.expected", "rb"
|
||||||
) as exp_stream:
|
) as exp_stream:
|
||||||
assert test_stream.read() == exp_stream.read()
|
assert test_stream.read() == exp_stream.read()
|
||||||
|
|
||||||
|
|
||||||
|
# Java compatibility tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def compile_jar():
|
||||||
|
# Skip if not all required tools are present
|
||||||
|
if java is None:
|
||||||
|
pytest.skip("`java` command is absent and is required")
|
||||||
|
mvn = which("mvn")
|
||||||
|
if mvn is None:
|
||||||
|
pytest.skip("Maven is absent and is required")
|
||||||
|
|
||||||
|
# Compile the JAR
|
||||||
|
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
|
||||||
|
if proc_maven.returncode != 0:
|
||||||
|
pytest.skip(
|
||||||
|
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
jar = "tests/streams/java/target/compatibility-test.jar"
|
||||||
|
|
||||||
|
|
||||||
|
def run_jar(command: str, tmp_path):
|
||||||
|
return run([java, "-jar", jar, command, tmp_path], check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def run_java_single_varint(value: int, tmp_path) -> int:
|
||||||
|
# Write single varint to file
|
||||||
|
with open(tmp_path / "py_single_varint.out", "wb") as stream:
|
||||||
|
betterproto.dump_varint(value, stream)
|
||||||
|
|
||||||
|
# Have Java read this varint and write it back
|
||||||
|
run_jar("single_varint", tmp_path)
|
||||||
|
|
||||||
|
# Read single varint from Java output file
|
||||||
|
with open(tmp_path / "java_single_varint.out", "rb") as stream:
|
||||||
|
returned = betterproto.load_varint(stream)
|
||||||
|
with pytest.raises(EOFError):
|
||||||
|
betterproto.load_varint(stream)
|
||||||
|
|
||||||
|
return returned
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_varint(compile_jar, tmp_path):
|
||||||
|
single_byte = (1, b"\x01")
|
||||||
|
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
|
||||||
|
|
||||||
|
# Write a single-byte varint to a file and have Java read it back
|
||||||
|
returned = run_java_single_varint(single_byte[0], tmp_path)
|
||||||
|
assert returned == single_byte
|
||||||
|
|
||||||
|
# Same for a multi-byte varint
|
||||||
|
returned = run_java_single_varint(multi_byte[0], tmp_path)
|
||||||
|
assert returned == multi_byte
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_varints(compile_jar, tmp_path):
|
||||||
|
single_byte = (1, b"\x01")
|
||||||
|
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
|
||||||
|
over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")
|
||||||
|
|
||||||
|
# Write two varints to the same file
|
||||||
|
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
|
||||||
|
betterproto.dump_varint(single_byte[0], stream)
|
||||||
|
betterproto.dump_varint(multi_byte[0], stream)
|
||||||
|
betterproto.dump_varint(over32[0], stream)
|
||||||
|
|
||||||
|
# Have Java read these varints and write them back
|
||||||
|
run_jar("multiple_varints", tmp_path)
|
||||||
|
|
||||||
|
# Read varints from Java output file
|
||||||
|
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
|
||||||
|
returned_single = betterproto.load_varint(stream)
|
||||||
|
returned_multi = betterproto.load_varint(stream)
|
||||||
|
returned_over32 = betterproto.load_varint(stream)
|
||||||
|
with pytest.raises(EOFError):
|
||||||
|
betterproto.load_varint(stream)
|
||||||
|
|
||||||
|
assert returned_single == single_byte
|
||||||
|
assert returned_multi == multi_byte
|
||||||
|
assert returned_over32 == over32
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_message(compile_jar, tmp_path):
|
||||||
|
# Write message to file
|
||||||
|
with open(tmp_path / "py_single_message.out", "wb") as stream:
|
||||||
|
oneof_example.dump(stream)
|
||||||
|
|
||||||
|
# Have Java read and return the message
|
||||||
|
run_jar("single_message", tmp_path)
|
||||||
|
|
||||||
|
# Read and check the returned message
|
||||||
|
with open(tmp_path / "java_single_message.out", "rb") as stream:
|
||||||
|
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
|
||||||
|
assert stream.read() == b""
|
||||||
|
|
||||||
|
assert returned == oneof_example
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_messages(compile_jar, tmp_path):
|
||||||
|
# Write delimited messages to file
|
||||||
|
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
|
||||||
|
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
|
||||||
|
# Have Java read and return the messages
|
||||||
|
run_jar("multiple_messages", tmp_path)
|
||||||
|
|
||||||
|
# Read and check the returned messages
|
||||||
|
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
|
||||||
|
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
assert stream.read() == b""
|
||||||
|
|
||||||
|
assert returned_oneof == oneof_example
|
||||||
|
assert returned_nested == nested_example
|
||||||
|
|
||||||
|
|
||||||
|
def test_infinite_messages(compile_jar, tmp_path):
|
||||||
|
num_messages = 5
|
||||||
|
|
||||||
|
# Write delimited messages to file
|
||||||
|
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
|
||||||
|
for x in range(num_messages):
|
||||||
|
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||||
|
|
||||||
|
# Have Java read and return the messages
|
||||||
|
run_jar("infinite_messages", tmp_path)
|
||||||
|
|
||||||
|
# Read and check the returned messages
|
||||||
|
messages = []
|
||||||
|
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert len(messages) == num_messages
|
||||||
|
27
tests/test_timestamp.py
Normal file
27
tests/test_timestamp.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from datetime import (
|
||||||
|
datetime,
|
||||||
|
timezone,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto import _Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dt",
|
||||||
|
[
|
||||||
|
datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
|
||||||
|
datetime.now(timezone.utc),
|
||||||
|
# potential issue with floating point precision:
|
||||||
|
datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
|
||||||
|
# potential issue with negative timestamps:
|
||||||
|
datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_timestamp_to_datetime_and_back(dt: datetime):
|
||||||
|
"""
|
||||||
|
Make sure converting a datetime to a protobuf timestamp message
|
||||||
|
and then back again ends up with the same datetime.
|
||||||
|
"""
|
||||||
|
assert _Timestamp.from_datetime(dt).to_datetime() == dt
|
Loading…
x
Reference in New Issue
Block a user