Skip to content

Commit 905f22f

Browse files
authored
Merge pull request #95 from mlcommons/generator-refactor
Improve maintainability and PEP 8 compliance for Chakra trace generator
2 parents 3d5b1da + c25e99c commit 905f22f

File tree

1 file changed

+80
-114
lines changed

1 file changed

+80
-114
lines changed

src/generator/generator.py

Lines changed: 80 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343

4444
def get_node(node_name: str, node_type: ChakraNodeType) -> ChakraNode:
45+
"""Generate a new ChakraNode with a unique ID."""
4546
global NODE_ID
4647
node = ChakraNode()
4748
node.id = NODE_ID
@@ -52,105 +53,89 @@ def get_node(node_name: str, node_type: ChakraNodeType) -> ChakraNode:
5253

5354

5455
def get_comm_type_attr(comm_type: int) -> ChakraAttr:
56+
"""Create a communication type attribute."""
5557
return ChakraAttr(name="comm_type", int64_val=comm_type)
5658

5759

5860
def one_metadata_node_all_types(num_npus: int) -> None:
61+
"""Generate metadata nodes with all types of attributes."""
5962
for npu_id in range(num_npus):
6063
output_filename = f"one_metadata_node_all_types.{npu_id}.et"
6164
with open(output_filename, "wb") as et:
6265
encode_message(et, GlobalMetadata(version="0.0.4"))
6366

6467
node = get_node("METADATA_NODE", METADATA_NODE)
65-
66-
node.attr.append(ChakraAttr(name="double", double_val=1.2345, doc_string="double"))
67-
double_list = DoubleList(values=[1.2345, 2.3456])
68-
node.attr.append(ChakraAttr(name="double_list", double_list=double_list))
69-
70-
node.attr.append(ChakraAttr(name="float", float_val=1.2345, doc_string="float"))
71-
float_list = FloatList(values=[1.2345, 2.3456])
72-
node.attr.append(ChakraAttr(name="float_list", float_list=float_list))
73-
74-
node.attr.append(ChakraAttr(name="int32", int32_val=12345, doc_string="int32"))
75-
int32_list = Int32List(values=[12345, 23456])
76-
node.attr.append(ChakraAttr(name="int32_list", int32_list=int32_list))
77-
78-
node.attr.append(ChakraAttr(name="int64", int64_val=9876543210, doc_string="int64"))
79-
int64_list = Int64List(values=[9876543210, 1234567890])
80-
node.attr.append(ChakraAttr(name="int64_list", int64_list=int64_list))
81-
82-
node.attr.append(ChakraAttr(name="uint32", uint32_val=12345, doc_string="uint32"))
83-
uint32_list = Uint32List(values=[12345, 23456])
84-
node.attr.append(ChakraAttr(name="uint32_list", uint32_list=uint32_list))
85-
86-
node.attr.append(ChakraAttr(name="uint64", uint64_val=9876543210, doc_string="uint64"))
87-
uint64_list = Uint64List(values=[9876543210, 1234567890])
88-
node.attr.append(ChakraAttr(name="uint64_list", uint64_list=uint64_list))
89-
90-
node.attr.append(ChakraAttr(name="sint32", sint32_val=-12345, doc_string="sint32"))
91-
sint32_list = Sint32List(values=[12345, -23456])
92-
node.attr.append(ChakraAttr(name="sint32_list", sint32_list=sint32_list))
93-
94-
node.attr.append(ChakraAttr(name="sint64", sint64_val=-9876543210, doc_string="sint64"))
95-
sint64_list = Sint64List(values=[9876543210, -1234567890])
96-
node.attr.append(ChakraAttr(name="sint64_list", sint64_list=sint64_list))
97-
98-
node.attr.append(ChakraAttr(name="fixed32", fixed32_val=12345))
99-
fixed32_list = Fixed32List(values=[12345, 23456])
100-
node.attr.append(ChakraAttr(name="fixed32_list", fixed32_list=fixed32_list))
101-
102-
node.attr.append(ChakraAttr(name="fixed64", fixed64_val=9876543210))
103-
fixed64_list = Fixed64List(values=[9876543210, 1234567890])
104-
node.attr.append(ChakraAttr(name="fixed64_list", fixed64_list=fixed64_list))
105-
106-
node.attr.append(ChakraAttr(name="sfixed32", sfixed32_val=-12345))
107-
sfixed32_list = Sfixed32List(values=[12345, -23456])
108-
node.attr.append(ChakraAttr(name="sfixed32_list", sfixed32_list=sfixed32_list))
109-
110-
node.attr.append(ChakraAttr(name="sfixed64", sfixed64_val=-9876543210))
111-
sfixed64_list = Sfixed64List(values=[9876543210, -1234567890])
112-
node.attr.append(ChakraAttr(name="sfixed64_list", sfixed64_list=sfixed64_list))
113-
114-
node.attr.append(ChakraAttr(name="bool", bool_val=True, doc_string="bool"))
115-
bool_list = BoolList(values=[i % 2 == 0 for i in range(10)])
116-
node.attr.append(ChakraAttr(name="bool_list", bool_list=bool_list))
117-
118-
node.attr.append(ChakraAttr(name="string", string_val="12345", doc_string="string"))
119-
string_list = StringList(values=[str(12345 + i) for i in range(10)])
120-
node.attr.append(ChakraAttr(name="string_list", string_list=string_list))
121-
122-
node.attr.append(ChakraAttr(name="bytes", bytes_val=bytes("12345", "utf-8")))
123-
bytes_list = BytesList(values=[bytes(str(12345 + i), "utf-8") for i in range(10)])
124-
node.attr.append(ChakraAttr(name="bytes_list", bytes_list=bytes_list))
68+
node.attr.extend(
69+
[
70+
ChakraAttr(name="double", double_val=1.2345, doc_string="double"),
71+
ChakraAttr(name="double_list", double_list=DoubleList(values=[1.2345, 2.3456])),
72+
ChakraAttr(name="float", float_val=1.2345, doc_string="float"),
73+
ChakraAttr(name="float_list", float_list=FloatList(values=[1.2345, 2.3456])),
74+
ChakraAttr(name="int32", int32_val=12345, doc_string="int32"),
75+
ChakraAttr(name="int32_list", int32_list=Int32List(values=[12345, 23456])),
76+
ChakraAttr(name="int64", int64_val=9876543210, doc_string="int64"),
77+
ChakraAttr(name="int64_list", int64_list=Int64List(values=[9876543210, 1234567890])),
78+
ChakraAttr(name="uint32", uint32_val=12345, doc_string="uint32"),
79+
ChakraAttr(name="uint32_list", uint32_list=Uint32List(values=[12345, 23456])),
80+
ChakraAttr(name="uint64", uint64_val=9876543210, doc_string="uint64"),
81+
ChakraAttr(name="uint64_list", uint64_list=Uint64List(values=[9876543210, 1234567890])),
82+
ChakraAttr(name="sint32", sint32_val=-12345, doc_string="sint32"),
83+
ChakraAttr(name="sint32_list", sint32_list=Sint32List(values=[12345, -23456])),
84+
ChakraAttr(name="sint64", sint64_val=-9876543210, doc_string="sint64"),
85+
ChakraAttr(name="sint64_list", sint64_list=Sint64List(values=[9876543210, -1234567890])),
86+
ChakraAttr(name="fixed32", fixed32_val=12345),
87+
ChakraAttr(name="fixed32_list", fixed32_list=Fixed32List(values=[12345, 23456])),
88+
ChakraAttr(name="fixed64", fixed64_val=9876543210),
89+
ChakraAttr(name="fixed64_list", fixed64_list=Fixed64List(values=[9876543210, 1234567890])),
90+
ChakraAttr(name="sfixed32", sfixed32_val=-12345),
91+
ChakraAttr(name="sfixed32_list", sfixed32_list=Sfixed32List(values=[12345, -23456])),
92+
ChakraAttr(name="sfixed64", sfixed64_val=-9876543210),
93+
ChakraAttr(name="sfixed64_list", sfixed64_list=Sfixed64List(values=[9876543210, -1234567890])),
94+
ChakraAttr(name="bool", bool_val=True, doc_string="bool"),
95+
ChakraAttr(name="bool_list", bool_list=BoolList(values=[i % 2 == 0 for i in range(10)])),
96+
ChakraAttr(name="string", string_val="12345", doc_string="string"),
97+
ChakraAttr(name="string_list", string_list=StringList(values=[str(12345 + i) for i in range(10)])),
98+
ChakraAttr(name="bytes", bytes_val=bytes("12345", "utf-8")),
99+
ChakraAttr(
100+
name="bytes_list",
101+
bytes_list=BytesList(values=[bytes(str(12345 + i), "utf-8") for i in range(10)]),
102+
),
103+
]
104+
)
125105

126106
encode_message(et, node)
127107

128108

129109
def one_remote_mem_load_node(num_npus: int, tensor_size: int) -> None:
110+
"""Generate remote memory load nodes."""
130111
for npu_id in range(num_npus):
131112
output_filename = f"one_remote_mem_load_node.{npu_id}.et"
132113
with open(output_filename, "wb") as et:
133114
encode_message(et, GlobalMetadata(version="0.0.4"))
134115

135116
node = get_node("MEM_LOAD_NODE", MEM_LOAD_NODE)
136-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
137-
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
117+
node.attr.extend(
118+
[ChakraAttr(name="is_cpu_op", bool_val=False), ChakraAttr(name="tensor_size", uint64_val=tensor_size)]
119+
)
138120
encode_message(et, node)
139121

140122

141123
def one_remote_mem_store_node(num_npus: int, tensor_size: int) -> None:
124+
"""Generate remote memory store nodes."""
142125
for npu_id in range(num_npus):
143126
output_filename = f"one_remote_mem_store_node.{npu_id}.et"
144127
with open(output_filename, "wb") as et:
145128
encode_message(et, GlobalMetadata(version="0.0.4"))
146129

147130
node = get_node("MEM_STORE_NODE", MEM_STORE_NODE)
148-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
149-
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
131+
node.attr.extend(
132+
[ChakraAttr(name="is_cpu_op", bool_val=False), ChakraAttr(name="tensor_size", uint64_val=tensor_size)]
133+
)
150134
encode_message(et, node)
151135

152136

153137
def one_comp_node(num_npus: int, runtime: int) -> None:
138+
"""Generate computation nodes with a given runtime."""
154139
for npu_id in range(num_npus):
155140
output_filename = f"one_comp_node.{npu_id}.et"
156141
with open(output_filename, "wb") as et:
@@ -163,90 +148,74 @@ def one_comp_node(num_npus: int, runtime: int) -> None:
163148

164149

165150
def two_comp_nodes_independent(num_npus: int, runtime: int) -> None:
151+
"""Generate two independent computation nodes."""
166152
for npu_id in range(num_npus):
167153
output_filename = f"two_comp_nodes_independent.{npu_id}.et"
168154
with open(output_filename, "wb") as et:
169155
encode_message(et, GlobalMetadata(version="0.0.4"))
170156

171-
node = get_node("COMP_NODE", COMP_NODE)
172-
node.duration_micros = runtime
173-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
174-
encode_message(et, node)
175-
176-
node = get_node("COMP_NODE", COMP_NODE)
177-
node.duration_micros = runtime
178-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
179-
encode_message(et, node)
157+
for _ in range(2):
158+
node = get_node("COMP_NODE", COMP_NODE)
159+
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
160+
node.duration_micros = runtime
161+
encode_message(et, node)
180162

181163

182164
def two_comp_nodes_dependent(num_npus: int, runtime: int) -> None:
165+
"""Generate two dependent computation nodes."""
183166
for npu_id in range(num_npus):
184167
output_filename = f"two_comp_nodes_dependent.{npu_id}.et"
185168
with open(output_filename, "wb") as et:
186169
encode_message(et, GlobalMetadata(version="0.0.4"))
187170

188171
parent_node = get_node("COMP_NODE", COMP_NODE)
189-
parent_node.duration_micros = runtime
190172
parent_node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
173+
parent_node.duration_micros = runtime
191174
encode_message(et, parent_node)
192175

193176
child_node = get_node("COMP_NODE", COMP_NODE)
177+
child_node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
194178
child_node.duration_micros = runtime
195179
child_node.data_deps.append(parent_node.id)
196-
child_node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
197180
encode_message(et, child_node)
198181

199182

200-
def one_comm_coll_node_allreduce(num_npus: int, comm_size: int) -> None:
183+
def generate_comm_coll_node(num_npus: int, comm_size: int, comm_type: int, node_name: str) -> None:
184+
"""Generate communication collective nodes."""
201185
for npu_id in range(num_npus):
202-
output_filename = f"one_comm_coll_node_allreduce.{npu_id}.et"
186+
output_filename = f"{node_name}.{npu_id}.et"
203187
with open(output_filename, "wb") as et:
204188
encode_message(et, GlobalMetadata(version="0.0.4"))
205189

206-
node = get_node("ALL_REDUCE", COMM_COLL_NODE)
207-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
208-
node.attr.append(get_comm_type_attr(ALL_REDUCE))
209-
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
190+
node = get_node(node_name, COMM_COLL_NODE)
191+
node.attr.extend(
192+
[
193+
ChakraAttr(name="is_cpu_op", bool_val=False),
194+
get_comm_type_attr(comm_type),
195+
ChakraAttr(name="comm_size", uint64_val=comm_size),
196+
]
197+
)
210198
encode_message(et, node)
211199

212200

213-
def one_comm_coll_node_alltoall(num_npus: int, comm_size: int) -> None:
214-
for npu_id in range(num_npus):
215-
output_filename = f"one_comm_coll_node_alltoall.{npu_id}.et"
216-
with open(output_filename, "wb") as et:
217-
encode_message(et, GlobalMetadata(version="0.0.4"))
201+
def one_comm_coll_node_allreduce(num_npus: int, comm_size: int) -> None:
202+
"""Generate one AllReduce communication collective node."""
203+
generate_comm_coll_node(num_npus, comm_size, ALL_REDUCE, "ALL_REDUCE")
218204

219-
node = get_node("ALL_TO_ALL", COMM_COLL_NODE)
220-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
221-
node.attr.append(get_comm_type_attr(ALL_TO_ALL))
222-
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
223-
encode_message(et, node)
224205

206+
def one_comm_coll_node_alltoall(num_npus: int, comm_size: int) -> None:
207+
"""Generate one AllToAll communication collective node."""
208+
generate_comm_coll_node(num_npus, comm_size, ALL_TO_ALL, "ALL_TO_ALL")
225209

226-
def one_comm_coll_node_allgather(num_npus: int, comm_size: int) -> None:
227-
for npu_id in range(num_npus):
228-
output_filename = f"one_comm_coll_node_allgather.{npu_id}.et"
229-
with open(output_filename, "wb") as et:
230-
encode_message(et, GlobalMetadata(version="0.0.4"))
231210

232-
node = get_node("ALL_GATHER", COMM_COLL_NODE)
233-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
234-
node.attr.append(get_comm_type_attr(ALL_GATHER))
235-
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
236-
encode_message(et, node)
211+
def one_comm_coll_node_allgather(num_npus: int, comm_size: int) -> None:
212+
"""Generate one AllGather communication collective node."""
213+
generate_comm_coll_node(num_npus, comm_size, ALL_GATHER, "ALL_GATHER")
237214

238215

239216
def one_comm_coll_node_reducescatter(num_npus: int, comm_size: int) -> None:
240-
for npu_id in range(num_npus):
241-
output_filename = f"one_comm_coll_node_reducescatter.{npu_id}.et"
242-
with open(output_filename, "wb") as et:
243-
encode_message(et, GlobalMetadata(version="0.0.4"))
244-
245-
node = get_node("REDUCE_SCATTER", COMM_COLL_NODE)
246-
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
247-
node.attr.append(get_comm_type_attr(REDUCE_SCATTER))
248-
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
249-
encode_message(et, node)
217+
"""Generate one ReduceScatter communication collective node."""
218+
generate_comm_coll_node(num_npus, comm_size, REDUCE_SCATTER, "REDUCE_SCATTER")
250219

251220

252221
def main() -> None:
@@ -260,14 +229,11 @@ def main() -> None:
260229
args = parser.parse_args()
261230

262231
one_metadata_node_all_types(args.num_npus)
263-
264232
one_remote_mem_load_node(args.num_npus, args.default_tensor_size)
265233
one_remote_mem_store_node(args.num_npus, args.default_tensor_size)
266-
267234
one_comp_node(args.num_npus, args.default_runtime)
268235
two_comp_nodes_independent(args.num_npus, args.default_runtime)
269236
two_comp_nodes_dependent(args.num_npus, args.default_runtime)
270-
271237
one_comm_coll_node_allreduce(args.num_npus, args.default_comm_size)
272238
one_comm_coll_node_alltoall(args.num_npus, args.default_comm_size)
273239
one_comm_coll_node_allgather(args.num_npus, args.default_comm_size)

0 commit comments

Comments
 (0)