Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
415 changes: 415 additions & 0 deletions tests/test_tritonparse.py

Large diffs are not rendered by default.

21 changes: 20 additions & 1 deletion tritonparse/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from importlib.metadata import PackageNotFoundError, version

from .common import is_fbcode
from .info.cli import _add_info_args, info_command
from .reproducer.cli import _add_reproducer_args
from .reproducer.orchestrator import reproduce
from .utils import _add_parse_args, unified_parse
Expand Down Expand Up @@ -31,6 +32,8 @@ def main():
"Examples:\n"
f" {prog_name} parse /path/to/logs --out parsed_output\n"
f" {prog_name} reproduce /path/to/trace.ndjson --line 1 --out-dir repro_output\n"
f" {prog_name} info /path/to/trace.ndjson\n"
f" {prog_name} info /path/to/trace.ndjson --kernel matmul_kernel\n"
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
Expand Down Expand Up @@ -60,6 +63,14 @@ def main():
_add_reproducer_args(repro_parser)
repro_parser.set_defaults(func="reproduce")

# info subcommand
info_parser = subparsers.add_parser(
"info",
help="Query kernel information from trace file",
)
_add_info_args(info_parser)
info_parser.set_defaults(func="info")

args = parser.parse_args()

if args.func == "parse":
Expand All @@ -68,6 +79,10 @@ def main():
}
unified_parse(**parse_args)
elif args.func == "reproduce":
# Check mutual exclusivity between --line and --kernel/--launch-id
if args.kernel and args.line != 0:
repro_parser.error("--line and --kernel/--launch-id are mutually exclusive")

replacer = None
if args.use_fbcode:
from tritonparse.fb.reproducer.replacer import FBCodePlaceholderReplacer
Expand All @@ -77,12 +92,16 @@ def main():

reproduce(
input_path=args.input,
line_index=args.line,
line_index=args.line if not args.kernel else 0,
out_dir=args.out_dir,
template=args.template,
kernel_name=args.kernel,
launch_id=args.launch_id if args.kernel else 0,
kernel_import=args.kernel_import,
replacer=replacer,
)
elif args.func == "info":
info_command(input_path=args.input, kernel_name=args.kernel)
else:
raise RuntimeError(f"Unknown command: {args.func}")

Expand Down
30 changes: 30 additions & 0 deletions tritonparse/info/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Info module for querying kernel information from NDJSON trace files.

This module provides core query functions for kernel information:
- Listing all kernels with their launch counts
- Finding launch events by kernel name and launch ID
- Querying launch information for specific kernels
"""

from tritonparse.info.kernel_query import (
find_launch_index_by_kernel,
find_similar_kernels,
KernelSummary,
LaunchInfo,
list_kernels,
list_kernels_fast,
list_launches_for_kernel,
)

__all__ = [
"KernelSummary",
"LaunchInfo",
"list_kernels",
"list_kernels_fast",
"list_launches_for_kernel",
"find_launch_index_by_kernel",
"find_similar_kernels",
]
121 changes: 121 additions & 0 deletions tritonparse/info/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
CLI implementation for the info subcommand.

This module provides command-line interface for querying kernel information
from NDJSON trace files.
"""

import argparse
import tempfile
from typing import Optional

from tritonparse.info.kernel_query import (
find_similar_kernels,
list_kernels_fast,
list_launches_for_kernel,
)
from tritonparse.info.parse_helper import parse_and_compress_raw_log
from tritonparse.tools.prettify_ndjson import load_ndjson


def _add_info_args(parser: argparse.ArgumentParser) -> None:
"""Add arguments for the info subcommand."""
parser.add_argument(
"input",
help="Path to ndjson/ndjson.gz/.bin.ndjson file",
)
parser.add_argument(
"--kernel",
type=str,
default=None,
help="Kernel name to list launches for",
)


def info_command(input_path: str, kernel_name: Optional[str] = None) -> None:
"""
Main function for the info command.

Args:
input_path: Path to ndjson file
kernel_name: Optional kernel name to list launches for
"""
# 1. Load and detect type
events = load_ndjson(input_path)
has_launch_diff = any(e.get("event_type") == "launch_diff" for e in events)

# 2. If no launch_diff, auto-parse
if not has_launch_diff:
print(
f"Input file '{input_path}' appears to be raw log (no launch_diff events)."
)
print("Parsing automatically to generate launch_diff events...")

temp_dir = tempfile.mkdtemp(prefix="tritonparse_info_")

try:
# Parse and compress (reuses parse module's functions)
parsed_file = parse_and_compress_raw_log(
input_path,
output_dir=temp_dir,
split_inductor_compilations=False,
verbose=False,
)

# Load compressed file (load_ndjson supports .ndjson.gz)
events = load_ndjson(parsed_file)

print(f"✓ Parsed and compressed file: {parsed_file}")
print(f" (Temporary directory: {temp_dir})")
except Exception as e:
raise RuntimeError(f"Failed to parse input file '{input_path}': {e}") from e
else:
print(f"Using parsed trace file: {input_path}")

# 3. Process query
if kernel_name:
# List launches for specific kernel
try:
launches = list_launches_for_kernel(events, kernel_name)
print(f"\nLaunches for '{kernel_name}':")
print("-" * 60)
for launch in launches:
grid_str = str(launch.grid) if launch.grid else "N/A"
print(
f" id={launch.launch_id:3d} line {launch.line_index:5d} grid={grid_str}"
)
except ValueError as e:
error_msg = str(e)
print(f"\nError: {error_msg}")
# Try to suggest similar kernels
try:
similar = find_similar_kernels(events, kernel_name, n=3)
if similar:
print("\nDid you mean one of these?")
all_kernels = list_kernels_fast(
events
) # Use fast path for consistency
kernel_dict = {k.name: k for k in all_kernels}
for name in similar:
count = kernel_dict[name].total_launches
print(f" - {name} ({count} launches)")
print("\nUse 'tritonparseoss info <file>' to list all kernels.")
except Exception:
pass # Ignore errors in suggestion
raise
else:
# List all kernels
kernels = list_kernels_fast(events)
print(f"\nKernels in {input_path}:")
print("-" * 60)
for kernel in kernels:
if kernel.total_launches > 0:
max_id = kernel.total_launches - 1
print(
f" {kernel.name:30s} {kernel.total_launches:3d} launches "
f"(id: 0-{max_id})"
)
else:
print(f" {kernel.name:30s} {kernel.total_launches:3d} launches")
Loading