Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 57 additions & 17 deletions .claude/skills/kernel-trace-analysis/scripts/hotspot_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,33 @@ def print_source_detail(hotspot, source_cache, context=3):
print(f" stall={fmt_cycles(inst.stall_cycles):>7} type={inst.stall_type:<12} {inst.asm}")


def read_kernel_metadata(dispatch_dir):
def read_kernel_metadata(dispatch_dir, kernel_filter=""):
"""Read authoritative resource counts from ``out_kernel_trace.csv`` if present.

The ATT ``code.json`` only contains the (possibly single-CU, possibly
vgpr-form) disassembly, so it cannot reveal accum_vgpr / SGPR / LDS /
workgroup size. The kernel-trace CSV carries the real launch metadata.
Searches the dispatch dir and its parent (staging often copies the CSV
next to the ui_output_agent_* dir). Returns {} if not found.

Row selection priority:
1. ``kernel_filter`` substring matched against Kernel_Name, optionally
narrowed by Dispatch_Id when the dir name encodes ``dispatch_<id>``
(rocprofv3 ``ui_output_agent_*_dispatch_<id>`` layout). Dispatch_Id
matching avoids false matches when a PyTorch reference kernel shares
the same name substring.
2. Bidirectional name heuristic against the directory basename (legacy
path for timestamped dirs like ``20240101_120000_pa_decode_kernel``).
"""
candidates = []
for base in (dispatch_dir, os.path.dirname(os.path.abspath(dispatch_dir))):
candidates += glob.glob(os.path.join(base, "*kernel_trace*.csv"))

dir_name = os.path.basename(os.path.abspath(dispatch_dir))
# Extract the dispatch id from rocprofv3's ui_output_agent_<N>_dispatch_<id> layout.
_dispatch_id_m = re.search(r"dispatch_(\d+)$", dir_name)
dispatch_id = _dispatch_id_m.group(1) if _dispatch_id_m else None

for path in candidates:
try:
with open(path) as f:
Expand All @@ -258,24 +273,35 @@ def read_kernel_metadata(dispatch_dir):
continue
if not rows or "Accum_VGPR_Count" not in rows[0]:
continue
# Pick the row whose kernel matches the dispatch dir name. The dir is
# usually staged as "<timestamp>_<short_kernel_name>" while the CSV
# Kernel_Name has a trailing index (e.g. dir ".._pa_decode_ps_kernel"
# vs kernel "pa_decode_ps_kernel_0"), so match bidirectionally on the
# timestamp-stripped short name.
dir_name = os.path.basename(os.path.abspath(dispatch_dir))
short = re.sub(r"^\d{8}_\d{6}_", "", dir_name) # strip YYYYMMDD_HHMMSS_

def _matches(kn):
if not kn:
return False
return kn in dir_name or short in kn or kn.startswith(short) or short.startswith(kn)

has_dispatch_col = "Dispatch_Id" in rows[0]

chosen = None
for r in rows:
if _matches(r.get("Kernel_Name", "")):
if kernel_filter:
# Explicit filter: kernel name substring, narrowed by Dispatch_Id when available.
for r in rows:
if kernel_filter not in r.get("Kernel_Name", ""):
continue
if dispatch_id and has_dispatch_col:
if str(r.get("Dispatch_Id", "")).strip() != dispatch_id:
continue
chosen = r
break
else:
# Legacy heuristic: bidirectional substring match against the dir basename.
# Works for timestamped dirs like ``20240101_120000_pa_decode_kernel``.
short = re.sub(r"^\d{8}_\d{6}_", "", dir_name) # strip YYYYMMDD_HHMMSS_

def _matches(kn):
if not kn:
return False
return kn in dir_name or short in kn or kn.startswith(short) or short.startswith(kn)

for r in rows:
if _matches(r.get("Kernel_Name", "")):
chosen = r
break

if chosen is None:
continue # no matching row in this CSV — try the next candidate

Expand Down Expand Up @@ -457,7 +483,10 @@ def print_reg_pressure(reg_info):
print_header("Register Pressure & Occupancy")
print(f" Architecture: {reg_info['arch']}")
if not reg_info["has_meta"]:
print(" (no kernel_trace CSV found — accum/LDS/SGPR estimated from ISA only)")
print(
" (kernel_trace CSV not matched — accum/LDS/SGPR estimated from ISA only; "
"pass --kernel <name_substr> to enable CSV metadata lookup)"
)
if reg_info["is_vgpr_form"]:
print(f" arch_vgpr: {reg_info['arch_vgpr']} (MFMA vgpr-form: accumulators in arch file, no AGPR)")
else:
Expand Down Expand Up @@ -496,6 +525,17 @@ def main():
"--detail", action="store_true", help="Show source snippet + instruction breakdown under each source hotspot"
)
parser.add_argument("--context", type=int, default=3, help="Source lines of context around hotspot (default: 3)")
parser.add_argument(
"--kernel",
default="",
metavar="SUBSTR",
help="Kernel name substring for CSV metadata lookup "
"(e.g. 'pa_mqa_logits_fp4_kernel_0'). "
"Required when the dispatch dir name does not encode the kernel name, "
"as with rocprofv3 ui_output_agent_*_dispatch_<id> directories. "
"Combined with the dispatch id from the dir name when a Dispatch_Id "
"column is present in the CSV.",
)
args = parser.parse_args()

if not os.path.isdir(args.dispatch_dir):
Expand All @@ -515,7 +555,7 @@ def main():
print(f" Total cycles: {fmt_cycles(total_cycles)}")
print(f" Total stalls: {fmt_cycles(total_stall)} ({100*total_stall/total_cycles:.1f}% of total cycles)")

meta = read_kernel_metadata(args.dispatch_dir)
meta = read_kernel_metadata(args.dispatch_dir, kernel_filter=args.kernel)
reg_info = detect_arch_and_reg_pressure(instructions, meta)
print_reg_pressure(reg_info)

Expand Down