diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c1eb9d7..6b1e006 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -238,7 +238,7 @@ class OutputTemplate: parent_request: PluginRequestCompiler package_proto_obj: FileDescriptorProto input_files: List[str] = field(default_factory=list) - imports: Set[str] = field(default_factory=set) + imports_end: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) pydantic_imports: Set[str] = field(default_factory=set) builtins_import: bool = False @@ -532,7 +532,7 @@ class FieldCompiler(MessageCompiler): # Type referencing another defined Message or a named enum return get_type_reference( package=self.output_file.package, - imports=self.output_file.imports, + imports=self.output_file.imports_end, source_type=self.proto_obj.type_name, typing_compiler=self.typing_compiler, pydantic=self.output_file.pydantic_dataclasses, @@ -730,7 +730,7 @@ class ServiceMethodCompiler(ProtoContentBase): """ return get_type_reference( package=self.output_file.package, - imports=self.output_file.imports, + imports=self.output_file.imports_end, source_type=self.proto_obj.input_type, typing_compiler=self.output_file.typing_compiler, unwrap=False, @@ -760,7 +760,7 @@ class ServiceMethodCompiler(ProtoContentBase): """ return get_type_reference( package=self.output_file.package, - imports=self.output_file.imports, + imports=self.output_file.imports_end, source_type=self.proto_obj.output_type, typing_compiler=self.output_file.typing_compiler, unwrap=False, diff --git a/src/betterproto/templates/header.py.j2 b/src/betterproto/templates/header.py.j2 index e3bf10d..9c8dddd 100644 --- a/src/betterproto/templates/header.py.j2 +++ b/src/betterproto/templates/header.py.j2 @@ -8,7 +8,6 @@ import {{ i }} {% if output_file.pydantic_dataclasses %} from pydantic.dataclasses import dataclass -from pydantic.dataclasses import rebuild_dataclass {%- else -%} from dataclasses import dataclass {% endif %} @@ -35,10 +34,6 @@ from betterproto.grpc.grpclib_server import ServiceBase import grpclib {% endif %} -{% for i in output_file.imports|sort %} -{{ i }} -{% endfor %} - {% if output_file.imports_type_checking_only %} from typing import TYPE_CHECKING diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index e8ed3d8..4a252ae 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -77,14 +77,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }} + , {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}" {%- endif -%} , * , timeout: {{ output_file.typing_compiler.optional("float") }} = None , deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None , metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None - ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} {{ method.comment }} @@ -143,6 +143,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% endfor %} +{% for i in output_file.imports_end %} +{{ i }} +{% endfor %} + {% for service in output_file.services %} class {{ service.py_name }}Base(ServiceBase): {% if service.comment %} @@ -211,11 +215,3 @@ class {{ service.py_name }}Base(ServiceBase): } {% endfor %} - -{% if output_file.pydantic_dataclasses %} -{% for message in output_file.messages %} -{% if message.has_message_field %} -rebuild_dataclass({{ message.py_name }}) # type: ignore -{% endif %} -{% endfor %} -{% endif %} diff --git a/tests/inputs/import_circular_dependency/import_circular_dependency.proto b/tests/inputs/import_circular_dependency/import_circular_dependency.proto index 8b159e2..4441be9 100644 --- a/tests/inputs/import_circular_dependency/import_circular_dependency.proto +++ b/tests/inputs/import_circular_dependency/import_circular_dependency.proto @@ -26,5 +26,5 @@ import "other.proto"; // (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage) message Test { RootPackageMessage message = 1; - other.OtherPackageMessage other = 2; + other.OtherPackageMessage other_value = 2; }