4242
4343
4444def get_node (node_name : str , node_type : ChakraNodeType ) -> ChakraNode :
45+ """Generate a new ChakraNode with a unique ID."""
4546 global NODE_ID
4647 node = ChakraNode ()
4748 node .id = NODE_ID
@@ -52,105 +53,89 @@ def get_node(node_name: str, node_type: ChakraNodeType) -> ChakraNode:
5253
5354
5455def get_comm_type_attr (comm_type : int ) -> ChakraAttr :
56+ """Create a communication type attribute."""
5557 return ChakraAttr (name = "comm_type" , int64_val = comm_type )
5658
5759
5860def one_metadata_node_all_types (num_npus : int ) -> None :
61+ """Generate metadata nodes with all types of attributes."""
5962 for npu_id in range (num_npus ):
6063 output_filename = f"one_metadata_node_all_types.{ npu_id } .et"
6164 with open (output_filename , "wb" ) as et :
6265 encode_message (et , GlobalMetadata (version = "0.0.4" ))
6366
6467 node = get_node ("METADATA_NODE" , METADATA_NODE )
65-
66- node .attr .append (ChakraAttr (name = "double" , double_val = 1.2345 , doc_string = "double" ))
67- double_list = DoubleList (values = [1.2345 , 2.3456 ])
68- node .attr .append (ChakraAttr (name = "double_list" , double_list = double_list ))
69-
70- node .attr .append (ChakraAttr (name = "float" , float_val = 1.2345 , doc_string = "float" ))
71- float_list = FloatList (values = [1.2345 , 2.3456 ])
72- node .attr .append (ChakraAttr (name = "float_list" , float_list = float_list ))
73-
74- node .attr .append (ChakraAttr (name = "int32" , int32_val = 12345 , doc_string = "int32" ))
75- int32_list = Int32List (values = [12345 , 23456 ])
76- node .attr .append (ChakraAttr (name = "int32_list" , int32_list = int32_list ))
77-
78- node .attr .append (ChakraAttr (name = "int64" , int64_val = 9876543210 , doc_string = "int64" ))
79- int64_list = Int64List (values = [9876543210 , 1234567890 ])
80- node .attr .append (ChakraAttr (name = "int64_list" , int64_list = int64_list ))
81-
82- node .attr .append (ChakraAttr (name = "uint32" , uint32_val = 12345 , doc_string = "uint32" ))
83- uint32_list = Uint32List (values = [12345 , 23456 ])
84- node .attr .append (ChakraAttr (name = "uint32_list" , uint32_list = uint32_list ))
85-
86- node .attr .append (ChakraAttr (name = "uint64" , uint64_val = 9876543210 , doc_string = "uint64" ))
87- uint64_list = Uint64List (values = [9876543210 , 1234567890 ])
88- node .attr .append (ChakraAttr (name = "uint64_list" , uint64_list = uint64_list ))
89-
90- node .attr .append (ChakraAttr (name = "sint32" , sint32_val = - 12345 , doc_string = "sint32" ))
91- sint32_list = Sint32List (values = [12345 , - 23456 ])
92- node .attr .append (ChakraAttr (name = "sint32_list" , sint32_list = sint32_list ))
93-
94- node .attr .append (ChakraAttr (name = "sint64" , sint64_val = - 9876543210 , doc_string = "sint64" ))
95- sint64_list = Sint64List (values = [9876543210 , - 1234567890 ])
96- node .attr .append (ChakraAttr (name = "sint64_list" , sint64_list = sint64_list ))
97-
98- node .attr .append (ChakraAttr (name = "fixed32" , fixed32_val = 12345 ))
99- fixed32_list = Fixed32List (values = [12345 , 23456 ])
100- node .attr .append (ChakraAttr (name = "fixed32_list" , fixed32_list = fixed32_list ))
101-
102- node .attr .append (ChakraAttr (name = "fixed64" , fixed64_val = 9876543210 ))
103- fixed64_list = Fixed64List (values = [9876543210 , 1234567890 ])
104- node .attr .append (ChakraAttr (name = "fixed64_list" , fixed64_list = fixed64_list ))
105-
106- node .attr .append (ChakraAttr (name = "sfixed32" , sfixed32_val = - 12345 ))
107- sfixed32_list = Sfixed32List (values = [12345 , - 23456 ])
108- node .attr .append (ChakraAttr (name = "sfixed32_list" , sfixed32_list = sfixed32_list ))
109-
110- node .attr .append (ChakraAttr (name = "sfixed64" , sfixed64_val = - 9876543210 ))
111- sfixed64_list = Sfixed64List (values = [9876543210 , - 1234567890 ])
112- node .attr .append (ChakraAttr (name = "sfixed64_list" , sfixed64_list = sfixed64_list ))
113-
114- node .attr .append (ChakraAttr (name = "bool" , bool_val = True , doc_string = "bool" ))
115- bool_list = BoolList (values = [i % 2 == 0 for i in range (10 )])
116- node .attr .append (ChakraAttr (name = "bool_list" , bool_list = bool_list ))
117-
118- node .attr .append (ChakraAttr (name = "string" , string_val = "12345" , doc_string = "string" ))
119- string_list = StringList (values = [str (12345 + i ) for i in range (10 )])
120- node .attr .append (ChakraAttr (name = "string_list" , string_list = string_list ))
121-
122- node .attr .append (ChakraAttr (name = "bytes" , bytes_val = bytes ("12345" , "utf-8" )))
123- bytes_list = BytesList (values = [bytes (str (12345 + i ), "utf-8" ) for i in range (10 )])
124- node .attr .append (ChakraAttr (name = "bytes_list" , bytes_list = bytes_list ))
68+ node .attr .extend (
69+ [
70+ ChakraAttr (name = "double" , double_val = 1.2345 , doc_string = "double" ),
71+ ChakraAttr (name = "double_list" , double_list = DoubleList (values = [1.2345 , 2.3456 ])),
72+ ChakraAttr (name = "float" , float_val = 1.2345 , doc_string = "float" ),
73+ ChakraAttr (name = "float_list" , float_list = FloatList (values = [1.2345 , 2.3456 ])),
74+ ChakraAttr (name = "int32" , int32_val = 12345 , doc_string = "int32" ),
75+ ChakraAttr (name = "int32_list" , int32_list = Int32List (values = [12345 , 23456 ])),
76+ ChakraAttr (name = "int64" , int64_val = 9876543210 , doc_string = "int64" ),
77+ ChakraAttr (name = "int64_list" , int64_list = Int64List (values = [9876543210 , 1234567890 ])),
78+ ChakraAttr (name = "uint32" , uint32_val = 12345 , doc_string = "uint32" ),
79+ ChakraAttr (name = "uint32_list" , uint32_list = Uint32List (values = [12345 , 23456 ])),
80+ ChakraAttr (name = "uint64" , uint64_val = 9876543210 , doc_string = "uint64" ),
81+ ChakraAttr (name = "uint64_list" , uint64_list = Uint64List (values = [9876543210 , 1234567890 ])),
82+ ChakraAttr (name = "sint32" , sint32_val = - 12345 , doc_string = "sint32" ),
83+ ChakraAttr (name = "sint32_list" , sint32_list = Sint32List (values = [12345 , - 23456 ])),
84+ ChakraAttr (name = "sint64" , sint64_val = - 9876543210 , doc_string = "sint64" ),
85+ ChakraAttr (name = "sint64_list" , sint64_list = Sint64List (values = [9876543210 , - 1234567890 ])),
86+ ChakraAttr (name = "fixed32" , fixed32_val = 12345 ),
87+ ChakraAttr (name = "fixed32_list" , fixed32_list = Fixed32List (values = [12345 , 23456 ])),
88+ ChakraAttr (name = "fixed64" , fixed64_val = 9876543210 ),
89+ ChakraAttr (name = "fixed64_list" , fixed64_list = Fixed64List (values = [9876543210 , 1234567890 ])),
90+ ChakraAttr (name = "sfixed32" , sfixed32_val = - 12345 ),
91+ ChakraAttr (name = "sfixed32_list" , sfixed32_list = Sfixed32List (values = [12345 , - 23456 ])),
92+ ChakraAttr (name = "sfixed64" , sfixed64_val = - 9876543210 ),
93+ ChakraAttr (name = "sfixed64_list" , sfixed64_list = Sfixed64List (values = [9876543210 , - 1234567890 ])),
94+ ChakraAttr (name = "bool" , bool_val = True , doc_string = "bool" ),
95+ ChakraAttr (name = "bool_list" , bool_list = BoolList (values = [i % 2 == 0 for i in range (10 )])),
96+ ChakraAttr (name = "string" , string_val = "12345" , doc_string = "string" ),
97+ ChakraAttr (name = "string_list" , string_list = StringList (values = [str (12345 + i ) for i in range (10 )])),
98+ ChakraAttr (name = "bytes" , bytes_val = bytes ("12345" , "utf-8" )),
99+ ChakraAttr (
100+ name = "bytes_list" ,
101+ bytes_list = BytesList (values = [bytes (str (12345 + i ), "utf-8" ) for i in range (10 )]),
102+ ),
103+ ]
104+ )
125105
126106 encode_message (et , node )
127107
128108
129109def one_remote_mem_load_node (num_npus : int , tensor_size : int ) -> None :
110+ """Generate remote memory load nodes."""
130111 for npu_id in range (num_npus ):
131112 output_filename = f"one_remote_mem_load_node.{ npu_id } .et"
132113 with open (output_filename , "wb" ) as et :
133114 encode_message (et , GlobalMetadata (version = "0.0.4" ))
134115
135116 node = get_node ("MEM_LOAD_NODE" , MEM_LOAD_NODE )
136- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
137- node .attr .append (ChakraAttr (name = "tensor_size" , uint64_val = tensor_size ))
117+ node .attr .extend (
118+ [ChakraAttr (name = "is_cpu_op" , bool_val = False ), ChakraAttr (name = "tensor_size" , uint64_val = tensor_size )]
119+ )
138120 encode_message (et , node )
139121
140122
141123def one_remote_mem_store_node (num_npus : int , tensor_size : int ) -> None :
124+ """Generate remote memory store nodes."""
142125 for npu_id in range (num_npus ):
143126 output_filename = f"one_remote_mem_store_node.{ npu_id } .et"
144127 with open (output_filename , "wb" ) as et :
145128 encode_message (et , GlobalMetadata (version = "0.0.4" ))
146129
147130 node = get_node ("MEM_STORE_NODE" , MEM_STORE_NODE )
148- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
149- node .attr .append (ChakraAttr (name = "tensor_size" , uint64_val = tensor_size ))
131+ node .attr .extend (
132+ [ChakraAttr (name = "is_cpu_op" , bool_val = False ), ChakraAttr (name = "tensor_size" , uint64_val = tensor_size )]
133+ )
150134 encode_message (et , node )
151135
152136
153137def one_comp_node (num_npus : int , runtime : int ) -> None :
138+ """Generate computation nodes with a given runtime."""
154139 for npu_id in range (num_npus ):
155140 output_filename = f"one_comp_node.{ npu_id } .et"
156141 with open (output_filename , "wb" ) as et :
@@ -163,90 +148,74 @@ def one_comp_node(num_npus: int, runtime: int) -> None:
163148
164149
165150def two_comp_nodes_independent (num_npus : int , runtime : int ) -> None :
151+ """Generate two independent computation nodes."""
166152 for npu_id in range (num_npus ):
167153 output_filename = f"two_comp_nodes_independent.{ npu_id } .et"
168154 with open (output_filename , "wb" ) as et :
169155 encode_message (et , GlobalMetadata (version = "0.0.4" ))
170156
171- node = get_node ("COMP_NODE" , COMP_NODE )
172- node .duration_micros = runtime
173- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
174- encode_message (et , node )
175-
176- node = get_node ("COMP_NODE" , COMP_NODE )
177- node .duration_micros = runtime
178- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
179- encode_message (et , node )
157+ for _ in range (2 ):
158+ node = get_node ("COMP_NODE" , COMP_NODE )
159+ node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
160+ node .duration_micros = runtime
161+ encode_message (et , node )
180162
181163
182164def two_comp_nodes_dependent (num_npus : int , runtime : int ) -> None :
165+ """Generate two dependent computation nodes."""
183166 for npu_id in range (num_npus ):
184167 output_filename = f"two_comp_nodes_dependent.{ npu_id } .et"
185168 with open (output_filename , "wb" ) as et :
186169 encode_message (et , GlobalMetadata (version = "0.0.4" ))
187170
188171 parent_node = get_node ("COMP_NODE" , COMP_NODE )
189- parent_node .duration_micros = runtime
190172 parent_node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
173+ parent_node .duration_micros = runtime
191174 encode_message (et , parent_node )
192175
193176 child_node = get_node ("COMP_NODE" , COMP_NODE )
177+ child_node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
194178 child_node .duration_micros = runtime
195179 child_node .data_deps .append (parent_node .id )
196- child_node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
197180 encode_message (et , child_node )
198181
199182
200- def one_comm_coll_node_allreduce (num_npus : int , comm_size : int ) -> None :
183+ def generate_comm_coll_node (num_npus : int , comm_size : int , comm_type : int , node_name : str ) -> None :
184+ """Generate communication collective nodes."""
201185 for npu_id in range (num_npus ):
202- output_filename = f"one_comm_coll_node_allreduce .{ npu_id } .et"
186+ output_filename = f"{ node_name } .{ npu_id } .et"
203187 with open (output_filename , "wb" ) as et :
204188 encode_message (et , GlobalMetadata (version = "0.0.4" ))
205189
206- node = get_node ("ALL_REDUCE" , COMM_COLL_NODE )
207- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
208- node .attr .append (get_comm_type_attr (ALL_REDUCE ))
209- node .attr .append (ChakraAttr (name = "comm_size" , uint64_val = comm_size ))
190+ node = get_node (node_name , COMM_COLL_NODE )
191+ node .attr .extend (
192+ [
193+ ChakraAttr (name = "is_cpu_op" , bool_val = False ),
194+ get_comm_type_attr (comm_type ),
195+ ChakraAttr (name = "comm_size" , uint64_val = comm_size ),
196+ ]
197+ )
210198 encode_message (et , node )
211199
212200
213- def one_comm_coll_node_alltoall (num_npus : int , comm_size : int ) -> None :
214- for npu_id in range (num_npus ):
215- output_filename = f"one_comm_coll_node_alltoall.{ npu_id } .et"
216- with open (output_filename , "wb" ) as et :
217- encode_message (et , GlobalMetadata (version = "0.0.4" ))
201+ def one_comm_coll_node_allreduce (num_npus : int , comm_size : int ) -> None :
202+ """Generate one AllReduce communication collective node."""
203+ generate_comm_coll_node (num_npus , comm_size , ALL_REDUCE , "ALL_REDUCE" )
218204
219- node = get_node ("ALL_TO_ALL" , COMM_COLL_NODE )
220- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
221- node .attr .append (get_comm_type_attr (ALL_TO_ALL ))
222- node .attr .append (ChakraAttr (name = "comm_size" , uint64_val = comm_size ))
223- encode_message (et , node )
224205
206+ def one_comm_coll_node_alltoall (num_npus : int , comm_size : int ) -> None :
207+ """Generate one AllToAll communication collective node."""
208+ generate_comm_coll_node (num_npus , comm_size , ALL_TO_ALL , "ALL_TO_ALL" )
225209
226- def one_comm_coll_node_allgather (num_npus : int , comm_size : int ) -> None :
227- for npu_id in range (num_npus ):
228- output_filename = f"one_comm_coll_node_allgather.{ npu_id } .et"
229- with open (output_filename , "wb" ) as et :
230- encode_message (et , GlobalMetadata (version = "0.0.4" ))
231210
232- node = get_node ("ALL_GATHER" , COMM_COLL_NODE )
233- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
234- node .attr .append (get_comm_type_attr (ALL_GATHER ))
235- node .attr .append (ChakraAttr (name = "comm_size" , uint64_val = comm_size ))
236- encode_message (et , node )
211+ def one_comm_coll_node_allgather (num_npus : int , comm_size : int ) -> None :
212+ """Generate one AllGather communication collective node."""
213+ generate_comm_coll_node (num_npus , comm_size , ALL_GATHER , "ALL_GATHER" )
237214
238215
239216def one_comm_coll_node_reducescatter (num_npus : int , comm_size : int ) -> None :
240- for npu_id in range (num_npus ):
241- output_filename = f"one_comm_coll_node_reducescatter.{ npu_id } .et"
242- with open (output_filename , "wb" ) as et :
243- encode_message (et , GlobalMetadata (version = "0.0.4" ))
244-
245- node = get_node ("REDUCE_SCATTER" , COMM_COLL_NODE )
246- node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
247- node .attr .append (get_comm_type_attr (REDUCE_SCATTER ))
248- node .attr .append (ChakraAttr (name = "comm_size" , uint64_val = comm_size ))
249- encode_message (et , node )
217+ """Generate one ReduceScatter communication collective node."""
218+ generate_comm_coll_node (num_npus , comm_size , REDUCE_SCATTER , "REDUCE_SCATTER" )
250219
251220
252221def main () -> None :
@@ -260,14 +229,11 @@ def main() -> None:
260229 args = parser .parse_args ()
261230
262231 one_metadata_node_all_types (args .num_npus )
263-
264232 one_remote_mem_load_node (args .num_npus , args .default_tensor_size )
265233 one_remote_mem_store_node (args .num_npus , args .default_tensor_size )
266-
267234 one_comp_node (args .num_npus , args .default_runtime )
268235 two_comp_nodes_independent (args .num_npus , args .default_runtime )
269236 two_comp_nodes_dependent (args .num_npus , args .default_runtime )
270-
271237 one_comm_coll_node_allreduce (args .num_npus , args .default_comm_size )
272238 one_comm_coll_node_alltoall (args .num_npus , args .default_comm_size )
273239 one_comm_coll_node_allgather (args .num_npus , args .default_comm_size )
0 commit comments