44 ALL_GATHER ,
55 ALL_REDUCE ,
66 ALL_TO_ALL ,
7+ BARRIER ,
8+ BROADCAST ,
79 COMM_COLL_NODE ,
10+ COMM_RECV_NODE ,
11+ COMM_SEND_NODE ,
812 COMP_NODE ,
913 MEM_LOAD_NODE ,
1014 MEM_STORE_NODE ,
@@ -114,9 +118,8 @@ def one_remote_mem_load_node(num_npus: int, tensor_size: int) -> None:
114118 encode_message (et , GlobalMetadata (version = "0.0.4" ))
115119
116120 node = get_node ("MEM_LOAD_NODE" , MEM_LOAD_NODE )
117- node .attr .extend (
118- [ChakraAttr (name = "is_cpu_op" , bool_val = False ), ChakraAttr (name = "tensor_size" , uint64_val = tensor_size )]
119- )
121+ node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
122+ node .attr .append (ChakraAttr (name = "tensor_size" , uint64_val = tensor_size ))
120123 encode_message (et , node )
121124
122125
@@ -128,9 +131,8 @@ def one_remote_mem_store_node(num_npus: int, tensor_size: int) -> None:
128131 encode_message (et , GlobalMetadata (version = "0.0.4" ))
129132
130133 node = get_node ("MEM_STORE_NODE" , MEM_STORE_NODE )
131- node .attr .extend (
132- [ChakraAttr (name = "is_cpu_op" , bool_val = False ), ChakraAttr (name = "tensor_size" , uint64_val = tensor_size )]
133- )
134+ node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
135+ node .attr .append (ChakraAttr (name = "tensor_size" , uint64_val = tensor_size ))
134136 encode_message (et , node )
135137
136138
@@ -188,13 +190,8 @@ def generate_comm_coll_node(num_npus: int, comm_size: int, comm_type: int, node_
188190 encode_message (et , GlobalMetadata (version = "0.0.4" ))
189191
190192 node = get_node (node_name , COMM_COLL_NODE )
191- node .attr .extend (
192- [
193- ChakraAttr (name = "is_cpu_op" , bool_val = False ),
194- get_comm_type_attr (comm_type ),
195- ChakraAttr (name = "comm_size" , uint64_val = comm_size ),
196- ]
197- )
193+ node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
194+ node .attr .extend ([get_comm_type_attr (comm_type ), ChakraAttr (name = "comm_size" , uint64_val = comm_size )])
198195 encode_message (et , node )
199196
200197
@@ -218,6 +215,42 @@ def one_comm_coll_node_reducescatter(num_npus: int, comm_size: int) -> None:
218215 generate_comm_coll_node (num_npus , comm_size , REDUCE_SCATTER , "REDUCE_SCATTER" )
219216
220217
218+ def one_comm_coll_node_broadcast (num_npus : int , comm_size : int ) -> None :
219+ """Generate one Broadcast communication collective node."""
220+ generate_comm_coll_node (num_npus , comm_size , BROADCAST , "BROADCAST" )
221+
222+
223+ def one_comm_coll_node_barrier (num_npus : int ) -> None :
224+ """Generate one Barrier communication collective node."""
225+ generate_comm_coll_node (num_npus , comm_size = 0 , comm_type = BARRIER , node_name = "BARRIER" )
226+
227+
228+ def one_comm_send_node (num_npus : int , tensor_size : int ) -> None :
229+ """Generate communication send nodes."""
230+ for npu_id in range (num_npus ):
231+ output_filename = f"one_comm_send_node.{ npu_id } .et"
232+ with open (output_filename , "wb" ) as et :
233+ encode_message (et , GlobalMetadata (version = "0.0.4" ))
234+
235+ node = get_node ("COMM_SEND_NODE" , COMM_SEND_NODE )
236+ node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
237+ node .attr .append (ChakraAttr (name = "tensor_size" , uint64_val = tensor_size ))
238+ encode_message (et , node )
239+
240+
241+ def one_comm_recv_node (num_npus : int , tensor_size : int ) -> None :
242+ """Generate communication receive nodes."""
243+ for npu_id in range (num_npus ):
244+ output_filename = f"one_comm_recv_node.{ npu_id } .et"
245+ with open (output_filename , "wb" ) as et :
246+ encode_message (et , GlobalMetadata (version = "0.0.4" ))
247+
248+ node = get_node ("COMM_RECV_NODE" , COMM_RECV_NODE )
249+ node .attr .append (ChakraAttr (name = "is_cpu_op" , bool_val = False ))
250+ node .attr .append (ChakraAttr (name = "tensor_size" , uint64_val = tensor_size ))
251+ encode_message (et , node )
252+
253+
221254def main () -> None :
222255 parser = argparse .ArgumentParser (description = "Execution Trace Generator" )
223256 parser .add_argument ("--num_npus" , type = int , default = 64 , help = "Number of NPUs" )
@@ -238,6 +271,10 @@ def main() -> None:
238271 one_comm_coll_node_alltoall (args .num_npus , args .default_comm_size )
239272 one_comm_coll_node_allgather (args .num_npus , args .default_comm_size )
240273 one_comm_coll_node_reducescatter (args .num_npus , args .default_comm_size )
274+ one_comm_coll_node_broadcast (args .num_npus , args .default_comm_size )
275+ one_comm_coll_node_barrier (args .num_npus )
276+ one_comm_send_node (args .num_npus , args .default_tensor_size )
277+ one_comm_recv_node (args .num_npus , args .default_tensor_size )
241278
242279
243280if __name__ == "__main__" :
0 commit comments