|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + |
| 4 | +""" |
| 5 | +Extract IR files from NDJSON trace logs. |
| 6 | +
|
| 7 | +This script extracts intermediate representation (IR) files from a Triton trace NDJSON file. |
| 8 | +For compilation events, it extracts the IR files (ttir, ttgir, llir, ptx, etc.) contained in |
| 9 | +the file_content field and saves them as individual files. |
| 10 | +
|
| 11 | +Example: |
| 12 | + Extract IRs from line 0 (first line) of the NDJSON file: |
| 13 | + python extract_irs.py -i logs.ndjson --line 0 -o output_folder |
| 14 | +
|
| 15 | + Extract from line 5: |
| 16 | + python extract_irs.py -i logs.ndjson --line 5 -o ./irs |
| 17 | +
|
| 18 | +Usage: |
| 19 | + python extract_irs.py -i <input.ndjson> --line <line_number> -o <output_folder> |
| 20 | +""" |
| 21 | + |
| 22 | +import argparse |
| 23 | +import json |
| 24 | +import sys |
| 25 | +from pathlib import Path |
| 26 | +from typing import Any, Dict, Optional |
| 27 | + |
| 28 | + |
| 29 | +def read_ndjson_line(file_path: Path, line_number: int) -> Optional[Dict[str, Any]]: |
| 30 | + """ |
| 31 | + Read a specific line from an NDJSON file (0-based indexing). |
| 32 | +
|
| 33 | + Args: |
| 34 | + file_path: Path to the NDJSON file |
| 35 | + line_number: Line number to read (0-based, where 0 = first line) |
| 36 | +
|
| 37 | + Returns: |
| 38 | + Parsed JSON object from the specified line, or None if line doesn't exist |
| 39 | +
|
| 40 | + Raises: |
| 41 | + FileNotFoundError: If the input file doesn't exist |
| 42 | + json.JSONDecodeError: If the line contains invalid JSON |
| 43 | + """ |
| 44 | + if not file_path.exists(): |
| 45 | + raise FileNotFoundError(f"File not found: {file_path}") |
| 46 | + |
| 47 | + try: |
| 48 | + with open(file_path, "r", encoding="utf-8") as f: |
| 49 | + for current_line_num, line in enumerate(f): |
| 50 | + if current_line_num == line_number: |
| 51 | + line = line.strip() |
| 52 | + if not line: |
| 53 | + print(f"Warning: Line {line_number} is empty", file=sys.stderr) |
| 54 | + return None |
| 55 | + return json.loads(line) |
| 56 | + |
| 57 | + print( |
| 58 | + f"Error: Line {line_number} not found in file (file has fewer lines)", |
| 59 | + file=sys.stderr, |
| 60 | + ) |
| 61 | + return None |
| 62 | + |
| 63 | + except json.JSONDecodeError as e: |
| 64 | + print(f"Error: Invalid JSON on line {line_number}: {e}", file=sys.stderr) |
| 65 | + raise |
| 66 | + |
| 67 | + |
| 68 | +def extract_irs( |
| 69 | + json_obj: Dict[str, Any], output_dir: Path, kernel_name: Optional[str] = None |
| 70 | +) -> int: |
| 71 | + """ |
| 72 | + Extract IR files from a JSON object and save them to the output directory. |
| 73 | +
|
| 74 | + Args: |
| 75 | + json_obj: Parsed JSON object from the NDJSON file |
| 76 | + output_dir: Directory to save the extracted IR files |
| 77 | + kernel_name: Optional kernel name to use for file naming (overrides metadata.name) |
| 78 | +
|
| 79 | + Returns: |
| 80 | + Number of files extracted |
| 81 | +
|
| 82 | + Raises: |
| 83 | + ValueError: If the JSON object is not a compilation event or missing required fields |
| 84 | + """ |
| 85 | + # Validate that this is a compilation event |
| 86 | + event_type = json_obj.get("event_type") |
| 87 | + if event_type != "compilation": |
| 88 | + raise ValueError(f"Not a compilation event (event_type: {event_type})") |
| 89 | + |
| 90 | + payload = json_obj.get("payload") |
| 91 | + if not payload: |
| 92 | + raise ValueError("Missing 'payload' field in JSON object") |
| 93 | + |
| 94 | + # Get file_content |
| 95 | + file_content = payload.get("file_content") |
| 96 | + if not file_content: |
| 97 | + raise ValueError("Missing 'file_content' field in payload") |
| 98 | + |
| 99 | + # Determine kernel name |
| 100 | + if kernel_name is None: |
| 101 | + metadata = payload.get("metadata", {}) |
| 102 | + kernel_name = metadata.get("name", "kernel") |
| 103 | + |
| 104 | + # Create output directory if it doesn't exist |
| 105 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 106 | + |
| 107 | + # Extract each IR file |
| 108 | + files_extracted = 0 |
| 109 | + for file_key, content in file_content.items(): |
| 110 | + # Determine file extension from the key |
| 111 | + # file_key is typically like "embedding_forward_kernel.ttir" |
| 112 | + # We want to extract just the extension |
| 113 | + if "." in file_key: |
| 114 | + extension = file_key.split(".")[-1] |
| 115 | + else: |
| 116 | + extension = "txt" |
| 117 | + |
| 118 | + # Create output filename |
| 119 | + output_filename = f"{kernel_name}.{extension}" |
| 120 | + output_path = output_dir / output_filename |
| 121 | + |
| 122 | + # Write content to file |
| 123 | + try: |
| 124 | + with open(output_path, "w", encoding="utf-8") as f: |
| 125 | + f.write(content) |
| 126 | + print(f"Extracted: {output_path}") |
| 127 | + files_extracted += 1 |
| 128 | + except OSError as e: |
| 129 | + print(f"Error writing file {output_path}: {e}", file=sys.stderr) |
| 130 | + |
| 131 | + # Optionally extract Python source code |
| 132 | + python_source = payload.get("python_source") |
| 133 | + if python_source and isinstance(python_source, dict): |
| 134 | + source_code = python_source.get("code") |
| 135 | + if source_code: |
| 136 | + output_path = output_dir / f"{kernel_name}_source.py" |
| 137 | + try: |
| 138 | + with open(output_path, "w", encoding="utf-8") as f: |
| 139 | + # Add header comment with file path and line range |
| 140 | + file_path_info = python_source.get("file_path", "unknown") |
| 141 | + start_line = python_source.get("start_line", "?") |
| 142 | + end_line = python_source.get("end_line", "?") |
| 143 | + f.write(f"# Source: {file_path_info}\n") |
| 144 | + f.write(f"# Lines: {start_line}-{end_line}\n\n") |
| 145 | + f.write(source_code) |
| 146 | + print(f"Extracted Python source: {output_path}") |
| 147 | + files_extracted += 1 |
| 148 | + except OSError as e: |
| 149 | + print( |
| 150 | + f"Error writing Python source file {output_path}: {e}", |
| 151 | + file=sys.stderr, |
| 152 | + ) |
| 153 | + |
| 154 | + return files_extracted |
| 155 | + |
| 156 | + |
| 157 | +def main(): |
| 158 | + """Main function to handle command line arguments and orchestrate IR extraction.""" |
| 159 | + parser = argparse.ArgumentParser( |
| 160 | + description="Extract IR files from Triton trace NDJSON logs", |
| 161 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 162 | + epilog=""" |
| 163 | +Examples: |
| 164 | + Extract IRs from line 0 (first line): |
| 165 | + python extract_irs.py -i logs.ndjson --line 0 -o output_folder |
| 166 | + |
| 167 | + Extract from line 5: |
| 168 | + python extract_irs.py -i logs.ndjson --line 5 -o ./irs |
| 169 | + |
| 170 | + Specify custom kernel name: |
| 171 | + python extract_irs.py -i logs.ndjson --line 0 -o ./irs --kernel-name my_kernel |
| 172 | + """, |
| 173 | + ) |
| 174 | + |
| 175 | + parser.add_argument( |
| 176 | + "-i", "--input", type=str, required=True, help="Path to the input NDJSON file" |
| 177 | + ) |
| 178 | + |
| 179 | + parser.add_argument( |
| 180 | + "--line", |
| 181 | + type=int, |
| 182 | + required=True, |
| 183 | + help="Line number to extract (0-based indexing, where 0 = first line)", |
| 184 | + ) |
| 185 | + |
| 186 | + parser.add_argument( |
| 187 | + "-o", |
| 188 | + "--output", |
| 189 | + type=str, |
| 190 | + required=True, |
| 191 | + help="Output directory to save extracted IR files", |
| 192 | + ) |
| 193 | + |
| 194 | + parser.add_argument( |
| 195 | + "--kernel-name", |
| 196 | + type=str, |
| 197 | + help="Custom kernel name for output files (default: use metadata.name from JSON)", |
| 198 | + ) |
| 199 | + |
| 200 | + args = parser.parse_args() |
| 201 | + |
| 202 | + # Validate line number |
| 203 | + if args.line < 0: |
| 204 | + print( |
| 205 | + f"Error: Line number must be non-negative (got {args.line})", |
| 206 | + file=sys.stderr, |
| 207 | + ) |
| 208 | + sys.exit(1) |
| 209 | + |
| 210 | + # Convert to Path objects |
| 211 | + input_path = Path(args.input) |
| 212 | + output_dir = Path(args.output) |
| 213 | + |
| 214 | + try: |
| 215 | + # Read the specified line |
| 216 | + print(f"Reading line {args.line} from {input_path}...") |
| 217 | + json_obj = read_ndjson_line(input_path, args.line) |
| 218 | + |
| 219 | + if json_obj is None: |
| 220 | + print("Error: Failed to read JSON from specified line", file=sys.stderr) |
| 221 | + sys.exit(1) |
| 222 | + |
| 223 | + # Extract IRs |
| 224 | + print(f"Extracting IRs to {output_dir}...") |
| 225 | + num_files = extract_irs(json_obj, output_dir, args.kernel_name) |
| 226 | + |
| 227 | + print(f"\nSuccess! Extracted {num_files} file(s) to {output_dir}") |
| 228 | + |
| 229 | + except FileNotFoundError as e: |
| 230 | + print(f"Error: {e}", file=sys.stderr) |
| 231 | + sys.exit(1) |
| 232 | + except ValueError as e: |
| 233 | + print(f"Error: {e}", file=sys.stderr) |
| 234 | + sys.exit(1) |
| 235 | + except json.JSONDecodeError as e: |
| 236 | + print(f"Error: Failed to parse JSON - {e}", file=sys.stderr) |
| 237 | + sys.exit(1) |
| 238 | + except Exception as e: |
| 239 | + print(f"Unexpected error: {e}", file=sys.stderr) |
| 240 | + sys.exit(1) |
| 241 | + |
| 242 | + |
| 243 | +if __name__ == "__main__": |
| 244 | + main() |
0 commit comments