Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve function declaration wrapping when it contains generic type definitions #4553

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
- Collapse multiple empty lines after an import into one (#4489)
- Prevent `string_processing` and `wrap_long_dict_values_in_parens` from removing
parentheses around long dictionary values (#4377)
- Improve the way black wraps generic type definitions in function declarations,
so it will prioritize wrapping parameters instead of the generic types. (#4553)

### Configuration

Expand Down
2 changes: 2 additions & 0 deletions docs/the_black_code_style/future_style.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Currently, the following features are included in the preview style:
cases)
- `always_one_newline_after_import`: Always force one blank line after import
statements, except when the line after the import is a comment or an import statement
- `generic_type_def_wrapping`: Improves the wrapping on function definitions that
containain generic type definitions, such as `def func[T](a: T, b: T):`.

(labels/unstable-features)=

Expand Down
38 changes: 38 additions & 0 deletions src/black/linegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def left_hand_split(
Prefer RHS otherwise. This is why this function is not symmetrical with
:func:`right_hand_split` which also handles optional parentheses.
"""

tail_leaves: list[Leaf] = []
body_leaves: list[Leaf] = []
head_leaves: list[Leaf] = []
Expand All @@ -805,6 +806,14 @@ def left_hand_split(
current_leaves.append(leaf)
if current_leaves is head_leaves:
if leaf.type in OPENING_BRACKETS:
if (
Preview.generic_type_def_wrapping in mode
and leaf.type
== token.LSQB # '[' indicates a generic type declaration
and not _should_generic_type_def_in_func_be_splitted(leaf, mode)
):
continue

matching_bracket = leaf
current_leaves = body_leaves
if not matching_bracket or not tail_leaves:
Expand All @@ -825,6 +834,35 @@ def left_hand_split(
yield result


def _should_generic_type_def_in_func_be_splitted(
opening_square_bracket: Leaf, mode: Mode
) -> bool:
"""
Receives the leaf of the opening square bracket that starts
the definition of a generic type in a function signature
"""

# Determine the length of the function definition from its start to ']'
func_def_length = opening_square_bracket.column
current_node: Union[Node, Leaf] = opening_square_bracket

# Add the length of tokens until the closing bracket is reached
while current_node.next_sibling is not None and current_node.type != token.RSQB:
current_node = current_node.next_sibling
func_def_length += len(str(current_node))

# Check if generic types should be split
generic_types_should_be_splitted = (
func_def_length >= mode.line_length - 3 # Exceeds line length
or (
current_node.prev_sibling is not None
and current_node.prev_sibling.type == token.COMMA # Trailing comma rule
)
)

return generic_types_should_be_splitted


def right_hand_split(
line: Line,
mode: Mode,
Expand Down
1 change: 1 addition & 0 deletions src/black/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class Preview(Enum):
remove_lone_list_item_parens = auto()
pep646_typed_star_arg_type_var_tuple = auto()
always_one_newline_after_import = auto()
generic_type_def_wrapping = auto()


UNSTABLE_FEATURES: set[Preview] = {
Expand Down
3 changes: 2 additions & 1 deletion src/black/resources/black.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@
"parens_for_long_if_clauses_in_case_block",
"remove_lone_list_item_parens",
"pep646_typed_star_arg_type_var_tuple",
"always_one_newline_after_import"
"always_one_newline_after_import",
"generic_type_def_wrapping"
]
},
"description": "Enable specific features included in the `--unstable` style. Requires `--preview`. No compatibility guarantees are provided on the behavior or existence of any unstable features."
Expand Down
99 changes: 99 additions & 0 deletions tests/data/cases/generics_wrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# flags: --minimum-version=3.12 --preview
def func[T](a: T, b: T,) -> T:
return a


def with_magic_trailling_comma[
T,
B,
](a: T, b: T,) -> T:
return a


def without_magic_trailling_comma[
T,
B
](a: T, b: T,) -> T:
return a


def func[
T
](a: T, b: T,) -> T:
return a


def something_something_function[
T: Model
](param: list[int], other_param: type[T], *, some_other_param: bool = True) -> QuerySet[
T
]:
pass


def func[A_LOT_OF_GENERIC_TYPES: AreBeingDefinedHere, LIKE_THIS, AND_THIS, ANOTHER_ONE, AND_YET_ANOTHER_ONE: ThisOneHasTyping](a: T, b: T, c: T, d: T, e: T, f: T, g: T, h: T, i: T, j: T, k: T, l: T, m: T, n: T, o: T, p: T) -> T:
return a


# output


def func[T](
a: T,
b: T,
) -> T:
return a


def with_magic_trailling_comma[
T,
B,
](a: T, b: T,) -> T:
return a


def without_magic_trailling_comma[T, B](
a: T,
b: T,
) -> T:
return a


def func[T](
a: T,
b: T,
) -> T:
return a


def something_something_function[T: Model](
param: list[int], other_param: type[T], *, some_other_param: bool = True
) -> QuerySet[T]:
pass


def func[
A_LOT_OF_GENERIC_TYPES: AreBeingDefinedHere,
LIKE_THIS,
AND_THIS,
ANOTHER_ONE,
AND_YET_ANOTHER_ONE: ThisOneHasTyping,
](
a: T,
b: T,
c: T,
d: T,
e: T,
f: T,
g: T,
h: T,
i: T,
j: T,
k: T,
l: T,
m: T,
n: T,
o: T,
p: T,
) -> T:
return a
61 changes: 61 additions & 0 deletions tests/data/cases/new_type_param_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# flags: --minimum-version=3.13 --preview

type A[T=int] = float
type B[*P=int] = float
type C[*Ts=int] = float
type D[*Ts=*int] = float
type D[something_that_is_very_very_very_long=something_that_is_very_very_very_long] = float
type D[*something_that_is_very_very_very_long=*something_that_is_very_very_very_long] = float
type something_that_is_long[something_that_is_long=something_that_is_long] = something_that_is_long

def simple[T=something_that_is_long](short1: int, short2: str, short3: bytes) -> float:
pass

def longer[something_that_is_long=something_that_is_long](something_that_is_long: something_that_is_long) -> something_that_is_long:
pass

def trailing_comma1[T=int,](a: str):
pass

def trailing_comma2[T=int](a: str,):
pass

# output

type A[T = int] = float
type B[*P = int] = float
type C[*Ts = int] = float
type D[*Ts = *int] = float
type D[
something_that_is_very_very_very_long = something_that_is_very_very_very_long
] = float
type D[
*something_that_is_very_very_very_long = *something_that_is_very_very_very_long
] = float
type something_that_is_long[
something_that_is_long = something_that_is_long
] = something_that_is_long


def simple[T = something_that_is_long](
short1: int, short2: str, short3: bytes
) -> float:
pass


def longer[something_that_is_long = something_that_is_long](
something_that_is_long: something_that_is_long,
) -> something_that_is_long:
pass


def trailing_comma1[
T = int,
](a: str):
pass


def trailing_comma2[T = int](
a: str,
):
pass
2 changes: 1 addition & 1 deletion tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_pep_695_version_detection(self) -> None:
self.assertIn(black.TargetVersion.PY312, versions)

def test_pep_696_version_detection(self) -> None:
source, _ = read_data("cases", "type_param_defaults")
source, _ = read_data("cases", "stable_type_param_defaults")
samples = [
source,
"type X[T=int] = float",
Expand Down
Loading