2222)
2323
2424from  .utils  import  (
25+     all_supported_devices ,
2526    assert_frames_equal ,
2627    AV1_VIDEO ,
27-     cpu_and_cuda ,
2828    get_ffmpeg_major_version ,
2929    H264_10BITS ,
3030    H265_10BITS ,
@@ -163,7 +163,7 @@ def test_create_fails(self):
163163            VideoDecoder (NASA_VIDEO .path , seek_mode = "blah" )
164164
165165    @pytest .mark .parametrize ("num_ffmpeg_threads" , (1 , 4 )) 
166-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
166+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
167167    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
168168    def  test_getitem_int (self , num_ffmpeg_threads , device , seek_mode ):
169169        decoder  =  VideoDecoder (
@@ -213,7 +213,7 @@ def test_getitem_numpy_int(self):
213213        assert_frames_equal (ref_frame1 , decoder [numpy .uint32 (1 )])
214214        assert_frames_equal (ref_frame180 , decoder [numpy .uint32 (180 )])
215215
216-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
216+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
217217    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
218218    def  test_getitem_slice (self , device , seek_mode ):
219219        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -373,7 +373,7 @@ def test_device_instance(self):
373373        decoder  =  VideoDecoder (NASA_VIDEO .path , device = torch .device ("cpu" ))
374374        assert  isinstance (decoder .metadata , VideoStreamMetadata )
375375
376-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
376+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
377377    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
378378    def  test_getitem_fails (self , device , seek_mode ):
379379        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -390,7 +390,7 @@ def test_getitem_fails(self, device, seek_mode):
390390        with  pytest .raises (TypeError , match = "Unsupported key type" ):
391391            frame  =  decoder [2.3 ]  # noqa 
392392
393-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
393+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
394394    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
395395    def  test_iteration (self , device , seek_mode ):
396396        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -437,7 +437,7 @@ def test_iteration_slow(self):
437437
438438        assert  iterations  ==  len (decoder ) ==  390 
439439
440-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
440+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
441441    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
442442    def  test_get_frame_at (self , device , seek_mode ):
443443        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -475,7 +475,7 @@ def test_get_frame_at(self, device, seek_mode):
475475        frame9  =  decoder .get_frame_at (numpy .uint32 (9 ))
476476        assert_frames_equal (ref_frame9 , frame9 .data )
477477
478-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
478+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
479479    def  test_get_frame_at_tuple_unpacking (self , device ):
480480        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device )
481481
@@ -486,7 +486,7 @@ def test_get_frame_at_tuple_unpacking(self, device):
486486        assert  frame .pts_seconds  ==  pts 
487487        assert  frame .duration_seconds  ==  duration 
488488
489-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
489+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
490490    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
491491    def  test_get_frame_at_fails (self , device , seek_mode ):
492492        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -500,7 +500,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
500500        with  pytest .raises (IndexError , match = "must be less than" ):
501501            frame  =  decoder .get_frame_at (10000 )  # noqa 
502502
503-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
503+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
504504    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
505505    def  test_get_frames_at (self , device , seek_mode ):
506506        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -551,7 +551,7 @@ def test_get_frames_at(self, device, seek_mode):
551551            frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0 
552552        )
553553
554-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
554+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
555555    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
556556    def  test_get_frames_at_fails (self , device , seek_mode ):
557557        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -568,7 +568,7 @@ def test_get_frames_at_fails(self, device, seek_mode):
568568        with  pytest .raises (RuntimeError , match = "Expected a value of type" ):
569569            decoder .get_frames_at ([0.3 ])
570570
571-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
571+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
572572    def  test_get_frame_at_av1 (self , device ):
573573        if  device  ==  "cuda"  and  get_ffmpeg_major_version () ==  4 :
574574            return 
@@ -581,7 +581,7 @@ def test_get_frame_at_av1(self, device):
581581        assert  decoded_frame10 .pts_seconds  ==  ref_frame_info10 .pts_seconds 
582582        assert_frames_equal (decoded_frame10 .data , ref_frame10 .to (device = device ))
583583
584-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
584+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
585585    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
586586    def  test_get_frame_played_at (self , device , seek_mode ):
587587        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -610,7 +610,7 @@ def test_get_frame_played_at_h265(self):
610610        ref_frame6  =  H265_VIDEO .get_frame_data_by_index (5 )
611611        assert_frames_equal (ref_frame6 , decoder .get_frame_played_at (0.5 ).data )
612612
613-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
613+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
614614    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
615615    def  test_get_frame_played_at_fails (self , device , seek_mode ):
616616        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -621,7 +621,7 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
621621        with  pytest .raises (IndexError , match = "Invalid pts in seconds" ):
622622            frame  =  decoder .get_frame_played_at (100.0 )  # noqa 
623623
624-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
624+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
625625    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
626626    def  test_get_frames_played_at (self , device , seek_mode ):
627627
@@ -660,7 +660,7 @@ def test_get_frames_played_at(self, device, seek_mode):
660660            frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0 
661661        )
662662
663-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
663+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
664664    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
665665    def  test_get_frames_played_at_fails (self , device , seek_mode ):
666666        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -674,7 +674,7 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
674674        with  pytest .raises (RuntimeError , match = "Expected a value of type" ):
675675            decoder .get_frames_played_at (["bad" ])
676676
677-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
677+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
678678    @pytest .mark .parametrize ("stream_index" , [0 , 3 , None ]) 
679679    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
680680    def  test_get_frames_in_range (self , stream_index , device , seek_mode ):
@@ -779,7 +779,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
779779            empty_frames .duration_seconds , NASA_VIDEO .empty_duration_seconds 
780780        )
781781
782-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
782+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
783783    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
784784    def  test_get_frames_in_range_slice_indices_syntax (self , device , seek_mode ):
785785        decoder  =  VideoDecoder (
@@ -831,7 +831,7 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
831831        ).to (device )
832832        assert_frames_equal (frames387_None .data , reference_frame387_389 )
833833
834-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
834+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
835835    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
836836    @patch ("torchcodec._core._metadata._get_stream_json_metadata" ) 
837837    def  test_get_frames_with_missing_num_frames_metadata (
@@ -894,7 +894,7 @@ def test_get_frames_with_missing_num_frames_metadata(
894894            lambda  decoder : decoder .get_frames_played_in_range (0 , 1 ).data , 
895895        ), 
896896    ) 
897-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
897+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
898898    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
899899    def  test_dimension_order (self , dimension_order , frame_getter , device , seek_mode ):
900900        decoder  =  VideoDecoder (
@@ -922,7 +922,7 @@ def test_dimension_order_fails(self):
922922            VideoDecoder (NASA_VIDEO .path , dimension_order = "NCDHW" )
923923
924924    @pytest .mark .parametrize ("stream_index" , [0 , 3 , None ]) 
925-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
925+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
926926    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
927927    def  test_get_frames_by_pts_in_range (self , stream_index , device , seek_mode ):
928928        decoder  =  VideoDecoder (
@@ -1061,7 +1061,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
10611061        )
10621062        assert_frames_equal (all_frames .data , decoder [:])
10631063
1064-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
1064+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
10651065    @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" )) 
10661066    def  test_get_frames_by_pts_in_range_fails (self , device , seek_mode ):
10671067        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -1075,7 +1075,7 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
10751075        with  pytest .raises (ValueError , match = "Invalid stop seconds" ):
10761076            frame  =  decoder .get_frames_played_in_range (0 , 23 )  # noqa 
10771077
1078-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
1078+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
10791079    def  test_get_key_frame_indices (self , device ):
10801080        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = "exact" )
10811081        key_frame_indices  =  decoder ._get_key_frame_indices ()
@@ -1120,7 +1120,7 @@ def test_get_key_frame_indices(self, device):
11201120
11211121    # TODO investigate why this fails internally. 
11221122    @pytest .mark .skipif (in_fbcode (), reason = "Compile test fails internally." ) 
1123-     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
1123+     @pytest .mark .parametrize ("device" , all_supported_devices ()) 
11241124    def  test_compile (self , device ):
11251125        decoder  =  VideoDecoder (NASA_VIDEO .path , device = device )
11261126
0 commit comments