diff --git a/generate_summary.py b/generate_summary.py index a66f094..6412527 100644 --- a/generate_summary.py +++ b/generate_summary.py @@ -1,6 +1,7 @@ import os import sys import argparse +import re def get_script_commands(script_file): fs = open(script_file, 'r') @@ -35,18 +36,37 @@ def parse_nccl_performance(useful_lines, commands): perf_lines = [] perf_lines.append("sep=|") - perf_lines.append("size|count|type|redop|time-oplace(us)|algbw(gb/s)-oplace|busbw(gb/s)-oplace|error|" + \ - "time-iplace(us)|algbw(gb/s)-iplace|busbw(gb/s)-iplace|error|avg_bus_bw|commands") + header = "size|count|type|redop|root|time-oplace(us)|algbw(gb/s)-oplace|busbw(gb/s)-oplace|error|" + \ + "time-iplace(us)|algbw(gb/s)-iplace|busbw(gb/s)-iplace|error|avg_bus_bw|commands" + #print(header) + num_fields = len(header.split("|")) + perf_lines.append(header) for j in range(len(useful_lines)): line = useful_lines[j] line = line.replace("# Avg bus bandwidth : ", "") split_list = line.split() perf_line = "" + field_index = 0 for i in range(len(split_list)): perf_line = perf_line + split_list[i] + "|" + # Some collectives do not involve a redop + if field_index==2 and "reduce" not in commands[j].lower(): + perf_line = perf_line + "|" + field_index = field_index + 1 + # Only broadcast and reduce involve a root + if ( + field_index==3 and + re.search(r'\Wreduce_perf', commands[j]) is None and + re.search(r'\Wbroadcast_perf', commands[j]) is None + ): + perf_line = perf_line + "|" + field_index = field_index + 1 + field_index = field_index + 1 #print (perf_line + commands[j]) - perf_lines.append(perf_line + commands[j]) + perf_line = perf_line + commands[j] + assert len(perf_line.split("|")) == num_fields + perf_lines.append(perf_line) return perf_lines diff --git a/rccl_nccl_parser.py b/rccl_nccl_parser.py index 8d1f974..8e12a84 100644 --- a/rccl_nccl_parser.py +++ b/rccl_nccl_parser.py @@ -3,8 +3,17 @@ import argparse coll_op_map = { - "AllReduce": "all_reduce_perf", "Broadcast": "broadcast_perf", + "Reduce": "reduce_perf", + "AllGather": "all_gather_perf", + "ReduceScatter": "reduce_scatter_perf", + "AllReduce": "all_reduce_perf", + "Gather": "gather_perf", + "Scatter": "scatter_perf", + "AllToAll": "alltoall_perf", +# "AllToAllv": "alltoallv_perf", + "Send": "sendrecv_perf", + "Recv": "sendrecv_perf", } reduction_op_map = { @@ -62,12 +71,12 @@ def parse_nccl_log(nccl_lines): for j in range(len(nccl_lines)): line = nccl_lines[j] split_list = line.split(" ") - comm = split_list[4].replace(":", "") - count = split_list[12] - datatype = split_list[14] - op_type = split_list[16] - root = split_list[18] - nnranks = split_list[21].split("=")[1].replace("]", "") + comm = split_list[split_list.index("INFO") + 1].replace(":", "") + count = split_list[split_list.index("count") + 1] + datatype = split_list[split_list.index("datatype") + 1] + op_type = split_list[split_list.index("op") + 1] + root = split_list[split_list.index("root") + 1] + nnranks = next(item for item in split_list if 'nranks' in item).split("=")[1].replace("]", "") #print (comm) #print (count)