@@ -722,6 +722,88 @@ def metric_module_gather_state(
722722 metric_module .shutdown ()
723723
724724
725+ class MetricsConfigPostInitTest (unittest .TestCase ):
726+ """Test class for MetricsConfig._post_init() validation functionality."""
727+
728+ def test_post_init_valid_rec_task_indices (self ) -> None :
729+ """Test that _post_init() passes when rec_task_indices are valid."""
730+ # Setup: create rec_tasks and valid indices
731+ task1 = RecTaskInfo (name = "task1" , label_name = "label1" , prediction_name = "pred1" )
732+ task2 = RecTaskInfo (name = "task2" , label_name = "label2" , prediction_name = "pred2" )
733+ rec_tasks = [task1 , task2 ]
734+
735+ # Execute: create MetricsConfig with valid rec_task_indices
736+ config = MetricsConfig (
737+ rec_tasks = rec_tasks ,
738+ rec_metrics = {
739+ RecMetricEnum .AUC : RecMetricDef (rec_task_indices = [0 , 1 ]),
740+ RecMetricEnum .NE : RecMetricDef (rec_task_indices = [0 ]),
741+ },
742+ )
743+
744+ # Assert: config should be created successfully without raising an exception
745+ self .assertEqual (len (config .rec_tasks ), 2 )
746+ self .assertEqual (len (config .rec_metrics ), 2 )
747+
748+ def test_post_init_empty_rec_task_indices (self ) -> None :
749+ """Test that _post_init() passes when rec_task_indices is empty."""
750+ # Setup: create rec_tasks but use empty indices
751+ task = RecTaskInfo (name = "task" , label_name = "label" , prediction_name = "pred" )
752+ rec_tasks = [task ]
753+
754+ # Execute: create MetricsConfig with empty rec_task_indices
755+ config = MetricsConfig (
756+ rec_tasks = rec_tasks ,
757+ rec_metrics = {
758+ RecMetricEnum .AUC : RecMetricDef (rec_task_indices = []),
759+ },
760+ )
761+
762+ # Assert: config should be created successfully with empty indices
763+ self .assertEqual (len (config .rec_tasks ), 1 )
764+ self .assertEqual (config .rec_metrics [RecMetricEnum .AUC ].rec_task_indices , [])
765+
766+ def test_post_init_raises_when_rec_tasks_is_none (self ) -> None :
767+ """Test that _post_init() raises ValueError when rec_tasks is None but rec_task_indices is specified."""
768+ # Setup: prepare to create config with None rec_tasks but specified indices
769+
770+ # Execute & Assert: should raise ValueError about rec_tasks being None
771+ with self .assertRaises (ValueError ) as context :
772+ config = MetricsConfig (
773+ rec_tasks = None , # pyre-ignore[6]: Intentionally passing None for testing
774+ rec_metrics = {
775+ RecMetricEnum .AUC : RecMetricDef (rec_task_indices = [0 ]),
776+ },
777+ )
778+
779+ error_message = str (context .exception )
780+ self .assertIn ("rec_task_indices [0] is specified" , error_message )
781+ self .assertIn ("but rec_tasks is None" , error_message )
782+ self .assertIn ("for metric auc" , error_message )
783+
784+ def test_post_init_raises_when_rec_task_index_out_of_range (self ) -> None :
785+ """Test that _post_init() raises ValueError when rec_task_index is out of range."""
786+ # Setup: create single rec_task but try to access index 1
787+ task = RecTaskInfo (name = "task" , label_name = "label" , prediction_name = "pred" )
788+ rec_tasks = [task ]
789+
790+ # Execute & Assert: should raise ValueError about index out of range
791+ with self .assertRaises (ValueError ) as context :
792+ config = MetricsConfig (
793+ rec_tasks = rec_tasks ,
794+ rec_metrics = {
795+ RecMetricEnum .NE : RecMetricDef (
796+ rec_task_indices = [1 ]
797+ ), # Index 1 doesn't exist
798+ },
799+ )
800+
801+ error_message = str (context .exception )
802+ self .assertIn ("rec_task_indices 1 is out of range" , error_message )
803+ self .assertIn ("of 1 tasks" , error_message )
804+ self .assertIn ("for metric ne" , error_message )
805+
806+
725807@skip_if_asan_class
726808class MetricModuleDistributedTest (MultiProcessTestBase ):
727809
0 commit comments