Skip to content

Commit 624ecdf

Browse files
authored
Refine typing. (#440)
1 parent 5cd12ac commit 624ecdf

38 files changed

+213
-198
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,13 @@ repos:
6262
rev: 'v1.5.1'
6363
hooks:
6464
- id: mypy
65-
args: [
66-
--no-strict-optional,
67-
--ignore-missing-imports,
68-
]
6965
additional_dependencies: [
7066
attrs>=21.3.0,
7167
click,
7268
optree,
7369
pluggy,
70+
rich,
71+
sqlalchemy,
7472
types-setuptools,
7573
]
7674
pass_filenames: false

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
5555
- {pull}`437` fixes the detection of task functions and publishes
5656
{func}`pytask.is_task_function`.
5757
- {pull}`438` clarifies some types.
58+
- {pull}`440` refines more types.
5859

5960
## 0.3.2 - 2023-06-07
6061

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ module = "tests.*"
2323
disallow_untyped_defs = false
2424
ignore_errors = true
2525

26+
[[tool.mypy.overrides]]
27+
module = ["click_default_group", "networkx"]
28+
ignore_missing_imports = true
29+
30+
[[tool.mypy.overrides]]
31+
module = ["_pytask.hookspecs"]
32+
disable_error_code = ["empty-body"]
33+
2634

2735
[tool.codespell]
2836
ignore-words-list = "falsy, hist, ines, unparseable"

src/_pytask/build.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,10 @@ def build( # noqa: C901, PLR0912, PLR0913, PLR0915
233233
session = Session.from_config(config_)
234234

235235
except (ConfigurationError, Exception):
236-
exc_info = sys.exc_info()
237-
exc_info = remove_internal_traceback_frames_from_exc_info(exc_info)
236+
exc_info = remove_internal_traceback_frames_from_exc_info(sys.exc_info())
238237
traceback = Traceback.from_exception(*exc_info)
239238
console.print(traceback)
240-
session = Session({}, None)
241-
session.exit_code = ExitCode.CONFIGURATION_FAILED
239+
session = Session(exit_code=ExitCode.CONFIGURATION_FAILED)
242240

243241
else:
244242
try:
@@ -257,8 +255,7 @@ def build( # noqa: C901, PLR0912, PLR0913, PLR0915
257255
session.exit_code = ExitCode.FAILED
258256

259257
except Exception: # noqa: BLE001
260-
exc_info = sys.exc_info()
261-
exc_info = remove_internal_traceback_frames_from_exc_info(exc_info)
258+
exc_info = remove_internal_traceback_frames_from_exc_info(sys.exc_info())
262259
traceback = Traceback.from_exception(*exc_info)
263260
console.print(traceback)
264261
session.exit_code = ExitCode.FAILED

src/_pytask/capture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def resume_capturing(self) -> None:
619619
if self.err:
620620
self.err.resume()
621621
if self._in_suspended:
622-
self.in_.resume()
622+
self.in_.resume() # type: ignore[union-attr]
623623
self._in_suspended = False
624624

625625
def stop_capturing(self) -> None:

src/_pytask/clean.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636

3737
if TYPE_CHECKING:
38-
from types import TracebackType
3938
from typing import NoReturn
4039

4140

@@ -112,12 +111,8 @@ def clean(**raw_config: Any) -> NoReturn: # noqa: C901, PLR0912, PLR0915
112111
session = Session.from_config(config)
113112

114113
except Exception: # noqa: BLE001
115-
session = Session({}, None)
116-
session.exit_code = ExitCode.CONFIGURATION_FAILED
117-
exc_info: tuple[
118-
type[BaseException], BaseException, TracebackType | None
119-
] = sys.exc_info()
120-
console.print(render_exc_info(*exc_info))
114+
session = Session(exit_code=ExitCode.CONFIGURATION_FAILED)
115+
console.print(render_exc_info(*sys.exc_info()))
121116

122117
else:
123118
try:
@@ -163,7 +158,7 @@ def clean(**raw_config: Any) -> NoReturn: # noqa: C901, PLR0912, PLR0915
163158
)
164159

165160
console.print()
166-
console.rule(style=None)
161+
console.rule(style="default")
167162

168163
except CollectionError:
169164
session.exit_code = ExitCode.COLLECTION_FAILED

src/_pytask/click.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def convert(
5656
class _OptionHighlighter(RegexHighlighter):
5757
"""A highlighter for help texts."""
5858

59-
highlights: ClassVar = [
59+
highlights: ClassVar = [ # type: ignore[misc]
6060
r"(?P<switch>\-\w)\b",
6161
r"(?P<option>\-\-[\w\-]+)",
6262
r"\-\-[\w\-]+(?P<metavar>[ |=][\w\.:]+)",
@@ -192,7 +192,7 @@ def _print_options(
192192

193193
def _format_help_text( # noqa: C901, PLR0912, PLR0915
194194
param: click.Parameter, ctx: click.Context
195-
) -> str:
195+
) -> Text:
196196
"""Format the help of a click parameter.
197197
198198
A large chunk of the function is copied from

src/_pytask/collect.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,16 @@ def _collect_from_tasks(session: Session) -> None:
112112
except (TypeError, OSError):
113113
path = None
114114
else:
115-
if path.name == "<stdin>":
115+
if path and path.name == "<stdin>":
116116
path = None # pragma: no cover
117117

118118
# Detect whether a path is defined in a Jupyter notebook.
119-
if is_jupyter() and "ipykernel" in path.as_posix() and path.suffix == ".py":
119+
if (
120+
is_jupyter()
121+
and path
122+
and "ipykernel" in path.as_posix()
123+
and path.suffix == ".py"
124+
):
120125
path = None # pragma: no cover
121126

122127
name = raw_task.pytask_meta.name
@@ -209,7 +214,11 @@ def pytask_collect_task_protocol(
209214
return CollectionReport(outcome=CollectionOutcome.SUCCESS, node=task)
210215

211216
except Exception: # noqa: BLE001
212-
task = Task(base_name=name, path=path, function=None)
217+
if path:
218+
task = Task(base_name=name, path=path, function=obj)
219+
else:
220+
task = TaskWithoutPath(name=name, function=obj)
221+
213222
return CollectionReport.from_exception(
214223
outcome=CollectionOutcome.FAIL, exc_info=sys.exc_info(), node=task
215224
)
@@ -463,7 +472,7 @@ def pytask_collect_log(
463472
if isinstance(report.node, PTask):
464473
short_name = format_task_name(
465474
report.node, editor_url_scheme="no_link"
466-
)
475+
).plain
467476
else:
468477
short_name = reduce_node_name(report.node, session.config["paths"])
469478
header = f"Could not collect {short_name}"
@@ -475,8 +484,11 @@ def pytask_collect_log(
475484

476485
console.print()
477486

487+
assert report.exc_info
478488
console.print(
479-
render_exc_info(*report.exc_info, session.config["show_locals"])
489+
render_exc_info(
490+
*report.exc_info, show_locals=session.config["show_locals"]
491+
)
480492
)
481493

482494
console.print()

src/_pytask/collect_command.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def collect(**raw_config: Any | None) -> NoReturn:
6868
session = Session.from_config(config)
6969

7070
except (ConfigurationError, Exception):
71-
session = Session({}, None)
72-
session.exit_code = ExitCode.CONFIGURATION_FAILED
71+
session = Session(exit_code=ExitCode.CONFIGURATION_FAILED)
7372
console.print_exception()
7473

7574
else:
@@ -222,17 +221,15 @@ def _print_collected_tasks( # noqa: PLR0912
222221
reduced_path = str(
223222
relative_to(Path(path_part), common_ancestor)
224223
)
225-
text = reduced_path + "::" + rest
224+
text = Text(reduced_path + "::" + rest)
226225
except Exception: # noqa: BLE001
227-
text = node.name
226+
text = Text(node.name)
228227

229228
task_branch.add(Text.assemble(FILE_ICON, "<Dependency ", text, ">"))
230229

231230
for node in sorted( # type: ignore[assignment]
232231
tree_leaves(task.produces),
233-
key=lambda x: getattr(
234-
x, "path", x.name # type: ignore[attr-defined]
235-
),
232+
key=lambda x: x.path if isinstance(x, PPathNode) else x.name, # type: ignore[attr-defined]
236233
):
237234
if isinstance(node, PPathNode):
238235
reduced_node_name = str(relative_to(node.path, common_ancestor))

src/_pytask/config_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def set_defaults_from_config(
1919
# command-line options during parsing. Here, we add their defaults to the
2020
# configuration.
2121
command_option_names = [option.name for option in context.command.params]
22-
commands = context.parent.command.commands # type: ignore[attr-defined]
22+
commands = context.parent.command.commands # type: ignore[union-attr]
2323
all_defaults_from_cli = {
2424
option.name: option.default
2525
for name, command in commands.items()
@@ -54,7 +54,7 @@ def set_defaults_from_config(
5454
return context.params["config"]
5555

5656

57-
def _find_project_root_and_config(paths: list[Path]) -> tuple[Path, Path]:
57+
def _find_project_root_and_config(paths: list[Path] | None) -> tuple[Path, Path | None]:
5858
"""Find the project root and configuration file from a list of paths.
5959
6060
The process is as follows:
@@ -68,7 +68,7 @@ def _find_project_root_and_config(paths: list[Path]) -> tuple[Path, Path]:
6868
6969
"""
7070
try:
71-
common_ancestor = Path(os.path.commonpath(paths))
71+
common_ancestor = Path(os.path.commonpath(paths)) # type: ignore[arg-type]
7272
except ValueError:
7373
common_ancestor = Path.cwd()
7474

src/_pytask/console.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import rich
1616
from _pytask.nodes import Task
1717
from rich.console import Console
18+
from rich.console import RenderableType
1819
from rich.padding import Padding
1920
from rich.panel import Panel
2021
from rich.segment import Segment
@@ -103,7 +104,7 @@
103104

104105

105106
def render_to_string(
106-
text: str | Text,
107+
text: RenderableType,
107108
*,
108109
console: Console | None = None,
109110
strip_styles: bool = False,
@@ -201,7 +202,7 @@ def create_url_style_for_path(path: Path, edtior_url_scheme: str) -> Style:
201202

202203
def get_file(
203204
function: Callable[..., Any], skipped_paths: list[Path] | None = None
204-
) -> Path:
205+
) -> Path | None:
205206
"""Get path to module where the function is defined.
206207
207208
When the ``pdb`` or ``trace`` mode is activated, every task function is wrapped with
@@ -214,12 +215,14 @@ def get_file(
214215

215216
if isinstance(function, functools.partial):
216217
return get_file(function.func)
217-
if (
218-
hasattr(function, "__wrapped__")
219-
and Path(inspect.getsourcefile(function)) in skipped_paths
220-
):
221-
return get_file(function.__wrapped__)
222-
return Path(inspect.getsourcefile(function))
218+
if hasattr(function, "__wrapped__"):
219+
source_file = inspect.getsourcefile(function)
220+
if source_file and Path(source_file) in skipped_paths:
221+
return get_file(function.__wrapped__)
222+
source_file = inspect.getsourcefile(function)
223+
if source_file:
224+
return Path(source_file)
225+
return None
223226

224227

225228
def _get_source_lines(function: Callable[..., Any]) -> int:

src/_pytask/dag.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def pytask_dag(session: Session) -> bool | None:
6565
except Exception: # noqa: BLE001
6666
report = DagReport.from_exception(sys.exc_info())
6767
session.hook.pytask_dag_log(session=session, report=report)
68-
session.resolving_dependencies_report = report
68+
session.dag_reports = report
6969

7070
raise ResolvingDependenciesError from None
7171

@@ -336,7 +336,9 @@ def pytask_dag_log(session: Session, report: DagReport) -> None:
336336
)
337337

338338
console.print()
339-
console.print(render_exc_info(*report.exc_info, session.config["show_locals"]))
339+
console.print(
340+
render_exc_info(*report.exc_info, show_locals=session.config["show_locals"])
341+
)
340342

341343
console.print()
342344
console.rule(style="failed")

src/_pytask/dag_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ class TopologicalSorter:
6767
"""
6868

6969
dag: nx.DiGraph
70+
dag_backup: nx.DiGraph
7071
priorities: dict[str, int] = field(factory=dict)
71-
_dag_backup: nx.DiGraph | None = None
7272
_is_prepared: bool = False
7373
_nodes_out: set[str] = field(factory=set)
7474

@@ -88,7 +88,7 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
8888
task_dict = {name: nx.ancestors(dag, name) & task_names for name in task_names}
8989
task_dag = nx.DiGraph(task_dict).reverse()
9090

91-
return cls(task_dag, priorities, task_dag.copy())
91+
return cls(dag=task_dag, priorities=priorities, dag_backup=task_dag.copy())
9292

9393
def prepare(self) -> None:
9494
"""Perform some checks before creating a topological ordering."""
@@ -131,7 +131,8 @@ def done(self, *nodes: str) -> None:
131131

132132
def reset(self) -> None:
133133
"""Reset an exhausted topological sorter."""
134-
self.dag = self._dag_backup.copy()
134+
if self.dag_backup:
135+
self.dag = self.dag_backup.copy()
135136
self._is_prepared = False
136137
self._nodes_out = set()
137138

src/_pytask/database_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def _create_or_update_state(
6969
)
7070
)
7171
else:
72-
state_in_db.modification_time = modification_time
73-
state_in_db.hash_ = hash_
72+
state_in_db.modification_time = (
73+
modification_time # type: ignore[assignment]
74+
)
75+
state_in_db.hash_ = hash_ # type: ignore[assignment]
7476

7577
session.commit()
7678

@@ -90,4 +92,6 @@ def update_states_in_database(session: Session, task_name: str) -> None:
9092
modification_time = ""
9193
hash_ = node.state()
9294

93-
_create_or_update_state(task_name, node.name, modification_time, hash_)
95+
_create_or_update_state(
96+
task_name, node.name, modification_time, hash_ # type: ignore[arg-type]
97+
)

0 commit comments

Comments
 (0)