Skip to content

Commit 103a3fd

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Address instances of "Overloaded function signature x will never be matched" + minor typing fixes (#1381)
Summary: Pull Request resolved: #1381 Many overloads produced false positives or required changing order due to mypy breaking ties by picking the first matching variant (https://mypy.readthedocs.io/en/stable/more_types.html). This fixes or suppresses these errors. Created T204932142 to address Literal-related issues. Reviewed By: vivekmig Differential Revision: D64517613 fbshipit-source-id: 34f52d35cfba30af856762d14581eb30c69ce89f
1 parent cd45461 commit 103a3fd

15 files changed

+159
-139
lines changed

captum/_utils/common.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ def safe_div(
7373
@typing.overload
7474
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
7575
# is incompatible with the return type of the implementation (`bool`).
76-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
76+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
7777
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
78-
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
78+
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
7979

8080

8181
@typing.overload
8282
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
8383
# is incompatible with the return type of the implementation (`bool`).
84-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
84+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
8585
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
86-
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
86+
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
8787

8888

8989
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
@@ -277,7 +277,7 @@ def _format_additional_forward_args(
277277

278278

279279
@overload
280-
def _format_additional_forward_args(
280+
def _format_additional_forward_args( # type: ignore
281281
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
282282
additional_forward_args: Any,
283283
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
@@ -780,10 +780,10 @@ def _reduce_list(
780780
"""
781781
assert len(val_list) > 0, "Cannot reduce empty list!"
782782
if isinstance(val_list[0], torch.Tensor):
783-
# pyre-fixme[16]: `bool` has no attribute `device`.
784-
first_device = val_list[0].device
785-
# pyre-fixme[16]: `bool` has no attribute `to`.
786-
return red_func([elem.to(first_device) for elem in val_list])
783+
first_device = cast(Tensor, val_list[0]).device
784+
return red_func(
785+
[elem.to(first_device) for elem in cast(List[Tensor], val_list)]
786+
)
787787
elif isinstance(val_list[0], bool):
788788
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
789789
return any(val_list)

captum/_utils/gradient.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,33 +159,34 @@ def _neuron_gradients(
159159

160160
@typing.overload
161161
# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all
162-
# possible arguments of overload defined on line `158`.
162+
# possible arguments of overload defined on line `170`.
163163
def _forward_layer_eval(
164164
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
165165
forward_fn: Callable,
166166
inputs: Union[Tensor, Tuple[Tensor, ...]],
167-
layer: Module,
167+
layer: List[Module],
168168
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
169169
additional_forward_args: Any = None,
170170
device_ids: Union[None, List[int]] = None,
171171
attribute_to_layer_input: bool = False,
172172
grad_enabled: bool = False,
173-
) -> Tuple[Tensor, ...]: ...
173+
) -> List[Tuple[Tensor, ...]]: ...
174174

175175

176176
@typing.overload
177177
# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all
178-
# possible arguments of overload defined on line `170`.
178+
# possible arguments of overload defined on line `158`.
179179
def _forward_layer_eval(
180180
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
181181
forward_fn: Callable,
182182
inputs: Union[Tensor, Tuple[Tensor, ...]],
183-
layer: List[Module],
183+
layer: Module,
184+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
184185
additional_forward_args: Any = None,
185186
device_ids: Union[None, List[int]] = None,
186187
attribute_to_layer_input: bool = False,
187188
grad_enabled: bool = False,
188-
) -> List[Tuple[Tensor, ...]]: ...
189+
) -> Tuple[Tensor, ...]: ...
189190

190191

191192
def _forward_layer_eval(
@@ -434,34 +435,34 @@ def _forward_layer_eval_with_neuron_grads(
434435

435436
@typing.overload
436437
# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does
437-
# not accept all possible arguments of overload defined on line `392`.
438+
# not accept all possible arguments of overload defined on line `405`.
438439
def _forward_layer_eval_with_neuron_grads(
439440
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
440441
forward_fn: Callable,
441442
inputs: Union[Tensor, Tuple[Tensor, ...]],
442-
layer: Module,
443+
layer: List[Module],
443444
additional_forward_args: Any = None,
444445
gradient_neuron_selector: None = None,
445446
grad_enabled: bool = False,
446447
device_ids: Union[None, List[int]] = None,
447448
attribute_to_layer_input: bool = False,
448-
) -> Tuple[Tensor, ...]: ...
449+
) -> List[Tuple[Tensor, ...]]: ...
449450

450451

451452
@typing.overload
452453
# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does
453-
# not accept all possible arguments of overload defined on line `405`.
454+
# not accept all possible arguments of overload defined on line `392`.
454455
def _forward_layer_eval_with_neuron_grads(
455456
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
456457
forward_fn: Callable,
457458
inputs: Union[Tensor, Tuple[Tensor, ...]],
458-
layer: List[Module],
459+
layer: Module,
459460
additional_forward_args: Any = None,
460461
gradient_neuron_selector: None = None,
461462
grad_enabled: bool = False,
462463
device_ids: Union[None, List[int]] = None,
463464
attribute_to_layer_input: bool = False,
464-
) -> List[Tuple[Tensor, ...]]: ...
465+
) -> Tuple[Tensor, ...]: ...
465466

466467

467468
def _forward_layer_eval_with_neuron_grads(

captum/attr/_core/deep_lift.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,36 +118,37 @@ def __init__(
118118

119119
@typing.overload
120120
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
121-
# arguments of overload defined on line `120`.
121+
# arguments of overload defined on line `131`.
122122
def attribute(
123123
self,
124124
inputs: TensorOrTupleOfTensorsGeneric,
125125
baselines: BaselineType = None,
126126
target: TargetType = None,
127127
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
128128
additional_forward_args: Any = None,
129-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
130-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
129+
*,
130+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
131131
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
132-
return_convergence_delta: Literal[False] = False,
132+
return_convergence_delta: Literal[True],
133133
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
134-
) -> TensorOrTupleOfTensorsGeneric: ...
134+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
135135

136136
@typing.overload
137137
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
138-
# arguments of overload defined on line `131`.
138+
# arguments of overload defined on line `120`.
139139
def attribute(
140140
self,
141141
inputs: TensorOrTupleOfTensorsGeneric,
142142
baselines: BaselineType = None,
143143
target: TargetType = None,
144+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
144145
additional_forward_args: Any = None,
145-
*,
146-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
146+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
147+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
147148
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
148-
return_convergence_delta: Literal[True],
149+
return_convergence_delta: Literal[False] = False,
149150
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
150-
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
151+
) -> TensorOrTupleOfTensorsGeneric: ...
151152

152153
@log_usage()
153154
def attribute( # type: ignore
@@ -636,7 +637,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
636637
# DeepLiftShap.attribute, so we ignore typing here
637638
@typing.overload # type: ignore
638639
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
639-
# arguments of overload defined on line `584`.
640+
# arguments of overload defined on line `597`.
640641
def attribute(
641642
self,
642643
inputs: TensorOrTupleOfTensorsGeneric,
@@ -646,30 +647,31 @@ def attribute(
646647
target: TargetType = None,
647648
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
648649
additional_forward_args: Any = None,
649-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
650-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
650+
*,
651+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
651652
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
652-
return_convergence_delta: Literal[False] = False,
653+
return_convergence_delta: Literal[True],
653654
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
654-
) -> TensorOrTupleOfTensorsGeneric: ...
655+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
655656

656657
@typing.overload
657658
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
658-
# arguments of overload defined on line `597`.
659+
# arguments of overload defined on line `584`.
659660
def attribute(
660661
self,
661662
inputs: TensorOrTupleOfTensorsGeneric,
662663
baselines: Union[
663664
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
664665
],
665666
target: TargetType = None,
667+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
666668
additional_forward_args: Any = None,
667-
*,
668-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
669+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
670+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
669671
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
670-
return_convergence_delta: Literal[True],
672+
return_convergence_delta: Literal[False] = False,
671673
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
672-
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
674+
) -> TensorOrTupleOfTensorsGeneric: ...
673675

674676
@log_usage()
675677
def attribute( # type: ignore

captum/attr/_core/integrated_gradients.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
# a tuple with both attributions and deltas.
8282
@typing.overload
8383
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
84-
# arguments of overload defined on line `82`.
84+
# arguments of overload defined on line `95`.
8585
def attribute(
8686
self,
8787
inputs: TensorOrTupleOfTensorsGeneric,
@@ -92,29 +92,30 @@ def attribute(
9292
n_steps: int = 50,
9393
method: str = "gausslegendre",
9494
internal_batch_size: Union[None, int] = None,
95-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
96-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
95+
*,
96+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
9797
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
98-
return_convergence_delta: Literal[False] = False,
99-
) -> TensorOrTupleOfTensorsGeneric: ...
98+
return_convergence_delta: Literal[True],
99+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
100100

101101
@typing.overload
102102
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
103-
# arguments of overload defined on line `95`.
103+
# arguments of overload defined on line `82`.
104104
def attribute(
105105
self,
106106
inputs: TensorOrTupleOfTensorsGeneric,
107107
baselines: BaselineType = None,
108108
target: TargetType = None,
109+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
109110
additional_forward_args: Any = None,
110111
n_steps: int = 50,
111112
method: str = "gausslegendre",
112113
internal_batch_size: Union[None, int] = None,
113-
*,
114-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
114+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
115+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
115116
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116-
return_convergence_delta: Literal[True],
117-
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
117+
return_convergence_delta: Literal[False] = False,
118+
) -> TensorOrTupleOfTensorsGeneric: ...
118119

119120
@log_usage()
120121
def attribute( # type: ignore

captum/attr/_core/layer/layer_activation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, List, Tuple, Union
4+
from typing import Any, Callable, cast, List, Tuple, Union
55

66
import torch
77
from captum._utils.common import _format_output
@@ -128,7 +128,9 @@ def attribute(
128128
attribute_to_layer_input=attribute_to_layer_input,
129129
)
130130
if isinstance(self.layer, Module):
131-
return _format_output(len(layer_eval) > 1, layer_eval)
131+
return _format_output(
132+
len(layer_eval) > 1, cast(Tuple[Tensor, ...], layer_eval)
133+
)
132134
else:
133135
return [
134136
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,40 +102,41 @@ def __init__(
102102
# Ignoring mypy error for inconsistent signature with DeepLift
103103
@typing.overload # type: ignore
104104
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
105-
# arguments of overload defined on line `104`.
105+
# arguments of overload defined on line `117`.
106106
def attribute(
107107
self,
108108
inputs: Union[Tensor, Tuple[Tensor, ...]],
109109
baselines: BaselineType = None,
110110
target: TargetType = None,
111111
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
112112
additional_forward_args: Any = None,
113-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
114-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
113+
*,
114+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
115115
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116-
return_convergence_delta: Literal[False] = False,
116+
return_convergence_delta: Literal[True],
117117
attribute_to_layer_input: bool = False,
118118
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
119119
grad_kwargs: Optional[Dict[str, Any]] = None,
120-
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
120+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
121121

122122
@typing.overload
123123
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
124-
# arguments of overload defined on line `117`.
124+
# arguments of overload defined on line `104`.
125125
def attribute(
126126
self,
127127
inputs: Union[Tensor, Tuple[Tensor, ...]],
128128
baselines: BaselineType = None,
129129
target: TargetType = None,
130+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
130131
additional_forward_args: Any = None,
131-
*,
132-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
132+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
133+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
133134
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
134-
return_convergence_delta: Literal[True],
135+
return_convergence_delta: Literal[False] = False,
135136
attribute_to_layer_input: bool = False,
136137
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
137138
grad_kwargs: Optional[Dict[str, Any]] = None,
138-
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
139+
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
139140

140141
@log_usage()
141142
# pyre-fixme[43]: This definition does not have the same decorators as the
@@ -452,7 +453,7 @@ def __init__(
452453
# Ignoring mypy error for inconsistent signature with DeepLiftShap
453454
@typing.overload # type: ignore
454455
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
455-
# arguments of overload defined on line `439`.
456+
# arguments of overload defined on line `453`.
456457
def attribute(
457458
self,
458459
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -462,32 +463,33 @@ def attribute(
462463
target: TargetType = None,
463464
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
464465
additional_forward_args: Any = None,
465-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
466-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
466+
*,
467+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
467468
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
468-
return_convergence_delta: Literal[False] = False,
469+
return_convergence_delta: Literal[True],
469470
attribute_to_layer_input: bool = False,
470471
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
471-
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
472+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
472473

473474
@typing.overload
474475
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
475-
# arguments of overload defined on line `453`.
476+
# arguments of overload defined on line `439`.
476477
def attribute(
477478
self,
478479
inputs: Union[Tensor, Tuple[Tensor, ...]],
479480
baselines: Union[
480481
Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
481482
],
482483
target: TargetType = None,
484+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
483485
additional_forward_args: Any = None,
484-
*,
485-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
486+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
487+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
486488
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
487-
return_convergence_delta: Literal[True],
489+
return_convergence_delta: Literal[False] = False,
488490
attribute_to_layer_input: bool = False,
489491
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
490-
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
492+
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
491493

492494
@log_usage()
493495
# pyre-fixme[43]: This definition does not have the same decorators as the

0 commit comments

Comments
 (0)