diff --git a/CHANGES.md b/CHANGES.md index 416aabcdf13..1cc4ad2c707 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -26,6 +26,8 @@ - Collapse multiple empty lines after an import into one (#4489) - Prevent `string_processing` and `wrap_long_dict_values_in_parens` from removing parentheses around long dictionary values (#4377) +- Improve the way black wraps generic type definitions in function declarations, + so it will prioritize wrapping parameters instead of the generic types. (#4553) ### Configuration diff --git a/docs/the_black_code_style/future_style.md b/docs/the_black_code_style/future_style.md index bc3e233ed3d..8fc2987b2b6 100644 --- a/docs/the_black_code_style/future_style.md +++ b/docs/the_black_code_style/future_style.md @@ -43,6 +43,8 @@ Currently, the following features are included in the preview style: cases) - `always_one_newline_after_import`: Always force one blank line after import statements, except when the line after the import is a comment or an import statement +- `generic_type_def_wrapping`: Improves the wrapping on function definitions that + containain generic type definitions, such as `def func[T](a: T, b: T):`. (labels/unstable-features)= diff --git a/src/black/linegen.py b/src/black/linegen.py index 90ca5da8587..1405d9ad260 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -787,6 +787,7 @@ def left_hand_split( Prefer RHS otherwise. This is why this function is not symmetrical with :func:`right_hand_split` which also handles optional parentheses. """ + tail_leaves: list[Leaf] = [] body_leaves: list[Leaf] = [] head_leaves: list[Leaf] = [] @@ -805,6 +806,14 @@ def left_hand_split( current_leaves.append(leaf) if current_leaves is head_leaves: if leaf.type in OPENING_BRACKETS: + if ( + Preview.generic_type_def_wrapping in mode + and leaf.type + == token.LSQB # '[' indicates a generic type declaration + and not _should_generic_type_def_in_func_be_splitted(leaf, mode) + ): + continue + matching_bracket = leaf current_leaves = body_leaves if not matching_bracket or not tail_leaves: @@ -825,6 +834,35 @@ def left_hand_split( yield result +def _should_generic_type_def_in_func_be_splitted( + opening_square_bracket: Leaf, mode: Mode +) -> bool: + """ + Receives the leaf of the opening square bracket that starts + the definition of a generic type in a function signature + """ + + # Determine the length of the function definition from its start to ']' + func_def_length = opening_square_bracket.column + current_node: Union[Node, Leaf] = opening_square_bracket + + # Add the length of tokens until the closing bracket is reached + while current_node.next_sibling is not None and current_node.type != token.RSQB: + current_node = current_node.next_sibling + func_def_length += len(str(current_node)) + + # Check if generic types should be split + generic_types_should_be_splitted = ( + func_def_length >= mode.line_length - 3 # Exceeds line length + or ( + current_node.prev_sibling is not None + and current_node.prev_sibling.type == token.COMMA # Trailing comma rule + ) + ) + + return generic_types_should_be_splitted + + def right_hand_split( line: Line, mode: Mode, diff --git a/src/black/mode.py b/src/black/mode.py index 96f72cc2ded..e6cba0e0ebb 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -215,6 +215,7 @@ class Preview(Enum): remove_lone_list_item_parens = auto() pep646_typed_star_arg_type_var_tuple = auto() always_one_newline_after_import = auto() + generic_type_def_wrapping = auto() UNSTABLE_FEATURES: set[Preview] = { diff --git a/src/black/resources/black.schema.json b/src/black/resources/black.schema.json index f1d471e7616..7174687d6f4 100644 --- a/src/black/resources/black.schema.json +++ b/src/black/resources/black.schema.json @@ -93,7 +93,8 @@ "parens_for_long_if_clauses_in_case_block", "remove_lone_list_item_parens", "pep646_typed_star_arg_type_var_tuple", - "always_one_newline_after_import" + "always_one_newline_after_import", + "generic_type_def_wrapping" ] }, "description": "Enable specific features included in the `--unstable` style. Requires `--preview`. No compatibility guarantees are provided on the behavior or existence of any unstable features." diff --git a/tests/data/cases/generics_wrapping.py b/tests/data/cases/generics_wrapping.py new file mode 100644 index 00000000000..baf14d4c6e2 --- /dev/null +++ b/tests/data/cases/generics_wrapping.py @@ -0,0 +1,99 @@ +# flags: --minimum-version=3.12 --preview +def func[T](a: T, b: T,) -> T: + return a + + +def with_magic_trailling_comma[ + T, + B, +](a: T, b: T,) -> T: + return a + + +def without_magic_trailling_comma[ + T, + B +](a: T, b: T,) -> T: + return a + + +def func[ + T +](a: T, b: T,) -> T: + return a + + +def something_something_function[ + T: Model +](param: list[int], other_param: type[T], *, some_other_param: bool = True) -> QuerySet[ + T +]: + pass + + +def func[A_LOT_OF_GENERIC_TYPES: AreBeingDefinedHere, LIKE_THIS, AND_THIS, ANOTHER_ONE, AND_YET_ANOTHER_ONE: ThisOneHasTyping](a: T, b: T, c: T, d: T, e: T, f: T, g: T, h: T, i: T, j: T, k: T, l: T, m: T, n: T, o: T, p: T) -> T: + return a + + +# output + + +def func[T]( + a: T, + b: T, +) -> T: + return a + + +def with_magic_trailling_comma[ + T, + B, +](a: T, b: T,) -> T: + return a + + +def without_magic_trailling_comma[T, B]( + a: T, + b: T, +) -> T: + return a + + +def func[T]( + a: T, + b: T, +) -> T: + return a + + +def something_something_function[T: Model]( + param: list[int], other_param: type[T], *, some_other_param: bool = True +) -> QuerySet[T]: + pass + + +def func[ + A_LOT_OF_GENERIC_TYPES: AreBeingDefinedHere, + LIKE_THIS, + AND_THIS, + ANOTHER_ONE, + AND_YET_ANOTHER_ONE: ThisOneHasTyping, +]( + a: T, + b: T, + c: T, + d: T, + e: T, + f: T, + g: T, + h: T, + i: T, + j: T, + k: T, + l: T, + m: T, + n: T, + o: T, + p: T, +) -> T: + return a diff --git a/tests/data/cases/new_type_param_defaults.py b/tests/data/cases/new_type_param_defaults.py new file mode 100644 index 00000000000..5e992b157cc --- /dev/null +++ b/tests/data/cases/new_type_param_defaults.py @@ -0,0 +1,61 @@ +# flags: --minimum-version=3.13 --preview + +type A[T=int] = float +type B[*P=int] = float +type C[*Ts=int] = float +type D[*Ts=*int] = float +type D[something_that_is_very_very_very_long=something_that_is_very_very_very_long] = float +type D[*something_that_is_very_very_very_long=*something_that_is_very_very_very_long] = float +type something_that_is_long[something_that_is_long=something_that_is_long] = something_that_is_long + +def simple[T=something_that_is_long](short1: int, short2: str, short3: bytes) -> float: + pass + +def longer[something_that_is_long=something_that_is_long](something_that_is_long: something_that_is_long) -> something_that_is_long: + pass + +def trailing_comma1[T=int,](a: str): + pass + +def trailing_comma2[T=int](a: str,): + pass + +# output + +type A[T = int] = float +type B[*P = int] = float +type C[*Ts = int] = float +type D[*Ts = *int] = float +type D[ + something_that_is_very_very_very_long = something_that_is_very_very_very_long +] = float +type D[ + *something_that_is_very_very_very_long = *something_that_is_very_very_very_long +] = float +type something_that_is_long[ + something_that_is_long = something_that_is_long +] = something_that_is_long + + +def simple[T = something_that_is_long]( + short1: int, short2: str, short3: bytes +) -> float: + pass + + +def longer[something_that_is_long = something_that_is_long]( + something_that_is_long: something_that_is_long, +) -> something_that_is_long: + pass + + +def trailing_comma1[ + T = int, +](a: str): + pass + + +def trailing_comma2[T = int]( + a: str, +): + pass diff --git a/tests/data/cases/type_param_defaults.py b/tests/data/cases/stable_type_param_defaults.py similarity index 100% rename from tests/data/cases/type_param_defaults.py rename to tests/data/cases/stable_type_param_defaults.py diff --git a/tests/test_black.py b/tests/test_black.py index 98d8ff886d7..2151d0726b8 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -253,7 +253,7 @@ def test_pep_695_version_detection(self) -> None: self.assertIn(black.TargetVersion.PY312, versions) def test_pep_696_version_detection(self) -> None: - source, _ = read_data("cases", "type_param_defaults") + source, _ = read_data("cases", "stable_type_param_defaults") samples = [ source, "type X[T=int] = float",