diff --git a/betterproto/casing.py b/betterproto/casing.py index 01e10bb..60ece6a 100644 --- a/betterproto/casing.py +++ b/betterproto/casing.py @@ -53,30 +53,65 @@ def safe_snake_case(value: str) -> str: return value -def snake_case(value: str): +def snake_case(value: str, strict: bool = True): """ Join words with an underscore into lowercase and remove symbols. + @param value: value to convert + @param strict: force single underscores """ + + def substitute_word(symbols, word, is_start): + if not word: + return "" + if strict: + delimiter_count = 0 if is_start else 1 # Single underscore if strict. + elif is_start: + delimiter_count = len(symbols) + elif word.isupper() or word.islower(): + delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict. + else: + delimiter_count = len(symbols) + 1 # Extra underscore for leading capital. + + return ("_" * delimiter_count) + word.lower() + snake = re.sub( - f"{SYMBOLS}({WORD_UPPER}|{WORD})", lambda groups: "_" + groups[1].lower(), value + f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})", + lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None), + value, ) - return snake.strip("_") + return snake -def pascal_case(value: str): +def pascal_case(value: str, strict: bool = True): """ Capitalize each word and remove symbols. + @param value: value to convert + @param strict: output only alphanumeric characters """ + + def substitute_word(symbols, word): + if strict: + return word.capitalize() # Remove all delimiters + + if word.islower(): + delimiter_length = len(symbols[:-1]) # Lose one delimiter + else: + delimiter_length = len(symbols) # Preserve all delimiters + + return ("_" * delimiter_length) + word.capitalize() + return re.sub( - f"{SYMBOLS}({WORD_UPPER}|{WORD})", lambda groups: groups[1].capitalize(), value + f"({SYMBOLS})({WORD_UPPER}|{WORD})", + lambda groups: substitute_word(groups[1], groups[2]), + value, ) -def camel_case(value: str): +def camel_case(value: str, strict: bool = True): """ Capitalize all words except first and remove symbols. """ - return lowercase_first(pascal_case(value)) + return lowercase_first(pascal_case(value, strict=strict)) def lowercase_first(value: str): diff --git a/betterproto/tests/test_casing.py b/betterproto/tests/test_casing.py index ac18309..ec60483 100644 --- a/betterproto/tests/test_casing.py +++ b/betterproto/tests/test_casing.py @@ -29,7 +29,7 @@ from betterproto.casing import camel_case, pascal_case, snake_case ], ) def test_pascal_case(value, expected): - actual = pascal_case(value) + actual = pascal_case(value, strict=True) assert actual == expected, f"{value} => {expected} (actual: {actual})" @@ -56,8 +56,22 @@ def test_pascal_case(value, expected): ("1foobar", "1Foobar"), ], ) -def test_camel_case(value, expected): - actual = camel_case(value) +def test_camel_case_strict(value, expected): + actual = camel_case(value, strict=True) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("foo_bar", "fooBar"), + ("FooBar", "fooBar"), + ("foo__bar", "foo_Bar"), + ("foo__Bar", "foo__Bar"), + ], +) +def test_camel_case_not_strict(value, expected): + actual = camel_case(value, strict=False) assert actual == expected, f"{value} => {expected} (actual: {actual})" @@ -71,6 +85,7 @@ def test_camel_case(value, expected): ("FooBar", "foo_bar"), ("foo.bar", "foo_bar"), ("foo_bar", "foo_bar"), + ("foo_Bar", "foo_bar"), ("FOOBAR", "foobar"), ("FOOBar", "foo_bar"), ("UInt32", "u_int32"), @@ -85,8 +100,26 @@ def test_camel_case(value, expected): ("foo~bar", "foo_bar"), ("foo:bar", "foo_bar"), ("1foobar", "1_foobar"), + ("GetUInt64", "get_u_int64"), ], ) -def test_snake_case(value, expected): +def test_snake_case_strict(value, expected): actual = snake_case(value) assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("fooBar", "foo_bar"), + ("FooBar", "foo_bar"), + ("foo_Bar", "foo__bar"), + ("foo__bar", "foo__bar"), + ("FOOBar", "foo_bar"), + ("__foo", "__foo"), + ("GetUInt64", "get_u_int64"), + ], +) +def test_snake_case_not_strict(value, expected): + actual = snake_case(value, strict=False) + assert actual == expected, f"{value} => {expected} (actual: {actual})"