Support nested messages, fix casing. Support test-cases in packages.

This commit is contained in:
boukeversteegh
2020-06-07 16:57:57 +02:00
parent d8abb850f8
commit f7c2fd1194
19 changed files with 333 additions and 163 deletions

View File

@@ -1,59 +1,72 @@
from functools import reduce
import re
from typing import Dict, List, Type
import stringcase
from betterproto import safe_snake_case
from betterproto.compile.naming import pythonize_class_name
from betterproto.lib.google import protobuf as google_protobuf
WRAPPER_TYPES: Dict[str, Type] = {
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
"google.protobuf.FloatValue": google_protobuf.FloatValue,
"google.protobuf.Int32Value": google_protobuf.Int32Value,
"google.protobuf.Int64Value": google_protobuf.Int64Value,
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
"google.protobuf.BoolValue": google_protobuf.BoolValue,
"google.protobuf.StringValue": google_protobuf.StringValue,
"google.protobuf.BytesValue": google_protobuf.BytesValue,
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
".google.protobuf.FloatValue": google_protobuf.FloatValue,
".google.protobuf.Int32Value": google_protobuf.Int32Value,
".google.protobuf.Int64Value": google_protobuf.Int64Value,
".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
".google.protobuf.BoolValue": google_protobuf.BoolValue,
".google.protobuf.StringValue": google_protobuf.StringValue,
".google.protobuf.BytesValue": google_protobuf.BytesValue,
}
def parse_source_type_name(field_type_name):
"""
Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message')
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
"""
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
if package_match:
package = package_match.group(1)
name = package_match.group(2)
else:
package = ""
name = field_type_name.lstrip(".")
return package, name
def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
package: str, imports: set, source_type: str, unwrap: bool = True,
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
is_wrapper = type_name in WRAPPER_TYPES
is_wrapper = source_type in WRAPPER_TYPES
if unwrap:
if is_wrapper:
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
if source_type == ".google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
if source_type == ".google.protobuf.Timestamp":
return "datetime"
# Use precompiled classes for google.protobuf.* objects
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
type_name = type_name.rsplit(".", maxsplit=1)[1]
import_package = "betterproto.lib.google.protobuf"
import_alias = safe_snake_case(import_package)
imports.add(f"import {import_package} as {import_alias}")
return f"{import_alias}.{type_name}"
source_package, source_type = parse_source_type_name(source_type)
# Use precompiled classes for google.protobuf.* objects
if source_package == "google.protobuf":
string_import = f"betterproto.lib.{source_package}"
py_type = source_type
string_alias = safe_snake_case(string_import)
imports.add(f"import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
py_package: List[str] = source_package.split(".") if source_package else []
py_type: str = pythonize_class_name(source_type)
importing_package: List[str] = type_name.split(".")
importing_type: str = stringcase.pascalcase(importing_package.pop())
current_package: List[str] = package.split(".") if package else []
# importing sibling
@@ -67,9 +80,8 @@ def get_ref_type(
package = foo.bar
name = foo.bar.Baz
"""
if importing_package == current_package:
imports.add(f"from . import {importing_type}")
return importing_type
if py_package == current_package:
return f'"{py_type}"'
# importing child & descendent:
"""
@@ -79,18 +91,18 @@ def get_ref_type(
package =
name = foo.bar.Baz
"""
if importing_package[0 : len(current_package)] == current_package:
importing_descendent = importing_package[len(current_package) :]
if py_package[0 : len(current_package)] == current_package:
importing_descendent = py_package[len(current_package) :]
string_from = ".".join(importing_descendent[0:-1])
string_import = importing_descendent[-1]
if string_from:
string_alias = "_".join(importing_descendent)
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{importing_type}"
return f"{string_alias}.{py_type}"
else:
imports.add(f"from . import {string_import}")
return f"{string_import}.{importing_type}"
return f"{string_import}.{py_type}"
# importing parent & ancestor
"""
@@ -103,10 +115,10 @@ def get_ref_type(
package = foo.bar.baz
name = Bar
"""
if current_package[0 : len(importing_package)] == importing_package:
distance_up = len(current_package) - len(importing_package)
imports.add(f"from .{'.' * distance_up} import {importing_type}")
return importing_type
if current_package[0 : len(py_package)] == py_package:
distance_up = len(current_package) - len(py_package)
imports.add(f"from .{'.' * distance_up} import {py_type}")
return py_type
# importing unrelated or cousin
"""
@@ -116,20 +128,16 @@ def get_ref_type(
package = foo.bar.baz
name = foo.example.Bar
"""
shared_ancestory = [
pair[0]
for pair in zip(current_package, importing_package)
if pair[0] == pair[1]
pair[0] for pair in zip(current_package, py_package) if pair[0] == pair[1]
]
distance_up = len(current_package) - len(shared_ancestory)
string_from = f".{'.' * distance_up}" + ".".join(
importing_package[len(shared_ancestory) : -1]
py_package[len(shared_ancestory) : -1]
)
string_import = importing_package[-1]
string_alias = f"{'_' * distance_up}" + safe_snake_case(
".".join(importing_package[len(shared_ancestory) :])
string_import = py_package[-1]
alias = f"{'_' * distance_up}" + safe_snake_case(
".".join(py_package[len(shared_ancestory) :])
)
imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{importing_type}"
imports.add(f"from {string_from} import {string_import} as {alias}")
return f"{alias}.{py_type}"

View File

@@ -0,0 +1,13 @@
from betterproto import casing
def pythonize_class_name(name):
return casing.pascal_case(name)
def pythonize_field_name(name: str):
return casing.safe_snake_case(name)
def pythonize_method_name(name: str):
return casing.safe_snake_case(name)