Skip to content

Commit 8b20269

Browse files
authored
Merge pull request #96 from mlcommons/generator-more-traces
Add methods to generate various trace nodes in generator
2 parents 905f22f + cd6ea3a commit 8b20269

File tree

1 file changed

+50
-13
lines changed

1 file changed

+50
-13
lines changed

src/generator/generator.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
ALL_GATHER,
55
ALL_REDUCE,
66
ALL_TO_ALL,
7+
BARRIER,
8+
BROADCAST,
79
COMM_COLL_NODE,
10+
COMM_RECV_NODE,
11+
COMM_SEND_NODE,
812
COMP_NODE,
913
MEM_LOAD_NODE,
1014
MEM_STORE_NODE,
@@ -114,9 +118,8 @@ def one_remote_mem_load_node(num_npus: int, tensor_size: int) -> None:
114118
encode_message(et, GlobalMetadata(version="0.0.4"))
115119

116120
node = get_node("MEM_LOAD_NODE", MEM_LOAD_NODE)
117-
node.attr.extend(
118-
[ChakraAttr(name="is_cpu_op", bool_val=False), ChakraAttr(name="tensor_size", uint64_val=tensor_size)]
119-
)
121+
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
122+
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
120123
encode_message(et, node)
121124

122125

@@ -128,9 +131,8 @@ def one_remote_mem_store_node(num_npus: int, tensor_size: int) -> None:
128131
encode_message(et, GlobalMetadata(version="0.0.4"))
129132

130133
node = get_node("MEM_STORE_NODE", MEM_STORE_NODE)
131-
node.attr.extend(
132-
[ChakraAttr(name="is_cpu_op", bool_val=False), ChakraAttr(name="tensor_size", uint64_val=tensor_size)]
133-
)
134+
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
135+
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
134136
encode_message(et, node)
135137

136138

@@ -188,13 +190,8 @@ def generate_comm_coll_node(num_npus: int, comm_size: int, comm_type: int, node_
188190
encode_message(et, GlobalMetadata(version="0.0.4"))
189191

190192
node = get_node(node_name, COMM_COLL_NODE)
191-
node.attr.extend(
192-
[
193-
ChakraAttr(name="is_cpu_op", bool_val=False),
194-
get_comm_type_attr(comm_type),
195-
ChakraAttr(name="comm_size", uint64_val=comm_size),
196-
]
197-
)
193+
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
194+
node.attr.extend([get_comm_type_attr(comm_type), ChakraAttr(name="comm_size", uint64_val=comm_size)])
198195
encode_message(et, node)
199196

200197

@@ -218,6 +215,42 @@ def one_comm_coll_node_reducescatter(num_npus: int, comm_size: int) -> None:
218215
generate_comm_coll_node(num_npus, comm_size, REDUCE_SCATTER, "REDUCE_SCATTER")
219216

220217

218+
def one_comm_coll_node_broadcast(num_npus: int, comm_size: int) -> None:
219+
"""Generate one Broadcast communication collective node."""
220+
generate_comm_coll_node(num_npus, comm_size, BROADCAST, "BROADCAST")
221+
222+
223+
def one_comm_coll_node_barrier(num_npus: int) -> None:
224+
"""Generate one Barrier communication collective node."""
225+
generate_comm_coll_node(num_npus, comm_size=0, comm_type=BARRIER, node_name="BARRIER")
226+
227+
228+
def one_comm_send_node(num_npus: int, tensor_size: int) -> None:
229+
"""Generate communication send nodes."""
230+
for npu_id in range(num_npus):
231+
output_filename = f"one_comm_send_node.{npu_id}.et"
232+
with open(output_filename, "wb") as et:
233+
encode_message(et, GlobalMetadata(version="0.0.4"))
234+
235+
node = get_node("COMM_SEND_NODE", COMM_SEND_NODE)
236+
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
237+
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
238+
encode_message(et, node)
239+
240+
241+
def one_comm_recv_node(num_npus: int, tensor_size: int) -> None:
242+
"""Generate communication receive nodes."""
243+
for npu_id in range(num_npus):
244+
output_filename = f"one_comm_recv_node.{npu_id}.et"
245+
with open(output_filename, "wb") as et:
246+
encode_message(et, GlobalMetadata(version="0.0.4"))
247+
248+
node = get_node("COMM_RECV_NODE", COMM_RECV_NODE)
249+
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
250+
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
251+
encode_message(et, node)
252+
253+
221254
def main() -> None:
222255
parser = argparse.ArgumentParser(description="Execution Trace Generator")
223256
parser.add_argument("--num_npus", type=int, default=64, help="Number of NPUs")
@@ -238,6 +271,10 @@ def main() -> None:
238271
one_comm_coll_node_alltoall(args.num_npus, args.default_comm_size)
239272
one_comm_coll_node_allgather(args.num_npus, args.default_comm_size)
240273
one_comm_coll_node_reducescatter(args.num_npus, args.default_comm_size)
274+
one_comm_coll_node_broadcast(args.num_npus, args.default_comm_size)
275+
one_comm_coll_node_barrier(args.num_npus)
276+
one_comm_send_node(args.num_npus, args.default_tensor_size)
277+
one_comm_recv_node(args.num_npus, args.default_tensor_size)
241278

242279

243280
if __name__ == "__main__":

0 commit comments

Comments
 (0)