Skip to content

Commit f08074f

Browse files
committed
fix lint
1 parent 5239b40 commit f08074f

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

nsight/extraction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import ncu_report
2525
import numpy as np
2626
import pandas as pd
27+
from numpy.typing import NDArray
2728

2829
from nsight import exceptions, utils
2930
from nsight.utils import is_scalar
@@ -105,7 +106,7 @@ def extract_df_from_report(
105106
)
106107

107108
annotations: List[str] = []
108-
all_values: List[np.ndarray | None] = []
109+
all_values: List[NDArray[Any] | None] = []
109110
kernel_names: List[str] = []
110111
gpus: List[str] = []
111112
compute_clocks: List[int] = []

nsight/transformation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
from numpy.typing import NDArray
1718

1819

19-
def _value_aggregator(agg_func_name: str) -> Callable[[pd.Series], np.ndarray]:
20+
def _value_aggregator(agg_func_name: str) -> Callable[[pd.Series], NDArray[Any]]:
2021
"""Factory function to create value aggregators.
2122
2223
Args:
@@ -44,7 +45,7 @@ def _value_aggregator(agg_func_name: str) -> Callable[[pd.Series], np.ndarray]:
4445

4546
numpy_agg_func = AGG_FUNCTIONS[agg_func_name]
4647

47-
def aggregator(series: pd.Series) -> np.ndarray:
48+
def aggregator(series: pd.Series) -> NDArray[Any]:
4849
# Convert None to np.nan
4950
cleaned_series = series.apply(lambda x: np.nan if x is None else x)
5051
# Convert to numpy array, handling tuples

nsight/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Iterator
1414

1515
import numpy as np
16+
from numpy.typing import NDArray
1617

1718
from nsight.exceptions import (
1819
CUDA_CORE_UNAVAILABLE_MSG,
@@ -133,7 +134,7 @@ def print_header(*lines: str) -> None:
133134
@dataclass
134135
class NCUActionData:
135136
name: str
136-
values: np.ndarray | None
137+
values: NDArray[Any] | None
137138
compute_clock: int
138139
memory_clock: int
139140
gpu: str

0 commit comments

Comments
 (0)