Refactor default value code

This commit is contained in:
Daniel G. Taylor
2019-10-25 21:16:32 -07:00
parent 4679c571c3
commit 5daf61f64c
2 changed files with 51 additions and 47 deletions

View File

@@ -35,13 +35,17 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
Return a Python type name for a proto type reference. Adds the import if
necessary.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
if type_name.startswith(package):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
parts = type_name.lstrip(package).lstrip(".").split(".")
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
if "." in type_name:
# This is imported from another package. No need