diff --git a/.gitignore b/.gitignore index 5b1a35e..a836557 100644 --- a/.gitignore +++ b/.gitignore @@ -64,4 +64,10 @@ Thumbs.db html/ # Test report -report.xml \ No newline at end of file +report.xml + +# Logs / NCU Reps / PNGs / CSVs +*.log +*.ncu-rep +*.png +*.csv diff --git a/docs/source/overview/architecture.rst b/docs/source/overview/architecture.rst index 3e840a0..e01dbd5 100644 --- a/docs/source/overview/architecture.rst +++ b/docs/source/overview/architecture.rst @@ -21,12 +21,23 @@ Advanced Options ---------------- **Metric Selection** -Nsight Python collects `gpu__time_duration.sum` by default. To collect another NVIDIA Nsight Compute metric: +Nsight Python collects `gpu__time_duration.sum` by default. To collect other NVIDIA Nsight Compute metrics: .. code-block:: python - @nsight.analyze.kernel(metric="sm__throughput.avg.pct_of_peak_sustained_elapsed") - def benchmark(...): + @nsight.analyze.kernel(metrics=["sm__throughput.avg.pct_of_peak_sustained_elapsed"]) + def benchmark1(...): + ... + + # or + @nsight.analyze.kernel( + metrics=[ + "smsp__sass_inst_executed_op_shared_ld.sum", + "smsp__sass_inst_executed_op_shared_st.sum", + "launch__sm_count", + ], + ) + def benchmark2(...): ... **Derived Metrics** diff --git a/examples/01_compare_throughput.py b/examples/01_compare_throughput.py index 7d325bc..8807019 100644 --- a/examples/01_compare_throughput.py +++ b/examples/01_compare_throughput.py @@ -10,7 +10,7 @@ New concepts: - Multiple `nsight.annotate()` blocks to profile different kernels - Using `@nsight.annotate()` as a function decorator (alternative to context manager) -- Using the `metric` parameter to collect a specific Nsight Compute metric (DRAM throughput instead of execution time) +- Using the `metrics` parameter to collect a specific Nsight Compute metric (DRAM throughput instead of execution time) - Using `print_data=True` to print the collected dataframe to the terminal """ @@ -33,7 +33,7 @@ def einsum_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: @nsight.analyze.kernel( runs=10, # Collect DRAM throughput as percentage of peak instead of time - metric="dram__throughput.avg.pct_of_peak_sustained_elapsed", + metrics=["dram__throughput.avg.pct_of_peak_sustained_elapsed"], ) def benchmark_matmul_throughput(n: int) -> None: """ diff --git a/examples/02_parameter_sweep.py b/examples/02_parameter_sweep.py index fd89943..5f7e531 100644 --- a/examples/02_parameter_sweep.py +++ b/examples/02_parameter_sweep.py @@ -36,7 +36,9 @@ def benchmark_matmul_sizes(n: int) -> None: def main() -> None: - benchmark_matmul_sizes() # notice no n parameter is passed, it is passed in the configs list instead + # notice no n parameter is passed, it is passed in the configs list instead + result = benchmark_matmul_sizes() + print(result.to_dataframe()) print("✓ Benchmark complete! Check '02_parameter_sweep.png'") diff --git a/examples/03_custom_metrics.py b/examples/03_custom_metrics.py index 5fc6603..723043d 100644 --- a/examples/03_custom_metrics.py +++ b/examples/03_custom_metrics.py @@ -69,7 +69,8 @@ def benchmark_tflops(n: int) -> None: def main() -> None: - benchmark_tflops() + result = benchmark_tflops() + print(result.to_dataframe()) print("✓ TFLOPs benchmark complete! Check '03_custom_metrics.png'") diff --git a/examples/08_multiple_metrics.py b/examples/08_multiple_metrics.py new file mode 100644 index 0000000..2cd66b9 --- /dev/null +++ b/examples/08_multiple_metrics.py @@ -0,0 +1,80 @@ +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example 8: Collecting Multiple Metrics +======================================= + +This example shows how to collect multiple metrics in a single profiling run. + +New concepts: +- Using the `metrics` parameter to collect multiple metrics +- `@nsight.analyze.plot` decorator does NOT support multiple metrics now +""" + +import torch + +import nsight + +sizes = [(2**i,) for i in range(11, 13)] + + +@nsight.analyze.kernel( + configs=sizes, + runs=5, + # Collect both shared memory load and store SASS instructions + metrics=[ + "smsp__sass_inst_executed_op_shared_ld.sum", + "smsp__sass_inst_executed_op_shared_st.sum", + ], +) +def analyze_shared_memory_ops(n: int) -> None: + """Analyze both shared memory load and store SASS instructions + for different kernels. + + Note: To evaluate multiple metrics, pass them as a sequence + (list/tuple). All results are merged into one ProfileResults + object, with the 'Metric' column indicating each specific metric. + """ + + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + c = torch.randn(2 * n, 2 * n, device="cuda") + d = torch.randn(2 * n, 2 * n, device="cuda") + + with nsight.annotate("@-operator"): + _ = a @ b + + with nsight.annotate("torch.matmul"): + _ = torch.matmul(c, d) + + +def main() -> None: + # Run analysis with multiple metrics + results = analyze_shared_memory_ops() + + df = results.to_dataframe() + print(df) + + unique_metrics = df["Metric"].unique() + print(f"\n✓ Collected {len(unique_metrics)} metrics:") + for metric in unique_metrics: + print(f" - {metric}") + + print("\n✓ Sample data:") + print(df[["Annotation", "n", "Metric", "AvgValue"]].to_string(index=False)) + + print("\n" + "=" * 60) + print("IMPORTANT: @plot decorator limitation") + print("=" * 60) + print("When multiple metrics are collected:") + print(" ✓ All metrics are collected in a single ProfileResults object") + print(" ✓ DataFrame has 'Metric' column to distinguish them") + print(" ✗ @nsight.analyze.plot decorator will RAISE AN ERROR") + print(" Why? @plot can only visualize one metric at a time.") + print(" Tip: Use separate @kernel functions for each metric or use") + print(" 'derive_metric' to compute custom values.") + + +if __name__ == "__main__": + main() diff --git a/examples/09_advanced_metric_custom.py b/examples/09_advanced_metric_custom.py new file mode 100644 index 0000000..2080609 --- /dev/null +++ b/examples/09_advanced_metric_custom.py @@ -0,0 +1,87 @@ +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example 9: Advanced Custom Metrics from Multiple Metrics +========================================================= + +This example shows how to compute custom metrics from multiple metrics. + +New concepts: +- Using `derive_metric` to compute custom values from multiple metrics +""" + +import torch + +import nsight + +sizes = [(2**i,) for i in range(10, 13)] + + +def compute_avg_insts( + ld_insts: int, st_insts: int, launch_sm_count: int, n: int +) -> float: + """ + Compute average shared memory load/store instructions per SM. + + Custom metric function signature: + - First several arguments: the measured metrics, must match the order + of metrics in @kernel decorator + - Remaining arguments: must match the decorated function's signature + + In this example: + - ld_insts: Total shared memory load instructions + (from smsp__inst_executed_pipe_lsu.shared_op_ld.sum metric) + - st_insts: Total shared memory store instructions + (from smsp__inst_executed_pipe_lsu.shared_op_st.sum metric) + - launch_sm_count: Number of SMs that launched blocks + (from launch__block_sm_count metric) + - n: Matches the 'n' parameter from benchmark_avg_insts(n) + + Args: + ld_insts: Total shared memory load instructions + st_insts: Total shared memory store instructions + launch_sm_count: Number of SMs that launched blocks + n: Matrix size (n x n) - parameter from the decorated benchmark function + + Returns: + Average shared memory load/store instructions per SM + """ + insts_per_sm = (ld_insts + st_insts) / launch_sm_count + return insts_per_sm + + +@nsight.analyze.plot( + filename="09_advanced_metric_custom.png", + ylabel="Average Shared Memory Load/Store Instructions per SM", # Custom y-axis label + annotate_points=True, # Show values on the plot +) +@nsight.analyze.kernel( + configs=sizes, + runs=10, + derive_metric=compute_avg_insts, # Use custom metric + metrics=[ + "smsp__sass_inst_executed_op_shared_ld.sum", + "smsp__sass_inst_executed_op_shared_st.sum", + "launch__sm_count", + ], +) +def benchmark_avg_insts(n: int) -> None: + """ + Benchmark matmul and display results. + """ + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + + with nsight.annotate("matmul"): + _ = a @ b + + +def main() -> None: + result = benchmark_avg_insts() + print(result.to_dataframe()) + print("✓ Avg Insts benchmark complete! Check '09_advanced_metric_custom.png'") + + +if __name__ == "__main__": + main() diff --git a/examples/10_combine_kernel_metrics.py b/examples/10_combine_kernel_metrics.py new file mode 100644 index 0000000..929855d --- /dev/null +++ b/examples/10_combine_kernel_metrics.py @@ -0,0 +1,63 @@ +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example 10: Multiple Kernels per Run with Combined Metrics +=========================================================== + +This example shows how to profile multiple kernels in a single run and combine their metrics. + +New concepts: +- Using `combine_kernel_metrics` to aggregate metrics from multiple kernels +- Summing metrics from consecutive kernel executions +""" + +import torch + +import nsight + +# Define configuration sizes +sizes = [(2**i,) for i in range(10, 13)] + + +@nsight.analyze.plot( + filename="10_combine_kernel_metrics.png", + ylabel="Total Cycles (Sum of 3 Kernels)", + annotate_points=True, +) +@nsight.analyze.kernel( + configs=sizes, + runs=7, + combine_kernel_metrics=lambda x, y: x + y, # Sum metrics from multiple kernels + metrics=[ + "sm__cycles_elapsed.avg", + ], +) +def benchmark_multiple_kernels(n: int) -> None: + """ + Benchmark three matrix multiplications in a single run. + + Executes three matmul operations within one profiled context, + demonstrating metric combination across kernels. + + Args: + n: Matrix size (n x n) + """ + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + + with nsight.annotate("test"): + # Three consecutive kernel executions + _ = a @ b # Kernel 1 + _ = a @ b # Kernel 2 + _ = a @ b # Kernel 3 + + +def main() -> None: + result = benchmark_multiple_kernels() + print(result.to_dataframe()) + print("\n✓ Total Cycles benchmark complete! Check '10_combine_kernel_metrics.png'") + + +if __name__ == "__main__": + main() diff --git a/examples/11_output_csv.py b/examples/11_output_csv.py new file mode 100644 index 0000000..5e983ce --- /dev/null +++ b/examples/11_output_csv.py @@ -0,0 +1,146 @@ +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example 11: Controlling CSV Output Files +========================================= + +This example shows how to control CSV file generation. + +New concepts: +- Using `output_csv` parameter to enable/disable CSV file generation +- Using `output_prefix` to specify output file location and naming +""" + +import os + +import pandas as pd +import torch + +import nsight + +# Get current directory for output +current_dir = os.path.dirname(os.path.abspath(__file__)) +output_prefix = f"{current_dir}/example10_" + + +# Matrix sizes to benchmark +sizes = [(2**i,) for i in range(10, 13)] + + +@nsight.analyze.kernel( + configs=sizes, + runs=3, + output_prefix=output_prefix, + output_csv=True, # Enable CSV file generation + metrics=[ + "smsp__sass_inst_executed_op_shared_ld.sum", + "smsp__sass_inst_executed_op_shared_st.sum", + ], +) +def analyze_memory_ops_with_csv(n: int) -> None: + """ + Analyze memory operations with CSV output enabled. + + When output_csv=True, two CSV files are generated: + 1. {prefix}processed_data--.csv - Raw profiled data + 2. {prefix}profiled_data--.csv - Processed/aggregated data + + Args: + n: Matrix size (n x n) + """ + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + + with nsight.annotate("matmul-operator"): + _ = a @ b + + with nsight.annotate("torch-matmul"): + _ = torch.matmul(a, b) + + +def print_full_dataframe( + df: pd.DataFrame, max_rows: int = 20, max_col_width: int = 100 +) -> None: + """ + Print DataFrame without truncation. + + Args: + df: DataFrame to print + max_rows: Maximum number of rows to display (None for all rows) + max_col_width: Maximum column width (None for no limit) + """ + # Save current display options + original_options = { + "display.max_rows": pd.get_option("display.max_rows"), + "display.max_columns": pd.get_option("display.max_columns"), + "display.max_colwidth": pd.get_option("display.max_colwidth"), + "display.width": pd.get_option("display.width"), + "display.expand_frame_repr": pd.get_option("display.expand_frame_repr"), + } + + try: + # Set display options for full output + pd.set_option("display.max_rows", max_rows if max_rows else None) + pd.set_option("display.max_columns", None) + pd.set_option("display.max_colwidth", max_col_width if max_col_width else None) + pd.set_option("display.width", None) + pd.set_option("display.expand_frame_repr", False) + + print(df.to_string()) + + finally: + # Restore original options + for option, value in original_options.items(): + pd.set_option(option, value) + + +def read_and_display_csv_files() -> None: + """Read and display the generated CSV files.""" + + # Find CSV files + csv_files = [] + for file in os.listdir(current_dir): + if file.startswith("example10_") and file.endswith(".csv"): + csv_files.append(os.path.join(current_dir, file)) + + for file_path in sorted(csv_files): + file_name = os.path.basename(file_path) + print(f"\nFile: {file_name}") + print("-" * (len(file_name) + 6)) + + # Read CSV file + try: + df = pd.read_csv(file_path) + + # Display only columns related to metrics/values + value_cols = [ + col + for col in df.columns + if "Value" in col or "Metric" in col or "Annotation" in col + ] + # print(df[value_cols].head()) + # Show full DataFrame without truncation + print_full_dataframe(df[value_cols]) + except Exception as e: + print(f"Error reading {file_name}: {e}") + + +def main() -> None: + # Clean up any previous output files + for old_file in os.listdir(current_dir): + if old_file.startswith("example10_") and old_file.endswith( + (".csv", ".ncu-rep", ".log") + ): + os.remove(os.path.join(current_dir, old_file)) + + # Run the analysis with CSV output + result = analyze_memory_ops_with_csv() + print(result.to_dataframe()) + + # Read and display generated CSV files + read_and_display_csv_files() + + +if __name__ == "__main__": + main() diff --git a/examples/README.md b/examples/README.md index 2b0706d..7426411 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,6 +5,7 @@ This directory contains examples demonstrating how to use Nsight Python for prof ## Prerequisites ### Required + - **Python 3.10+** - **CUDA-capable GPU** - **NVIDIA Nsight Compute** (for profiling) @@ -14,6 +15,7 @@ This directory contains examples demonstrating how to use Nsight Python for prof The examples require additional packages beyond the base `nsight` package: #### PyTorch + Most examples use PyTorch for GPU operations: ```bash @@ -25,6 +27,7 @@ pip install torch --index-url https://download.pytorch.org/whl/cuXXX Visit [pytorch.org](https://pytorch.org/get-started/locally/) for installation commands matching your specific CUDA version. #### Triton (Optional) + For the Triton examples (`07_triton_minimal.py`): ```bash @@ -60,7 +63,7 @@ This will profile a simple matrix multiplication and generate a plot showing the - Visualizing performance across problem sizes - **`03_custom_metrics.py`** - Computing TFLOPs - - Using `derive_metric` to compute custom metrics + - Using `derive_metric` to compute custom metric - Understanding the metric function signature - Transforming time measurements into performance metrics @@ -85,3 +88,20 @@ This will profile a simple matrix multiplication and generate a plot showing the - Using `variant_fields` and `variant_annotations` - Comparing against PyTorch baselines with `normalize_against` - Showing speedup metrics + +- **`08_multiple_metrics.py`** - Collecting multiple metrics + - Collecting multiple metrics with using a sequence of metric names + - Merged results with `"Metric"` column in DataFrame + - `@plot` decorator incompatible with multiple metrics + +- **`09_advanced_metric_custom.py`** - Computing advanced custom metric + - Using `derive_metric` to compute custom metric from multiple metrics + +- **`10_multiple_kernels_combine.py`** - Combining metrics from multiple kernels + - Using `combine_kernel_metrics` to aggregate metrics from multiple kernels + - Summing metrics from consecutive kernel executions + +- **`11_output_csv.py`** - Outputting to CSV + - Using `output_csv` parameter to enable/disable CSV file generation + - Using `output_prefix` to specify output file location and naming + - Reading and displaying generated CSV files diff --git a/examples/test_examples.py b/examples/test_examples.py index b85ab93..990c470 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -45,3 +45,23 @@ def test_07_triton_minimal() -> None: pytest.importorskip("triton") triton_minimal = importlib.import_module("examples.07_triton_minimal") triton_minimal.main() + + +def test_08_multiple_metrics() -> None: + multiple_metrics = importlib.import_module("examples.08_multiple_metrics") + multiple_metrics.main() + + +def test_09_advanced_metric_custom() -> None: + advanced_custom = importlib.import_module("examples.09_advanced_metric_custom") + advanced_custom.main() + + +def test_10_combine_kernel_metrics() -> None: + combine_metrics = importlib.import_module("examples.10_combine_kernel_metrics") + combine_metrics.main() + + +def test_11_output_csv() -> None: + output_csv = importlib.import_module("examples.11_output_csv") + output_csv.main() diff --git a/nsight/analyze.py b/nsight/analyze.py index 42e379c..e8bbbc4 100644 --- a/nsight/analyze.py +++ b/nsight/analyze.py @@ -5,11 +5,13 @@ import functools import os import tempfile +import warnings from collections.abc import Callable, Iterable, Sequence from typing import Any, Literal, overload import matplotlib import matplotlib.figure +import numpy as np import nsight.collection as collection import nsight.visualization as visualization @@ -32,7 +34,7 @@ def kernel( derive_metric: Callable[..., float] | None = None, normalize_against: str | None = None, output: Literal["quiet", "progress", "verbose"] = "progress", - metric: str = "gpu__time_duration.sum", + metrics: Sequence[str] = ["gpu__time_duration.sum"], ignore_kernel_list: Sequence[str] | None = None, clock_control: Literal["base", "none"] = "none", cache_control: Literal["all", "none"] = "all", @@ -53,7 +55,7 @@ def kernel( derive_metric: Callable[..., float] | None = None, normalize_against: str | None = None, output: Literal["quiet", "progress", "verbose"] = "progress", - metric: str = "gpu__time_duration.sum", + metrics: Sequence[str] = ["gpu__time_duration.sum"], ignore_kernel_list: Sequence[str] | None = None, clock_control: Literal["base", "none"] = "none", cache_control: Literal["all", "none"] = "all", @@ -99,15 +101,22 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults If the configs are not provided at decoration time, they must be provided when calling the decorated function. runs: Number of times each configuration should be executed. derive_metric: - A function to transform the collected metric. + A function to transform the collected metrics. This can be used to compute derived metrics like TFLOPs that cannot - be captured by ncu directly. The function takes the metric value and + be captured by ncu directly. The function takes the metric values and the arguments of the profile-decorated function and returns the new - metric. See the examples for concrete use cases. + metric. The parameter order requirements for the custom function: + + - First several arguments: Must exactly match the order of metrics declared in the @kernel decorator. These arguments will receive the actual measured values of those metrics. + - Remaining arguments: Must exactly match the signature of the decorated function. In other words, the original function's parameters are passed in order. + + See the examples for concrete use cases. normalize_against: Annotation name to normalize metrics against. This is useful to compute relative metrics like speedup. - metric: The metric to collect. By default, kernel runtimes in nanoseconds are collected. Default: ``"gpu__time_duration.sum"``. To see the available metrics on your system, use the command: ``ncu --query-metrics``. + metrics: The metrics to collect. By default, kernel runtimes in nanoseconds + are collected. Default: ``["gpu__time_duration.sum"]``. To see the available + metrics on your system, use the command: ``ncu --query-metrics``. ignore_kernel_list: List of kernel names to ignore. If you call a library within an annotated range context, you might not have precise control over which and how many kernels are being launched. If some of these kernels should be ignored in the profile, their names can be provided in this parameter. Default: ``None`` @@ -165,9 +174,9 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults **Raw Data CSV** (``profiled_data--.csv``): Contains unprocessed profiling data with one row per run per configuration. Columns include: - ``Annotation``: Name of the annotated region being profiled - - ``Value``: Raw metric value collected by the profiler - - ``Metric``: The metric being collected (e.g., ``gpu__time_duration.sum``) - - ``Transformed``: Name of the function used to transform the metric (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` + - ``Value``: Raw metric values collected by the profiler + - ``Metric``: The metrics being collected (e.g., ``gpu__time_duration.sum``) + - ``Transformed``: Name of the function used to transform the metrics (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` - ``Kernel``: Name of the GPU kernel(s) launched - ``GPU``: GPU device name - ``Host``: Host machine name @@ -179,23 +188,25 @@ def wrapped_function(*args, configs=None, **kwargs) -> ProfileResults - ``Annotation``: Name of the annotated region being profiled - ````: One column for each parameter of the decorated function - - ``AvgValue``: Average metric value across all runs - - ``StdDev``: Standard deviation of the metric across runs - - ``MinValue``: Minimum metric value observed - - ``MaxValue``: Maximum metric value observed + - ``AvgValue``: Average metric values across all runs + - ``StdDev``: Standard deviation of the metrics across runs + - ``MinValue``: Minimum metric values observed + - ``MaxValue``: Maximum metric values observed - ``NumRuns``: Number of runs used for aggregation - ``CI95_Lower``: Lower bound of the 95% confidence interval - ``CI95_Upper``: Upper bound of the 95% confidence interval - ``RelativeStdDevPct``: Standard deviation as a percentage of the mean - ``StableMeasurement``: Boolean indicating if the measurement is stable (low variance). The measurement is stable if ``RelativeStdDevPct`` < 2 % . - - ``Metric``: The metric being collected - - ``Transformed``: Name of the function used to transform the metric (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` + - ``Metric``: The metrics being collected + - ``Transformed``: Name of the function used to transform the metrics (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` - ``Kernel``: Name of the GPU kernel(s) launched - ``GPU``: GPU device name - ``Host``: Host machine name - ``ComputeClock``: GPU compute clock frequency - ``MemoryClock``: GPU memory clock frequency """ + # Strip whitespace + metrics = [m.strip() for m in metrics] def _create_profiler() -> collection.core.NsightProfiler: """Helper to create the profiler with the given settings.""" @@ -228,7 +239,7 @@ def _create_profiler() -> collection.core.NsightProfiler: output_csv=output_csv, ) ncu = collection.ncu.NCUCollector( - metric=metric, + metrics=metrics, ignore_kernel_list=ignore_kernel_list, combine_kernel_metrics=combine_kernel_metrics, clock_control=clock_control, @@ -248,6 +259,27 @@ def _create_profiler() -> collection.core.NsightProfiler: return profiler(_func) # type: ignore[return-value] +def _validate_metric(result: collection.core.ProfileResults) -> None: + """ + Check if ProfileResults contains only a single metric. + + Args: + result: ProfileResults object + + Raises: + ValueError: If multiple metrics are detected + """ + df = result.to_dataframe() + + # Check for multiple metrics in "Metric" column + unique_metrics = df["Metric"].unique() + if len(unique_metrics) > 1: + raise ValueError( + f"Cannot visualize {len(unique_metrics)} > 1 metrics with the " + "@nsight.analyze.plot decorator." + ) + + def plot( filename: str = "plot.png", *, @@ -326,6 +358,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: result = func(*args, **kwargs) if "NSPY_NCU_PROFILE" not in os.environ: + # Check for multiple metrics or complex data structures + _validate_metric(result) + visualization.visualize( result.to_dataframe(), row_panels=row_panels, diff --git a/nsight/collection/core.py b/nsight/collection/core.py index 544dfc9..b3eca56 100644 --- a/nsight/collection/core.py +++ b/nsight/collection/core.py @@ -13,6 +13,7 @@ from collections.abc import Callable, Collection, Iterable, Sequence from typing import Any +import numpy as np import pandas as pd from nsight import annotation, exceptions, thermovision, transformation, utils @@ -276,11 +277,11 @@ class ProfileSettings: derive_metric: Callable[..., float] | None """ - A function to transform the collected metric. + A function to transform the collected metrics. This can be used to compute derived metrics like TFLOPs that cannot - be captured by ncu directly. The function takes the metric value and + be captured by ncu directly. The function takes the metric values and the arguments of the profile-decorated function and returns the new - metric. See the examples for concrete use cases. + metrics. See the examples for concrete use cases. """ normalize_against: str | None @@ -333,17 +334,17 @@ def to_dataframe(self) -> pd.DataFrame: - ``Annotation``: Name of the annotated region being profiled - ````: One column for each parameter of the decorated function - - ``AvgValue``: Average metric value across all runs - - ``StdDev``: Standard deviation of the metric across runs - - ``MinValue``: Minimum metric value observed - - ``MaxValue``: Maximum metric value observed + - ``AvgValue``: Average metric values across all runs + - ``StdDev``: Standard deviation of the metrics across runs + - ``MinValue``: Minimum metric values observed + - ``MaxValue``: Maximum metric values observed - ``NumRuns``: Number of runs used for aggregation - ``CI95_Lower``: Lower bound of the 95% confidence interval - ``CI95_Upper``: Upper bound of the 95% confidence interval - ``RelativeStdDevPct``: Standard deviation as a percentage of the mean - ``StableMeasurement``: Boolean indicating if the measurement is stable (low variance). The measurement is stable if ``RelativeStdDevPct`` < 2 % . - - ``Metric``: The metric being collected - - ``Transformed``: Name of the function used to transform the metric (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` + - ``Metric``: The metrics being collected + - ``Transformed``: Name of the function used to transform the metrics (specified via ``derive_metric``), or ``False`` if no transformation was applied. For lambda functions, this shows ``""`` - ``Kernel``: Name of the GPU kernel(s) launched - ``GPU``: GPU device name - ``Host``: Host machine name @@ -423,12 +424,12 @@ def wrapper( **kwargs, ) - # Check if the function has a return type - raw_df = self.collector.collect(func, configs, self.settings) + # Check if the function has a return type if raw_df is not None: - processed = transformation.aggregate_data( + + processed: pd.DataFrame = transformation.aggregate_data( raw_df, func, self.settings.normalize_against, diff --git a/nsight/collection/ncu.py b/nsight/collection/ncu.py index 85c96ce..65dcf4d 100644 --- a/nsight/collection/ncu.py +++ b/nsight/collection/ncu.py @@ -25,7 +25,7 @@ def launch_ncu( report_path: str, name: str, - metric: str, + metrics: Sequence[str], cache_control: Literal["none", "all"], clock_control: Literal["none", "base"], replay_mode: Literal["kernel", "range"], @@ -36,7 +36,7 @@ def launch_ncu( Args: report_path: Path to write report file to. - metric: Specific metric to collect. + metrics: Specific metrics to collect. cache_control: Select cache control option clock_control: Select clock control option replay_mode: Select replay mode option @@ -74,9 +74,10 @@ def launch_ncu( log_path = os.path.splitext(report_path)[0] + ".log" log = f"--log-file {log_path}" nvtx = f'--nvtx --nvtx-include "regex:{utils.NVTX_DOMAIN}@.+/"' + metrics_str = ",".join(metrics) # Construct the ncu command - ncu_command = f"""ncu {log} {cache} {clocks} {replay} {nvtx} --metrics {metric} -f -o {report_path} {sys.executable} {script_path} {script_args}""" + ncu_command = f"""ncu {log} {cache} {clocks} {replay} {nvtx} --metrics {metrics_str} -f -o {report_path} {sys.executable} {script_path} {script_args}""" # Check if ncu is available on the system ncu_available = False @@ -109,7 +110,7 @@ def launch_ncu( error_context = NCUErrorContext( errors=error_logs, log_file_path=log_path, - metric=metric, + metrics=metrics, ) error_message = utils.format_ncu_error_message(error_context) @@ -127,7 +128,7 @@ class NCUCollector(core.NsightCollector): NCU collector for Nsight Python. Args: - metric: Metric to collect from + metrics: Metrics to collect from NVIDIA Nsight Compute. By default we collect kernel runtimes in nanoseconds. A list of supported metrics can be found with ``ncu --list-metrics``. ignore_kernel_list: List of kernel names to ignore. @@ -161,7 +162,7 @@ class NCUCollector(core.NsightCollector): def __init__( self, - metric: str = "gpu__time_duration.sum", + metrics: Sequence[str] = ["gpu__time_duration.sum"], ignore_kernel_list: Sequence[str] | None = None, combine_kernel_metrics: Callable[[float, float], float] | None = None, clock_control: Literal["base", "none"] = "none", @@ -175,7 +176,7 @@ def __init__( if replay_mode not in ("kernel", "range"): raise ValueError("replay_mode must be 'kernel', or 'range'") - self.metric = metric + self.metrics = metrics self.ignore_kernel_list = ignore_kernel_list or [] self.combine_kernel_metrics = combine_kernel_metrics self.clock_control = clock_control @@ -210,7 +211,7 @@ def collect( log_path = launch_ncu( report_path, func.__name__, - self.metric, + self.metrics, self.cache_control, self.clock_control, self.replay_mode, @@ -226,10 +227,9 @@ def collect( f"[NSIGHT-PYTHON] Refer to {log_path} for the NVIDIA Nsight Compute CLI logs" ) - # Extract raw data df = extraction.extract_df_from_report( report_path, - self.metric, + self.metrics, configs, # type: ignore[arg-type] settings.runs, func, @@ -238,6 +238,7 @@ def collect( settings.output_progress, self.combine_kernel_metrics, ) + return df else: diff --git a/nsight/exceptions.py b/nsight/exceptions.py index c92a1f2..a2029e5 100644 --- a/nsight/exceptions.py +++ b/nsight/exceptions.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum @@ -45,19 +46,21 @@ class NCUNotAvailableError(Exception): CUDA_CORE_UNAVAILABLE_MSG = "cuda-core is required for ignore_failures functionality.\n Install it with:\n - pip install nsight-python[cu12] (if you have CUDA 12.x)\n - pip install nsight-python[cu13] (if you have CUDA 13.x)" -def get_metric_error_message(metric: str, error_type: MetricErrorType) -> str: +def get_metrics_error_message( + metrics: Sequence[str], error_type: MetricErrorType +) -> str: """ Returns a formatted error message for invalid or unsupported metric names. Args: - metric: The invalid/unsupported metric name that was provided. + metrics: The invalid/unsupported metric names that was provided. error_type: The type of error (Invalid or Unsupported). Returns: str: User-friendly error message with guidance. """ return ( - f"{error_type.value} value '{metric}' for 'metric' parameter for nsight.analyze.kernel(). " + f"{error_type.value} value {metrics} for 'metrics' parameter for nsight.analyze.kernel()." f"\nPlease refer ncu --query-metrics for list of supported metrics." ) @@ -70,9 +73,9 @@ class NCUErrorContext: Attributes: errors: The error logs from NCU log_file_path: Path to the NCU log file - metric: The metric that was being collected + metrics: The metrics that was being collected """ errors: list[str] log_file_path: str - metric: str + metrics: Sequence[str] diff --git a/nsight/extraction.py b/nsight/extraction.py index d7b0fe3..113518f 100644 --- a/nsight/extraction.py +++ b/nsight/extraction.py @@ -8,57 +8,103 @@ and transform it into structured pandas DataFrames for further analysis. Functions: - extract_ncu_action_data(action, metric): + extract_ncu_action_data(action, metrics): Extracts performance data for a specific kernel action from an NVIDIA Nsight Compute report. - extract_df_from_report(metric, configs, iterations, func, derive_metric, ignore_kernel_list, verbose, combine_kernel_metrics=None): + extract_df_from_report(report_path, metrics, configs, iterations, func, derive_metric, ignore_kernel_list, output_progress, combine_kernel_metrics=None): Processes the full NVIDIA Nsight Compute report and returns a pandas DataFrame containing performance metrics. """ import functools import inspect import socket -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Any, List, Tuple import ncu_report +import numpy as np import pandas as pd +from numpy.typing import NDArray from nsight import exceptions, utils from nsight.utils import is_scalar -def extract_ncu_action_data(action: Any, metric: str) -> utils.NCUActionData: +def extract_ncu_action_data(action: Any, metrics: Sequence[str]) -> utils.NCUActionData: """ Extracts performance data from an NVIDIA Nsight Compute kernel action. Args: action: The NVIDIA Nsight Compute action object. - metric: The metric name to extract from the action. + metrics: The metric names to extract from the action. Returns: - A data container with extracted metric, clock rates, and GPU name. + A data container with extracted metrics, clock rates, and GPU name. """ - if metric not in action.metric_names(): - error_message = exceptions.get_metric_error_message( - metric, error_type=exceptions.MetricErrorType.UNSUPPORTED - ) - raise exceptions.ProfilerException(error_message) + for metric in metrics: + if metric not in action.metric_names(): + error_message = exceptions.get_metrics_error_message( + metric, error_type=exceptions.MetricErrorType.INVALID + ) + raise exceptions.ProfilerException(error_message) + + # Extract values for all metrics. + failure = "dummy_kernel_failure" in action.name() + all_values = ( + None if failure else np.array([action[metric].value() for metric in metrics]) + ) return utils.NCUActionData( name=action.name(), - value=( - None if "dummy_kernel_failure" in action.name() else action[metric].value() - ), + values=all_values, compute_clock=action["device__attribute_clock_rate"].value(), memory_clock=action["device__attribute_memory_clock_rate"].value(), gpu=action["device__attribute_display_name"].value(), ) +def explode_dataframe(df: pd.DataFrame) -> pd.DataFrame: + """ + Explode columns with list/tuple/np.ndarray values into multiple rows. + Two scenarios: + + 1. No derived metrics (all "Transformed" = False): + - All columns maybe contain multiple values (lists/arrays). + - Use `explode()` to flatten each list element into separate rows. + + 2. With derived metrics: + - Metric columns contain either single-element lists or scalars. + - Only flatten single-element lists to scalars, don't create new rows. + + Args: + df: Dataframe to be exploded. + + Returns: + Exploded dataframe. + """ + df_explode = None + if df["Transformed"].eq(False).all(): + # 1: No derived metrics - explode all columns with sequences into rows. + df_explode = df.apply(pd.Series.explode).reset_index(drop=True) + else: + # 2: With derived metrics - only explode columns with single-value sequences. + df_explode = df.apply( + lambda col: ( + col.apply( + lambda x: ( + x[0] + if isinstance(x, (list, tuple, np.ndarray)) and len(x) == 1 + else x + ) + ) + ) + ) + return df_explode + + def extract_df_from_report( report_path: str, - metric: str, + metrics: Sequence[str], configs: List[Tuple[Any, ...]], iterations: int, func: Callable[..., Any], @@ -72,14 +118,14 @@ def extract_df_from_report( Args: report_path: Path to the report file. - metric: The NVIDIA Nsight Compute metric to extract. + metrics: The NVIDIA Nsight Compute metrics to extract. configs: Configuration settings used during profiling runs. iterations: Number of times each configuration was run. func: Function representing the kernel launch with parameter signature. - derive_metric: Function to transform the raw metric value with config values. + derive_metric: Function to transform the raw metric values with config values. ignore_kernel_list: Kernel names to ignore in the analysis. combine_kernel_metrics: Function to merge multiple kernel metrics. - verbose: Toggles the printing of extraction progress + verbose: Toggles the printing of extraction progress. Returns: A DataFrame containing the extracted and transformed performance data. @@ -91,7 +137,7 @@ def extract_df_from_report( if output_progress: print("[NSIGHT-PYTHON] Loading profiled data") try: - report = ncu_report.load_report(report_path) + report: ncu_report.IContext = ncu_report.load_report(report_path) except FileNotFoundError: raise exceptions.ProfilerException( "No NVIDIA Nsight Compute report found. Please run nsight-python with `@nsight.analyze.kernel(output='verbose')`" @@ -99,13 +145,13 @@ def extract_df_from_report( ) annotations: List[str] = [] - values: List[float | None] = [] + all_values: List[NDArray[Any] | None] = [] kernel_names: List[str] = [] gpus: List[str] = [] compute_clocks: List[int] = [] memory_clocks: List[int] = [] - metrics: List[str] = [] - transformed_metrics: List[str | bool] = [] + all_metrics: List[Tuple[str, ...]] = [] + all_transformed_metrics: List[str | bool] = [] hostnames: List[str] = [] sig = inspect.signature(func) @@ -118,13 +164,13 @@ def extract_df_from_report( print(f"Extracting profiling data") profiling_data: dict[str, list[utils.NCUActionData]] = {} for range_idx in range(report.num_ranges()): - current_range = report.range_by_idx(range_idx) + current_range: ncu_report.IRange = report.range_by_idx(range_idx) for action_idx in range(current_range.num_actions()): - action = current_range.action_by_idx(action_idx) - state = action.nvtx_state() + action: ncu_report.IAction = current_range.action_by_idx(action_idx) + state: ncu_report.INvtxState = action.nvtx_state() for domain_idx in state.domains(): - domain = state.domain_by_id(domain_idx) + domain: ncu_report.INvtxDomainInfo = state.domain_by_id(domain_idx) # ignore actions not in the nsight-python nvtx domain if domain.name() != utils.NVTX_DOMAIN: @@ -133,8 +179,8 @@ def extract_df_from_report( if ignore_kernel_list and action.name() in ignore_kernel_list: continue - annotation = domain.push_pop_ranges()[0] - data = extract_ncu_action_data(action, metric) + annotation: str = domain.push_pop_ranges()[0] + data = extract_ncu_action_data(action, metrics) if annotation not in profiling_data: profiling_data[annotation] = [] @@ -196,21 +242,23 @@ def extract_df_from_report( gpus.append(data.gpu) kernel_names.append(data.name) - # evaluate the measured metric - value = data.value + # evaluate the measured metrics + values = data.values if derive_metric is not None: - derived_metric = None if value is None else derive_metric(value, *conf) - value = derived_metric + derived_metric: float | int | None = ( + None if values is None else derive_metric(*values, *conf) + ) + values = derived_metric # type: ignore[assignment] derive_metric_name = derive_metric.__name__ - transformed_metrics.append(derive_metric_name) + all_transformed_metrics.append(derive_metric_name) else: - transformed_metrics.append(False) + all_transformed_metrics.append(False) - values.append(value) + all_values.append(values) # gather remaining required data annotations.append(annotation) - metrics.append(metric) + all_metrics.append(tuple(metrics)) hostnames.append(socket.gethostname()) # Add a field for every config argument bound_args = sig.bind(*conf) @@ -220,9 +268,9 @@ def extract_df_from_report( # Create the DataFrame with the initial columns df_data = { "Annotation": annotations, - "Value": values, - "Metric": metrics, - "Transformed": transformed_metrics, + "Value": all_values, + "Metric": all_metrics, + "Transformed": all_transformed_metrics, "Kernel": kernel_names, "GPU": gpus, "Host": hostnames, @@ -234,4 +282,7 @@ def extract_df_from_report( for arg_name, arg_values in arg_arrays.items(): df_data[arg_name] = arg_values - return pd.DataFrame(df_data) + # Explode the dataframe + df = explode_dataframe(pd.DataFrame(df_data)) + + return df diff --git a/nsight/transformation.py b/nsight/transformation.py index bcc50f5..2c4a9bc 100644 --- a/nsight/transformation.py +++ b/nsight/transformation.py @@ -44,13 +44,13 @@ def aggregate_data( # Note: When num_args=0, we need an empty list (not all columns via [-0:]) func_fields = df.columns[-num_args:].tolist() if num_args > 0 else [] - # Function to convert non-sortable columns to strings + # Function to convert non-sortable columns to tuples or strings def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: for col in dframe.columns: - # Try sorting the column to check if it's sortable try: + # Try sorting the column to check if it's sortable. sorted(dframe[col].dropna()) - except TypeError: + except (TypeError, ValueError): # If sorting fails, convert the column to string dframe[col] = dframe[col].astype(str) return dframe @@ -75,11 +75,14 @@ def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: ), # Use min to preserve first occurrence } + # The columns to aggregate except for the function parameters + groupby_columns = ["Annotation", "Metric", "Transformed"] + # Add assertion-based unique selection for remaining fields remaining_fields = [ col for col in df.columns - if col not in ["Value", "Annotation", "_original_order"] + func_fields + if col not in [*groupby_columns, "Value", "_original_order"] + func_fields ] for col in remaining_fields: @@ -102,7 +105,8 @@ def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: ) # Apply aggregation with named aggregation - agg_df = df.groupby(["Annotation"] + func_fields).agg(**named_aggs).reset_index() + groupby_df = df.groupby(groupby_columns + func_fields) + agg_df = groupby_df.agg(**named_aggs).reset_index() # Compute 95% confidence intervals agg_df["CI95_Lower"] = agg_df["AvgValue"] - 1.96 * ( @@ -127,21 +131,23 @@ def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: do_normalize = normalize_against is not None if do_normalize: - assert ( normalize_against in agg_df["Annotation"].values ), f"Annotation '{normalize_against}' not found in data." + # Columns of normalization dataframe to merge on + merge_on = func_fields + ["Metric", "Transformed"] + # Create a DataFrame to hold the normalization values normalization_df = agg_df[agg_df["Annotation"] == normalize_against][ - func_fields + ["AvgValue"] + merge_on + ["AvgValue"] ] normalization_df = normalization_df.rename( columns={"AvgValue": "NormalizationValue"} ) # Merge with the original DataFrame to apply normalization - agg_df = pd.merge(agg_df, normalization_df, on=func_fields) + agg_df = pd.merge(agg_df, normalization_df, on=merge_on) # Normalize the AvgValue by the values of the normalization annotation agg_df["AvgValue"] = agg_df["NormalizationValue"] / agg_df["AvgValue"] @@ -151,18 +157,13 @@ def convert_non_sortable_columns(dframe: pd.DataFrame) -> pd.DataFrame: agg_df["Metric"].astype(str) + f" relative to {normalize_against}" ) - # Calculate geometric mean for each annotation - geomean_values = {} - for annotation in agg_df["Annotation"].unique(): - annotation_data = agg_df[agg_df["Annotation"] == annotation] - valid_values = annotation_data["AvgValue"].dropna() - if not valid_values.empty: - geomean = np.exp(np.mean(np.log(valid_values))) - geomean_values[annotation] = geomean - else: - geomean_values[annotation] = np.nan - - # Add geomean values to the DataFrame - agg_df["Geomean"] = agg_df["Annotation"].map(geomean_values) + # Calculate the geometric mean of the AvgValue column + agg_df["Geomean"] = agg_df.groupby(groupby_columns)["AvgValue"].transform( + lambda x: ( + np.exp(np.mean(np.log(pd.to_numeric(x.dropna(), errors="coerce")))) + if not x.dropna().empty + else np.nan + ) + ) return agg_df diff --git a/nsight/utils.py b/nsight/utils.py index 0ac56c6..cc17e3b 100644 --- a/nsight/utils.py +++ b/nsight/utils.py @@ -12,11 +12,14 @@ from itertools import islice from typing import Any, Iterator +import numpy as np +from numpy.typing import NDArray + from nsight.exceptions import ( CUDA_CORE_UNAVAILABLE_MSG, MetricErrorType, NCUErrorContext, - get_metric_error_message, + get_metrics_error_message, ) # Try to import cuda-core (optional dependency) @@ -131,7 +134,7 @@ def print_header(*lines: str) -> None: @dataclass class NCUActionData: name: str - value: Any + values: NDArray[Any] | None compute_clock: int memory_clock: int gpu: str @@ -149,7 +152,7 @@ def _combine(lhs: "NCUActionData", rhs: "NCUActionData") -> "NCUActionData": assert lhs.gpu == rhs.gpu return NCUActionData( name=f"{lhs.name}|{rhs.name}", - value=value_reduce_op(lhs.value, rhs.value), + values=value_reduce_op(lhs.values, rhs.values), compute_clock=lhs.compute_clock, memory_clock=lhs.memory_clock, gpu=lhs.gpu, @@ -318,7 +321,9 @@ def format_ncu_error_message(context: NCUErrorContext) -> str: if context.errors and INVALID_METRIC_ERROR_HINT in context.errors[0]: message_parts.append( - get_metric_error_message(context.metric, error_type=MetricErrorType.INVALID) + get_metrics_error_message( + context.metrics, error_type=MetricErrorType.INVALID + ) ) else: message_parts.append("\n".join(f"- {error}" for error in context.errors)) diff --git a/nsight/visualization.py b/nsight/visualization.py index 2ec88e6..79857ef 100644 --- a/nsight/visualization.py +++ b/nsight/visualization.py @@ -81,7 +81,7 @@ def visualize( # Build Configuration field excluding variant_fields annotation_idx = agg_df.columns.get_loc("AvgValue") - func_fields = list(agg_df.columns[1:annotation_idx]) + func_fields = list(agg_df.columns[3:annotation_idx]) subplot_fields = row_panels + col_panels # type: ignore[operator] non_panel_fields = [ field @@ -202,7 +202,7 @@ def visualize( config_fields = x_keys else: annotation_idx = local_df.columns.get_loc("AvgValue") - func_fields = list(local_df.columns[1:annotation_idx]) + func_fields = list(local_df.columns[3:annotation_idx]) subplot_fields = row_panels + col_panels # type: ignore[operator] config_exclude = set(variant_fields or []) config_fields = [ diff --git a/tests/test_api_params.py b/tests/test_api_params.py index 8b16217..0873080 100644 --- a/tests/test_api_params.py +++ b/tests/test_api_params.py @@ -20,7 +20,7 @@ def get_app_args() -> argparse.Namespace: # nsight.analyze.kernel() parameters # TBD no command line arguments yet for: configs, derive_metric, ignore_kernel_list, combine_kernel_metrics parser.add_argument( - "--metric", "-m", default="dram__bytes.sum.per_second", help="Metric name" + "--metrics", "-m", default=["dram__bytes.sum.per_second"], help="Metric name" ) parser.add_argument("--runs", "-r", type=int, default=10, help="Number of runs") parser.add_argument("--replay-mode", "-p", default="kernel", help="Replay mode") @@ -93,7 +93,7 @@ def einsum(a: torch.Tensor, b: torch.Tensor) -> Any: @nsight.analyze.kernel( configs=sizes, runs=args.runs, - metric=args.metric, + metrics=args.metrics, replay_mode=args.replay_mode, normalize_against=args.normalize_against, clock_control=args.clock_control, diff --git a/tests/test_collection.py b/tests/test_collection.py index ef34bd3..d816eb2 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -19,7 +19,7 @@ def test_launch_ncu_runs_with_ncu_available(mock_run: MagicMock) -> None: collection.ncu.launch_ncu( "report.ncu-rep", "func_name", - metric="sm__cycles_elapsed.avg", + metrics=["sm__cycles_elapsed.avg"], cache_control="all", clock_control="base", replay_mode="kernel", @@ -58,7 +58,7 @@ def test_launch_ncu_falls_back_without_ncu(mock_run: MagicMock) -> None: collection.ncu.launch_ncu( "report.ncu-rep", "func_name", - metric="metric", + metrics=["metric"], cache_control="all", clock_control="base", replay_mode="kernel", diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 9edec4c..db0bf8f 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -7,8 +7,7 @@ import os import shutil -import tempfile -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Any, Literal import pytest @@ -278,7 +277,7 @@ def with_args(size: int) -> None: def test_no_args_function_with_derive_metric() -> None: """Test that derive_metric works with functions that have no arguments.""" - # Define a derive_metric function that only takes the metric value + # Define a derive_metric function that only takes the metric values # (no config parameters since the function has no args) def custom_metric(time_ns: float) -> float: """Convert time to arbitrary custom metric.""" @@ -540,6 +539,69 @@ def test_parameter_normalize_against() -> None: assert (df.loc[df["Annotation"] == "annotation1", "AvgValue"] == 1).all() +@nsight.analyze.kernel( + configs=( + (1,), + (2,), + (3,), + ), + runs=3, + normalize_against="annotation1", + # Some parameters that have a numerical determinism greater than + # 1 and grow with substantial increases in n. + metrics=[ + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + "smsp__sass_inst_executed_op_global_ld.sum", + "smsp__sass_inst_executed_op_global_st.sum", + ], +) +def normalize_against_multiple_metrics(n: int) -> None: + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + + c = torch.randn(100 * n, 100 * n, device="cuda") + d = torch.randn(100 * n, 100 * n, device="cuda") + + with nsight.annotate("annotation1"): + _ = a + b + + with nsight.annotate("annotation2"): + _ = c + d + + +@pytest.mark.xfail( # type: ignore[untyped-decorator] + reason="Waiting for proper support for standard normalization and speedup computation" +) +def test_parameter_normalize_against_multiple_metrics() -> None: + profile_output = normalize_against_multiple_metrics() + if profile_output is not None: + df = profile_output.to_dataframe() + + requested_metrics = [ + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + "smsp__sass_inst_executed_op_global_ld.sum", + "smsp__sass_inst_executed_op_global_st.sum", + ] + + for annotation in ["annotation1", "annotation2"]: + for n in [1, 2, 3]: + subset = df[(df["Annotation"] == annotation) & (df["n"] == n)] + assert len(subset) == len(requested_metrics) + + actual_metrics = subset["Metric"].tolist() + expected_metrics = [ + m + " relative to annotation1" for m in requested_metrics + ] + assert all(metric in actual_metrics for metric in expected_metrics) + + # AvgValue for the annotation being used as normalization factor should be 1 + assert (df.loc[df["Annotation"] == "annotation1", "AvgValue"] == 1).all() + # Validate that the AvgValue for the annotation being used for normalization is greater than 1 + assert (df.loc[df["Annotation"] == "annotation2", "AvgValue"] > 1).all() + + # ============================================================================ # Output prefix tests # ============================================================================ @@ -937,34 +999,71 @@ def profiled_func(x: int, y: int) -> None: # ============================================================================ -# metric parameter test +# metrics parameter test # ============================================================================ -@pytest.mark.parametrize("metric", ["invalid_value", "sm__warps_launched.sum"]) # type: ignore[untyped-decorator] -def test_parameter_metric(metric: str) -> None: +@pytest.mark.parametrize( # type: ignore[untyped-decorator] + "metrics, expected_result", + [ + pytest.param( + [ + "invalid_value", + ], + "invalid_single", + id="invalid_single", + ), + pytest.param( + [ + "sm__warps_launched.sum", + ], + "valid_single", + id="valid_single", + ), + pytest.param( + [ + "smsp__inst_executed.sum", + "smsp__inst_issued.sum", + ], + "invalid_multiple", + id="invalid_multiple", + ), + ], +) +def test_parameter_metric(metrics: Sequence[str], expected_result: str) -> None: - @nsight.analyze.kernel(configs=[(100, 100), (200, 200)], runs=2, metric=metric) + @nsight.analyze.plot(filename="plot.png", ylabel="Instructions") + @nsight.analyze.kernel(configs=[(100, 100), (200, 200)], runs=2, metrics=metrics) def profiled_func(x: int, y: int) -> None: _simple_kernel_impl(x, y, "test_parameter_metric") # Run profiling - if metric == "invalid_value": + if expected_result == "invalid_single": with pytest.raises( exceptions.ProfilerException, - match=f"Invalid value '{metric}' for 'metric' parameter for nsight.analyze.kernel()", + match=( + rf"Invalid value \['{metrics[0]}'\] for 'metrics' parameter for nsight.analyze.kernel()" + ), ): profiled_func() - else: + elif expected_result == "valid_single": profile_output = profiled_func() df = profile_output.to_dataframe() # Checking if the dataframe has the right metric name assert ( - df["Metric"] == metric - ).all(), f"Invalid metric name {df.loc[df['Metric'] != metric, 'Metric'].iloc[0]} found in output dataframe" + df["Metric"] == metrics[0] + ).all(), f"Invalid metric name {df.loc[df['Metric'] != metrics[0], 'Metric'].iloc[0]} found in output dataframe" # Checking if the metric values are valid assert ( df["AvgValue"].notna() & df["AvgValue"] > 0 - ).all(), f"Invalid AvgValue for metric {metric}" + ).all(), f"Invalid AvgValue for metric {metrics}" + elif expected_result == "invalid_multiple": + with pytest.raises( + ValueError, + match=( + f"Cannot visualize {len(metrics)} > 1 metrics with the @nsight.analyze.plot decorator." + ), + ): + profiled_func()