Skip to content

Commit d21432b

Browse files
committed
imprv: generate: Hint callbacks
1 parent a69db0d commit d21432b

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

tools/generate.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,12 @@ def _type_to_python(
292292
else:
293293
return_type = f"Tuple[{', '.join(return_args)}]"
294294

295-
# FIXME, how to express Callable with variable arguments?
296295
if (len(names) > 0 and names[-1].startswith("*")) or varargs:
297-
return f"Callable[..., {return_type}]"
296+
args = args[:-1] # Closure data is always the last argument
297+
if len(args) > 0:
298+
return f"Callable[Concatenate[{', '.join(args)}, _VarArgs], {return_type}]"
299+
else:
300+
return f"Callable[Concatenate[_VarArgs], {return_type}]"
298301
else:
299302
return f"Callable[[{', '.join(args)}], {return_type}]"
300303
else:
@@ -326,9 +329,25 @@ def _build(parent: ObjectT, namespace: str, overrides: dict[str, str]) -> str:
326329
ns = set()
327330
ret = _gi_build_stub(parent, namespace, dir(parent), ns, overrides, None, "")
328331

329-
typings = "from typing import Any, Callable, Literal, Optional, Tuple, Type, TypeVar, Sequence"
332+
typings_list = [
333+
"Any",
334+
"Callable",
335+
"Literal",
336+
"Optional",
337+
"Tuple",
338+
"Type",
339+
"TypeVar",
340+
"Sequence",
341+
]
342+
343+
typing_extensions_list = ["Concatenate", "ParamSpec"]
344+
345+
typings = f"from typing import {' ,'.join(typings_list)}"
346+
typing_extensions = (
347+
f"from typing_extensions import {' ,'.join(typing_extensions_list)}"
348+
)
330349

331-
typevars: list[str] = []
350+
typevars: list[str] = ["_VarArgs = ParamSpec('_VarArgs')"]
332351
imports: list[str] = []
333352
if "cairo" in ns:
334353
imports = ["import cairo"]
@@ -340,6 +359,8 @@ def _build(parent: ObjectT, namespace: str, overrides: dict[str, str]) -> str:
340359
return (
341360
typings
342361
+ "\n\n"
362+
+ typing_extensions
363+
+ "\n\n"
343364
+ "\n".join(imports)
344365
+ "\n"
345366
+ "\n".join(typevars)

0 commit comments

Comments
 (0)