1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -12,4 +12,5 @@ dist
 | 
				
			|||||||
**/*.egg-info
 | 
					**/*.egg-info
 | 
				
			||||||
output
 | 
					output
 | 
				
			||||||
.idea
 | 
					.idea
 | 
				
			||||||
 | 
					.DS_Store
 | 
				
			||||||
.tox
 | 
					.tox
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,8 +5,9 @@ import json
 | 
				
			|||||||
import struct
 | 
					import struct
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from abc import ABC
 | 
					from abc import ABC
 | 
				
			||||||
from base64 import b64encode, b64decode
 | 
					from base64 import b64decode, b64encode
 | 
				
			||||||
from datetime import datetime, timedelta, timezone
 | 
					from datetime import datetime, timedelta, timezone
 | 
				
			||||||
 | 
					import stringcase
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
    Any,
 | 
					    Any,
 | 
				
			||||||
    AsyncGenerator,
 | 
					    AsyncGenerator,
 | 
				
			||||||
@@ -14,28 +15,20 @@ from typing import (
 | 
				
			|||||||
    Collection,
 | 
					    Collection,
 | 
				
			||||||
    Dict,
 | 
					    Dict,
 | 
				
			||||||
    Generator,
 | 
					    Generator,
 | 
				
			||||||
 | 
					    Iterator,
 | 
				
			||||||
    List,
 | 
					    List,
 | 
				
			||||||
    Mapping,
 | 
					    Mapping,
 | 
				
			||||||
    Optional,
 | 
					    Optional,
 | 
				
			||||||
    Set,
 | 
					    Set,
 | 
				
			||||||
 | 
					    SupportsBytes,
 | 
				
			||||||
    Tuple,
 | 
					    Tuple,
 | 
				
			||||||
    Type,
 | 
					    Type,
 | 
				
			||||||
    TypeVar,
 | 
					 | 
				
			||||||
    Union,
 | 
					    Union,
 | 
				
			||||||
    get_type_hints,
 | 
					    get_type_hints,
 | 
				
			||||||
    TYPE_CHECKING,
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from ._types import ST, T
 | 
				
			||||||
 | 
					 | 
				
			||||||
import grpclib.const
 | 
					 | 
				
			||||||
import stringcase
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .casing import safe_snake_case
 | 
					from .casing import safe_snake_case
 | 
				
			||||||
 | 
					from .grpc.grpclib_client import ServiceStub
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					 | 
				
			||||||
    from grpclib._protocols import IProtoMessage
 | 
					 | 
				
			||||||
    from grpclib.client import Channel
 | 
					 | 
				
			||||||
    from grpclib.metadata import Deadline
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
 | 
					if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
 | 
				
			||||||
    # Apply backport of datetime.fromisoformat from 3.7
 | 
					    # Apply backport of datetime.fromisoformat from 3.7
 | 
				
			||||||
@@ -429,10 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Bound type variable to allow methods to return `self` of subclasses
 | 
					 | 
				
			||||||
T = TypeVar("T", bound="Message")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ProtoClassMetadata:
 | 
					class ProtoClassMetadata:
 | 
				
			||||||
    oneof_group_by_field: Dict[str, str]
 | 
					    oneof_group_by_field: Dict[str, str]
 | 
				
			||||||
    oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
 | 
					    oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
 | 
				
			||||||
@@ -451,7 +440,7 @@ class ProtoClassMetadata:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __init__(self, cls: Type["Message"]):
 | 
					    def __init__(self, cls: Type["Message"]):
 | 
				
			||||||
        by_field = {}
 | 
					        by_field = {}
 | 
				
			||||||
        by_group = {}
 | 
					        by_group: Dict[str, Set] = {}
 | 
				
			||||||
        by_field_name = {}
 | 
					        by_field_name = {}
 | 
				
			||||||
        by_field_number = {}
 | 
					        by_field_number = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -604,7 +593,7 @@ class Message(ABC):
 | 
				
			|||||||
            serialize_empty = False
 | 
					            serialize_empty = False
 | 
				
			||||||
            if isinstance(value, Message) and value._serialized_on_wire:
 | 
					            if isinstance(value, Message) and value._serialized_on_wire:
 | 
				
			||||||
                # Empty messages can still be sent on the wire if they were
 | 
					                # Empty messages can still be sent on the wire if they were
 | 
				
			||||||
                # set (or received empty).
 | 
					                # set (or recieved empty).
 | 
				
			||||||
                serialize_empty = True
 | 
					                serialize_empty = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if value == self._get_field_default(field_name) and not (
 | 
					            if value == self._get_field_default(field_name) and not (
 | 
				
			||||||
@@ -791,7 +780,7 @@ class Message(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def to_dict(
 | 
					    def to_dict(
 | 
				
			||||||
        self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
 | 
					        self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
 | 
				
			||||||
    ) -> dict:
 | 
					    ) -> Dict[str, Any]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Returns a dict representation of this message instance which can be
 | 
					        Returns a dict representation of this message instance which can be
 | 
				
			||||||
        used to serialize to e.g. JSON. Defaults to camel casing for
 | 
					        used to serialize to e.g. JSON. Defaults to camel casing for
 | 
				
			||||||
@@ -1024,83 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type:
 | 
				
			|||||||
        TYPE_STRING: StringValue,
 | 
					        TYPE_STRING: StringValue,
 | 
				
			||||||
        TYPE_BYTES: BytesValue,
 | 
					        TYPE_BYTES: BytesValue,
 | 
				
			||||||
    }[proto_type]
 | 
					    }[proto_type]
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_Value = Union[str, bytes]
 | 
					 | 
				
			||||||
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ServiceStub(ABC):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Base class for async gRPC service stubs.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        channel: "Channel",
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        timeout: Optional[float] = None,
 | 
					 | 
				
			||||||
        deadline: Optional["Deadline"] = None,
 | 
					 | 
				
			||||||
        metadata: Optional[_MetadataLike] = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.channel = channel
 | 
					 | 
				
			||||||
        self.timeout = timeout
 | 
					 | 
				
			||||||
        self.deadline = deadline
 | 
					 | 
				
			||||||
        self.metadata = metadata
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __resolve_request_kwargs(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        timeout: Optional[float],
 | 
					 | 
				
			||||||
        deadline: Optional["Deadline"],
 | 
					 | 
				
			||||||
        metadata: Optional[_MetadataLike],
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "timeout": self.timeout if timeout is None else timeout,
 | 
					 | 
				
			||||||
            "deadline": self.deadline if deadline is None else deadline,
 | 
					 | 
				
			||||||
            "metadata": self.metadata if metadata is None else metadata,
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def _unary_unary(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        route: str,
 | 
					 | 
				
			||||||
        request: "IProtoMessage",
 | 
					 | 
				
			||||||
        response_type: Type[T],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        timeout: Optional[float] = None,
 | 
					 | 
				
			||||||
        deadline: Optional["Deadline"] = None,
 | 
					 | 
				
			||||||
        metadata: Optional[_MetadataLike] = None,
 | 
					 | 
				
			||||||
    ) -> T:
 | 
					 | 
				
			||||||
        """Make a unary request and return the response."""
 | 
					 | 
				
			||||||
        async with self.channel.request(
 | 
					 | 
				
			||||||
            route,
 | 
					 | 
				
			||||||
            grpclib.const.Cardinality.UNARY_UNARY,
 | 
					 | 
				
			||||||
            type(request),
 | 
					 | 
				
			||||||
            response_type,
 | 
					 | 
				
			||||||
            **self.__resolve_request_kwargs(timeout, deadline, metadata),
 | 
					 | 
				
			||||||
        ) as stream:
 | 
					 | 
				
			||||||
            await stream.send_message(request, end=True)
 | 
					 | 
				
			||||||
            response = await stream.recv_message()
 | 
					 | 
				
			||||||
            assert response is not None
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def _unary_stream(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        route: str,
 | 
					 | 
				
			||||||
        request: "IProtoMessage",
 | 
					 | 
				
			||||||
        response_type: Type[T],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        timeout: Optional[float] = None,
 | 
					 | 
				
			||||||
        deadline: Optional["Deadline"] = None,
 | 
					 | 
				
			||||||
        metadata: Optional[_MetadataLike] = None,
 | 
					 | 
				
			||||||
    ) -> AsyncGenerator[T, None]:
 | 
					 | 
				
			||||||
        """Make a unary request and return the stream response iterator."""
 | 
					 | 
				
			||||||
        async with self.channel.request(
 | 
					 | 
				
			||||||
            route,
 | 
					 | 
				
			||||||
            grpclib.const.Cardinality.UNARY_STREAM,
 | 
					 | 
				
			||||||
            type(request),
 | 
					 | 
				
			||||||
            response_type,
 | 
					 | 
				
			||||||
            **self.__resolve_request_kwargs(timeout, deadline, metadata),
 | 
					 | 
				
			||||||
        ) as stream:
 | 
					 | 
				
			||||||
            await stream.send_message(request, end=True)
 | 
					 | 
				
			||||||
            async for message in stream:
 | 
					 | 
				
			||||||
                yield message
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										9
									
								
								betterproto/_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								betterproto/_types.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					from typing import TYPE_CHECKING, TypeVar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from . import Message
 | 
				
			||||||
 | 
					    from grpclib._protocols import IProtoMessage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Bound type variable to allow methods to return `self` of subclasses
 | 
				
			||||||
 | 
					T = TypeVar("T", bound="Message")
 | 
				
			||||||
 | 
					ST = TypeVar("ST", bound="IProtoMessage")
 | 
				
			||||||
							
								
								
									
										0
									
								
								betterproto/grpc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/grpc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										170
									
								
								betterproto/grpc/grpclib_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								betterproto/grpc/grpclib_client.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,170 @@
 | 
				
			|||||||
 | 
					from abc import ABC
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import grpclib.const
 | 
				
			||||||
 | 
					from typing import (
 | 
				
			||||||
 | 
					    Any,
 | 
				
			||||||
 | 
					    AsyncIterable,
 | 
				
			||||||
 | 
					    AsyncIterator,
 | 
				
			||||||
 | 
					    Collection,
 | 
				
			||||||
 | 
					    Iterable,
 | 
				
			||||||
 | 
					    Mapping,
 | 
				
			||||||
 | 
					    Optional,
 | 
				
			||||||
 | 
					    Tuple,
 | 
				
			||||||
 | 
					    TYPE_CHECKING,
 | 
				
			||||||
 | 
					    Type,
 | 
				
			||||||
 | 
					    Union,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from .._types import ST, T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from grpclib._protocols import IProtoMessage
 | 
				
			||||||
 | 
					    from grpclib.client import Channel, Stream
 | 
				
			||||||
 | 
					    from grpclib.metadata import Deadline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_Value = Union[str, bytes]
 | 
				
			||||||
 | 
					_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
 | 
				
			||||||
 | 
					_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ServiceStub(ABC):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Base class for async gRPC clients.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        channel: "Channel",
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        deadline: Optional["Deadline"] = None,
 | 
				
			||||||
 | 
					        metadata: Optional[_MetadataLike] = None,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.channel = channel
 | 
				
			||||||
 | 
					        self.timeout = timeout
 | 
				
			||||||
 | 
					        self.deadline = deadline
 | 
				
			||||||
 | 
					        self.metadata = metadata
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __resolve_request_kwargs(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        timeout: Optional[float],
 | 
				
			||||||
 | 
					        deadline: Optional["Deadline"],
 | 
				
			||||||
 | 
					        metadata: Optional[_MetadataLike],
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "timeout": self.timeout if timeout is None else timeout,
 | 
				
			||||||
 | 
					            "deadline": self.deadline if deadline is None else deadline,
 | 
				
			||||||
 | 
					            "metadata": self.metadata if metadata is None else metadata,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _unary_unary(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        route: str,
 | 
				
			||||||
 | 
					        request: "IProtoMessage",
 | 
				
			||||||
 | 
					        response_type: Type[T],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        deadline: Optional["Deadline"] = None,
 | 
				
			||||||
 | 
					        metadata: Optional[_MetadataLike] = None,
 | 
				
			||||||
 | 
					    ) -> T:
 | 
				
			||||||
 | 
					        """Make a unary request and return the response."""
 | 
				
			||||||
 | 
					        async with self.channel.request(
 | 
				
			||||||
 | 
					            route,
 | 
				
			||||||
 | 
					            grpclib.const.Cardinality.UNARY_UNARY,
 | 
				
			||||||
 | 
					            type(request),
 | 
				
			||||||
 | 
					            response_type,
 | 
				
			||||||
 | 
					            **self.__resolve_request_kwargs(timeout, deadline, metadata),
 | 
				
			||||||
 | 
					        ) as stream:
 | 
				
			||||||
 | 
					            await stream.send_message(request, end=True)
 | 
				
			||||||
 | 
					            response = await stream.recv_message()
 | 
				
			||||||
 | 
					            assert response is not None
 | 
				
			||||||
 | 
					            return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _unary_stream(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        route: str,
 | 
				
			||||||
 | 
					        request: "IProtoMessage",
 | 
				
			||||||
 | 
					        response_type: Type[T],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        deadline: Optional["Deadline"] = None,
 | 
				
			||||||
 | 
					        metadata: Optional[_MetadataLike] = None,
 | 
				
			||||||
 | 
					    ) -> AsyncIterator[T]:
 | 
				
			||||||
 | 
					        """Make a unary request and return the stream response iterator."""
 | 
				
			||||||
 | 
					        async with self.channel.request(
 | 
				
			||||||
 | 
					            route,
 | 
				
			||||||
 | 
					            grpclib.const.Cardinality.UNARY_STREAM,
 | 
				
			||||||
 | 
					            type(request),
 | 
				
			||||||
 | 
					            response_type,
 | 
				
			||||||
 | 
					            **self.__resolve_request_kwargs(timeout, deadline, metadata),
 | 
				
			||||||
 | 
					        ) as stream:
 | 
				
			||||||
 | 
					            await stream.send_message(request, end=True)
 | 
				
			||||||
 | 
					            async for message in stream:
 | 
				
			||||||
 | 
					                yield message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _stream_unary(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        route: str,
 | 
				
			||||||
 | 
					        request_iterator: _MessageSource,
 | 
				
			||||||
 | 
					        request_type: Type[ST],
 | 
				
			||||||
 | 
					        response_type: Type[T],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        deadline: Optional["Deadline"] = None,
 | 
				
			||||||
 | 
					        metadata: Optional[_MetadataLike] = None,
 | 
				
			||||||
 | 
					    ) -> T:
 | 
				
			||||||
 | 
					        """Make a stream request and return the response."""
 | 
				
			||||||
 | 
					        async with self.channel.request(
 | 
				
			||||||
 | 
					            route,
 | 
				
			||||||
 | 
					            grpclib.const.Cardinality.STREAM_UNARY,
 | 
				
			||||||
 | 
					            request_type,
 | 
				
			||||||
 | 
					            response_type,
 | 
				
			||||||
 | 
					            **self.__resolve_request_kwargs(timeout, deadline, metadata),
 | 
				
			||||||
 | 
					        ) as stream:
 | 
				
			||||||
 | 
					            await self._send_messages(stream, request_iterator)
 | 
				
			||||||
 | 
					            response = await stream.recv_message()
 | 
				
			||||||
 | 
					            assert response is not None
 | 
				
			||||||
 | 
					            return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _stream_stream(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        route: str,
 | 
				
			||||||
 | 
					        request_iterator: _MessageSource,
 | 
				
			||||||
 | 
					        request_type: Type[ST],
 | 
				
			||||||
 | 
					        response_type: Type[T],
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        timeout: Optional[float] = None,
 | 
				
			||||||
 | 
					        deadline: Optional["Deadline"] = None,
 | 
				
			||||||
 | 
					        metadata: Optional[_MetadataLike] = None,
 | 
				
			||||||
 | 
					    ) -> AsyncIterator[T]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Make a stream request and return an AsyncIterator to iterate over response
 | 
				
			||||||
 | 
					        messages.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        async with self.channel.request(
 | 
				
			||||||
 | 
					            route,
 | 
				
			||||||
 | 
					            grpclib.const.Cardinality.STREAM_STREAM,
 | 
				
			||||||
 | 
					            request_type,
 | 
				
			||||||
 | 
					            response_type,
 | 
				
			||||||
 | 
					            **self.__resolve_request_kwargs(timeout, deadline, metadata),
 | 
				
			||||||
 | 
					        ) as stream:
 | 
				
			||||||
 | 
					            await stream.send_request()
 | 
				
			||||||
 | 
					            sending_task = asyncio.ensure_future(
 | 
				
			||||||
 | 
					                self._send_messages(stream, request_iterator)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                async for response in stream:
 | 
				
			||||||
 | 
					                    yield response
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                sending_task.cancel()
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    async def _send_messages(stream, messages: _MessageSource):
 | 
				
			||||||
 | 
					        if isinstance(messages, AsyncIterable):
 | 
				
			||||||
 | 
					            async for message in messages:
 | 
				
			||||||
 | 
					                await stream.send_message(message)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for message in messages:
 | 
				
			||||||
 | 
					                await stream.send_message(message)
 | 
				
			||||||
 | 
					        await stream.end()
 | 
				
			||||||
							
								
								
									
										0
									
								
								betterproto/grpc/util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/grpc/util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										198
									
								
								betterproto/grpc/util/async_channel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								betterproto/grpc/util/async_channel.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,198 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from typing import (
 | 
				
			||||||
 | 
					    AsyncIterable,
 | 
				
			||||||
 | 
					    AsyncIterator,
 | 
				
			||||||
 | 
					    Iterable,
 | 
				
			||||||
 | 
					    Optional,
 | 
				
			||||||
 | 
					    TypeVar,
 | 
				
			||||||
 | 
					    Union,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					T = TypeVar("T")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChannelClosed(Exception):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    An exception raised on an attempt to send through a closed channel
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChannelDone(Exception):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    An exception raised on an attempt to send recieve from a channel that is both closed
 | 
				
			||||||
 | 
					    and empty.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AsyncChannel(AsyncIterable[T]):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    A buffered async channel for sending items between coroutines with FIFO ordering.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This makes decoupled bidirection steaming gRPC requests easy if used like:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    .. code-block:: python
 | 
				
			||||||
 | 
					        client = GeneratedStub(grpclib_chan)
 | 
				
			||||||
 | 
					        request_chan = await AsyncChannel()
 | 
				
			||||||
 | 
					        # We can start be sending all the requests we already have
 | 
				
			||||||
 | 
					        await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
 | 
				
			||||||
 | 
					        async for response in client.rpc_call(request_chan):
 | 
				
			||||||
 | 
					            # The response iterator will remain active until the connection is closed
 | 
				
			||||||
 | 
					            ...
 | 
				
			||||||
 | 
					            # More items can be sent at any time
 | 
				
			||||||
 | 
					            await request_chan.send(ReqestObject(...))
 | 
				
			||||||
 | 
					            ...
 | 
				
			||||||
 | 
					            # The channel must be closed to complete the gRPC connection
 | 
				
			||||||
 | 
					            request_chan.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Items can be sent through the channel by either:
 | 
				
			||||||
 | 
					    - providing an iterable to the send_from method
 | 
				
			||||||
 | 
					    - passing them to the send method one at a time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Items can be recieved from the channel by either:
 | 
				
			||||||
 | 
					    - iterating over the channel with a for loop to get all items
 | 
				
			||||||
 | 
					    - calling the recieve method to get one item at a time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    If the channel is empty then recievers will wait until either an item appears or the
 | 
				
			||||||
 | 
					    channel is closed.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Once the channel is closed then subsequent attempt to send through the channel will
 | 
				
			||||||
 | 
					    fail with a ChannelClosed exception.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    When th channel is closed and empty then it is done, and further attempts to recieve
 | 
				
			||||||
 | 
					    from it will fail with a ChannelDone exception
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    If multiple coroutines recieve from the channel concurrently, each item sent will be
 | 
				
			||||||
 | 
					    recieved by only one of the recievers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param source:
 | 
				
			||||||
 | 
					        An optional iterable will items that should be sent through the channel
 | 
				
			||||||
 | 
					        immediately.
 | 
				
			||||||
 | 
					    :param buffer_limit:
 | 
				
			||||||
 | 
					        Limit the number of items that can be buffered in the channel, A value less than
 | 
				
			||||||
 | 
					        1 implies no limit. If the channel is full then attempts to send more items will
 | 
				
			||||||
 | 
					        result in the sender waiting until an item is recieved from the channel.
 | 
				
			||||||
 | 
					    :param close:
 | 
				
			||||||
 | 
					        If set to True then the channel will automatically close after exhausting source
 | 
				
			||||||
 | 
					        or immediately if no source is provided.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self, *, buffer_limit: int = 0, close: bool = False,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
 | 
				
			||||||
 | 
					        self._closed = False
 | 
				
			||||||
 | 
					        self._waiting_recievers: int = 0
 | 
				
			||||||
 | 
					        # Track whether flush has been invoked so it can only happen once
 | 
				
			||||||
 | 
					        self._flushed = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __aiter__(self) -> AsyncIterator[T]:
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __anext__(self) -> T:
 | 
				
			||||||
 | 
					        if self.done():
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					        self._waiting_recievers += 1
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            result = await self._queue.get()
 | 
				
			||||||
 | 
					            if result is self.__flush:
 | 
				
			||||||
 | 
					                raise StopAsyncIteration
 | 
				
			||||||
 | 
					            return result
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            self._waiting_recievers -= 1
 | 
				
			||||||
 | 
					            self._queue.task_done()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def closed(self) -> bool:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns True if this channel is closed and no-longer accepting new items
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self._closed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def done(self) -> bool:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Check if this channel is done.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: True if this channel is closed and and has been drained of items in
 | 
				
			||||||
 | 
					        which case any further attempts to recieve an item from this channel will raise
 | 
				
			||||||
 | 
					        a ChannelDone exception.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # After close the channel is not yet done until there is at least one waiting
 | 
				
			||||||
 | 
					        # reciever per enqueued item.
 | 
				
			||||||
 | 
					        return self._closed and self._queue.qsize() <= self._waiting_recievers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def send_from(
 | 
				
			||||||
 | 
					        self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
 | 
				
			||||||
 | 
					    ) -> "AsyncChannel[T]":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Iterates the given [Async]Iterable and sends all the resulting items.
 | 
				
			||||||
 | 
					        If close is set to True then subsequent send calls will be rejected with a
 | 
				
			||||||
 | 
					        ChannelClosed exception.
 | 
				
			||||||
 | 
					        :param source: an iterable of items to send
 | 
				
			||||||
 | 
					        :param close:
 | 
				
			||||||
 | 
					            if True then the channel will be closed after the source has been exhausted
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self._closed:
 | 
				
			||||||
 | 
					            raise ChannelClosed("Cannot send through a closed channel")
 | 
				
			||||||
 | 
					        if isinstance(source, AsyncIterable):
 | 
				
			||||||
 | 
					            async for item in source:
 | 
				
			||||||
 | 
					                await self._queue.put(item)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for item in source:
 | 
				
			||||||
 | 
					                await self._queue.put(item)
 | 
				
			||||||
 | 
					        if close:
 | 
				
			||||||
 | 
					            # Complete the closing process
 | 
				
			||||||
 | 
					            self.close()
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def send(self, item: T) -> "AsyncChannel[T]":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Send a single item over this channel.
 | 
				
			||||||
 | 
					        :param item: The item to send
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self._closed:
 | 
				
			||||||
 | 
					            raise ChannelClosed("Cannot send through a closed channel")
 | 
				
			||||||
 | 
					        await self._queue.put(item)
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def recieve(self) -> Optional[T]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns the next item from this channel when it becomes available,
 | 
				
			||||||
 | 
					        or None if the channel is closed before another item is sent.
 | 
				
			||||||
 | 
					        :return: An item from the channel
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.done():
 | 
				
			||||||
 | 
					            raise ChannelDone("Cannot recieve from a closed channel")
 | 
				
			||||||
 | 
					        self._waiting_recievers += 1
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            result = await self._queue.get()
 | 
				
			||||||
 | 
					            if result is self.__flush:
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					            return result
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            self._waiting_recievers -= 1
 | 
				
			||||||
 | 
					            self._queue.task_done()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Close this channel to new items
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self._closed = True
 | 
				
			||||||
 | 
					        asyncio.ensure_future(self._flush_queue())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _flush_queue(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        To be called after the channel is closed. Pushes a number of self.__flush
 | 
				
			||||||
 | 
					        objects to the queue to ensure no waiting consumers get deadlocked.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not self._flushed:
 | 
				
			||||||
 | 
					            self._flushed = True
 | 
				
			||||||
 | 
					            deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
 | 
				
			||||||
 | 
					            for _ in range(deadlocked_recievers):
 | 
				
			||||||
 | 
					                await self._queue.put(self.__flush)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # A special signal object for flushing the queue when the channel is closed
 | 
				
			||||||
 | 
					    __flush = object()
 | 
				
			||||||
@@ -6,10 +6,10 @@ import re
 | 
				
			|||||||
import stringcase
 | 
					import stringcase
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import textwrap
 | 
					import textwrap
 | 
				
			||||||
from typing import List
 | 
					from typing import List, Union
 | 
				
			||||||
 | 
					import betterproto
 | 
				
			||||||
from betterproto.casing import safe_snake_case
 | 
					from betterproto.casing import safe_snake_case
 | 
				
			||||||
from betterproto.compile.importing import get_ref_type
 | 
					from betterproto.compile.importing import get_ref_type
 | 
				
			||||||
import betterproto
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    # betterproto[compiler] specific dependencies
 | 
					    # betterproto[compiler] specific dependencies
 | 
				
			||||||
@@ -58,8 +58,8 @@ def py_type(
 | 
				
			|||||||
        raise NotImplementedError(f"Unknown type {descriptor.type}")
 | 
					        raise NotImplementedError(f"Unknown type {descriptor.type}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_py_zero(type_num: int) -> str:
 | 
					def get_py_zero(type_num: int) -> Union[str, float]:
 | 
				
			||||||
    zero = 0
 | 
					    zero: Union[str, float] = 0
 | 
				
			||||||
    if type_num in []:
 | 
					    if type_num in []:
 | 
				
			||||||
        zero = 0.0
 | 
					        zero = 0.0
 | 
				
			||||||
    elif type_num == 8:
 | 
					    elif type_num == 8:
 | 
				
			||||||
@@ -311,9 +311,6 @@ def generate_code(request, response):
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for j, method in enumerate(service.method):
 | 
					                for j, method in enumerate(service.method):
 | 
				
			||||||
                    if method.client_streaming:
 | 
					 | 
				
			||||||
                        raise NotImplementedError("Client streaming not yet supported")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    input_message = None
 | 
					                    input_message = None
 | 
				
			||||||
                    input_type = get_ref_type(
 | 
					                    input_type = get_ref_type(
 | 
				
			||||||
                        package, output["imports"], method.input_type
 | 
					                        package, output["imports"], method.input_type
 | 
				
			||||||
@@ -347,8 +344,12 @@ def generate_code(request, response):
 | 
				
			|||||||
                        }
 | 
					                        }
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if method.client_streaming:
 | 
				
			||||||
 | 
					                        output["typing_imports"].add("AsyncIterable")
 | 
				
			||||||
 | 
					                        output["typing_imports"].add("Iterable")
 | 
				
			||||||
 | 
					                        output["typing_imports"].add("Union")
 | 
				
			||||||
                    if method.server_streaming:
 | 
					                    if method.server_streaming:
 | 
				
			||||||
                        output["typing_imports"].add("AsyncGenerator")
 | 
					                        output["typing_imports"].add("AsyncIterator")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                output["services"].append(data)
 | 
					                output["services"].append(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    {% endif %}
 | 
					    {% endif %}
 | 
				
			||||||
    {% for method in service.methods %}
 | 
					    {% for method in service.methods %}
 | 
				
			||||||
    async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
 | 
					    async def {{ method.py_name }}(self
 | 
				
			||||||
 | 
					        {%- if not method.client_streaming -%}
 | 
				
			||||||
 | 
					            {%- if method.input_message and method.input_message.properties -%}, *,
 | 
				
			||||||
 | 
					                {%- for field in method.input_message.properties -%}
 | 
				
			||||||
 | 
					                    {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
 | 
				
			||||||
 | 
					                                            Optional[{{ field.type }}]
 | 
				
			||||||
 | 
					                                         {%- else -%}
 | 
				
			||||||
 | 
					                                            {{ field.type }}
 | 
				
			||||||
 | 
					                                         {%- endif -%} = {{ field.zero }}
 | 
				
			||||||
 | 
					                    {%- if not loop.last %}, {% endif -%}
 | 
				
			||||||
 | 
					                {%- endfor -%}
 | 
				
			||||||
 | 
					            {%- endif -%}
 | 
				
			||||||
 | 
					        {%- else -%}
 | 
				
			||||||
 | 
					            {# Client streaming: need a request iterator instead #}
 | 
				
			||||||
 | 
					            , request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
 | 
				
			||||||
 | 
					        {%- endif -%}
 | 
				
			||||||
 | 
					            ) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
 | 
				
			||||||
        {% if method.comment %}
 | 
					        {% if method.comment %}
 | 
				
			||||||
{{ method.comment }}
 | 
					{{ method.comment }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        {% endif %}
 | 
					        {% endif %}
 | 
				
			||||||
 | 
					        {% if not method.client_streaming %}
 | 
				
			||||||
        request = {{ method.input }}()
 | 
					        request = {{ method.input }}()
 | 
				
			||||||
        {% for field in method.input_message.properties %}
 | 
					        {% for field in method.input_message.properties %}
 | 
				
			||||||
            {% if field.field_type == 'message' %}
 | 
					            {% if field.field_type == 'message' %}
 | 
				
			||||||
@@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
 | 
				
			|||||||
        request.{{ field.py_name }} = {{ field.py_name }}
 | 
					        request.{{ field.py_name }} = {{ field.py_name }}
 | 
				
			||||||
            {% endif %}
 | 
					            {% endif %}
 | 
				
			||||||
        {% endfor %}
 | 
					        {% endfor %}
 | 
				
			||||||
 | 
					        {% endif %}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        {% if method.server_streaming %}
 | 
					        {% if method.server_streaming %}
 | 
				
			||||||
 | 
					            {% if method.client_streaming %}
 | 
				
			||||||
 | 
					        async for response in self._stream_stream(
 | 
				
			||||||
 | 
					            "{{ method.route }}",
 | 
				
			||||||
 | 
					            request_iterator,
 | 
				
			||||||
 | 
					            {{ method.input }},
 | 
				
			||||||
 | 
					            {{ method.output }},
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            yield response
 | 
				
			||||||
 | 
					            {% else %}{# i.e. not client streaming #}
 | 
				
			||||||
        async for response in self._unary_stream(
 | 
					        async for response in self._unary_stream(
 | 
				
			||||||
            "{{ method.route }}",
 | 
					            "{{ method.route }}",
 | 
				
			||||||
            request,
 | 
					            request,
 | 
				
			||||||
            {{ method.output }},
 | 
					            {{ method.output }},
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            yield response
 | 
					            yield response
 | 
				
			||||||
        {% else %}
 | 
					
 | 
				
			||||||
 | 
					            {% endif %}{# if client streaming #}
 | 
				
			||||||
 | 
					        {% else %}{# i.e. not server streaming #}
 | 
				
			||||||
 | 
					            {% if method.client_streaming %}
 | 
				
			||||||
 | 
					        return await self._stream_unary(
 | 
				
			||||||
 | 
					            "{{ method.route }}",
 | 
				
			||||||
 | 
					            request_iterator,
 | 
				
			||||||
 | 
					            {{ method.input }},
 | 
				
			||||||
 | 
					            {{ method.output }}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					            {% else %}{# i.e. not client streaming #}
 | 
				
			||||||
        return await self._unary_unary(
 | 
					        return await self._unary_unary(
 | 
				
			||||||
            "{{ method.route }}",
 | 
					            "{{ method.route }}",
 | 
				
			||||||
            request,
 | 
					            request,
 | 
				
			||||||
            {{ method.output }},
 | 
					            {{ method.output }}
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					            {% endif %}{# client streaming #}
 | 
				
			||||||
        {% endif %}
 | 
					        {% endif %}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    {% endfor %}
 | 
					    {% endfor %}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
#!/usr/bin/env python
 | 
					#!/usr/bin/env python
 | 
				
			||||||
import glob
 | 
					import asyncio
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
@@ -20,58 +21,63 @@ from betterproto.tests.util import (
 | 
				
			|||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 | 
					os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def clear_directory(path: str):
 | 
					def clear_directory(dir_path: Path):
 | 
				
			||||||
    for file_or_directory in glob.glob(os.path.join(path, "*")):
 | 
					    for file_or_directory in dir_path.glob("*"):
 | 
				
			||||||
        if os.path.isdir(file_or_directory):
 | 
					        if file_or_directory.is_dir():
 | 
				
			||||||
            shutil.rmtree(file_or_directory)
 | 
					            shutil.rmtree(file_or_directory)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            os.remove(file_or_directory)
 | 
					            file_or_directory.unlink()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate(whitelist: Set[str]):
 | 
					async def generate(whitelist: Set[str], verbose: bool):
 | 
				
			||||||
    path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)}
 | 
					    test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
 | 
				
			||||||
    name_whitelist = {e for e in whitelist if not os.path.exists(e)}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    test_case_names = set(get_directories(inputs_path))
 | 
					    path_whitelist = set()
 | 
				
			||||||
 | 
					    name_whitelist = set()
 | 
				
			||||||
    failed_test_cases = []
 | 
					    for item in whitelist:
 | 
				
			||||||
 | 
					        if item in test_case_names:
 | 
				
			||||||
 | 
					            name_whitelist.add(item)
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        path_whitelist.add(item)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generation_tasks = []
 | 
				
			||||||
    for test_case_name in sorted(test_case_names):
 | 
					    for test_case_name in sorted(test_case_names):
 | 
				
			||||||
        test_case_input_path = os.path.realpath(
 | 
					        test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
 | 
				
			||||||
            os.path.join(inputs_path, test_case_name)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (
 | 
					        if (
 | 
				
			||||||
            whitelist
 | 
					            whitelist
 | 
				
			||||||
            and test_case_input_path not in path_whitelist
 | 
					            and str(test_case_input_path) not in path_whitelist
 | 
				
			||||||
            and test_case_name not in name_whitelist
 | 
					            and test_case_name not in name_whitelist
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
 | 
					        generation_tasks.append(
 | 
				
			||||||
 | 
					            generate_test_case_output(test_case_input_path, test_case_name, verbose)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print(f"Generating output for {test_case_name}")
 | 
					    failed_test_cases = []
 | 
				
			||||||
        try:
 | 
					    # Wait for all subprocs and match any failures to names to report
 | 
				
			||||||
            generate_test_case_output(test_case_name, test_case_input_path)
 | 
					    for test_case_name, result in zip(
 | 
				
			||||||
        except subprocess.CalledProcessError as e:
 | 
					        sorted(test_case_names), await asyncio.gather(*generation_tasks)
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if result != 0:
 | 
				
			||||||
            failed_test_cases.append(test_case_name)
 | 
					            failed_test_cases.append(test_case_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if failed_test_cases:
 | 
					    if failed_test_cases:
 | 
				
			||||||
        sys.stderr.write("\nFailed to generate the following test cases:\n")
 | 
					        sys.stderr.write(
 | 
				
			||||||
 | 
					            "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        for failed_test_case in failed_test_cases:
 | 
					        for failed_test_case in failed_test_cases:
 | 
				
			||||||
            sys.stderr.write(f"- {failed_test_case}\n")
 | 
					            sys.stderr.write(f"- {failed_test_case}\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_test_case_output(test_case_name, test_case_input_path=None):
 | 
					async def generate_test_case_output(
 | 
				
			||||||
    if not test_case_input_path:
 | 
					    test_case_input_path: Path, test_case_name: str, verbose: bool
 | 
				
			||||||
        test_case_input_path = os.path.realpath(
 | 
					) -> int:
 | 
				
			||||||
            os.path.join(inputs_path, test_case_name)
 | 
					    """
 | 
				
			||||||
        )
 | 
					    Returns the max of the subprocess return values
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    test_case_output_path_reference = os.path.join(
 | 
					    test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
 | 
				
			||||||
        output_path_reference, test_case_name
 | 
					    test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name)
 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    test_case_output_path_betterproto = os.path.join(
 | 
					 | 
				
			||||||
        output_path_betterproto, test_case_name
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    os.makedirs(test_case_output_path_reference, exist_ok=True)
 | 
					    os.makedirs(test_case_output_path_reference, exist_ok=True)
 | 
				
			||||||
    os.makedirs(test_case_output_path_betterproto, exist_ok=True)
 | 
					    os.makedirs(test_case_output_path_betterproto, exist_ok=True)
 | 
				
			||||||
@@ -79,14 +85,36 @@ def generate_test_case_output(test_case_name, test_case_input_path=None):
 | 
				
			|||||||
    clear_directory(test_case_output_path_reference)
 | 
					    clear_directory(test_case_output_path_reference)
 | 
				
			||||||
    clear_directory(test_case_output_path_betterproto)
 | 
					    clear_directory(test_case_output_path_betterproto)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    protoc_reference(test_case_input_path, test_case_output_path_reference)
 | 
					    (
 | 
				
			||||||
    protoc_plugin(test_case_input_path, test_case_output_path_betterproto)
 | 
					        (ref_out, ref_err, ref_code),
 | 
				
			||||||
 | 
					        (plg_out, plg_err, plg_code),
 | 
				
			||||||
 | 
					    ) = await asyncio.gather(
 | 
				
			||||||
 | 
					        protoc_reference(test_case_input_path, test_case_output_path_reference),
 | 
				
			||||||
 | 
					        protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    message = f"Generated output for {test_case_name!r}"
 | 
				
			||||||
 | 
					    if verbose:
 | 
				
			||||||
 | 
					        print(f"\033[31;1;4m{message}\033[0m")
 | 
				
			||||||
 | 
					        if ref_out:
 | 
				
			||||||
 | 
					            sys.stdout.buffer.write(ref_out)
 | 
				
			||||||
 | 
					        if ref_err:
 | 
				
			||||||
 | 
					            sys.stderr.buffer.write(ref_err)
 | 
				
			||||||
 | 
					        if plg_out:
 | 
				
			||||||
 | 
					            sys.stdout.buffer.write(plg_out)
 | 
				
			||||||
 | 
					        if plg_err:
 | 
				
			||||||
 | 
					            sys.stderr.buffer.write(plg_err)
 | 
				
			||||||
 | 
					        sys.stdout.buffer.flush()
 | 
				
			||||||
 | 
					        sys.stderr.buffer.flush()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        print(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return max(ref_code, plg_code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HELP = "\n".join(
 | 
					HELP = "\n".join(
 | 
				
			||||||
    [
 | 
					    (
 | 
				
			||||||
        "Usage: python generate.py",
 | 
					        "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
 | 
				
			||||||
        "       python generate.py [DIRECTORIES or NAMES]",
 | 
					 | 
				
			||||||
        "Generate python classes for standard tests.",
 | 
					        "Generate python classes for standard tests.",
 | 
				
			||||||
        "",
 | 
					        "",
 | 
				
			||||||
        "DIRECTORIES    One or more relative or absolute directories of test-cases to generate classes for.",
 | 
					        "DIRECTORIES    One or more relative or absolute directories of test-cases to generate classes for.",
 | 
				
			||||||
@@ -94,7 +122,7 @@ HELP = "\n".join(
 | 
				
			|||||||
        "",
 | 
					        "",
 | 
				
			||||||
        "NAMES          One or more test-case names to generate classes for.",
 | 
					        "NAMES          One or more test-case names to generate classes for.",
 | 
				
			||||||
        "               python generate.py bool double enums",
 | 
					        "               python generate.py bool double enums",
 | 
				
			||||||
    ]
 | 
					    )
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -102,9 +130,13 @@ def main():
 | 
				
			|||||||
    if set(sys.argv).intersection({"-h", "--help"}):
 | 
					    if set(sys.argv).intersection({"-h", "--help"}):
 | 
				
			||||||
        print(HELP)
 | 
					        print(HELP)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    whitelist = set(sys.argv[1:])
 | 
					    if sys.argv[1:2] == ["-v"]:
 | 
				
			||||||
 | 
					        verbose = True
 | 
				
			||||||
    generate(whitelist)
 | 
					        whitelist = set(sys.argv[2:])
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        verbose = False
 | 
				
			||||||
 | 
					        whitelist = set(sys.argv[1:])
 | 
				
			||||||
 | 
					    asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										0
									
								
								betterproto/tests/grpc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								betterproto/tests/grpc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										154
									
								
								betterproto/tests/grpc/test_grpclib_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								betterproto/tests/grpc/test_grpclib_client.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,154 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from betterproto.tests.output_betterproto.service.service import (
 | 
				
			||||||
 | 
					    DoThingResponse,
 | 
				
			||||||
 | 
					    DoThingRequest,
 | 
				
			||||||
 | 
					    GetThingRequest,
 | 
				
			||||||
 | 
					    GetThingResponse,
 | 
				
			||||||
 | 
					    TestStub as ThingServiceClient,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					import grpclib
 | 
				
			||||||
 | 
					from grpclib.testing import ChannelFor
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from betterproto.grpc.util.async_channel import AsyncChannel
 | 
				
			||||||
 | 
					from .thing_service import ThingService
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def _test_client(client, name="clean room", **kwargs):
 | 
				
			||||||
 | 
					    response = await client.do_thing(name=name)
 | 
				
			||||||
 | 
					    assert response.names == [name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _assert_request_meta_recieved(deadline, metadata):
 | 
				
			||||||
 | 
					    def server_side_test(stream):
 | 
				
			||||||
 | 
					        assert stream.deadline._timestamp == pytest.approx(
 | 
				
			||||||
 | 
					            deadline._timestamp, 1
 | 
				
			||||||
 | 
					        ), "The provided deadline should be recieved serverside"
 | 
				
			||||||
 | 
					        assert (
 | 
				
			||||||
 | 
					            stream.metadata["authorization"] == metadata["authorization"]
 | 
				
			||||||
 | 
					        ), "The provided authorization metadata should be recieved serverside"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return server_side_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_simple_service_call():
 | 
				
			||||||
 | 
					    async with ChannelFor([ThingService()]) as channel:
 | 
				
			||||||
 | 
					        await _test_client(ThingServiceClient(channel))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_service_call_with_upfront_request_params():
 | 
				
			||||||
 | 
					    # Setting deadline
 | 
				
			||||||
 | 
					    deadline = grpclib.metadata.Deadline.from_timeout(22)
 | 
				
			||||||
 | 
					    metadata = {"authorization": "12345"}
 | 
				
			||||||
 | 
					    async with ChannelFor(
 | 
				
			||||||
 | 
					        [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
 | 
				
			||||||
 | 
					    ) as channel:
 | 
				
			||||||
 | 
					        await _test_client(
 | 
				
			||||||
 | 
					            ThingServiceClient(channel, deadline=deadline, metadata=metadata)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setting timeout
 | 
				
			||||||
 | 
					    timeout = 99
 | 
				
			||||||
 | 
					    deadline = grpclib.metadata.Deadline.from_timeout(timeout)
 | 
				
			||||||
 | 
					    metadata = {"authorization": "12345"}
 | 
				
			||||||
 | 
					    async with ChannelFor(
 | 
				
			||||||
 | 
					        [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
 | 
				
			||||||
 | 
					    ) as channel:
 | 
				
			||||||
 | 
					        await _test_client(
 | 
				
			||||||
 | 
					            ThingServiceClient(channel, timeout=timeout, metadata=metadata)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_service_call_lower_level_with_overrides():
 | 
				
			||||||
 | 
					    THING_TO_DO = "get milk"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setting deadline
 | 
				
			||||||
 | 
					    deadline = grpclib.metadata.Deadline.from_timeout(22)
 | 
				
			||||||
 | 
					    metadata = {"authorization": "12345"}
 | 
				
			||||||
 | 
					    kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
 | 
				
			||||||
 | 
					    kwarg_metadata = {"authorization": "12345"}
 | 
				
			||||||
 | 
					    async with ChannelFor(
 | 
				
			||||||
 | 
					        [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
 | 
				
			||||||
 | 
					    ) as channel:
 | 
				
			||||||
 | 
					        client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
 | 
				
			||||||
 | 
					        response = await client._unary_unary(
 | 
				
			||||||
 | 
					            "/service.Test/DoThing",
 | 
				
			||||||
 | 
					            DoThingRequest(THING_TO_DO),
 | 
				
			||||||
 | 
					            DoThingResponse,
 | 
				
			||||||
 | 
					            deadline=kwarg_deadline,
 | 
				
			||||||
 | 
					            metadata=kwarg_metadata,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert response.names == [THING_TO_DO]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setting timeout
 | 
				
			||||||
 | 
					    timeout = 99
 | 
				
			||||||
 | 
					    deadline = grpclib.metadata.Deadline.from_timeout(timeout)
 | 
				
			||||||
 | 
					    metadata = {"authorization": "12345"}
 | 
				
			||||||
 | 
					    kwarg_timeout = 9000
 | 
				
			||||||
 | 
					    kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
 | 
				
			||||||
 | 
					    kwarg_metadata = {"authorization": "09876"}
 | 
				
			||||||
 | 
					    async with ChannelFor(
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            ThingService(
 | 
				
			||||||
 | 
					                test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    ) as channel:
 | 
				
			||||||
 | 
					        client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
 | 
				
			||||||
 | 
					        response = await client._unary_unary(
 | 
				
			||||||
 | 
					            "/service.Test/DoThing",
 | 
				
			||||||
 | 
					            DoThingRequest(THING_TO_DO),
 | 
				
			||||||
 | 
					            DoThingResponse,
 | 
				
			||||||
 | 
					            timeout=kwarg_timeout,
 | 
				
			||||||
 | 
					            metadata=kwarg_metadata,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert response.names == [THING_TO_DO]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_gen_for_unary_stream_request():
 | 
				
			||||||
 | 
					    thing_name = "my milkshakes"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with ChannelFor([ThingService()]) as channel:
 | 
				
			||||||
 | 
					        client = ThingServiceClient(channel)
 | 
				
			||||||
 | 
					        expected_versions = [5, 4, 3, 2, 1]
 | 
				
			||||||
 | 
					        async for response in client.get_thing_versions(name=thing_name):
 | 
				
			||||||
 | 
					            assert response.name == thing_name
 | 
				
			||||||
 | 
					            assert response.version == expected_versions.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_gen_for_stream_stream_request():
 | 
				
			||||||
 | 
					    some_things = ["cake", "cricket", "coral reef"]
 | 
				
			||||||
 | 
					    more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
 | 
				
			||||||
 | 
					    expected_things = (*some_things, *more_things)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with ChannelFor([ThingService()]) as channel:
 | 
				
			||||||
 | 
					        client = ThingServiceClient(channel)
 | 
				
			||||||
 | 
					        # Use an AsyncChannel to decouple sending and recieving, it'll send some_things
 | 
				
			||||||
 | 
					        # immediately and we'll use it to send more_things later, after recieving some
 | 
				
			||||||
 | 
					        # results
 | 
				
			||||||
 | 
					        request_chan = AsyncChannel()
 | 
				
			||||||
 | 
					        send_initial_requests = asyncio.ensure_future(
 | 
				
			||||||
 | 
					            request_chan.send_from(GetThingRequest(name) for name in some_things)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        response_index = 0
 | 
				
			||||||
 | 
					        async for response in client.get_different_things(request_chan):
 | 
				
			||||||
 | 
					            assert response.name == expected_things[response_index]
 | 
				
			||||||
 | 
					            assert response.version == response_index + 1
 | 
				
			||||||
 | 
					            response_index += 1
 | 
				
			||||||
 | 
					            if more_things:
 | 
				
			||||||
 | 
					                # Send some more requests as we recieve reponses to be sure coordination of
 | 
				
			||||||
 | 
					                # send/recieve events doesn't matter
 | 
				
			||||||
 | 
					                await request_chan.send(GetThingRequest(more_things.pop(0)))
 | 
				
			||||||
 | 
					            elif not send_initial_requests.done():
 | 
				
			||||||
 | 
					                # Make sure the sending task it completed
 | 
				
			||||||
 | 
					                await send_initial_requests
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # No more things to send make sure channel is closed
 | 
				
			||||||
 | 
					                request_chan.close()
 | 
				
			||||||
 | 
					        assert response_index == len(
 | 
				
			||||||
 | 
					            expected_things
 | 
				
			||||||
 | 
					        ), "Didn't recieve all exptected responses"
 | 
				
			||||||
							
								
								
									
										100
									
								
								betterproto/tests/grpc/test_stream_stream.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								betterproto/tests/grpc/test_stream_stream.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,100 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import betterproto
 | 
				
			||||||
 | 
					from betterproto.grpc.util.async_channel import AsyncChannel
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from typing import AsyncIterator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class Message(betterproto.Message):
 | 
				
			||||||
 | 
					    body: str = betterproto.string_field(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def expected_responses():
 | 
				
			||||||
 | 
					    return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClientStub:
 | 
				
			||||||
 | 
					    async def connect(self, requests: AsyncIterator):
 | 
				
			||||||
 | 
					        await asyncio.sleep(0.1)
 | 
				
			||||||
 | 
					        async for request in requests:
 | 
				
			||||||
 | 
					            await asyncio.sleep(0.1)
 | 
				
			||||||
 | 
					            yield request
 | 
				
			||||||
 | 
					        await asyncio.sleep(0.1)
 | 
				
			||||||
 | 
					        yield Message("Done")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def to_list(generator: AsyncIterator):
 | 
				
			||||||
 | 
					    result = []
 | 
				
			||||||
 | 
					    async for value in generator:
 | 
				
			||||||
 | 
					        result.append(value)
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def client():
 | 
				
			||||||
 | 
					    # channel = Channel(host='127.0.0.1', port=50051)
 | 
				
			||||||
 | 
					    # return ClientStub(channel)
 | 
				
			||||||
 | 
					    return ClientStub()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_send_from_before_connect_and_close_automatically(
 | 
				
			||||||
 | 
					    client, expected_responses
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    requests = AsyncChannel()
 | 
				
			||||||
 | 
					    await requests.send_from(
 | 
				
			||||||
 | 
					        [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    responses = client.connect(requests)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await to_list(responses) == expected_responses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_send_from_after_connect_and_close_automatically(
 | 
				
			||||||
 | 
					    client, expected_responses
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    requests = AsyncChannel()
 | 
				
			||||||
 | 
					    responses = client.connect(requests)
 | 
				
			||||||
 | 
					    await requests.send_from(
 | 
				
			||||||
 | 
					        [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await to_list(responses) == expected_responses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_send_from_close_manually_immediately(client, expected_responses):
 | 
				
			||||||
 | 
					    requests = AsyncChannel()
 | 
				
			||||||
 | 
					    responses = client.connect(requests)
 | 
				
			||||||
 | 
					    await requests.send_from(
 | 
				
			||||||
 | 
					        [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    requests.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await to_list(responses) == expected_responses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_send_individually_and_close_before_connect(client, expected_responses):
 | 
				
			||||||
 | 
					    requests = AsyncChannel()
 | 
				
			||||||
 | 
					    await requests.send(Message(body="Hello world 1"))
 | 
				
			||||||
 | 
					    await requests.send(Message(body="Hello world 2"))
 | 
				
			||||||
 | 
					    requests.close()
 | 
				
			||||||
 | 
					    responses = client.connect(requests)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await to_list(responses) == expected_responses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_send_individually_and_close_after_connect(client, expected_responses):
 | 
				
			||||||
 | 
					    requests = AsyncChannel()
 | 
				
			||||||
 | 
					    await requests.send(Message(body="Hello world 1"))
 | 
				
			||||||
 | 
					    await requests.send(Message(body="Hello world 2"))
 | 
				
			||||||
 | 
					    responses = client.connect(requests)
 | 
				
			||||||
 | 
					    requests.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await to_list(responses) == expected_responses
 | 
				
			||||||
							
								
								
									
										83
									
								
								betterproto/tests/grpc/thing_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								betterproto/tests/grpc/thing_service.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,83 @@
 | 
				
			|||||||
 | 
					from betterproto.tests.output_betterproto.service.service import (
 | 
				
			||||||
 | 
					    DoThingResponse,
 | 
				
			||||||
 | 
					    DoThingRequest,
 | 
				
			||||||
 | 
					    GetThingRequest,
 | 
				
			||||||
 | 
					    GetThingResponse,
 | 
				
			||||||
 | 
					    TestStub as ThingServiceClient,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					import grpclib
 | 
				
			||||||
 | 
					from typing import Any, Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ThingService:
 | 
				
			||||||
 | 
					    def __init__(self, test_hook=None):
 | 
				
			||||||
 | 
					        # This lets us pass assertions to the servicer ;)
 | 
				
			||||||
 | 
					        self.test_hook = test_hook
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def do_thing(
 | 
				
			||||||
 | 
					        self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        request = await stream.recv_message()
 | 
				
			||||||
 | 
					        if self.test_hook is not None:
 | 
				
			||||||
 | 
					            self.test_hook(stream)
 | 
				
			||||||
 | 
					        await stream.send_message(DoThingResponse([request.name]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def do_many_things(
 | 
				
			||||||
 | 
					        self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        thing_names = [request.name for request in stream]
 | 
				
			||||||
 | 
					        if self.test_hook is not None:
 | 
				
			||||||
 | 
					            self.test_hook(stream)
 | 
				
			||||||
 | 
					        await stream.send_message(DoThingResponse(thing_names))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_thing_versions(
 | 
				
			||||||
 | 
					        self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        request = await stream.recv_message()
 | 
				
			||||||
 | 
					        if self.test_hook is not None:
 | 
				
			||||||
 | 
					            self.test_hook(stream)
 | 
				
			||||||
 | 
					        for version_num in range(1, 6):
 | 
				
			||||||
 | 
					            await stream.send_message(
 | 
				
			||||||
 | 
					                GetThingResponse(name=request.name, version=version_num)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_different_things(
 | 
				
			||||||
 | 
					        self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if self.test_hook is not None:
 | 
				
			||||||
 | 
					            self.test_hook(stream)
 | 
				
			||||||
 | 
					        #  Respond to each input item immediately
 | 
				
			||||||
 | 
					        response_num = 0
 | 
				
			||||||
 | 
					        async for request in stream:
 | 
				
			||||||
 | 
					            response_num += 1
 | 
				
			||||||
 | 
					            await stream.send_message(
 | 
				
			||||||
 | 
					                GetThingResponse(name=request.name, version=response_num)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "/service.Test/DoThing": grpclib.const.Handler(
 | 
				
			||||||
 | 
					                self.do_thing,
 | 
				
			||||||
 | 
					                grpclib.const.Cardinality.UNARY_UNARY,
 | 
				
			||||||
 | 
					                DoThingRequest,
 | 
				
			||||||
 | 
					                DoThingResponse,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            "/service.Test/DoManyThings": grpclib.const.Handler(
 | 
				
			||||||
 | 
					                self.do_many_things,
 | 
				
			||||||
 | 
					                grpclib.const.Cardinality.STREAM_UNARY,
 | 
				
			||||||
 | 
					                DoThingRequest,
 | 
				
			||||||
 | 
					                DoThingResponse,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            "/service.Test/GetThingVersions": grpclib.const.Handler(
 | 
				
			||||||
 | 
					                self.get_thing_versions,
 | 
				
			||||||
 | 
					                grpclib.const.Cardinality.UNARY_STREAM,
 | 
				
			||||||
 | 
					                GetThingRequest,
 | 
				
			||||||
 | 
					                GetThingResponse,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            "/service.Test/GetDifferentThings": grpclib.const.Handler(
 | 
				
			||||||
 | 
					                self.get_different_things,
 | 
				
			||||||
 | 
					                grpclib.const.Cardinality.STREAM_STREAM,
 | 
				
			||||||
 | 
					                GetThingRequest,
 | 
				
			||||||
 | 
					                GetThingResponse,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
@@ -23,7 +23,7 @@ test_cases = [
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@pytest.mark.asyncio
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
 | 
					@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
 | 
				
			||||||
async def test_channel_receives_wrapped_type(
 | 
					async def test_channel_recieves_wrapped_type(
 | 
				
			||||||
    service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
 | 
					    service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    wrapped_value = wrapper_class()
 | 
					    wrapped_value = wrapper_class()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,13 +3,25 @@ syntax = "proto3";
 | 
				
			|||||||
package service;
 | 
					package service;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message DoThingRequest {
 | 
					message DoThingRequest {
 | 
				
			||||||
  int32 iterations = 1;
 | 
					  string name = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message DoThingResponse {
 | 
					message DoThingResponse {
 | 
				
			||||||
  int32 successfulIterations = 1;
 | 
					  repeated string names = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message GetThingRequest {
 | 
				
			||||||
 | 
					  string name = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message GetThingResponse {
 | 
				
			||||||
 | 
					  string name = 1;
 | 
				
			||||||
 | 
					  int32 version = 2;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
service Test {
 | 
					service Test {
 | 
				
			||||||
  rpc DoThing (DoThingRequest) returns (DoThingResponse);
 | 
					  rpc DoThing (DoThingRequest) returns (DoThingResponse);
 | 
				
			||||||
 | 
					  rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse);
 | 
				
			||||||
 | 
					  rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse);
 | 
				
			||||||
 | 
					  rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,132 +0,0 @@
 | 
				
			|||||||
import betterproto
 | 
					 | 
				
			||||||
import grpclib
 | 
					 | 
				
			||||||
from grpclib.testing import ChannelFor
 | 
					 | 
				
			||||||
import pytest
 | 
					 | 
				
			||||||
from typing import Dict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from betterproto.tests.output_betterproto.service.service import (
 | 
					 | 
				
			||||||
    DoThingResponse,
 | 
					 | 
				
			||||||
    DoThingRequest,
 | 
					 | 
				
			||||||
    TestStub as ExampleServiceStub,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ExampleService:
 | 
					 | 
				
			||||||
    def __init__(self, test_hook=None):
 | 
					 | 
				
			||||||
        # This lets us pass assertions to the servicer ;)
 | 
					 | 
				
			||||||
        self.test_hook = test_hook
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def DoThing(
 | 
					 | 
				
			||||||
        self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        request = await stream.recv_message()
 | 
					 | 
				
			||||||
        print("self.test_hook", self.test_hook)
 | 
					 | 
				
			||||||
        if self.test_hook is not None:
 | 
					 | 
				
			||||||
            self.test_hook(stream)
 | 
					 | 
				
			||||||
        for iteration in range(request.iterations):
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        await stream.send_message(DoThingResponse(request.iterations))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "/service.Test/DoThing": grpclib.const.Handler(
 | 
					 | 
				
			||||||
                self.DoThing,
 | 
					 | 
				
			||||||
                grpclib.const.Cardinality.UNARY_UNARY,
 | 
					 | 
				
			||||||
                DoThingRequest,
 | 
					 | 
				
			||||||
                DoThingResponse,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def _test_stub(stub, iterations=42, **kwargs):
 | 
					 | 
				
			||||||
    response = await stub.do_thing(iterations=iterations)
 | 
					 | 
				
			||||||
    assert response.successful_iterations == iterations
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_server_side_test(deadline, metadata):
 | 
					 | 
				
			||||||
    def server_side_test(stream):
 | 
					 | 
				
			||||||
        assert stream.deadline._timestamp == pytest.approx(
 | 
					 | 
				
			||||||
            deadline._timestamp, 1
 | 
					 | 
				
			||||||
        ), "The provided deadline should be recieved serverside"
 | 
					 | 
				
			||||||
        assert (
 | 
					 | 
				
			||||||
            stream.metadata["authorization"] == metadata["authorization"]
 | 
					 | 
				
			||||||
        ), "The provided authorization metadata should be recieved serverside"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return server_side_test
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.asyncio
 | 
					 | 
				
			||||||
async def test_simple_service_call():
 | 
					 | 
				
			||||||
    async with ChannelFor([ExampleService()]) as channel:
 | 
					 | 
				
			||||||
        await _test_stub(ExampleServiceStub(channel))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.asyncio
 | 
					 | 
				
			||||||
async def test_service_call_with_upfront_request_params():
 | 
					 | 
				
			||||||
    # Setting deadline
 | 
					 | 
				
			||||||
    deadline = grpclib.metadata.Deadline.from_timeout(22)
 | 
					 | 
				
			||||||
    metadata = {"authorization": "12345"}
 | 
					 | 
				
			||||||
    async with ChannelFor(
 | 
					 | 
				
			||||||
        [ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
 | 
					 | 
				
			||||||
    ) as channel:
 | 
					 | 
				
			||||||
        await _test_stub(
 | 
					 | 
				
			||||||
            ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setting timeout
 | 
					 | 
				
			||||||
    timeout = 99
 | 
					 | 
				
			||||||
    deadline = grpclib.metadata.Deadline.from_timeout(timeout)
 | 
					 | 
				
			||||||
    metadata = {"authorization": "12345"}
 | 
					 | 
				
			||||||
    async with ChannelFor(
 | 
					 | 
				
			||||||
        [ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
 | 
					 | 
				
			||||||
    ) as channel:
 | 
					 | 
				
			||||||
        await _test_stub(
 | 
					 | 
				
			||||||
            ExampleServiceStub(channel, timeout=timeout, metadata=metadata)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.asyncio
 | 
					 | 
				
			||||||
async def test_service_call_lower_level_with_overrides():
 | 
					 | 
				
			||||||
    ITERATIONS = 99
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setting deadline
 | 
					 | 
				
			||||||
    deadline = grpclib.metadata.Deadline.from_timeout(22)
 | 
					 | 
				
			||||||
    metadata = {"authorization": "12345"}
 | 
					 | 
				
			||||||
    kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
 | 
					 | 
				
			||||||
    kwarg_metadata = {"authorization": "12345"}
 | 
					 | 
				
			||||||
    async with ChannelFor(
 | 
					 | 
				
			||||||
        [ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
 | 
					 | 
				
			||||||
    ) as channel:
 | 
					 | 
				
			||||||
        stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
 | 
					 | 
				
			||||||
        response = await stub._unary_unary(
 | 
					 | 
				
			||||||
            "/service.Test/DoThing",
 | 
					 | 
				
			||||||
            DoThingRequest(ITERATIONS),
 | 
					 | 
				
			||||||
            DoThingResponse,
 | 
					 | 
				
			||||||
            deadline=kwarg_deadline,
 | 
					 | 
				
			||||||
            metadata=kwarg_metadata,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        assert response.successful_iterations == ITERATIONS
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setting timeout
 | 
					 | 
				
			||||||
    timeout = 99
 | 
					 | 
				
			||||||
    deadline = grpclib.metadata.Deadline.from_timeout(timeout)
 | 
					 | 
				
			||||||
    metadata = {"authorization": "12345"}
 | 
					 | 
				
			||||||
    kwarg_timeout = 9000
 | 
					 | 
				
			||||||
    kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
 | 
					 | 
				
			||||||
    kwarg_metadata = {"authorization": "09876"}
 | 
					 | 
				
			||||||
    async with ChannelFor(
 | 
					 | 
				
			||||||
        [
 | 
					 | 
				
			||||||
            ExampleService(
 | 
					 | 
				
			||||||
                test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata)
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
    ) as channel:
 | 
					 | 
				
			||||||
        stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
 | 
					 | 
				
			||||||
        response = await stub._unary_unary(
 | 
					 | 
				
			||||||
            "/service.Test/DoThing",
 | 
					 | 
				
			||||||
            DoThingRequest(ITERATIONS),
 | 
					 | 
				
			||||||
            DoThingResponse,
 | 
					 | 
				
			||||||
            timeout=kwarg_timeout,
 | 
					 | 
				
			||||||
            metadata=kwarg_metadata,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        assert response.successful_iterations == ITERATIONS
 | 
					 | 
				
			||||||
@@ -23,7 +23,7 @@ from google.protobuf.json_format import Parse
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class TestCases:
 | 
					class TestCases:
 | 
				
			||||||
    def __init__(self, path, services: Set[str], xfail: Set[str]):
 | 
					    def __init__(self, path, services: Set[str], xfail: Set[str]):
 | 
				
			||||||
        _all = set(get_directories(path))
 | 
					        _all = set(get_directories(path)) - {"__pycache__"}
 | 
				
			||||||
        _services = services
 | 
					        _services = services
 | 
				
			||||||
        _messages = (_all - services) - {"__pycache__"}
 | 
					        _messages = (_all - services) - {"__pycache__"}
 | 
				
			||||||
        _messages_with_json = {
 | 
					        _messages_with_json = {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,23 +1,24 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import subprocess
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Generator
 | 
					from typing import Generator, IO, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 | 
					os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
root_path = os.path.dirname(os.path.realpath(__file__))
 | 
					root_path = Path(__file__).resolve().parent
 | 
				
			||||||
inputs_path = os.path.join(root_path, "inputs")
 | 
					inputs_path = root_path.joinpath("inputs")
 | 
				
			||||||
output_path_reference = os.path.join(root_path, "output_reference")
 | 
					output_path_reference = root_path.joinpath("output_reference")
 | 
				
			||||||
output_path_betterproto = os.path.join(root_path, "output_betterproto")
 | 
					output_path_betterproto = root_path.joinpath("output_betterproto")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if os.name == "nt":
 | 
					if os.name == "nt":
 | 
				
			||||||
    plugin_path = os.path.join(root_path, "..", "plugin.bat")
 | 
					    plugin_path = root_path.joinpath("..", "plugin.bat").resolve()
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    plugin_path = os.path.join(root_path, "..", "plugin.py")
 | 
					    plugin_path = root_path.joinpath("..", "plugin.py").resolve()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_files(path, end: str) -> Generator[str, None, None]:
 | 
					def get_files(path, suffix: str) -> Generator[str, None, None]:
 | 
				
			||||||
    for r, dirs, files in os.walk(path):
 | 
					    for r, dirs, files in os.walk(path):
 | 
				
			||||||
        for filename in [f for f in files if f.endswith(end)]:
 | 
					        for filename in [f for f in files if f.endswith(suffix)]:
 | 
				
			||||||
            yield os.path.join(r, filename)
 | 
					            yield os.path.join(r, filename)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -27,36 +28,30 @@ def get_directories(path):
 | 
				
			|||||||
            yield directory
 | 
					            yield directory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def relative(file: str, path: str):
 | 
					async def protoc_plugin(path: str, output_dir: str):
 | 
				
			||||||
    return os.path.join(os.path.dirname(file), path)
 | 
					    proc = await asyncio.create_subprocess_shell(
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def read_relative(file: str, path: str):
 | 
					 | 
				
			||||||
    with open(relative(file, path)) as fh:
 | 
					 | 
				
			||||||
        return fh.read()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def protoc_plugin(path: str, output_dir: str) -> subprocess.CompletedProcess:
 | 
					 | 
				
			||||||
    return subprocess.run(
 | 
					 | 
				
			||||||
        f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
 | 
					        f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
 | 
				
			||||||
        shell=True,
 | 
					        stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
        check=True,
 | 
					        stderr=asyncio.subprocess.PIPE,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    return (*(await proc.communicate()), proc.returncode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def protoc_reference(path: str, output_dir: str):
 | 
					async def protoc_reference(path: str, output_dir: str):
 | 
				
			||||||
    subprocess.run(
 | 
					    proc = await asyncio.create_subprocess_shell(
 | 
				
			||||||
        f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
 | 
					        f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
 | 
				
			||||||
        shell=True,
 | 
					        stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        stderr=asyncio.subprocess.PIPE,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    return (*(await proc.communicate()), proc.returncode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_test_case_json_data(test_case_name, json_file_name=None):
 | 
					def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
 | 
				
			||||||
    test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
 | 
					    test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
 | 
				
			||||||
    test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name)
 | 
					    test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not os.path.exists(test_data_file_path):
 | 
					    if not test_data_file_path.exists():
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(test_data_file_path) as fh:
 | 
					    with test_data_file_path.open("r") as fh:
 | 
				
			||||||
        return fh.read()
 | 
					        return fh.read()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user