Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions benchmarks/rocm_benchmarks/bench_fa2_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
python benchmarks/rocm_benchmarks/bench_fa2_prefill.py --counters occupancy
python benchmarks/rocm_benchmarks/bench_fa2_prefill.py --counters stall
python benchmarks/rocm_benchmarks/bench_fa2_prefill.py --counters compute
python benchmarks/rocm_benchmarks/bench_fa2_prefill.py --counters lds_stall

# Override the output file label prefix:
python benchmarks/rocm_benchmarks/bench_fa2_prefill.py --counters occupancy --label fa2_occ
Expand All @@ -47,14 +48,20 @@
benchmarks/rocm_benchmarks/<label>_roofline.png (roofline preset only)

Counter presets available out of the box:
roofline — FetchSize, WriteSize, MFMA ops, TCC DRAM requests (default)
compute — MFMA ops and cycle counters
memory — L2 and DRAM bandwidth breakdown
basic — minimal: FetchSize / WriteSize only
occupancy — SQ_WAVES, SQ_BUSY_CYCLES, SQ_VALU_MFMA_BUSY_CYCLES,
SQ_WAIT_INST_ANY, SQ_INSTS_LDS
stall — SQ_INSTS_MFMA, SQ_WAIT_INST_VMEM, SQ_VALU_MFMA_BUSY_CYCLES,
SQ_WAIT_INST_LDS, SQ_BUSY_CYCLES
roofline — FetchSize, WriteSize, MFMA ops, TCC DRAM requests (default)
compute — MFMA ops and cycle counters
memory — L2 and DRAM bandwidth breakdown
basic — minimal: FetchSize / WriteSize only
occupancy — SQ_WAVES, SQ_BUSY_CYCLES, SQ_VALU_MFMA_BUSY_CYCLES,
SQ_WAIT_INST_ANY, SQ_INSTS_LDS
stall — SQ_INSTS_MFMA, SQ_WAIT_INST_VMEM, SQ_VALU_MFMA_BUSY_CYCLES,
SQ_WAIT_INST_LDS, SQ_BUSY_CYCLES
lds_stall — focused 3-pass LDS/VMEM stall analysis (recommended for
diagnosing prefill kernel stalls); automatically prints:
ALUStalledByLDS(%) = SQ_WAIT_INST_LDS / SQ_BUSY_CYCLES * 100
ALUStalledByVMEM(%) = SQ_WAIT_INST_VMEM / SQ_BUSY_CYCLES * 100
VMEM avg latency = SQ_ACCUM_PREV_HIRES / SQ_INSTS_VMEM
LDS wait/inst = SQ_WAIT_INST_LDS / SQ_INSTS_LDS

Or pass a path to a YAML file in rocprofv3 native job format.

Expand Down
155 changes: 155 additions & 0 deletions rocm_profiler/rocm_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@
- pmc: [SQ_INSTS_MFMA, SQ_WAIT_INST_VMEM, SQ_VALU_MFMA_BUSY_CYCLES, FetchSize]
# Pass 2: LDS stall cycles + total busy cycles + L2 writes
- pmc: [SQ_WAIT_INST_LDS, SQ_BUSY_CYCLES, WriteSize]
""",
# Three-pass focused LDS/VMEM stall analysis — captures all counters needed
# to derive ALUStalledByLDS, ALUStalledByVMEM (AMD standard derived metrics),
# average VMEM latency via the level-counter method, and LDS wait per instruction.
#
# Derived metrics (printed automatically by _print_stall_table):
# ALUStalledByLDS (%) = SQ_WAIT_INST_LDS / SQ_BUSY_CYCLES * 100
# ALUStalledByVMEM (%) = SQ_WAIT_INST_VMEM / SQ_BUSY_CYCLES * 100
# VMEM avg latency (cycles) = SQ_ACCUM_PREV_HIRES / SQ_INSTS_VMEM
# (level-counter method: SQ_INST_LEVEL_VMEM must precede SQ_ACCUM_PREV_HIRES)
# LDS wait / LDS inst = SQ_WAIT_INST_LDS / SQ_INSTS_LDS
#
# FetchSize and WriteSize are in separate passes (gfx942 constraint, error 38).
"lds_stall": """\
jobs:
# Pass 1: LDS wait + VMEM wait + total busy cycles + LDS instruction count
- pmc: [SQ_WAIT_INST_LDS, SQ_WAIT_INST_VMEM, SQ_BUSY_CYCLES, SQ_INSTS_LDS]
# Pass 2: VMEM instruction count + LDS bank conflicts + L2 reads
- pmc: [SQ_INSTS_VMEM, SQ_LDS_BANK_CONFLICT, FetchSize]
# Pass 3: VMEM average latency via level counter (HIRES must follow level counter)
- pmc: [SQ_INST_LEVEL_VMEM, SQ_ACCUM_PREV_HIRES, WriteSize]
""",
# -----------------------------------------------------------------------
# Template for custom tuning presets — copy, rename, and modify.
Expand Down Expand Up @@ -415,6 +436,9 @@ def run(self) -> None:
if counter_csv is not None and self.roofline and not args.skip_roofline:
self._plot_roofline(timing_csv, counter_csv)

if counter_csv is not None and not args.skip_roofline:
Comment thread
diptorupd marked this conversation as resolved.
Outdated
self._maybe_print_stall_table(timing_csv, counter_csv)

# ── Argparse ──────────────────────────────────────────────────────────────

def _parse_args(self) -> argparse.Namespace:
Expand Down Expand Up @@ -912,6 +936,49 @@ def _write_counter_file(self) -> Path:
raise FileNotFoundError(f"Counter file not found: {p}")
return p

# ── Stall analysis ────────────────────────────────────────────────────────

def _maybe_print_stall_table(self, timing_csv: Path, counter_csv: Path) -> None:
"""Print the LDS/VMEM stall table when stall PMCs are present in the CSV.

Checks for the ``SQ_WAIT_INST_LDS`` column as a sentinel. This fires
automatically for both the ``stall`` and ``lds_stall`` presets (and any
custom YAML that includes those counters) without requiring the caller to
set a flag.
"""
import csv as _csv

# Read counter CSV and check for the sentinel column.
try:
with open(counter_csv, newline="", encoding="utf-8") as f:
counter_rows = list(_csv.DictReader(f))
except OSError as exc:
print(f"[rocm_profiler] Cannot read counter CSV: {exc}", file=sys.stderr)
return

if not counter_rows:
return

# Case-insensitive sentinel check — column names vary by rocprofv3 version.
col_names_lower = {k.lower() for k in counter_rows[0]}
if "sq_wait_inst_lds" not in col_names_lower:
return # No stall counters present; skip silently.

# Read timing CSV.
try:
with open(timing_csv, newline="", encoding="utf-8") as f:
timing_rows = list(_csv.DictReader(f))
except OSError as exc:
print(f"[rocm_profiler] Cannot read timing CSV: {exc}", file=sys.stderr)
return

_print_stall_table(
timing_rows,
counter_rows,
self.label,
name_to_label={cfg.name: cfg.label for cfg in self.configs},
)


# ---------------------------------------------------------------------------
# Helper functions
Expand Down Expand Up @@ -1145,6 +1212,94 @@ def _print_metrics_table(
print()


def _print_stall_table(
timing_rows: list[dict],
counter_rows: list[dict],
label: str,
name_to_label: dict[str, str] | None = None,
) -> None:
"""Print per-config LDS/VMEM stall analysis derived from ``lds_stall`` (or
``stall``) preset counters.

Derived metrics printed:

* ``ALUStalledByLDS (%)`` = SQ_WAIT_INST_LDS / SQ_BUSY_CYCLES * 100
* ``ALUStalledByVMEM (%)`` = SQ_WAIT_INST_VMEM / SQ_BUSY_CYCLES * 100
* ``VMEM avg lat (cyc)`` = SQ_ACCUM_PREV_HIRES / SQ_INSTS_VMEM
(level-counter method; shows ``N/A`` when counters are unavailable)
* ``LDS wait/inst`` = SQ_WAIT_INST_LDS / SQ_INSTS_LDS
(shows ``N/A`` when SQ_INSTS_LDS is absent)
* ``LDS bank cf`` = SQ_LDS_BANK_CONFLICT (raw count)
"""
if name_to_label is None:
name_to_label = {}

# Counter CSV is keyed by "config_name" (written by _merge_pass_csvs).
# Timing CSV is keyed by "name". Both use the KernelConfig.name value.
counter_by_name: dict[str, dict] = {
r.get("config_name", "").strip(): r for r in counter_rows
}
timing_by_name: dict[str, dict] = {r.get("name", ""): r for r in timing_rows}

title = f"=== {label} — LDS/VMEM Stall Analysis ==="
print(f"\n{title}\n")

header = (
f"{'Config':>30} | {'ms':>7} | {'AluStLDS%':>9} | {'AluStVMEM%':>10} | "
f"{'VMEM_lat_cyc':>12} | {'LDS_w/inst':>10} | {'LDS_bank_cf':>11}"
)
print(header)
print("-" * len(header))

for name, t in timing_by_name.items():
c = counter_by_name.get(name)
if c is None:
print(
f" WARNING: no counter row for '{name}' — "
"check kernel_name_regex and rocprofv3 output",
file=sys.stderr,
)
continue

median_ms = _float(t.get("median_ms", "nan"), float("nan"))

wait_lds = _float(c.get("SQ_WAIT_INST_LDS", 0))
wait_vmem = _float(c.get("SQ_WAIT_INST_VMEM", 0))
busy = _float(c.get("SQ_BUSY_CYCLES", 0))
insts_lds = _float(c.get("SQ_INSTS_LDS", 0))
insts_vmem = _float(c.get("SQ_INSTS_VMEM", 0))
hires = _float(c.get("SQ_ACCUM_PREV_HIRES", 0))
bank_cf = _float(c.get("SQ_LDS_BANK_CONFLICT", 0))
Comment thread
diptorupd marked this conversation as resolved.
Outdated

alu_stalled_lds = (wait_lds / busy * 100) if busy > 0 else float("nan")
alu_stalled_vmem = (wait_vmem / busy * 100) if busy > 0 else float("nan")
vmem_lat = (hires / insts_vmem) if (insts_vmem > 0 and hires > 0) else None
lds_wait_per_inst = (wait_lds / insts_lds) if insts_lds > 0 else None

vmem_lat_str = f"{vmem_lat:>12.1f}" if vmem_lat is not None else f"{'N/A':>12}"
lds_wpi_str = (
f"{lds_wait_per_inst:>10.2f}"
if lds_wait_per_inst is not None
else f"{'N/A':>10}"
)

print(
f"{name_to_label.get(name, name):>30} | {median_ms:>7.3f} | "
f"{alu_stalled_lds:>9.1f} | {alu_stalled_vmem:>10.1f} | "
f"{vmem_lat_str} | {lds_wpi_str} | {bank_cf:>11.0f}"
)

print()
print(
" ALUStalledByLDS(%) = SQ_WAIT_INST_LDS / SQ_BUSY_CYCLES * 100\n"
" ALUStalledByVMEM(%) = SQ_WAIT_INST_VMEM / SQ_BUSY_CYCLES * 100\n"
" VMEM_lat_cyc = SQ_ACCUM_PREV_HIRES / SQ_INSTS_VMEM "
"(N/A if level counter unavailable)\n"
" LDS_w/inst = SQ_WAIT_INST_LDS / SQ_INSTS_LDS "
"(N/A if SQ_INSTS_LDS absent)\n"
)


def _draw_roofline_backdrop(ax: Any, hw: HardwareCeilings) -> None:
bw_peak = hw.peak_bw_tbs * 1e12
fp16_peak = hw.peak_tflops_fp16 * 1e12
Expand Down
Loading