Compare commits

..

22 Commits
master ... next

Author SHA1 Message Date
Georg K
87b84afc4b Merge branch 'rust_extras' into next 2023-09-06 14:34:14 +03:00
Erik Friese
8283ef7298 bugfix
map fields with values of message type were not serialized correctly
2023-09-05 21:58:58 +02:00
Erik Friese
0931eb3bf5 bugfix
byte fields were deserialized incorrectly
2023-09-05 20:26:46 +02:00
Erik Friese
8f535913a1 bugfix
using python identifier in message names
necessary to distinguish between dynamically created classes with same name
2023-09-05 20:06:00 +02:00
Erik Friese
fd02cb6180 supporting datetime and timedelta 2023-09-05 11:27:04 +02:00
Erik Friese
950d2f6536 google wrapper types 2023-09-04 21:09:06 +02:00
Erik Friese
29f12ea88d map support 2023-09-04 12:52:12 +02:00
Erik Friese
219233b50e bugfix: parsing unknown fields properly 2023-08-31 17:57:28 +02:00
Erik Friese
2d30bdb7b2 Merge branch 'master' into rust_extras 2023-08-31 17:39:34 +02:00
Erik Friese
bdd3389b17 bugfix in proto descriptor creation
reference cycles in betterproto messages have led to infinite recursion
2023-08-31 17:17:28 +02:00
Erik Friese
441844b97a avoiding name clash 2023-08-31 13:18:07 +02:00
Erik Friese
a413d08fc1 enum support 2023-08-30 21:06:32 +02:00
Erik Friese
24d694afe2 storing unknown fields 2023-08-30 15:49:25 +02:00
Erik Friese
84af157122 minor refactoring 2023-08-30 15:39:34 +02:00
Erik Friese
df0c17bf0a optional support 2023-08-29 18:08:59 +02:00
Erik Friese
d1825026db lock file reverted to master 2023-08-27 18:46:20 +02:00
Erik Friese
a12c9d24de oneof support 2023-08-27 14:37:23 +02:00
Erik Friese
d79a9eee14 deserializing lists 2023-08-27 13:22:51 +02:00
Erik Friese
d848d05710 minor optimization 2023-08-26 21:47:35 +02:00
Erik Friese
26da86d2cd proper error handling 2023-08-26 21:32:35 +02:00
Erik Friese
604dcb104f type info + doc string added 2023-08-26 21:04:30 +02:00
Erik Friese
421aa78014 Native deserialization based on Rust and PyO3
Proof of concept
Only capable of deserializing (nested) Messages with primitive fields
No handling of lists, maps, enums, .. implemented yet
See `example.py` for a working example
2023-08-26 13:04:26 +02:00
94 changed files with 4752 additions and 10320 deletions

View File

@ -1,35 +0,0 @@
name: Release
run-name: ${{ gitea.actor }} is runs ci pipeline
on:
push:
branches:
- master
jobs:
packaging:
name: Distribution
runs-on: ubuntu-latest
env:
EXT_FIX: "6"
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Install poetry
run: python -m pip install poetry chardet
- name: Install poetry compiler
run: poetry install -E compiler
- name: Set poetry version
run: PV=$(poetry version -s) && poetry version ${PV}+jar3b${EXT_FIX}
- name: Build package
run: poetry build
- name: Add pypi source
run: poetry source add --priority=supplemental ahax https://git.ahax86.ru/api/packages/pub/pypi
- name: Add pypi credentials
run: poetry config http-basic.ahax ${{ secrets.REPO_USER }} ${{ secrets.REPO_PASS }}
- name: Push to pypi
run: poetry publish -r ahax -u ${{ secrets.REPO_USER }} -p ${{ secrets.REPO_PASS }} -n

View File

@ -2,7 +2,7 @@
There's lots to do, and we're working hard, so any help is welcome!
- :speech_balloon: Join us on [Discord](https://discord.gg/DEVteTupPb)!
- :speech_balloon: Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
What can you do?
@ -15,9 +15,9 @@ What can you do?
- File a bug (please check its not a duplicate)
- Propose an enhancement
- :white_check_mark: Create a PR:
- [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/tests/README.md) to make bug-fixing easier
- [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/betterproto/tests/README.md) to make bug-fixing easier
- Fix any of the open issues
- [Good first issues](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
- [Issues with tests](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22has+test%22)
- New bugfix or idea
- If you'd like to discuss your idea first, join us on Discord!
- If you'd like to discuss your idea first, join us on Slack!

View File

@ -1,63 +0,0 @@
name: Bug Report
description: Report broken or incorrect behaviour
labels: ["bug", "investigation needed"]
body:
- type: markdown
attributes:
value: >
Thanks for taking the time to fill out a bug report!
If you're not sure it's a bug and you just have a question, the [community Discord channel](https://discord.gg/DEVteTupPb) is a better place for general questions than a GitHub issue.
- type: input
attributes:
label: Summary
description: A simple summary of your bug report
validations:
required: true
- type: textarea
attributes:
label: Reproduction Steps
description: >
What you did to make it happen.
Ideally there should be a short code snippet in this section to help reproduce the bug.
validations:
required: true
- type: textarea
attributes:
label: Expected Results
description: >
What did you expect to happen?
validations:
required: true
- type: textarea
attributes:
label: Actual Results
description: >
What actually happened?
validations:
required: true
- type: textarea
attributes:
label: System Information
description: >
Paste the result of `protoc --version; python --version; pip show betterproto` below.
validations:
required: true
- type: checkboxes
attributes:
label: Checklist
options:
- label: I have searched the issues for duplicates.
required: true
- label: I have shown the entire traceback, if possible.
required: true
- label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible.
required: true

View File

@ -1,6 +0,0 @@
name:
description:
contact_links:
- name: For questions about the library
about: Support questions are better answered in our Discord group.
url: https://discord.gg/DEVteTupPb

View File

@ -1,49 +0,0 @@
name: Feature Request
description: Suggest a feature for this library
labels: ["enhancement"]
body:
- type: input
attributes:
label: Summary
description: >
What problem is your feature trying to solve? What would become easier or possible if feature was implemented?
validations:
required: true
- type: dropdown
attributes:
multiple: false
label: What is the feature request for?
options:
- The core library
- RPC handling
- The documentation
validations:
required: true
- type: textarea
attributes:
label: The Problem
description: >
What problem is your feature trying to solve?
What would become easier or possible if feature was implemented?
validations:
required: true
- type: textarea
attributes:
label: The Ideal Solution
description: >
What is your ideal solution to the problem?
What would you like this feature to do?
validations:
required: true
- type: textarea
attributes:
label: The Current Solution
description: >
What is the current solution to the problem, if any?
validations:
required: false

View File

@ -1,16 +0,0 @@
## Summary
<!-- What is this pull request for? Does it fix any issues? -->
## Checklist
<!-- Put an x inside [ ] to check it, like so: [x] -->
- [ ] If code changes were made then they have been tested.
- [ ] I have updated the documentation to reflect the changes.
- [ ] This PR fixes an issue.
- [ ] This PR adds something new (e.g. new method or parameters).
- [ ] This change has an associated test.
- [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed)
- [ ] This PR is **not** a code change (e.g. documentation, README, ...)

View File

@ -16,19 +16,19 @@ jobs:
fail-fast: false
matrix:
os: [Ubuntu, MacOS, Windows]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Get full Python version
id: full-python-version
shell: bash
run: echo "version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")" >> "$GITHUB_OUTPUT"
run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")
- name: Install poetry
shell: bash
@ -41,7 +41,7 @@ jobs:
run: poetry config virtualenvs.in-project true
- name: Set up cache
uses: actions/cache@v4
uses: actions/cache@v3
id: cache
with:
path: .venv

View File

@ -13,6 +13,6 @@ jobs:
name: Check code/doc formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: pre-commit/action@v3.0.1
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: pre-commit/action@v2.0.3

View File

@ -1,46 +0,0 @@
name: "CodeQL"
on:
push:
branches: [ "master" ]
pull_request:
branches:
- '**'
schedule:
- cron: '19 1 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'python' ]
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
- name: Autobuild
uses: github/codeql-action/autobuild@v3
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3

View File

@ -15,11 +15,11 @@ jobs:
name: Distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
- uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.8
- name: Install poetry
run: python -m pip install poetry
- name: Build package

1
.gitignore vendored
View File

@ -18,4 +18,3 @@ output
.asv
venv
.devcontainer
.ruff_cache

View File

@ -2,24 +2,20 @@ ci:
autofix_prs: false
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.1
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: ruff-format
args: ["--diff", "src", "tests"]
- id: ruff
args: ["--select", "I", "src", "tests"]
- id: isort
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
args: ["--target-version", "py310"]
- repo: https://github.com/PyCQA/doc8
rev: 0.10.1
hooks:
- id: doc8
- id: doc8
additional_dependencies:
- toml
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.14.0
hooks:
- id: pretty-format-java
args: [--autofix, --aosp]
files: ^.*\.java$

View File

@ -7,29 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [2.0.0b7] - 2024-08-11
- **Breaking**: Support `Pydantic` v2 and dropping support for v1 [#588](https://github.com/danielgtaylor/python-betterproto/pull/588)
- **Breaking**: The attempting to access an unset `oneof` now raises an `AttributeError`
field. To see how to access `oneof` fields now, refer to [#558](https://github.com/danielgtaylor/python-betterproto/pull/558)
and [README.md](https://github.com/danielgtaylor/python-betterproto#one-of-support).
- **Breaking**: A custom `Enum` has been implemented to match the behaviour of being an open set. Any checks for `isinstance(enum_member, enum.Enum)` and `issubclass(EnumSubclass, enum.Enum)` will now return `False`. This change also has the side effect of
preventing any passthrough of `Enum` members (i.e. `Foo.RED.GREEN` doesn't work any more). See [#293](https://github.com/danielgtaylor/python-betterproto/pull/293) for more info, this fixed many bugs related to `Enum` handling.
- Add support for `pickle` methods [#535](https://github.com/danielgtaylor/python-betterproto/pull/535)
- Add support for `Struct` and `Value` types [#551](https://github.com/danielgtaylor/python-betterproto/pull/551)
- Add support for [`Rich` package](https://rich.readthedocs.io/en/latest/index.html) for pretty printing [#508](https://github.com/danielgtaylor/python-betterproto/pull/508)
- Improve support for streaming messages [#518](https://github.com/danielgtaylor/python-betterproto/pull/518) [#529](https://github.com/danielgtaylor/python-betterproto/pull/529)
- Improve performance of serializing / de-serializing messages [#545](https://github.com/danielgtaylor/python-betterproto/pull/545)
- Improve the handling of message name collisions with typing by allowing the method / type of imports to be configured.
Refer to [#582](https://github.com/danielgtaylor/python-betterproto/pull/582)
and [README.md](https://github.com/danielgtaylor/python-betterproto#configuration-typing-imports).
- Fix roundtrip parsing of `datetime`s [#534](https://github.com/danielgtaylor/python-betterproto/pull/534)
- Fix accessing unset optional fields [#523](https://github.com/danielgtaylor/python-betterproto/pull/523)
- Fix `Message` equality comparison [#513](https://github.com/danielgtaylor/python-betterproto/pull/513)
- Fix behaviour with long comment messages [#532](https://github.com/danielgtaylor/python-betterproto/pull/532)
- Add a warning when calling a deprecated message [#596](https://github.com/danielgtaylor/python-betterproto/pull/596)
## [2.0.0b6] - 2023-06-25
- **Breaking**: the minimum Python version has been bumped to `3.7` [#444](https://github.com/danielgtaylor/python-betterproto/pull/444)

View File

@ -1,7 +1,6 @@
# Better Protobuf / gRPC Support for Python
![](https://github.com/danielgtaylor/python-betterproto/actions/workflows/ci.yml/badge.svg)
![](https://github.com/danielgtaylor/python-betterproto/workflows/CI/badge.svg)
> :octocat: If you're reading this on github, please be aware that it might mention unreleased features! See the latest released README on [pypi](https://pypi.org/project/betterproto/).
This project aims to provide an improved experience when using Protobuf / gRPC in a modern Python environment by making use of modern language features and generating readable, understandable, idiomatic Python code. It will not support legacy features or environments (e.g. Protobuf 2). The following are supported:
@ -278,22 +277,7 @@ message Test {
}
```
On Python 3.10 and later, you can use a `match` statement to access the provided one-of field, which supports type-checking:
```py
test = Test()
match test:
case Test(on=value):
print(value) # value: bool
case Test(count=value):
print(value) # value: int
case Test(name=value):
print(value) # value: str
case _:
print("No value provided")
```
You can also use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.
You can use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.
```py
>>> test = Test()
@ -308,11 +292,17 @@ You can also use `betterproto.which_one_of(message, group_name)` to determine wh
>>> test.count = 57
>>> betterproto.which_one_of(test, "foo")
["count", 57]
>>> test.on
False
# Default (zero) values also work.
>>> test.name = ""
>>> betterproto.which_one_of(test, "foo")
["name", ""]
>>> test.count
0
>>> test.on
False
```
Again this is a little different than the official Google code generator:
@ -392,54 +382,11 @@ swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for
this to work.
## Configuration typing imports
By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:
### Direct
```
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
```
this configuration is the default, and will import types as follows:
```
from typing import (
List,
Optional,
Union
)
...
value: List[str] = []
value2: Optional[str] = None
value3: Union[str, int] = 1
```
### Root
```
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
```
this configuration loads the root typing module, and then access the types off of it directly:
```
import typing
...
value: typing.List[str] = []
value2: typing.Optional[str] = None
value3: typing.Union[str, int] = 1
```
### 310
```
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
```
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
```
...
value: list[str] = []
value2: str | None = None
value3: str | int = 1
```
## Development
- _Join us on [Discord](https://discord.gg/DEVteTupPb)!_
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
- _See how you can help &rarr; [Contributing](.github/CONTRIBUTING.md)_
### Requirements
@ -575,7 +522,7 @@ protoc \
## Community
Join us on [Discord](https://discord.gg/DEVteTupPb)!
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
## License

View File

@ -6,32 +6,32 @@ import betterproto
@dataclass
class TestMessage(betterproto.Message):
foo: int = betterproto.uint32_field(1)
bar: str = betterproto.string_field(2)
baz: float = betterproto.float_field(3)
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
@dataclass
class TestNestedChildMessage(betterproto.Message):
str_key: str = betterproto.string_field(1)
bytes_key: bytes = betterproto.bytes_field(2)
bool_key: bool = betterproto.bool_field(3)
float_key: float = betterproto.float_field(4)
int_key: int = betterproto.uint64_field(5)
str_key: str = betterproto.string_field(0)
bytes_key: bytes = betterproto.bytes_field(1)
bool_key: bool = betterproto.bool_field(2)
float_key: float = betterproto.float_field(3)
int_key: int = betterproto.uint64_field(4)
@dataclass
class TestNestedMessage(betterproto.Message):
foo: TestNestedChildMessage = betterproto.message_field(1)
bar: TestNestedChildMessage = betterproto.message_field(2)
baz: TestNestedChildMessage = betterproto.message_field(3)
foo: TestNestedChildMessage = betterproto.message_field(0)
bar: TestNestedChildMessage = betterproto.message_field(1)
baz: TestNestedChildMessage = betterproto.message_field(2)
@dataclass
class TestRepeatedMessage(betterproto.Message):
foo_repeat: List[str] = betterproto.string_field(1)
bar_repeat: List[int] = betterproto.int64_field(2)
baz_repeat: List[bool] = betterproto.bool_field(3)
foo_repeat: List[str] = betterproto.string_field(0)
bar_repeat: List[int] = betterproto.int64_field(1)
baz_repeat: List[bool] = betterproto.bool_field(2)
class BenchMessage:
@ -44,14 +44,25 @@ class BenchMessage:
self.instance_filled_bytes = bytes(self.instance_filled)
self.instance_filled_nested = TestNestedMessage(
TestNestedChildMessage("foo", bytearray(b"test1"), True, 0.1234, 500),
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, 302),
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, -302),
TestNestedChildMessage("baz", bytearray(b"test3"), False, 1e5, 300),
)
self.instance_filled_nested_bytes = bytes(self.instance_filled_nested)
self.instance_filled_repeated = TestRepeatedMessage(
[f"test{i}" for i in range(1_000)],
[(i - 500) ** 3 for i in range(1_000)],
[i % 2 == 0 for i in range(1_000)],
[
"test1",
"test2",
"test3",
"test4",
"test5",
"test6",
"test7",
"test8",
"test9",
"test10",
],
[2, -100, 0, 500000, 600, -425678, 1000000000, -300, 1, -694214214466],
[True, False, False, False, True, True, False, True, False, False],
)
self.instance_filled_repeated_bytes = bytes(self.instance_filled_repeated)
@ -60,9 +71,9 @@ class BenchMessage:
@dataclass
class Message(betterproto.Message):
foo: int = betterproto.uint32_field(1)
bar: str = betterproto.string_field(2)
baz: float = betterproto.float_field(3)
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
def time_instantiation(self):
"""Time instantiation"""

72
betterproto-extras/.gitignore vendored Normal file
View File

@ -0,0 +1,72 @@
/target
# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.Python
.venv/
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
include/
man/
venv/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
pip-selfcheck.json
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
.DS_Store
# Sphinx documentation
docs/_build/
# PyCharm
.idea/
# VSCode
.vscode/
# Pyenv
.python-version

383
betterproto-extras/Cargo.lock generated Normal file
View File

@ -0,0 +1,383 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "anyhow"
version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "betterproto-extras"
version = "0.1.0"
dependencies = [
"indoc 2.0.3",
"prost-reflect",
"pyo3",
"thiserror",
]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bytes"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "indoc"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
[[package]]
name = "indoc"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c785eefb63ebd0e33416dfcb8d6da0bf27ce752843a45632a67bf10d4d4b5c4"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "libc"
version = "0.2.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "lock_api"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "memoffset"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]]
name = "parking_lot"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets",
]
[[package]]
name = "proc-macro2"
version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9"
dependencies = [
"unicode-ident",
]
[[package]]
name = "prost"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive",
]
[[package]]
name = "prost-derive"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "prost-reflect"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b823de344848e011658ac981009100818b322421676740546f8b52ed5249428"
dependencies = [
"once_cell",
"prost",
"prost-types",
]
[[package]]
name = "prost-types"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [
"prost",
]
[[package]]
name = "pyo3"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38"
dependencies = [
"cfg-if",
"indoc 1.0.9",
"libc",
"memoffset",
"parking_lot",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 1.0.109",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "quote"
version = "1.0.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae"
dependencies = [
"proc-macro2",
]
[[package]]
name = "redox_syscall"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29"
dependencies = [
"bitflags",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "smallvec"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "target-lexicon"
version = "0.12.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a"
[[package]]
name = "thiserror"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]]
name = "unicode-ident"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c"
[[package]]
name = "unindent"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"

View File

@ -0,0 +1,14 @@
[package]
name = "betterproto-extras"
version = "0.1.0"
edition = "2021"
[lib]
name = "betterproto_extras"
crate-type = ["cdylib"]
[dependencies]
indoc = "2.0.3"
prost-reflect = "0.11.5"
pyo3 = { version = "0.19.2", features = ["abi3-py37", "extension-module"] }
thiserror = "1.0.47"

View File

@ -0,0 +1,5 @@
def deserialize(msg, data: bytes):
"""
Parses the binary encoded Protobuf `data` with respect to the metadata
given by the betterproto message `msg`, and merges the result into `msg`.
"""

View File

@ -0,0 +1,16 @@
[build-system]
requires = ["maturin>=1.2,<2.0"]
build-backend = "maturin"
[project]
name = "betterproto-extras"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
[tool.maturin]
features = ["pyo3/extension-module"]

View File

@ -0,0 +1,289 @@
use crate::{
error::{Error, Result},
py_any_extras::PyAnyExtras,
};
use prost_reflect::{
prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
FileDescriptorProto, MessageOptions, OneofDescriptorProto,
},
DescriptorPool, MessageDescriptor,
};
use pyo3::PyAny;
use std::sync::{Mutex, OnceLock};
pub fn create_cached_descriptor(obj: &PyAny) -> Result<MessageDescriptor> {
static DESCRIPTOR_POOL: OnceLock<Mutex<DescriptorPool>> = OnceLock::new();
let mut pool = DESCRIPTOR_POOL
.get_or_init(|| Mutex::new(DescriptorPool::global()))
.lock()
.unwrap();
let cls = obj.getattr("__class__")?;
let name = format!("{}_{}", cls.qualified_name()?, cls.py_identifier());
if let Some(desc) = pool.get_message_by_name(&name) {
return Ok(desc);
}
let mut file = FileDescriptorProto {
name: Some(name.clone()),
..Default::default()
};
add_message_to_file(name.clone(), obj, &pool, &mut file)?;
pool.add_file_descriptor_proto(file)?;
Ok(pool.get_message_by_name(&name).expect("Just registered..."))
}
fn add_message_to_file(
message_name: String,
obj: &PyAny,
pool: &DescriptorPool,
file: &mut FileDescriptorProto,
) -> Result<()> {
let mut messages_to_add = vec![(message_name, obj)];
while let Some((message_name, obj)) = messages_to_add.pop() {
let meta = obj.get_proto_meta()?;
let mut message = DescriptorProto {
name: Some(message_name.to_string()),
..Default::default()
};
for item in meta
.getattr("meta_by_field_name")?
.call_method0("items")?
.iter()?
{
let (field_name, field_meta) = item?.extract::<(&str, &PyAny)>()?;
message.field.push({
let mut field = FieldDescriptorProto {
name: Some(field_name.to_string()),
number: Some(field_meta.getattr("number")?.extract::<i32>()?),
..Default::default()
};
let proto_type = field_meta.getattr("proto_type")?.extract::<&str>()?;
if proto_type == "map" {
field.set_type(Type::Message);
let (key, val) = field_meta.getattr("map_types")?.extract::<(&str, &str)>()?;
let key = map_type(key)?;
let val = map_type(val)?;
if matches!(
key,
Type::Float | Type::Double | Type::Bytes | Type::Message | Type::Enum
) {
return Err(Error::UnsupportedMapKeyType(key));
}
let map_entry_name = format!("{field_name}Entry");
field.type_name = Some(format!("{message_name}.{map_entry_name}"));
field.set_label(Label::Repeated);
message.nested_type.push(DescriptorProto {
name: Some(map_entry_name),
field: vec![
{
let mut proto = FieldDescriptorProto {
name: Some("key".to_string()),
number: Some(1),
..Default::default()
};
proto.set_type(key);
proto
},
{
let mut proto = FieldDescriptorProto {
name: Some("value".to_string()),
number: Some(2),
..Default::default()
};
proto.set_type(val);
if val == Type::Message {
set_type_name(
&message_name,
meta.get_class(&format!("{field_name}.value"))?,
&mut proto,
file,
&mut messages_to_add,
pool,
)?;
}
proto
},
],
options: Some(MessageOptions {
map_entry: Some(true),
..Default::default()
}),
..Default::default()
})
} else {
field.set_type(map_type(proto_type)?);
match field.r#type() {
Type::Message => match field_meta
.getattr("wraps")?
.extract::<Option<&str>>()?
.map(map_type)
.transpose()?
{
Some(Type::Bool) => {
field.type_name = Some("google.protobuf.BoolValue".to_string());
}
Some(Type::Double) => {
field.type_name = Some("google.protobuf.DoubleValue".to_string());
}
Some(Type::Float) => {
field.type_name = Some("google.protobuf.FloatValue".to_string());
}
Some(Type::Int64) => {
field.type_name = Some("google.protobuf.Int64Value".to_string());
}
Some(Type::Uint64) => {
field.type_name = Some("google.protobuf.UInt64Value".to_string());
}
Some(Type::Int32) => {
field.type_name = Some("google.protobuf.Int32Value".to_string());
}
Some(Type::Uint32) => {
field.type_name = Some("google.protobuf.UInt32Value".to_string());
}
Some(Type::String) => {
field.type_name = Some("google.protobuf.StringValue".to_string());
}
Some(Type::Bytes) => {
field.type_name = Some("google.protobuf.BytesValue".to_string());
}
Some(t) => return Err(Error::UnsupportedWrapperType(t)),
None => {
set_type_name(
&message_name,
meta.get_class(field_name)?,
&mut field,
file,
&mut messages_to_add,
pool,
)?;
}
},
Type::Enum => {
let cls = meta.get_class(field_name)?;
let cls_name =
format!("{}_{}", cls.qualified_name()?, cls.py_identifier());
field.type_name = Some(cls_name.to_string());
if pool.get_enum_by_name(&cls_name).is_none()
&& !file.enum_type.iter().any(|item| item.name() == cls_name)
{
let mut proto = EnumDescriptorProto {
name: Some(cls_name.clone()),
..Default::default()
};
for item in cls.iter()? {
let item = item?;
proto.value.push(EnumValueDescriptorProto {
number: Some(item.getattr("value")?.extract()?),
name: Some(format!(
"{}_{}",
cls_name,
item.getattr("name")?.extract::<&str>()?
)),
..Default::default()
});
}
file.enum_type.push(proto);
}
}
_ => {}
}
if meta.is_list_field(field_name)? {
field.set_label(Label::Repeated);
} else if field_meta.getattr("optional")?.extract::<bool>()? {
field.proto3_optional = Some(true);
}
}
if let Some(grp) = meta.oneof_group(field_name)? {
let oneof_index = message.oneof_decl.iter().position(|x| x.name() == grp);
match oneof_index {
Some(i) => field.oneof_index = Some(i as i32),
None => {
message.oneof_decl.push(OneofDescriptorProto {
name: Some(grp),
..Default::default()
});
field.oneof_index = Some((message.oneof_decl.len() - 1) as i32)
}
}
}
field
});
}
file.message_type.push(message);
}
Ok(())
}
fn map_type(str: &str) -> Result<Type> {
match str {
"enum" => Ok(Type::Enum),
"bool" => Ok(Type::Bool),
"int32" => Ok(Type::Int32),
"int64" => Ok(Type::Int64),
"uint32" => Ok(Type::Uint32),
"uint64" => Ok(Type::Uint64),
"sint32" => Ok(Type::Sint32),
"sint64" => Ok(Type::Sint64),
"float" => Ok(Type::Float),
"double" => Ok(Type::Double),
"fixed32" => Ok(Type::Fixed32),
"sfixed32" => Ok(Type::Sfixed32),
"fixed64" => Ok(Type::Fixed64),
"sfixed64" => Ok(Type::Sfixed64),
"string" => Ok(Type::String),
"bytes" => Ok(Type::Bytes),
"message" => Ok(Type::Message),
_ => Err(Error::UnsupportedType(str.to_string())),
}
}
fn set_type_name<'py>(
message_name: &str,
field_cls: &'py PyAny,
field: &mut FieldDescriptorProto,
file: &FileDescriptorProto,
messages_to_add: &mut Vec<(String, &'py PyAny)>,
pool: &DescriptorPool,
) -> Result<()> {
let cls_name = field_cls.qualified_name()?;
match cls_name.as_str() {
"datetime.datetime" => {
field.type_name = Some("google.protobuf.Timestamp".to_string());
}
"datetime.timedelta" => {
field.type_name = Some("google.protobuf.Duration".to_string());
}
_ => {
let cls_name = format!("{}_{}", cls_name, field_cls.py_identifier());
field.type_name = Some(cls_name.clone());
if message_name != cls_name
&& pool.get_message_by_name(&cls_name).is_none()
&& !file.message_type.iter().any(|item| item.name() == cls_name)
&& !messages_to_add.iter().any(|item| item.0 == cls_name)
{
messages_to_add.push((cls_name, field_cls.call0()?));
}
}
}
Ok(())
}

View File

@ -0,0 +1,29 @@
use prost_reflect::{
prost::DecodeError, prost_types::field_descriptor_proto::Type, DescriptorError,
};
use pyo3::{exceptions::PyRuntimeError, PyErr};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Given object is not a valid betterproto message.")]
NoBetterprotoMessage(#[from] PyErr),
#[error("Unsupported type `{0}`.")]
UnsupportedType(String),
#[error("Unsupported map key type `{0:?}`.")]
UnsupportedMapKeyType(Type),
#[error("Unsupported wrapper type `{0:?}`.")]
UnsupportedWrapperType(Type),
#[error("Error on proto registration")]
FailedToRegisterDescriptor(#[from] DescriptorError),
#[error("The given binary data does not match the protobuf schema.")]
FailedToDecode(#[from] DecodeError),
}
pub type Result<T> = core::result::Result<T, Error>;
impl From<Error> for PyErr {
fn from(value: Error) -> Self {
PyRuntimeError::new_err(value.to_string())
}
}

View File

@ -0,0 +1,24 @@
mod descriptor_pool;
mod error;
mod merging;
mod py_any_extras;
use descriptor_pool::create_cached_descriptor;
use error::Result;
use merging::merge_msg_into_pyobj;
use prost_reflect::DynamicMessage;
use pyo3::prelude::*;
#[pyfunction]
fn deserialize(obj: &PyAny, buf: &[u8]) -> Result<()> {
let desc = create_cached_descriptor(obj)?;
let msg = DynamicMessage::decode(desc, buf)?;
merge_msg_into_pyobj(obj, msg)?;
Ok(())
}
#[pymodule]
fn betterproto_extras(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(deserialize, m)?)?;
Ok(())
}

View File

@ -0,0 +1,182 @@
use crate::{error::Result, py_any_extras::PyAnyExtras};
use indoc::indoc;
use prost_reflect::{
prost_types::{Duration, Timestamp},
DynamicMessage, MapKey, ReflectMessage, Value,
};
use pyo3::{
sync::GILOnceCell,
types::{IntoPyDict, PyBytes, PyModule},
Py, PyAny, PyObject, Python, ToPyObject,
};
pub fn merge_msg_into_pyobj(obj: &PyAny, mut msg: DynamicMessage) -> Result<()> {
for field in msg.take_fields() {
let field_name = field.0.name();
let proto_meta = obj.get_proto_meta()?;
obj.setattr(
field_name,
map_field_value(field_name, field.1, proto_meta)?,
)?;
}
let mut buf = vec![];
for field in msg.unknown_fields() {
field.encode(&mut buf);
}
if !buf.is_empty() {
let mut unknown_fields = obj.getattr("_unknown_fields")?.extract::<Vec<u8>>()?;
unknown_fields.append(&mut buf);
obj.setattr("_unknown_fields", PyBytes::new(obj.py(), &unknown_fields))?;
}
obj.setattr("_serialized_on_wire", true)?;
Ok(())
}
fn map_field_value(field_name: &str, field_value: Value, proto_meta: &PyAny) -> Result<PyObject> {
let py = proto_meta.py();
match field_value {
Value::Bool(x) => Ok(x.to_object(py)),
Value::Bytes(x) => Ok(PyBytes::new(py, &x).to_object(py)),
Value::F32(x) => Ok(x.to_object(py)),
Value::F64(x) => Ok(x.to_object(py)),
Value::I32(x) => Ok(x.to_object(py)),
Value::I64(x) => Ok(x.to_object(py)),
Value::String(x) => Ok(x.to_object(py)),
Value::U32(x) => Ok(x.to_object(py)),
Value::U64(x) => Ok(x.to_object(py)),
Value::Message(msg) => match msg.descriptor().full_name() {
"google.protobuf.BoolValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_bool())
.to_object(py)),
"google.protobuf.DoubleValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_f64())
.to_object(py)),
"google.protobuf.FloatValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_f32())
.to_object(py)),
"google.protobuf.Int64Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_i64())
.to_object(py)),
"google.protobuf.UInt64Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_u64())
.to_object(py)),
"google.protobuf.Int32Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_i32())
.to_object(py)),
"google.protobuf.UInt32Value" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_u32())
.to_object(py)),
"google.protobuf.StringValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_str().map(|s| s.to_string()))
.to_object(py)),
"google.protobuf.BytesValue" => Ok(msg
.get_field_by_number(1)
.and_then(|val| val.as_bytes().map(|b| PyBytes::new(py, b)))
.to_object(py)),
"google.protobuf.Timestamp" => {
let msg = msg.transcode_to::<Timestamp>()?;
Ok(create_py_datetime(&msg, py))
}
"google.protobuf.Duration" => {
let msg = msg.transcode_to::<Duration>()?;
Ok(create_py_timedelta(&msg, py))
}
_ => {
let obj = proto_meta.create_instance(field_name)?;
merge_msg_into_pyobj(obj, msg)?;
Ok(obj.to_object(py))
}
},
Value::List(ls) => Ok(ls
.into_iter()
.map(|x| map_field_value(field_name, x, proto_meta))
.collect::<Result<Vec<PyObject>>>()?
.to_object(py)),
Value::EnumNumber(x) => {
let cls = proto_meta.get_class(field_name)?;
Ok(cls.call1((x,))?.to_object(py))
}
Value::Map(map) => {
let res: Result<Vec<_>> = map
.into_iter()
.map(|(k, v)| {
let key = map_key(k, py);
let val = map_field_value(&format!("{field_name}.value"), v, proto_meta)?;
Ok((key, val))
})
.collect();
Ok(res?.into_py_dict(py).to_object(py))
}
}
}
fn map_key(key: MapKey, py: Python) -> PyObject {
match key {
MapKey::Bool(x) => x.to_object(py),
MapKey::I32(x) => x.to_object(py),
MapKey::I64(x) => x.to_object(py),
MapKey::U32(x) => x.to_object(py),
MapKey::U64(x) => x.to_object(py),
MapKey::String(x) => x.to_object(py),
}
}
fn create_py_datetime(ts: &Timestamp, py: Python) -> PyObject {
static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || {
let constructor = PyModule::from_code(
py,
indoc! {"
from datetime import datetime, timezone
def constructor(ts):
return datetime.fromtimestamp(ts, tz=timezone.utc)
"},
"",
"",
)
.expect("This is a valid Python module")
.getattr("constructor")
.expect("Attribute exists");
Py::from(constructor)
});
let ts = (ts.seconds as f64) + (ts.nanos as f64) / 1e9;
constructor
.call1(py, (ts,))
.expect("static function will not fail")
}
fn create_py_timedelta(duration: &Duration, py: Python) -> PyObject {
static CONSTRUCTOR_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let constructor = CONSTRUCTOR_CACHE.get_or_init(py, || {
let constructor = PyModule::from_code(
py,
indoc! {"
from datetime import timedelta
def constructor(s, ms):
return timedelta(seconds=s, microseconds=ms)
"},
"",
"",
)
.expect("This is a valid Python module")
.getattr("constructor")
.expect("Attribute exists");
Py::from(constructor)
});
constructor
.call1(py, (duration.seconds as f64, (duration.nanos as f64) / 1e3))
.expect("static function will not fail")
}

View File

@ -0,0 +1,68 @@
use crate::error::Result;
use pyo3::{PyAny, Py, sync::GILOnceCell};
pub trait PyAnyExtras {
fn qualified_name(&self) -> Result<String>;
fn qualified_class_name(&self) -> Result<String>;
fn get_proto_meta(&self) -> Result<&PyAny>;
fn get_class(&self, field_name: &str) -> Result<&PyAny>;
fn create_instance(&self, field_name: &str) -> Result<&PyAny>;
fn is_list_field(&self, field_name: &str) -> Result<bool>;
fn oneof_group(&self, field_name: &str) -> Result<Option<String>>;
fn py_identifier(&self) -> u64;
}
impl PyAnyExtras for PyAny {
fn qualified_name(&self) -> Result<String> {
let module = self.getattr("__module__")?;
let name = self.getattr("__name__")?;
Ok(format!("{module}.{name}"))
}
fn qualified_class_name(&self) -> Result<String> {
self.getattr("__class__")?.qualified_name()
}
fn get_proto_meta(&self) -> Result<&PyAny> {
Ok(self.getattr("_betterproto")?)
}
fn get_class(&self, field_name: &str) -> Result<&PyAny> {
let cls = self.getattr("cls_by_field")?.get_item(field_name)?;
Ok(cls)
}
fn create_instance(&self, field_name: &str) -> Result<&PyAny> {
Ok(self.get_class(field_name)?.call0()?)
}
fn is_list_field(&self, field_name: &str) -> Result<bool> {
let cls = self.getattr("default_gen")?.get_item(field_name)?;
let module = cls.getattr("__module__")?;
let name = cls.getattr("__name__")?;
Ok(module.to_string() == "builtins" && name.to_string() == "list")
}
fn oneof_group(&self, field_name: &str) -> Result<Option<String>> {
let opt = self
.getattr("oneof_group_by_field")?
.call_method1("get", (field_name,))?
.extract()?;
Ok(opt)
}
fn py_identifier(&self) -> u64 {
static FUN_CACHE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
let py = self.py();
let fun = FUN_CACHE.get_or_init(py, || {
let fun = py
.eval("id", None, None)
.expect("This is a valid Python expression");
Py::from(fun)
});
fun.call1(py, (self,))
.expect("Identity function is callable")
.extract::<u64>(py)
.expect("Identity function always returns an integer")
}
}

View File

@ -85,19 +85,17 @@ wrappers used to provide optional zero value support. Each of these has a specia
representation and is handled a little differently from normal messages. The Python
mapping for these is as follows:
+-------------------------------+-------------------------------------------------+--------------------------+
| ``Google Message`` | ``Python Type`` | ``Default`` |
+===============================+=================================================+==========================+
| ``google.protobuf.duration`` | :class:`datetime.timedelta` | ``0`` |
+-------------------------------+-------------------------------------------------+--------------------------+
| ``google.protobuf.timestamp`` | ``Timezone-aware`` :class:`datetime.datetime` | ``1970-01-01T00:00:00Z`` |
+-------------------------------+-------------------------------------------------+--------------------------+
| ``google.protobuf.*Value`` | ``Optional[...]``/``None`` | ``None`` |
+-------------------------------+-------------------------------------------------+--------------------------+
| ``google.protobuf.*`` | ``betterproto.lib.std.google.protobuf.*`` | ``None`` |
+-------------------------------+-------------------------------------------------+--------------------------+
| ``google.protobuf.*`` | ``betterproto.lib.pydantic.google.protobuf.*`` | ``None`` |
+-------------------------------+-------------------------------------------------+--------------------------+
+-------------------------------+-----------------------------------------------+--------------------------+
| ``Google Message`` | ``Python Type`` | ``Default`` |
+===============================+===============================================+==========================+
| ``google.protobuf.duration`` | :class:`datetime.timedelta` | ``0`` |
+-------------------------------+-----------------------------------------------+--------------------------+
| ``google.protobuf.timestamp`` | ``Timezone-aware`` :class:`datetime.datetime` | ``1970-01-01T00:00:00Z`` |
+-------------------------------+-----------------------------------------------+--------------------------+
| ``google.protobuf.*Value`` | ``Optional[...]``/``None`` | ``None`` |
+-------------------------------+-----------------------------------------------+--------------------------+
| ``google.protobuf.*`` | ``betterproto.lib.google.protobuf.*`` | ``None`` |
+-------------------------------+-----------------------------------------------+--------------------------+
For the wrapper types, the Python type corresponds to the wrapped type, e.g.

55
example.py Normal file
View File

@ -0,0 +1,55 @@
# dev tests
# to be deleted later
import betterproto
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass(repr=False)
class Baz(betterproto.Message):
a: float = betterproto.float_field(1, group = "x")
b: int = betterproto.int64_field(2, group = "x")
c: float = betterproto.float_field(3, group = "y")
d: int = betterproto.int64_field(4, group = "y")
e: Optional[int] = betterproto.int32_field(5, group = "_e", optional = True)
@dataclass(repr=False)
class Foo(betterproto.Message):
x: int = betterproto.int32_field(1)
y: float = betterproto.double_field(2)
z: List[Baz] = betterproto.message_field(3)
class Enm(betterproto.Enum):
A = 0
B = 1
C = 2
@dataclass(repr=False)
class Bar(betterproto.Message):
foo1: Foo = betterproto.message_field(1)
foo2: Foo = betterproto.message_field(2)
packed: List[int] = betterproto.int64_field(3)
enm: Enm = betterproto.enum_field(4)
map: Dict[int, bool] = betterproto.map_field(5, betterproto.TYPE_INT64, betterproto.TYPE_BOOL)
maybe: Optional[bool] = betterproto.message_field(6, wraps=betterproto.TYPE_BOOL)
bts: bytes = betterproto.bytes_field(7)
# Serialization has not been changed yet. So nothing unusual here
buffer = bytes(
Bar(
foo1=Foo(1, 2.34),
foo2=Foo(3, 4.56, [Baz(a = 1.234), Baz(b = 5, e=1), Baz(b = 2, d = 3)]),
packed=[5, 3, 1],
enm=Enm.B,
map={
1: True,
42: False
},
maybe=True,
bts=b'Hi There!'
)
)
# Native deserialization happening here
bar = Bar().parse(buffer)
print(bar)

2586
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,8 @@
[project]
[tool.poetry]
name = "betterproto"
version = "2.0.0b7"
version = "2.0.0b6"
description = "A better Protobuf / gRPC generator & library"
authors = [
{name = "Daniel G. Taylor", email = "danielgtaylor@gmail.com"}
]
authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"]
readme = "README.md"
repository = "https://github.com/danielgtaylor/python-betterproto"
keywords = ["protobuf", "gRPC"]
@ -12,54 +10,43 @@ license = "MIT"
packages = [
{ include = "betterproto", from = "src" }
]
requires-python = ">=3.9,<4.0"
dynamic = ["dependencies"]
[tool.poetry.dependencies]
# The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml
ruff = { version = "~0.9.1", optional = true }
python = "^3.7"
black = { version = ">=23.1.0", optional = true }
grpclib = "^0.4.1"
importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
typing-extensions = "^4.7.1"
betterproto-rust-codec = { version = "0.1.1", optional = true }
isort = {version = "^5.11.5", optional = true}
betterproto-extras = { path = "betterproto-extras" }
[tool.poetry.group.dev.dependencies]
asv = "^0.6.4"
bpython = "^0.24"
jinja2 = ">=3.0.3"
mypy = "^1.11.2"
sphinx = "7.4.7"
sphinx-rtd-theme = "3.0.2"
pre-commit = "^4.0.1"
[tool.poetry.dev-dependencies]
asv = "^0.4.2"
bpython = "^0.19"
grpcio-tools = "^1.54.2"
tox = "^4.0.0"
[tool.poetry.group.test.dependencies]
jinja2 = ">=3.0.3"
mypy = "^0.930"
poethepoet = ">=0.9.0"
pytest = "^7.4.4"
pytest-asyncio = "^0.23.8"
pytest-cov = "^6.0.0"
protobuf = "^4.21.6"
pytest = "^6.2.5"
pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1"
pydantic = ">=2.0,<3"
protobuf = "^5"
cachelib = "^0.13.0"
tomlkit = ">=0.7.0"
sphinx = "3.1.2"
sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0"
tox = "^3.15.1"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"
[project.scripts]
[tool.poetry.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main"
[project.optional-dependencies]
compiler = ["ruff", "jinja2"]
rust-codec = ["betterproto-rust-codec"]
[tool.poetry.extras]
compiler = ["black", "isort", "jinja2"]
[tool.ruff]
extend-exclude = ["tests/output_*"]
target-version = "py38"
[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
# Dev workflow tasks
@ -76,28 +63,8 @@ cmd = "mypy src --ignore-missing-imports"
help = "Check types with mypy"
[tool.poe.tasks.format]
sequence = ["_format", "_sort-imports"]
help = "Format the source code, and sort the imports"
[tool.poe.tasks.check]
sequence = ["_check-format", "_check-imports"]
help = "Check that the source code is formatted and the imports sorted"
[tool.poe.tasks._format]
cmd = "ruff format src tests"
help = "Format the source code without sorting the imports"
[tool.poe.tasks._sort-imports]
cmd = "ruff check --select I --fix src tests"
help = "Sort the imports"
[tool.poe.tasks._check-format]
cmd = "ruff format --diff src tests"
help = "Check that the source code is formatted"
[tool.poe.tasks._check-imports]
cmd = "ruff check --select I src tests"
help = "Check that the imports are sorted"
cmd = "black . --exclude tests/output_ --target-version py310"
help = "Apply black formatting to source code"
[tool.poe.tasks.docs]
cmd = "sphinx-build docs docs/build"
@ -120,11 +87,11 @@ cmd = """
protoc
--plugin=protoc-gen-custom=src/betterproto/plugin/main.py
--custom_opt=INCLUDE_GOOGLE
--custom_out=src/betterproto/lib/std
--custom_out=src/betterproto/lib
-I C:\\work\\include
C:\\work\\include\\google\\protobuf\\**\\*.proto
"""
help = "Regenerate the types in betterproto.lib.std.google"
help = "Regenerate the types in betterproto.lib.google"
# CI tasks
@ -132,6 +99,23 @@ help = "Regenerate the types in betterproto.lib.std.google"
shell = "poe generate && tox"
help = "Run tests with multiple pythons"
[tool.poe.tasks.check-style]
cmd = "black . --check --diff"
help = "Check if code style is correct"
[tool.isort]
py_version = 37
profile = "black"
force_single_line = false
combine_as_imports = true
lines_after_imports = 2
include_trailing_comma = true
force_grid_wrap = 2
src_paths = ["src", "tests"]
[tool.black]
target-version = ['py37']
[tool.doc8]
paths = ["docs"]
max_line_length = 88
@ -147,23 +131,16 @@ omit = ["betterproto/tests/*"]
[tool.tox]
legacy_tox_ini = """
[tox]
requires =
tox>=4.2
tox-poetry-installer[poetry]==1.0.0b1
env_list =
py311
py38
py37
isolated_build = true
envlist = py37, py38, py310
[testenv]
whitelist_externals = poetry
commands =
pytest {posargs: --cov betterproto}
poetry_dep_groups =
test
require_locked_deps = true
require_poetry = true
poetry install -v --extras compiler
poetry run pytest --cov betterproto
"""
[build-system]
requires = ["poetry-core>=2.0.0,<3"]
requires = ["poetry-core>=1.0.0,<2"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,7 +1,5 @@
from __future__ import annotations
import dataclasses
import enum as builtin_enum
import enum
import json
import math
import struct
@ -24,8 +22,8 @@ from itertools import count
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
ClassVar,
Dict,
Generator,
Iterable,
@ -39,7 +37,6 @@ from typing import (
)
from dateutil.parser import isoparse
from typing_extensions import Self
from ._types import T
from ._version import __version__
@ -48,25 +45,11 @@ from .casing import (
safe_snake_case,
snake_case,
)
from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import (
classproperty,
hybridmethod,
)
from .grpc.grpclib_client import ServiceStub
if TYPE_CHECKING:
from _typeshed import (
SupportsRead,
SupportsWrite,
)
if sys.version_info >= (3, 10):
from types import UnionType as _types_UnionType
else:
class _types_UnionType: ...
from _typeshed import ReadableBuffer
# Proto 3 data types
@ -143,9 +126,6 @@ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Indicator of message delimitation in streams
SIZE_DELIMITED = -1
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen() -> datetime:
@ -154,36 +134,20 @@ def datetime_default_gen() -> datetime:
DATETIME_ZERO = datetime_default_gen()
# Special protobuf json doubles
INFINITY = "Infinity"
NEG_INFINITY = "-Infinity"
NAN = "NaN"
class Casing(builtin_enum.Enum):
class Casing(enum.Enum):
"""Casing constants for serialization."""
CAMEL = camel_case #: A camelCase sterilization function.
SNAKE = snake_case #: A snake_case sterilization function.
class Placeholder:
__slots__ = ()
def __repr__(self) -> str:
return "<PLACEHOLDER>"
def __copy__(self) -> Self:
return self
def __deepcopy__(self, _) -> Self:
return self
# We can't simply use object() here because pydantic automatically performs deep-copy of mutable default values
# See #606
PLACEHOLDER: Any = Placeholder()
PLACEHOLDER: Any = object()
@dataclasses.dataclass(frozen=True)
@ -220,7 +184,7 @@ def dataclass_field(
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=None if optional else PLACEHOLDER, # type: ignore
default=None if optional else PLACEHOLDER,
metadata={
"betterproto": FieldMetadata(
number, proto_type, map_types, group, wraps, optional
@ -345,6 +309,32 @@ def map_field(
)
class Enum(enum.IntEnum):
"""
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""
@classmethod
def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary."""
return {
@ -357,7 +347,7 @@ def _pack_fmt(proto_type: str) -> str:
}[proto_type]
def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
def dump_varint(value: int, stream: BinaryIO) -> None:
"""Encodes a single varint and dumps it into the provided stream."""
if value < -(1 << 63):
raise ValueError(
@ -566,7 +556,7 @@ def _dump_float(value: float) -> Union[float, str]:
return value
def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
"""
Load a single varint value from a stream. Returns the value and the raw bytes read.
"""
@ -604,7 +594,7 @@ class ParsedField:
raw: bytes
def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
while True:
try:
num_wire, raw = load_varint(stream)
@ -758,7 +748,6 @@ class Message(ABC):
_serialized_on_wire: bool
_unknown_fields: bytes
_group_current: Dict[str, str]
_betterproto_meta: ClassVar[ProtoClassMetadata]
def __post_init__(self) -> None:
# Keep track of whether every field was default
@ -771,7 +760,7 @@ class Message(ABC):
group_current.setdefault(meta.group)
value = self.__raw_get(field_name)
if value is not PLACEHOLDER and not (meta.optional and value is None):
if value != PLACEHOLDER and not (meta.optional and value is None):
# Found a non-sentinel value
all_sentinel = False
@ -826,10 +815,6 @@ class Message(ABC):
]
return f"{self.__class__.__name__}({', '.join(parts)})"
def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
for field_name in self._betterproto.sorted_field_names:
yield field_name, self.__raw_get(field_name), PLACEHOLDER
if not TYPE_CHECKING:
def __getattribute__(self, name: str) -> Any:
@ -904,28 +889,20 @@ class Message(ABC):
kwargs[name] = deepcopy(value)
return self.__class__(**kwargs) # type: ignore
def __copy__(self: T, _: Any = {}) -> T:
kwargs = {}
for name in self._betterproto.sorted_field_names:
value = self.__raw_get(name)
if value is not PLACEHOLDER:
kwargs[name] = value
return self.__class__(**kwargs) # type: ignore
@classproperty
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
@property
def _betterproto(self) -> ProtoClassMetadata:
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness.
"""
try:
return cls._betterproto_meta
except AttributeError:
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
return meta
meta = getattr(self.__class__, "_betterproto_meta", None)
if not meta:
meta = ProtoClassMetadata(self.__class__)
self.__class__._betterproto_meta = meta # type: ignore
return meta
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
def dump(self, stream: BinaryIO) -> None:
"""
Dumps the binary encoded Protobuf message to the stream.
@ -933,11 +910,7 @@ class Message(ABC):
-----------
stream: :class:`BinaryIO`
The stream to dump the message to.
delimit:
Whether to prefix the message with a varint declaring its size.
"""
if delimit == SIZE_DELIMITED:
dump_varint(len(self), stream)
for field_name, meta in self._betterproto.meta_by_field_name.items():
try:
@ -957,7 +930,7 @@ class Message(ABC):
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = bool(meta.group) or meta.optional
selected_in_group = bool(meta.group)
# Empty messages can still be sent on the wire if they were
# set (or received empty).
@ -1151,15 +1124,6 @@ class Message(ABC):
"""
return bytes(self)
def __getstate__(self) -> bytes:
return bytes(self)
def __setstate__(self: T, pickled_bytes: bytes) -> T:
return self.parse(pickled_bytes)
def __reduce__(self) -> Tuple[Any, ...]:
return (self.__class__.FromString, (bytes(self),))
@classmethod
def _type_hint(cls, field_name: str) -> Type:
return cls._type_hints()[field_name]
@ -1188,29 +1152,30 @@ class Message(ABC):
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
t = cls._type_hint(field.name)
is_310_union = isinstance(t, _types_UnionType)
if hasattr(t, "__origin__") or is_310_union:
if is_310_union or t.__origin__ is Union:
if hasattr(t, "__origin__"):
if t.__origin__ is dict:
# This is some kind of map (dict in Python).
return dict
elif t.__origin__ is list:
# This is some kind of list (repeated) field.
return list
elif t.__origin__ is Union and t.__args__[1] is type(None):
# This is an optional field (either wrapped, or using proto3
# field presence). For setting the default we really don't care
# what kind of field it is.
return type(None)
if t.__origin__ is list:
# This is some kind of list (repeated) field.
return list
if t.__origin__ is dict:
# This is some kind of map (dict in Python).
return dict
return t
if issubclass(t, Enum):
else:
return t
elif issubclass(t, Enum):
# Enums always default to zero.
return t.try_value
if t is datetime:
return int
elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
return t
else:
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
return t
def _postprocess_single(
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
@ -1228,9 +1193,6 @@ class Message(ABC):
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
elif meta.proto_type == TYPE_ENUM:
# Convert enum ints to python enum instances
value = self._betterproto.cls_by_field[field_name].try_value(value)
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
@ -1263,11 +1225,7 @@ class Message(ABC):
meta.group is not None and self._group_current.get(meta.group) == field_name
)
def load(
self: T,
stream: "SupportsRead[bytes]",
size: Optional[int] = None,
) -> T:
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
"""
Load the binary encoded Protobuf from a stream into this message instance. This
returns the instance itself and is therefore assignable and chainable.
@ -1279,17 +1237,12 @@ class Message(ABC):
size: :class:`Optional[int]`
The size of the message in the stream.
Reads stream until EOF if ``None`` is given.
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
Returns
--------
:class:`Message`
The initialized message.
"""
# If the message is delimited, parse the message delimiter
if size == SIZE_DELIMITED:
size, _ = load_varint(stream)
# Got some data over the wire
self._serialized_on_wire = True
proto_meta = self._betterproto
@ -1362,7 +1315,7 @@ class Message(ABC):
return self
def parse(self: T, data: bytes) -> T:
def parse(self: T, data: "ReadableBuffer") -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
@ -1377,6 +1330,13 @@ class Message(ABC):
:class:`Message`
The initialized message.
"""
if True:
# TODO: Make native deserialization optional
import betterproto_extras
betterproto_extras.deserialize(self, data)
return self
with BytesIO(data) as stream:
return self.load(stream)
@ -1541,91 +1501,7 @@ class Message(ABC):
output[cased_name] = value
return output
@classmethod
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
init_kwargs: Dict[str, Any] = {}
for key, value in mapping.items():
field_name = safe_snake_case(key)
try:
meta = cls._betterproto.meta_by_field_name[field_name]
except KeyError:
continue
if value is None:
continue
if meta.proto_type == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[field_name]
if sub_cls == datetime:
value = (
[isoparse(item) for item in value]
if isinstance(value, list)
else isoparse(value)
)
elif sub_cls == timedelta:
value = (
[timedelta(seconds=float(item[:-1])) for item in value]
if isinstance(value, list)
else timedelta(seconds=float(value[:-1]))
)
elif not meta.wraps:
value = (
[sub_cls.from_dict(item) for item in value]
if isinstance(value, list)
else sub_cls.from_dict(value)
)
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
else:
if meta.proto_type in INT_64_TYPES:
value = (
[int(n) for n in value]
if isinstance(value, list)
else int(value)
)
elif meta.proto_type == TYPE_BYTES:
value = (
[b64decode(n) for n in value]
if isinstance(value, list)
else b64decode(value)
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]
if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
elif isinstance(value, str):
value = enum_cls.from_string(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
if isinstance(value, list)
else _parse_float(value)
)
init_kwargs[field_name] = value
return init_kwargs
@hybridmethod
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
"""
Parse the key/value pairs into the a new message instance.
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
"""
self = cls(**cls._from_dict_init(value))
self._serialized_on_wire = True
return self
@from_dict.instancemethod
def from_dict(self, value: Mapping[str, Any]) -> Self:
def from_dict(self: T, value: Mapping[str, Any]) -> T:
"""
Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable.
@ -1641,8 +1517,71 @@ class Message(ABC):
The initialized message.
"""
self._serialized_on_wire = True
for field, value in self._from_dict_init(value).items():
setattr(self, field, value)
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif cls == datetime:
v = isoparse(value[key])
setattr(self, field_name, v)
elif cls == timedelta:
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
if v is not None:
setattr(self, field_name, v)
return self
def to_json(
@ -1859,8 +1798,8 @@ class Message(ABC):
@classmethod
def _validate_field_groups(cls, values):
group_to_one_ofs = cls._betterproto.oneof_field_by_group
field_name_to_meta = cls._betterproto.meta_by_field_name
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1:
@ -1873,12 +1812,12 @@ class Message(ABC):
continue
set_fields = [
field.name
for field in field_set
if getattr(values, field.name, None) is not None
field.name for field in field_set if values[field.name] is not None
]
if len(set_fields) > 1:
if not set_fields:
raise ValueError(f"Group {group} has no value; all fields are None")
elif len(set_fields) > 1:
set_fields_str = ", ".join(set_fields)
raise ValueError(
f"Group {group} has more than one value; fields {set_fields_str} are not None"
@ -1887,26 +1826,6 @@ class Message(ABC):
return values
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
# monkey patch (de-)serialization functions of class `Message`
# with functions from `betterproto-rust-codec` if available
try:
import betterproto_rust_codec
def __parse_patch(self: T, data: bytes) -> T:
betterproto_rust_codec.deserialize(self, data)
return self
def __bytes_patch(self) -> bytes:
return betterproto_rust_codec.serialize(self)
Message.parse = __parse_patch
Message.__bytes__ = __bytes_patch
except ModuleNotFoundError:
pass
def serialized_on_wire(message: Message) -> bool:
"""
If this message was or should be serialized on the wire. This can be used to detect
@ -1978,26 +1897,17 @@ class _Duration(Duration):
class _Timestamp(Timestamp):
@classmethod
def from_datetime(cls, dt: datetime) -> "_Timestamp":
# manual epoch offset calulation to avoid rounding errors,
# to support negative timestamps (before 1970) and skirt
# around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
offset = dt - DATETIME_ZERO
# below is the same as timedelta.total_seconds() but without dividing by 1e6
# so we end up with microseconds as integers instead of seconds as float
offset_us = (
offset.days * 24 * 60 * 60 + offset.seconds
) * 10**6 + offset.microseconds
seconds, us = divmod(offset_us, 10**6)
return cls(seconds, us * 1000)
seconds = int(dt.timestamp())
nanos = int(dt.microsecond * 1e3)
return cls(seconds, nanos)
def to_datetime(self) -> datetime:
# datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
# if we pass it as a floating point number, we will run into rounding errors
# see also #407
offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
return DATETIME_ZERO + offset
ts = self.seconds + (self.nanos / 1e9)
if ts < 0:
return datetime(1970, 1, 1) + timedelta(seconds=ts)
else:
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
def timestamp_to_json(dt: datetime) -> str:
@ -2013,10 +1923,10 @@ class _Timestamp(Timestamp):
return f"{result}Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return f"{result}.{int(nanos // 1e6):03d}Z"
return f"{result}.{int(nanos // 1e6) :03d}Z"
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return f"{result}.{int(nanos // 1e3):06d}Z"
return f"{result}.{int(nanos // 1e3) :06d}Z"
# Serialize 9 fractional digits.
return f"{result}.{nanos:09d}"

View File

@ -136,8 +136,4 @@ def lowercase_first(value: str) -> str:
def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
if keyword.iskeyword(value):
return f"{value}_"
if not value.isidentifier():
return f"_{value}"
return value
return f"{value}_" if keyword.iskeyword(value) else value

View File

@ -1,9 +1,6 @@
from __future__ import annotations
import os
import re
from typing import (
TYPE_CHECKING,
Dict,
List,
Set,
@ -16,9 +13,6 @@ from ..lib.google import protobuf as google_protobuf
from .naming import pythonize_class_name
if TYPE_CHECKING:
from ..plugin.typing_compiler import TypingCompiler
WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
".google.protobuf.FloatValue": google_protobuf.FloatValue,
@ -49,13 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
def get_type_reference(
*,
package: str,
imports: set,
source_type: str,
typing_compiler: TypingCompiler,
unwrap: bool = True,
pydantic: bool = False,
*, package: str, imports: set, source_type: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
@ -64,7 +52,7 @@ def get_type_reference(
if unwrap:
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return typing_compiler.optional(wrapped_type.__name__)
return f"Optional[{wrapped_type.__name__}]"
if source_type == ".google.protobuf.Duration":
return "timedelta"
@ -81,9 +69,7 @@ def get_type_reference(
compiling_google_protobuf = current_package == ["google", "protobuf"]
importing_google_protobuf = py_package == ["google", "protobuf"]
if importing_google_protobuf and not compiling_google_protobuf:
py_package = (
["betterproto", "lib"] + (["pydantic"] if pydantic else []) + py_package
)
py_package = ["betterproto", "lib"] + py_package
if py_package[:1] == ["betterproto"]:
return reference_absolute(imports, py_package, py_type)

View File

@ -11,11 +11,3 @@ def pythonize_field_name(name: str) -> str:
def pythonize_method_name(name: str) -> str:
return casing.safe_snake_case(name)
def pythonize_enum_member_name(name: str, enum_name: str) -> str:
enum_name = casing.snake_case(enum_name).upper()
find = name.find(enum_name)
if find != -1:
name = name[find + len(enum_name) :].strip("_")
return casing.sanitize_name(name)

View File

@ -1,197 +0,0 @@
from __future__ import annotations
from enum import (
EnumMeta,
IntEnum,
)
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Tuple,
)
if TYPE_CHECKING:
from collections.abc import (
Generator,
Mapping,
)
from typing_extensions import (
Never,
Self,
)
def _is_descriptor(obj: object) -> bool:
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)
class EnumType(EnumMeta if TYPE_CHECKING else type):
_value_map_: Mapping[int, Enum]
_member_map_: Mapping[str, Enum]
def __new__(
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
) -> Self:
value_map = {}
member_map = {}
new_mcs = type(
f"{name}Type",
tuple(
dict.fromkeys(
[base.__class__ for base in bases if base.__class__ is not type]
+ [EnumType, type]
)
), # reorder the bases so EnumType and type are last to avoid conflicts
{"_value_map_": value_map, "_member_map_": member_map},
)
members = {
name: value
for name, value in namespace.items()
if not _is_descriptor(value) and not name.startswith("__")
}
cls = type.__new__(
new_mcs,
name,
bases,
{key: value for key, value in namespace.items() if key not in members},
)
# this allows us to disallow member access from other members as
# members become proper class variables
for name, value in members.items():
member = value_map.get(value)
if member is None:
member = cls.__new__(cls, name=name, value=value) # type: ignore
value_map[value] = member
member_map[name] = member
type.__setattr__(new_mcs, name, member)
return cls
if not TYPE_CHECKING:
def __call__(cls, value: int) -> Enum:
try:
return cls._value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
def __iter__(cls) -> Generator[Enum, None, None]:
yield from cls._member_map_.values()
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(cls._member_map_.values())
def __getitem__(cls, key: str) -> Enum:
return cls._member_map_[key]
@property
def __members__(cls) -> MappingProxyType[str, Enum]:
return MappingProxyType(cls._member_map_)
def __repr__(cls) -> str:
return f"<enum {cls.__name__!r}>"
def __len__(cls) -> int:
return len(cls._member_map_)
def __setattr__(cls, name: str, value: Any) -> Never:
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
def __delattr__(cls, name: str) -> Never:
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
def __contains__(cls, member: object) -> bool:
return isinstance(member, cls) and member.name in cls._member_map_
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
"""
The base class for protobuf enumerations, all generated enumerations will
inherit from this. Emulates `enum.IntEnum`.
"""
name: Optional[str]
value: int
if not TYPE_CHECKING:
def __new__(cls, *, name: Optional[str], value: int) -> Self:
self = super().__new__(cls, value)
super().__setattr__(self, "name", name)
super().__setattr__(self, "value", value)
return self
def __getnewargs_ex__(self) -> Tuple[Tuple[()], Dict[str, Any]]:
return (), {"name": self.name, "value": self.value}
def __str__(self) -> str:
return self.name or "None"
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __setattr__(self, key: str, value: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot reassign a member's attributes."
)
def __delattr__(self, item: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot delete a member's attributes."
)
def __copy__(self) -> Self:
return self
def __deepcopy__(self, memo: Any) -> Self:
return self
@classmethod
def try_value(cls, value: int = 0) -> Self:
"""Return the value which corresponds to the value.
Parameters
-----------
value: :class:`int`
The value of the enum member to get.
Returns
-------
:class:`Enum`
The corresponding member or a new instance of the enum if
``value`` isn't actually a member.
"""
try:
return cls._value_map_[value]
except (KeyError, TypeError):
return cls.__new__(cls, name=None, value=value)
@classmethod
def from_string(cls, name: str) -> Self:
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get.
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

View File

@ -127,7 +127,6 @@ class ServiceStub(ABC):
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_request()
await self._send_messages(stream, request_iterator)
response = await stream.recv_message()
assert response is not None

File diff suppressed because it is too large Load Diff

View File

@ -1 +1,152 @@
from betterproto.lib.std.google.protobuf.compiler import *
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/compiler/plugin.proto
# plugin: python-betterproto
# This file has been @generated
from dataclasses import dataclass
from typing import List
import betterproto
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
class CodeGeneratorResponseFeature(betterproto.Enum):
"""Sync with code_generator.h."""
FEATURE_NONE = 0
FEATURE_PROTO3_OPTIONAL = 1
@dataclass(eq=False, repr=False)
class Version(betterproto.Message):
"""The version number of protocol compiler."""
major: int = betterproto.int32_field(1)
minor: int = betterproto.int32_field(2)
patch: int = betterproto.int32_field(3)
suffix: str = betterproto.string_field(4)
"""
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
be empty for mainline stable releases.
"""
@dataclass(eq=False, repr=False)
class CodeGeneratorRequest(betterproto.Message):
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
file_to_generate: List[str] = betterproto.string_field(1)
"""
The .proto files that were explicitly listed on the command-line. The code
generator should generate code only for these files. Each file's
descriptor will be included in proto_file, below.
"""
parameter: str = betterproto.string_field(2)
"""The generator parameter passed on the command-line."""
proto_file: List[
"betterproto_lib_google_protobuf.FileDescriptorProto"
] = betterproto.message_field(15)
"""
FileDescriptorProtos for all files in files_to_generate and everything they
import. The files will appear in topological order, so each file appears
before any file that imports it. protoc guarantees that all proto_files
will be written after the fields above, even though this is not technically
guaranteed by the protobuf wire format. This theoretically could allow a
plugin to stream in the FileDescriptorProtos and handle them one by one
rather than read the entire set into memory at once. However, as of this
writing, this is not similarly optimized on protoc's end -- it will store
all fields in memory at once before sending them to the plugin. Type names
of fields and extensions in the FileDescriptorProto are always fully
qualified.
"""
compiler_version: "Version" = betterproto.message_field(3)
"""The version number of protocol compiler."""
@dataclass(eq=False, repr=False)
class CodeGeneratorResponse(betterproto.Message):
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
error: str = betterproto.string_field(1)
"""
Error message. If non-empty, code generation failed. The plugin process
should exit with status code zero even if it reports an error in this way.
This should be used to indicate errors in .proto files which prevent the
code generator from generating correct code. Errors which indicate a
problem in protoc itself -- such as the input CodeGeneratorRequest being
unparseable -- should be reported by writing a message to stderr and
exiting with a non-zero status code.
"""
supported_features: int = betterproto.uint64_field(2)
"""
A bitmask of supported features that the code generator supports. This is a
bitwise "or" of values from the Feature enum.
"""
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
@dataclass(eq=False, repr=False)
class CodeGeneratorResponseFile(betterproto.Message):
"""Represents a single generated file."""
name: str = betterproto.string_field(1)
"""
The file name, relative to the output directory. The name must not contain
"." or ".." components and must be relative, not be absolute (so, the file
cannot lie outside the output directory). "/" must be used as the path
separator, not "\". If the name is omitted, the content will be appended to
the previous file. This allows the generator to break large files into
small chunks, and allows the generated text to be streamed back to protoc
so that large files need not reside completely in memory at one time. Note
that as of this writing protoc does not optimize for this -- it will read
the entire CodeGeneratorResponse before writing files to disk.
"""
insertion_point: str = betterproto.string_field(2)
"""
If non-empty, indicates that the named file should already exist, and the
content here is to be inserted into that file at a defined insertion point.
This feature allows a code generator to extend the output produced by
another code generator. The original generator may provide insertion
points by placing special annotations in the file that look like:
@@protoc_insertion_point(NAME) The annotation can have arbitrary text
before and after it on the line, which allows it to be placed in a comment.
NAME should be replaced with an identifier naming the point -- this is what
other generators will use as the insertion_point. Code inserted at this
point will be placed immediately above the line containing the insertion
point (thus multiple insertions to the same point will come out in the
order they were added). The double-@ is intended to make it unlikely that
the generated code could contain things that look like insertion points by
accident. For example, the C++ code generator places the following line in
the .pb.h files that it generates: //
@@protoc_insertion_point(namespace_scope) This line appears within the
scope of the file's package namespace, but outside of any particular class.
Another plugin can then specify the insertion_point "namespace_scope" to
generate additional classes or other declarations that should be placed in
this scope. Note that if the line containing the insertion point begins
with whitespace, the same whitespace will be added to every line of the
inserted text. This is useful for languages like Python, where indentation
matters. In these languages, the insertion point comment should be
indented the same amount as any inserted code will need to be in order to
work correctly in that context. The code generator that generates the
initial file and the one which inserts into it must both run as part of a
single invocation of protoc. Code generators are executed in the order in
which they appear on the command line. If |insertion_point| is present,
|name| must also be present.
"""
content: str = betterproto.string_field(15)
"""The file contents."""
generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = (
betterproto.message_field(16)
)
"""
Information describing the file content being inserted. If an insertion
point is used, this information will be appropriately offset and inserted
into the code generation metadata for the generated files.
"""

File diff suppressed because it is too large Load Diff

View File

@ -1,210 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/compiler/plugin.proto
# plugin: python-betterproto
# This file has been @generated
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
from typing import List
import betterproto
import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf
class CodeGeneratorResponseFeature(betterproto.Enum):
"""Sync with code_generator.h."""
FEATURE_NONE = 0
FEATURE_PROTO3_OPTIONAL = 1
FEATURE_SUPPORTS_EDITIONS = 2
@dataclass(eq=False, repr=False)
class Version(betterproto.Message):
"""The version number of protocol compiler."""
major: int = betterproto.int32_field(1)
minor: int = betterproto.int32_field(2)
patch: int = betterproto.int32_field(3)
suffix: str = betterproto.string_field(4)
"""
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
be empty for mainline stable releases.
"""
@dataclass(eq=False, repr=False)
class CodeGeneratorRequest(betterproto.Message):
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
file_to_generate: List[str] = betterproto.string_field(1)
"""
The .proto files that were explicitly listed on the command-line. The
code generator should generate code only for these files. Each file's
descriptor will be included in proto_file, below.
"""
parameter: str = betterproto.string_field(2)
"""The generator parameter passed on the command-line."""
proto_file: List["betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"] = (
betterproto.message_field(15)
)
"""
FileDescriptorProtos for all files in files_to_generate and everything
they import. The files will appear in topological order, so each file
appears before any file that imports it.
Note: the files listed in files_to_generate will include runtime-retention
options only, but all other files will include source-retention options.
The source_file_descriptors field below is available in case you need
source-retention options for files_to_generate.
protoc guarantees that all proto_files will be written after
the fields above, even though this is not technically guaranteed by the
protobuf wire format. This theoretically could allow a plugin to stream
in the FileDescriptorProtos and handle them one by one rather than read
the entire set into memory at once. However, as of this writing, this
is not similarly optimized on protoc's end -- it will store all fields in
memory at once before sending them to the plugin.
Type names of fields and extensions in the FileDescriptorProto are always
fully qualified.
"""
source_file_descriptors: List[
"betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"
] = betterproto.message_field(17)
"""
File descriptors with all options, including source-retention options.
These descriptors are only provided for the files listed in
files_to_generate.
"""
compiler_version: "Version" = betterproto.message_field(3)
"""The version number of protocol compiler."""
@dataclass(eq=False, repr=False)
class CodeGeneratorResponse(betterproto.Message):
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
error: str = betterproto.string_field(1)
"""
Error message. If non-empty, code generation failed. The plugin process
should exit with status code zero even if it reports an error in this way.
This should be used to indicate errors in .proto files which prevent the
code generator from generating correct code. Errors which indicate a
problem in protoc itself -- such as the input CodeGeneratorRequest being
unparseable -- should be reported by writing a message to stderr and
exiting with a non-zero status code.
"""
supported_features: int = betterproto.uint64_field(2)
"""
A bitmask of supported features that the code generator supports.
This is a bitwise "or" of values from the Feature enum.
"""
minimum_edition: int = betterproto.int32_field(3)
"""
The minimum edition this plugin supports. This will be treated as an
Edition enum, but we want to allow unknown values. It should be specified
according the edition enum value, *not* the edition number. Only takes
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
"""
maximum_edition: int = betterproto.int32_field(4)
"""
The maximum edition this plugin supports. This will be treated as an
Edition enum, but we want to allow unknown values. It should be specified
according the edition enum value, *not* the edition number. Only takes
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
"""
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
@dataclass(eq=False, repr=False)
class CodeGeneratorResponseFile(betterproto.Message):
"""Represents a single generated file."""
name: str = betterproto.string_field(1)
"""
The file name, relative to the output directory. The name must not
contain "." or ".." components and must be relative, not be absolute (so,
the file cannot lie outside the output directory). "/" must be used as
the path separator, not "\".
If the name is omitted, the content will be appended to the previous
file. This allows the generator to break large files into small chunks,
and allows the generated text to be streamed back to protoc so that large
files need not reside completely in memory at one time. Note that as of
this writing protoc does not optimize for this -- it will read the entire
CodeGeneratorResponse before writing files to disk.
"""
insertion_point: str = betterproto.string_field(2)
"""
If non-empty, indicates that the named file should already exist, and the
content here is to be inserted into that file at a defined insertion
point. This feature allows a code generator to extend the output
produced by another code generator. The original generator may provide
insertion points by placing special annotations in the file that look
like:
@@protoc_insertion_point(NAME)
The annotation can have arbitrary text before and after it on the line,
which allows it to be placed in a comment. NAME should be replaced with
an identifier naming the point -- this is what other generators will use
as the insertion_point. Code inserted at this point will be placed
immediately above the line containing the insertion point (thus multiple
insertions to the same point will come out in the order they were added).
The double-@ is intended to make it unlikely that the generated code
could contain things that look like insertion points by accident.
For example, the C++ code generator places the following line in the
.pb.h files that it generates:
// @@protoc_insertion_point(namespace_scope)
This line appears within the scope of the file's package namespace, but
outside of any particular class. Another plugin can then specify the
insertion_point "namespace_scope" to generate additional classes or
other declarations that should be placed in this scope.
Note that if the line containing the insertion point begins with
whitespace, the same whitespace will be added to every line of the
inserted text. This is useful for languages like Python, where
indentation matters. In these languages, the insertion point comment
should be indented the same amount as any inserted code will need to be
in order to work correctly in that context.
The code generator that generates the initial file and the one which
inserts into it must both run as part of a single invocation of protoc.
Code generators are executed in the order in which they appear on the
command line.
If |insertion_point| is present, |name| must also be present.
"""
content: str = betterproto.string_field(15)
"""The file contents."""
generated_code_info: "betterproto_lib_pydantic_google_protobuf.GeneratedCodeInfo" = betterproto.message_field(
16
)
"""
Information describing the file content being inserted. If an insertion
point is used, this information will be appropriately offset and inserted
into the code generation metadata for the generated files.
"""
CodeGeneratorRequest.__pydantic_model__.update_forward_refs() # type: ignore
CodeGeneratorResponse.__pydantic_model__.update_forward_refs() # type: ignore
CodeGeneratorResponseFile.__pydantic_model__.update_forward_refs() # type: ignore

File diff suppressed because it is too large Load Diff

View File

@ -1,198 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/compiler/plugin.proto
# plugin: python-betterproto
# This file has been @generated
from dataclasses import dataclass
from typing import List
import betterproto
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
class CodeGeneratorResponseFeature(betterproto.Enum):
"""Sync with code_generator.h."""
FEATURE_NONE = 0
FEATURE_PROTO3_OPTIONAL = 1
FEATURE_SUPPORTS_EDITIONS = 2
@dataclass(eq=False, repr=False)
class Version(betterproto.Message):
"""The version number of protocol compiler."""
major: int = betterproto.int32_field(1)
minor: int = betterproto.int32_field(2)
patch: int = betterproto.int32_field(3)
suffix: str = betterproto.string_field(4)
"""
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
be empty for mainline stable releases.
"""
@dataclass(eq=False, repr=False)
class CodeGeneratorRequest(betterproto.Message):
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
file_to_generate: List[str] = betterproto.string_field(1)
"""
The .proto files that were explicitly listed on the command-line. The
code generator should generate code only for these files. Each file's
descriptor will be included in proto_file, below.
"""
parameter: str = betterproto.string_field(2)
"""The generator parameter passed on the command-line."""
proto_file: List["betterproto_lib_google_protobuf.FileDescriptorProto"] = (
betterproto.message_field(15)
)
"""
FileDescriptorProtos for all files in files_to_generate and everything
they import. The files will appear in topological order, so each file
appears before any file that imports it.
Note: the files listed in files_to_generate will include runtime-retention
options only, but all other files will include source-retention options.
The source_file_descriptors field below is available in case you need
source-retention options for files_to_generate.
protoc guarantees that all proto_files will be written after
the fields above, even though this is not technically guaranteed by the
protobuf wire format. This theoretically could allow a plugin to stream
in the FileDescriptorProtos and handle them one by one rather than read
the entire set into memory at once. However, as of this writing, this
is not similarly optimized on protoc's end -- it will store all fields in
memory at once before sending them to the plugin.
Type names of fields and extensions in the FileDescriptorProto are always
fully qualified.
"""
source_file_descriptors: List[
"betterproto_lib_google_protobuf.FileDescriptorProto"
] = betterproto.message_field(17)
"""
File descriptors with all options, including source-retention options.
These descriptors are only provided for the files listed in
files_to_generate.
"""
compiler_version: "Version" = betterproto.message_field(3)
"""The version number of protocol compiler."""
@dataclass(eq=False, repr=False)
class CodeGeneratorResponse(betterproto.Message):
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
error: str = betterproto.string_field(1)
"""
Error message. If non-empty, code generation failed. The plugin process
should exit with status code zero even if it reports an error in this way.
This should be used to indicate errors in .proto files which prevent the
code generator from generating correct code. Errors which indicate a
problem in protoc itself -- such as the input CodeGeneratorRequest being
unparseable -- should be reported by writing a message to stderr and
exiting with a non-zero status code.
"""
supported_features: int = betterproto.uint64_field(2)
"""
A bitmask of supported features that the code generator supports.
This is a bitwise "or" of values from the Feature enum.
"""
minimum_edition: int = betterproto.int32_field(3)
"""
The minimum edition this plugin supports. This will be treated as an
Edition enum, but we want to allow unknown values. It should be specified
according the edition enum value, *not* the edition number. Only takes
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
"""
maximum_edition: int = betterproto.int32_field(4)
"""
The maximum edition this plugin supports. This will be treated as an
Edition enum, but we want to allow unknown values. It should be specified
according the edition enum value, *not* the edition number. Only takes
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
"""
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
@dataclass(eq=False, repr=False)
class CodeGeneratorResponseFile(betterproto.Message):
"""Represents a single generated file."""
name: str = betterproto.string_field(1)
"""
The file name, relative to the output directory. The name must not
contain "." or ".." components and must be relative, not be absolute (so,
the file cannot lie outside the output directory). "/" must be used as
the path separator, not "\".
If the name is omitted, the content will be appended to the previous
file. This allows the generator to break large files into small chunks,
and allows the generated text to be streamed back to protoc so that large
files need not reside completely in memory at one time. Note that as of
this writing protoc does not optimize for this -- it will read the entire
CodeGeneratorResponse before writing files to disk.
"""
insertion_point: str = betterproto.string_field(2)
"""
If non-empty, indicates that the named file should already exist, and the
content here is to be inserted into that file at a defined insertion
point. This feature allows a code generator to extend the output
produced by another code generator. The original generator may provide
insertion points by placing special annotations in the file that look
like:
@@protoc_insertion_point(NAME)
The annotation can have arbitrary text before and after it on the line,
which allows it to be placed in a comment. NAME should be replaced with
an identifier naming the point -- this is what other generators will use
as the insertion_point. Code inserted at this point will be placed
immediately above the line containing the insertion point (thus multiple
insertions to the same point will come out in the order they were added).
The double-@ is intended to make it unlikely that the generated code
could contain things that look like insertion points by accident.
For example, the C++ code generator places the following line in the
.pb.h files that it generates:
// @@protoc_insertion_point(namespace_scope)
This line appears within the scope of the file's package namespace, but
outside of any particular class. Another plugin can then specify the
insertion_point "namespace_scope" to generate additional classes or
other declarations that should be placed in this scope.
Note that if the line containing the insertion point begins with
whitespace, the same whitespace will be added to every line of the
inserted text. This is useful for languages like Python, where
indentation matters. In these languages, the insertion point comment
should be indented the same amount as any inserted code will need to be
in order to work correctly in that context.
The code generator that generates the initial file and the one which
inserts into it must both run as part of a single invocation of protoc.
Code generators are executed in the order in which they appear on the
command line.
If |insertion_point| is present, |name| must also be present.
"""
content: str = betterproto.string_field(15)
"""The file contents."""
generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = (
betterproto.message_field(16)
)
"""
Information describing the file content being inserted. If an insertion
point is used, this information will be appropriately offset and inserted
into the code generation metadata for the generated files.
"""

View File

@ -1,12 +1,10 @@
import os.path
import subprocess
import sys
from .module_validation import ModuleValidator
try:
# betterproto[compiler] specific dependencies
import black
import isort.api
import jinja2
except ImportError as err:
print(
@ -31,34 +29,22 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
undefined=jinja2.StrictUndefined,
)
# Load the body first so we have a compleate list of imports needed.
body_template = env.get_template("template.py.j2")
header_template = env.get_template("header.py.j2")
template = env.get_template("template.py.j2")
code = body_template.render(output_file=output_file)
code = header_template.render(output_file=output_file) + code
# Sort imports, delete unused ones
code = subprocess.check_output(
["ruff", "check", "--select", "I,F401", "--fix", "--silent", "-"],
input=code,
encoding="utf-8",
code = template.render(output_file=output_file)
code = isort.api.sort_code_string(
code=code,
show_diff=False,
py_version=37,
profile="black",
combine_as_imports=True,
lines_after_imports=2,
quiet=True,
force_grid_wrap=2,
known_third_party=["grpclib", "betterproto"],
)
# Format the code
code = subprocess.check_output(
["ruff", "format", "-"], input=code, encoding="utf-8"
return black.format_str(
src_contents=code,
mode=black.Mode(),
)
# Validate the generated code.
validator = ModuleValidator(iter(code.splitlines()))
if not validator.validate():
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
for collision, lines in validator.collisions.items():
message_builder.append(f' "{collision}" on lines:')
for num, line in lines:
message_builder.append(f" {num}:{line}")
print("\n".join(message_builder), file=sys.stderr)
return code

View File

@ -29,8 +29,10 @@ instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute.
"""
import builtins
import re
import textwrap
from dataclasses import (
dataclass,
field,
@ -47,6 +49,12 @@ from typing import (
)
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
@ -64,21 +72,16 @@ from betterproto.lib.google.protobuf import (
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
from .. import which_one_of
from ..casing import sanitize_name
from ..compile.importing import (
get_type_reference,
parse_source_type_name,
)
from ..compile.naming import (
pythonize_class_name,
pythonize_enum_member_name,
pythonize_field_name,
pythonize_method_name,
)
from .typing_compiler import (
DirectImportTypingCompiler,
TypingCompiler,
)
# Create a unique placeholder to deal with
@ -139,12 +142,12 @@ def monkey_patch_oneof_index():
"betterproto"
],
"group",
"oneof_index",
"_oneof_index",
)
object.__setattr__(
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
"group",
"oneof_index",
"_oneof_index",
)
@ -153,33 +156,14 @@ def get_comment(
) -> str:
pad = " " * indent
for sci_loc in proto_file.source_code_info.location:
if list(sci_loc.path) == path:
all_comments = list(sci_loc.leading_detached_comments)
if sci_loc.leading_comments:
all_comments.append(sci_loc.leading_comments)
if sci_loc.trailing_comments:
all_comments.append(sci_loc.trailing_comments)
lines = []
for comment in all_comments:
lines += comment.split("\n")
lines.append("")
# Remove consecutive empty lines
lines = [
line for i, line in enumerate(lines) if line or (i == 0 or lines[i - 1])
]
if lines and not lines[-1]:
lines.pop() # Remove the last empty line
# It is common for one line comments to start with a space, for example: // comment
# We don't add this space to the generated file.
lines = [line[1:] if line and line[0] == " " else line for line in lines]
if list(sci_loc.path) == path and sci_loc.leading_comments:
lines = textwrap.wrap(
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
)
# This is a field, message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
@ -192,7 +176,6 @@ class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
@ -260,8 +243,9 @@ class OutputTemplate:
parent_request: PluginRequestCompiler
package_proto_obj: FileDescriptorProto
input_files: List[str] = field(default_factory=list)
imports_end: Set[str] = field(default_factory=set)
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
@ -270,7 +254,6 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
@property
def package(self) -> str:
@ -297,21 +280,8 @@ class OutputTemplate:
@property
def python_module_imports(self) -> Set[str]:
imports = set()
has_deprecated = False
if any(m.deprecated for m in self.messages):
has_deprecated = True
if any(x for x in self.messages if any(x.deprecated_fields)):
has_deprecated = True
if any(
any(m.proto_obj.options.deprecated for m in s.methods)
for s in self.services
):
has_deprecated = True
if has_deprecated:
imports.add("warnings")
if self.builtins_import:
imports.add("builtins")
return imports
@ -322,7 +292,6 @@ class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
@ -350,6 +319,12 @@ class MessageCompiler(ProtoContentBase):
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return self.py_name
@property
def deprecated_fields(self) -> Iterator[str]:
for f in self.fields:
@ -410,10 +385,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
us to tell whether it was set, via the which_one_of interface.
"""
return (
not proto_field_obj.proto3_optional
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
)
return which_one_of(proto_field_obj, "_oneof_index")[0] == "oneof_index"
@dataclass
@ -462,6 +434,18 @@ class FieldCompiler(MessageCompiler):
imports.add("datetime")
return imports
@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports
@property
def pydantic_imports(self) -> Set[str]:
return set()
@ -474,6 +458,7 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins
@ -500,6 +485,11 @@ class FieldCompiler(MessageCompiler):
def optional(self) -> bool:
return self.proto_obj.proto3_optional
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
return self.annotation.startswith(("List[", "Dict["))
@property
def field_type(self) -> str:
"""String representation of proto field type."""
@ -509,6 +499,35 @@ class FieldCompiler(MessageCompiler):
.replace("type_", "")
)
@property
def default_value_string(self) -> str:
"""Python representation of the default proto value."""
if self.repeated:
return "[]"
if self.optional:
return "None"
if self.py_type == "int":
return "0"
if self.py_type == "float":
return "0.0"
elif self.py_type == "bool":
return "False"
elif self.py_type == "str":
return '""'
elif self.py_type == "bytes":
return 'b""'
elif self.field_type == "enum":
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
enum = next(
e
for e in self.output_file.enums
if e.proto_obj.name == enum_proto_obj_name
)
return enum.default_value_string
else:
# Message type
return "None"
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
@ -541,10 +560,8 @@ class FieldCompiler(MessageCompiler):
# Type referencing another defined Message or a named enum
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports_end,
imports=self.output_file.imports,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
pydantic=self.output_file.pydantic_dataclasses,
)
else:
raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
@ -555,9 +572,9 @@ class FieldCompiler(MessageCompiler):
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return self.typing_compiler.list(py_type)
return f"List[{py_type}]"
if self.optional:
return self.typing_compiler.optional(py_type)
return f"Optional[{py_type}]"
return py_type
@ -582,7 +599,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property
def pydantic_imports(self) -> Set[str]:
return {"model_validator"}
return {"root_validator"}
@dataclass
@ -605,13 +622,11 @@ class MapEntryCompiler(FieldCompiler):
source_file=self.source_file,
parent=self,
proto_obj=nested.field[0], # key
typing_compiler=self.typing_compiler,
).py_type
self.py_v_type = FieldCompiler(
source_file=self.source_file,
parent=self,
proto_obj=nested.field[1], # value
typing_compiler=self.typing_compiler,
).py_type
# Get proto types
@ -629,7 +644,7 @@ class MapEntryCompiler(FieldCompiler):
@property
def annotation(self) -> str:
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
@property
def repeated(self) -> bool:
@ -655,9 +670,7 @@ class EnumDefinitionCompiler(MessageCompiler):
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
name=pythonize_enum_member_name(
entry_proto_value.name, self.proto_obj.name
),
name=sanitize_name(entry_proto_value.name),
value=entry_proto_value.number,
comment=get_comment(
proto_file=self.source_file, path=self.path + [2, entry_number]
@ -667,10 +680,17 @@ class EnumDefinitionCompiler(MessageCompiler):
]
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> str:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
"""
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
@dataclass
class ServiceCompiler(ProtoContentBase):
source_file: FileDescriptorProto
parent: OutputTemplate = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
@ -679,6 +699,7 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields
@property
@ -692,7 +713,6 @@ class ServiceCompiler(ProtoContentBase):
@dataclass
class ServiceMethodCompiler(ProtoContentBase):
source_file: FileDescriptorProto
parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
@ -702,6 +722,22 @@ class ServiceMethodCompiler(ProtoContentBase):
# Add method to service
self.parent.methods.append(self)
# Check for imports
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
# Required by both client and server
if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
@ -729,6 +765,30 @@ class ServiceMethodCompiler(ProtoContentBase):
)
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object.
Returns
-------
Optional[MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
"""
package, name = parse_source_type_name(self.proto_obj.input_type)
# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is
# comparable with method.input_type
for msg in self.request.all_messages:
if (
msg.py_name == pythonize_class_name(name.replace(".", ""))
and msg.output_file.package == package
):
return msg
return None
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type corresponding to the
@ -741,11 +801,9 @@ class ServiceMethodCompiler(ProtoContentBase):
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports_end,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
@property
@ -771,11 +829,9 @@ class ServiceMethodCompiler(ProtoContentBase):
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports_end,
imports=self.output_file.imports,
source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
@property

View File

@ -1,163 +0,0 @@
import re
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
List,
Tuple,
)
@dataclass
class ModuleValidator:
line_iterator: Iterator[str]
line_number: int = field(init=False, default=0)
collisions: Dict[str, List[Tuple[int, str]]] = field(
init=False, default_factory=lambda: defaultdict(list)
)
def add_import(self, imp: str, number: int, full_line: str):
"""
Adds an import to be tracked.
"""
self.collisions[imp].append((number, full_line))
def process_import(self, imp: str):
"""
Filters out the import to its actual value.
"""
if " as " in imp:
imp = imp[imp.index(" as ") + 4 :]
imp = imp.strip()
assert " " not in imp, imp
return imp
def evaluate_multiline_import(self, line: str):
"""
Evaluates a multiline import from a starting line
"""
# Filter the first line and remove anything before the import statement.
full_line = line
line = line.split("import", 1)[1]
if "(" in line:
conditional = lambda line: ")" not in line
else:
conditional = lambda line: "\\" in line
# Remove open parenthesis if it exists.
if "(" in line:
line = line[line.index("(") + 1 :]
# Choose the conditional based on how multiline imports are formatted.
while conditional(line):
# Split the line by commas
imports = line.split(",")
for imp in imports:
# Add the import to the namespace
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
# Get the next line
full_line = line = next(self.line_iterator)
# Increment the line number
self.line_number += 1
# validate the last line
if ")" in line:
line = line[: line.index(")")]
imports = line.split(",")
for imp in imports:
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
def evaluate_import(self, line: str):
"""
Extracts an import from a line.
"""
whole_line = line
line = line[line.index("import") + 6 :]
values = line.split(",")
for v in values:
self.add_import(self.process_import(v), self.line_number, whole_line)
def next(self):
"""
Evaluate each line for names in the module.
"""
line = next(self.line_iterator)
# Skip lines with indentation or comments
if (
# Skip indents and whitespace.
line.startswith(" ")
or line == "\n"
or line.startswith("\t")
or
# Skip comments
line.startswith("#")
or
# Skip decorators
line.startswith("@")
):
self.line_number += 1
return
# Skip docstrings.
if line.startswith('"""') or line.startswith("'''"):
quote = line[0] * 3
line = line[3:]
while quote not in line:
line = next(self.line_iterator)
self.line_number += 1
return
# Evaluate Imports.
if line.startswith("from ") or line.startswith("import "):
if "(" in line or "\\" in line:
self.evaluate_multiline_import(line)
else:
self.evaluate_import(line)
# Evaluate Classes.
elif line.startswith("class "):
class_name = re.search(r"class (\w+)", line).group(1)
if class_name:
self.add_import(class_name, self.line_number, line)
# Evaluate Functions.
elif line.startswith("def "):
function_name = re.search(r"def (\w+)", line).group(1)
if function_name:
self.add_import(function_name, self.line_number, line)
# Evaluate direct assignments.
elif "=" in line:
assignment = re.search(r"(\w+)\s*=", line).group(1)
if assignment:
self.add_import(assignment, self.line_number, line)
self.line_number += 1
def validate(self) -> bool:
"""
Run Validation.
"""
try:
while True:
self.next()
except StopIteration:
pass
# Filter collisions for those with more than one value.
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
# Return True if no collisions are found.
return not bool(self.collisions)

View File

@ -37,12 +37,6 @@ from .models import (
is_map,
is_oneof,
)
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)
def traverse(
@ -104,28 +98,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
output_package_name
].pydantic_dataclasses = True
# Gather any typing generation options.
typing_opts = [
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
]
if len(typing_opts) > 1:
raise ValueError("Multiple typing options provided")
# Set the compiler type.
typing_opt = typing_opts[0] if typing_opts else "direct"
if typing_opt == "direct":
request_data.output_packages[
output_package_name
].typing_compiler = DirectImportTypingCompiler()
elif typing_opt == "root":
request_data.output_packages[
output_package_name
].typing_compiler = TypingImportTypingCompiler()
elif typing_opt == "310":
request_data.output_packages[
output_package_name
].typing_compiler = NoTyping310TypingCompiler()
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
@ -143,7 +115,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(proto_input_file, service, index, output_package)
read_protobuf_service(service, index, output_package)
# Generate output files
output_paths: Set[pathlib.Path] = set()
@ -194,7 +166,6 @@ def _make_one_of_field_compiler(
parent=parent,
proto_obj=proto_obj,
path=path,
typing_compiler=output_package.typing_compiler,
)
@ -210,11 +181,7 @@ def read_protobuf_type(
return
# Process Message
message_data = MessageCompiler(
source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
for index, field in enumerate(item.field):
if is_map(field, item):
@ -223,7 +190,6 @@ def read_protobuf_type(
parent=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
)
elif is_oneof(field):
_make_one_of_field_compiler(
@ -235,35 +201,21 @@ def read_protobuf_type(
parent=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(
source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
def read_protobuf_service(
source_file: FileDescriptorProto,
service: ServiceDescriptorProto,
index: int,
output_package: OutputTemplate,
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
) -> None:
service_data = ServiceCompiler(
source_file=source_file,
parent=output_package,
proto_obj=service,
path=[6, index],
parent=output_package, proto_obj=service, path=[6, index]
)
for j, method in enumerate(service.method):
ServiceMethodCompiler(
source_file=source_file,
parent=service_data,
proto_obj=method,
path=[6, index, 2, j],
parent=service_data, proto_obj=method, path=[6, index, 2, j]
)

View File

@ -1,173 +0,0 @@
import abc
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
Optional,
Set,
)
class TypingCompiler(metaclass=abc.ABCMeta):
@abc.abstractmethod
def optional(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def list(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def dict(self, key: str, value: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def union(self, *types: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterator(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def imports(self) -> Dict[str, Optional[Set[str]]]:
"""
Returns either the direct import as a key with none as value, or a set of
values to import from the key.
"""
raise NotImplementedError()
def import_lines(self) -> Iterator:
imports = self.imports()
for key, value in imports.items():
if value is None:
yield f"import {key}"
else:
yield f"from {key} import ("
for v in sorted(value):
yield f" {v},"
yield ")"
@dataclass
class DirectImportTypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
self._imports["typing"].add("Optional")
return f"Optional[{type}]"
def list(self, type: str) -> str:
self._imports["typing"].add("List")
return f"List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imports["typing"].add("Dict")
return f"Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imports["typing"].add("Union")
return f"Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}
@dataclass
class TypingImportTypingCompiler(TypingCompiler):
_imported: bool = False
def optional(self, type: str) -> str:
self._imported = True
return f"typing.Optional[{type}]"
def list(self, type: str) -> str:
self._imported = True
return f"typing.List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imported = True
return f"typing.Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imported = True
return f"typing.Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imported = True
return f"typing.Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
if self._imported:
return {"typing": None}
return {}
@dataclass
class NoTyping310TypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
@staticmethod
def _fmt(type: str) -> str: # for now this is necessary till 3.14
if type.startswith('"'):
return type[1:-1]
return type
def optional(self, type: str) -> str:
return f'"{self._fmt(type)} | None"'
def list(self, type: str) -> str:
return f'"list[{self._fmt(type)}]"'
def dict(self, key: str, value: str) -> str:
return f'"dict[{key}, {self._fmt(value)}]"'
def union(self, *types: str) -> str:
return f'"{" | ".join(map(self._fmt, types))}"'
def iterable(self, type: str) -> str:
self._imports["collections.abc"].add("Iterable")
return f'"Iterable[{type}]"'
def async_iterable(self, type: str) -> str:
self._imports["collections.abc"].add("AsyncIterable")
return f'"AsyncIterable[{type}]"'
def async_iterator(self, type: str) -> str:
self._imports["collections.abc"].add("AsyncIterator")
return f'"AsyncIterator[{type}]"'
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}

View File

@ -1,57 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
__all__ = (
{%- for enum in output_file.enums -%}
"{{ enum.py_name }}",
{%- endfor -%}
{%- for message in output_file.messages -%}
"{{ message.py_name }}",
{%- endfor -%}
{%- for service in output_file.services -%}
"{{ service.py_name }}Stub",
"{{ service.py_name }}Base",
{%- endfor -%}
)
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% set typing_imports = output_file.typing_compiler.imports() %}
{% if typing_imports %}
{% for line in output_file.typing_compiler.import_lines() %}
{{ line }}
{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}

View File

@ -1,3 +1,53 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if output_file.typing_imports %}
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
@ -12,22 +62,11 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endif %}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
{% endif %}
{% endfor %}
{% endif %}
{% for message in output_file.messages %}
{% if output_file.pydantic_dataclasses %}
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
{% else %}
@dataclass(eq=False, repr=False)
{% endif %}
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
@ -57,7 +96,7 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %}
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@model_validator(mode='after')
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
@ -74,24 +113,20 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}"
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
,
*
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{% if method.proto_obj.options.deprecated %}
warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning)
{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
@ -143,10 +178,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %}
{% endfor %}
{% for i in output_file.imports_end %}
{{ i }}
{% endfor %}
{% for service in output_file.services %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
@ -157,12 +188,12 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%}
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
@ -194,7 +225,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %}
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(
@ -215,3 +246,11 @@ class {{ service.py_name }}Base(ServiceBase):
}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %}
{% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
{% endif %}
{% endfor %}
{% endif %}

View File

@ -1,56 +0,0 @@
from __future__ import annotations
from typing import (
Any,
Callable,
Generic,
Optional,
Type,
TypeVar,
)
from typing_extensions import (
Concatenate,
ParamSpec,
Self,
)
SelfT = TypeVar("SelfT")
P = ParamSpec("P")
HybridT = TypeVar("HybridT", covariant=True)
class hybridmethod(Generic[SelfT, P, HybridT]):
def __init__(
self,
func: Callable[
Concatenate[type[SelfT], P], HybridT
], # Must be the classmethod version
):
self.cls_func = func
self.__doc__ = func.__doc__
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
self.instance_func = func
return self
def __get__(
self, instance: Optional[SelfT], owner: Type[SelfT]
) -> Callable[P, HybridT]:
if instance is None or self.instance_func is None:
# either bound to the class, or no instance method available
return self.cls_func.__get__(owner, None)
return self.instance_func.__get__(instance, owner)
T_co = TypeVar("T_co")
TT_co = TypeVar("TT_co", bound="type[Any]")
class classproperty(Generic[TT_co, T_co]):
def __init__(self, func: Callable[[TT_co], T_co]):
self.__func__ = func
def __get__(self, instance: Any, type: TT_co) -> T_co:
return self.__func__(type)

View File

@ -4,6 +4,17 @@ import sys
import pytest
def pytest_addoption(parser):
parser.addoption(
"--repeat", type=int, default=1, help="repeat the operation multiple times"
)
@pytest.fixture(scope="session")
def repeat(request):
return request.config.getoption("repeat")
@pytest.fixture
def reset_sys_path():
original = copy.deepcopy(sys.path)

View File

@ -108,7 +108,6 @@ async def generate_test_case_output(
print(
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
)
print(ref_err.decode())
if verbose:
if ref_out:
@ -127,7 +126,6 @@ async def generate_test_case_output(
print(
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
)
print(plg_err.decode())
if verbose:
if plg_out:
@ -148,7 +146,6 @@ async def generate_test_case_output(
print(
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
)
print(plg_err_pyd.decode())
if verbose:
if plg_out_pyd:

View File

@ -1,4 +1,5 @@
import asyncio
import sys
import uuid
import grpclib
@ -26,12 +27,12 @@ async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
def _assert_request_meta_received(deadline, metadata):
def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx(deadline._timestamp, 1), (
"The provided deadline should be received serverside"
)
assert stream.metadata["authorization"] == metadata["authorization"], (
"The provided authorization metadata should be received serverside"
)
assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1
), "The provided deadline should be received serverside"
assert (
stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be received serverside"
return server_side_test
@ -90,6 +91,9 @@ async def test_trailer_only_error_stream_unary(
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
)
async def test_service_call_mutable_defaults(mocker):
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
@ -265,30 +269,6 @@ async def test_async_gen_for_stream_stream_request():
else:
# No more things to send make sure channel is closed
request_chan.close()
assert response_index == len(expected_things), (
"Didn't receive all expected responses"
)
@pytest.mark.asyncio
async def test_stream_unary_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [DoThingRequest(name) for name in things]
response = await client.do_many_things(requests)
assert len(response.names) == 0
@pytest.mark.asyncio
async def test_stream_stream_with_empty_iterable():
things = [] # empty
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
requests = [GetThingRequest(name) for name in things]
responses = [
response async for response in client.get_different_things(requests)
]
assert len(responses) == 0
assert response_index == len(
expected_things
), "Didn't receive all expected responses"

View File

@ -27,7 +27,7 @@ class ThingService:
async def do_many_things(
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
):
thing_names = [request.name async for request in stream]
thing_names = [request.name for request in stream]
if self.test_hook is not None:
self.test_hook(stream)
await stream.send_message(DoThingResponse(thing_names))

View File

@ -10,15 +10,10 @@ def test_value():
def test_pydantic_no_value():
message = TestPyd()
assert not message.value, "Boolean is False by default"
with pytest.raises(ValueError):
TestPyd()
def test_pydantic_value():
message = TestPyd(value=False)
message = Test(value=False)
assert not message.value
def test_pydantic_bad_value():
with pytest.raises(ValueError):
TestPyd(value=123)

View File

@ -4,20 +4,20 @@ from tests.output_betterproto.casing import Test
def test_message_attributes():
message = Test()
assert hasattr(message, "snake_case_message"), (
"snake_case field name is same in python"
)
assert hasattr(
message, "snake_case_message"
), "snake_case field name is same in python"
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
def test_message_casing():
assert hasattr(casing, "SnakeCaseMessage"), (
"snake_case Message name is converted to CamelCase in python"
)
assert hasattr(
casing, "SnakeCaseMessage"
), "snake_case Message name is converted to CamelCase in python"
def test_enum_casing():
assert hasattr(casing, "MyEnum"), (
"snake_case Enum name is converted to CamelCase in python"
)
assert hasattr(
casing, "MyEnum"
), "snake_case Enum name is converted to CamelCase in python"

View File

@ -2,13 +2,13 @@ import tests.output_betterproto.casing_inner_class as casing_inner_class
def test_message_casing_inner_class_name():
assert hasattr(casing_inner_class, "TestInnerClass"), (
"Inline defined Message is correctly converted to CamelCase"
)
assert hasattr(
casing_inner_class, "TestInnerClass"
), "Inline defined Message is correctly converted to CamelCase"
def test_message_casing_inner_class_attributes():
message = casing_inner_class.Test()
assert hasattr(message.inner, "old_exp"), (
"Inline defined Message attribute is snake_case"
)
assert hasattr(
message.inner, "old_exp"
), "Inline defined Message attribute is snake_case"

View File

@ -3,12 +3,12 @@ from tests.output_betterproto.casing_message_field_uppercase import Test
def test_message_casing():
message = Test()
assert hasattr(message, "uppercase"), (
"UPPERCASE attribute is converted to 'uppercase' in python"
)
assert hasattr(message, "uppercase_v2"), (
"UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
)
assert hasattr(message, "upper_camel_case"), (
"UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"
)
assert hasattr(
message, "uppercase"
), "UPPERCASE attribute is converted to 'uppercase' in python"
assert hasattr(
message, "uppercase_v2"
), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
assert hasattr(
message, "upper_camel_case"
), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"

View File

@ -12,10 +12,3 @@ message Message {
option deprecated = true;
string value = 1;
}
message Empty {}
service TestService {
rpc func(Empty) returns (Empty);
rpc deprecated_func(Empty) returns (Empty) { option deprecated = true; };
}

View File

@ -1,44 +0,0 @@
syntax = "proto3";
package documentation;
// Documentation of message 1
// other line 1
// Documentation of message 2
// other line 2
message Test { // Documentation of message 3
// Documentation of field 1
// other line 1
// Documentation of field 2
// other line 2
uint32 x = 1; // Documentation of field 3
}
// Documentation of enum 1
// other line 1
// Documentation of enum 2
// other line 2
enum Enum { // Documentation of enum 3
// Documentation of variant 1
// other line 1
// Documentation of variant 2
// other line 2
Enum_Variant = 0; // Documentation of variant 3
}
// Documentation of service 1
// other line 1
// Documentation of service 2
// other line 2
service Service { // Documentation of service 3
// Documentation of method 1
// other line 1
// Documentation of method 2
// other line 2
rpc get(Test) returns (Test); // Documentation of method 3
}

View File

@ -15,11 +15,3 @@ enum Choice {
FOUR = 4;
THREE = 3;
}
// A "C" like enum with the enum name prefixed onto members, these should be stripped
enum ArithmeticOperator {
ARITHMETIC_OPERATOR_NONE = 0;
ARITHMETIC_OPERATOR_PLUS = 1;
ARITHMETIC_OPERATOR_MINUS = 2;
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
}

View File

@ -1,5 +1,4 @@
from tests.output_betterproto.enum import (
ArithmeticOperator,
Choice,
Test,
)
@ -27,9 +26,9 @@ def test_enum_is_comparable_with_int():
def test_enum_to_dict():
assert "choice" not in Test(choice=Choice.ZERO).to_dict(), (
"Default enum value is not serialized"
)
assert (
"choice" not in Test(choice=Choice.ZERO).to_dict()
), "Default enum value is not serialized"
assert (
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
== "ZERO"
@ -83,32 +82,3 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
yield Choice.THREE
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
def test_enum_mapped_on_parse():
# test default value
b = Test().parse(bytes(Test()))
assert b.choice.name == Choice.ZERO.name
assert b.choices == []
# test non default value
a = Test().parse(bytes(Test(choice=Choice.ONE)))
assert a.choice.name == Choice.ONE.name
assert b.choices == []
# test repeated
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
assert c.choices[0].name == Choice.THREE.name
assert c.choices[1].name == Choice.FOUR.name
# bonus: defaults after empty init are also mapped
assert Test().choice.name == Choice.ZERO.name
def test_renamed_enum_members():
assert set(ArithmeticOperator.__members__) == {
"NONE",
"PLUS",
"MINUS",
"_0_PREFIXED",
}

View File

@ -1,6 +1,5 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
package google_impl_behavior_equivalence;
message Foo { int64 bar = 1; }
@ -13,10 +12,6 @@ message Test {
}
}
message Spam {
google.protobuf.Timestamp ts = 1;
}
message Request { Empty foo = 1; }
message Empty {}
message Empty {}

View File

@ -1,25 +1,17 @@
from datetime import (
datetime,
timezone,
)
import pytest
from google.protobuf import json_format
from google.protobuf.timestamp_pb2 import Timestamp
import betterproto
from tests.output_betterproto.google_impl_behavior_equivalence import (
Empty,
Foo,
Request,
Spam,
Test,
)
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
Empty as ReferenceEmpty,
Foo as ReferenceFoo,
Request as ReferenceRequest,
Spam as ReferenceSpam,
Test as ReferenceTest,
)
@ -67,19 +59,6 @@ def test_bytes_are_the_same_for_oneof():
assert isinstance(message_reference2.foo, ReferenceFoo)
@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
def test_datetime_clamping(dt): # see #407
ts = Timestamp()
ts.FromDatetime(dt)
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
message_bytes = bytes(Spam(dt))
assert (
Spam().parse(message_bytes).ts.timestamp()
== ReferenceSpam.FromString(message_bytes).ts.seconds
)
def test_empty_message_field():
message = Request()
reference_message = ReferenceRequest()

View File

@ -26,5 +26,5 @@ import "other.proto";
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
message Test {
RootPackageMessage message = 1;
other.OtherPackageMessage other_value = 2;
other.OtherPackageMessage other = 2;
}

View File

@ -1,7 +0,0 @@
syntax = "proto3";
package invalid_field;
message Test {
int32 x = 1;
}

View File

@ -1,17 +0,0 @@
import pytest
def test_invalid_field():
from tests.output_betterproto.invalid_field import Test
with pytest.raises(TypeError):
Test(unknown_field=12)
def test_invalid_field_pydantic():
from pydantic import ValidationError
from tests.output_betterproto_pydantic.invalid_field import Test
with pytest.raises(ValidationError):
Test(unknown_field=12)

View File

@ -2,10 +2,6 @@ syntax = "proto3";
package oneof;
message MixedDrink {
int32 shots = 1;
}
message Test {
oneof foo {
int32 pitied = 1;
@ -17,7 +13,6 @@ message Test {
oneof bar {
int32 drinks = 11;
string bar_name = 12;
MixedDrink mixed_drink = 13;
}
}

View File

@ -1,10 +1,5 @@
import pytest
import betterproto
from tests.output_betterproto.oneof import (
MixedDrink,
Test,
)
from tests.output_betterproto.oneof import Test
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
from tests.util import get_test_case_json_data
@ -24,20 +19,3 @@ def test_which_name():
def test_which_count_pyd():
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
def test_oneof_constructor_assign():
message = Test(mixed_drink=MixedDrink(shots=42))
field, value = betterproto.which_one_of(message, "bar")
assert field == "mixed_drink"
assert value.shots == 42
# Issue #305:
@pytest.mark.xfail
def test_oneof_nested_assign():
message = Test()
message.mixed_drink.shots = 42
field, value = betterproto.which_one_of(message, "bar")
assert field == "mixed_drink"
assert value.shots == 42

View File

@ -41,8 +41,3 @@ def test_null_fields_json():
"test8": None,
"test9": None,
}
def test_unset_access(): # see #523
assert Test().test1 is None
assert Test(test1=None).test1 is None

View File

@ -1,2 +0,0 @@
•šï:bTesting•šï:bTesting
 

View File

@ -1,38 +0,0 @@
### Output ###
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
dependency-reduced-pom.xml
MANIFEST.MF
### IntelliJ IDEA ###
.idea/
*.iws
*.iml
*.ipr
### Eclipse ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
### VS Code ###
.vscode/
### Mac OS ###
.DS_Store

View File

@ -1,94 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>betterproto</groupId>
<artifactId>compatibility-test</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<protobuf.version>3.23.4</protobuf.version>
</properties>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
</dependencies>
<build>
<extensions>
<extension>
<groupId>kr.motd.maven</groupId>
<artifactId>os-maven-plugin</artifactId>
<version>1.7.1</version>
</extension>
</extensions>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.5.0</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>betterproto.CompatibilityTest</mainClass>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<configuration>
<archive>
<manifest>
<addClasspath>true</addClasspath>
<mainClass>betterproto.CompatibilityTest</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.6.1</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
</goals>
</execution>
</executions>
<configuration>
<protocArtifact>
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
</protocArtifact>
</configuration>
</plugin>
</plugins>
<finalName>${project.artifactId}</finalName>
</build>
</project>

View File

@ -1,41 +0,0 @@
package betterproto;
import java.io.IOException;
public class CompatibilityTest {
public static void main(String[] args) throws IOException {
if (args.length < 2)
throw new RuntimeException("Attempted to run without the required arguments.");
else if (args.length > 2)
throw new RuntimeException(
"Attempted to run with more than the expected number of arguments (>1).");
Tests tests = new Tests(args[1]);
switch (args[0]) {
case "single_varint":
tests.testSingleVarint();
break;
case "multiple_varints":
tests.testMultipleVarints();
break;
case "single_message":
tests.testSingleMessage();
break;
case "multiple_messages":
tests.testMultipleMessages();
break;
case "infinite_messages":
tests.testInfiniteMessages();
break;
default:
throw new RuntimeException(
"Attempted to run with unknown argument '" + args[0] + "'.");
}
}
}

View File

@ -1,115 +0,0 @@
package betterproto;
import betterproto.nested.NestedOuterClass;
import betterproto.oneof.Oneof;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
public class Tests {
String path;
public Tests(String path) {
this.path = path;
}
public void testSingleVarint() throws IOException {
// Read in the Python-generated single varint file
FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
int value = codedInput.readUInt32();
inputStream.close();
// Write the value back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
codedOutput.writeUInt32NoTag(value);
codedOutput.flush();
outputStream.close();
}
public void testMultipleVarints() throws IOException {
// Read in the Python-generated multiple varints file
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
int value1 = codedInput.readUInt32();
int value2 = codedInput.readUInt32();
long value3 = codedInput.readUInt64();
inputStream.close();
// Write the values back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
codedOutput.writeUInt32NoTag(value1);
codedOutput.writeUInt64NoTag(value2);
codedOutput.writeUInt64NoTag(value3);
codedOutput.flush();
outputStream.close();
}
public void testSingleMessage() throws IOException {
// Read in the Python-generated single message file
FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
Oneof.Test message = Oneof.Test.parseFrom(codedInput);
inputStream.close();
// Write the message back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
message.writeTo(codedOutput);
codedOutput.flush();
outputStream.close();
}
public void testMultipleMessages() throws IOException {
// Read in the Python-generated multi-message file
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
inputStream.close();
// Write the messages back to a file
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
oneof.writeDelimitedTo(outputStream);
nested.writeDelimitedTo(outputStream);
outputStream.flush();
outputStream.close();
}
public void testInfiniteMessages() throws IOException {
// Read in as many messages as are present in the Python-generated file and write them back
FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
while (current != null) {
current.writeDelimitedTo(outputStream);
current = Oneof.Test.parseDelimitedFrom(inputStream);
}
inputStream.close();
outputStream.flush();
outputStream.close();
}
}

View File

@ -1,27 +0,0 @@
syntax = "proto3";
package nested;
option java_package = "betterproto.nested";
// A test message with a nested message inside of it.
message Test {
// This is the nested type.
message Nested {
// Stores a simple counter.
int32 count = 1;
}
// This is the nested enum.
enum Msg {
NONE = 0;
THIS = 1;
}
Nested nested = 1;
Sibling sibling = 2;
Sibling sibling2 = 3;
Msg msg = 4;
}
message Sibling {
int32 foo = 1;
}

View File

@ -1,19 +0,0 @@
syntax = "proto3";
package oneof;
option java_package = "betterproto.oneof";
message Test {
oneof foo {
int32 pitied = 1;
string pitier = 2;
}
int32 just_a_regular_field = 3;
oneof bar {
int32 drinks = 11;
string bar_name = 12;
}
}

View File

@ -1,19 +0,0 @@
def test_all_definition():
"""
Check that a compiled module defines __all__ with the right value.
These modules have been chosen since they contain messages, services and enums.
"""
import tests.output_betterproto.enum as enum
import tests.output_betterproto.service as service
assert service.__all__ == (
"ThingType",
"DoThingRequest",
"DoThingResponse",
"GetThingRequest",
"GetThingResponse",
"TestStub",
"TestBase",
)
assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test")

View File

@ -2,12 +2,9 @@ import warnings
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.deprecated import (
Empty,
Message,
Test,
TestServiceStub,
)
@ -35,27 +32,14 @@ def test_message_with_deprecated_field(message):
def test_message_with_deprecated_field_not_set(message):
with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns(None) as record:
Test(value=10)
assert not record
def test_message_with_deprecated_field_not_set_default(message):
with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns(None) as record:
_ = Test(value=10).message
@pytest.mark.asyncio
async def test_service_with_deprecated_method():
stub = TestServiceStub(MockChannel([Empty(), Empty()]))
with pytest.warns(DeprecationWarning) as record:
await stub.deprecated_func(Empty())
assert len(record) == 1
assert str(record[0].message) == f"TestService.deprecated_func is deprecated"
with warnings.catch_warnings():
warnings.simplefilter("error")
await stub.func(Empty())
assert not record

View File

@ -1,37 +0,0 @@
import ast
import inspect
def check(generated_doc: str, type: str) -> None:
assert f"Documentation of {type} 1" in generated_doc
assert "other line 1" in generated_doc
assert f"Documentation of {type} 2" in generated_doc
assert "other line 2" in generated_doc
assert f"Documentation of {type} 3" in generated_doc
def test_documentation() -> None:
from .output_betterproto.documentation import (
Enum,
ServiceBase,
ServiceStub,
Test,
)
check(Test.__doc__, "message")
source = inspect.getsource(Test)
tree = ast.parse(source)
check(tree.body[0].body[2].value.value, "field")
check(Enum.__doc__, "enum")
source = inspect.getsource(Enum)
tree = ast.parse(source)
check(tree.body[0].body[2].value.value, "variant")
check(ServiceBase.__doc__, "service")
check(ServiceBase.get.__doc__, "method")
check(ServiceStub.__doc__, "service")
check(ServiceStub.get.__doc__, "method")

View File

@ -1,79 +0,0 @@
from typing import (
Optional,
Tuple,
)
import pytest
import betterproto
class Colour(betterproto.Enum):
RED = 1
GREEN = 2
BLUE = 3
PURPLE = Colour.__new__(Colour, name=None, value=4)
@pytest.mark.parametrize(
"member, str_value",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_str(member: Colour, str_value: str) -> None:
assert str(member) == str_value
@pytest.mark.parametrize(
"member, repr_value",
[
(Colour.RED, "Colour.RED"),
(Colour.GREEN, "Colour.GREEN"),
(Colour.BLUE, "Colour.BLUE"),
],
)
def test_repr(member: Colour, repr_value: str) -> None:
assert repr(member) == repr_value
@pytest.mark.parametrize(
"member, values",
[
(Colour.RED, ("RED", 1)),
(Colour.GREEN, ("GREEN", 2)),
(Colour.BLUE, ("BLUE", 3)),
(PURPLE, (None, 4)),
],
)
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
assert (member.name, member.value) == values
@pytest.mark.parametrize(
"member, input_str",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_from_string(member: Colour, input_str: str) -> None:
assert Colour.from_string(input_str) == member
@pytest.mark.parametrize(
"member, input_int",
[
(Colour.RED, 1),
(Colour.GREEN, 2),
(Colour.BLUE, 3),
(PURPLE, 4),
],
)
def test_try_value(member: Colour, input_int: int) -> None:
assert Colour.try_value(input_int) == member

View File

@ -545,6 +545,47 @@ def test_oneof_default_value_set_causes_writes_wire():
)
def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
def test_message_repr():
from tests.output_betterproto.recursivemessage import Test
@ -621,7 +662,9 @@ iso_candidates = """2009-12-12T12:34
2010-02-18T16:00:00.23334444
2010-02-18T16:00:00,2283
2009-05-19 143922
2009-05-19 1439""".split("\n")
2009-05-19 1439""".split(
"\n"
)
def test_iso_datetime():
@ -656,6 +699,25 @@ def test_service_argument__expected_parameter():
assert do_thing_request_parameter.annotation == "DoThingRequest"
def test_copyability():
@dataclass
class Spam(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
spam = Spam(bar=12, baz=["hello"])
copied = copy(spam)
assert spam == copied
assert spam is not copied
assert spam.baz is copied.baz
deepcopied = deepcopy(spam)
assert spam == deepcopied
assert spam is not deepcopied
assert spam.baz is not deepcopied.baz
def test_is_set():
@dataclass
class Spam(betterproto.Message):

View File

@ -4,15 +4,6 @@ from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.plugin.typing_compiler import DirectImportTypingCompiler
@pytest.fixture
def typing_compiler() -> DirectImportTypingCompiler:
"""
Generates a simple Direct Import Typing Compiler for testing.
"""
return DirectImportTypingCompiler()
@pytest.mark.parametrize(
@ -41,70 +32,15 @@ def typing_compiler() -> DirectImportTypingCompiler:
],
)
def test_reference_google_wellknown_types_non_wrappers(
google_type: str,
expected_name: str,
expected_import: str,
typing_compiler: DirectImportTypingCompiler,
google_type: str, expected_name: str, expected_import: str
):
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
pydantic=False,
)
name = get_type_reference(package="", imports=imports, source_type=google_type)
assert name == expected_name
assert imports.__contains__(expected_import), (
f"{expected_import} not found in {imports}"
)
@pytest.mark.parametrize(
["google_type", "expected_name", "expected_import"],
[
(
".google.protobuf.Empty",
'"betterproto_lib_pydantic_google_protobuf.Empty"',
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
),
(
".google.protobuf.Struct",
'"betterproto_lib_pydantic_google_protobuf.Struct"',
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
),
(
".google.protobuf.ListValue",
'"betterproto_lib_pydantic_google_protobuf.ListValue"',
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
),
(
".google.protobuf.Value",
'"betterproto_lib_pydantic_google_protobuf.Value"',
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
),
],
)
def test_reference_google_wellknown_types_non_wrappers_pydantic(
google_type: str,
expected_name: str,
expected_import: str,
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
pydantic=True,
)
assert name == expected_name
assert imports.__contains__(expected_import), (
f"{expected_import} not found in {imports}"
)
assert imports.__contains__(
expected_import
), f"{expected_import} not found in {imports}"
@pytest.mark.parametrize(
@ -122,15 +58,10 @@ def test_reference_google_wellknown_types_non_wrappers_pydantic(
],
)
def test_referenceing_google_wrappers_unwraps_them(
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
google_type: str, expected_name: str
):
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
)
name = get_type_reference(package="", imports=imports, source_type=google_type)
assert name == expected_name
assert imports == set()
@ -163,321 +94,223 @@ def test_referenceing_google_wrappers_unwraps_them(
],
)
def test_referenceing_google_wrappers_without_unwrapping(
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
google_type: str, expected_name: str
):
name = get_type_reference(
package="",
imports=set(),
source_type=google_type,
typing_compiler=typing_compiler,
unwrap=False,
package="", imports=set(), source_type=google_type, unwrap=False
)
assert name == expected_name
def test_reference_child_package_from_package(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_child_package_from_package():
imports = set()
name = get_type_reference(
package="package",
imports=imports,
source_type="package.child.Message",
typing_compiler=typing_compiler,
package="package", imports=imports, source_type="package.child.Message"
)
assert imports == {"from . import child"}
assert name == '"child.Message"'
def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler):
def test_reference_child_package_from_root():
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type="child.Message",
typing_compiler=typing_compiler,
)
name = get_type_reference(package="", imports=imports, source_type="child.Message")
assert imports == {"from . import child"}
assert name == '"child.Message"'
def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler):
def test_reference_camel_cased():
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type="child_package.example_message",
typing_compiler=typing_compiler,
package="", imports=imports, source_type="child_package.example_message"
)
assert imports == {"from . import child_package"}
assert name == '"child_package.ExampleMessage"'
def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler):
def test_reference_nested_child_from_root():
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type="nested.child.Message",
typing_compiler=typing_compiler,
package="", imports=imports, source_type="nested.child.Message"
)
assert imports == {"from .nested import child as nested_child"}
assert name == '"nested_child.Message"'
def test_reference_deeply_nested_child_from_root(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_deeply_nested_child_from_root():
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type="deeply.nested.child.Message",
typing_compiler=typing_compiler,
package="", imports=imports, source_type="deeply.nested.child.Message"
)
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == '"deeply_nested_child.Message"'
def test_reference_deeply_nested_child_from_package(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_deeply_nested_child_from_package():
imports = set()
name = get_type_reference(
package="package",
imports=imports,
source_type="package.deeply.nested.child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == '"deeply_nested_child.Message"'
def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler):
def test_reference_root_sibling():
imports = set()
name = get_type_reference(package="", imports=imports, source_type="Message")
assert imports == set()
assert name == '"Message"'
def test_reference_nested_siblings():
imports = set()
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings():
imports = set()
name = get_type_reference(
package="",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
package="foo.bar", imports=imports, source_type="foo.bar.Message"
)
assert imports == set()
assert name == '"Message"'
def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler):
def test_reference_parent_package_from_child():
imports = set()
name = get_type_reference(
package="foo",
imports=imports,
source_type="foo.Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="foo.bar",
imports=imports,
source_type="foo.bar.Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_parent_package_from_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package.child",
imports=imports,
source_type="package.Message",
typing_compiler=typing_compiler,
package="package.child", imports=imports, source_type="package.Message"
)
assert imports == {"from ... import package as __package__"}
assert name == '"__package__.Message"'
def test_reference_parent_package_from_deeply_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_parent_package_from_deeply_nested_child():
imports = set()
name = get_type_reference(
package="package.deeply.nested.child",
imports=imports,
source_type="package.deeply.nested.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ... import nested as __nested__"}
assert name == '"__nested__.Message"'
def test_reference_ancestor_package_from_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_ancestor_package_from_nested_child():
imports = set()
name = get_type_reference(
package="package.ancestor.nested.child",
imports=imports,
source_type="package.ancestor.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .... import ancestor as ___ancestor__"}
assert name == '"___ancestor__.Message"'
def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler):
def test_reference_root_package_from_child():
imports = set()
name = get_type_reference(
package="package.child",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
package="package.child", imports=imports, source_type="Message"
)
assert imports == {"from ... import Message as __Message__"}
assert name == '"__Message__"'
def test_reference_root_package_from_deeply_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_root_package_from_deeply_nested_child():
imports = set()
name = get_type_reference(
package="package.deeply.nested.child",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
package="package.deeply.nested.child", imports=imports, source_type="Message"
)
assert imports == {"from ..... import Message as ____Message__"}
assert name == '"____Message__"'
def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler):
def test_reference_unrelated_package():
imports = set()
name = get_type_reference(
package="a",
imports=imports,
source_type="p.Message",
typing_compiler=typing_compiler,
)
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
assert imports == {"from .. import p as _p__"}
assert name == '"_p__.Message"'
def test_reference_unrelated_nested_package(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_unrelated_nested_package():
imports = set()
name = get_type_reference(
package="a.b",
imports=imports,
source_type="p.q.Message",
typing_compiler=typing_compiler,
)
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
assert imports == {"from ...p import q as __p_q__"}
assert name == '"__p_q__.Message"'
def test_reference_unrelated_deeply_nested_package(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_unrelated_deeply_nested_package():
imports = set()
name = get_type_reference(
package="a.b.c.d",
imports=imports,
source_type="p.q.r.s.Message",
typing_compiler=typing_compiler,
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
)
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
assert name == '"____p_q_r_s__.Message"'
def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler):
def test_reference_cousin_package():
imports = set()
name = get_type_reference(
package="a.x",
imports=imports,
source_type="a.y.Message",
typing_compiler=typing_compiler,
)
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
assert imports == {"from .. import y as _y__"}
assert name == '"_y__.Message"'
def test_reference_cousin_package_different_name(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_cousin_package_different_name():
imports = set()
name = get_type_reference(
package="test.package1",
imports=imports,
source_type="cousin.package2.Message",
typing_compiler=typing_compiler,
package="test.package1", imports=imports, source_type="cousin.package2.Message"
)
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
assert name == '"__cousin_package2__.Message"'
def test_reference_cousin_package_same_name(
typing_compiler: DirectImportTypingCompiler,
):
def test_reference_cousin_package_same_name():
imports = set()
name = get_type_reference(
package="test.package",
imports=imports,
source_type="cousin.package.Message",
typing_compiler=typing_compiler,
package="test.package", imports=imports, source_type="cousin.package.Message"
)
assert imports == {"from ...cousin import package as __cousin_package__"}
assert name == '"__cousin_package__.Message"'
def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
def test_reference_far_cousin_package():
imports = set()
name = get_type_reference(
package="a.x.y",
imports=imports,
source_type="a.b.c.Message",
typing_compiler=typing_compiler,
package="a.x.y", imports=imports, source_type="a.b.c.Message"
)
assert imports == {"from ...b import c as __b_c__"}
assert name == '"__b_c__.Message"'
def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
def test_reference_far_far_cousin_package():
imports = set()
name = get_type_reference(
package="a.x.y.z",
imports=imports,
source_type="a.b.c.d.Message",
typing_compiler=typing_compiler,
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
)
assert imports == {"from ....b.c import d as ___b_c_d__"}

View File

@ -174,21 +174,22 @@ def test_message_equality(test_data: TestData) -> None:
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_message_json(test_data: TestData) -> None:
def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data
for sample in json_data:
if sample.belongs_to(test_input_config.non_symmetrical_json):
continue
for _ in range(repeat):
for sample in json_data:
if sample.belongs_to(test_input_config.non_symmetrical_json):
continue
message: betterproto.Message = plugin_module.Test()
message: betterproto.Message = plugin_module.Test()
message.from_json(sample.json)
message_json = message.to_json(0)
message.from_json(sample.json)
message_json = message.to_json(0)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(sample.json)
)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(sample.json)
)
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
@ -197,27 +198,28 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_binary_compatibility(test_data: TestData) -> None:
def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data
for sample in json_data:
reference_instance = Parse(sample.json, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
sample.json
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
)
for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = (
plugin_module.Test().from_json(sample.json)
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
)
# Generally this can't be relied on, but here we are aiming to match the
# existing Python implementation and aren't doing anything tricky.
# https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output
# Generally this can't be relied on, but here we are aiming to match the
# existing Python implementation and aren't doing anything tricky.
# https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output
assert plugin_instance_from_json == plugin_instance_from_binary
assert dict_replace_nans(
plugin_instance_from_json.to_dict()
) == dict_replace_nans(plugin_instance_from_binary.to_dict())
assert plugin_instance_from_json == plugin_instance_from_binary
assert dict_replace_nans(
plugin_instance_from_json.to_dict()
) == dict_replace_nans(plugin_instance_from_binary.to_dict())

View File

@ -1,111 +0,0 @@
from typing import (
List,
Optional,
Set,
)
import pytest
from betterproto.plugin.module_validation import ModuleValidator
@pytest.mark.parametrize(
["text", "expected_collisions"],
[
pytest.param(
["import os"],
None,
id="single import",
),
pytest.param(
["import os", "import sys"],
None,
id="multiple imports",
),
pytest.param(
["import os", "import os"],
{"os"},
id="duplicate imports",
),
pytest.param(
["from os import path", "import os"],
None,
id="duplicate imports with alias",
),
pytest.param(
["from os import path", "import os as os_alias"],
None,
id="duplicate imports with alias",
),
pytest.param(
["from os import path", "import os as path"],
{"path"},
id="duplicate imports with alias",
),
pytest.param(
["import os", "class os:"],
{"os"},
id="duplicate import with class",
),
pytest.param(
["import os", "class os:", " pass", "import sys"],
{"os"},
id="duplicate import with class and another",
),
pytest.param(
["def test(): pass", "class test:"],
{"test"},
id="duplicate class and function",
),
pytest.param(
["def test(): pass", "def test(): pass"],
{"test"},
id="duplicate functions",
),
pytest.param(
["def test(): pass", "test = 100"],
{"test"},
id="function and variable",
),
pytest.param(
["def test():", " test = 3"],
None,
id="function and variable in function",
),
pytest.param(
[
"def test(): pass",
"'''",
"def test(): pass",
"'''",
"def test_2(): pass",
],
None,
id="duplicate functions with multiline string",
),
pytest.param(
["def test(): pass", "# def test(): pass"],
None,
id="duplicate functions with comments",
),
pytest.param(
["from test import (", " A", " B", " C", ")"],
None,
id="multiline import",
),
pytest.param(
["from test import (", " A", " B", " C", ")", "from test import A"],
{"A"},
id="multiline import with duplicate",
),
],
)
def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]):
line_iterator = iter(text)
validator = ModuleValidator(line_iterator)
valid = validator.validate()
if expected_collisions is None:
assert valid
else:
assert set(validator.collisions.keys()) == expected_collisions
assert not valid

View File

@ -1,216 +0,0 @@
import pickle
from copy import (
copy,
deepcopy,
)
from dataclasses import dataclass
from typing import (
Dict,
List,
)
from unittest.mock import ANY
import cachelib
import betterproto
from betterproto.lib.google import protobuf as google
def unpickled(message):
return pickle.loads(pickle.dumps(message))
@dataclass(eq=False, repr=False)
class Fe(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fi(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class Fo(betterproto.Message):
abc: str = betterproto.string_field(1)
@dataclass(eq=False, repr=False)
class NestedData(betterproto.Message):
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
@dataclass(eq=False, repr=False)
class Complex(betterproto.Message):
foo_str: str = betterproto.string_field(1)
fe: "Fe" = betterproto.message_field(3, group="grp")
fi: "Fi" = betterproto.message_field(4, group="grp")
fo: "Fo" = betterproto.message_field(5, group="grp")
nested_data: "NestedData" = betterproto.message_field(6)
mapping: Dict[str, "google.Any"] = betterproto.map_field(
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
class BetterprotoEnum(betterproto.Enum):
UNSPECIFIED = 0
ONE = 1
def complex_msg():
return Complex(
foo_str="yep",
fe=Fe(abc="1"),
nested_data=NestedData(
struct_foo={
"foo": google.Struct(
fields={
"hello": google.Value(
list_value=google.ListValue(
values=[google.Value(string_value="world")]
)
)
}
),
},
map_str_any_bar={
"key": google.Any(value=b"value"),
},
),
mapping={
"message": google.Any(value=bytes(Fi(abc="hi"))),
"string": google.Any(value=b"howdy"),
},
)
def test_pickling_complex_message():
msg = complex_msg()
deser = unpickled(msg)
assert msg == deser
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)
def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage()
msg = unpickled(msg)
assert msg.child == RecursiveMessage()
# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()
# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Intermediate,
Test as RecursiveMessage,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
msg = unpickled(msg)
# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)
# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()
@dataclass
class PickledMessage(betterproto.Message):
foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2)
baz: List[str] = betterproto.string_field(3)
def test_copyability():
msg = PickledMessage(bar=12, baz=["hello"])
msg = unpickled(msg)
copied = copy(msg)
assert msg == copied
assert msg is not copied
assert msg.baz is copied.baz
deepcopied = deepcopy(msg)
assert msg == deepcopied
assert msg is not deepcopied
assert msg.baz is not deepcopied.baz
def test_message_can_be_cached():
"""Cachelib uses pickling to cache values"""
cache = cachelib.SimpleCache()
def use_cache():
calls = getattr(use_cache, "calls", 0)
result = cache.get("message")
if result is not None:
return result
else:
setattr(use_cache, "calls", calls + 1)
result = complex_msg()
cache.set("message", result)
return result
for n in range(10):
if n == 0:
assert not cache.has("message")
else:
assert cache.has("message")
msg = use_cache()
assert use_cache.calls == 1 # The message is only ever built once
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert (
msg.nested_data.struct_foo["foo"]
.fields["hello"]
.list_value.values[0]
.string_value
== "world"
)
def test_pickle_enum():
enum = BetterprotoEnum.ONE
assert unpickled(enum) == enum
enum = BetterprotoEnum.UNSPECIFIED
assert unpickled(enum) == enum

View File

@ -1,8 +1,6 @@
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from shutil import which
from subprocess import run
from typing import Optional
import pytest
@ -42,8 +40,6 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
streams_path = Path("tests/streams/")
java = which("java")
def test_load_varint_too_long():
with BytesIO(
@ -62,7 +58,7 @@ def test_load_varint_file():
stream.read(2) # Skip until first multi-byte
assert betterproto.load_varint(stream) == (
123456789,
b"\x95\x9a\xef\x3a",
b"\x95\x9A\xEF\x3A",
) # Multi-byte varint
@ -131,18 +127,6 @@ def test_message_dump_file_multiple(tmp_path):
assert test_stream.read() == exp_stream.read()
def test_message_dump_delimited(tmp_path):
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
streams_path / "delimited_messages.in", "rb"
) as exp_stream:
assert test_stream.read() == exp_stream.read()
def test_message_len():
assert len_oneof == len(bytes(oneof_example))
assert len(nested_example) == len(bytes(nested_example))
@ -171,15 +155,7 @@ def test_message_load_too_small():
oneof.Test().load(stream, len_oneof - 1)
def test_message_load_delimited():
with open(streams_path / "delimited_messages.in", "rb") as stream:
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
assert stream.read(1) == b""
def test_message_load_too_large():
def test_message_too_large():
with open(
streams_path / "message_dump_file_single.expected", "rb"
) as stream, pytest.raises(ValueError):
@ -290,145 +266,3 @@ def test_dump_varint_positive(tmp_path):
streams_path / "dump_varint_positive.expected", "rb"
) as exp_stream:
assert test_stream.read() == exp_stream.read()
# Java compatibility tests
@pytest.fixture(scope="module")
def compile_jar():
# Skip if not all required tools are present
if java is None:
pytest.skip("`java` command is absent and is required")
mvn = which("mvn")
if mvn is None:
pytest.skip("Maven is absent and is required")
# Compile the JAR
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
if proc_maven.returncode != 0:
pytest.skip(
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
)
jar = "tests/streams/java/target/compatibility-test.jar"
def run_jar(command: str, tmp_path):
return run([java, "-jar", jar, command, tmp_path], check=True)
def run_java_single_varint(value: int, tmp_path) -> int:
# Write single varint to file
with open(tmp_path / "py_single_varint.out", "wb") as stream:
betterproto.dump_varint(value, stream)
# Have Java read this varint and write it back
run_jar("single_varint", tmp_path)
# Read single varint from Java output file
with open(tmp_path / "java_single_varint.out", "rb") as stream:
returned = betterproto.load_varint(stream)
with pytest.raises(EOFError):
betterproto.load_varint(stream)
return returned
def test_single_varint(compile_jar, tmp_path):
single_byte = (1, b"\x01")
multi_byte = (123456789, b"\x95\x9a\xef\x3a")
# Write a single-byte varint to a file and have Java read it back
returned = run_java_single_varint(single_byte[0], tmp_path)
assert returned == single_byte
# Same for a multi-byte varint
returned = run_java_single_varint(multi_byte[0], tmp_path)
assert returned == multi_byte
def test_multiple_varints(compile_jar, tmp_path):
single_byte = (1, b"\x01")
multi_byte = (123456789, b"\x95\x9a\xef\x3a")
over32 = (3000000000, b"\x80\xbc\xc1\x96\x0b")
# Write two varints to the same file
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
betterproto.dump_varint(single_byte[0], stream)
betterproto.dump_varint(multi_byte[0], stream)
betterproto.dump_varint(over32[0], stream)
# Have Java read these varints and write them back
run_jar("multiple_varints", tmp_path)
# Read varints from Java output file
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
returned_single = betterproto.load_varint(stream)
returned_multi = betterproto.load_varint(stream)
returned_over32 = betterproto.load_varint(stream)
with pytest.raises(EOFError):
betterproto.load_varint(stream)
assert returned_single == single_byte
assert returned_multi == multi_byte
assert returned_over32 == over32
def test_single_message(compile_jar, tmp_path):
# Write message to file
with open(tmp_path / "py_single_message.out", "wb") as stream:
oneof_example.dump(stream)
# Have Java read and return the message
run_jar("single_message", tmp_path)
# Read and check the returned message
with open(tmp_path / "java_single_message.out", "rb") as stream:
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
assert stream.read() == b""
assert returned == oneof_example
def test_multiple_messages(compile_jar, tmp_path):
# Write delimited messages to file
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
# Have Java read and return the messages
run_jar("multiple_messages", tmp_path)
# Read and check the returned messages
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
assert stream.read() == b""
assert returned_oneof == oneof_example
assert returned_nested == nested_example
def test_infinite_messages(compile_jar, tmp_path):
num_messages = 5
# Write delimited messages to file
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
for x in range(num_messages):
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
# Have Java read and return the messages
run_jar("infinite_messages", tmp_path)
# Read and check the returned messages
messages = []
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
while True:
try:
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
except EOFError:
break
assert len(messages) == num_messages

View File

@ -1,36 +0,0 @@
import json
from betterproto.lib.google.protobuf import Struct
from betterproto.lib.pydantic.google.protobuf import Struct as StructPydantic
def test_struct_roundtrip():
data = {
"foo": "bar",
"baz": None,
"quux": 123,
"zap": [1, {"two": 3}, "four"],
}
data_json = json.dumps(data)
struct_from_dict = Struct().from_dict(data)
assert struct_from_dict.fields == data
assert struct_from_dict.to_dict() == data
assert struct_from_dict.to_json() == data_json
struct_from_json = Struct().from_json(data_json)
assert struct_from_json.fields == data
assert struct_from_json.to_dict() == data
assert struct_from_json == struct_from_dict
assert struct_from_json.to_json() == data_json
struct_pyd_from_dict = StructPydantic(fields={}).from_dict(data)
assert struct_pyd_from_dict.fields == data
assert struct_pyd_from_dict.to_dict() == data
assert struct_pyd_from_dict.to_json() == data_json
struct_pyd_from_dict = StructPydantic(fields={}).from_json(data_json)
assert struct_pyd_from_dict.fields == data
assert struct_pyd_from_dict.to_dict() == data
assert struct_pyd_from_dict == struct_pyd_from_dict
assert struct_pyd_from_dict.to_json() == data_json

View File

@ -1,27 +0,0 @@
from datetime import (
datetime,
timezone,
)
import pytest
from betterproto import _Timestamp
@pytest.mark.parametrize(
"dt",
[
datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
datetime.now(timezone.utc),
# potential issue with floating point precision:
datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
# potential issue with negative timestamps:
datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
],
)
def test_timestamp_to_datetime_and_back(dt: datetime):
"""
Make sure converting a datetime to a protobuf timestamp message
and then back again ends up with the same datetime.
"""
assert _Timestamp.from_datetime(dt).to_datetime() == dt

View File

@ -1,78 +0,0 @@
import pytest
from betterproto.plugin.typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingImportTypingCompiler,
)
def test_direct_import_typing_compiler():
compiler = DirectImportTypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "Optional[str]"
assert compiler.imports() == {"typing": {"Optional"}}
assert compiler.list("str") == "List[str]"
assert compiler.imports() == {"typing": {"Optional", "List"}}
assert compiler.dict("str", "int") == "Dict[str, int]"
assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}}
assert compiler.union("str", "int") == "Union[str, int]"
assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}}
assert compiler.iterable("str") == "Iterable[str]"
assert compiler.imports() == {
"typing": {"Optional", "List", "Dict", "Union", "Iterable"}
}
assert compiler.async_iterable("str") == "AsyncIterable[str]"
assert compiler.imports() == {
"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}
}
assert compiler.async_iterator("str") == "AsyncIterator[str]"
assert compiler.imports() == {
"typing": {
"Optional",
"List",
"Dict",
"Union",
"Iterable",
"AsyncIterable",
"AsyncIterator",
}
}
def test_typing_import_typing_compiler():
compiler = TypingImportTypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "typing.Optional[str]"
assert compiler.imports() == {"typing": None}
assert compiler.list("str") == "typing.List[str]"
assert compiler.imports() == {"typing": None}
assert compiler.dict("str", "int") == "typing.Dict[str, int]"
assert compiler.imports() == {"typing": None}
assert compiler.union("str", "int") == "typing.Union[str, int]"
assert compiler.imports() == {"typing": None}
assert compiler.iterable("str") == "typing.Iterable[str]"
assert compiler.imports() == {"typing": None}
assert compiler.async_iterable("str") == "typing.AsyncIterable[str]"
assert compiler.imports() == {"typing": None}
assert compiler.async_iterator("str") == "typing.AsyncIterator[str]"
assert compiler.imports() == {"typing": None}
def test_no_typing_311_typing_compiler():
compiler = NoTyping310TypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == '"str | None"'
assert compiler.imports() == {}
assert compiler.list("str") == '"list[str]"'
assert compiler.imports() == {}
assert compiler.dict("str", "int") == '"dict[str, int]"'
assert compiler.imports() == {}
assert compiler.union("str", "int") == '"str | int"'
assert compiler.imports() == {}
assert compiler.iterable("str") == '"Iterable[str]"'
assert compiler.async_iterable("str") == '"AsyncIterable[str]"'
assert compiler.async_iterator("str") == '"AsyncIterator[str]"'
assert compiler.imports() == {
"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}
}

View File

@ -11,6 +11,6 @@ PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve()
def test_version():
with PROJECT_TOML.open() as toml_file:
project_config = tomlkit.loads(toml_file.read())
assert __version__ == project_config["project"]["version"], (
"Project version should match in package and package config"
)
assert (
__version__ == project_config["tool"]["poetry"]["version"]
), "Project version should match in package and package config"