diff --git a/tests/test_tritonparse.py b/tests/test_tritonparse.py index 5e8816c..b0c469a 100644 --- a/tests/test_tritonparse.py +++ b/tests/test_tritonparse.py @@ -14,6 +14,7 @@ import unittest from collections import defaultdict from dataclasses import dataclass +from pathlib import Path from typing import Any, Union import torch @@ -21,6 +22,7 @@ import triton # @manual=//triton:triton import triton.language as tl # @manual=//triton:triton import tritonparse.context_manager +import tritonparse.reproducer.orchestrator import tritonparse.structured_logging import tritonparse.utils from triton import knobs # @manual=//triton:triton @@ -138,6 +140,26 @@ def clear_all_caches(*kernels): class TestTritonparseCPU(unittest.TestCase): """CPU-only tests (no CUDA required)""" + def _get_test_ndjson_file(self): + """Get the test NDJSON file path.""" + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + self.assertTrue(gz_file.exists(), f"Test file not found: {gz_file}") + return gz_file + + def setup_temp_reproduce_dir(self): + """Setup temporary directory for reproduce tests.""" + temp_dir = tempfile.mkdtemp() + out_dir = os.path.join(temp_dir, "repro_output") + return temp_dir, out_dir + + def cleanup_temp_reproduce_dir(self, temp_dir): + """Cleanup temporary directory for reproduce tests.""" + if not TEST_KEEP_OUTPUT: + shutil.rmtree(temp_dir, ignore_errors=True) + def test_callsite_parsing(self): """Test parsing of callsite locations in TTIR/TTGIR""" from tritonparse.ir_parser import extract_loc_definitions @@ -307,6 +329,399 @@ def test_loc_alias_parsing(self): print("✓ All loc alias parsing tests passed") + def test_load_ndjson_gzip_support(self): + """Test that load_ndjson can load .ndjson.gz files.""" + from pathlib import Path + + from tritonparse.tools.prettify_ndjson import load_ndjson + + # Use existing .ndjson.gz test file + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + + # Verify file exists + self.assertTrue(gz_file.exists(), f"Test file not found: {gz_file}") + + # Load and verify + events = load_ndjson(gz_file) + self.assertIsInstance(events, list) + self.assertGreater(len(events), 0, "Should load at least one event") + + # Verify we have expected event types + event_types = {e.get("event_type") for e in events if isinstance(e, dict)} + self.assertTrue( + "compilation" in event_types or "launch" in event_types, + f"Expected compilation or launch events, got: {event_types}", + ) + + print(f"✓ Successfully loaded {len(events)} events from .ndjson.gz file") + + def test_list_kernels_empty(self): + """Test listing kernels from empty events list.""" + from tritonparse.info.kernel_query import list_kernels + + events = [] + result = list_kernels(events) + self.assertEqual(result, []) + + def test_list_kernels_single(self): + """Test listing kernels with single kernel and multiple launches.""" + from pathlib import Path + + from tritonparse.info.kernel_query import list_kernels + from tritonparse.tools.prettify_ndjson import load_ndjson + + # Load real test data + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + events = load_ndjson(gz_file) + + # Filter to only fused_op_kernel launches (4 launches) + filtered_events = [] + for event in events: + if event.get("event_type") == "launch": + kernel_name = event.get("compilation_metadata", {}).get("name") + if kernel_name == "fused_op_kernel": + filtered_events.append(event) + else: + # Keep non-launch events to test filtering + filtered_events.append(event) + + result = list_kernels(filtered_events) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "fused_op_kernel") + self.assertEqual(result[0].total_launches, 4) + + def test_list_kernels_multiple(self): + """Test listing kernels with multiple different kernels.""" + from pathlib import Path + + from tritonparse.info.kernel_query import list_kernels + from tritonparse.tools.prettify_ndjson import load_ndjson + + # Load real test data + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + events = load_ndjson(gz_file) + + result = list_kernels(events) + self.assertEqual(len(result), 2) + + # Check that results are sorted by name + names = [k.name for k in result] + self.assertEqual(names, ["fused_op_kernel", "matmul_kernel"]) + + # Check launch counts + kernel_dict = {k.name: k for k in result} + self.assertEqual(kernel_dict["matmul_kernel"].total_launches, 1553) + self.assertEqual(kernel_dict["fused_op_kernel"].total_launches, 4) + + def test_find_launch_index_valid(self): + """Test finding valid kernel name and launch_id.""" + from pathlib import Path + + from tritonparse.info.kernel_query import find_launch_index_by_kernel + from tritonparse.tools.prettify_ndjson import load_ndjson + + # Load real test data + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + events = load_ndjson(gz_file) + + # Test first launch of fused_op_kernel (launch_id=0) + index = find_launch_index_by_kernel(events, "fused_op_kernel", 0) + self.assertEqual(events[index].get("event_type"), "launch") + self.assertEqual( + events[index].get("compilation_metadata", {}).get("name"), + "fused_op_kernel", + ) + + # Test second launch of fused_op_kernel (launch_id=1) + index = find_launch_index_by_kernel(events, "fused_op_kernel", 1) + self.assertEqual(events[index].get("event_type"), "launch") + self.assertEqual( + events[index].get("compilation_metadata", {}).get("name"), + "fused_op_kernel", + ) + + # Test first launch of matmul_kernel (launch_id=0) + index = find_launch_index_by_kernel(events, "matmul_kernel", 0) + self.assertEqual(events[index].get("event_type"), "launch") + self.assertEqual( + events[index].get("compilation_metadata", {}).get("name"), + "matmul_kernel", + ) + + def test_find_launch_index_kernel_not_found(self): + """Test that ValueError is raised when kernel not found.""" + from pathlib import Path + + from tritonparse.info.kernel_query import find_launch_index_by_kernel + from tritonparse.tools.prettify_ndjson import load_ndjson + + # Load real test data + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + events = load_ndjson(gz_file) + + with self.assertRaises(ValueError) as cm: + find_launch_index_by_kernel(events, "nonexistent_kernel", 0) + + error_msg = str(cm.exception) + self.assertIn("not found", error_msg) + self.assertIn("nonexistent_kernel", error_msg) + + def test_find_launch_index_out_of_range(self): + """Test that ValueError is raised when launch_id is out of range.""" + from pathlib import Path + + from tritonparse.info.kernel_query import find_launch_index_by_kernel + from tritonparse.tools.prettify_ndjson import load_ndjson + + # Load real test data + gz_file = ( + Path(__file__).parent + / "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz" + ) + events = load_ndjson(gz_file) + + # fused_op_kernel has only 4 launches (0-3), test with launch_id=10 + with self.assertRaises(ValueError) as cm: + find_launch_index_by_kernel(events, "fused_op_kernel", 10) + + error_msg = str(cm.exception) + self.assertIn("has only 4 launches", error_msg) + self.assertIn("--launch-id 10", error_msg) + self.assertIn("Valid range: 0 to 3", error_msg) + + def test_reproduce_mutual_exclusivity(self): + """Test that --line and --kernel/--launch-id are mutually exclusive.""" + import argparse + + from tritonparse.reproducer.cli import _add_reproducer_args + + parser = argparse.ArgumentParser() + _add_reproducer_args(parser) + + # Test: both --line and --kernel provided should raise error + # Create a mock parser with error method + mock_parser = argparse.ArgumentParser() + _add_reproducer_args(mock_parser) + args = mock_parser.parse_args( + ["test.ndjson", "--line", "5", "--kernel", "matmul_kernel"] + ) + + # The mutual exclusivity check happens in cli.py main() + # We test that args are parsed correctly, and the check will happen there + self.assertEqual(args.kernel, "matmul_kernel") + self.assertEqual(args.line, 5) + + # Test: only --kernel should work (line defaults to 0, which is allowed) + args = parser.parse_args(["test.ndjson", "--kernel", "matmul_kernel"]) + self.assertEqual(args.kernel, "matmul_kernel") + self.assertEqual(args.line, 0) # default value, allowed with --kernel + + # Test: only --line should work + args = parser.parse_args(["test.ndjson", "--line", "5"]) + self.assertEqual(args.line, 5) + self.assertIsNone(args.kernel) + + def test_reproduce_kernel_launch_id(self): + """End-to-end test: reproduce using --kernel and --launch-id.""" + gz_file = self._get_test_ndjson_file() + temp_dir, out_dir = self.setup_temp_reproduce_dir() + + try: + # Test reproducing fused_op_kernel launch_id=0 + result = tritonparse.reproducer.orchestrator.reproduce( + input_path=str(gz_file), + line_index=0, # Placeholder, will be recalculated from kernel_name + out_dir=out_dir, + template="example", + kernel_name="fused_op_kernel", + launch_id=0, + ) + + # Verify output structure + self.assertIn("kernel", result) + self.assertIn("repro_script", result) + self.assertIn("repro_context", result) + self.assertTrue(os.path.exists(result["repro_script"])) + self.assertTrue(os.path.exists(result["repro_context"])) + + # Verify the script contains kernel name + script_content = Path(result["repro_script"]).read_text() + self.assertIn("fused_op_kernel", script_content) + + finally: + self.cleanup_temp_reproduce_dir(temp_dir) + + def test_reproduce_kernel_not_found(self): + """Test that proper error is raised when kernel not found.""" + gz_file = self._get_test_ndjson_file() + temp_dir, out_dir = self.setup_temp_reproduce_dir() + + try: + with self.assertRaises(ValueError) as cm: + tritonparse.reproducer.orchestrator.reproduce( + input_path=str(gz_file), + line_index=0, # Placeholder, will be recalculated from kernel_name + out_dir=out_dir, + template="example", + kernel_name="nonexistent_kernel", + launch_id=0, + ) + + error_msg = str(cm.exception) + self.assertIn("not found", error_msg) + self.assertIn("nonexistent_kernel", error_msg) + + finally: + self.cleanup_temp_reproduce_dir(temp_dir) + + def test_reproduce_launch_id_out_of_range(self): + """Test that proper error is raised when launch_id is out of range.""" + gz_file = self._get_test_ndjson_file() + temp_dir, out_dir = self.setup_temp_reproduce_dir() + + try: + # fused_op_kernel has only 4 launches (0-3), test with launch_id=10 + with self.assertRaises(ValueError) as cm: + tritonparse.reproducer.orchestrator.reproduce( + input_path=str(gz_file), + line_index=0, # Placeholder, will be recalculated from kernel_name + out_dir=out_dir, + template="example", + kernel_name="fused_op_kernel", + launch_id=10, + ) + + error_msg = str(cm.exception) + self.assertIn("has only 4 launches", error_msg) + self.assertIn("--launch-id 10", error_msg) + self.assertIn("Valid range: 0 to 3", error_msg) + + finally: + self.cleanup_temp_reproduce_dir(temp_dir) + + def test_info_kernel_query_functions(self): + """Test info module kernel query functions.""" + from tritonparse.info.kernel_query import ( + find_similar_kernels, + list_kernels, + list_kernels_fast, + list_launches_for_kernel, + ) + from tritonparse.tools.prettify_ndjson import load_ndjson + + gz_file = self._get_test_ndjson_file() + events = load_ndjson(gz_file) + + # Test list_launches_for_kernel + launches = list_launches_for_kernel(events, "fused_op_kernel") + self.assertGreater(len(launches), 0) + self.assertEqual(launches[0].launch_id, 0) + self.assertIsInstance(launches[0].grid, list) + + # Test list_launches_for_kernel with non-existent kernel + with self.assertRaises(ValueError) as cm: + list_launches_for_kernel(events, "nonexistent_kernel") + self.assertIn("not found", str(cm.exception)) + + # Test find_similar_kernels + similar = find_similar_kernels(events, "fused_op", n=3) + self.assertGreater(len(similar), 0) + self.assertIn("fused_op_kernel", similar) + + similar = find_similar_kernels(events, "fused_op_kernel", n=3) + self.assertIn("fused_op_kernel", similar) + + similar = find_similar_kernels(events, "xyz_abc_123", n=3) + self.assertEqual(len(similar), 0) + + # Test list_kernels_fast (should use launch_diff and match list_kernels) + kernels_fast = list_kernels_fast(events) + self.assertGreater(len(kernels_fast), 0) + + kernels_slow = list_kernels(events) + fast_dict = {k.name: k.total_launches for k in kernels_fast} + slow_dict = {k.name: k.total_launches for k in kernels_slow} + self.assertEqual(fast_dict, slow_dict) + + def test_info_list_kernels(self): + """Integration test: info command lists all kernels.""" + import sys + from io import StringIO + + from tritonparse.info.cli import info_command + + gz_file = self._get_test_ndjson_file() + + # Capture stdout + old_stdout = sys.stdout + sys.stdout = captured_output = StringIO() + + try: + info_command(str(gz_file), kernel_name=None) + output = captured_output.getvalue() + self.assertIn("Kernels in", output) + self.assertIn("launches", output) + finally: + sys.stdout = old_stdout + + def test_info_kernel_launches(self): + """Integration test: info command lists launches for specific kernel.""" + import sys + from io import StringIO + + from tritonparse.info.cli import info_command + + gz_file = self._get_test_ndjson_file() + + # Capture stdout + old_stdout = sys.stdout + sys.stdout = captured_output = StringIO() + + try: + info_command(str(gz_file), kernel_name="fused_op_kernel") + output = captured_output.getvalue() + self.assertIn("Launches for 'fused_op_kernel'", output) + self.assertIn("id=", output) + self.assertIn("line", output) + finally: + sys.stdout = old_stdout + + def test_info_kernel_not_found(self): + """Integration test: info command handles kernel not found.""" + import sys + from io import StringIO + + from tritonparse.info.cli import info_command + + gz_file = self._get_test_ndjson_file() + + # Capture stdout + old_stdout = sys.stdout + sys.stdout = captured_output = StringIO() + + try: + with self.assertRaises(ValueError): + info_command(str(gz_file), kernel_name="nonexistent_kernel") + output = captured_output.getvalue() + self.assertIn("not found", output) + finally: + sys.stdout = old_stdout + class TestTritonparseCUDA(unittest.TestCase): """CUDA tests (require GPU)""" diff --git a/tritonparse/cli.py b/tritonparse/cli.py index c56dc86..4df0117 100644 --- a/tritonparse/cli.py +++ b/tritonparse/cli.py @@ -4,6 +4,7 @@ from importlib.metadata import PackageNotFoundError, version from .common import is_fbcode +from .info.cli import _add_info_args, info_command from .reproducer.cli import _add_reproducer_args from .reproducer.orchestrator import reproduce from .utils import _add_parse_args, unified_parse @@ -31,6 +32,8 @@ def main(): "Examples:\n" f" {prog_name} parse /path/to/logs --out parsed_output\n" f" {prog_name} reproduce /path/to/trace.ndjson --line 1 --out-dir repro_output\n" + f" {prog_name} info /path/to/trace.ndjson\n" + f" {prog_name} info /path/to/trace.ndjson --kernel matmul_kernel\n" ), formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -60,6 +63,14 @@ def main(): _add_reproducer_args(repro_parser) repro_parser.set_defaults(func="reproduce") + # info subcommand + info_parser = subparsers.add_parser( + "info", + help="Query kernel information from trace file", + ) + _add_info_args(info_parser) + info_parser.set_defaults(func="info") + args = parser.parse_args() if args.func == "parse": @@ -68,6 +79,10 @@ def main(): } unified_parse(**parse_args) elif args.func == "reproduce": + # Check mutual exclusivity between --line and --kernel/--launch-id + if args.kernel and args.line != 0: + repro_parser.error("--line and --kernel/--launch-id are mutually exclusive") + replacer = None if args.use_fbcode: from tritonparse.fb.reproducer.replacer import FBCodePlaceholderReplacer @@ -77,12 +92,16 @@ def main(): reproduce( input_path=args.input, - line_index=args.line, + line_index=args.line if not args.kernel else 0, out_dir=args.out_dir, template=args.template, + kernel_name=args.kernel, + launch_id=args.launch_id if args.kernel else 0, kernel_import=args.kernel_import, replacer=replacer, ) + elif args.func == "info": + info_command(input_path=args.input, kernel_name=args.kernel) else: raise RuntimeError(f"Unknown command: {args.func}") diff --git a/tritonparse/info/__init__.py b/tritonparse/info/__init__.py new file mode 100644 index 0000000..5e96696 --- /dev/null +++ b/tritonparse/info/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Info module for querying kernel information from NDJSON trace files. + +This module provides core query functions for kernel information: +- Listing all kernels with their launch counts +- Finding launch events by kernel name and launch ID +- Querying launch information for specific kernels +""" + +from tritonparse.info.kernel_query import ( + find_launch_index_by_kernel, + find_similar_kernels, + KernelSummary, + LaunchInfo, + list_kernels, + list_kernels_fast, + list_launches_for_kernel, +) + +__all__ = [ + "KernelSummary", + "LaunchInfo", + "list_kernels", + "list_kernels_fast", + "list_launches_for_kernel", + "find_launch_index_by_kernel", + "find_similar_kernels", +] diff --git a/tritonparse/info/cli.py b/tritonparse/info/cli.py new file mode 100644 index 0000000..e939ce3 --- /dev/null +++ b/tritonparse/info/cli.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +CLI implementation for the info subcommand. + +This module provides command-line interface for querying kernel information +from NDJSON trace files. +""" + +import argparse +import tempfile +from typing import Optional + +from tritonparse.info.kernel_query import ( + find_similar_kernels, + list_kernels_fast, + list_launches_for_kernel, +) +from tritonparse.info.parse_helper import parse_and_compress_raw_log +from tritonparse.tools.prettify_ndjson import load_ndjson + + +def _add_info_args(parser: argparse.ArgumentParser) -> None: + """Add arguments for the info subcommand.""" + parser.add_argument( + "input", + help="Path to ndjson/ndjson.gz/.bin.ndjson file", + ) + parser.add_argument( + "--kernel", + type=str, + default=None, + help="Kernel name to list launches for", + ) + + +def info_command(input_path: str, kernel_name: Optional[str] = None) -> None: + """ + Main function for the info command. + + Args: + input_path: Path to ndjson file + kernel_name: Optional kernel name to list launches for + """ + # 1. Load and detect type + events = load_ndjson(input_path) + has_launch_diff = any(e.get("event_type") == "launch_diff" for e in events) + + # 2. If no launch_diff, auto-parse + if not has_launch_diff: + print( + f"Input file '{input_path}' appears to be raw log (no launch_diff events)." + ) + print("Parsing automatically to generate launch_diff events...") + + temp_dir = tempfile.mkdtemp(prefix="tritonparse_info_") + + try: + # Parse and compress (reuses parse module's functions) + parsed_file = parse_and_compress_raw_log( + input_path, + output_dir=temp_dir, + split_inductor_compilations=False, + verbose=False, + ) + + # Load compressed file (load_ndjson supports .ndjson.gz) + events = load_ndjson(parsed_file) + + print(f"✓ Parsed and compressed file: {parsed_file}") + print(f" (Temporary directory: {temp_dir})") + except Exception as e: + raise RuntimeError(f"Failed to parse input file '{input_path}': {e}") from e + else: + print(f"Using parsed trace file: {input_path}") + + # 3. Process query + if kernel_name: + # List launches for specific kernel + try: + launches = list_launches_for_kernel(events, kernel_name) + print(f"\nLaunches for '{kernel_name}':") + print("-" * 60) + for launch in launches: + grid_str = str(launch.grid) if launch.grid else "N/A" + print( + f" id={launch.launch_id:3d} line {launch.line_index:5d} grid={grid_str}" + ) + except ValueError as e: + error_msg = str(e) + print(f"\nError: {error_msg}") + # Try to suggest similar kernels + try: + similar = find_similar_kernels(events, kernel_name, n=3) + if similar: + print("\nDid you mean one of these?") + all_kernels = list_kernels_fast( + events + ) # Use fast path for consistency + kernel_dict = {k.name: k for k in all_kernels} + for name in similar: + count = kernel_dict[name].total_launches + print(f" - {name} ({count} launches)") + print("\nUse 'tritonparseoss info ' to list all kernels.") + except Exception: + pass # Ignore errors in suggestion + raise + else: + # List all kernels + kernels = list_kernels_fast(events) + print(f"\nKernels in {input_path}:") + print("-" * 60) + for kernel in kernels: + if kernel.total_launches > 0: + max_id = kernel.total_launches - 1 + print( + f" {kernel.name:30s} {kernel.total_launches:3d} launches " + f"(id: 0-{max_id})" + ) + else: + print(f" {kernel.name:30s} {kernel.total_launches:3d} launches") diff --git a/tritonparse/info/kernel_query.py b/tritonparse/info/kernel_query.py new file mode 100644 index 0000000..5793308 --- /dev/null +++ b/tritonparse/info/kernel_query.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Core query functions for kernel information from NDJSON trace files. + +This module provides functions to query kernel launch information from parsed +event lists. It supports both raw log files and parsed ndjson files (with launch_diff events). +""" + +import difflib +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List + + +@dataclass +class KernelSummary: + """Summary information about a kernel.""" + + name: str + hash: str + total_launches: int + + +@dataclass +class LaunchInfo: + """Information about a specific kernel launch.""" + + launch_id: int # 0-based + line_index: int # 0-based (index in events list) + grid: List[int] + + +def list_kernels(events: List[Dict[str, Any]]) -> List[KernelSummary]: + """ + List all kernels with their launch counts. + + Args: + events: List of parsed event dictionaries from NDJSON file + + Returns: + List of KernelSummary objects, sorted by kernel name + """ + # Count launches per kernel + kernel_counts: Dict[str, Dict[str, Any]] = defaultdict( + lambda: {"hash": "", "count": 0} + ) + + for event in events: + if event.get("event_type") != "launch": + continue + + comp_meta = event.get("compilation_metadata", {}) + kernel_name = comp_meta.get("name") + kernel_hash = comp_meta.get("hash", "") + + if kernel_name: + kernel_counts[kernel_name]["hash"] = kernel_hash + kernel_counts[kernel_name]["count"] += 1 + + # Convert to KernelSummary list + summaries = [ + KernelSummary(name=name, hash=info["hash"], total_launches=info["count"]) + for name, info in kernel_counts.items() + ] + + # Sort by kernel name for consistent output + summaries.sort(key=lambda x: x.name) + + return summaries + + +def find_launch_index_by_kernel( + events: List[Dict[str, Any]], kernel_name: str, launch_id: int +) -> int: + """ + Find the 0-based line index for a kernel's N-th launch. + + Args: + events: List of parsed event dictionaries + kernel_name: Exact kernel name to match (case-sensitive) + launch_id: 0-based launch index for the kernel + + Returns: + 0-based line index (index in events list) + + Raises: + ValueError: If kernel not found or launch_id out of range + """ + count = 0 + for i, event in enumerate(events): + if event.get("event_type") != "launch": + continue + + comp_meta = event.get("compilation_metadata", {}) + name = comp_meta.get("name") + if name == kernel_name: + if count == launch_id: + return i + count += 1 + + if count == 0: + raise ValueError(f"Kernel '{kernel_name}' not found") + else: + raise ValueError( + f"Kernel '{kernel_name}' has only {count} launches, " + f"but --launch-id {launch_id} was requested. Valid range: 0 to {count - 1}" + ) + + +def list_launches_for_kernel( + events: List[Dict[str, Any]], kernel_name: str +) -> List[LaunchInfo]: + """ + List all launches for a specific kernel. + + Args: + events: List of parsed event dictionaries + kernel_name: Exact kernel name to match (case-sensitive) + + Returns: + List of LaunchInfo objects for the kernel, sorted by launch_id + + Raises: + ValueError: If kernel not found + """ + launches = [] + launch_id = 0 + + for i, event in enumerate(events): + if event.get("event_type") != "launch": + continue + + comp_meta = event.get("compilation_metadata", {}) + name = comp_meta.get("name") + if name == kernel_name: + # Extract grid information from launch event + grid = event.get("grid", []) + launches.append(LaunchInfo(launch_id=launch_id, line_index=i, grid=grid)) + launch_id += 1 + + if not launches: + raise ValueError(f"Kernel '{kernel_name}' not found") + + return launches + + +def find_similar_kernels( + events: List[Dict[str, Any]], kernel_name: str, n: int = 3 +) -> List[str]: + """ + Find similar kernel names using fuzzy matching. + + Args: + events: List of parsed event dictionaries + kernel_name: Kernel name to find similar matches for + n: Maximum number of matches to return + + Returns: + List of similar kernel names (may be empty if no matches found) + """ + all_kernels = list_kernels(events) + all_names = [k.name for k in all_kernels] + return difflib.get_close_matches(kernel_name, all_names, n=n, cutoff=0.6) + + +def list_kernels_fast(events: List[Dict[str, Any]]) -> List[KernelSummary]: + """ + Fast kernel listing using launch_diff events when available. + + If launch_diff events are present, uses them for fast listing. + Otherwise, falls back to list_kernels(). + + Args: + events: List of parsed event dictionaries + + Returns: + List of KernelSummary objects, sorted by kernel name + """ + # Check if launch_diff events are available + launch_diff_events = [e for e in events if e.get("event_type") == "launch_diff"] + + if launch_diff_events: + # Use launch_diff events for fast listing + # Merge kernels with the same name (sum up launches) + kernel_dict: Dict[str, KernelSummary] = {} + for event in launch_diff_events: + name = event.get("name", "") + if not name: + continue + hash_val = event.get("hash", "") + launches = event.get("total_launches", 0) + + if name in kernel_dict: + # Merge: sum up launches, keep first hash + kernel_dict[name].total_launches += launches + else: + kernel_dict[name] = KernelSummary( + name=name, + hash=hash_val, + total_launches=launches, + ) + + summaries = list(kernel_dict.values()) + summaries.sort(key=lambda x: x.name) + return summaries + else: + # Fall back to full traversal + return list_kernels(events) diff --git a/tritonparse/info/parse_helper.py b/tritonparse/info/parse_helper.py new file mode 100644 index 0000000..e923d3c --- /dev/null +++ b/tritonparse/info/parse_helper.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Helper functions for parsing raw log files in the info module. + +This module provides utilities to parse and compress raw log files, +reusing functionality from the parse module. +""" + +from pathlib import Path + +from tritonparse.common import gzip_single_file +from tritonparse.trace_processor import parse_single_file + + +def parse_and_compress_raw_log( + input_path: str, + output_dir: str, + split_inductor_compilations: bool = False, + verbose: bool = False, +) -> Path: + """ + Parse a raw log file, compress it, and return the path to the compressed parsed file. + + This function reuses the parse module's functionality: + - parse_single_file: Parse the file + - gzip_single_file: Compress the parsed file + + Args: + input_path: Path to raw log file + output_dir: Directory to save parsed file + split_inductor_compilations: Whether to split by inductor compilations + verbose: Whether to print verbose information + + Returns: + Path to the generated compressed parsed file (.ndjson.gz) + + Raises: + RuntimeError: If parsing fails or parsed file not found + """ + # 1. Parse the file (generates uncompressed .ndjson) + parse_single_file( + input_path, + output_dir=output_dir, + split_inductor_compilations=split_inductor_compilations, + ) + + # 2. Calculate expected output filename + input_path_obj = Path(input_path) + file_name = input_path_obj.name + + if input_path.endswith(".bin.ndjson"): + file_name_without_ext = file_name[:-11] # Remove ".bin.ndjson" + else: + file_name_without_ext = input_path_obj.stem # Remove all extensions + # If there's still a .ndjson extension, remove it + if file_name_without_ext.endswith(".ndjson"): + file_name_without_ext = file_name_without_ext[:-7] + + uncompressed_file = Path(output_dir) / f"{file_name_without_ext}_mapped.ndjson" + + if not uncompressed_file.exists(): + raise RuntimeError( + f"Failed to generate parsed file. Expected: {uncompressed_file}" + ) + + # 3. Compress the file (reusing parse module's function) + compressed_file = gzip_single_file(str(uncompressed_file), verbose=verbose) + + return Path(compressed_file) # Returns .ndjson.gz path diff --git a/tritonparse/reproducer/cli.py b/tritonparse/reproducer/cli.py index 4cc3955..b6787f8 100644 --- a/tritonparse/reproducer/cli.py +++ b/tritonparse/reproducer/cli.py @@ -14,7 +14,26 @@ def _add_reproducer_args(parser: argparse.ArgumentParser) -> None: default=0, help=( "The line index (0-based) of the launch event in the input file to reproduce. " - "Defaults to 0 (first launch event)." + "Defaults to 0 (first launch event). Mutually exclusive with --kernel/--launch-id." + ), + ) + parser.add_argument( + "--kernel", + type=str, + default=None, + help=( + "Kernel name (exact match, case-sensitive) to reproduce. " + "Use with --launch-id to specify which launch of the kernel. " + "Mutually exclusive with --line." + ), + ) + parser.add_argument( + "--launch-id", + type=int, + default=0, + help=( + "0-based launch index for the kernel specified by --kernel. " + "Defaults to 0 (first launch). Only used when --kernel is provided." ), ) parser.add_argument( diff --git a/tritonparse/reproducer/orchestrator.py b/tritonparse/reproducer/orchestrator.py index 6dcc530..4067e9f 100644 --- a/tritonparse/reproducer/orchestrator.py +++ b/tritonparse/reproducer/orchestrator.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Optional +from tritonparse.info.kernel_query import find_launch_index_by_kernel from tritonparse.reproducer.ingestion.ndjson import build_context_bundle from tritonparse.reproducer.placeholder_replacer import ( DefaultPlaceholderReplacer, @@ -20,24 +21,44 @@ def reproduce( line_index: int, out_dir: str, template: str, + kernel_name: Optional[str] = None, + launch_id: int = 0, replacer: Optional[PlaceholderReplacer] = None, kernel_import: KernelImportMode = KernelImportMode.DEFAULT, ) -> dict[str, str]: """ Generate a reproducer script from NDJSON trace file. + Must provide either line_index OR (kernel_name + launch_id), not both. + If kernel_name is provided, the line_index parameter will be ignored and + recalculated from the kernel lookup. + Args: - input_path: Path to the NDJSON trace file. - line_index: 0-based index of the launch event to reproduce in the events list. + input_path: Path to ndjson file. Supports uncompressed (.ndjson), + gzip compressed (.ndjson.gz), and gzip member concatenation (.bin.ndjson) formats. + line_index: 0-based index in events list. Ignored if kernel_name is provided. out_dir: Output directory for reproducer files. template: Template name to use for the reproducer. + kernel_name: Exact kernel name to match (case-sensitive). If provided, line_index will be recalculated. + launch_id: 0-based launch index for the kernel (default: 0, first launch). replacer: Optional custom PlaceholderReplacer instance. If None, uses DefaultPlaceholderReplacer. kernel_import: Kernel import mode (DEFAULT or COPY). """ - logger.debug(f"Building bundle from {input_path} at line {line_index}") events = load_ndjson(Path(input_path)) logger.debug(f"Loaded {len(events)} events") + # If kernel_name is provided, lookup the actual line_index (overrides the parameter) + if kernel_name is not None: + logger.debug( + f"Looking up kernel '{kernel_name}' launch_id={launch_id} in {input_path}" + ) + line_index = find_launch_index_by_kernel(events, kernel_name, launch_id) + logger.debug( + f"Found kernel '{kernel_name}' launch_id={launch_id} at line {line_index}" + ) + + logger.debug(f"Building bundle from {input_path} at line {line_index}") + # Build context bundle from the specified launch event context_bundle = build_context_bundle(events, line_index) logger.debug( diff --git a/tritonparse/tools/prettify_ndjson.py b/tritonparse/tools/prettify_ndjson.py index 72fb5e8..863f574 100644 --- a/tritonparse/tools/prettify_ndjson.py +++ b/tritonparse/tools/prettify_ndjson.py @@ -39,12 +39,19 @@ """ import argparse +import gzip import json import sys from pathlib import Path from typing import Any, List, Union +def _is_gzip_file(file_path: Path) -> bool: + """Check if file is gzip compressed (.gz or .bin.ndjson).""" + path_str = str(file_path) + return path_str.endswith(".gz") or path_str.endswith(".bin.ndjson") + + def parse_line_ranges(lines_arg: str) -> set[int]: """ Parse line ranges from string like "1,2,3,5-10" into a set of line numbers. @@ -106,6 +113,9 @@ def load_ndjson( """ Load NDJSON file and return list of JSON objects. + Supports uncompressed (.ndjson), gzip compressed (.ndjson.gz), + and gzip member concatenation (.bin.ndjson) formats. + Args: file_path: Path to the NDJSON file not_save_irs: Whether to NOT save file_content and python_source for compilation events @@ -122,8 +132,13 @@ def load_ndjson( filtered_compilation_events = 0 total_lines_processed = 0 + # Determine if file is gzip compressed + is_compressed = _is_gzip_file(file_path) + opener = gzip.open if is_compressed else open + mode = "rt" if is_compressed else "r" + try: - with open(file_path, "r", encoding="utf-8") as f: + with opener(file_path, mode, encoding="utf-8") as f: # enumerate(f, 1) starts line numbering from 1 (1-based indexing) for line_num, line in enumerate(f, 1): line = line.strip()