Refactor default value code
This commit is contained in:
@@ -35,13 +35,17 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
if type_name.startswith(package):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
|
||||
if "." in type_name:
|
||||
# This is imported from another package. No need
|
||||
|
||||
Reference in New Issue
Block a user