Merge branch 'master_gh'
# Conflicts: # src/betterproto/__init__.py
This commit is contained in:
commit
1d296f1a88
63
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: Bug Report
|
||||
description: Report broken or incorrect behaviour
|
||||
labels: ["bug", "investigation needed"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for taking the time to fill out a bug report!
|
||||
|
||||
If you're not sure it's a bug and you just have a question, the [community Slack channel](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ) is a better place for general questions than a GitHub issue.
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Summary
|
||||
description: A simple summary of your bug report
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Reproduction Steps
|
||||
description: >
|
||||
What you did to make it happen.
|
||||
Ideally there should be a short code snippet in this section to help reproduce the bug.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Results
|
||||
description: >
|
||||
What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Actual Results
|
||||
description: >
|
||||
What actually happened?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: System Information
|
||||
description: >
|
||||
Paste the result of `protoc --version; python --version; pip show betterproto` below.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have searched the issues for duplicates.
|
||||
required: true
|
||||
- label: I have shown the entire traceback, if possible.
|
||||
required: true
|
||||
- label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible.
|
||||
required: true
|
||||
|
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
name:
|
||||
description:
|
||||
contact_links:
|
||||
- name: For questions about the library
|
||||
about: Support questions are better answered in our Slack group.
|
||||
url: https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ
|
49
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
49
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
name: Feature Request
|
||||
description: Suggest a feature for this library
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: input
|
||||
attributes:
|
||||
label: Summary
|
||||
description: >
|
||||
What problem is your feature trying to solve? What would become easier or possible if feature was implemented?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
multiple: false
|
||||
label: What is the feature request for?
|
||||
options:
|
||||
- The core library
|
||||
- RPC handling
|
||||
- The documentation
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The Problem
|
||||
description: >
|
||||
What problem is your feature trying to solve?
|
||||
What would become easier or possible if feature was implemented?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The Ideal Solution
|
||||
description: >
|
||||
What is your ideal solution to the problem?
|
||||
What would you like this feature to do?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The Current Solution
|
||||
description: >
|
||||
What is the current solution to the problem, if any?
|
||||
validations:
|
||||
required: false
|
16
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
16
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
## Summary
|
||||
|
||||
<!-- What is this pull request for? Does it fix any issues? -->
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Put an x inside [ ] to check it, like so: [x] -->
|
||||
|
||||
- [ ] If code changes were made then they have been tested.
|
||||
- [ ] I have updated the documentation to reflect the changes.
|
||||
- [ ] This PR fixes an issue.
|
||||
- [ ] This PR adds something new (e.g. new method or parameters).
|
||||
- [ ] This change has an associated test.
|
||||
- [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed)
|
||||
- [ ] This PR is **not** a code change (e.g. documentation, README, ...)
|
||||
|
46
.github/workflows/codeql-analysis.yml
vendored
Normal file
46
.github/workflows/codeql-analysis.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
schedule:
|
||||
- cron: '19 1 * * 6'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ 'python' ]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
@ -16,6 +16,13 @@ repos:
|
||||
- repo: https://github.com/PyCQA/doc8
|
||||
rev: 0.10.1
|
||||
hooks:
|
||||
- id: doc8
|
||||
- id: doc8
|
||||
additional_dependencies:
|
||||
- toml
|
||||
|
||||
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
|
||||
rev: v2.10.0
|
||||
hooks:
|
||||
- id: pretty-format-java
|
||||
args: [--autofix, --aosp]
|
||||
files: ^.*\.java$
|
||||
|
1200
poetry.lock
generated
1200
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -19,26 +19,29 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
|
||||
jinja2 = { version = ">=3.0.3", optional = true }
|
||||
python-dateutil = "^2.8"
|
||||
isort = {version = "^5.11.5", optional = true}
|
||||
typing-extensions = "^4.7.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
asv = "^0.4.2"
|
||||
bpython = "^0.19"
|
||||
grpcio-tools = "^1.54.2"
|
||||
jinja2 = ">=3.0.3"
|
||||
mypy = "^0.930"
|
||||
sphinx = "3.1.2"
|
||||
sphinx-rtd-theme = "0.5.0"
|
||||
pre-commit = "^2.17.0"
|
||||
grpcio-tools = "^1.54.2"
|
||||
tox = "^4.0.0"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
poethepoet = ">=0.9.0"
|
||||
protobuf = "^4.21.6"
|
||||
pytest = "^6.2.5"
|
||||
pytest-asyncio = "^0.12.0"
|
||||
pytest-cov = "^2.9.0"
|
||||
pytest-mock = "^3.1.1"
|
||||
sphinx = "3.1.2"
|
||||
sphinx-rtd-theme = "0.5.0"
|
||||
tomlkit = "^0.7.0"
|
||||
tox = "^3.15.1"
|
||||
pre-commit = "^2.17.0"
|
||||
pydantic = ">=1.8.0"
|
||||
|
||||
pydantic = ">=1.8.0,<2"
|
||||
protobuf = "^4"
|
||||
cachelib = "^0.10.2"
|
||||
tomlkit = ">=0.7.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
protoc-gen-python_betterproto = "betterproto.plugin:main"
|
||||
@ -61,9 +64,13 @@ help = "Run tests"
|
||||
cmd = "mypy src --ignore-missing-imports"
|
||||
help = "Check types with mypy"
|
||||
|
||||
[tool.poe.tasks]
|
||||
_black = "black . --exclude tests/output_ --target-version py310"
|
||||
_isort = "isort . --extend-skip-glob 'tests/output_*/**/*'"
|
||||
|
||||
[tool.poe.tasks.format]
|
||||
cmd = "black . --exclude tests/output_ --target-version py310"
|
||||
help = "Apply black formatting to source code"
|
||||
sequence = ["_black", "_isort"]
|
||||
help = "Apply black and isort formatting to source code"
|
||||
|
||||
[tool.poe.tasks.docs]
|
||||
cmd = "sphinx-build docs docs/build"
|
||||
@ -130,14 +137,21 @@ omit = ["betterproto/tests/*"]
|
||||
[tool.tox]
|
||||
legacy_tox_ini = """
|
||||
[tox]
|
||||
isolated_build = true
|
||||
envlist = py37, py38, py310
|
||||
requires =
|
||||
tox>=4.2
|
||||
tox-poetry-installer[poetry]==1.0.0b1
|
||||
env_list =
|
||||
py311
|
||||
py38
|
||||
py37
|
||||
|
||||
[testenv]
|
||||
whitelist_externals = poetry
|
||||
commands =
|
||||
poetry install -v --extras compiler
|
||||
poetry run pytest --cov betterproto
|
||||
pytest {posargs: --cov betterproto}
|
||||
poetry_dep_groups =
|
||||
test
|
||||
require_locked_deps = true
|
||||
require_poetry = true
|
||||
"""
|
||||
|
||||
[build-system]
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import enum as builtin_enum
|
||||
import json
|
||||
import math
|
||||
import struct
|
||||
@ -22,8 +24,8 @@ from itertools import count
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
@ -37,6 +39,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._types import T
|
||||
from ._version import __version__
|
||||
@ -45,11 +48,19 @@ from .casing import (
|
||||
safe_snake_case,
|
||||
snake_case,
|
||||
)
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
from .enum import Enum as Enum
|
||||
from .grpc.grpclib_client import ServiceStub as ServiceStub
|
||||
from .utils import (
|
||||
classproperty,
|
||||
hybridmethod,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import ReadableBuffer
|
||||
from _typeshed import (
|
||||
SupportsRead,
|
||||
SupportsWrite,
|
||||
)
|
||||
|
||||
|
||||
# Proto 3 data types
|
||||
@ -126,6 +137,9 @@ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
||||
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
# Indicator of message delimitation in streams
|
||||
SIZE_DELIMITED = -1
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
def datetime_default_gen() -> datetime:
|
||||
@ -140,7 +154,7 @@ NEG_INFINITY = "-Infinity"
|
||||
NAN = "NaN"
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
class Casing(builtin_enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
CAMEL = camel_case #: A camelCase sterilization function.
|
||||
@ -309,32 +323,6 @@ def map_field(
|
||||
)
|
||||
|
||||
|
||||
class Enum(enum.IntEnum):
|
||||
"""
|
||||
The base class for protobuf enumerations, all generated enumerations will inherit
|
||||
from this. Bases :class:`enum.IntEnum`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str) -> "Enum":
|
||||
"""Return the value which corresponds to the string name.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the enum member to get
|
||||
|
||||
Raises
|
||||
-------
|
||||
:exc:`ValueError`
|
||||
The member was not found in the Enum.
|
||||
"""
|
||||
try:
|
||||
return cls._member_map_[name] # type: ignore
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
||||
|
||||
|
||||
def _pack_fmt(proto_type: str) -> str:
|
||||
"""Returns a little-endian format string for reading/writing binary."""
|
||||
return {
|
||||
@ -347,7 +335,7 @@ def _pack_fmt(proto_type: str) -> str:
|
||||
}[proto_type]
|
||||
|
||||
|
||||
def dump_varint(value: int, stream: BinaryIO) -> None:
|
||||
def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
|
||||
"""Encodes a single varint and dumps it into the provided stream."""
|
||||
if value < -(1 << 63):
|
||||
raise ValueError(
|
||||
@ -556,7 +544,7 @@ def _dump_float(value: float) -> Union[float, str]:
|
||||
return value
|
||||
|
||||
|
||||
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
|
||||
def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
|
||||
"""
|
||||
Load a single varint value from a stream. Returns the value and the raw bytes read.
|
||||
"""
|
||||
@ -594,7 +582,7 @@ class ParsedField:
|
||||
raw: bytes
|
||||
|
||||
|
||||
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
|
||||
def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
|
||||
while True:
|
||||
try:
|
||||
num_wire, raw = load_varint(stream)
|
||||
@ -748,6 +736,7 @@ class Message(ABC):
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_current: Dict[str, str]
|
||||
_betterproto_meta: ClassVar[ProtoClassMetadata]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Keep track of whether every field was default
|
||||
@ -815,6 +804,10 @@ class Message(ABC):
|
||||
]
|
||||
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||
|
||||
def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
|
||||
for field_name in self._betterproto.sorted_field_names:
|
||||
yield field_name, self.__raw_get(field_name), PLACEHOLDER
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
@ -889,20 +882,28 @@ class Message(ABC):
|
||||
kwargs[name] = deepcopy(value)
|
||||
return self.__class__(**kwargs) # type: ignore
|
||||
|
||||
@property
|
||||
def _betterproto(self) -> ProtoClassMetadata:
|
||||
def __copy__(self: T, _: Any = {}) -> T:
|
||||
kwargs = {}
|
||||
for name in self._betterproto.sorted_field_names:
|
||||
value = self.__raw_get(name)
|
||||
if value is not PLACEHOLDER:
|
||||
kwargs[name] = value
|
||||
return self.__class__(**kwargs) # type: ignore
|
||||
|
||||
@classproperty
|
||||
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
|
||||
"""
|
||||
Lazy initialize metadata for each protobuf class.
|
||||
It may be initialized multiple times in a multi-threaded environment,
|
||||
but that won't affect the correctness.
|
||||
"""
|
||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||
if not meta:
|
||||
meta = ProtoClassMetadata(self.__class__)
|
||||
self.__class__._betterproto_meta = meta # type: ignore
|
||||
return meta
|
||||
try:
|
||||
return cls._betterproto_meta
|
||||
except AttributeError:
|
||||
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
|
||||
return meta
|
||||
|
||||
def dump(self, stream: BinaryIO) -> None:
|
||||
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
|
||||
"""
|
||||
Dumps the binary encoded Protobuf message to the stream.
|
||||
|
||||
@ -910,7 +911,11 @@ class Message(ABC):
|
||||
-----------
|
||||
stream: :class:`BinaryIO`
|
||||
The stream to dump the message to.
|
||||
delimit:
|
||||
Whether to prefix the message with a varint declaring its size.
|
||||
"""
|
||||
if delimit == SIZE_DELIMITED:
|
||||
dump_varint(len(self), stream)
|
||||
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
try:
|
||||
@ -930,7 +935,7 @@ class Message(ABC):
|
||||
# Note that proto3 field presence/optional fields are put in a
|
||||
# synthetic single-item oneof by protoc, which helps us ensure we
|
||||
# send the value even if the value is the default zero value.
|
||||
selected_in_group = bool(meta.group)
|
||||
selected_in_group = bool(meta.group) or meta.optional
|
||||
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or received empty).
|
||||
@ -1124,6 +1129,15 @@ class Message(ABC):
|
||||
"""
|
||||
return bytes(self)
|
||||
|
||||
def __getstate__(self) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
def __setstate__(self: T, pickled_bytes: bytes) -> T:
|
||||
return self.parse(pickled_bytes)
|
||||
|
||||
def __reduce__(self) -> Tuple[Any, ...]:
|
||||
return (self.__class__.FromString, (bytes(self),))
|
||||
|
||||
@classmethod
|
||||
def _type_hint(cls, field_name: str) -> Type:
|
||||
return cls._type_hints()[field_name]
|
||||
@ -1168,7 +1182,7 @@ class Message(ABC):
|
||||
return t
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
return int
|
||||
return t.try_value
|
||||
elif t is datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
return datetime_default_gen
|
||||
@ -1193,6 +1207,9 @@ class Message(ABC):
|
||||
elif meta.proto_type == TYPE_BOOL:
|
||||
# Booleans use a varint encoding, so convert it to true/false.
|
||||
value = value > 0
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
# Convert enum ints to python enum instances
|
||||
value = self._betterproto.cls_by_field[field_name].try_value(value)
|
||||
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
@ -1225,7 +1242,11 @@ class Message(ABC):
|
||||
meta.group is not None and self._group_current.get(meta.group) == field_name
|
||||
)
|
||||
|
||||
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
|
||||
def load(
|
||||
self: T,
|
||||
stream: "SupportsRead[bytes]",
|
||||
size: Optional[int] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Load the binary encoded Protobuf from a stream into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
@ -1237,12 +1258,17 @@ class Message(ABC):
|
||||
size: :class:`Optional[int]`
|
||||
The size of the message in the stream.
|
||||
Reads stream until EOF if ``None`` is given.
|
||||
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Message`
|
||||
The initialized message.
|
||||
"""
|
||||
# If the message is delimited, parse the message delimiter
|
||||
if size == SIZE_DELIMITED:
|
||||
size, _ = load_varint(stream)
|
||||
|
||||
# Got some data over the wire
|
||||
self._serialized_on_wire = True
|
||||
proto_meta = self._betterproto
|
||||
@ -1315,7 +1341,7 @@ class Message(ABC):
|
||||
|
||||
return self
|
||||
|
||||
def parse(self: T, data: "ReadableBuffer") -> T:
|
||||
def parse(self: T, data: bytes) -> T:
|
||||
"""
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
@ -1494,7 +1520,91 @@ class Message(ABC):
|
||||
output[cased_name] = value
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: Mapping[str, Any]) -> T:
|
||||
@classmethod
|
||||
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
init_kwargs: Dict[str, Any] = {}
|
||||
for key, value in mapping.items():
|
||||
field_name = safe_snake_case(key)
|
||||
try:
|
||||
meta = cls._betterproto.meta_by_field_name[field_name]
|
||||
except KeyError:
|
||||
continue
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
sub_cls = cls._betterproto.cls_by_field[field_name]
|
||||
if sub_cls == datetime:
|
||||
value = (
|
||||
[isoparse(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else isoparse(value)
|
||||
)
|
||||
elif sub_cls == timedelta:
|
||||
value = (
|
||||
[timedelta(seconds=float(item[:-1])) for item in value]
|
||||
if isinstance(value, list)
|
||||
else timedelta(seconds=float(value[:-1]))
|
||||
)
|
||||
elif not meta.wraps:
|
||||
value = (
|
||||
[sub_cls.from_dict(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else sub_cls.from_dict(value)
|
||||
)
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
|
||||
else:
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
value = (
|
||||
[int(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else int(value)
|
||||
)
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
value = (
|
||||
[b64decode(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else b64decode(value)
|
||||
)
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = cls._betterproto.cls_by_field[field_name]
|
||||
if isinstance(value, list):
|
||||
value = [enum_cls.from_string(e) for e in value]
|
||||
elif isinstance(value, str):
|
||||
value = enum_cls.from_string(value)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
value = (
|
||||
[_parse_float(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else _parse_float(value)
|
||||
)
|
||||
|
||||
init_kwargs[field_name] = value
|
||||
return init_kwargs
|
||||
|
||||
@hybridmethod
|
||||
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
|
||||
"""
|
||||
Parse the key/value pairs into the a new message instance.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
value: Dict[:class:`str`, Any]
|
||||
The dictionary to parse from.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Message`
|
||||
The initialized message.
|
||||
"""
|
||||
self = cls(**cls._from_dict_init(value))
|
||||
self._serialized_on_wire = True
|
||||
return self
|
||||
|
||||
@from_dict.instancemethod
|
||||
def from_dict(self, value: Mapping[str, Any]) -> Self:
|
||||
"""
|
||||
Parse the key/value pairs into the current message instance. This returns the
|
||||
instance itself and is therefore assignable and chainable.
|
||||
@ -1510,71 +1620,8 @@ class Message(ABC):
|
||||
The initialized message.
|
||||
"""
|
||||
self._serialized_on_wire = True
|
||||
for key in value:
|
||||
field_name = safe_snake_case(key)
|
||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
v = self._get_field_default(field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
if cls == datetime:
|
||||
v = [isoparse(item) for item in value[key]]
|
||||
elif cls == timedelta:
|
||||
v = [
|
||||
timedelta(seconds=float(item[:-1]))
|
||||
for item in value[key]
|
||||
]
|
||||
else:
|
||||
v = [cls().from_dict(item) for item in value[key]]
|
||||
elif cls == datetime:
|
||||
v = isoparse(value[key])
|
||||
setattr(self, field_name, v)
|
||||
elif cls == timedelta:
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field_name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field_name, value[key])
|
||||
elif v is None:
|
||||
setattr(self, field_name, cls().from_dict(value[key]))
|
||||
else:
|
||||
# NOTE: `from_dict` mutates the underlying message, so no
|
||||
# assignment here is necessary.
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if isinstance(value[key], list):
|
||||
v = [_parse_float(n) for n in value[key]]
|
||||
else:
|
||||
v = _parse_float(value[key])
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
for field, value in self._from_dict_init(value).items():
|
||||
setattr(self, field, value)
|
||||
return self
|
||||
|
||||
def to_json(
|
||||
@ -1791,8 +1838,8 @@ class Message(ABC):
|
||||
|
||||
@classmethod
|
||||
def _validate_field_groups(cls, values):
|
||||
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
||||
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
|
||||
group_to_one_ofs = cls._betterproto.oneof_field_by_group
|
||||
field_name_to_meta = cls._betterproto.meta_by_field_name
|
||||
|
||||
for group, field_set in group_to_one_ofs.items():
|
||||
if len(field_set) == 1:
|
||||
@ -1819,6 +1866,9 @@ class Message(ABC):
|
||||
return values
|
||||
|
||||
|
||||
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
|
||||
|
||||
|
||||
def serialized_on_wire(message: Message) -> bool:
|
||||
"""
|
||||
If this message was or should be serialized on the wire. This can be used to detect
|
||||
@ -1890,17 +1940,24 @@ class _Duration(Duration):
|
||||
class _Timestamp(Timestamp):
|
||||
@classmethod
|
||||
def from_datetime(cls, dt: datetime) -> "_Timestamp":
|
||||
seconds = int(dt.timestamp())
|
||||
nanos = int(dt.microsecond * 1e3)
|
||||
return cls(seconds, nanos)
|
||||
# manual epoch offset calulation to avoid rounding errors,
|
||||
# to support negative timestamps (before 1970) and skirt
|
||||
# around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
|
||||
offset = dt - DATETIME_ZERO
|
||||
# below is the same as timedelta.total_seconds() but without dividing by 1e6
|
||||
# so we end up with microseconds as integers instead of seconds as float
|
||||
offset_us = (
|
||||
offset.days * 24 * 60 * 60 + offset.seconds
|
||||
) * 10**6 + offset.microseconds
|
||||
seconds, us = divmod(offset_us, 10**6)
|
||||
return cls(seconds, us * 1000)
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
|
||||
if ts < 0:
|
||||
return datetime(1970, 1, 1) + timedelta(seconds=ts)
|
||||
else:
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
# datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
|
||||
# if we pass it as a floating point number, we will run into rounding errors
|
||||
# see also #407
|
||||
offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
|
||||
return DATETIME_ZERO + offset
|
||||
|
||||
@staticmethod
|
||||
def timestamp_to_json(dt: datetime) -> str:
|
||||
|
@ -136,4 +136,8 @@ def lowercase_first(value: str) -> str:
|
||||
|
||||
def sanitize_name(value: str) -> str:
|
||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||
return f"{value}_" if keyword.iskeyword(value) else value
|
||||
if keyword.iskeyword(value):
|
||||
return f"{value}_"
|
||||
if not value.isidentifier():
|
||||
return f"_{value}"
|
||||
return value
|
||||
|
@ -11,3 +11,11 @@ def pythonize_field_name(name: str) -> str:
|
||||
|
||||
def pythonize_method_name(name: str) -> str:
|
||||
return casing.safe_snake_case(name)
|
||||
|
||||
|
||||
def pythonize_enum_member_name(name: str, enum_name: str) -> str:
|
||||
enum_name = casing.snake_case(enum_name).upper()
|
||||
find = name.find(enum_name)
|
||||
if find != -1:
|
||||
name = name[find + len(enum_name) :].strip("_")
|
||||
return casing.sanitize_name(name)
|
||||
|
196
src/betterproto/enum.py
Normal file
196
src/betterproto/enum.py
Normal file
@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from enum import (
|
||||
EnumMeta,
|
||||
IntEnum,
|
||||
)
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
Generator,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Never,
|
||||
Self,
|
||||
)
|
||||
|
||||
|
||||
def _is_descriptor(obj: object) -> bool:
|
||||
return (
|
||||
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
|
||||
)
|
||||
|
||||
|
||||
class EnumType(EnumMeta if TYPE_CHECKING else type):
|
||||
_value_map_: Mapping[int, Enum]
|
||||
_member_map_: Mapping[str, Enum]
|
||||
|
||||
def __new__(
|
||||
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
|
||||
) -> Self:
|
||||
value_map = {}
|
||||
member_map = {}
|
||||
|
||||
new_mcs = type(
|
||||
f"{name}Type",
|
||||
tuple(
|
||||
dict.fromkeys(
|
||||
[base.__class__ for base in bases if base.__class__ is not type]
|
||||
+ [EnumType, type]
|
||||
)
|
||||
), # reorder the bases so EnumType and type are last to avoid conflicts
|
||||
{"_value_map_": value_map, "_member_map_": member_map},
|
||||
)
|
||||
|
||||
members = {
|
||||
name: value
|
||||
for name, value in namespace.items()
|
||||
if not _is_descriptor(value) and not name.startswith("__")
|
||||
}
|
||||
|
||||
cls = type.__new__(
|
||||
new_mcs,
|
||||
name,
|
||||
bases,
|
||||
{key: value for key, value in namespace.items() if key not in members},
|
||||
)
|
||||
# this allows us to disallow member access from other members as
|
||||
# members become proper class variables
|
||||
|
||||
for name, value in members.items():
|
||||
member = value_map.get(value)
|
||||
if member is None:
|
||||
member = cls.__new__(cls, name=name, value=value) # type: ignore
|
||||
value_map[value] = member
|
||||
member_map[name] = member
|
||||
type.__setattr__(new_mcs, name, member)
|
||||
|
||||
return cls
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __call__(cls, value: int) -> Enum:
|
||||
try:
|
||||
return cls._value_map_[value]
|
||||
except (KeyError, TypeError):
|
||||
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
|
||||
|
||||
def __iter__(cls) -> Generator[Enum, None, None]:
|
||||
yield from cls._member_map_.values()
|
||||
|
||||
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
|
||||
|
||||
def __reversed__(cls) -> Generator[Enum, None, None]:
|
||||
yield from reversed(cls._member_map_.values())
|
||||
|
||||
else:
|
||||
|
||||
def __reversed__(cls) -> Generator[Enum, None, None]:
|
||||
yield from reversed(tuple(cls._member_map_.values()))
|
||||
|
||||
def __getitem__(cls, key: str) -> Enum:
|
||||
return cls._member_map_[key]
|
||||
|
||||
@property
|
||||
def __members__(cls) -> MappingProxyType[str, Enum]:
|
||||
return MappingProxyType(cls._member_map_)
|
||||
|
||||
def __repr__(cls) -> str:
|
||||
return f"<enum {cls.__name__!r}>"
|
||||
|
||||
def __len__(cls) -> int:
|
||||
return len(cls._member_map_)
|
||||
|
||||
def __setattr__(cls, name: str, value: Any) -> Never:
|
||||
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
|
||||
|
||||
def __delattr__(cls, name: str) -> Never:
|
||||
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
|
||||
|
||||
def __contains__(cls, member: object) -> bool:
|
||||
return isinstance(member, cls) and member.name in cls._member_map_
|
||||
|
||||
|
||||
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
|
||||
"""
|
||||
The base class for protobuf enumerations, all generated enumerations will
|
||||
inherit from this. Emulates `enum.IntEnum`.
|
||||
"""
|
||||
|
||||
name: Optional[str]
|
||||
value: int
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __new__(cls, *, name: Optional[str], value: int) -> Self:
|
||||
self = super().__new__(cls, value)
|
||||
super().__setattr__(self, "name", name)
|
||||
super().__setattr__(self, "value", value)
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name or "None"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> Never:
|
||||
raise AttributeError(
|
||||
f"{self.__class__.__name__} Cannot reassign a member's attributes."
|
||||
)
|
||||
|
||||
def __delattr__(self, item: Any) -> Never:
|
||||
raise AttributeError(
|
||||
f"{self.__class__.__name__} Cannot delete a member's attributes."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def try_value(cls, value: int = 0) -> Self:
|
||||
"""Return the value which corresponds to the value.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
value: :class:`int`
|
||||
The value of the enum member to get.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`Enum`
|
||||
The corresponding member or a new instance of the enum if
|
||||
``value`` isn't actually a member.
|
||||
"""
|
||||
try:
|
||||
return cls._value_map_[value]
|
||||
except (KeyError, TypeError):
|
||||
return cls.__new__(cls, name=None, value=value)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str) -> Self:
|
||||
"""Return the value which corresponds to the string name.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the enum member to get.
|
||||
|
||||
Raises
|
||||
-------
|
||||
:exc:`ValueError`
|
||||
The member was not found in the Enum.
|
||||
"""
|
||||
try:
|
||||
return cls._member_map_[name]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
@ -127,6 +127,7 @@ class ServiceStub(ABC):
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
await stream.send_request()
|
||||
await self._send_messages(stream, request_iterator)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
|
@ -72,13 +72,13 @@ from betterproto.lib.google.protobuf import (
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from ..compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_enum_member_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
@ -385,7 +385,10 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
us to tell whether it was set, via the which_one_of interface.
|
||||
"""
|
||||
|
||||
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
return (
|
||||
not proto_field_obj.proto3_optional
|
||||
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -670,7 +673,9 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
name=sanitize_name(entry_proto_value.name),
|
||||
name=pythonize_enum_member_name(
|
||||
entry_proto_value.name, self.proto_obj.name
|
||||
),
|
||||
value=entry_proto_value.number,
|
||||
comment=get_comment(
|
||||
proto_file=self.source_file, path=self.path + [2, entry_number]
|
||||
|
56
src/betterproto/utils.py
Normal file
56
src/betterproto/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Concatenate,
|
||||
ParamSpec,
|
||||
Self,
|
||||
)
|
||||
|
||||
|
||||
SelfT = TypeVar("SelfT")
|
||||
P = ParamSpec("P")
|
||||
HybridT = TypeVar("HybridT", covariant=True)
|
||||
|
||||
|
||||
class hybridmethod(Generic[SelfT, P, HybridT]):
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[
|
||||
Concatenate[type[SelfT], P], HybridT
|
||||
], # Must be the classmethod version
|
||||
):
|
||||
self.cls_func = func
|
||||
self.__doc__ = func.__doc__
|
||||
|
||||
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
|
||||
self.instance_func = func
|
||||
return self
|
||||
|
||||
def __get__(
|
||||
self, instance: Optional[SelfT], owner: Type[SelfT]
|
||||
) -> Callable[P, HybridT]:
|
||||
if instance is None or self.instance_func is None:
|
||||
# either bound to the class, or no instance method available
|
||||
return self.cls_func.__get__(owner, None)
|
||||
return self.instance_func.__get__(instance, owner)
|
||||
|
||||
|
||||
T_co = TypeVar("T_co")
|
||||
TT_co = TypeVar("TT_co", bound="type[Any]")
|
||||
|
||||
|
||||
class classproperty(Generic[TT_co, T_co]):
|
||||
def __init__(self, func: Callable[[TT_co], T_co]):
|
||||
self.__func__ = func
|
||||
|
||||
def __get__(self, instance: Any, type: TT_co) -> T_co:
|
||||
return self.__func__(type)
|
@ -272,3 +272,27 @@ async def test_async_gen_for_stream_stream_request():
|
||||
assert response_index == len(
|
||||
expected_things
|
||||
), "Didn't receive all expected responses"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_unary_with_empty_iterable():
|
||||
things = [] # empty
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
requests = [DoThingRequest(name) for name in things]
|
||||
response = await client.do_many_things(requests)
|
||||
assert len(response.names) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_stream_with_empty_iterable():
|
||||
things = [] # empty
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
requests = [GetThingRequest(name) for name in things]
|
||||
responses = [
|
||||
response async for response in client.get_different_things(requests)
|
||||
]
|
||||
assert len(responses) == 0
|
||||
|
@ -27,7 +27,7 @@ class ThingService:
|
||||
async def do_many_things(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name for request in stream]
|
||||
thing_names = [request.name async for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
@ -15,3 +15,11 @@ enum Choice {
|
||||
FOUR = 4;
|
||||
THREE = 3;
|
||||
}
|
||||
|
||||
// A "C" like enum with the enum name prefixed onto members, these should be stripped
|
||||
enum ArithmeticOperator {
|
||||
ARITHMETIC_OPERATOR_NONE = 0;
|
||||
ARITHMETIC_OPERATOR_PLUS = 1;
|
||||
ARITHMETIC_OPERATOR_MINUS = 2;
|
||||
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
from tests.output_betterproto.enum import (
|
||||
ArithmeticOperator,
|
||||
Choice,
|
||||
Test,
|
||||
)
|
||||
@ -82,3 +83,32 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
|
||||
yield Choice.THREE
|
||||
|
||||
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
|
||||
|
||||
|
||||
def test_enum_mapped_on_parse():
|
||||
# test default value
|
||||
b = Test().parse(bytes(Test()))
|
||||
assert b.choice.name == Choice.ZERO.name
|
||||
assert b.choices == []
|
||||
|
||||
# test non default value
|
||||
a = Test().parse(bytes(Test(choice=Choice.ONE)))
|
||||
assert a.choice.name == Choice.ONE.name
|
||||
assert b.choices == []
|
||||
|
||||
# test repeated
|
||||
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
|
||||
assert c.choices[0].name == Choice.THREE.name
|
||||
assert c.choices[1].name == Choice.FOUR.name
|
||||
|
||||
# bonus: defaults after empty init are also mapped
|
||||
assert Test().choice.name == Choice.ZERO.name
|
||||
|
||||
|
||||
def test_renamed_enum_members():
|
||||
assert set(ArithmeticOperator.__members__) == {
|
||||
"NONE",
|
||||
"PLUS",
|
||||
"MINUS",
|
||||
"_0_PREFIXED",
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
package google_impl_behavior_equivalence;
|
||||
|
||||
message Foo { int64 bar = 1; }
|
||||
@ -12,6 +13,10 @@ message Test {
|
||||
}
|
||||
}
|
||||
|
||||
message Spam {
|
||||
google.protobuf.Timestamp ts = 1;
|
||||
}
|
||||
|
||||
message Request { Empty foo = 1; }
|
||||
|
||||
message Empty {}
|
||||
message Empty {}
|
||||
|
@ -1,17 +1,25 @@
|
||||
from datetime import (
|
||||
datetime,
|
||||
timezone,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
|
||||
import betterproto
|
||||
from tests.output_betterproto.google_impl_behavior_equivalence import (
|
||||
Empty,
|
||||
Foo,
|
||||
Request,
|
||||
Spam,
|
||||
Test,
|
||||
)
|
||||
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
|
||||
Empty as ReferenceEmpty,
|
||||
Foo as ReferenceFoo,
|
||||
Request as ReferenceRequest,
|
||||
Spam as ReferenceSpam,
|
||||
Test as ReferenceTest,
|
||||
)
|
||||
|
||||
@ -59,6 +67,19 @@ def test_bytes_are_the_same_for_oneof():
|
||||
assert isinstance(message_reference2.foo, ReferenceFoo)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
|
||||
def test_datetime_clamping(dt): # see #407
|
||||
ts = Timestamp()
|
||||
ts.FromDatetime(dt)
|
||||
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
|
||||
message_bytes = bytes(Spam(dt))
|
||||
|
||||
assert (
|
||||
Spam().parse(message_bytes).ts.timestamp()
|
||||
== ReferenceSpam.FromString(message_bytes).ts.seconds
|
||||
)
|
||||
|
||||
|
||||
def test_empty_message_field():
|
||||
message = Request()
|
||||
reference_message = ReferenceRequest()
|
||||
|
@ -2,6 +2,10 @@ syntax = "proto3";
|
||||
|
||||
package oneof;
|
||||
|
||||
message MixedDrink {
|
||||
int32 shots = 1;
|
||||
}
|
||||
|
||||
message Test {
|
||||
oneof foo {
|
||||
int32 pitied = 1;
|
||||
@ -13,6 +17,7 @@ message Test {
|
||||
oneof bar {
|
||||
int32 drinks = 11;
|
||||
string bar_name = 12;
|
||||
MixedDrink mixed_drink = 13;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
import pytest
|
||||
|
||||
import betterproto
|
||||
from tests.output_betterproto.oneof import Test
|
||||
from tests.output_betterproto.oneof import (
|
||||
MixedDrink,
|
||||
Test,
|
||||
)
|
||||
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
|
||||
from tests.util import get_test_case_json_data
|
||||
|
||||
@ -19,3 +24,20 @@ def test_which_name():
|
||||
def test_which_count_pyd():
|
||||
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
|
||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||
|
||||
|
||||
def test_oneof_constructor_assign():
|
||||
message = Test(mixed_drink=MixedDrink(shots=42))
|
||||
field, value = betterproto.which_one_of(message, "bar")
|
||||
assert field == "mixed_drink"
|
||||
assert value.shots == 42
|
||||
|
||||
|
||||
# Issue #305:
|
||||
@pytest.mark.xfail
|
||||
def test_oneof_nested_assign():
|
||||
message = Test()
|
||||
message.mixed_drink.shots = 42
|
||||
field, value = betterproto.which_one_of(message, "bar")
|
||||
assert field == "mixed_drink"
|
||||
assert value.shots == 42
|
||||
|
@ -41,3 +41,8 @@ def test_null_fields_json():
|
||||
"test8": None,
|
||||
"test9": None,
|
||||
}
|
||||
|
||||
|
||||
def test_unset_access(): # see #523
|
||||
assert Test().test1 is None
|
||||
assert Test(test1=None).test1 is None
|
||||
|
2
tests/streams/delimited_messages.in
Normal file
2
tests/streams/delimited_messages.in
Normal file
@ -0,0 +1,2 @@
|
||||
•šï:bTesting•šï:bTesting
|
||||
|
38
tests/streams/java/.gitignore
vendored
Normal file
38
tests/streams/java/.gitignore
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
### Output ###
|
||||
target/
|
||||
!.mvn/wrapper/maven-wrapper.jar
|
||||
!**/src/main/**/target/
|
||||
!**/src/test/**/target/
|
||||
dependency-reduced-pom.xml
|
||||
MANIFEST.MF
|
||||
|
||||
### IntelliJ IDEA ###
|
||||
.idea/
|
||||
*.iws
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
### Eclipse ###
|
||||
.apt_generated
|
||||
.classpath
|
||||
.factorypath
|
||||
.project
|
||||
.settings
|
||||
.springBeans
|
||||
.sts4-cache
|
||||
|
||||
### NetBeans ###
|
||||
/nbproject/private/
|
||||
/nbbuild/
|
||||
/dist/
|
||||
/nbdist/
|
||||
/.nb-gradle/
|
||||
build/
|
||||
!**/src/main/**/build/
|
||||
!**/src/test/**/build/
|
||||
|
||||
### VS Code ###
|
||||
.vscode/
|
||||
|
||||
### Mac OS ###
|
||||
.DS_Store
|
94
tests/streams/java/pom.xml
Normal file
94
tests/streams/java/pom.xml
Normal file
@ -0,0 +1,94 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>betterproto</groupId>
|
||||
<artifactId>compatibility-test</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<protobuf.version>3.23.4</protobuf.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${protobuf.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<extensions>
|
||||
<extension>
|
||||
<groupId>kr.motd.maven</groupId>
|
||||
<artifactId>os-maven-plugin</artifactId>
|
||||
<version>1.7.1</version>
|
||||
</extension>
|
||||
</extensions>
|
||||
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.5.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<mainClass>betterproto.CompatibilityTest</mainClass>
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<version>3.3.0</version>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifest>
|
||||
<addClasspath>true</addClasspath>
|
||||
<mainClass>betterproto.CompatibilityTest</mainClass>
|
||||
</manifest>
|
||||
</archive>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.xolstice.maven.plugins</groupId>
|
||||
<artifactId>protobuf-maven-plugin</artifactId>
|
||||
<version>0.6.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>compile</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<configuration>
|
||||
<protocArtifact>
|
||||
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
|
||||
</protocArtifact>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
||||
<finalName>${project.artifactId}</finalName>
|
||||
</build>
|
||||
|
||||
</project>
|
@ -0,0 +1,41 @@
|
||||
package betterproto;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class CompatibilityTest {
|
||||
public static void main(String[] args) throws IOException {
|
||||
if (args.length < 2)
|
||||
throw new RuntimeException("Attempted to run without the required arguments.");
|
||||
else if (args.length > 2)
|
||||
throw new RuntimeException(
|
||||
"Attempted to run with more than the expected number of arguments (>1).");
|
||||
|
||||
Tests tests = new Tests(args[1]);
|
||||
|
||||
switch (args[0]) {
|
||||
case "single_varint":
|
||||
tests.testSingleVarint();
|
||||
break;
|
||||
|
||||
case "multiple_varints":
|
||||
tests.testMultipleVarints();
|
||||
break;
|
||||
|
||||
case "single_message":
|
||||
tests.testSingleMessage();
|
||||
break;
|
||||
|
||||
case "multiple_messages":
|
||||
tests.testMultipleMessages();
|
||||
break;
|
||||
|
||||
case "infinite_messages":
|
||||
tests.testInfiniteMessages();
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new RuntimeException(
|
||||
"Attempted to run with unknown argument '" + args[0] + "'.");
|
||||
}
|
||||
}
|
||||
}
|
115
tests/streams/java/src/main/java/betterproto/Tests.java
Normal file
115
tests/streams/java/src/main/java/betterproto/Tests.java
Normal file
@ -0,0 +1,115 @@
|
||||
package betterproto;
|
||||
|
||||
import betterproto.nested.NestedOuterClass;
|
||||
import betterproto.oneof.Oneof;
|
||||
|
||||
import com.google.protobuf.CodedInputStream;
|
||||
import com.google.protobuf.CodedOutputStream;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
public class Tests {
|
||||
String path;
|
||||
|
||||
public Tests(String path) {
|
||||
this.path = path;
|
||||
}
|
||||
|
||||
public void testSingleVarint() throws IOException {
|
||||
// Read in the Python-generated single varint file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
|
||||
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||
|
||||
int value = codedInput.readUInt32();
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the value back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
|
||||
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||
|
||||
codedOutput.writeUInt32NoTag(value);
|
||||
|
||||
codedOutput.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testMultipleVarints() throws IOException {
|
||||
// Read in the Python-generated multiple varints file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
|
||||
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||
|
||||
int value1 = codedInput.readUInt32();
|
||||
int value2 = codedInput.readUInt32();
|
||||
long value3 = codedInput.readUInt64();
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the values back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
|
||||
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||
|
||||
codedOutput.writeUInt32NoTag(value1);
|
||||
codedOutput.writeUInt64NoTag(value2);
|
||||
codedOutput.writeUInt64NoTag(value3);
|
||||
|
||||
codedOutput.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testSingleMessage() throws IOException {
|
||||
// Read in the Python-generated single message file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
|
||||
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||
|
||||
Oneof.Test message = Oneof.Test.parseFrom(codedInput);
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the message back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
|
||||
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||
|
||||
message.writeTo(codedOutput);
|
||||
|
||||
codedOutput.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testMultipleMessages() throws IOException {
|
||||
// Read in the Python-generated multi-message file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
|
||||
|
||||
Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||
NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the messages back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
|
||||
|
||||
oneof.writeDelimitedTo(outputStream);
|
||||
nested.writeDelimitedTo(outputStream);
|
||||
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testInfiniteMessages() throws IOException {
|
||||
// Read in as many messages as are present in the Python-generated file and write them back
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
|
||||
|
||||
Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||
while (current != null) {
|
||||
current.writeDelimitedTo(outputStream);
|
||||
current = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||
}
|
||||
|
||||
inputStream.close();
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
}
|
27
tests/streams/java/src/main/proto/betterproto/nested.proto
Normal file
27
tests/streams/java/src/main/proto/betterproto/nested.proto
Normal file
@ -0,0 +1,27 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package nested;
|
||||
option java_package = "betterproto.nested";
|
||||
|
||||
// A test message with a nested message inside of it.
|
||||
message Test {
|
||||
// This is the nested type.
|
||||
message Nested {
|
||||
// Stores a simple counter.
|
||||
int32 count = 1;
|
||||
}
|
||||
// This is the nested enum.
|
||||
enum Msg {
|
||||
NONE = 0;
|
||||
THIS = 1;
|
||||
}
|
||||
|
||||
Nested nested = 1;
|
||||
Sibling sibling = 2;
|
||||
Sibling sibling2 = 3;
|
||||
Msg msg = 4;
|
||||
}
|
||||
|
||||
message Sibling {
|
||||
int32 foo = 1;
|
||||
}
|
19
tests/streams/java/src/main/proto/betterproto/oneof.proto
Normal file
19
tests/streams/java/src/main/proto/betterproto/oneof.proto
Normal file
@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package oneof;
|
||||
option java_package = "betterproto.oneof";
|
||||
|
||||
message Test {
|
||||
oneof foo {
|
||||
int32 pitied = 1;
|
||||
string pitier = 2;
|
||||
}
|
||||
|
||||
int32 just_a_regular_field = 3;
|
||||
|
||||
oneof bar {
|
||||
int32 drinks = 11;
|
||||
string bar_name = 12;
|
||||
}
|
||||
}
|
||||
|
79
tests/test_enum.py
Normal file
79
tests/test_enum.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
import betterproto
|
||||
|
||||
|
||||
class Colour(betterproto.Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
|
||||
PURPLE = Colour.__new__(Colour, name=None, value=4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, str_value",
|
||||
[
|
||||
(Colour.RED, "RED"),
|
||||
(Colour.GREEN, "GREEN"),
|
||||
(Colour.BLUE, "BLUE"),
|
||||
],
|
||||
)
|
||||
def test_str(member: Colour, str_value: str) -> None:
|
||||
assert str(member) == str_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, repr_value",
|
||||
[
|
||||
(Colour.RED, "Colour.RED"),
|
||||
(Colour.GREEN, "Colour.GREEN"),
|
||||
(Colour.BLUE, "Colour.BLUE"),
|
||||
],
|
||||
)
|
||||
def test_repr(member: Colour, repr_value: str) -> None:
|
||||
assert repr(member) == repr_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, values",
|
||||
[
|
||||
(Colour.RED, ("RED", 1)),
|
||||
(Colour.GREEN, ("GREEN", 2)),
|
||||
(Colour.BLUE, ("BLUE", 3)),
|
||||
(PURPLE, (None, 4)),
|
||||
],
|
||||
)
|
||||
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
|
||||
assert (member.name, member.value) == values
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, input_str",
|
||||
[
|
||||
(Colour.RED, "RED"),
|
||||
(Colour.GREEN, "GREEN"),
|
||||
(Colour.BLUE, "BLUE"),
|
||||
],
|
||||
)
|
||||
def test_from_string(member: Colour, input_str: str) -> None:
|
||||
assert Colour.from_string(input_str) == member
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, input_int",
|
||||
[
|
||||
(Colour.RED, 1),
|
||||
(Colour.GREEN, 2),
|
||||
(Colour.BLUE, 3),
|
||||
(PURPLE, 4),
|
||||
],
|
||||
)
|
||||
def test_try_value(member: Colour, input_int: int) -> None:
|
||||
assert Colour.try_value(input_int) == member
|
@ -545,47 +545,6 @@ def test_oneof_default_value_set_causes_writes_wire():
|
||||
)
|
||||
|
||||
|
||||
def test_recursive_message():
|
||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||
|
||||
msg = RecursiveMessage()
|
||||
|
||||
assert msg.child == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect equality.
|
||||
assert msg == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect serialization.
|
||||
assert bytes(msg) == b""
|
||||
|
||||
|
||||
def test_recursive_message_defaults():
|
||||
from tests.output_betterproto.recursivemessage import (
|
||||
Intermediate,
|
||||
Test as RecursiveMessage,
|
||||
)
|
||||
|
||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
# set values are as expected
|
||||
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
# lazy initialized works modifies the message
|
||||
assert msg != RecursiveMessage(
|
||||
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
||||
)
|
||||
msg.child.child.name = "jude"
|
||||
assert msg == RecursiveMessage(
|
||||
name="bob",
|
||||
intermediate=Intermediate(42),
|
||||
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
||||
)
|
||||
|
||||
# lazily initialization recurses as needed
|
||||
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
||||
assert msg.intermediate.child.intermediate == Intermediate()
|
||||
|
||||
|
||||
def test_message_repr():
|
||||
from tests.output_betterproto.recursivemessage import Test
|
||||
|
||||
@ -699,25 +658,6 @@ def test_service_argument__expected_parameter():
|
||||
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
||||
|
||||
|
||||
def test_copyability():
|
||||
@dataclass
|
||||
class Spam(betterproto.Message):
|
||||
foo: bool = betterproto.bool_field(1)
|
||||
bar: int = betterproto.int32_field(2)
|
||||
baz: List[str] = betterproto.string_field(3)
|
||||
|
||||
spam = Spam(bar=12, baz=["hello"])
|
||||
copied = copy(spam)
|
||||
assert spam == copied
|
||||
assert spam is not copied
|
||||
assert spam.baz is copied.baz
|
||||
|
||||
deepcopied = deepcopy(spam)
|
||||
assert spam == deepcopied
|
||||
assert spam is not deepcopied
|
||||
assert spam.baz is not deepcopied.baz
|
||||
|
||||
|
||||
def test_is_set():
|
||||
@dataclass
|
||||
class Spam(betterproto.Message):
|
||||
|
203
tests/test_pickling.py
Normal file
203
tests/test_pickling.py
Normal file
@ -0,0 +1,203 @@
|
||||
import pickle
|
||||
from copy import (
|
||||
copy,
|
||||
deepcopy,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
from unittest.mock import ANY
|
||||
|
||||
import cachelib
|
||||
|
||||
import betterproto
|
||||
from betterproto.lib.google import protobuf as google
|
||||
|
||||
|
||||
def unpickled(message):
|
||||
return pickle.loads(pickle.dumps(message))
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Fe(betterproto.Message):
|
||||
abc: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Fi(betterproto.Message):
|
||||
abc: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Fo(betterproto.Message):
|
||||
abc: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class NestedData(betterproto.Message):
|
||||
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
|
||||
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||
)
|
||||
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
|
||||
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Complex(betterproto.Message):
|
||||
foo_str: str = betterproto.string_field(1)
|
||||
fe: "Fe" = betterproto.message_field(3, group="grp")
|
||||
fi: "Fi" = betterproto.message_field(4, group="grp")
|
||||
fo: "Fo" = betterproto.message_field(5, group="grp")
|
||||
nested_data: "NestedData" = betterproto.message_field(6)
|
||||
mapping: Dict[str, "google.Any"] = betterproto.map_field(
|
||||
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
def complex_msg():
|
||||
return Complex(
|
||||
foo_str="yep",
|
||||
fe=Fe(abc="1"),
|
||||
nested_data=NestedData(
|
||||
struct_foo={
|
||||
"foo": google.Struct(
|
||||
fields={
|
||||
"hello": google.Value(
|
||||
list_value=google.ListValue(
|
||||
values=[google.Value(string_value="world")]
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
map_str_any_bar={
|
||||
"key": google.Any(value=b"value"),
|
||||
},
|
||||
),
|
||||
mapping={
|
||||
"message": google.Any(value=bytes(Fi(abc="hi"))),
|
||||
"string": google.Any(value=b"howdy"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_pickling_complex_message():
|
||||
msg = complex_msg()
|
||||
deser = unpickled(msg)
|
||||
assert msg == deser
|
||||
assert msg.fe.abc == "1"
|
||||
assert msg.is_set("fi") is not True
|
||||
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
|
||||
assert msg.mapping["string"].value.decode() == "howdy"
|
||||
assert (
|
||||
msg.nested_data.struct_foo["foo"]
|
||||
.fields["hello"]
|
||||
.list_value.values[0]
|
||||
.string_value
|
||||
== "world"
|
||||
)
|
||||
|
||||
|
||||
def test_recursive_message():
|
||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||
|
||||
msg = RecursiveMessage()
|
||||
msg = unpickled(msg)
|
||||
|
||||
assert msg.child == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect equality.
|
||||
assert msg == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect serialization.
|
||||
assert bytes(msg) == b""
|
||||
|
||||
|
||||
def test_recursive_message_defaults():
|
||||
from tests.output_betterproto.recursivemessage import (
|
||||
Intermediate,
|
||||
Test as RecursiveMessage,
|
||||
)
|
||||
|
||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
msg = unpickled(msg)
|
||||
|
||||
# set values are as expected
|
||||
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
# lazy initialized works modifies the message
|
||||
assert msg != RecursiveMessage(
|
||||
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
||||
)
|
||||
msg.child.child.name = "jude"
|
||||
assert msg == RecursiveMessage(
|
||||
name="bob",
|
||||
intermediate=Intermediate(42),
|
||||
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
||||
)
|
||||
|
||||
# lazily initialization recurses as needed
|
||||
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
||||
assert msg.intermediate.child.intermediate == Intermediate()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PickledMessage(betterproto.Message):
|
||||
foo: bool = betterproto.bool_field(1)
|
||||
bar: int = betterproto.int32_field(2)
|
||||
baz: List[str] = betterproto.string_field(3)
|
||||
|
||||
|
||||
def test_copyability():
|
||||
msg = PickledMessage(bar=12, baz=["hello"])
|
||||
msg = unpickled(msg)
|
||||
|
||||
copied = copy(msg)
|
||||
assert msg == copied
|
||||
assert msg is not copied
|
||||
assert msg.baz is copied.baz
|
||||
|
||||
deepcopied = deepcopy(msg)
|
||||
assert msg == deepcopied
|
||||
assert msg is not deepcopied
|
||||
assert msg.baz is not deepcopied.baz
|
||||
|
||||
|
||||
def test_message_can_be_cached():
|
||||
"""Cachelib uses pickling to cache values"""
|
||||
|
||||
cache = cachelib.SimpleCache()
|
||||
|
||||
def use_cache():
|
||||
calls = getattr(use_cache, "calls", 0)
|
||||
result = cache.get("message")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
setattr(use_cache, "calls", calls + 1)
|
||||
result = complex_msg()
|
||||
cache.set("message", result)
|
||||
return result
|
||||
|
||||
for n in range(10):
|
||||
if n == 0:
|
||||
assert not cache.has("message")
|
||||
else:
|
||||
assert cache.has("message")
|
||||
|
||||
msg = use_cache()
|
||||
assert use_cache.calls == 1 # The message is only ever built once
|
||||
assert msg.fe.abc == "1"
|
||||
assert msg.is_set("fi") is not True
|
||||
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
|
||||
assert msg.mapping["string"].value.decode() == "howdy"
|
||||
assert (
|
||||
msg.nested_data.struct_foo["foo"]
|
||||
.fields["hello"]
|
||||
.list_value.values[0]
|
||||
.string_value
|
||||
== "world"
|
||||
)
|
@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from subprocess import run
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -40,6 +42,8 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
|
||||
|
||||
streams_path = Path("tests/streams/")
|
||||
|
||||
java = which("java")
|
||||
|
||||
|
||||
def test_load_varint_too_long():
|
||||
with BytesIO(
|
||||
@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path):
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
def test_message_dump_delimited(tmp_path):
|
||||
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
|
||||
streams_path / "delimited_messages.in", "rb"
|
||||
) as exp_stream:
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
def test_message_len():
|
||||
assert len_oneof == len(bytes(oneof_example))
|
||||
assert len(nested_example) == len(bytes(nested_example))
|
||||
@ -155,7 +171,15 @@ def test_message_load_too_small():
|
||||
oneof.Test().load(stream, len_oneof - 1)
|
||||
|
||||
|
||||
def test_message_too_large():
|
||||
def test_message_load_delimited():
|
||||
with open(streams_path / "delimited_messages.in", "rb") as stream:
|
||||
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
|
||||
assert stream.read(1) == b""
|
||||
|
||||
|
||||
def test_message_load_too_large():
|
||||
with open(
|
||||
streams_path / "message_dump_file_single.expected", "rb"
|
||||
) as stream, pytest.raises(ValueError):
|
||||
@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path):
|
||||
streams_path / "dump_varint_positive.expected", "rb"
|
||||
) as exp_stream:
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
# Java compatibility tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def compile_jar():
|
||||
# Skip if not all required tools are present
|
||||
if java is None:
|
||||
pytest.skip("`java` command is absent and is required")
|
||||
mvn = which("mvn")
|
||||
if mvn is None:
|
||||
pytest.skip("Maven is absent and is required")
|
||||
|
||||
# Compile the JAR
|
||||
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
|
||||
if proc_maven.returncode != 0:
|
||||
pytest.skip(
|
||||
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
|
||||
)
|
||||
|
||||
|
||||
jar = "tests/streams/java/target/compatibility-test.jar"
|
||||
|
||||
|
||||
def run_jar(command: str, tmp_path):
|
||||
return run([java, "-jar", jar, command, tmp_path], check=True)
|
||||
|
||||
|
||||
def run_java_single_varint(value: int, tmp_path) -> int:
|
||||
# Write single varint to file
|
||||
with open(tmp_path / "py_single_varint.out", "wb") as stream:
|
||||
betterproto.dump_varint(value, stream)
|
||||
|
||||
# Have Java read this varint and write it back
|
||||
run_jar("single_varint", tmp_path)
|
||||
|
||||
# Read single varint from Java output file
|
||||
with open(tmp_path / "java_single_varint.out", "rb") as stream:
|
||||
returned = betterproto.load_varint(stream)
|
||||
with pytest.raises(EOFError):
|
||||
betterproto.load_varint(stream)
|
||||
|
||||
return returned
|
||||
|
||||
|
||||
def test_single_varint(compile_jar, tmp_path):
|
||||
single_byte = (1, b"\x01")
|
||||
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
|
||||
|
||||
# Write a single-byte varint to a file and have Java read it back
|
||||
returned = run_java_single_varint(single_byte[0], tmp_path)
|
||||
assert returned == single_byte
|
||||
|
||||
# Same for a multi-byte varint
|
||||
returned = run_java_single_varint(multi_byte[0], tmp_path)
|
||||
assert returned == multi_byte
|
||||
|
||||
|
||||
def test_multiple_varints(compile_jar, tmp_path):
|
||||
single_byte = (1, b"\x01")
|
||||
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
|
||||
over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")
|
||||
|
||||
# Write two varints to the same file
|
||||
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
|
||||
betterproto.dump_varint(single_byte[0], stream)
|
||||
betterproto.dump_varint(multi_byte[0], stream)
|
||||
betterproto.dump_varint(over32[0], stream)
|
||||
|
||||
# Have Java read these varints and write them back
|
||||
run_jar("multiple_varints", tmp_path)
|
||||
|
||||
# Read varints from Java output file
|
||||
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
|
||||
returned_single = betterproto.load_varint(stream)
|
||||
returned_multi = betterproto.load_varint(stream)
|
||||
returned_over32 = betterproto.load_varint(stream)
|
||||
with pytest.raises(EOFError):
|
||||
betterproto.load_varint(stream)
|
||||
|
||||
assert returned_single == single_byte
|
||||
assert returned_multi == multi_byte
|
||||
assert returned_over32 == over32
|
||||
|
||||
|
||||
def test_single_message(compile_jar, tmp_path):
|
||||
# Write message to file
|
||||
with open(tmp_path / "py_single_message.out", "wb") as stream:
|
||||
oneof_example.dump(stream)
|
||||
|
||||
# Have Java read and return the message
|
||||
run_jar("single_message", tmp_path)
|
||||
|
||||
# Read and check the returned message
|
||||
with open(tmp_path / "java_single_message.out", "rb") as stream:
|
||||
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
|
||||
assert stream.read() == b""
|
||||
|
||||
assert returned == oneof_example
|
||||
|
||||
|
||||
def test_multiple_messages(compile_jar, tmp_path):
|
||||
# Write delimited messages to file
|
||||
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
# Have Java read and return the messages
|
||||
run_jar("multiple_messages", tmp_path)
|
||||
|
||||
# Read and check the returned messages
|
||||
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
|
||||
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||
assert stream.read() == b""
|
||||
|
||||
assert returned_oneof == oneof_example
|
||||
assert returned_nested == nested_example
|
||||
|
||||
|
||||
def test_infinite_messages(compile_jar, tmp_path):
|
||||
num_messages = 5
|
||||
|
||||
# Write delimited messages to file
|
||||
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
|
||||
for x in range(num_messages):
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
# Have Java read and return the messages
|
||||
run_jar("infinite_messages", tmp_path)
|
||||
|
||||
# Read and check the returned messages
|
||||
messages = []
|
||||
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
|
||||
while True:
|
||||
try:
|
||||
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
assert len(messages) == num_messages
|
||||
|
27
tests/test_timestamp.py
Normal file
27
tests/test_timestamp.py
Normal file
@ -0,0 +1,27 @@
|
||||
from datetime import (
|
||||
datetime,
|
||||
timezone,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from betterproto import _Timestamp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dt",
|
||||
[
|
||||
datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
|
||||
datetime.now(timezone.utc),
|
||||
# potential issue with floating point precision:
|
||||
datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
|
||||
# potential issue with negative timestamps:
|
||||
datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
|
||||
],
|
||||
)
|
||||
def test_timestamp_to_datetime_and_back(dt: datetime):
|
||||
"""
|
||||
Make sure converting a datetime to a protobuf timestamp message
|
||||
and then back again ends up with the same datetime.
|
||||
"""
|
||||
assert _Timestamp.from_datetime(dt).to_datetime() == dt
|
Loading…
x
Reference in New Issue
Block a user