Compare commits
81 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
a32a326d38 | ||
|
256f499c90 | ||
|
5a518ed044 | ||
|
08e8a68893 | ||
|
6a65ca94bc | ||
|
dbc612c7f3 | ||
|
f33a42b082 | ||
|
36b5fd1495 | ||
|
f41934a0e2 | ||
|
37fa3abbac | ||
|
ed7eefac6f | ||
|
1741a126b3 | ||
|
1a23f09c16 | ||
|
335eee7537 | ||
|
a99be78fe4 | ||
|
849c12fe88 | ||
|
c621ef8a8d | ||
|
34b8249b91 | ||
|
6a3bbe3f25 | ||
|
65ee2fc702 | ||
|
7c43c39eab | ||
|
f8ecc42478 | ||
|
c2bcd31fe3 | ||
|
c9dfe9ab1f | ||
|
32d642d2a0 | ||
|
1161803069 | ||
|
8d25c96cea | ||
|
8fdcb381b7 | ||
|
4cf6e7d95c | ||
|
63458e2da0 | ||
|
efaef5095c | ||
|
1538e156a1 | ||
|
4e9a17c227 | ||
|
f96f51650c | ||
|
970624fe08 | ||
|
32eaa51e8d | ||
|
5fdd0bb24f | ||
|
8b59234856 | ||
|
7c6c627938 | ||
|
696b7ae9fc | ||
|
6dce440975 | ||
|
1f88b67eeb | ||
|
1f79bdd7e4 | ||
|
6606cd3bb9 | ||
|
576f878ddc | ||
|
6bdfa67fa1 | ||
|
b075402a93 | ||
|
49ac12634b | ||
|
a7f0d028ff | ||
|
acca29731f | ||
|
ecbe8dc04d | ||
|
85d2990ca1 | ||
|
c3c20556e0 | ||
|
df1ba911b7 | ||
|
126b256b4c | ||
|
e98c47861d | ||
|
dbd31929d3 | ||
|
7dee36e073 | ||
|
5666393f9d | ||
|
ce5093eec0 | ||
|
c47e83fe5b | ||
|
b8a091ae70 | ||
|
61fc2f4160 | ||
|
d34b16993d | ||
|
9ed579fa35 | ||
|
1d296f1a88 | ||
|
bd7de203e1 | ||
|
d9b7608980 | ||
|
02aa4e88b7 | ||
|
1dd001b6d3 | ||
|
e309513131 | ||
|
24db53290e | ||
|
2bcb05a905 | ||
|
ca6b9fe1a2 | ||
|
4f18ed1325 | ||
|
6b36b9ba9f | ||
|
61d192e207 | ||
|
8b5dd6c1f8 | ||
|
3514991133 | ||
|
e3b44f491f | ||
|
c82816b8be |
35
.gitea/workflows/release.yml
Normal file
35
.gitea/workflows/release.yml
Normal file
@ -0,0 +1,35 @@
|
||||
name: Release
|
||||
run-name: ${{ gitea.actor }} is runs ci pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
packaging:
|
||||
name: Distribution
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
EXT_FIX: "6"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.9
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
- name: Install poetry
|
||||
run: python -m pip install poetry chardet
|
||||
- name: Install poetry compiler
|
||||
run: poetry install -E compiler
|
||||
- name: Set poetry version
|
||||
run: PV=$(poetry version -s) && poetry version ${PV}+jar3b${EXT_FIX}
|
||||
- name: Build package
|
||||
run: poetry build
|
||||
- name: Add pypi source
|
||||
run: poetry source add --priority=supplemental ahax https://git.ahax86.ru/api/packages/pub/pypi
|
||||
- name: Add pypi credentials
|
||||
run: poetry config http-basic.ahax ${{ secrets.REPO_USER }} ${{ secrets.REPO_PASS }}
|
||||
- name: Push to pypi
|
||||
run: poetry publish -r ahax -u ${{ secrets.REPO_USER }} -p ${{ secrets.REPO_PASS }} -n
|
||||
|
6
.github/CONTRIBUTING.md
vendored
6
.github/CONTRIBUTING.md
vendored
@ -2,7 +2,7 @@
|
||||
|
||||
There's lots to do, and we're working hard, so any help is welcome!
|
||||
|
||||
- :speech_balloon: Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
|
||||
- :speech_balloon: Join us on [Discord](https://discord.gg/DEVteTupPb)!
|
||||
|
||||
What can you do?
|
||||
|
||||
@ -15,9 +15,9 @@ What can you do?
|
||||
- File a bug (please check its not a duplicate)
|
||||
- Propose an enhancement
|
||||
- :white_check_mark: Create a PR:
|
||||
- [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/betterproto/tests/README.md) to make bug-fixing easier
|
||||
- [Creating a failing test-case](https://github.com/danielgtaylor/python-betterproto/blob/master/tests/README.md) to make bug-fixing easier
|
||||
- Fix any of the open issues
|
||||
- [Good first issues](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
|
||||
- [Issues with tests](https://github.com/danielgtaylor/python-betterproto/issues?q=is%3Aissue+is%3Aopen+label%3A%22has+test%22)
|
||||
- New bugfix or idea
|
||||
- If you'd like to discuss your idea first, join us on Slack!
|
||||
- If you'd like to discuss your idea first, join us on Discord!
|
||||
|
63
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: Bug Report
|
||||
description: Report broken or incorrect behaviour
|
||||
labels: ["bug", "investigation needed"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for taking the time to fill out a bug report!
|
||||
|
||||
If you're not sure it's a bug and you just have a question, the [community Discord channel](https://discord.gg/DEVteTupPb) is a better place for general questions than a GitHub issue.
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Summary
|
||||
description: A simple summary of your bug report
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Reproduction Steps
|
||||
description: >
|
||||
What you did to make it happen.
|
||||
Ideally there should be a short code snippet in this section to help reproduce the bug.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Results
|
||||
description: >
|
||||
What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Actual Results
|
||||
description: >
|
||||
What actually happened?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: System Information
|
||||
description: >
|
||||
Paste the result of `protoc --version; python --version; pip show betterproto` below.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have searched the issues for duplicates.
|
||||
required: true
|
||||
- label: I have shown the entire traceback, if possible.
|
||||
required: true
|
||||
- label: I have verified this issue occurs on the latest prelease of betterproto which can be installed using `pip install -U --pre betterproto`, if possible.
|
||||
required: true
|
||||
|
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
name:
|
||||
description:
|
||||
contact_links:
|
||||
- name: For questions about the library
|
||||
about: Support questions are better answered in our Discord group.
|
||||
url: https://discord.gg/DEVteTupPb
|
49
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
49
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
name: Feature Request
|
||||
description: Suggest a feature for this library
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: input
|
||||
attributes:
|
||||
label: Summary
|
||||
description: >
|
||||
What problem is your feature trying to solve? What would become easier or possible if feature was implemented?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
multiple: false
|
||||
label: What is the feature request for?
|
||||
options:
|
||||
- The core library
|
||||
- RPC handling
|
||||
- The documentation
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The Problem
|
||||
description: >
|
||||
What problem is your feature trying to solve?
|
||||
What would become easier or possible if feature was implemented?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The Ideal Solution
|
||||
description: >
|
||||
What is your ideal solution to the problem?
|
||||
What would you like this feature to do?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The Current Solution
|
||||
description: >
|
||||
What is the current solution to the problem, if any?
|
||||
validations:
|
||||
required: false
|
16
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
16
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
## Summary
|
||||
|
||||
<!-- What is this pull request for? Does it fix any issues? -->
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Put an x inside [ ] to check it, like so: [x] -->
|
||||
|
||||
- [ ] If code changes were made then they have been tested.
|
||||
- [ ] I have updated the documentation to reflect the changes.
|
||||
- [ ] This PR fixes an issue.
|
||||
- [ ] This PR adds something new (e.g. new method or parameters).
|
||||
- [ ] This change has an associated test.
|
||||
- [ ] This PR is a breaking change (e.g. methods or parameters removed/renamed)
|
||||
- [ ] This PR is **not** a code change (e.g. documentation, README, ...)
|
||||
|
10
.github/workflows/ci.yml
vendored
10
.github/workflows/ci.yml
vendored
@ -16,19 +16,19 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [Ubuntu, MacOS, Windows]
|
||||
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Get full Python version
|
||||
id: full-python-version
|
||||
shell: bash
|
||||
run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")
|
||||
run: echo "version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Install poetry
|
||||
shell: bash
|
||||
@ -41,7 +41,7 @@ jobs:
|
||||
run: poetry config virtualenvs.in-project true
|
||||
|
||||
- name: Set up cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
id: cache
|
||||
with:
|
||||
path: .venv
|
||||
|
6
.github/workflows/code-quality.yml
vendored
6
.github/workflows/code-quality.yml
vendored
@ -13,6 +13,6 @@ jobs:
|
||||
name: Check code/doc formatting
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: pre-commit/action@v2.0.3
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
46
.github/workflows/codeql-analysis.yml
vendored
Normal file
46
.github/workflows/codeql-analysis.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
schedule:
|
||||
- cron: '19 1 * * 6'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ 'python' ]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@ -15,11 +15,11 @@ jobs:
|
||||
name: Distribution
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.9
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: 3.9
|
||||
- name: Install poetry
|
||||
run: python -m pip install poetry
|
||||
- name: Build package
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -18,3 +18,4 @@ output
|
||||
.asv
|
||||
venv
|
||||
.devcontainer
|
||||
.ruff_cache
|
@ -2,20 +2,24 @@ ci:
|
||||
autofix_prs: false
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.11.5
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.1
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ["--target-version", "py310"]
|
||||
- id: ruff-format
|
||||
args: ["--diff", "src", "tests"]
|
||||
- id: ruff
|
||||
args: ["--select", "I", "src", "tests"]
|
||||
|
||||
- repo: https://github.com/PyCQA/doc8
|
||||
rev: 0.10.1
|
||||
hooks:
|
||||
- id: doc8
|
||||
- id: doc8
|
||||
additional_dependencies:
|
||||
- toml
|
||||
|
||||
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
|
||||
rev: v2.14.0
|
||||
hooks:
|
||||
- id: pretty-format-java
|
||||
args: [--autofix, --aosp]
|
||||
files: ^.*\.java$
|
||||
|
23
CHANGELOG.md
23
CHANGELOG.md
@ -7,6 +7,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
|
||||
|
||||
## [2.0.0b7] - 2024-08-11
|
||||
|
||||
- **Breaking**: Support `Pydantic` v2 and dropping support for v1 [#588](https://github.com/danielgtaylor/python-betterproto/pull/588)
|
||||
- **Breaking**: The attempting to access an unset `oneof` now raises an `AttributeError`
|
||||
field. To see how to access `oneof` fields now, refer to [#558](https://github.com/danielgtaylor/python-betterproto/pull/558)
|
||||
and [README.md](https://github.com/danielgtaylor/python-betterproto#one-of-support).
|
||||
- **Breaking**: A custom `Enum` has been implemented to match the behaviour of being an open set. Any checks for `isinstance(enum_member, enum.Enum)` and `issubclass(EnumSubclass, enum.Enum)` will now return `False`. This change also has the side effect of
|
||||
preventing any passthrough of `Enum` members (i.e. `Foo.RED.GREEN` doesn't work any more). See [#293](https://github.com/danielgtaylor/python-betterproto/pull/293) for more info, this fixed many bugs related to `Enum` handling.
|
||||
|
||||
- Add support for `pickle` methods [#535](https://github.com/danielgtaylor/python-betterproto/pull/535)
|
||||
- Add support for `Struct` and `Value` types [#551](https://github.com/danielgtaylor/python-betterproto/pull/551)
|
||||
- Add support for [`Rich` package](https://rich.readthedocs.io/en/latest/index.html) for pretty printing [#508](https://github.com/danielgtaylor/python-betterproto/pull/508)
|
||||
- Improve support for streaming messages [#518](https://github.com/danielgtaylor/python-betterproto/pull/518) [#529](https://github.com/danielgtaylor/python-betterproto/pull/529)
|
||||
- Improve performance of serializing / de-serializing messages [#545](https://github.com/danielgtaylor/python-betterproto/pull/545)
|
||||
- Improve the handling of message name collisions with typing by allowing the method / type of imports to be configured.
|
||||
Refer to [#582](https://github.com/danielgtaylor/python-betterproto/pull/582)
|
||||
and [README.md](https://github.com/danielgtaylor/python-betterproto#configuration-typing-imports).
|
||||
- Fix roundtrip parsing of `datetime`s [#534](https://github.com/danielgtaylor/python-betterproto/pull/534)
|
||||
- Fix accessing unset optional fields [#523](https://github.com/danielgtaylor/python-betterproto/pull/523)
|
||||
- Fix `Message` equality comparison [#513](https://github.com/danielgtaylor/python-betterproto/pull/513)
|
||||
- Fix behaviour with long comment messages [#532](https://github.com/danielgtaylor/python-betterproto/pull/532)
|
||||
- Add a warning when calling a deprecated message [#596](https://github.com/danielgtaylor/python-betterproto/pull/596)
|
||||
|
||||
## [2.0.0b6] - 2023-06-25
|
||||
|
||||
- **Breaking**: the minimum Python version has been bumped to `3.7` [#444](https://github.com/danielgtaylor/python-betterproto/pull/444)
|
||||
|
73
README.md
73
README.md
@ -1,6 +1,7 @@
|
||||
# Better Protobuf / gRPC Support for Python
|
||||
|
||||

|
||||

|
||||
|
||||
> :octocat: If you're reading this on github, please be aware that it might mention unreleased features! See the latest released README on [pypi](https://pypi.org/project/betterproto/).
|
||||
|
||||
This project aims to provide an improved experience when using Protobuf / gRPC in a modern Python environment by making use of modern language features and generating readable, understandable, idiomatic Python code. It will not support legacy features or environments (e.g. Protobuf 2). The following are supported:
|
||||
@ -277,7 +278,22 @@ message Test {
|
||||
}
|
||||
```
|
||||
|
||||
You can use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.
|
||||
On Python 3.10 and later, you can use a `match` statement to access the provided one-of field, which supports type-checking:
|
||||
|
||||
```py
|
||||
test = Test()
|
||||
match test:
|
||||
case Test(on=value):
|
||||
print(value) # value: bool
|
||||
case Test(count=value):
|
||||
print(value) # value: int
|
||||
case Test(name=value):
|
||||
print(value) # value: str
|
||||
case _:
|
||||
print("No value provided")
|
||||
```
|
||||
|
||||
You can also use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.
|
||||
|
||||
```py
|
||||
>>> test = Test()
|
||||
@ -292,17 +308,11 @@ You can use `betterproto.which_one_of(message, group_name)` to determine which o
|
||||
>>> test.count = 57
|
||||
>>> betterproto.which_one_of(test, "foo")
|
||||
["count", 57]
|
||||
>>> test.on
|
||||
False
|
||||
|
||||
# Default (zero) values also work.
|
||||
>>> test.name = ""
|
||||
>>> betterproto.which_one_of(test, "foo")
|
||||
["name", ""]
|
||||
>>> test.count
|
||||
0
|
||||
>>> test.on
|
||||
False
|
||||
```
|
||||
|
||||
Again this is a little different than the official Google code generator:
|
||||
@ -382,11 +392,54 @@ swap the dataclass implementation from the builtin python dataclass to the
|
||||
pydantic dataclass. You must have pydantic as a dependency in your project for
|
||||
this to work.
|
||||
|
||||
## Configuration typing imports
|
||||
|
||||
By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:
|
||||
|
||||
### Direct
|
||||
```
|
||||
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
|
||||
```
|
||||
this configuration is the default, and will import types as follows:
|
||||
```
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Union
|
||||
)
|
||||
...
|
||||
value: List[str] = []
|
||||
value2: Optional[str] = None
|
||||
value3: Union[str, int] = 1
|
||||
```
|
||||
### Root
|
||||
```
|
||||
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
|
||||
```
|
||||
this configuration loads the root typing module, and then access the types off of it directly:
|
||||
```
|
||||
import typing
|
||||
...
|
||||
value: typing.List[str] = []
|
||||
value2: typing.Optional[str] = None
|
||||
value3: typing.Union[str, int] = 1
|
||||
```
|
||||
|
||||
### 310
|
||||
```
|
||||
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
|
||||
```
|
||||
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
|
||||
```
|
||||
...
|
||||
value: list[str] = []
|
||||
value2: str | None = None
|
||||
value3: str | int = 1
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
|
||||
- _Join us on [Discord](https://discord.gg/DEVteTupPb)!_
|
||||
- _See how you can help → [Contributing](.github/CONTRIBUTING.md)_
|
||||
|
||||
### Requirements
|
||||
@ -522,7 +575,7 @@ protoc \
|
||||
|
||||
## Community
|
||||
|
||||
Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!
|
||||
Join us on [Discord](https://discord.gg/DEVteTupPb)!
|
||||
|
||||
## License
|
||||
|
||||
|
@ -6,32 +6,32 @@ import betterproto
|
||||
|
||||
@dataclass
|
||||
class TestMessage(betterproto.Message):
|
||||
foo: int = betterproto.uint32_field(0)
|
||||
bar: str = betterproto.string_field(1)
|
||||
baz: float = betterproto.float_field(2)
|
||||
foo: int = betterproto.uint32_field(1)
|
||||
bar: str = betterproto.string_field(2)
|
||||
baz: float = betterproto.float_field(3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestNestedChildMessage(betterproto.Message):
|
||||
str_key: str = betterproto.string_field(0)
|
||||
bytes_key: bytes = betterproto.bytes_field(1)
|
||||
bool_key: bool = betterproto.bool_field(2)
|
||||
float_key: float = betterproto.float_field(3)
|
||||
int_key: int = betterproto.uint64_field(4)
|
||||
str_key: str = betterproto.string_field(1)
|
||||
bytes_key: bytes = betterproto.bytes_field(2)
|
||||
bool_key: bool = betterproto.bool_field(3)
|
||||
float_key: float = betterproto.float_field(4)
|
||||
int_key: int = betterproto.uint64_field(5)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestNestedMessage(betterproto.Message):
|
||||
foo: TestNestedChildMessage = betterproto.message_field(0)
|
||||
bar: TestNestedChildMessage = betterproto.message_field(1)
|
||||
baz: TestNestedChildMessage = betterproto.message_field(2)
|
||||
foo: TestNestedChildMessage = betterproto.message_field(1)
|
||||
bar: TestNestedChildMessage = betterproto.message_field(2)
|
||||
baz: TestNestedChildMessage = betterproto.message_field(3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestRepeatedMessage(betterproto.Message):
|
||||
foo_repeat: List[str] = betterproto.string_field(0)
|
||||
bar_repeat: List[int] = betterproto.int64_field(1)
|
||||
baz_repeat: List[bool] = betterproto.bool_field(2)
|
||||
foo_repeat: List[str] = betterproto.string_field(1)
|
||||
bar_repeat: List[int] = betterproto.int64_field(2)
|
||||
baz_repeat: List[bool] = betterproto.bool_field(3)
|
||||
|
||||
|
||||
class BenchMessage:
|
||||
@ -44,25 +44,14 @@ class BenchMessage:
|
||||
self.instance_filled_bytes = bytes(self.instance_filled)
|
||||
self.instance_filled_nested = TestNestedMessage(
|
||||
TestNestedChildMessage("foo", bytearray(b"test1"), True, 0.1234, 500),
|
||||
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, -302),
|
||||
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, 302),
|
||||
TestNestedChildMessage("baz", bytearray(b"test3"), False, 1e5, 300),
|
||||
)
|
||||
self.instance_filled_nested_bytes = bytes(self.instance_filled_nested)
|
||||
self.instance_filled_repeated = TestRepeatedMessage(
|
||||
[
|
||||
"test1",
|
||||
"test2",
|
||||
"test3",
|
||||
"test4",
|
||||
"test5",
|
||||
"test6",
|
||||
"test7",
|
||||
"test8",
|
||||
"test9",
|
||||
"test10",
|
||||
],
|
||||
[2, -100, 0, 500000, 600, -425678, 1000000000, -300, 1, -694214214466],
|
||||
[True, False, False, False, True, True, False, True, False, False],
|
||||
[f"test{i}" for i in range(1_000)],
|
||||
[(i - 500) ** 3 for i in range(1_000)],
|
||||
[i % 2 == 0 for i in range(1_000)],
|
||||
)
|
||||
self.instance_filled_repeated_bytes = bytes(self.instance_filled_repeated)
|
||||
|
||||
@ -71,9 +60,9 @@ class BenchMessage:
|
||||
|
||||
@dataclass
|
||||
class Message(betterproto.Message):
|
||||
foo: int = betterproto.uint32_field(0)
|
||||
bar: str = betterproto.string_field(1)
|
||||
baz: float = betterproto.float_field(2)
|
||||
foo: int = betterproto.uint32_field(1)
|
||||
bar: str = betterproto.string_field(2)
|
||||
baz: float = betterproto.float_field(3)
|
||||
|
||||
def time_instantiation(self):
|
||||
"""Time instantiation"""
|
||||
|
@ -85,17 +85,19 @@ wrappers used to provide optional zero value support. Each of these has a specia
|
||||
representation and is handled a little differently from normal messages. The Python
|
||||
mapping for these is as follows:
|
||||
|
||||
+-------------------------------+-----------------------------------------------+--------------------------+
|
||||
| ``Google Message`` | ``Python Type`` | ``Default`` |
|
||||
+===============================+===============================================+==========================+
|
||||
| ``google.protobuf.duration`` | :class:`datetime.timedelta` | ``0`` |
|
||||
+-------------------------------+-----------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.timestamp`` | ``Timezone-aware`` :class:`datetime.datetime` | ``1970-01-01T00:00:00Z`` |
|
||||
+-------------------------------+-----------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.*Value`` | ``Optional[...]``/``None`` | ``None`` |
|
||||
+-------------------------------+-----------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.*`` | ``betterproto.lib.google.protobuf.*`` | ``None`` |
|
||||
+-------------------------------+-----------------------------------------------+--------------------------+
|
||||
+-------------------------------+-------------------------------------------------+--------------------------+
|
||||
| ``Google Message`` | ``Python Type`` | ``Default`` |
|
||||
+===============================+=================================================+==========================+
|
||||
| ``google.protobuf.duration`` | :class:`datetime.timedelta` | ``0`` |
|
||||
+-------------------------------+-------------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.timestamp`` | ``Timezone-aware`` :class:`datetime.datetime` | ``1970-01-01T00:00:00Z`` |
|
||||
+-------------------------------+-------------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.*Value`` | ``Optional[...]``/``None`` | ``None`` |
|
||||
+-------------------------------+-------------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.*`` | ``betterproto.lib.std.google.protobuf.*`` | ``None`` |
|
||||
+-------------------------------+-------------------------------------------------+--------------------------+
|
||||
| ``google.protobuf.*`` | ``betterproto.lib.pydantic.google.protobuf.*`` | ``None`` |
|
||||
+-------------------------------+-------------------------------------------------+--------------------------+
|
||||
|
||||
|
||||
For the wrapper types, the Python type corresponds to the wrapped type, e.g.
|
||||
|
2684
poetry.lock
generated
2684
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
130
pyproject.toml
130
pyproject.toml
@ -1,8 +1,10 @@
|
||||
[tool.poetry]
|
||||
[project]
|
||||
name = "betterproto"
|
||||
version = "2.0.0b6"
|
||||
version = "2.0.0b7"
|
||||
description = "A better Protobuf / gRPC generator & library"
|
||||
authors = ["Daniel G. Taylor <danielgtaylor@gmail.com>"]
|
||||
authors = [
|
||||
{name = "Daniel G. Taylor", email = "danielgtaylor@gmail.com"}
|
||||
]
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/danielgtaylor/python-betterproto"
|
||||
keywords = ["protobuf", "gRPC"]
|
||||
@ -10,42 +12,54 @@ license = "MIT"
|
||||
packages = [
|
||||
{ include = "betterproto", from = "src" }
|
||||
]
|
||||
requires-python = ">=3.9,<4.0"
|
||||
dynamic = ["dependencies"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
black = { version = ">=23.1.0", optional = true }
|
||||
# The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml
|
||||
ruff = { version = "~0.9.1", optional = true }
|
||||
grpclib = "^0.4.1"
|
||||
importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
|
||||
jinja2 = { version = ">=3.0.3", optional = true }
|
||||
python-dateutil = "^2.8"
|
||||
isort = {version = "^5.11.5", optional = true}
|
||||
typing-extensions = "^4.7.1"
|
||||
betterproto-rust-codec = { version = "0.1.1", optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
asv = "^0.4.2"
|
||||
bpython = "^0.19"
|
||||
grpcio-tools = "^1.54.2"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
asv = "^0.6.4"
|
||||
bpython = "^0.24"
|
||||
jinja2 = ">=3.0.3"
|
||||
mypy = "^0.930"
|
||||
mypy = "^1.11.2"
|
||||
sphinx = "7.4.7"
|
||||
sphinx-rtd-theme = "3.0.2"
|
||||
pre-commit = "^4.0.1"
|
||||
grpcio-tools = "^1.54.2"
|
||||
tox = "^4.0.0"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
poethepoet = ">=0.9.0"
|
||||
protobuf = "^4.21.6"
|
||||
pytest = "^6.2.5"
|
||||
pytest-asyncio = "^0.12.0"
|
||||
pytest-cov = "^2.9.0"
|
||||
pytest = "^7.4.4"
|
||||
pytest-asyncio = "^0.23.8"
|
||||
pytest-cov = "^6.0.0"
|
||||
pytest-mock = "^3.1.1"
|
||||
sphinx = "3.1.2"
|
||||
sphinx-rtd-theme = "0.5.0"
|
||||
tomlkit = "^0.7.0"
|
||||
tox = "^3.15.1"
|
||||
pre-commit = "^2.17.0"
|
||||
pydantic = ">=1.8.0"
|
||||
pydantic = ">=2.0,<3"
|
||||
protobuf = "^5"
|
||||
cachelib = "^0.13.0"
|
||||
tomlkit = ">=0.7.0"
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
[project.scripts]
|
||||
protoc-gen-python_betterproto = "betterproto.plugin:main"
|
||||
|
||||
[tool.poetry.extras]
|
||||
compiler = ["black", "isort", "jinja2"]
|
||||
[project.optional-dependencies]
|
||||
compiler = ["ruff", "jinja2"]
|
||||
rust-codec = ["betterproto-rust-codec"]
|
||||
|
||||
[tool.ruff]
|
||||
extend-exclude = ["tests/output_*"]
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
lines-after-imports = 2
|
||||
|
||||
# Dev workflow tasks
|
||||
|
||||
@ -62,8 +76,28 @@ cmd = "mypy src --ignore-missing-imports"
|
||||
help = "Check types with mypy"
|
||||
|
||||
[tool.poe.tasks.format]
|
||||
cmd = "black . --exclude tests/output_ --target-version py310"
|
||||
help = "Apply black formatting to source code"
|
||||
sequence = ["_format", "_sort-imports"]
|
||||
help = "Format the source code, and sort the imports"
|
||||
|
||||
[tool.poe.tasks.check]
|
||||
sequence = ["_check-format", "_check-imports"]
|
||||
help = "Check that the source code is formatted and the imports sorted"
|
||||
|
||||
[tool.poe.tasks._format]
|
||||
cmd = "ruff format src tests"
|
||||
help = "Format the source code without sorting the imports"
|
||||
|
||||
[tool.poe.tasks._sort-imports]
|
||||
cmd = "ruff check --select I --fix src tests"
|
||||
help = "Sort the imports"
|
||||
|
||||
[tool.poe.tasks._check-format]
|
||||
cmd = "ruff format --diff src tests"
|
||||
help = "Check that the source code is formatted"
|
||||
|
||||
[tool.poe.tasks._check-imports]
|
||||
cmd = "ruff check --select I src tests"
|
||||
help = "Check that the imports are sorted"
|
||||
|
||||
[tool.poe.tasks.docs]
|
||||
cmd = "sphinx-build docs docs/build"
|
||||
@ -86,11 +120,11 @@ cmd = """
|
||||
protoc
|
||||
--plugin=protoc-gen-custom=src/betterproto/plugin/main.py
|
||||
--custom_opt=INCLUDE_GOOGLE
|
||||
--custom_out=src/betterproto/lib
|
||||
--custom_out=src/betterproto/lib/std
|
||||
-I C:\\work\\include
|
||||
C:\\work\\include\\google\\protobuf\\**\\*.proto
|
||||
"""
|
||||
help = "Regenerate the types in betterproto.lib.google"
|
||||
help = "Regenerate the types in betterproto.lib.std.google"
|
||||
|
||||
# CI tasks
|
||||
|
||||
@ -98,23 +132,6 @@ help = "Regenerate the types in betterproto.lib.google"
|
||||
shell = "poe generate && tox"
|
||||
help = "Run tests with multiple pythons"
|
||||
|
||||
[tool.poe.tasks.check-style]
|
||||
cmd = "black . --check --diff"
|
||||
help = "Check if code style is correct"
|
||||
|
||||
[tool.isort]
|
||||
py_version = 37
|
||||
profile = "black"
|
||||
force_single_line = false
|
||||
combine_as_imports = true
|
||||
lines_after_imports = 2
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 2
|
||||
src_paths = ["src", "tests"]
|
||||
|
||||
[tool.black]
|
||||
target-version = ['py37']
|
||||
|
||||
[tool.doc8]
|
||||
paths = ["docs"]
|
||||
max_line_length = 88
|
||||
@ -130,16 +147,23 @@ omit = ["betterproto/tests/*"]
|
||||
[tool.tox]
|
||||
legacy_tox_ini = """
|
||||
[tox]
|
||||
isolated_build = true
|
||||
envlist = py37, py38, py310
|
||||
requires =
|
||||
tox>=4.2
|
||||
tox-poetry-installer[poetry]==1.0.0b1
|
||||
env_list =
|
||||
py311
|
||||
py38
|
||||
py37
|
||||
|
||||
[testenv]
|
||||
whitelist_externals = poetry
|
||||
commands =
|
||||
poetry install -v --extras compiler
|
||||
poetry run pytest --cov betterproto
|
||||
pytest {posargs: --cov betterproto}
|
||||
poetry_dep_groups =
|
||||
test
|
||||
require_locked_deps = true
|
||||
require_poetry = true
|
||||
"""
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0,<2"]
|
||||
requires = ["poetry-core>=2.0.0,<3"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import enum as builtin_enum
|
||||
import json
|
||||
import math
|
||||
import struct
|
||||
@ -22,8 +24,8 @@ from itertools import count
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
@ -37,6 +39,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._types import T
|
||||
from ._version import __version__
|
||||
@ -45,11 +48,25 @@ from .casing import (
|
||||
safe_snake_case,
|
||||
snake_case,
|
||||
)
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
from .enum import Enum as Enum
|
||||
from .grpc.grpclib_client import ServiceStub as ServiceStub
|
||||
from .utils import (
|
||||
classproperty,
|
||||
hybridmethod,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import ReadableBuffer
|
||||
from _typeshed import (
|
||||
SupportsRead,
|
||||
SupportsWrite,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType as _types_UnionType
|
||||
else:
|
||||
|
||||
class _types_UnionType: ...
|
||||
|
||||
|
||||
# Proto 3 data types
|
||||
@ -126,6 +143,9 @@ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
||||
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
# Indicator of message delimitation in streams
|
||||
SIZE_DELIMITED = -1
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
def datetime_default_gen() -> datetime:
|
||||
@ -134,20 +154,36 @@ def datetime_default_gen() -> datetime:
|
||||
|
||||
DATETIME_ZERO = datetime_default_gen()
|
||||
|
||||
|
||||
# Special protobuf json doubles
|
||||
INFINITY = "Infinity"
|
||||
NEG_INFINITY = "-Infinity"
|
||||
NAN = "NaN"
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
class Casing(builtin_enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
CAMEL = camel_case #: A camelCase sterilization function.
|
||||
SNAKE = snake_case #: A snake_case sterilization function.
|
||||
|
||||
|
||||
PLACEHOLDER: Any = object()
|
||||
class Placeholder:
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<PLACEHOLDER>"
|
||||
|
||||
def __copy__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, _) -> Self:
|
||||
return self
|
||||
|
||||
|
||||
# We can't simply use object() here because pydantic automatically performs deep-copy of mutable default values
|
||||
# See #606
|
||||
PLACEHOLDER: Any = Placeholder()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -184,7 +220,7 @@ def dataclass_field(
|
||||
) -> dataclasses.Field:
|
||||
"""Creates a dataclass field with attached protobuf metadata."""
|
||||
return dataclasses.field(
|
||||
default=None if optional else PLACEHOLDER,
|
||||
default=None if optional else PLACEHOLDER, # type: ignore
|
||||
metadata={
|
||||
"betterproto": FieldMetadata(
|
||||
number, proto_type, map_types, group, wraps, optional
|
||||
@ -309,32 +345,6 @@ def map_field(
|
||||
)
|
||||
|
||||
|
||||
class Enum(enum.IntEnum):
|
||||
"""
|
||||
The base class for protobuf enumerations, all generated enumerations will inherit
|
||||
from this. Bases :class:`enum.IntEnum`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str) -> "Enum":
|
||||
"""Return the value which corresponds to the string name.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the enum member to get
|
||||
|
||||
Raises
|
||||
-------
|
||||
:exc:`ValueError`
|
||||
The member was not found in the Enum.
|
||||
"""
|
||||
try:
|
||||
return cls._member_map_[name] # type: ignore
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
||||
|
||||
|
||||
def _pack_fmt(proto_type: str) -> str:
|
||||
"""Returns a little-endian format string for reading/writing binary."""
|
||||
return {
|
||||
@ -347,7 +357,7 @@ def _pack_fmt(proto_type: str) -> str:
|
||||
}[proto_type]
|
||||
|
||||
|
||||
def dump_varint(value: int, stream: BinaryIO) -> None:
|
||||
def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
|
||||
"""Encodes a single varint and dumps it into the provided stream."""
|
||||
if value < -(1 << 63):
|
||||
raise ValueError(
|
||||
@ -556,7 +566,7 @@ def _dump_float(value: float) -> Union[float, str]:
|
||||
return value
|
||||
|
||||
|
||||
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
|
||||
def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
|
||||
"""
|
||||
Load a single varint value from a stream. Returns the value and the raw bytes read.
|
||||
"""
|
||||
@ -594,7 +604,7 @@ class ParsedField:
|
||||
raw: bytes
|
||||
|
||||
|
||||
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
|
||||
def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
|
||||
while True:
|
||||
try:
|
||||
num_wire, raw = load_varint(stream)
|
||||
@ -748,6 +758,7 @@ class Message(ABC):
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_current: Dict[str, str]
|
||||
_betterproto_meta: ClassVar[ProtoClassMetadata]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Keep track of whether every field was default
|
||||
@ -760,7 +771,7 @@ class Message(ABC):
|
||||
group_current.setdefault(meta.group)
|
||||
|
||||
value = self.__raw_get(field_name)
|
||||
if value != PLACEHOLDER and not (meta.optional and value is None):
|
||||
if value is not PLACEHOLDER and not (meta.optional and value is None):
|
||||
# Found a non-sentinel value
|
||||
all_sentinel = False
|
||||
|
||||
@ -815,6 +826,10 @@ class Message(ABC):
|
||||
]
|
||||
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||
|
||||
def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
|
||||
for field_name in self._betterproto.sorted_field_names:
|
||||
yield field_name, self.__raw_get(field_name), PLACEHOLDER
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
@ -889,20 +904,28 @@ class Message(ABC):
|
||||
kwargs[name] = deepcopy(value)
|
||||
return self.__class__(**kwargs) # type: ignore
|
||||
|
||||
@property
|
||||
def _betterproto(self) -> ProtoClassMetadata:
|
||||
def __copy__(self: T, _: Any = {}) -> T:
|
||||
kwargs = {}
|
||||
for name in self._betterproto.sorted_field_names:
|
||||
value = self.__raw_get(name)
|
||||
if value is not PLACEHOLDER:
|
||||
kwargs[name] = value
|
||||
return self.__class__(**kwargs) # type: ignore
|
||||
|
||||
@classproperty
|
||||
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
|
||||
"""
|
||||
Lazy initialize metadata for each protobuf class.
|
||||
It may be initialized multiple times in a multi-threaded environment,
|
||||
but that won't affect the correctness.
|
||||
"""
|
||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||
if not meta:
|
||||
meta = ProtoClassMetadata(self.__class__)
|
||||
self.__class__._betterproto_meta = meta # type: ignore
|
||||
return meta
|
||||
try:
|
||||
return cls._betterproto_meta
|
||||
except AttributeError:
|
||||
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
|
||||
return meta
|
||||
|
||||
def dump(self, stream: BinaryIO) -> None:
|
||||
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
|
||||
"""
|
||||
Dumps the binary encoded Protobuf message to the stream.
|
||||
|
||||
@ -910,7 +933,11 @@ class Message(ABC):
|
||||
-----------
|
||||
stream: :class:`BinaryIO`
|
||||
The stream to dump the message to.
|
||||
delimit:
|
||||
Whether to prefix the message with a varint declaring its size.
|
||||
"""
|
||||
if delimit == SIZE_DELIMITED:
|
||||
dump_varint(len(self), stream)
|
||||
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
try:
|
||||
@ -930,7 +957,7 @@ class Message(ABC):
|
||||
# Note that proto3 field presence/optional fields are put in a
|
||||
# synthetic single-item oneof by protoc, which helps us ensure we
|
||||
# send the value even if the value is the default zero value.
|
||||
selected_in_group = bool(meta.group)
|
||||
selected_in_group = bool(meta.group) or meta.optional
|
||||
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or received empty).
|
||||
@ -1124,6 +1151,15 @@ class Message(ABC):
|
||||
"""
|
||||
return bytes(self)
|
||||
|
||||
def __getstate__(self) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
def __setstate__(self: T, pickled_bytes: bytes) -> T:
|
||||
return self.parse(pickled_bytes)
|
||||
|
||||
def __reduce__(self) -> Tuple[Any, ...]:
|
||||
return (self.__class__.FromString, (bytes(self),))
|
||||
|
||||
@classmethod
|
||||
def _type_hint(cls, field_name: str) -> Type:
|
||||
return cls._type_hints()[field_name]
|
||||
@ -1152,30 +1188,29 @@ class Message(ABC):
|
||||
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
|
||||
t = cls._type_hint(field.name)
|
||||
|
||||
if hasattr(t, "__origin__"):
|
||||
if t.__origin__ is dict:
|
||||
# This is some kind of map (dict in Python).
|
||||
return dict
|
||||
elif t.__origin__ is list:
|
||||
# This is some kind of list (repeated) field.
|
||||
return list
|
||||
elif t.__origin__ is Union and t.__args__[1] is type(None):
|
||||
is_310_union = isinstance(t, _types_UnionType)
|
||||
if hasattr(t, "__origin__") or is_310_union:
|
||||
if is_310_union or t.__origin__ is Union:
|
||||
# This is an optional field (either wrapped, or using proto3
|
||||
# field presence). For setting the default we really don't care
|
||||
# what kind of field it is.
|
||||
return type(None)
|
||||
else:
|
||||
return t
|
||||
elif issubclass(t, Enum):
|
||||
if t.__origin__ is list:
|
||||
# This is some kind of list (repeated) field.
|
||||
return list
|
||||
if t.__origin__ is dict:
|
||||
# This is some kind of map (dict in Python).
|
||||
return dict
|
||||
return t
|
||||
if issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
return int
|
||||
elif t is datetime:
|
||||
return t.try_value
|
||||
if t is datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
return datetime_default_gen
|
||||
else:
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
return t
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
return t
|
||||
|
||||
def _postprocess_single(
|
||||
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
|
||||
@ -1193,6 +1228,9 @@ class Message(ABC):
|
||||
elif meta.proto_type == TYPE_BOOL:
|
||||
# Booleans use a varint encoding, so convert it to true/false.
|
||||
value = value > 0
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
# Convert enum ints to python enum instances
|
||||
value = self._betterproto.cls_by_field[field_name].try_value(value)
|
||||
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
@ -1225,7 +1263,11 @@ class Message(ABC):
|
||||
meta.group is not None and self._group_current.get(meta.group) == field_name
|
||||
)
|
||||
|
||||
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
|
||||
def load(
|
||||
self: T,
|
||||
stream: "SupportsRead[bytes]",
|
||||
size: Optional[int] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Load the binary encoded Protobuf from a stream into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
@ -1237,12 +1279,17 @@ class Message(ABC):
|
||||
size: :class:`Optional[int]`
|
||||
The size of the message in the stream.
|
||||
Reads stream until EOF if ``None`` is given.
|
||||
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Message`
|
||||
The initialized message.
|
||||
"""
|
||||
# If the message is delimited, parse the message delimiter
|
||||
if size == SIZE_DELIMITED:
|
||||
size, _ = load_varint(stream)
|
||||
|
||||
# Got some data over the wire
|
||||
self._serialized_on_wire = True
|
||||
proto_meta = self._betterproto
|
||||
@ -1315,7 +1362,7 @@ class Message(ABC):
|
||||
|
||||
return self
|
||||
|
||||
def parse(self: T, data: "ReadableBuffer") -> T:
|
||||
def parse(self: T, data: bytes) -> T:
|
||||
"""
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
@ -1494,7 +1541,91 @@ class Message(ABC):
|
||||
output[cased_name] = value
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: Mapping[str, Any]) -> T:
|
||||
@classmethod
|
||||
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
init_kwargs: Dict[str, Any] = {}
|
||||
for key, value in mapping.items():
|
||||
field_name = safe_snake_case(key)
|
||||
try:
|
||||
meta = cls._betterproto.meta_by_field_name[field_name]
|
||||
except KeyError:
|
||||
continue
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
sub_cls = cls._betterproto.cls_by_field[field_name]
|
||||
if sub_cls == datetime:
|
||||
value = (
|
||||
[isoparse(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else isoparse(value)
|
||||
)
|
||||
elif sub_cls == timedelta:
|
||||
value = (
|
||||
[timedelta(seconds=float(item[:-1])) for item in value]
|
||||
if isinstance(value, list)
|
||||
else timedelta(seconds=float(value[:-1]))
|
||||
)
|
||||
elif not meta.wraps:
|
||||
value = (
|
||||
[sub_cls.from_dict(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else sub_cls.from_dict(value)
|
||||
)
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
|
||||
else:
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
value = (
|
||||
[int(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else int(value)
|
||||
)
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
value = (
|
||||
[b64decode(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else b64decode(value)
|
||||
)
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = cls._betterproto.cls_by_field[field_name]
|
||||
if isinstance(value, list):
|
||||
value = [enum_cls.from_string(e) for e in value]
|
||||
elif isinstance(value, str):
|
||||
value = enum_cls.from_string(value)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
value = (
|
||||
[_parse_float(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else _parse_float(value)
|
||||
)
|
||||
|
||||
init_kwargs[field_name] = value
|
||||
return init_kwargs
|
||||
|
||||
@hybridmethod
|
||||
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
|
||||
"""
|
||||
Parse the key/value pairs into the a new message instance.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
value: Dict[:class:`str`, Any]
|
||||
The dictionary to parse from.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Message`
|
||||
The initialized message.
|
||||
"""
|
||||
self = cls(**cls._from_dict_init(value))
|
||||
self._serialized_on_wire = True
|
||||
return self
|
||||
|
||||
@from_dict.instancemethod
|
||||
def from_dict(self, value: Mapping[str, Any]) -> Self:
|
||||
"""
|
||||
Parse the key/value pairs into the current message instance. This returns the
|
||||
instance itself and is therefore assignable and chainable.
|
||||
@ -1510,71 +1641,8 @@ class Message(ABC):
|
||||
The initialized message.
|
||||
"""
|
||||
self._serialized_on_wire = True
|
||||
for key in value:
|
||||
field_name = safe_snake_case(key)
|
||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
v = self._get_field_default(field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
if cls == datetime:
|
||||
v = [isoparse(item) for item in value[key]]
|
||||
elif cls == timedelta:
|
||||
v = [
|
||||
timedelta(seconds=float(item[:-1]))
|
||||
for item in value[key]
|
||||
]
|
||||
else:
|
||||
v = [cls().from_dict(item) for item in value[key]]
|
||||
elif cls == datetime:
|
||||
v = isoparse(value[key])
|
||||
setattr(self, field_name, v)
|
||||
elif cls == timedelta:
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field_name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field_name, value[key])
|
||||
elif v is None:
|
||||
setattr(self, field_name, cls().from_dict(value[key]))
|
||||
else:
|
||||
# NOTE: `from_dict` mutates the underlying message, so no
|
||||
# assignment here is necessary.
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if isinstance(value[key], list):
|
||||
v = [_parse_float(n) for n in value[key]]
|
||||
else:
|
||||
v = _parse_float(value[key])
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
for field, value in self._from_dict_init(value).items():
|
||||
setattr(self, field, value)
|
||||
return self
|
||||
|
||||
def to_json(
|
||||
@ -1791,8 +1859,8 @@ class Message(ABC):
|
||||
|
||||
@classmethod
|
||||
def _validate_field_groups(cls, values):
|
||||
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
||||
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
|
||||
group_to_one_ofs = cls._betterproto.oneof_field_by_group
|
||||
field_name_to_meta = cls._betterproto.meta_by_field_name
|
||||
|
||||
for group, field_set in group_to_one_ofs.items():
|
||||
if len(field_set) == 1:
|
||||
@ -1805,12 +1873,12 @@ class Message(ABC):
|
||||
continue
|
||||
|
||||
set_fields = [
|
||||
field.name for field in field_set if values[field.name] is not None
|
||||
field.name
|
||||
for field in field_set
|
||||
if getattr(values, field.name, None) is not None
|
||||
]
|
||||
|
||||
if not set_fields:
|
||||
raise ValueError(f"Group {group} has no value; all fields are None")
|
||||
elif len(set_fields) > 1:
|
||||
if len(set_fields) > 1:
|
||||
set_fields_str = ", ".join(set_fields)
|
||||
raise ValueError(
|
||||
f"Group {group} has more than one value; fields {set_fields_str} are not None"
|
||||
@ -1819,6 +1887,26 @@ class Message(ABC):
|
||||
return values
|
||||
|
||||
|
||||
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
|
||||
|
||||
# monkey patch (de-)serialization functions of class `Message`
|
||||
# with functions from `betterproto-rust-codec` if available
|
||||
try:
|
||||
import betterproto_rust_codec
|
||||
|
||||
def __parse_patch(self: T, data: bytes) -> T:
|
||||
betterproto_rust_codec.deserialize(self, data)
|
||||
return self
|
||||
|
||||
def __bytes_patch(self) -> bytes:
|
||||
return betterproto_rust_codec.serialize(self)
|
||||
|
||||
Message.parse = __parse_patch
|
||||
Message.__bytes__ = __bytes_patch
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def serialized_on_wire(message: Message) -> bool:
|
||||
"""
|
||||
If this message was or should be serialized on the wire. This can be used to detect
|
||||
@ -1890,17 +1978,26 @@ class _Duration(Duration):
|
||||
class _Timestamp(Timestamp):
|
||||
@classmethod
|
||||
def from_datetime(cls, dt: datetime) -> "_Timestamp":
|
||||
seconds = int(dt.timestamp())
|
||||
nanos = int(dt.microsecond * 1e3)
|
||||
return cls(seconds, nanos)
|
||||
# manual epoch offset calulation to avoid rounding errors,
|
||||
# to support negative timestamps (before 1970) and skirt
|
||||
# around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
offset = dt - DATETIME_ZERO
|
||||
# below is the same as timedelta.total_seconds() but without dividing by 1e6
|
||||
# so we end up with microseconds as integers instead of seconds as float
|
||||
offset_us = (
|
||||
offset.days * 24 * 60 * 60 + offset.seconds
|
||||
) * 10**6 + offset.microseconds
|
||||
seconds, us = divmod(offset_us, 10**6)
|
||||
return cls(seconds, us * 1000)
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
|
||||
if ts < 0:
|
||||
return datetime(1970, 1, 1) + timedelta(seconds=ts)
|
||||
else:
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
# datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
|
||||
# if we pass it as a floating point number, we will run into rounding errors
|
||||
# see also #407
|
||||
offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
|
||||
return DATETIME_ZERO + offset
|
||||
|
||||
@staticmethod
|
||||
def timestamp_to_json(dt: datetime) -> str:
|
||||
@ -1916,10 +2013,10 @@ class _Timestamp(Timestamp):
|
||||
return f"{result}Z"
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return f"{result}.{int(nanos // 1e6) :03d}Z"
|
||||
return f"{result}.{int(nanos // 1e6):03d}Z"
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return f"{result}.{int(nanos // 1e3) :06d}Z"
|
||||
return f"{result}.{int(nanos // 1e3):06d}Z"
|
||||
# Serialize 9 fractional digits.
|
||||
return f"{result}.{nanos:09d}"
|
||||
|
||||
|
@ -136,4 +136,8 @@ def lowercase_first(value: str) -> str:
|
||||
|
||||
def sanitize_name(value: str) -> str:
|
||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||
return f"{value}_" if keyword.iskeyword(value) else value
|
||||
if keyword.iskeyword(value):
|
||||
return f"{value}_"
|
||||
if not value.isidentifier():
|
||||
return f"_{value}"
|
||||
return value
|
||||
|
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
List,
|
||||
Set,
|
||||
@ -13,6 +16,9 @@ from ..lib.google import protobuf as google_protobuf
|
||||
from .naming import pythonize_class_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..plugin.typing_compiler import TypingCompiler
|
||||
|
||||
WRAPPER_TYPES: Dict[str, Type] = {
|
||||
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||
@ -43,7 +49,13 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
|
||||
|
||||
|
||||
def get_type_reference(
|
||||
*, package: str, imports: set, source_type: str, unwrap: bool = True
|
||||
*,
|
||||
package: str,
|
||||
imports: set,
|
||||
source_type: str,
|
||||
typing_compiler: TypingCompiler,
|
||||
unwrap: bool = True,
|
||||
pydantic: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
@ -52,7 +64,7 @@ def get_type_reference(
|
||||
if unwrap:
|
||||
if source_type in WRAPPER_TYPES:
|
||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
return typing_compiler.optional(wrapped_type.__name__)
|
||||
|
||||
if source_type == ".google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
@ -69,7 +81,9 @@ def get_type_reference(
|
||||
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
||||
importing_google_protobuf = py_package == ["google", "protobuf"]
|
||||
if importing_google_protobuf and not compiling_google_protobuf:
|
||||
py_package = ["betterproto", "lib"] + py_package
|
||||
py_package = (
|
||||
["betterproto", "lib"] + (["pydantic"] if pydantic else []) + py_package
|
||||
)
|
||||
|
||||
if py_package[:1] == ["betterproto"]:
|
||||
return reference_absolute(imports, py_package, py_type)
|
||||
|
@ -11,3 +11,11 @@ def pythonize_field_name(name: str) -> str:
|
||||
|
||||
def pythonize_method_name(name: str) -> str:
|
||||
return casing.safe_snake_case(name)
|
||||
|
||||
|
||||
def pythonize_enum_member_name(name: str, enum_name: str) -> str:
|
||||
enum_name = casing.snake_case(enum_name).upper()
|
||||
find = name.find(enum_name)
|
||||
if find != -1:
|
||||
name = name[find + len(enum_name) :].strip("_")
|
||||
return casing.sanitize_name(name)
|
||||
|
197
src/betterproto/enum.py
Normal file
197
src/betterproto/enum.py
Normal file
@ -0,0 +1,197 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import (
|
||||
EnumMeta,
|
||||
IntEnum,
|
||||
)
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
Generator,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Never,
|
||||
Self,
|
||||
)
|
||||
|
||||
|
||||
def _is_descriptor(obj: object) -> bool:
|
||||
return (
|
||||
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
|
||||
)
|
||||
|
||||
|
||||
class EnumType(EnumMeta if TYPE_CHECKING else type):
|
||||
_value_map_: Mapping[int, Enum]
|
||||
_member_map_: Mapping[str, Enum]
|
||||
|
||||
def __new__(
|
||||
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
|
||||
) -> Self:
|
||||
value_map = {}
|
||||
member_map = {}
|
||||
|
||||
new_mcs = type(
|
||||
f"{name}Type",
|
||||
tuple(
|
||||
dict.fromkeys(
|
||||
[base.__class__ for base in bases if base.__class__ is not type]
|
||||
+ [EnumType, type]
|
||||
)
|
||||
), # reorder the bases so EnumType and type are last to avoid conflicts
|
||||
{"_value_map_": value_map, "_member_map_": member_map},
|
||||
)
|
||||
|
||||
members = {
|
||||
name: value
|
||||
for name, value in namespace.items()
|
||||
if not _is_descriptor(value) and not name.startswith("__")
|
||||
}
|
||||
|
||||
cls = type.__new__(
|
||||
new_mcs,
|
||||
name,
|
||||
bases,
|
||||
{key: value for key, value in namespace.items() if key not in members},
|
||||
)
|
||||
# this allows us to disallow member access from other members as
|
||||
# members become proper class variables
|
||||
|
||||
for name, value in members.items():
|
||||
member = value_map.get(value)
|
||||
if member is None:
|
||||
member = cls.__new__(cls, name=name, value=value) # type: ignore
|
||||
value_map[value] = member
|
||||
member_map[name] = member
|
||||
type.__setattr__(new_mcs, name, member)
|
||||
|
||||
return cls
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __call__(cls, value: int) -> Enum:
|
||||
try:
|
||||
return cls._value_map_[value]
|
||||
except (KeyError, TypeError):
|
||||
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
|
||||
|
||||
def __iter__(cls) -> Generator[Enum, None, None]:
|
||||
yield from cls._member_map_.values()
|
||||
|
||||
def __reversed__(cls) -> Generator[Enum, None, None]:
|
||||
yield from reversed(cls._member_map_.values())
|
||||
|
||||
def __getitem__(cls, key: str) -> Enum:
|
||||
return cls._member_map_[key]
|
||||
|
||||
@property
|
||||
def __members__(cls) -> MappingProxyType[str, Enum]:
|
||||
return MappingProxyType(cls._member_map_)
|
||||
|
||||
def __repr__(cls) -> str:
|
||||
return f"<enum {cls.__name__!r}>"
|
||||
|
||||
def __len__(cls) -> int:
|
||||
return len(cls._member_map_)
|
||||
|
||||
def __setattr__(cls, name: str, value: Any) -> Never:
|
||||
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
|
||||
|
||||
def __delattr__(cls, name: str) -> Never:
|
||||
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
|
||||
|
||||
def __contains__(cls, member: object) -> bool:
|
||||
return isinstance(member, cls) and member.name in cls._member_map_
|
||||
|
||||
|
||||
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
|
||||
"""
|
||||
The base class for protobuf enumerations, all generated enumerations will
|
||||
inherit from this. Emulates `enum.IntEnum`.
|
||||
"""
|
||||
|
||||
name: Optional[str]
|
||||
value: int
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __new__(cls, *, name: Optional[str], value: int) -> Self:
|
||||
self = super().__new__(cls, value)
|
||||
super().__setattr__(self, "name", name)
|
||||
super().__setattr__(self, "value", value)
|
||||
return self
|
||||
|
||||
def __getnewargs_ex__(self) -> Tuple[Tuple[()], Dict[str, Any]]:
|
||||
return (), {"name": self.name, "value": self.value}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name or "None"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> Never:
|
||||
raise AttributeError(
|
||||
f"{self.__class__.__name__} Cannot reassign a member's attributes."
|
||||
)
|
||||
|
||||
def __delattr__(self, item: Any) -> Never:
|
||||
raise AttributeError(
|
||||
f"{self.__class__.__name__} Cannot delete a member's attributes."
|
||||
)
|
||||
|
||||
def __copy__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> Self:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def try_value(cls, value: int = 0) -> Self:
|
||||
"""Return the value which corresponds to the value.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
value: :class:`int`
|
||||
The value of the enum member to get.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`Enum`
|
||||
The corresponding member or a new instance of the enum if
|
||||
``value`` isn't actually a member.
|
||||
"""
|
||||
try:
|
||||
return cls._value_map_[value]
|
||||
except (KeyError, TypeError):
|
||||
return cls.__new__(cls, name=None, value=value)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str) -> Self:
|
||||
"""Return the value which corresponds to the string name.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the enum member to get.
|
||||
|
||||
Raises
|
||||
-------
|
||||
:exc:`ValueError`
|
||||
The member was not found in the Enum.
|
||||
"""
|
||||
try:
|
||||
return cls._member_map_[name]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
@ -127,6 +127,7 @@ class ServiceStub(ABC):
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
await stream.send_request()
|
||||
await self._send_messages(stream, request_iterator)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,152 +1 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: google/protobuf/compiler/plugin.proto
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import betterproto
|
||||
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
|
||||
|
||||
|
||||
class CodeGeneratorResponseFeature(betterproto.Enum):
|
||||
"""Sync with code_generator.h."""
|
||||
|
||||
FEATURE_NONE = 0
|
||||
FEATURE_PROTO3_OPTIONAL = 1
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Version(betterproto.Message):
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
major: int = betterproto.int32_field(1)
|
||||
minor: int = betterproto.int32_field(2)
|
||||
patch: int = betterproto.int32_field(3)
|
||||
suffix: str = betterproto.string_field(4)
|
||||
"""
|
||||
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
|
||||
be empty for mainline stable releases.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorRequest(betterproto.Message):
|
||||
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
|
||||
|
||||
file_to_generate: List[str] = betterproto.string_field(1)
|
||||
"""
|
||||
The .proto files that were explicitly listed on the command-line. The code
|
||||
generator should generate code only for these files. Each file's
|
||||
descriptor will be included in proto_file, below.
|
||||
"""
|
||||
|
||||
parameter: str = betterproto.string_field(2)
|
||||
"""The generator parameter passed on the command-line."""
|
||||
|
||||
proto_file: List[
|
||||
"betterproto_lib_google_protobuf.FileDescriptorProto"
|
||||
] = betterproto.message_field(15)
|
||||
"""
|
||||
FileDescriptorProtos for all files in files_to_generate and everything they
|
||||
import. The files will appear in topological order, so each file appears
|
||||
before any file that imports it. protoc guarantees that all proto_files
|
||||
will be written after the fields above, even though this is not technically
|
||||
guaranteed by the protobuf wire format. This theoretically could allow a
|
||||
plugin to stream in the FileDescriptorProtos and handle them one by one
|
||||
rather than read the entire set into memory at once. However, as of this
|
||||
writing, this is not similarly optimized on protoc's end -- it will store
|
||||
all fields in memory at once before sending them to the plugin. Type names
|
||||
of fields and extensions in the FileDescriptorProto are always fully
|
||||
qualified.
|
||||
"""
|
||||
|
||||
compiler_version: "Version" = betterproto.message_field(3)
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponse(betterproto.Message):
|
||||
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
|
||||
|
||||
error: str = betterproto.string_field(1)
|
||||
"""
|
||||
Error message. If non-empty, code generation failed. The plugin process
|
||||
should exit with status code zero even if it reports an error in this way.
|
||||
This should be used to indicate errors in .proto files which prevent the
|
||||
code generator from generating correct code. Errors which indicate a
|
||||
problem in protoc itself -- such as the input CodeGeneratorRequest being
|
||||
unparseable -- should be reported by writing a message to stderr and
|
||||
exiting with a non-zero status code.
|
||||
"""
|
||||
|
||||
supported_features: int = betterproto.uint64_field(2)
|
||||
"""
|
||||
A bitmask of supported features that the code generator supports. This is a
|
||||
bitwise "or" of values from the Feature enum.
|
||||
"""
|
||||
|
||||
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponseFile(betterproto.Message):
|
||||
"""Represents a single generated file."""
|
||||
|
||||
name: str = betterproto.string_field(1)
|
||||
"""
|
||||
The file name, relative to the output directory. The name must not contain
|
||||
"." or ".." components and must be relative, not be absolute (so, the file
|
||||
cannot lie outside the output directory). "/" must be used as the path
|
||||
separator, not "\". If the name is omitted, the content will be appended to
|
||||
the previous file. This allows the generator to break large files into
|
||||
small chunks, and allows the generated text to be streamed back to protoc
|
||||
so that large files need not reside completely in memory at one time. Note
|
||||
that as of this writing protoc does not optimize for this -- it will read
|
||||
the entire CodeGeneratorResponse before writing files to disk.
|
||||
"""
|
||||
|
||||
insertion_point: str = betterproto.string_field(2)
|
||||
"""
|
||||
If non-empty, indicates that the named file should already exist, and the
|
||||
content here is to be inserted into that file at a defined insertion point.
|
||||
This feature allows a code generator to extend the output produced by
|
||||
another code generator. The original generator may provide insertion
|
||||
points by placing special annotations in the file that look like:
|
||||
@@protoc_insertion_point(NAME) The annotation can have arbitrary text
|
||||
before and after it on the line, which allows it to be placed in a comment.
|
||||
NAME should be replaced with an identifier naming the point -- this is what
|
||||
other generators will use as the insertion_point. Code inserted at this
|
||||
point will be placed immediately above the line containing the insertion
|
||||
point (thus multiple insertions to the same point will come out in the
|
||||
order they were added). The double-@ is intended to make it unlikely that
|
||||
the generated code could contain things that look like insertion points by
|
||||
accident. For example, the C++ code generator places the following line in
|
||||
the .pb.h files that it generates: //
|
||||
@@protoc_insertion_point(namespace_scope) This line appears within the
|
||||
scope of the file's package namespace, but outside of any particular class.
|
||||
Another plugin can then specify the insertion_point "namespace_scope" to
|
||||
generate additional classes or other declarations that should be placed in
|
||||
this scope. Note that if the line containing the insertion point begins
|
||||
with whitespace, the same whitespace will be added to every line of the
|
||||
inserted text. This is useful for languages like Python, where indentation
|
||||
matters. In these languages, the insertion point comment should be
|
||||
indented the same amount as any inserted code will need to be in order to
|
||||
work correctly in that context. The code generator that generates the
|
||||
initial file and the one which inserts into it must both run as part of a
|
||||
single invocation of protoc. Code generators are executed in the order in
|
||||
which they appear on the command line. If |insertion_point| is present,
|
||||
|name| must also be present.
|
||||
"""
|
||||
|
||||
content: str = betterproto.string_field(15)
|
||||
"""The file contents."""
|
||||
|
||||
generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = (
|
||||
betterproto.message_field(16)
|
||||
)
|
||||
"""
|
||||
Information describing the file content being inserted. If an insertion
|
||||
point is used, this information will be appropriately offset and inserted
|
||||
into the code generation metadata for the generated files.
|
||||
"""
|
||||
from betterproto.lib.std.google.protobuf.compiler import *
|
||||
|
0
src/betterproto/lib/pydantic/__init__.py
Normal file
0
src/betterproto/lib/pydantic/__init__.py
Normal file
0
src/betterproto/lib/pydantic/google/__init__.py
Normal file
0
src/betterproto/lib/pydantic/google/__init__.py
Normal file
2673
src/betterproto/lib/pydantic/google/protobuf/__init__.py
Normal file
2673
src/betterproto/lib/pydantic/google/protobuf/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,210 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: google/protobuf/compiler/plugin.proto
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from typing import List
|
||||
|
||||
import betterproto
|
||||
import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf
|
||||
|
||||
|
||||
class CodeGeneratorResponseFeature(betterproto.Enum):
|
||||
"""Sync with code_generator.h."""
|
||||
|
||||
FEATURE_NONE = 0
|
||||
FEATURE_PROTO3_OPTIONAL = 1
|
||||
FEATURE_SUPPORTS_EDITIONS = 2
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Version(betterproto.Message):
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
major: int = betterproto.int32_field(1)
|
||||
minor: int = betterproto.int32_field(2)
|
||||
patch: int = betterproto.int32_field(3)
|
||||
suffix: str = betterproto.string_field(4)
|
||||
"""
|
||||
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
|
||||
be empty for mainline stable releases.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorRequest(betterproto.Message):
|
||||
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
|
||||
|
||||
file_to_generate: List[str] = betterproto.string_field(1)
|
||||
"""
|
||||
The .proto files that were explicitly listed on the command-line. The
|
||||
code generator should generate code only for these files. Each file's
|
||||
descriptor will be included in proto_file, below.
|
||||
"""
|
||||
|
||||
parameter: str = betterproto.string_field(2)
|
||||
"""The generator parameter passed on the command-line."""
|
||||
|
||||
proto_file: List["betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"] = (
|
||||
betterproto.message_field(15)
|
||||
)
|
||||
"""
|
||||
FileDescriptorProtos for all files in files_to_generate and everything
|
||||
they import. The files will appear in topological order, so each file
|
||||
appears before any file that imports it.
|
||||
|
||||
Note: the files listed in files_to_generate will include runtime-retention
|
||||
options only, but all other files will include source-retention options.
|
||||
The source_file_descriptors field below is available in case you need
|
||||
source-retention options for files_to_generate.
|
||||
|
||||
protoc guarantees that all proto_files will be written after
|
||||
the fields above, even though this is not technically guaranteed by the
|
||||
protobuf wire format. This theoretically could allow a plugin to stream
|
||||
in the FileDescriptorProtos and handle them one by one rather than read
|
||||
the entire set into memory at once. However, as of this writing, this
|
||||
is not similarly optimized on protoc's end -- it will store all fields in
|
||||
memory at once before sending them to the plugin.
|
||||
|
||||
Type names of fields and extensions in the FileDescriptorProto are always
|
||||
fully qualified.
|
||||
"""
|
||||
|
||||
source_file_descriptors: List[
|
||||
"betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"
|
||||
] = betterproto.message_field(17)
|
||||
"""
|
||||
File descriptors with all options, including source-retention options.
|
||||
These descriptors are only provided for the files listed in
|
||||
files_to_generate.
|
||||
"""
|
||||
|
||||
compiler_version: "Version" = betterproto.message_field(3)
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponse(betterproto.Message):
|
||||
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
|
||||
|
||||
error: str = betterproto.string_field(1)
|
||||
"""
|
||||
Error message. If non-empty, code generation failed. The plugin process
|
||||
should exit with status code zero even if it reports an error in this way.
|
||||
|
||||
This should be used to indicate errors in .proto files which prevent the
|
||||
code generator from generating correct code. Errors which indicate a
|
||||
problem in protoc itself -- such as the input CodeGeneratorRequest being
|
||||
unparseable -- should be reported by writing a message to stderr and
|
||||
exiting with a non-zero status code.
|
||||
"""
|
||||
|
||||
supported_features: int = betterproto.uint64_field(2)
|
||||
"""
|
||||
A bitmask of supported features that the code generator supports.
|
||||
This is a bitwise "or" of values from the Feature enum.
|
||||
"""
|
||||
|
||||
minimum_edition: int = betterproto.int32_field(3)
|
||||
"""
|
||||
The minimum edition this plugin supports. This will be treated as an
|
||||
Edition enum, but we want to allow unknown values. It should be specified
|
||||
according the edition enum value, *not* the edition number. Only takes
|
||||
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
|
||||
"""
|
||||
|
||||
maximum_edition: int = betterproto.int32_field(4)
|
||||
"""
|
||||
The maximum edition this plugin supports. This will be treated as an
|
||||
Edition enum, but we want to allow unknown values. It should be specified
|
||||
according the edition enum value, *not* the edition number. Only takes
|
||||
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
|
||||
"""
|
||||
|
||||
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponseFile(betterproto.Message):
|
||||
"""Represents a single generated file."""
|
||||
|
||||
name: str = betterproto.string_field(1)
|
||||
"""
|
||||
The file name, relative to the output directory. The name must not
|
||||
contain "." or ".." components and must be relative, not be absolute (so,
|
||||
the file cannot lie outside the output directory). "/" must be used as
|
||||
the path separator, not "\".
|
||||
|
||||
If the name is omitted, the content will be appended to the previous
|
||||
file. This allows the generator to break large files into small chunks,
|
||||
and allows the generated text to be streamed back to protoc so that large
|
||||
files need not reside completely in memory at one time. Note that as of
|
||||
this writing protoc does not optimize for this -- it will read the entire
|
||||
CodeGeneratorResponse before writing files to disk.
|
||||
"""
|
||||
|
||||
insertion_point: str = betterproto.string_field(2)
|
||||
"""
|
||||
If non-empty, indicates that the named file should already exist, and the
|
||||
content here is to be inserted into that file at a defined insertion
|
||||
point. This feature allows a code generator to extend the output
|
||||
produced by another code generator. The original generator may provide
|
||||
insertion points by placing special annotations in the file that look
|
||||
like:
|
||||
@@protoc_insertion_point(NAME)
|
||||
The annotation can have arbitrary text before and after it on the line,
|
||||
which allows it to be placed in a comment. NAME should be replaced with
|
||||
an identifier naming the point -- this is what other generators will use
|
||||
as the insertion_point. Code inserted at this point will be placed
|
||||
immediately above the line containing the insertion point (thus multiple
|
||||
insertions to the same point will come out in the order they were added).
|
||||
The double-@ is intended to make it unlikely that the generated code
|
||||
could contain things that look like insertion points by accident.
|
||||
|
||||
For example, the C++ code generator places the following line in the
|
||||
.pb.h files that it generates:
|
||||
// @@protoc_insertion_point(namespace_scope)
|
||||
This line appears within the scope of the file's package namespace, but
|
||||
outside of any particular class. Another plugin can then specify the
|
||||
insertion_point "namespace_scope" to generate additional classes or
|
||||
other declarations that should be placed in this scope.
|
||||
|
||||
Note that if the line containing the insertion point begins with
|
||||
whitespace, the same whitespace will be added to every line of the
|
||||
inserted text. This is useful for languages like Python, where
|
||||
indentation matters. In these languages, the insertion point comment
|
||||
should be indented the same amount as any inserted code will need to be
|
||||
in order to work correctly in that context.
|
||||
|
||||
The code generator that generates the initial file and the one which
|
||||
inserts into it must both run as part of a single invocation of protoc.
|
||||
Code generators are executed in the order in which they appear on the
|
||||
command line.
|
||||
|
||||
If |insertion_point| is present, |name| must also be present.
|
||||
"""
|
||||
|
||||
content: str = betterproto.string_field(15)
|
||||
"""The file contents."""
|
||||
|
||||
generated_code_info: "betterproto_lib_pydantic_google_protobuf.GeneratedCodeInfo" = betterproto.message_field(
|
||||
16
|
||||
)
|
||||
"""
|
||||
Information describing the file content being inserted. If an insertion
|
||||
point is used, this information will be appropriately offset and inserted
|
||||
into the code generation metadata for the generated files.
|
||||
"""
|
||||
|
||||
|
||||
CodeGeneratorRequest.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
CodeGeneratorResponse.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
CodeGeneratorResponseFile.__pydantic_model__.update_forward_refs() # type: ignore
|
0
src/betterproto/lib/std/__init__.py
Normal file
0
src/betterproto/lib/std/__init__.py
Normal file
0
src/betterproto/lib/std/google/__init__.py
Normal file
0
src/betterproto/lib/std/google/__init__.py
Normal file
2526
src/betterproto/lib/std/google/protobuf/__init__.py
Normal file
2526
src/betterproto/lib/std/google/protobuf/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
198
src/betterproto/lib/std/google/protobuf/compiler/__init__.py
Normal file
198
src/betterproto/lib/std/google/protobuf/compiler/__init__.py
Normal file
@ -0,0 +1,198 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: google/protobuf/compiler/plugin.proto
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import betterproto
|
||||
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
|
||||
|
||||
|
||||
class CodeGeneratorResponseFeature(betterproto.Enum):
|
||||
"""Sync with code_generator.h."""
|
||||
|
||||
FEATURE_NONE = 0
|
||||
FEATURE_PROTO3_OPTIONAL = 1
|
||||
FEATURE_SUPPORTS_EDITIONS = 2
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Version(betterproto.Message):
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
major: int = betterproto.int32_field(1)
|
||||
minor: int = betterproto.int32_field(2)
|
||||
patch: int = betterproto.int32_field(3)
|
||||
suffix: str = betterproto.string_field(4)
|
||||
"""
|
||||
A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
|
||||
be empty for mainline stable releases.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorRequest(betterproto.Message):
|
||||
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
|
||||
|
||||
file_to_generate: List[str] = betterproto.string_field(1)
|
||||
"""
|
||||
The .proto files that were explicitly listed on the command-line. The
|
||||
code generator should generate code only for these files. Each file's
|
||||
descriptor will be included in proto_file, below.
|
||||
"""
|
||||
|
||||
parameter: str = betterproto.string_field(2)
|
||||
"""The generator parameter passed on the command-line."""
|
||||
|
||||
proto_file: List["betterproto_lib_google_protobuf.FileDescriptorProto"] = (
|
||||
betterproto.message_field(15)
|
||||
)
|
||||
"""
|
||||
FileDescriptorProtos for all files in files_to_generate and everything
|
||||
they import. The files will appear in topological order, so each file
|
||||
appears before any file that imports it.
|
||||
|
||||
Note: the files listed in files_to_generate will include runtime-retention
|
||||
options only, but all other files will include source-retention options.
|
||||
The source_file_descriptors field below is available in case you need
|
||||
source-retention options for files_to_generate.
|
||||
|
||||
protoc guarantees that all proto_files will be written after
|
||||
the fields above, even though this is not technically guaranteed by the
|
||||
protobuf wire format. This theoretically could allow a plugin to stream
|
||||
in the FileDescriptorProtos and handle them one by one rather than read
|
||||
the entire set into memory at once. However, as of this writing, this
|
||||
is not similarly optimized on protoc's end -- it will store all fields in
|
||||
memory at once before sending them to the plugin.
|
||||
|
||||
Type names of fields and extensions in the FileDescriptorProto are always
|
||||
fully qualified.
|
||||
"""
|
||||
|
||||
source_file_descriptors: List[
|
||||
"betterproto_lib_google_protobuf.FileDescriptorProto"
|
||||
] = betterproto.message_field(17)
|
||||
"""
|
||||
File descriptors with all options, including source-retention options.
|
||||
These descriptors are only provided for the files listed in
|
||||
files_to_generate.
|
||||
"""
|
||||
|
||||
compiler_version: "Version" = betterproto.message_field(3)
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponse(betterproto.Message):
|
||||
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
|
||||
|
||||
error: str = betterproto.string_field(1)
|
||||
"""
|
||||
Error message. If non-empty, code generation failed. The plugin process
|
||||
should exit with status code zero even if it reports an error in this way.
|
||||
|
||||
This should be used to indicate errors in .proto files which prevent the
|
||||
code generator from generating correct code. Errors which indicate a
|
||||
problem in protoc itself -- such as the input CodeGeneratorRequest being
|
||||
unparseable -- should be reported by writing a message to stderr and
|
||||
exiting with a non-zero status code.
|
||||
"""
|
||||
|
||||
supported_features: int = betterproto.uint64_field(2)
|
||||
"""
|
||||
A bitmask of supported features that the code generator supports.
|
||||
This is a bitwise "or" of values from the Feature enum.
|
||||
"""
|
||||
|
||||
minimum_edition: int = betterproto.int32_field(3)
|
||||
"""
|
||||
The minimum edition this plugin supports. This will be treated as an
|
||||
Edition enum, but we want to allow unknown values. It should be specified
|
||||
according the edition enum value, *not* the edition number. Only takes
|
||||
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
|
||||
"""
|
||||
|
||||
maximum_edition: int = betterproto.int32_field(4)
|
||||
"""
|
||||
The maximum edition this plugin supports. This will be treated as an
|
||||
Edition enum, but we want to allow unknown values. It should be specified
|
||||
according the edition enum value, *not* the edition number. Only takes
|
||||
effect for plugins that have FEATURE_SUPPORTS_EDITIONS set.
|
||||
"""
|
||||
|
||||
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponseFile(betterproto.Message):
|
||||
"""Represents a single generated file."""
|
||||
|
||||
name: str = betterproto.string_field(1)
|
||||
"""
|
||||
The file name, relative to the output directory. The name must not
|
||||
contain "." or ".." components and must be relative, not be absolute (so,
|
||||
the file cannot lie outside the output directory). "/" must be used as
|
||||
the path separator, not "\".
|
||||
|
||||
If the name is omitted, the content will be appended to the previous
|
||||
file. This allows the generator to break large files into small chunks,
|
||||
and allows the generated text to be streamed back to protoc so that large
|
||||
files need not reside completely in memory at one time. Note that as of
|
||||
this writing protoc does not optimize for this -- it will read the entire
|
||||
CodeGeneratorResponse before writing files to disk.
|
||||
"""
|
||||
|
||||
insertion_point: str = betterproto.string_field(2)
|
||||
"""
|
||||
If non-empty, indicates that the named file should already exist, and the
|
||||
content here is to be inserted into that file at a defined insertion
|
||||
point. This feature allows a code generator to extend the output
|
||||
produced by another code generator. The original generator may provide
|
||||
insertion points by placing special annotations in the file that look
|
||||
like:
|
||||
@@protoc_insertion_point(NAME)
|
||||
The annotation can have arbitrary text before and after it on the line,
|
||||
which allows it to be placed in a comment. NAME should be replaced with
|
||||
an identifier naming the point -- this is what other generators will use
|
||||
as the insertion_point. Code inserted at this point will be placed
|
||||
immediately above the line containing the insertion point (thus multiple
|
||||
insertions to the same point will come out in the order they were added).
|
||||
The double-@ is intended to make it unlikely that the generated code
|
||||
could contain things that look like insertion points by accident.
|
||||
|
||||
For example, the C++ code generator places the following line in the
|
||||
.pb.h files that it generates:
|
||||
// @@protoc_insertion_point(namespace_scope)
|
||||
This line appears within the scope of the file's package namespace, but
|
||||
outside of any particular class. Another plugin can then specify the
|
||||
insertion_point "namespace_scope" to generate additional classes or
|
||||
other declarations that should be placed in this scope.
|
||||
|
||||
Note that if the line containing the insertion point begins with
|
||||
whitespace, the same whitespace will be added to every line of the
|
||||
inserted text. This is useful for languages like Python, where
|
||||
indentation matters. In these languages, the insertion point comment
|
||||
should be indented the same amount as any inserted code will need to be
|
||||
in order to work correctly in that context.
|
||||
|
||||
The code generator that generates the initial file and the one which
|
||||
inserts into it must both run as part of a single invocation of protoc.
|
||||
Code generators are executed in the order in which they appear on the
|
||||
command line.
|
||||
|
||||
If |insertion_point| is present, |name| must also be present.
|
||||
"""
|
||||
|
||||
content: str = betterproto.string_field(15)
|
||||
"""The file contents."""
|
||||
|
||||
generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = (
|
||||
betterproto.message_field(16)
|
||||
)
|
||||
"""
|
||||
Information describing the file content being inserted. If an insertion
|
||||
point is used, this information will be appropriately offset and inserted
|
||||
into the code generation metadata for the generated files.
|
||||
"""
|
@ -1,10 +1,12 @@
|
||||
import os.path
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from .module_validation import ModuleValidator
|
||||
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
import black
|
||||
import isort.api
|
||||
import jinja2
|
||||
except ImportError as err:
|
||||
print(
|
||||
@ -29,22 +31,34 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=jinja2.FileSystemLoader(templates_folder),
|
||||
undefined=jinja2.StrictUndefined,
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
# Load the body first so we have a compleate list of imports needed.
|
||||
body_template = env.get_template("template.py.j2")
|
||||
header_template = env.get_template("header.py.j2")
|
||||
|
||||
code = template.render(output_file=output_file)
|
||||
code = isort.api.sort_code_string(
|
||||
code=code,
|
||||
show_diff=False,
|
||||
py_version=37,
|
||||
profile="black",
|
||||
combine_as_imports=True,
|
||||
lines_after_imports=2,
|
||||
quiet=True,
|
||||
force_grid_wrap=2,
|
||||
known_third_party=["grpclib", "betterproto"],
|
||||
code = body_template.render(output_file=output_file)
|
||||
code = header_template.render(output_file=output_file) + code
|
||||
|
||||
# Sort imports, delete unused ones
|
||||
code = subprocess.check_output(
|
||||
["ruff", "check", "--select", "I,F401", "--fix", "--silent", "-"],
|
||||
input=code,
|
||||
encoding="utf-8",
|
||||
)
|
||||
return black.format_str(
|
||||
src_contents=code,
|
||||
mode=black.Mode(),
|
||||
|
||||
# Format the code
|
||||
code = subprocess.check_output(
|
||||
["ruff", "format", "-"], input=code, encoding="utf-8"
|
||||
)
|
||||
|
||||
# Validate the generated code.
|
||||
validator = ModuleValidator(iter(code.splitlines()))
|
||||
if not validator.validate():
|
||||
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
|
||||
for collision, lines in validator.collisions.items():
|
||||
message_builder.append(f' "{collision}" on lines:')
|
||||
for num, line in lines:
|
||||
message_builder.append(f" {num}:{line}")
|
||||
print("\n".join(message_builder), file=sys.stderr)
|
||||
return code
|
||||
|
@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
|
||||
import builtins
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
@ -49,12 +47,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import betterproto
|
||||
from betterproto import which_one_of
|
||||
from betterproto.casing import sanitize_name
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
@ -72,16 +64,21 @@ from betterproto.lib.google.protobuf import (
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from .. import which_one_of
|
||||
from ..compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from ..compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_enum_member_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
from .typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
TypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
@ -156,14 +153,33 @@ def get_comment(
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci_loc in proto_file.source_code_info.location:
|
||||
if list(sci_loc.path) == path and sci_loc.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
)
|
||||
if list(sci_loc.path) == path:
|
||||
all_comments = list(sci_loc.leading_detached_comments)
|
||||
if sci_loc.leading_comments:
|
||||
all_comments.append(sci_loc.leading_comments)
|
||||
if sci_loc.trailing_comments:
|
||||
all_comments.append(sci_loc.trailing_comments)
|
||||
|
||||
lines = []
|
||||
|
||||
for comment in all_comments:
|
||||
lines += comment.split("\n")
|
||||
lines.append("")
|
||||
|
||||
# Remove consecutive empty lines
|
||||
lines = [
|
||||
line for i, line in enumerate(lines) if line or (i == 0 or lines[i - 1])
|
||||
]
|
||||
|
||||
if lines and not lines[-1]:
|
||||
lines.pop() # Remove the last empty line
|
||||
|
||||
# It is common for one line comments to start with a space, for example: // comment
|
||||
# We don't add this space to the generated file.
|
||||
lines = [line[1:] if line and line[0] == " " else line for line in lines]
|
||||
|
||||
# This is a field, message, enum, service, or method
|
||||
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
|
||||
lines[0] = lines[0].strip('"')
|
||||
return f'{pad}"""{lines[0]}"""'
|
||||
else:
|
||||
joined = f"\n{pad}".join(lines)
|
||||
@ -176,6 +192,7 @@ class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
typing_compiler: TypingCompiler
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
@ -243,9 +260,8 @@ class OutputTemplate:
|
||||
parent_request: PluginRequestCompiler
|
||||
package_proto_obj: FileDescriptorProto
|
||||
input_files: List[str] = field(default_factory=list)
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
imports_end: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
pydantic_imports: Set[str] = field(default_factory=set)
|
||||
builtins_import: bool = False
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
@ -254,6 +270,7 @@ class OutputTemplate:
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
pydantic_dataclasses: bool = False
|
||||
output: bool = True
|
||||
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
|
||||
|
||||
@property
|
||||
def package(self) -> str:
|
||||
@ -280,8 +297,21 @@ class OutputTemplate:
|
||||
@property
|
||||
def python_module_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
|
||||
has_deprecated = False
|
||||
if any(m.deprecated for m in self.messages):
|
||||
has_deprecated = True
|
||||
if any(x for x in self.messages if any(x.deprecated_fields)):
|
||||
has_deprecated = True
|
||||
if any(
|
||||
any(m.proto_obj.options.deprecated for m in s.methods)
|
||||
for s in self.services
|
||||
):
|
||||
has_deprecated = True
|
||||
|
||||
if has_deprecated:
|
||||
imports.add("warnings")
|
||||
|
||||
if self.builtins_import:
|
||||
imports.add("builtins")
|
||||
return imports
|
||||
@ -292,6 +322,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
typing_compiler: TypingCompiler
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
@ -319,12 +350,6 @@ class MessageCompiler(ProtoContentBase):
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
if self.repeated:
|
||||
return f"List[{self.py_name}]"
|
||||
return self.py_name
|
||||
|
||||
@property
|
||||
def deprecated_fields(self) -> Iterator[str]:
|
||||
for f in self.fields:
|
||||
@ -385,7 +410,10 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
us to tell whether it was set, via the which_one_of interface.
|
||||
"""
|
||||
|
||||
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
return (
|
||||
not proto_field_obj.proto3_optional
|
||||
and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -434,18 +462,6 @@ class FieldCompiler(MessageCompiler):
|
||||
imports.add("datetime")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def typing_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return set()
|
||||
@ -458,7 +474,6 @@ class FieldCompiler(MessageCompiler):
|
||||
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||
|
||||
@ -485,11 +500,6 @@ class FieldCompiler(MessageCompiler):
|
||||
def optional(self) -> bool:
|
||||
return self.proto_obj.proto3_optional
|
||||
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
"""String representation of proto field type."""
|
||||
@ -499,35 +509,6 @@ class FieldCompiler(MessageCompiler):
|
||||
.replace("type_", "")
|
||||
)
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default proto value."""
|
||||
if self.repeated:
|
||||
return "[]"
|
||||
if self.optional:
|
||||
return "None"
|
||||
if self.py_type == "int":
|
||||
return "0"
|
||||
if self.py_type == "float":
|
||||
return "0.0"
|
||||
elif self.py_type == "bool":
|
||||
return "False"
|
||||
elif self.py_type == "str":
|
||||
return '""'
|
||||
elif self.py_type == "bytes":
|
||||
return 'b""'
|
||||
elif self.field_type == "enum":
|
||||
enum_proto_obj_name = self.proto_obj.type_name.split(".").pop()
|
||||
enum = next(
|
||||
e
|
||||
for e in self.output_file.enums
|
||||
if e.proto_obj.name == enum_proto_obj_name
|
||||
)
|
||||
return enum.default_value_string
|
||||
else:
|
||||
# Message type
|
||||
return "None"
|
||||
|
||||
@property
|
||||
def packed(self) -> bool:
|
||||
"""True if the wire representation is a packed format."""
|
||||
@ -560,8 +541,10 @@ class FieldCompiler(MessageCompiler):
|
||||
# Type referencing another defined Message or a named enum
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
imports=self.output_file.imports_end,
|
||||
source_type=self.proto_obj.type_name,
|
||||
typing_compiler=self.typing_compiler,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
|
||||
@ -572,9 +555,9 @@ class FieldCompiler(MessageCompiler):
|
||||
if self.use_builtins:
|
||||
py_type = f"builtins.{py_type}"
|
||||
if self.repeated:
|
||||
return f"List[{py_type}]"
|
||||
return self.typing_compiler.list(py_type)
|
||||
if self.optional:
|
||||
return f"Optional[{py_type}]"
|
||||
return self.typing_compiler.optional(py_type)
|
||||
return py_type
|
||||
|
||||
|
||||
@ -599,7 +582,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return {"root_validator"}
|
||||
return {"model_validator"}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -622,11 +605,13 @@ class MapEntryCompiler(FieldCompiler):
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[0], # key
|
||||
typing_compiler=self.typing_compiler,
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[1], # value
|
||||
typing_compiler=self.typing_compiler,
|
||||
).py_type
|
||||
|
||||
# Get proto types
|
||||
@ -644,7 +629,7 @@ class MapEntryCompiler(FieldCompiler):
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
@ -670,7 +655,9 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
name=sanitize_name(entry_proto_value.name),
|
||||
name=pythonize_enum_member_name(
|
||||
entry_proto_value.name, self.proto_obj.name
|
||||
),
|
||||
value=entry_proto_value.number,
|
||||
comment=get_comment(
|
||||
proto_file=self.source_file, path=self.path + [2, entry_number]
|
||||
@ -680,17 +667,10 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
]
|
||||
super().__post_init__() # call MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default value for Enums.
|
||||
|
||||
As per the spec, this is the first value of the Enum.
|
||||
"""
|
||||
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceCompiler(ProtoContentBase):
|
||||
source_file: FileDescriptorProto
|
||||
parent: OutputTemplate = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
@ -699,7 +679,6 @@ class ServiceCompiler(ProtoContentBase):
|
||||
def __post_init__(self) -> None:
|
||||
# Add service to output file
|
||||
self.output_file.services.append(self)
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
@ -713,6 +692,7 @@ class ServiceCompiler(ProtoContentBase):
|
||||
|
||||
@dataclass
|
||||
class ServiceMethodCompiler(ProtoContentBase):
|
||||
source_file: FileDescriptorProto
|
||||
parent: ServiceCompiler
|
||||
proto_obj: MethodDescriptorProto
|
||||
path: List[int] = PLACEHOLDER
|
||||
@ -722,22 +702,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
# Add method to service
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for imports
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
|
||||
# Check for Async imports
|
||||
if self.client_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterable")
|
||||
self.output_file.typing_imports.add("Iterable")
|
||||
self.output_file.typing_imports.add("Union")
|
||||
|
||||
# Required by both client and server
|
||||
if self.client_streaming or self.server_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterator")
|
||||
|
||||
# add imports required for request arguments timeout, deadline and metadata
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
||||
self.output_file.imports_type_checking_only.add(
|
||||
"from betterproto.grpc.grpclib_client import MetadataLike"
|
||||
@ -765,30 +729,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
)
|
||||
return f"/{package_part}{self.parent.proto_name}/{self.proto_name}"
|
||||
|
||||
@property
|
||||
def py_input_message(self) -> Optional[MessageCompiler]:
|
||||
"""Find the input message object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[MessageCompiler]
|
||||
Method instance representing the input message.
|
||||
If not input message could be found or there are no
|
||||
input messages, None is returned.
|
||||
"""
|
||||
package, name = parse_source_type_name(self.proto_obj.input_type)
|
||||
|
||||
# Nested types are currently flattened without dots.
|
||||
# Todo: keep a fully quantified name in types, that is
|
||||
# comparable with method.input_type
|
||||
for msg in self.request.all_messages:
|
||||
if (
|
||||
msg.py_name == pythonize_class_name(name.replace(".", ""))
|
||||
and msg.output_file.package == package
|
||||
):
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def py_input_message_type(self) -> str:
|
||||
"""String representation of the Python type corresponding to the
|
||||
@ -801,9 +741,11 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
imports=self.output_file.imports_end,
|
||||
source_type=self.proto_obj.input_type,
|
||||
typing_compiler=self.output_file.typing_compiler,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
@ -829,9 +771,11 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
imports=self.output_file.imports_end,
|
||||
source_type=self.proto_obj.output_type,
|
||||
typing_compiler=self.output_file.typing_compiler,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
|
163
src/betterproto/plugin/module_validation.py
Normal file
163
src/betterproto/plugin/module_validation.py
Normal file
@ -0,0 +1,163 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleValidator:
|
||||
line_iterator: Iterator[str]
|
||||
line_number: int = field(init=False, default=0)
|
||||
|
||||
collisions: Dict[str, List[Tuple[int, str]]] = field(
|
||||
init=False, default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
def add_import(self, imp: str, number: int, full_line: str):
|
||||
"""
|
||||
Adds an import to be tracked.
|
||||
"""
|
||||
self.collisions[imp].append((number, full_line))
|
||||
|
||||
def process_import(self, imp: str):
|
||||
"""
|
||||
Filters out the import to its actual value.
|
||||
"""
|
||||
if " as " in imp:
|
||||
imp = imp[imp.index(" as ") + 4 :]
|
||||
|
||||
imp = imp.strip()
|
||||
assert " " not in imp, imp
|
||||
return imp
|
||||
|
||||
def evaluate_multiline_import(self, line: str):
|
||||
"""
|
||||
Evaluates a multiline import from a starting line
|
||||
"""
|
||||
# Filter the first line and remove anything before the import statement.
|
||||
full_line = line
|
||||
line = line.split("import", 1)[1]
|
||||
if "(" in line:
|
||||
conditional = lambda line: ")" not in line
|
||||
else:
|
||||
conditional = lambda line: "\\" in line
|
||||
|
||||
# Remove open parenthesis if it exists.
|
||||
if "(" in line:
|
||||
line = line[line.index("(") + 1 :]
|
||||
|
||||
# Choose the conditional based on how multiline imports are formatted.
|
||||
while conditional(line):
|
||||
# Split the line by commas
|
||||
imports = line.split(",")
|
||||
|
||||
for imp in imports:
|
||||
# Add the import to the namespace
|
||||
imp = self.process_import(imp)
|
||||
if imp:
|
||||
self.add_import(imp, self.line_number, full_line)
|
||||
# Get the next line
|
||||
full_line = line = next(self.line_iterator)
|
||||
# Increment the line number
|
||||
self.line_number += 1
|
||||
|
||||
# validate the last line
|
||||
if ")" in line:
|
||||
line = line[: line.index(")")]
|
||||
imports = line.split(",")
|
||||
for imp in imports:
|
||||
imp = self.process_import(imp)
|
||||
if imp:
|
||||
self.add_import(imp, self.line_number, full_line)
|
||||
|
||||
def evaluate_import(self, line: str):
|
||||
"""
|
||||
Extracts an import from a line.
|
||||
"""
|
||||
whole_line = line
|
||||
line = line[line.index("import") + 6 :]
|
||||
values = line.split(",")
|
||||
for v in values:
|
||||
self.add_import(self.process_import(v), self.line_number, whole_line)
|
||||
|
||||
def next(self):
|
||||
"""
|
||||
Evaluate each line for names in the module.
|
||||
"""
|
||||
line = next(self.line_iterator)
|
||||
|
||||
# Skip lines with indentation or comments
|
||||
if (
|
||||
# Skip indents and whitespace.
|
||||
line.startswith(" ")
|
||||
or line == "\n"
|
||||
or line.startswith("\t")
|
||||
or
|
||||
# Skip comments
|
||||
line.startswith("#")
|
||||
or
|
||||
# Skip decorators
|
||||
line.startswith("@")
|
||||
):
|
||||
self.line_number += 1
|
||||
return
|
||||
|
||||
# Skip docstrings.
|
||||
if line.startswith('"""') or line.startswith("'''"):
|
||||
quote = line[0] * 3
|
||||
line = line[3:]
|
||||
while quote not in line:
|
||||
line = next(self.line_iterator)
|
||||
self.line_number += 1
|
||||
return
|
||||
|
||||
# Evaluate Imports.
|
||||
if line.startswith("from ") or line.startswith("import "):
|
||||
if "(" in line or "\\" in line:
|
||||
self.evaluate_multiline_import(line)
|
||||
else:
|
||||
self.evaluate_import(line)
|
||||
|
||||
# Evaluate Classes.
|
||||
elif line.startswith("class "):
|
||||
class_name = re.search(r"class (\w+)", line).group(1)
|
||||
if class_name:
|
||||
self.add_import(class_name, self.line_number, line)
|
||||
|
||||
# Evaluate Functions.
|
||||
elif line.startswith("def "):
|
||||
function_name = re.search(r"def (\w+)", line).group(1)
|
||||
if function_name:
|
||||
self.add_import(function_name, self.line_number, line)
|
||||
|
||||
# Evaluate direct assignments.
|
||||
elif "=" in line:
|
||||
assignment = re.search(r"(\w+)\s*=", line).group(1)
|
||||
if assignment:
|
||||
self.add_import(assignment, self.line_number, line)
|
||||
|
||||
self.line_number += 1
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Run Validation.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
self.next()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# Filter collisions for those with more than one value.
|
||||
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
|
||||
|
||||
# Return True if no collisions are found.
|
||||
return not bool(self.collisions)
|
@ -37,6 +37,12 @@ from .models import (
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
from .typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
NoTyping310TypingCompiler,
|
||||
TypingCompiler,
|
||||
TypingImportTypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
def traverse(
|
||||
@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
output_package_name
|
||||
].pydantic_dataclasses = True
|
||||
|
||||
# Gather any typing generation options.
|
||||
typing_opts = [
|
||||
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
|
||||
]
|
||||
|
||||
if len(typing_opts) > 1:
|
||||
raise ValueError("Multiple typing options provided")
|
||||
# Set the compiler type.
|
||||
typing_opt = typing_opts[0] if typing_opts else "direct"
|
||||
if typing_opt == "direct":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = DirectImportTypingCompiler()
|
||||
elif typing_opt == "root":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = TypingImportTypingCompiler()
|
||||
elif typing_opt == "310":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = NoTyping310TypingCompiler()
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
@ -115,7 +143,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
for proto_input_file in output_package.input_files:
|
||||
for index, service in enumerate(proto_input_file.service):
|
||||
read_protobuf_service(service, index, output_package)
|
||||
read_protobuf_service(proto_input_file, service, index, output_package)
|
||||
|
||||
# Generate output files
|
||||
output_paths: Set[pathlib.Path] = set()
|
||||
@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
|
||||
parent=parent,
|
||||
proto_obj=proto_obj,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
|
||||
|
||||
@ -181,7 +210,11 @@ def read_protobuf_type(
|
||||
return
|
||||
# Process Message
|
||||
message_data = MessageCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=item,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
for index, field in enumerate(item.field):
|
||||
if is_map(field, item):
|
||||
@ -190,6 +223,7 @@ def read_protobuf_type(
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
elif is_oneof(field):
|
||||
_make_one_of_field_compiler(
|
||||
@ -201,21 +235,35 @@ def read_protobuf_type(
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# Enum
|
||||
EnumDefinitionCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=item,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
|
||||
|
||||
def read_protobuf_service(
|
||||
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
|
||||
source_file: FileDescriptorProto,
|
||||
service: ServiceDescriptorProto,
|
||||
index: int,
|
||||
output_package: OutputTemplate,
|
||||
) -> None:
|
||||
service_data = ServiceCompiler(
|
||||
parent=output_package, proto_obj=service, path=[6, index]
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=service,
|
||||
path=[6, index],
|
||||
)
|
||||
for j, method in enumerate(service.method):
|
||||
ServiceMethodCompiler(
|
||||
parent=service_data, proto_obj=method, path=[6, index, 2, j]
|
||||
source_file=source_file,
|
||||
parent=service_data,
|
||||
proto_obj=method,
|
||||
path=[6, index, 2, j],
|
||||
)
|
||||
|
173
src/betterproto/plugin/typing_compiler.py
Normal file
173
src/betterproto/plugin/typing_compiler.py
Normal file
@ -0,0 +1,173 @@
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
|
||||
class TypingCompiler(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def optional(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def list(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def union(self, *types: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def iterable(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_iterable(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_iterator(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
"""
|
||||
Returns either the direct import as a key with none as value, or a set of
|
||||
values to import from the key.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def import_lines(self) -> Iterator:
|
||||
imports = self.imports()
|
||||
for key, value in imports.items():
|
||||
if value is None:
|
||||
yield f"import {key}"
|
||||
else:
|
||||
yield f"from {key} import ("
|
||||
for v in sorted(value):
|
||||
yield f" {v},"
|
||||
yield ")"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectImportTypingCompiler(TypingCompiler):
|
||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
self._imports["typing"].add("Optional")
|
||||
return f"Optional[{type}]"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
self._imports["typing"].add("List")
|
||||
return f"List[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
self._imports["typing"].add("Dict")
|
||||
return f"Dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
self._imports["typing"].add("Union")
|
||||
return f"Union[{', '.join(types)}]"
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("Iterable")
|
||||
return f"Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterable")
|
||||
return f"AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterator")
|
||||
return f"AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
return {k: v if v else None for k, v in self._imports.items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypingImportTypingCompiler(TypingCompiler):
|
||||
_imported: bool = False
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Optional[{type}]"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.List[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Union[{', '.join(types)}]"
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
if self._imported:
|
||||
return {"typing": None}
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoTyping310TypingCompiler(TypingCompiler):
|
||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||
|
||||
@staticmethod
|
||||
def _fmt(type: str) -> str: # for now this is necessary till 3.14
|
||||
if type.startswith('"'):
|
||||
return type[1:-1]
|
||||
return type
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
return f'"{self._fmt(type)} | None"'
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
return f'"list[{self._fmt(type)}]"'
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
return f'"dict[{key}, {self._fmt(value)}]"'
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
return f'"{" | ".join(map(self._fmt, types))}"'
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imports["collections.abc"].add("Iterable")
|
||||
return f'"Iterable[{type}]"'
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imports["collections.abc"].add("AsyncIterable")
|
||||
return f'"AsyncIterable[{type}]"'
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imports["collections.abc"].add("AsyncIterator")
|
||||
return f'"AsyncIterator[{type}]"'
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
return {k: v if v else None for k, v in self._imports.items()}
|
57
src/betterproto/templates/header.py.j2
Normal file
57
src/betterproto/templates/header.py.j2
Normal file
@ -0,0 +1,57 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
|
||||
__all__ = (
|
||||
{%- for enum in output_file.enums -%}
|
||||
"{{ enum.py_name }}",
|
||||
{%- endfor -%}
|
||||
{%- for message in output_file.messages -%}
|
||||
"{{ message.py_name }}",
|
||||
{%- endfor -%}
|
||||
{%- for service in output_file.services -%}
|
||||
"{{ service.py_name }}Stub",
|
||||
"{{ service.py_name }}Base",
|
||||
{%- endfor -%}
|
||||
)
|
||||
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% set typing_imports = output_file.typing_compiler.imports() %}
|
||||
{% if typing_imports %}
|
||||
{% for line in output_file.typing_compiler.import_lines() %}
|
||||
{{ line }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.imports_type_checking_only %}
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
@ -1,53 +1,3 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% if output_file.typing_imports %}
|
||||
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.imports_type_checking_only %}
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.enums %}{% for enum in output_file.enums %}
|
||||
class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% if enum.comment %}
|
||||
@ -62,11 +12,22 @@ class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
{% for message in output_file.messages %}
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
|
||||
{% else %}
|
||||
@dataclass(eq=False, repr=False)
|
||||
{% endif %}
|
||||
class {{ message.py_name }}(betterproto.Message):
|
||||
{% if message.comment %}
|
||||
{{ message.comment }}
|
||||
@ -96,7 +57,7 @@ class {{ message.py_name }}(betterproto.Message):
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
||||
@root_validator()
|
||||
@model_validator(mode='after')
|
||||
def check_oneof(cls, values):
|
||||
return cls._validate_field_groups(values)
|
||||
{% endif %}
|
||||
@ -113,20 +74,24 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self
|
||||
{%- if not method.client_streaming -%}
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}"
|
||||
{%- endif -%}
|
||||
,
|
||||
*
|
||||
, timeout: Optional[float] = None
|
||||
, deadline: Optional["Deadline"] = None
|
||||
, metadata: Optional["MetadataLike"] = None
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
|
||||
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
|
||||
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
|
||||
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
{% endif %}
|
||||
{% if method.proto_obj.options.deprecated %}
|
||||
warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning)
|
||||
|
||||
{% endif %}
|
||||
{% if method.server_streaming %}
|
||||
{% if method.client_streaming %}
|
||||
@ -178,6 +143,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
|
||||
{% for i in output_file.imports_end %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% for service in output_file.services %}
|
||||
class {{ service.py_name }}Base(ServiceBase):
|
||||
{% if service.comment %}
|
||||
@ -188,12 +157,12 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self
|
||||
{%- if not method.client_streaming -%}
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
@ -225,7 +194,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
|
||||
{% endfor %}
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
|
||||
return {
|
||||
{% for method in service.methods %}
|
||||
"{{ method.route }}": grpclib.const.Handler(
|
||||
@ -246,11 +215,3 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
{% for message in output_file.messages %}
|
||||
{% if message.has_message_field %}
|
||||
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
56
src/betterproto/utils.py
Normal file
56
src/betterproto/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Concatenate,
|
||||
ParamSpec,
|
||||
Self,
|
||||
)
|
||||
|
||||
|
||||
SelfT = TypeVar("SelfT")
|
||||
P = ParamSpec("P")
|
||||
HybridT = TypeVar("HybridT", covariant=True)
|
||||
|
||||
|
||||
class hybridmethod(Generic[SelfT, P, HybridT]):
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[
|
||||
Concatenate[type[SelfT], P], HybridT
|
||||
], # Must be the classmethod version
|
||||
):
|
||||
self.cls_func = func
|
||||
self.__doc__ = func.__doc__
|
||||
|
||||
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
|
||||
self.instance_func = func
|
||||
return self
|
||||
|
||||
def __get__(
|
||||
self, instance: Optional[SelfT], owner: Type[SelfT]
|
||||
) -> Callable[P, HybridT]:
|
||||
if instance is None or self.instance_func is None:
|
||||
# either bound to the class, or no instance method available
|
||||
return self.cls_func.__get__(owner, None)
|
||||
return self.instance_func.__get__(instance, owner)
|
||||
|
||||
|
||||
T_co = TypeVar("T_co")
|
||||
TT_co = TypeVar("TT_co", bound="type[Any]")
|
||||
|
||||
|
||||
class classproperty(Generic[TT_co, T_co]):
|
||||
def __init__(self, func: Callable[[TT_co], T_co]):
|
||||
self.__func__ = func
|
||||
|
||||
def __get__(self, instance: Any, type: TT_co) -> T_co:
|
||||
return self.__func__(type)
|
@ -4,17 +4,6 @@ import sys
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--repeat", type=int, default=1, help="repeat the operation multiple times"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def repeat(request):
|
||||
return request.config.getoption("repeat")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_sys_path():
|
||||
original = copy.deepcopy(sys.path)
|
||||
|
@ -108,6 +108,7 @@ async def generate_test_case_output(
|
||||
print(
|
||||
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
|
||||
)
|
||||
print(ref_err.decode())
|
||||
|
||||
if verbose:
|
||||
if ref_out:
|
||||
@ -126,6 +127,7 @@ async def generate_test_case_output(
|
||||
print(
|
||||
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
|
||||
)
|
||||
print(plg_err.decode())
|
||||
|
||||
if verbose:
|
||||
if plg_out:
|
||||
@ -146,6 +148,7 @@ async def generate_test_case_output(
|
||||
print(
|
||||
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||||
)
|
||||
print(plg_err_pyd.decode())
|
||||
|
||||
if verbose:
|
||||
if plg_out_pyd:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import grpclib
|
||||
@ -27,12 +26,12 @@ async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
||||
|
||||
def _assert_request_meta_received(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be received serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be received serverside"
|
||||
assert stream.deadline._timestamp == pytest.approx(deadline._timestamp, 1), (
|
||||
"The provided deadline should be received serverside"
|
||||
)
|
||||
assert stream.metadata["authorization"] == metadata["authorization"], (
|
||||
"The provided authorization metadata should be received serverside"
|
||||
)
|
||||
|
||||
return server_side_test
|
||||
|
||||
@ -91,9 +90,6 @@ async def test_trailer_only_error_stream_unary(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
|
||||
)
|
||||
async def test_service_call_mutable_defaults(mocker):
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
@ -269,6 +265,30 @@ async def test_async_gen_for_stream_stream_request():
|
||||
else:
|
||||
# No more things to send make sure channel is closed
|
||||
request_chan.close()
|
||||
assert response_index == len(
|
||||
expected_things
|
||||
), "Didn't receive all expected responses"
|
||||
assert response_index == len(expected_things), (
|
||||
"Didn't receive all expected responses"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_unary_with_empty_iterable():
|
||||
things = [] # empty
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
requests = [DoThingRequest(name) for name in things]
|
||||
response = await client.do_many_things(requests)
|
||||
assert len(response.names) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_stream_with_empty_iterable():
|
||||
things = [] # empty
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
requests = [GetThingRequest(name) for name in things]
|
||||
responses = [
|
||||
response async for response in client.get_different_things(requests)
|
||||
]
|
||||
assert len(responses) == 0
|
||||
|
@ -27,7 +27,7 @@ class ThingService:
|
||||
async def do_many_things(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name for request in stream]
|
||||
thing_names = [request.name async for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
@ -10,10 +10,15 @@ def test_value():
|
||||
|
||||
|
||||
def test_pydantic_no_value():
|
||||
with pytest.raises(ValueError):
|
||||
TestPyd()
|
||||
message = TestPyd()
|
||||
assert not message.value, "Boolean is False by default"
|
||||
|
||||
|
||||
def test_pydantic_value():
|
||||
message = Test(value=False)
|
||||
message = TestPyd(value=False)
|
||||
assert not message.value
|
||||
|
||||
|
||||
def test_pydantic_bad_value():
|
||||
with pytest.raises(ValueError):
|
||||
TestPyd(value=123)
|
||||
|
@ -4,20 +4,20 @@ from tests.output_betterproto.casing import Test
|
||||
|
||||
def test_message_attributes():
|
||||
message = Test()
|
||||
assert hasattr(
|
||||
message, "snake_case_message"
|
||||
), "snake_case field name is same in python"
|
||||
assert hasattr(message, "snake_case_message"), (
|
||||
"snake_case field name is same in python"
|
||||
)
|
||||
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
|
||||
assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
|
||||
|
||||
|
||||
def test_message_casing():
|
||||
assert hasattr(
|
||||
casing, "SnakeCaseMessage"
|
||||
), "snake_case Message name is converted to CamelCase in python"
|
||||
assert hasattr(casing, "SnakeCaseMessage"), (
|
||||
"snake_case Message name is converted to CamelCase in python"
|
||||
)
|
||||
|
||||
|
||||
def test_enum_casing():
|
||||
assert hasattr(
|
||||
casing, "MyEnum"
|
||||
), "snake_case Enum name is converted to CamelCase in python"
|
||||
assert hasattr(casing, "MyEnum"), (
|
||||
"snake_case Enum name is converted to CamelCase in python"
|
||||
)
|
||||
|
@ -2,13 +2,13 @@ import tests.output_betterproto.casing_inner_class as casing_inner_class
|
||||
|
||||
|
||||
def test_message_casing_inner_class_name():
|
||||
assert hasattr(
|
||||
casing_inner_class, "TestInnerClass"
|
||||
), "Inline defined Message is correctly converted to CamelCase"
|
||||
assert hasattr(casing_inner_class, "TestInnerClass"), (
|
||||
"Inline defined Message is correctly converted to CamelCase"
|
||||
)
|
||||
|
||||
|
||||
def test_message_casing_inner_class_attributes():
|
||||
message = casing_inner_class.Test()
|
||||
assert hasattr(
|
||||
message.inner, "old_exp"
|
||||
), "Inline defined Message attribute is snake_case"
|
||||
assert hasattr(message.inner, "old_exp"), (
|
||||
"Inline defined Message attribute is snake_case"
|
||||
)
|
||||
|
@ -3,12 +3,12 @@ from tests.output_betterproto.casing_message_field_uppercase import Test
|
||||
|
||||
def test_message_casing():
|
||||
message = Test()
|
||||
assert hasattr(
|
||||
message, "uppercase"
|
||||
), "UPPERCASE attribute is converted to 'uppercase' in python"
|
||||
assert hasattr(
|
||||
message, "uppercase_v2"
|
||||
), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
|
||||
assert hasattr(
|
||||
message, "upper_camel_case"
|
||||
), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"
|
||||
assert hasattr(message, "uppercase"), (
|
||||
"UPPERCASE attribute is converted to 'uppercase' in python"
|
||||
)
|
||||
assert hasattr(message, "uppercase_v2"), (
|
||||
"UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
|
||||
)
|
||||
assert hasattr(message, "upper_camel_case"), (
|
||||
"UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"
|
||||
)
|
||||
|
@ -12,3 +12,10 @@ message Message {
|
||||
option deprecated = true;
|
||||
string value = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
service TestService {
|
||||
rpc func(Empty) returns (Empty);
|
||||
rpc deprecated_func(Empty) returns (Empty) { option deprecated = true; };
|
||||
}
|
||||
|
44
tests/inputs/documentation/documentation.proto
Normal file
44
tests/inputs/documentation/documentation.proto
Normal file
@ -0,0 +1,44 @@
|
||||
syntax = "proto3";
|
||||
package documentation;
|
||||
|
||||
// Documentation of message 1
|
||||
// other line 1
|
||||
|
||||
// Documentation of message 2
|
||||
// other line 2
|
||||
message Test { // Documentation of message 3
|
||||
// Documentation of field 1
|
||||
// other line 1
|
||||
|
||||
// Documentation of field 2
|
||||
// other line 2
|
||||
uint32 x = 1; // Documentation of field 3
|
||||
}
|
||||
|
||||
// Documentation of enum 1
|
||||
// other line 1
|
||||
|
||||
// Documentation of enum 2
|
||||
// other line 2
|
||||
enum Enum { // Documentation of enum 3
|
||||
// Documentation of variant 1
|
||||
// other line 1
|
||||
|
||||
// Documentation of variant 2
|
||||
// other line 2
|
||||
Enum_Variant = 0; // Documentation of variant 3
|
||||
}
|
||||
|
||||
// Documentation of service 1
|
||||
// other line 1
|
||||
|
||||
// Documentation of service 2
|
||||
// other line 2
|
||||
service Service { // Documentation of service 3
|
||||
// Documentation of method 1
|
||||
// other line 1
|
||||
|
||||
// Documentation of method 2
|
||||
// other line 2
|
||||
rpc get(Test) returns (Test); // Documentation of method 3
|
||||
}
|
@ -15,3 +15,11 @@ enum Choice {
|
||||
FOUR = 4;
|
||||
THREE = 3;
|
||||
}
|
||||
|
||||
// A "C" like enum with the enum name prefixed onto members, these should be stripped
|
||||
enum ArithmeticOperator {
|
||||
ARITHMETIC_OPERATOR_NONE = 0;
|
||||
ARITHMETIC_OPERATOR_PLUS = 1;
|
||||
ARITHMETIC_OPERATOR_MINUS = 2;
|
||||
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
from tests.output_betterproto.enum import (
|
||||
ArithmeticOperator,
|
||||
Choice,
|
||||
Test,
|
||||
)
|
||||
@ -26,9 +27,9 @@ def test_enum_is_comparable_with_int():
|
||||
|
||||
|
||||
def test_enum_to_dict():
|
||||
assert (
|
||||
"choice" not in Test(choice=Choice.ZERO).to_dict()
|
||||
), "Default enum value is not serialized"
|
||||
assert "choice" not in Test(choice=Choice.ZERO).to_dict(), (
|
||||
"Default enum value is not serialized"
|
||||
)
|
||||
assert (
|
||||
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
|
||||
== "ZERO"
|
||||
@ -82,3 +83,32 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
|
||||
yield Choice.THREE
|
||||
|
||||
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
|
||||
|
||||
|
||||
def test_enum_mapped_on_parse():
|
||||
# test default value
|
||||
b = Test().parse(bytes(Test()))
|
||||
assert b.choice.name == Choice.ZERO.name
|
||||
assert b.choices == []
|
||||
|
||||
# test non default value
|
||||
a = Test().parse(bytes(Test(choice=Choice.ONE)))
|
||||
assert a.choice.name == Choice.ONE.name
|
||||
assert b.choices == []
|
||||
|
||||
# test repeated
|
||||
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
|
||||
assert c.choices[0].name == Choice.THREE.name
|
||||
assert c.choices[1].name == Choice.FOUR.name
|
||||
|
||||
# bonus: defaults after empty init are also mapped
|
||||
assert Test().choice.name == Choice.ZERO.name
|
||||
|
||||
|
||||
def test_renamed_enum_members():
|
||||
assert set(ArithmeticOperator.__members__) == {
|
||||
"NONE",
|
||||
"PLUS",
|
||||
"MINUS",
|
||||
"_0_PREFIXED",
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
package google_impl_behavior_equivalence;
|
||||
|
||||
message Foo { int64 bar = 1; }
|
||||
@ -12,6 +13,10 @@ message Test {
|
||||
}
|
||||
}
|
||||
|
||||
message Spam {
|
||||
google.protobuf.Timestamp ts = 1;
|
||||
}
|
||||
|
||||
message Request { Empty foo = 1; }
|
||||
|
||||
message Empty {}
|
@ -1,17 +1,25 @@
|
||||
from datetime import (
|
||||
datetime,
|
||||
timezone,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
|
||||
import betterproto
|
||||
from tests.output_betterproto.google_impl_behavior_equivalence import (
|
||||
Empty,
|
||||
Foo,
|
||||
Request,
|
||||
Spam,
|
||||
Test,
|
||||
)
|
||||
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
|
||||
Empty as ReferenceEmpty,
|
||||
Foo as ReferenceFoo,
|
||||
Request as ReferenceRequest,
|
||||
Spam as ReferenceSpam,
|
||||
Test as ReferenceTest,
|
||||
)
|
||||
|
||||
@ -59,6 +67,19 @@ def test_bytes_are_the_same_for_oneof():
|
||||
assert isinstance(message_reference2.foo, ReferenceFoo)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
|
||||
def test_datetime_clamping(dt): # see #407
|
||||
ts = Timestamp()
|
||||
ts.FromDatetime(dt)
|
||||
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
|
||||
message_bytes = bytes(Spam(dt))
|
||||
|
||||
assert (
|
||||
Spam().parse(message_bytes).ts.timestamp()
|
||||
== ReferenceSpam.FromString(message_bytes).ts.seconds
|
||||
)
|
||||
|
||||
|
||||
def test_empty_message_field():
|
||||
message = Request()
|
||||
reference_message = ReferenceRequest()
|
||||
|
@ -26,5 +26,5 @@ import "other.proto";
|
||||
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
|
||||
message Test {
|
||||
RootPackageMessage message = 1;
|
||||
other.OtherPackageMessage other = 2;
|
||||
other.OtherPackageMessage other_value = 2;
|
||||
}
|
||||
|
7
tests/inputs/invalid_field/invalid_field.proto
Normal file
7
tests/inputs/invalid_field/invalid_field.proto
Normal file
@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package invalid_field;
|
||||
|
||||
message Test {
|
||||
int32 x = 1;
|
||||
}
|
17
tests/inputs/invalid_field/test_invalid_field.py
Normal file
17
tests/inputs/invalid_field/test_invalid_field.py
Normal file
@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_invalid_field():
|
||||
from tests.output_betterproto.invalid_field import Test
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Test(unknown_field=12)
|
||||
|
||||
|
||||
def test_invalid_field_pydantic():
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tests.output_betterproto_pydantic.invalid_field import Test
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Test(unknown_field=12)
|
@ -2,6 +2,10 @@ syntax = "proto3";
|
||||
|
||||
package oneof;
|
||||
|
||||
message MixedDrink {
|
||||
int32 shots = 1;
|
||||
}
|
||||
|
||||
message Test {
|
||||
oneof foo {
|
||||
int32 pitied = 1;
|
||||
@ -13,6 +17,7 @@ message Test {
|
||||
oneof bar {
|
||||
int32 drinks = 11;
|
||||
string bar_name = 12;
|
||||
MixedDrink mixed_drink = 13;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
import pytest
|
||||
|
||||
import betterproto
|
||||
from tests.output_betterproto.oneof import Test
|
||||
from tests.output_betterproto.oneof import (
|
||||
MixedDrink,
|
||||
Test,
|
||||
)
|
||||
from tests.output_betterproto_pydantic.oneof import Test as TestPyd
|
||||
from tests.util import get_test_case_json_data
|
||||
|
||||
@ -19,3 +24,20 @@ def test_which_name():
|
||||
def test_which_count_pyd():
|
||||
message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
|
||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||
|
||||
|
||||
def test_oneof_constructor_assign():
|
||||
message = Test(mixed_drink=MixedDrink(shots=42))
|
||||
field, value = betterproto.which_one_of(message, "bar")
|
||||
assert field == "mixed_drink"
|
||||
assert value.shots == 42
|
||||
|
||||
|
||||
# Issue #305:
|
||||
@pytest.mark.xfail
|
||||
def test_oneof_nested_assign():
|
||||
message = Test()
|
||||
message.mixed_drink.shots = 42
|
||||
field, value = betterproto.which_one_of(message, "bar")
|
||||
assert field == "mixed_drink"
|
||||
assert value.shots == 42
|
||||
|
@ -41,3 +41,8 @@ def test_null_fields_json():
|
||||
"test8": None,
|
||||
"test9": None,
|
||||
}
|
||||
|
||||
|
||||
def test_unset_access(): # see #523
|
||||
assert Test().test1 is None
|
||||
assert Test(test1=None).test1 is None
|
||||
|
2
tests/streams/delimited_messages.in
Normal file
2
tests/streams/delimited_messages.in
Normal file
@ -0,0 +1,2 @@
|
||||
•šï:bTesting•šï:bTesting
|
||||
|
38
tests/streams/java/.gitignore
vendored
Normal file
38
tests/streams/java/.gitignore
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
### Output ###
|
||||
target/
|
||||
!.mvn/wrapper/maven-wrapper.jar
|
||||
!**/src/main/**/target/
|
||||
!**/src/test/**/target/
|
||||
dependency-reduced-pom.xml
|
||||
MANIFEST.MF
|
||||
|
||||
### IntelliJ IDEA ###
|
||||
.idea/
|
||||
*.iws
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
### Eclipse ###
|
||||
.apt_generated
|
||||
.classpath
|
||||
.factorypath
|
||||
.project
|
||||
.settings
|
||||
.springBeans
|
||||
.sts4-cache
|
||||
|
||||
### NetBeans ###
|
||||
/nbproject/private/
|
||||
/nbbuild/
|
||||
/dist/
|
||||
/nbdist/
|
||||
/.nb-gradle/
|
||||
build/
|
||||
!**/src/main/**/build/
|
||||
!**/src/test/**/build/
|
||||
|
||||
### VS Code ###
|
||||
.vscode/
|
||||
|
||||
### Mac OS ###
|
||||
.DS_Store
|
94
tests/streams/java/pom.xml
Normal file
94
tests/streams/java/pom.xml
Normal file
@ -0,0 +1,94 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>betterproto</groupId>
|
||||
<artifactId>compatibility-test</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<protobuf.version>3.23.4</protobuf.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${protobuf.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<extensions>
|
||||
<extension>
|
||||
<groupId>kr.motd.maven</groupId>
|
||||
<artifactId>os-maven-plugin</artifactId>
|
||||
<version>1.7.1</version>
|
||||
</extension>
|
||||
</extensions>
|
||||
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.5.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
<mainClass>betterproto.CompatibilityTest</mainClass>
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<version>3.3.0</version>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifest>
|
||||
<addClasspath>true</addClasspath>
|
||||
<mainClass>betterproto.CompatibilityTest</mainClass>
|
||||
</manifest>
|
||||
</archive>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.xolstice.maven.plugins</groupId>
|
||||
<artifactId>protobuf-maven-plugin</artifactId>
|
||||
<version>0.6.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>compile</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<configuration>
|
||||
<protocArtifact>
|
||||
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
|
||||
</protocArtifact>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
||||
<finalName>${project.artifactId}</finalName>
|
||||
</build>
|
||||
|
||||
</project>
|
@ -0,0 +1,41 @@
|
||||
package betterproto;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class CompatibilityTest {
|
||||
public static void main(String[] args) throws IOException {
|
||||
if (args.length < 2)
|
||||
throw new RuntimeException("Attempted to run without the required arguments.");
|
||||
else if (args.length > 2)
|
||||
throw new RuntimeException(
|
||||
"Attempted to run with more than the expected number of arguments (>1).");
|
||||
|
||||
Tests tests = new Tests(args[1]);
|
||||
|
||||
switch (args[0]) {
|
||||
case "single_varint":
|
||||
tests.testSingleVarint();
|
||||
break;
|
||||
|
||||
case "multiple_varints":
|
||||
tests.testMultipleVarints();
|
||||
break;
|
||||
|
||||
case "single_message":
|
||||
tests.testSingleMessage();
|
||||
break;
|
||||
|
||||
case "multiple_messages":
|
||||
tests.testMultipleMessages();
|
||||
break;
|
||||
|
||||
case "infinite_messages":
|
||||
tests.testInfiniteMessages();
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new RuntimeException(
|
||||
"Attempted to run with unknown argument '" + args[0] + "'.");
|
||||
}
|
||||
}
|
||||
}
|
115
tests/streams/java/src/main/java/betterproto/Tests.java
Normal file
115
tests/streams/java/src/main/java/betterproto/Tests.java
Normal file
@ -0,0 +1,115 @@
|
||||
package betterproto;
|
||||
|
||||
import betterproto.nested.NestedOuterClass;
|
||||
import betterproto.oneof.Oneof;
|
||||
|
||||
import com.google.protobuf.CodedInputStream;
|
||||
import com.google.protobuf.CodedOutputStream;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
public class Tests {
|
||||
String path;
|
||||
|
||||
public Tests(String path) {
|
||||
this.path = path;
|
||||
}
|
||||
|
||||
public void testSingleVarint() throws IOException {
|
||||
// Read in the Python-generated single varint file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
|
||||
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||
|
||||
int value = codedInput.readUInt32();
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the value back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
|
||||
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||
|
||||
codedOutput.writeUInt32NoTag(value);
|
||||
|
||||
codedOutput.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testMultipleVarints() throws IOException {
|
||||
// Read in the Python-generated multiple varints file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
|
||||
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||
|
||||
int value1 = codedInput.readUInt32();
|
||||
int value2 = codedInput.readUInt32();
|
||||
long value3 = codedInput.readUInt64();
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the values back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
|
||||
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||
|
||||
codedOutput.writeUInt32NoTag(value1);
|
||||
codedOutput.writeUInt64NoTag(value2);
|
||||
codedOutput.writeUInt64NoTag(value3);
|
||||
|
||||
codedOutput.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testSingleMessage() throws IOException {
|
||||
// Read in the Python-generated single message file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
|
||||
CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
|
||||
|
||||
Oneof.Test message = Oneof.Test.parseFrom(codedInput);
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the message back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
|
||||
CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
|
||||
|
||||
message.writeTo(codedOutput);
|
||||
|
||||
codedOutput.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testMultipleMessages() throws IOException {
|
||||
// Read in the Python-generated multi-message file
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
|
||||
|
||||
Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||
NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
|
||||
|
||||
inputStream.close();
|
||||
|
||||
// Write the messages back to a file
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
|
||||
|
||||
oneof.writeDelimitedTo(outputStream);
|
||||
nested.writeDelimitedTo(outputStream);
|
||||
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
|
||||
public void testInfiniteMessages() throws IOException {
|
||||
// Read in as many messages as are present in the Python-generated file and write them back
|
||||
FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
|
||||
FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
|
||||
|
||||
Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||
while (current != null) {
|
||||
current.writeDelimitedTo(outputStream);
|
||||
current = Oneof.Test.parseDelimitedFrom(inputStream);
|
||||
}
|
||||
|
||||
inputStream.close();
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
}
|
27
tests/streams/java/src/main/proto/betterproto/nested.proto
Normal file
27
tests/streams/java/src/main/proto/betterproto/nested.proto
Normal file
@ -0,0 +1,27 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package nested;
|
||||
option java_package = "betterproto.nested";
|
||||
|
||||
// A test message with a nested message inside of it.
|
||||
message Test {
|
||||
// This is the nested type.
|
||||
message Nested {
|
||||
// Stores a simple counter.
|
||||
int32 count = 1;
|
||||
}
|
||||
// This is the nested enum.
|
||||
enum Msg {
|
||||
NONE = 0;
|
||||
THIS = 1;
|
||||
}
|
||||
|
||||
Nested nested = 1;
|
||||
Sibling sibling = 2;
|
||||
Sibling sibling2 = 3;
|
||||
Msg msg = 4;
|
||||
}
|
||||
|
||||
message Sibling {
|
||||
int32 foo = 1;
|
||||
}
|
19
tests/streams/java/src/main/proto/betterproto/oneof.proto
Normal file
19
tests/streams/java/src/main/proto/betterproto/oneof.proto
Normal file
@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package oneof;
|
||||
option java_package = "betterproto.oneof";
|
||||
|
||||
message Test {
|
||||
oneof foo {
|
||||
int32 pitied = 1;
|
||||
string pitier = 2;
|
||||
}
|
||||
|
||||
int32 just_a_regular_field = 3;
|
||||
|
||||
oneof bar {
|
||||
int32 drinks = 11;
|
||||
string bar_name = 12;
|
||||
}
|
||||
}
|
||||
|
19
tests/test_all_definition.py
Normal file
19
tests/test_all_definition.py
Normal file
@ -0,0 +1,19 @@
|
||||
def test_all_definition():
|
||||
"""
|
||||
Check that a compiled module defines __all__ with the right value.
|
||||
|
||||
These modules have been chosen since they contain messages, services and enums.
|
||||
"""
|
||||
import tests.output_betterproto.enum as enum
|
||||
import tests.output_betterproto.service as service
|
||||
|
||||
assert service.__all__ == (
|
||||
"ThingType",
|
||||
"DoThingRequest",
|
||||
"DoThingResponse",
|
||||
"GetThingRequest",
|
||||
"GetThingResponse",
|
||||
"TestStub",
|
||||
"TestBase",
|
||||
)
|
||||
assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test")
|
@ -2,9 +2,12 @@ import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.deprecated import (
|
||||
Empty,
|
||||
Message,
|
||||
Test,
|
||||
TestServiceStub,
|
||||
)
|
||||
|
||||
|
||||
@ -32,14 +35,27 @@ def test_message_with_deprecated_field(message):
|
||||
|
||||
|
||||
def test_message_with_deprecated_field_not_set(message):
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
Test(value=10)
|
||||
|
||||
assert not record
|
||||
|
||||
|
||||
def test_message_with_deprecated_field_not_set_default(message):
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
_ = Test(value=10).message
|
||||
|
||||
assert not record
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_with_deprecated_method():
|
||||
stub = TestServiceStub(MockChannel([Empty(), Empty()]))
|
||||
|
||||
with pytest.warns(DeprecationWarning) as record:
|
||||
await stub.deprecated_func(Empty())
|
||||
|
||||
assert len(record) == 1
|
||||
assert str(record[0].message) == f"TestService.deprecated_func is deprecated"
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
await stub.func(Empty())
|
||||
|
37
tests/test_documentation.py
Normal file
37
tests/test_documentation.py
Normal file
@ -0,0 +1,37 @@
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
|
||||
def check(generated_doc: str, type: str) -> None:
|
||||
assert f"Documentation of {type} 1" in generated_doc
|
||||
assert "other line 1" in generated_doc
|
||||
assert f"Documentation of {type} 2" in generated_doc
|
||||
assert "other line 2" in generated_doc
|
||||
assert f"Documentation of {type} 3" in generated_doc
|
||||
|
||||
|
||||
def test_documentation() -> None:
|
||||
from .output_betterproto.documentation import (
|
||||
Enum,
|
||||
ServiceBase,
|
||||
ServiceStub,
|
||||
Test,
|
||||
)
|
||||
|
||||
check(Test.__doc__, "message")
|
||||
|
||||
source = inspect.getsource(Test)
|
||||
tree = ast.parse(source)
|
||||
check(tree.body[0].body[2].value.value, "field")
|
||||
|
||||
check(Enum.__doc__, "enum")
|
||||
|
||||
source = inspect.getsource(Enum)
|
||||
tree = ast.parse(source)
|
||||
check(tree.body[0].body[2].value.value, "variant")
|
||||
|
||||
check(ServiceBase.__doc__, "service")
|
||||
check(ServiceBase.get.__doc__, "method")
|
||||
|
||||
check(ServiceStub.__doc__, "service")
|
||||
check(ServiceStub.get.__doc__, "method")
|
79
tests/test_enum.py
Normal file
79
tests/test_enum.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
import betterproto
|
||||
|
||||
|
||||
class Colour(betterproto.Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
|
||||
PURPLE = Colour.__new__(Colour, name=None, value=4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, str_value",
|
||||
[
|
||||
(Colour.RED, "RED"),
|
||||
(Colour.GREEN, "GREEN"),
|
||||
(Colour.BLUE, "BLUE"),
|
||||
],
|
||||
)
|
||||
def test_str(member: Colour, str_value: str) -> None:
|
||||
assert str(member) == str_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, repr_value",
|
||||
[
|
||||
(Colour.RED, "Colour.RED"),
|
||||
(Colour.GREEN, "Colour.GREEN"),
|
||||
(Colour.BLUE, "Colour.BLUE"),
|
||||
],
|
||||
)
|
||||
def test_repr(member: Colour, repr_value: str) -> None:
|
||||
assert repr(member) == repr_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, values",
|
||||
[
|
||||
(Colour.RED, ("RED", 1)),
|
||||
(Colour.GREEN, ("GREEN", 2)),
|
||||
(Colour.BLUE, ("BLUE", 3)),
|
||||
(PURPLE, (None, 4)),
|
||||
],
|
||||
)
|
||||
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
|
||||
assert (member.name, member.value) == values
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, input_str",
|
||||
[
|
||||
(Colour.RED, "RED"),
|
||||
(Colour.GREEN, "GREEN"),
|
||||
(Colour.BLUE, "BLUE"),
|
||||
],
|
||||
)
|
||||
def test_from_string(member: Colour, input_str: str) -> None:
|
||||
assert Colour.from_string(input_str) == member
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member, input_int",
|
||||
[
|
||||
(Colour.RED, 1),
|
||||
(Colour.GREEN, 2),
|
||||
(Colour.BLUE, 3),
|
||||
(PURPLE, 4),
|
||||
],
|
||||
)
|
||||
def test_try_value(member: Colour, input_int: int) -> None:
|
||||
assert Colour.try_value(input_int) == member
|
@ -545,47 +545,6 @@ def test_oneof_default_value_set_causes_writes_wire():
|
||||
)
|
||||
|
||||
|
||||
def test_recursive_message():
|
||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||
|
||||
msg = RecursiveMessage()
|
||||
|
||||
assert msg.child == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect equality.
|
||||
assert msg == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect serialization.
|
||||
assert bytes(msg) == b""
|
||||
|
||||
|
||||
def test_recursive_message_defaults():
|
||||
from tests.output_betterproto.recursivemessage import (
|
||||
Intermediate,
|
||||
Test as RecursiveMessage,
|
||||
)
|
||||
|
||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
# set values are as expected
|
||||
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
# lazy initialized works modifies the message
|
||||
assert msg != RecursiveMessage(
|
||||
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
||||
)
|
||||
msg.child.child.name = "jude"
|
||||
assert msg == RecursiveMessage(
|
||||
name="bob",
|
||||
intermediate=Intermediate(42),
|
||||
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
||||
)
|
||||
|
||||
# lazily initialization recurses as needed
|
||||
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
||||
assert msg.intermediate.child.intermediate == Intermediate()
|
||||
|
||||
|
||||
def test_message_repr():
|
||||
from tests.output_betterproto.recursivemessage import Test
|
||||
|
||||
@ -662,9 +621,7 @@ iso_candidates = """2009-12-12T12:34
|
||||
2010-02-18T16:00:00.23334444
|
||||
2010-02-18T16:00:00,2283
|
||||
2009-05-19 143922
|
||||
2009-05-19 1439""".split(
|
||||
"\n"
|
||||
)
|
||||
2009-05-19 1439""".split("\n")
|
||||
|
||||
|
||||
def test_iso_datetime():
|
||||
@ -699,25 +656,6 @@ def test_service_argument__expected_parameter():
|
||||
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
||||
|
||||
|
||||
def test_copyability():
|
||||
@dataclass
|
||||
class Spam(betterproto.Message):
|
||||
foo: bool = betterproto.bool_field(1)
|
||||
bar: int = betterproto.int32_field(2)
|
||||
baz: List[str] = betterproto.string_field(3)
|
||||
|
||||
spam = Spam(bar=12, baz=["hello"])
|
||||
copied = copy(spam)
|
||||
assert spam == copied
|
||||
assert spam is not copied
|
||||
assert spam.baz is copied.baz
|
||||
|
||||
deepcopied = deepcopy(spam)
|
||||
assert spam == deepcopied
|
||||
assert spam is not deepcopied
|
||||
assert spam.baz is not deepcopied.baz
|
||||
|
||||
|
||||
def test_is_set():
|
||||
@dataclass
|
||||
class Spam(betterproto.Message):
|
||||
|
@ -4,6 +4,15 @@ from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.plugin.typing_compiler import DirectImportTypingCompiler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def typing_compiler() -> DirectImportTypingCompiler:
|
||||
"""
|
||||
Generates a simple Direct Import Typing Compiler for testing.
|
||||
"""
|
||||
return DirectImportTypingCompiler()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -32,15 +41,70 @@ from betterproto.compile.importing import (
|
||||
],
|
||||
)
|
||||
def test_reference_google_wellknown_types_non_wrappers(
|
||||
google_type: str, expected_name: str, expected_import: str
|
||||
google_type: str,
|
||||
expected_name: str,
|
||||
expected_import: str,
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(package="", imports=imports, source_type=google_type)
|
||||
name = get_type_reference(
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type=google_type,
|
||||
typing_compiler=typing_compiler,
|
||||
pydantic=False,
|
||||
)
|
||||
|
||||
assert name == expected_name
|
||||
assert imports.__contains__(
|
||||
expected_import
|
||||
), f"{expected_import} not found in {imports}"
|
||||
assert imports.__contains__(expected_import), (
|
||||
f"{expected_import} not found in {imports}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["google_type", "expected_name", "expected_import"],
|
||||
[
|
||||
(
|
||||
".google.protobuf.Empty",
|
||||
'"betterproto_lib_pydantic_google_protobuf.Empty"',
|
||||
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
|
||||
),
|
||||
(
|
||||
".google.protobuf.Struct",
|
||||
'"betterproto_lib_pydantic_google_protobuf.Struct"',
|
||||
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
|
||||
),
|
||||
(
|
||||
".google.protobuf.ListValue",
|
||||
'"betterproto_lib_pydantic_google_protobuf.ListValue"',
|
||||
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
|
||||
),
|
||||
(
|
||||
".google.protobuf.Value",
|
||||
'"betterproto_lib_pydantic_google_protobuf.Value"',
|
||||
"import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_reference_google_wellknown_types_non_wrappers_pydantic(
|
||||
google_type: str,
|
||||
expected_name: str,
|
||||
expected_import: str,
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type=google_type,
|
||||
typing_compiler=typing_compiler,
|
||||
pydantic=True,
|
||||
)
|
||||
|
||||
assert name == expected_name
|
||||
assert imports.__contains__(expected_import), (
|
||||
f"{expected_import} not found in {imports}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -58,10 +122,15 @@ def test_reference_google_wellknown_types_non_wrappers(
|
||||
],
|
||||
)
|
||||
def test_referenceing_google_wrappers_unwraps_them(
|
||||
google_type: str, expected_name: str
|
||||
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(package="", imports=imports, source_type=google_type)
|
||||
name = get_type_reference(
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type=google_type,
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert name == expected_name
|
||||
assert imports == set()
|
||||
@ -94,223 +163,321 @@ def test_referenceing_google_wrappers_unwraps_them(
|
||||
],
|
||||
)
|
||||
def test_referenceing_google_wrappers_without_unwrapping(
|
||||
google_type: str, expected_name: str
|
||||
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
|
||||
):
|
||||
name = get_type_reference(
|
||||
package="", imports=set(), source_type=google_type, unwrap=False
|
||||
package="",
|
||||
imports=set(),
|
||||
source_type=google_type,
|
||||
typing_compiler=typing_compiler,
|
||||
unwrap=False,
|
||||
)
|
||||
|
||||
assert name == expected_name
|
||||
|
||||
|
||||
def test_reference_child_package_from_package():
|
||||
def test_reference_child_package_from_package(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package", imports=imports, source_type="package.child.Message"
|
||||
package="package",
|
||||
imports=imports,
|
||||
source_type="package.child.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from . import child"}
|
||||
assert name == '"child.Message"'
|
||||
|
||||
|
||||
def test_reference_child_package_from_root():
|
||||
def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(package="", imports=imports, source_type="child.Message")
|
||||
name = get_type_reference(
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type="child.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from . import child"}
|
||||
assert name == '"child.Message"'
|
||||
|
||||
|
||||
def test_reference_camel_cased():
|
||||
def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="", imports=imports, source_type="child_package.example_message"
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type="child_package.example_message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from . import child_package"}
|
||||
assert name == '"child_package.ExampleMessage"'
|
||||
|
||||
|
||||
def test_reference_nested_child_from_root():
|
||||
def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="", imports=imports, source_type="nested.child.Message"
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type="nested.child.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .nested import child as nested_child"}
|
||||
assert name == '"nested_child.Message"'
|
||||
|
||||
|
||||
def test_reference_deeply_nested_child_from_root():
|
||||
def test_reference_deeply_nested_child_from_root(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="", imports=imports, source_type="deeply.nested.child.Message"
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type="deeply.nested.child.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
||||
assert name == '"deeply_nested_child.Message"'
|
||||
|
||||
|
||||
def test_reference_deeply_nested_child_from_package():
|
||||
def test_reference_deeply_nested_child_from_package(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package",
|
||||
imports=imports,
|
||||
source_type="package.deeply.nested.child.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
||||
assert name == '"deeply_nested_child.Message"'
|
||||
|
||||
|
||||
def test_reference_root_sibling():
|
||||
imports = set()
|
||||
name = get_type_reference(package="", imports=imports, source_type="Message")
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_nested_siblings():
|
||||
imports = set()
|
||||
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_deeply_nested_siblings():
|
||||
def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="foo.bar", imports=imports, source_type="foo.bar.Message"
|
||||
package="",
|
||||
imports=imports,
|
||||
source_type="Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_parent_package_from_child():
|
||||
def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.child", imports=imports, source_type="package.Message"
|
||||
package="foo",
|
||||
imports=imports,
|
||||
source_type="foo.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="foo.bar",
|
||||
imports=imports,
|
||||
source_type="foo.bar.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_parent_package_from_child(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.child",
|
||||
imports=imports,
|
||||
source_type="package.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ... import package as __package__"}
|
||||
assert name == '"__package__.Message"'
|
||||
|
||||
|
||||
def test_reference_parent_package_from_deeply_nested_child():
|
||||
def test_reference_parent_package_from_deeply_nested_child(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.deeply.nested.child",
|
||||
imports=imports,
|
||||
source_type="package.deeply.nested.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ... import nested as __nested__"}
|
||||
assert name == '"__nested__.Message"'
|
||||
|
||||
|
||||
def test_reference_ancestor_package_from_nested_child():
|
||||
def test_reference_ancestor_package_from_nested_child(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.ancestor.nested.child",
|
||||
imports=imports,
|
||||
source_type="package.ancestor.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .... import ancestor as ___ancestor__"}
|
||||
assert name == '"___ancestor__.Message"'
|
||||
|
||||
|
||||
def test_reference_root_package_from_child():
|
||||
def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.child", imports=imports, source_type="Message"
|
||||
package="package.child",
|
||||
imports=imports,
|
||||
source_type="Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ... import Message as __Message__"}
|
||||
assert name == '"__Message__"'
|
||||
|
||||
|
||||
def test_reference_root_package_from_deeply_nested_child():
|
||||
def test_reference_root_package_from_deeply_nested_child(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.deeply.nested.child", imports=imports, source_type="Message"
|
||||
package="package.deeply.nested.child",
|
||||
imports=imports,
|
||||
source_type="Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ..... import Message as ____Message__"}
|
||||
assert name == '"____Message__"'
|
||||
|
||||
|
||||
def test_reference_unrelated_package():
|
||||
def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
|
||||
name = get_type_reference(
|
||||
package="a",
|
||||
imports=imports,
|
||||
source_type="p.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .. import p as _p__"}
|
||||
assert name == '"_p__.Message"'
|
||||
|
||||
|
||||
def test_reference_unrelated_nested_package():
|
||||
def test_reference_unrelated_nested_package(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
|
||||
name = get_type_reference(
|
||||
package="a.b",
|
||||
imports=imports,
|
||||
source_type="p.q.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ...p import q as __p_q__"}
|
||||
assert name == '"__p_q__.Message"'
|
||||
|
||||
|
||||
def test_reference_unrelated_deeply_nested_package():
|
||||
def test_reference_unrelated_deeply_nested_package(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
|
||||
package="a.b.c.d",
|
||||
imports=imports,
|
||||
source_type="p.q.r.s.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
|
||||
assert name == '"____p_q_r_s__.Message"'
|
||||
|
||||
|
||||
def test_reference_cousin_package():
|
||||
def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
|
||||
name = get_type_reference(
|
||||
package="a.x",
|
||||
imports=imports,
|
||||
source_type="a.y.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from .. import y as _y__"}
|
||||
assert name == '"_y__.Message"'
|
||||
|
||||
|
||||
def test_reference_cousin_package_different_name():
|
||||
def test_reference_cousin_package_different_name(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="test.package1", imports=imports, source_type="cousin.package2.Message"
|
||||
package="test.package1",
|
||||
imports=imports,
|
||||
source_type="cousin.package2.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
|
||||
assert name == '"__cousin_package2__.Message"'
|
||||
|
||||
|
||||
def test_reference_cousin_package_same_name():
|
||||
def test_reference_cousin_package_same_name(
|
||||
typing_compiler: DirectImportTypingCompiler,
|
||||
):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="test.package", imports=imports, source_type="cousin.package.Message"
|
||||
package="test.package",
|
||||
imports=imports,
|
||||
source_type="cousin.package.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ...cousin import package as __cousin_package__"}
|
||||
assert name == '"__cousin_package__.Message"'
|
||||
|
||||
|
||||
def test_reference_far_cousin_package():
|
||||
def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="a.x.y", imports=imports, source_type="a.b.c.Message"
|
||||
package="a.x.y",
|
||||
imports=imports,
|
||||
source_type="a.b.c.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ...b import c as __b_c__"}
|
||||
assert name == '"__b_c__.Message"'
|
||||
|
||||
|
||||
def test_reference_far_far_cousin_package():
|
||||
def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
|
||||
package="a.x.y.z",
|
||||
imports=imports,
|
||||
source_type="a.b.c.d.Message",
|
||||
typing_compiler=typing_compiler,
|
||||
)
|
||||
|
||||
assert imports == {"from ....b.c import d as ___b_c_d__"}
|
||||
|
@ -174,22 +174,21 @@ def test_message_equality(test_data: TestData) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
|
||||
def test_message_json(repeat, test_data: TestData) -> None:
|
||||
def test_message_json(test_data: TestData) -> None:
|
||||
plugin_module, _, json_data = test_data
|
||||
|
||||
for _ in range(repeat):
|
||||
for sample in json_data:
|
||||
if sample.belongs_to(test_input_config.non_symmetrical_json):
|
||||
continue
|
||||
for sample in json_data:
|
||||
if sample.belongs_to(test_input_config.non_symmetrical_json):
|
||||
continue
|
||||
|
||||
message: betterproto.Message = plugin_module.Test()
|
||||
message: betterproto.Message = plugin_module.Test()
|
||||
|
||||
message.from_json(sample.json)
|
||||
message_json = message.to_json(0)
|
||||
message.from_json(sample.json)
|
||||
message_json = message.to_json(0)
|
||||
|
||||
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
||||
json.loads(sample.json)
|
||||
)
|
||||
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
||||
json.loads(sample.json)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
||||
@ -198,28 +197,27 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
|
||||
def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
||||
def test_binary_compatibility(test_data: TestData) -> None:
|
||||
plugin_module, reference_module, json_data = test_data
|
||||
|
||||
for sample in json_data:
|
||||
reference_instance = Parse(sample.json, reference_module().Test())
|
||||
reference_binary_output = reference_instance.SerializeToString()
|
||||
|
||||
for _ in range(repeat):
|
||||
plugin_instance_from_json: betterproto.Message = (
|
||||
plugin_module.Test().from_json(sample.json)
|
||||
)
|
||||
plugin_instance_from_binary = plugin_module.Test.FromString(
|
||||
reference_binary_output
|
||||
)
|
||||
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
|
||||
sample.json
|
||||
)
|
||||
plugin_instance_from_binary = plugin_module.Test.FromString(
|
||||
reference_binary_output
|
||||
)
|
||||
|
||||
# Generally this can't be relied on, but here we are aiming to match the
|
||||
# existing Python implementation and aren't doing anything tricky.
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||
assert bytes(plugin_instance_from_json) == reference_binary_output
|
||||
assert bytes(plugin_instance_from_binary) == reference_binary_output
|
||||
# Generally this can't be relied on, but here we are aiming to match the
|
||||
# existing Python implementation and aren't doing anything tricky.
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||
assert bytes(plugin_instance_from_json) == reference_binary_output
|
||||
assert bytes(plugin_instance_from_binary) == reference_binary_output
|
||||
|
||||
assert plugin_instance_from_json == plugin_instance_from_binary
|
||||
assert dict_replace_nans(
|
||||
plugin_instance_from_json.to_dict()
|
||||
) == dict_replace_nans(plugin_instance_from_binary.to_dict())
|
||||
assert plugin_instance_from_json == plugin_instance_from_binary
|
||||
assert dict_replace_nans(
|
||||
plugin_instance_from_json.to_dict()
|
||||
) == dict_replace_nans(plugin_instance_from_binary.to_dict())
|
||||
|
111
tests/test_module_validation.py
Normal file
111
tests/test_module_validation.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from betterproto.plugin.module_validation import ModuleValidator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["text", "expected_collisions"],
|
||||
[
|
||||
pytest.param(
|
||||
["import os"],
|
||||
None,
|
||||
id="single import",
|
||||
),
|
||||
pytest.param(
|
||||
["import os", "import sys"],
|
||||
None,
|
||||
id="multiple imports",
|
||||
),
|
||||
pytest.param(
|
||||
["import os", "import os"],
|
||||
{"os"},
|
||||
id="duplicate imports",
|
||||
),
|
||||
pytest.param(
|
||||
["from os import path", "import os"],
|
||||
None,
|
||||
id="duplicate imports with alias",
|
||||
),
|
||||
pytest.param(
|
||||
["from os import path", "import os as os_alias"],
|
||||
None,
|
||||
id="duplicate imports with alias",
|
||||
),
|
||||
pytest.param(
|
||||
["from os import path", "import os as path"],
|
||||
{"path"},
|
||||
id="duplicate imports with alias",
|
||||
),
|
||||
pytest.param(
|
||||
["import os", "class os:"],
|
||||
{"os"},
|
||||
id="duplicate import with class",
|
||||
),
|
||||
pytest.param(
|
||||
["import os", "class os:", " pass", "import sys"],
|
||||
{"os"},
|
||||
id="duplicate import with class and another",
|
||||
),
|
||||
pytest.param(
|
||||
["def test(): pass", "class test:"],
|
||||
{"test"},
|
||||
id="duplicate class and function",
|
||||
),
|
||||
pytest.param(
|
||||
["def test(): pass", "def test(): pass"],
|
||||
{"test"},
|
||||
id="duplicate functions",
|
||||
),
|
||||
pytest.param(
|
||||
["def test(): pass", "test = 100"],
|
||||
{"test"},
|
||||
id="function and variable",
|
||||
),
|
||||
pytest.param(
|
||||
["def test():", " test = 3"],
|
||||
None,
|
||||
id="function and variable in function",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
"def test(): pass",
|
||||
"'''",
|
||||
"def test(): pass",
|
||||
"'''",
|
||||
"def test_2(): pass",
|
||||
],
|
||||
None,
|
||||
id="duplicate functions with multiline string",
|
||||
),
|
||||
pytest.param(
|
||||
["def test(): pass", "# def test(): pass"],
|
||||
None,
|
||||
id="duplicate functions with comments",
|
||||
),
|
||||
pytest.param(
|
||||
["from test import (", " A", " B", " C", ")"],
|
||||
None,
|
||||
id="multiline import",
|
||||
),
|
||||
pytest.param(
|
||||
["from test import (", " A", " B", " C", ")", "from test import A"],
|
||||
{"A"},
|
||||
id="multiline import with duplicate",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]):
|
||||
line_iterator = iter(text)
|
||||
validator = ModuleValidator(line_iterator)
|
||||
valid = validator.validate()
|
||||
if expected_collisions is None:
|
||||
assert valid
|
||||
else:
|
||||
assert set(validator.collisions.keys()) == expected_collisions
|
||||
assert not valid
|
216
tests/test_pickling.py
Normal file
216
tests/test_pickling.py
Normal file
@ -0,0 +1,216 @@
|
||||
import pickle
|
||||
from copy import (
|
||||
copy,
|
||||
deepcopy,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
from unittest.mock import ANY
|
||||
|
||||
import cachelib
|
||||
|
||||
import betterproto
|
||||
from betterproto.lib.google import protobuf as google
|
||||
|
||||
|
||||
def unpickled(message):
|
||||
return pickle.loads(pickle.dumps(message))
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Fe(betterproto.Message):
|
||||
abc: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Fi(betterproto.Message):
|
||||
abc: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Fo(betterproto.Message):
|
||||
abc: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class NestedData(betterproto.Message):
|
||||
struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(
|
||||
1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||
)
|
||||
map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field(
|
||||
2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Complex(betterproto.Message):
|
||||
foo_str: str = betterproto.string_field(1)
|
||||
fe: "Fe" = betterproto.message_field(3, group="grp")
|
||||
fi: "Fi" = betterproto.message_field(4, group="grp")
|
||||
fo: "Fo" = betterproto.message_field(5, group="grp")
|
||||
nested_data: "NestedData" = betterproto.message_field(6)
|
||||
mapping: Dict[str, "google.Any"] = betterproto.map_field(
|
||||
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
class BetterprotoEnum(betterproto.Enum):
|
||||
UNSPECIFIED = 0
|
||||
ONE = 1
|
||||
|
||||
|
||||
def complex_msg():
|
||||
return Complex(
|
||||
foo_str="yep",
|
||||
fe=Fe(abc="1"),
|
||||
nested_data=NestedData(
|
||||
struct_foo={
|
||||
"foo": google.Struct(
|
||||
fields={
|
||||
"hello": google.Value(
|
||||
list_value=google.ListValue(
|
||||
values=[google.Value(string_value="world")]
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
map_str_any_bar={
|
||||
"key": google.Any(value=b"value"),
|
||||
},
|
||||
),
|
||||
mapping={
|
||||
"message": google.Any(value=bytes(Fi(abc="hi"))),
|
||||
"string": google.Any(value=b"howdy"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_pickling_complex_message():
|
||||
msg = complex_msg()
|
||||
deser = unpickled(msg)
|
||||
assert msg == deser
|
||||
assert msg.fe.abc == "1"
|
||||
assert msg.is_set("fi") is not True
|
||||
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
|
||||
assert msg.mapping["string"].value.decode() == "howdy"
|
||||
assert (
|
||||
msg.nested_data.struct_foo["foo"]
|
||||
.fields["hello"]
|
||||
.list_value.values[0]
|
||||
.string_value
|
||||
== "world"
|
||||
)
|
||||
|
||||
|
||||
def test_recursive_message():
|
||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||
|
||||
msg = RecursiveMessage()
|
||||
msg = unpickled(msg)
|
||||
|
||||
assert msg.child == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect equality.
|
||||
assert msg == RecursiveMessage()
|
||||
|
||||
# Lazily-created zero-value children must not affect serialization.
|
||||
assert bytes(msg) == b""
|
||||
|
||||
|
||||
def test_recursive_message_defaults():
|
||||
from tests.output_betterproto.recursivemessage import (
|
||||
Intermediate,
|
||||
Test as RecursiveMessage,
|
||||
)
|
||||
|
||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
msg = unpickled(msg)
|
||||
|
||||
# set values are as expected
|
||||
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
# lazy initialized works modifies the message
|
||||
assert msg != RecursiveMessage(
|
||||
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
||||
)
|
||||
msg.child.child.name = "jude"
|
||||
assert msg == RecursiveMessage(
|
||||
name="bob",
|
||||
intermediate=Intermediate(42),
|
||||
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
||||
)
|
||||
|
||||
# lazily initialization recurses as needed
|
||||
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
||||
assert msg.intermediate.child.intermediate == Intermediate()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PickledMessage(betterproto.Message):
|
||||
foo: bool = betterproto.bool_field(1)
|
||||
bar: int = betterproto.int32_field(2)
|
||||
baz: List[str] = betterproto.string_field(3)
|
||||
|
||||
|
||||
def test_copyability():
|
||||
msg = PickledMessage(bar=12, baz=["hello"])
|
||||
msg = unpickled(msg)
|
||||
|
||||
copied = copy(msg)
|
||||
assert msg == copied
|
||||
assert msg is not copied
|
||||
assert msg.baz is copied.baz
|
||||
|
||||
deepcopied = deepcopy(msg)
|
||||
assert msg == deepcopied
|
||||
assert msg is not deepcopied
|
||||
assert msg.baz is not deepcopied.baz
|
||||
|
||||
|
||||
def test_message_can_be_cached():
|
||||
"""Cachelib uses pickling to cache values"""
|
||||
|
||||
cache = cachelib.SimpleCache()
|
||||
|
||||
def use_cache():
|
||||
calls = getattr(use_cache, "calls", 0)
|
||||
result = cache.get("message")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
setattr(use_cache, "calls", calls + 1)
|
||||
result = complex_msg()
|
||||
cache.set("message", result)
|
||||
return result
|
||||
|
||||
for n in range(10):
|
||||
if n == 0:
|
||||
assert not cache.has("message")
|
||||
else:
|
||||
assert cache.has("message")
|
||||
|
||||
msg = use_cache()
|
||||
assert use_cache.calls == 1 # The message is only ever built once
|
||||
assert msg.fe.abc == "1"
|
||||
assert msg.is_set("fi") is not True
|
||||
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
|
||||
assert msg.mapping["string"].value.decode() == "howdy"
|
||||
assert (
|
||||
msg.nested_data.struct_foo["foo"]
|
||||
.fields["hello"]
|
||||
.list_value.values[0]
|
||||
.string_value
|
||||
== "world"
|
||||
)
|
||||
|
||||
|
||||
def test_pickle_enum():
|
||||
enum = BetterprotoEnum.ONE
|
||||
assert unpickled(enum) == enum
|
||||
|
||||
enum = BetterprotoEnum.UNSPECIFIED
|
||||
assert unpickled(enum) == enum
|
@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from subprocess import run
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -40,6 +42,8 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
|
||||
|
||||
streams_path = Path("tests/streams/")
|
||||
|
||||
java = which("java")
|
||||
|
||||
|
||||
def test_load_varint_too_long():
|
||||
with BytesIO(
|
||||
@ -58,7 +62,7 @@ def test_load_varint_file():
|
||||
stream.read(2) # Skip until first multi-byte
|
||||
assert betterproto.load_varint(stream) == (
|
||||
123456789,
|
||||
b"\x95\x9A\xEF\x3A",
|
||||
b"\x95\x9a\xef\x3a",
|
||||
) # Multi-byte varint
|
||||
|
||||
|
||||
@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path):
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
def test_message_dump_delimited(tmp_path):
|
||||
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
|
||||
streams_path / "delimited_messages.in", "rb"
|
||||
) as exp_stream:
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
def test_message_len():
|
||||
assert len_oneof == len(bytes(oneof_example))
|
||||
assert len(nested_example) == len(bytes(nested_example))
|
||||
@ -155,7 +171,15 @@ def test_message_load_too_small():
|
||||
oneof.Test().load(stream, len_oneof - 1)
|
||||
|
||||
|
||||
def test_message_too_large():
|
||||
def test_message_load_delimited():
|
||||
with open(streams_path / "delimited_messages.in", "rb") as stream:
|
||||
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
|
||||
assert stream.read(1) == b""
|
||||
|
||||
|
||||
def test_message_load_too_large():
|
||||
with open(
|
||||
streams_path / "message_dump_file_single.expected", "rb"
|
||||
) as stream, pytest.raises(ValueError):
|
||||
@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path):
|
||||
streams_path / "dump_varint_positive.expected", "rb"
|
||||
) as exp_stream:
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
# Java compatibility tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def compile_jar():
|
||||
# Skip if not all required tools are present
|
||||
if java is None:
|
||||
pytest.skip("`java` command is absent and is required")
|
||||
mvn = which("mvn")
|
||||
if mvn is None:
|
||||
pytest.skip("Maven is absent and is required")
|
||||
|
||||
# Compile the JAR
|
||||
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
|
||||
if proc_maven.returncode != 0:
|
||||
pytest.skip(
|
||||
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
|
||||
)
|
||||
|
||||
|
||||
jar = "tests/streams/java/target/compatibility-test.jar"
|
||||
|
||||
|
||||
def run_jar(command: str, tmp_path):
|
||||
return run([java, "-jar", jar, command, tmp_path], check=True)
|
||||
|
||||
|
||||
def run_java_single_varint(value: int, tmp_path) -> int:
|
||||
# Write single varint to file
|
||||
with open(tmp_path / "py_single_varint.out", "wb") as stream:
|
||||
betterproto.dump_varint(value, stream)
|
||||
|
||||
# Have Java read this varint and write it back
|
||||
run_jar("single_varint", tmp_path)
|
||||
|
||||
# Read single varint from Java output file
|
||||
with open(tmp_path / "java_single_varint.out", "rb") as stream:
|
||||
returned = betterproto.load_varint(stream)
|
||||
with pytest.raises(EOFError):
|
||||
betterproto.load_varint(stream)
|
||||
|
||||
return returned
|
||||
|
||||
|
||||
def test_single_varint(compile_jar, tmp_path):
|
||||
single_byte = (1, b"\x01")
|
||||
multi_byte = (123456789, b"\x95\x9a\xef\x3a")
|
||||
|
||||
# Write a single-byte varint to a file and have Java read it back
|
||||
returned = run_java_single_varint(single_byte[0], tmp_path)
|
||||
assert returned == single_byte
|
||||
|
||||
# Same for a multi-byte varint
|
||||
returned = run_java_single_varint(multi_byte[0], tmp_path)
|
||||
assert returned == multi_byte
|
||||
|
||||
|
||||
def test_multiple_varints(compile_jar, tmp_path):
|
||||
single_byte = (1, b"\x01")
|
||||
multi_byte = (123456789, b"\x95\x9a\xef\x3a")
|
||||
over32 = (3000000000, b"\x80\xbc\xc1\x96\x0b")
|
||||
|
||||
# Write two varints to the same file
|
||||
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
|
||||
betterproto.dump_varint(single_byte[0], stream)
|
||||
betterproto.dump_varint(multi_byte[0], stream)
|
||||
betterproto.dump_varint(over32[0], stream)
|
||||
|
||||
# Have Java read these varints and write them back
|
||||
run_jar("multiple_varints", tmp_path)
|
||||
|
||||
# Read varints from Java output file
|
||||
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
|
||||
returned_single = betterproto.load_varint(stream)
|
||||
returned_multi = betterproto.load_varint(stream)
|
||||
returned_over32 = betterproto.load_varint(stream)
|
||||
with pytest.raises(EOFError):
|
||||
betterproto.load_varint(stream)
|
||||
|
||||
assert returned_single == single_byte
|
||||
assert returned_multi == multi_byte
|
||||
assert returned_over32 == over32
|
||||
|
||||
|
||||
def test_single_message(compile_jar, tmp_path):
|
||||
# Write message to file
|
||||
with open(tmp_path / "py_single_message.out", "wb") as stream:
|
||||
oneof_example.dump(stream)
|
||||
|
||||
# Have Java read and return the message
|
||||
run_jar("single_message", tmp_path)
|
||||
|
||||
# Read and check the returned message
|
||||
with open(tmp_path / "java_single_message.out", "rb") as stream:
|
||||
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
|
||||
assert stream.read() == b""
|
||||
|
||||
assert returned == oneof_example
|
||||
|
||||
|
||||
def test_multiple_messages(compile_jar, tmp_path):
|
||||
# Write delimited messages to file
|
||||
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
# Have Java read and return the messages
|
||||
run_jar("multiple_messages", tmp_path)
|
||||
|
||||
# Read and check the returned messages
|
||||
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
|
||||
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||
assert stream.read() == b""
|
||||
|
||||
assert returned_oneof == oneof_example
|
||||
assert returned_nested == nested_example
|
||||
|
||||
|
||||
def test_infinite_messages(compile_jar, tmp_path):
|
||||
num_messages = 5
|
||||
|
||||
# Write delimited messages to file
|
||||
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
|
||||
for x in range(num_messages):
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
# Have Java read and return the messages
|
||||
run_jar("infinite_messages", tmp_path)
|
||||
|
||||
# Read and check the returned messages
|
||||
messages = []
|
||||
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
|
||||
while True:
|
||||
try:
|
||||
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
assert len(messages) == num_messages
|
||||
|
36
tests/test_struct.py
Normal file
36
tests/test_struct.py
Normal file
@ -0,0 +1,36 @@
|
||||
import json
|
||||
|
||||
from betterproto.lib.google.protobuf import Struct
|
||||
from betterproto.lib.pydantic.google.protobuf import Struct as StructPydantic
|
||||
|
||||
|
||||
def test_struct_roundtrip():
|
||||
data = {
|
||||
"foo": "bar",
|
||||
"baz": None,
|
||||
"quux": 123,
|
||||
"zap": [1, {"two": 3}, "four"],
|
||||
}
|
||||
data_json = json.dumps(data)
|
||||
|
||||
struct_from_dict = Struct().from_dict(data)
|
||||
assert struct_from_dict.fields == data
|
||||
assert struct_from_dict.to_dict() == data
|
||||
assert struct_from_dict.to_json() == data_json
|
||||
|
||||
struct_from_json = Struct().from_json(data_json)
|
||||
assert struct_from_json.fields == data
|
||||
assert struct_from_json.to_dict() == data
|
||||
assert struct_from_json == struct_from_dict
|
||||
assert struct_from_json.to_json() == data_json
|
||||
|
||||
struct_pyd_from_dict = StructPydantic(fields={}).from_dict(data)
|
||||
assert struct_pyd_from_dict.fields == data
|
||||
assert struct_pyd_from_dict.to_dict() == data
|
||||
assert struct_pyd_from_dict.to_json() == data_json
|
||||
|
||||
struct_pyd_from_dict = StructPydantic(fields={}).from_json(data_json)
|
||||
assert struct_pyd_from_dict.fields == data
|
||||
assert struct_pyd_from_dict.to_dict() == data
|
||||
assert struct_pyd_from_dict == struct_pyd_from_dict
|
||||
assert struct_pyd_from_dict.to_json() == data_json
|
27
tests/test_timestamp.py
Normal file
27
tests/test_timestamp.py
Normal file
@ -0,0 +1,27 @@
|
||||
from datetime import (
|
||||
datetime,
|
||||
timezone,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from betterproto import _Timestamp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dt",
|
||||
[
|
||||
datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
|
||||
datetime.now(timezone.utc),
|
||||
# potential issue with floating point precision:
|
||||
datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
|
||||
# potential issue with negative timestamps:
|
||||
datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
|
||||
],
|
||||
)
|
||||
def test_timestamp_to_datetime_and_back(dt: datetime):
|
||||
"""
|
||||
Make sure converting a datetime to a protobuf timestamp message
|
||||
and then back again ends up with the same datetime.
|
||||
"""
|
||||
assert _Timestamp.from_datetime(dt).to_datetime() == dt
|
78
tests/test_typing_compiler.py
Normal file
78
tests/test_typing_compiler.py
Normal file
@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.plugin.typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
NoTyping310TypingCompiler,
|
||||
TypingImportTypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
def test_direct_import_typing_compiler():
|
||||
compiler = DirectImportTypingCompiler()
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.optional("str") == "Optional[str]"
|
||||
assert compiler.imports() == {"typing": {"Optional"}}
|
||||
assert compiler.list("str") == "List[str]"
|
||||
assert compiler.imports() == {"typing": {"Optional", "List"}}
|
||||
assert compiler.dict("str", "int") == "Dict[str, int]"
|
||||
assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}}
|
||||
assert compiler.union("str", "int") == "Union[str, int]"
|
||||
assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}}
|
||||
assert compiler.iterable("str") == "Iterable[str]"
|
||||
assert compiler.imports() == {
|
||||
"typing": {"Optional", "List", "Dict", "Union", "Iterable"}
|
||||
}
|
||||
assert compiler.async_iterable("str") == "AsyncIterable[str]"
|
||||
assert compiler.imports() == {
|
||||
"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}
|
||||
}
|
||||
assert compiler.async_iterator("str") == "AsyncIterator[str]"
|
||||
assert compiler.imports() == {
|
||||
"typing": {
|
||||
"Optional",
|
||||
"List",
|
||||
"Dict",
|
||||
"Union",
|
||||
"Iterable",
|
||||
"AsyncIterable",
|
||||
"AsyncIterator",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_typing_import_typing_compiler():
|
||||
compiler = TypingImportTypingCompiler()
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.optional("str") == "typing.Optional[str]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
assert compiler.list("str") == "typing.List[str]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
assert compiler.dict("str", "int") == "typing.Dict[str, int]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
assert compiler.union("str", "int") == "typing.Union[str, int]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
assert compiler.iterable("str") == "typing.Iterable[str]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
assert compiler.async_iterable("str") == "typing.AsyncIterable[str]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
assert compiler.async_iterator("str") == "typing.AsyncIterator[str]"
|
||||
assert compiler.imports() == {"typing": None}
|
||||
|
||||
|
||||
def test_no_typing_311_typing_compiler():
|
||||
compiler = NoTyping310TypingCompiler()
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.optional("str") == '"str | None"'
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.list("str") == '"list[str]"'
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.dict("str", "int") == '"dict[str, int]"'
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.union("str", "int") == '"str | int"'
|
||||
assert compiler.imports() == {}
|
||||
assert compiler.iterable("str") == '"Iterable[str]"'
|
||||
assert compiler.async_iterable("str") == '"AsyncIterable[str]"'
|
||||
assert compiler.async_iterator("str") == '"AsyncIterator[str]"'
|
||||
assert compiler.imports() == {
|
||||
"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}
|
||||
}
|
@ -11,6 +11,6 @@ PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve()
|
||||
def test_version():
|
||||
with PROJECT_TOML.open() as toml_file:
|
||||
project_config = tomlkit.loads(toml_file.read())
|
||||
assert (
|
||||
__version__ == project_config["tool"]["poetry"]["version"]
|
||||
), "Project version should match in package and package config"
|
||||
assert __version__ == project_config["project"]["version"], (
|
||||
"Project version should match in package and package config"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user