@@ -65,6 +65,7 @@ def dummy_init_pg() -> None:
6565def _test_pg (
6666 pg : ProcessGroup ,
6767 example_tensor : torch .Tensor = torch .randn ((2 , 3 ), dtype = torch .float32 ),
68+ skip : list [str ] = [],
6869) -> Dict [str , dist ._Work ]:
6970 """
7071 Helper function to test a set of collective operations on a given process group.
@@ -124,6 +125,8 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
124125 works : Dict [str , dist ._Work ] = {}
125126
126127 for coll_str , args in collectives :
128+ if coll_str in skip :
129+ continue
127130 try :
128131 coll = getattr (pg , coll_str )
129132 work = coll (* args )
@@ -496,7 +499,12 @@ def run_reduce_scatter_tensor_coalesced_test(
496499
497500
498501class ProcessGroupTest (TestCase ):
499- def test_gloo_apis (self ) -> None :
502+ @parameterized .expand (["cpu" , "cuda" ])
503+ def test_gloo_apis (self , device : str ) -> None :
504+ if device == "cuda" and not torch .cuda .is_available ():
505+ self .skipTest ("CUDA is not available" )
506+ return
507+
500508 store = TCPStore (
501509 host_name = "localhost" , port = 0 , is_master = True , wait_for_workers = False
502510 )
@@ -507,11 +515,23 @@ def test_gloo_apis(self) -> None:
507515
508516 self .assertEqual (pg .size (), 1 )
509517
510- _test_pg (pg )
518+ _test_pg (
519+ pg ,
520+ torch .tensor ([2 ], device = device ),
521+ skip = (
522+ # https://github.com/pytorch/pytorch/issues/152645
523+ [
524+ "allreduce_coalesced" ,
525+ "allgather_into_tensor_coalesced" ,
526+ ]
527+ if device == "cuda"
528+ else []
529+ ),
530+ )
511531
512- m = nn .Linear (3 , 4 )
532+ m = nn .Linear (3 , 4 ). to ( device )
513533 m = torch .nn .parallel .DistributedDataParallel (m , process_group = pg )
514- m (torch .rand (2 , 3 ))
534+ m (torch .rand (2 , 3 , device = device ))
515535
516536 def test_gloo_timeout (self ) -> None :
517537 store = TCPStore (
0 commit comments