Merge remote-tracking branch 'daniel/master' into fix/service-input-message
# Conflicts: # betterproto/plugin.py
This commit is contained in:
commit
72d72b4603
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,3 +14,4 @@ output
|
|||||||
.idea
|
.idea
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.tox
|
.tox
|
||||||
|
.venv
|
||||||
|
@ -22,7 +22,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ._types import T
|
from ._types import T
|
||||||
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
|
from .casing import camel_case, safe_snake_case, snake_case
|
||||||
from .grpc.grpclib_client import ServiceStub
|
from .grpc.grpclib_client import ServiceStub
|
||||||
|
|
||||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||||
@ -378,7 +378,7 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
|
|||||||
result |= (b & 0x7F) << shift
|
result |= (b & 0x7F) << shift
|
||||||
pos += 1
|
pos += 1
|
||||||
if not (b & 0x80):
|
if not (b & 0x80):
|
||||||
return (result, pos)
|
return result, pos
|
||||||
shift += 7
|
shift += 7
|
||||||
if shift >= 64:
|
if shift >= 64:
|
||||||
raise ValueError("Too many bytes when decoding varint.")
|
raise ValueError("Too many bytes when decoding varint.")
|
||||||
@ -479,7 +479,7 @@ class ProtoClassMetadata:
|
|||||||
assert meta.map_types
|
assert meta.map_types
|
||||||
kt = cls._cls_for(field, index=0)
|
kt = cls._cls_for(field, index=0)
|
||||||
vt = cls._cls_for(field, index=1)
|
vt = cls._cls_for(field, index=1)
|
||||||
Entry = dataclasses.make_dataclass(
|
field_cls[field.name] = dataclasses.make_dataclass(
|
||||||
"Entry",
|
"Entry",
|
||||||
[
|
[
|
||||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||||
@ -487,7 +487,6 @@ class ProtoClassMetadata:
|
|||||||
],
|
],
|
||||||
bases=(Message,),
|
bases=(Message,),
|
||||||
)
|
)
|
||||||
field_cls[field.name] = Entry
|
|
||||||
field_cls[field.name + ".value"] = vt
|
field_cls[field.name + ".value"] = vt
|
||||||
else:
|
else:
|
||||||
field_cls[field.name] = cls._cls_for(field)
|
field_cls[field.name] = cls._cls_for(field)
|
||||||
@ -588,7 +587,7 @@ class Message(ABC):
|
|||||||
serialize_empty = False
|
serialize_empty = False
|
||||||
if isinstance(value, Message) and value._serialized_on_wire:
|
if isinstance(value, Message) and value._serialized_on_wire:
|
||||||
# Empty messages can still be sent on the wire if they were
|
# Empty messages can still be sent on the wire if they were
|
||||||
# set (or recieved empty).
|
# set (or received empty).
|
||||||
serialize_empty = True
|
serialize_empty = True
|
||||||
|
|
||||||
if value == self._get_field_default(field_name) and not (
|
if value == self._get_field_default(field_name) and not (
|
||||||
@ -926,7 +925,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
# Circular import workaround: google.protobuf depends on base classes defined above.
|
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||||
from .lib.google.protobuf import (
|
from .lib.google.protobuf import ( # noqa
|
||||||
Duration,
|
Duration,
|
||||||
Timestamp,
|
Timestamp,
|
||||||
BoolValue,
|
BoolValue,
|
||||||
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, TypeVar
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from . import Message
|
from . import Message
|
||||||
from grpclib._protocols import IProtoMessage
|
from grpclib._typing import IProtoMessage
|
||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
T = TypeVar("T", bound="Message")
|
T = TypeVar("T", bound="Message")
|
||||||
|
@ -20,6 +20,8 @@ def safe_snake_case(value: str) -> str:
|
|||||||
"and",
|
"and",
|
||||||
"as",
|
"as",
|
||||||
"assert",
|
"assert",
|
||||||
|
"async",
|
||||||
|
"await",
|
||||||
"break",
|
"break",
|
||||||
"class",
|
"class",
|
||||||
"continue",
|
"continue",
|
||||||
|
@ -2,7 +2,6 @@ from abc import ABC
|
|||||||
import asyncio
|
import asyncio
|
||||||
import grpclib.const
|
import grpclib.const
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Collection,
|
Collection,
|
||||||
@ -17,8 +16,8 @@ from typing import (
|
|||||||
from .._types import ST, T
|
from .._types import ST, T
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from grpclib._protocols import IProtoMessage
|
from grpclib._typing import IProtoMessage
|
||||||
from grpclib.client import Channel, Stream
|
from grpclib.client import Channel
|
||||||
from grpclib.metadata import Deadline
|
from grpclib.metadata import Deadline
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ class ChannelClosed(Exception):
|
|||||||
|
|
||||||
class ChannelDone(Exception):
|
class ChannelDone(Exception):
|
||||||
"""
|
"""
|
||||||
An exception raised on an attempt to send recieve from a channel that is both closed
|
An exception raised on an attempt to send receive from a channel that is both closed
|
||||||
and empty.
|
and empty.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -32,41 +32,41 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
"""
|
"""
|
||||||
A buffered async channel for sending items between coroutines with FIFO ordering.
|
A buffered async channel for sending items between coroutines with FIFO ordering.
|
||||||
|
|
||||||
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
This makes decoupled bidirectional steaming gRPC requests easy if used like:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
client = GeneratedStub(grpclib_chan)
|
client = GeneratedStub(grpclib_chan)
|
||||||
request_chan = await AsyncChannel()
|
request_channel = await AsyncChannel()
|
||||||
# We can start be sending all the requests we already have
|
# We can start be sending all the requests we already have
|
||||||
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
|
await request_channel.send_from([RequestObject(...), RequestObject(...)])
|
||||||
async for response in client.rpc_call(request_chan):
|
async for response in client.rpc_call(request_channel):
|
||||||
# The response iterator will remain active until the connection is closed
|
# The response iterator will remain active until the connection is closed
|
||||||
...
|
...
|
||||||
# More items can be sent at any time
|
# More items can be sent at any time
|
||||||
await request_chan.send(ReqestObject(...))
|
await request_channel.send(RequestObject(...))
|
||||||
...
|
...
|
||||||
# The channel must be closed to complete the gRPC connection
|
# The channel must be closed to complete the gRPC connection
|
||||||
request_chan.close()
|
request_channel.close()
|
||||||
|
|
||||||
Items can be sent through the channel by either:
|
Items can be sent through the channel by either:
|
||||||
- providing an iterable to the send_from method
|
- providing an iterable to the send_from method
|
||||||
- passing them to the send method one at a time
|
- passing them to the send method one at a time
|
||||||
|
|
||||||
Items can be recieved from the channel by either:
|
Items can be received from the channel by either:
|
||||||
- iterating over the channel with a for loop to get all items
|
- iterating over the channel with a for loop to get all items
|
||||||
- calling the recieve method to get one item at a time
|
- calling the receive method to get one item at a time
|
||||||
|
|
||||||
If the channel is empty then recievers will wait until either an item appears or the
|
If the channel is empty then receivers will wait until either an item appears or the
|
||||||
channel is closed.
|
channel is closed.
|
||||||
|
|
||||||
Once the channel is closed then subsequent attempt to send through the channel will
|
Once the channel is closed then subsequent attempt to send through the channel will
|
||||||
fail with a ChannelClosed exception.
|
fail with a ChannelClosed exception.
|
||||||
|
|
||||||
When th channel is closed and empty then it is done, and further attempts to recieve
|
When th channel is closed and empty then it is done, and further attempts to receive
|
||||||
from it will fail with a ChannelDone exception
|
from it will fail with a ChannelDone exception
|
||||||
|
|
||||||
If multiple coroutines recieve from the channel concurrently, each item sent will be
|
If multiple coroutines receive from the channel concurrently, each item sent will be
|
||||||
recieved by only one of the recievers.
|
received by only one of the receivers.
|
||||||
|
|
||||||
:param source:
|
:param source:
|
||||||
An optional iterable will items that should be sent through the channel
|
An optional iterable will items that should be sent through the channel
|
||||||
@ -74,7 +74,7 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
:param buffer_limit:
|
:param buffer_limit:
|
||||||
Limit the number of items that can be buffered in the channel, A value less than
|
Limit the number of items that can be buffered in the channel, A value less than
|
||||||
1 implies no limit. If the channel is full then attempts to send more items will
|
1 implies no limit. If the channel is full then attempts to send more items will
|
||||||
result in the sender waiting until an item is recieved from the channel.
|
result in the sender waiting until an item is received from the channel.
|
||||||
:param close:
|
:param close:
|
||||||
If set to True then the channel will automatically close after exhausting source
|
If set to True then the channel will automatically close after exhausting source
|
||||||
or immediately if no source is provided.
|
or immediately if no source is provided.
|
||||||
@ -85,7 +85,7 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
):
|
):
|
||||||
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._waiting_recievers: int = 0
|
self._waiting_receivers: int = 0
|
||||||
# Track whether flush has been invoked so it can only happen once
|
# Track whether flush has been invoked so it can only happen once
|
||||||
self._flushed = False
|
self._flushed = False
|
||||||
|
|
||||||
@ -95,14 +95,14 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
async def __anext__(self) -> T:
|
async def __anext__(self) -> T:
|
||||||
if self.done():
|
if self.done():
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
self._waiting_recievers += 1
|
self._waiting_receivers += 1
|
||||||
try:
|
try:
|
||||||
result = await self._queue.get()
|
result = await self._queue.get()
|
||||||
if result is self.__flush:
|
if result is self.__flush:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
self._waiting_recievers -= 1
|
self._waiting_receivers -= 1
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
def closed(self) -> bool:
|
def closed(self) -> bool:
|
||||||
@ -116,12 +116,12 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
Check if this channel is done.
|
Check if this channel is done.
|
||||||
|
|
||||||
:return: True if this channel is closed and and has been drained of items in
|
:return: True if this channel is closed and and has been drained of items in
|
||||||
which case any further attempts to recieve an item from this channel will raise
|
which case any further attempts to receive an item from this channel will raise
|
||||||
a ChannelDone exception.
|
a ChannelDone exception.
|
||||||
"""
|
"""
|
||||||
# After close the channel is not yet done until there is at least one waiting
|
# After close the channel is not yet done until there is at least one waiting
|
||||||
# reciever per enqueued item.
|
# receiver per enqueued item.
|
||||||
return self._closed and self._queue.qsize() <= self._waiting_recievers
|
return self._closed and self._queue.qsize() <= self._waiting_receivers
|
||||||
|
|
||||||
async def send_from(
|
async def send_from(
|
||||||
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||||
@ -158,22 +158,22 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
await self._queue.put(item)
|
await self._queue.put(item)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def recieve(self) -> Optional[T]:
|
async def receive(self) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Returns the next item from this channel when it becomes available,
|
Returns the next item from this channel when it becomes available,
|
||||||
or None if the channel is closed before another item is sent.
|
or None if the channel is closed before another item is sent.
|
||||||
:return: An item from the channel
|
:return: An item from the channel
|
||||||
"""
|
"""
|
||||||
if self.done():
|
if self.done():
|
||||||
raise ChannelDone("Cannot recieve from a closed channel")
|
raise ChannelDone("Cannot receive from a closed channel")
|
||||||
self._waiting_recievers += 1
|
self._waiting_receivers += 1
|
||||||
try:
|
try:
|
||||||
result = await self._queue.get()
|
result = await self._queue.get()
|
||||||
if result is self.__flush:
|
if result is self.__flush:
|
||||||
return None
|
return None
|
||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
self._waiting_recievers -= 1
|
self._waiting_receivers -= 1
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@ -190,8 +190,8 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
"""
|
"""
|
||||||
if not self._flushed:
|
if not self._flushed:
|
||||||
self._flushed = True
|
self._flushed = True
|
||||||
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
|
deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
|
||||||
for _ in range(deadlocked_recievers):
|
for _ in range(deadlocked_receivers):
|
||||||
await self._queue.put(self.__flush)
|
await self._queue.put(self.__flush)
|
||||||
|
|
||||||
# A special signal object for flushing the queue when the channel is closed
|
# A special signal object for flushing the queue when the channel is closed
|
||||||
|
@ -199,7 +199,7 @@ def generate_code(request, response):
|
|||||||
# Render and then format the output file.
|
# Render and then format the output file.
|
||||||
f.content = black.format_str(
|
f.content = black.format_str(
|
||||||
template.render(description=template_data),
|
template.render(description=template_data),
|
||||||
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make each output directory a package with __init__ file
|
# Make each output directory a package with __init__ file
|
||||||
|
@ -3,7 +3,6 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ from betterproto.tests.output_betterproto.service.service import (
|
|||||||
DoThingResponse,
|
DoThingResponse,
|
||||||
DoThingRequest,
|
DoThingRequest,
|
||||||
GetThingRequest,
|
GetThingRequest,
|
||||||
GetThingResponse,
|
|
||||||
TestStub as ThingServiceClient,
|
TestStub as ThingServiceClient,
|
||||||
)
|
)
|
||||||
import grpclib
|
import grpclib
|
||||||
@ -18,14 +17,14 @@ async def _test_client(client, name="clean room", **kwargs):
|
|||||||
assert response.names == [name]
|
assert response.names == [name]
|
||||||
|
|
||||||
|
|
||||||
def _assert_request_meta_recieved(deadline, metadata):
|
def _assert_request_meta_received(deadline, metadata):
|
||||||
def server_side_test(stream):
|
def server_side_test(stream):
|
||||||
assert stream.deadline._timestamp == pytest.approx(
|
assert stream.deadline._timestamp == pytest.approx(
|
||||||
deadline._timestamp, 1
|
deadline._timestamp, 1
|
||||||
), "The provided deadline should be recieved serverside"
|
), "The provided deadline should be received serverside"
|
||||||
assert (
|
assert (
|
||||||
stream.metadata["authorization"] == metadata["authorization"]
|
stream.metadata["authorization"] == metadata["authorization"]
|
||||||
), "The provided authorization metadata should be recieved serverside"
|
), "The provided authorization metadata should be received serverside"
|
||||||
|
|
||||||
return server_side_test
|
return server_side_test
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ async def test_service_call_with_upfront_request_params():
|
|||||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||||
metadata = {"authorization": "12345"}
|
metadata = {"authorization": "12345"}
|
||||||
async with ChannelFor(
|
async with ChannelFor(
|
||||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
|
||||||
) as channel:
|
) as channel:
|
||||||
await _test_client(
|
await _test_client(
|
||||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
@ -53,7 +52,7 @@ async def test_service_call_with_upfront_request_params():
|
|||||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||||
metadata = {"authorization": "12345"}
|
metadata = {"authorization": "12345"}
|
||||||
async with ChannelFor(
|
async with ChannelFor(
|
||||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
|
||||||
) as channel:
|
) as channel:
|
||||||
await _test_client(
|
await _test_client(
|
||||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||||
@ -70,7 +69,7 @@ async def test_service_call_lower_level_with_overrides():
|
|||||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||||
kwarg_metadata = {"authorization": "12345"}
|
kwarg_metadata = {"authorization": "12345"}
|
||||||
async with ChannelFor(
|
async with ChannelFor(
|
||||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
|
||||||
) as channel:
|
) as channel:
|
||||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
response = await client._unary_unary(
|
response = await client._unary_unary(
|
||||||
@ -92,7 +91,7 @@ async def test_service_call_lower_level_with_overrides():
|
|||||||
async with ChannelFor(
|
async with ChannelFor(
|
||||||
[
|
[
|
||||||
ThingService(
|
ThingService(
|
||||||
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
) as channel:
|
) as channel:
|
||||||
@ -140,8 +139,8 @@ async def test_async_gen_for_stream_stream_request():
|
|||||||
assert response.version == response_index + 1
|
assert response.version == response_index + 1
|
||||||
response_index += 1
|
response_index += 1
|
||||||
if more_things:
|
if more_things:
|
||||||
# Send some more requests as we recieve reponses to be sure coordination of
|
# Send some more requests as we receive responses to be sure coordination of
|
||||||
# send/recieve events doesn't matter
|
# send/receive events doesn't matter
|
||||||
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||||
elif not send_initial_requests.done():
|
elif not send_initial_requests.done():
|
||||||
# Make sure the sending task it completed
|
# Make sure the sending task it completed
|
||||||
@ -151,4 +150,4 @@ async def test_async_gen_for_stream_stream_request():
|
|||||||
request_chan.close()
|
request_chan.close()
|
||||||
assert response_index == len(
|
assert response_index == len(
|
||||||
expected_things
|
expected_things
|
||||||
), "Didn't recieve all exptected responses"
|
), "Didn't receive all expected responses"
|
||||||
|
@ -3,10 +3,10 @@ from betterproto.tests.output_betterproto.service.service import (
|
|||||||
DoThingRequest,
|
DoThingRequest,
|
||||||
GetThingRequest,
|
GetThingRequest,
|
||||||
GetThingResponse,
|
GetThingResponse,
|
||||||
TestStub as ThingServiceClient,
|
|
||||||
)
|
)
|
||||||
import grpclib
|
import grpclib
|
||||||
from typing import Any, Dict
|
import grpclib.server
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
class ThingService:
|
class ThingService:
|
||||||
|
@ -21,7 +21,7 @@ test_cases = [
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||||
async def test_channel_recieves_wrapped_type(
|
async def test_channel_receives_wrapped_type(
|
||||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||||
):
|
):
|
||||||
wrapped_value = wrapper_class()
|
wrapped_value = wrapper_class()
|
||||||
|
@ -14,23 +14,23 @@ def test_has_field():
|
|||||||
|
|
||||||
# Unset by default
|
# Unset by default
|
||||||
foo = Foo()
|
foo = Foo()
|
||||||
assert betterproto.serialized_on_wire(foo.bar) == False
|
assert betterproto.serialized_on_wire(foo.bar) is False
|
||||||
|
|
||||||
# Serialized after setting something
|
# Serialized after setting something
|
||||||
foo.bar.baz = 1
|
foo.bar.baz = 1
|
||||||
assert betterproto.serialized_on_wire(foo.bar) == True
|
assert betterproto.serialized_on_wire(foo.bar) is True
|
||||||
|
|
||||||
# Still has it after setting the default value
|
# Still has it after setting the default value
|
||||||
foo.bar.baz = 0
|
foo.bar.baz = 0
|
||||||
assert betterproto.serialized_on_wire(foo.bar) == True
|
assert betterproto.serialized_on_wire(foo.bar) is True
|
||||||
|
|
||||||
# Manual override (don't do this)
|
# Manual override (don't do this)
|
||||||
foo.bar._serialized_on_wire = False
|
foo.bar._serialized_on_wire = False
|
||||||
assert betterproto.serialized_on_wire(foo.bar) == False
|
assert betterproto.serialized_on_wire(foo.bar) is False
|
||||||
|
|
||||||
# Can manually set it but defaults to false
|
# Can manually set it but defaults to false
|
||||||
foo.bar = Bar()
|
foo.bar = Bar()
|
||||||
assert betterproto.serialized_on_wire(foo.bar) == False
|
assert betterproto.serialized_on_wire(foo.bar) is False
|
||||||
|
|
||||||
|
|
||||||
def test_class_init():
|
def test_class_init():
|
||||||
@ -118,7 +118,7 @@ def test_oneof_support():
|
|||||||
|
|
||||||
# Group 1 shouldn't be touched, group 2 should have reset
|
# Group 1 shouldn't be touched, group 2 should have reset
|
||||||
assert foo.sub.val == 0
|
assert foo.sub.val == 0
|
||||||
assert betterproto.serialized_on_wire(foo.sub) == False
|
assert betterproto.serialized_on_wire(foo.sub) is False
|
||||||
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
|
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
|
||||||
|
|
||||||
# Zero value should always serialize for one-of
|
# Zero value should always serialize for one-of
|
||||||
@ -175,8 +175,8 @@ def test_optional_flag():
|
|||||||
assert bytes(Request(flag=False)) == b"\n\x00"
|
assert bytes(Request(flag=False)) == b"\n\x00"
|
||||||
|
|
||||||
# Differentiate between not passed and the zero-value.
|
# Differentiate between not passed and the zero-value.
|
||||||
assert Request().parse(b"").flag == None
|
assert Request().parse(b"").flag is None
|
||||||
assert Request().parse(b"\n\x00").flag == False
|
assert Request().parse(b"\n\x00").flag is False
|
||||||
|
|
||||||
|
|
||||||
def test_to_dict_default_values():
|
def test_to_dict_default_values():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user