@@ -173,6 +173,27 @@ func TestMPIJobSuccess(t *testing.T) {
173173}
174174
175175func TestMPIJobWaitWorkers (t * testing.T ) {
176+ testcases := []struct {
177+ name string
178+ startSuspended bool
179+ }{
180+ {
181+ name : "don't start suspended" ,
182+ startSuspended : false ,
183+ },
184+ {
185+ name : "start suspended" ,
186+ startSuspended : true ,
187+ },
188+ }
189+ for _ , tc := range testcases {
190+ t .Run (tc .name , func (t * testing.T ) {
191+ testMpiJobWaitWorkers (t , tc .startSuspended )
192+ })
193+ }
194+ }
195+
196+ func testMpiJobWaitWorkers (t * testing.T , startSuspended bool ) {
176197 ctx , cancel := context .WithCancel (context .Background ())
177198 t .Cleanup (cancel )
178199 s := newTestSetup (ctx , t )
@@ -187,6 +208,7 @@ func TestMPIJobWaitWorkers(t *testing.T) {
187208 SlotsPerWorker : ptr.To [int32 ](1 ),
188209 LauncherCreationPolicy : "WaitForWorkersReady" ,
189210 RunPolicy : kubeflow.RunPolicy {
211+ Suspend : ptr .To (startSuspended ),
190212 CleanPodPolicy : ptr .To (kubeflow .CleanPodPolicyRunning ),
191213 },
192214 MPIReplicaSpecs : map [kubeflow.MPIReplicaType ]* kubeflow.ReplicaSpec {
@@ -237,9 +259,37 @@ func TestMPIJobWaitWorkers(t *testing.T) {
237259 }
238260 s .events .verify (t )
239261
240- workerPods , err := getPodsForJob (ctx , s .kClient , mpiJob )
262+ // The launcher job should not be created until all workers are ready even when we start in suspended mode.
263+ job , err := getLauncherJobForMPIJob (ctx , s .kClient , mpiJob )
241264 if err != nil {
242- t .Fatalf ("Cannot get worker pods from job: %v" , err )
265+ t .Fatalf ("Cannot get launcher job from job: %v" , err )
266+ }
267+ if job != nil {
268+ t .Fatalf ("Launcher is created before workers" )
269+ }
270+
271+ if startSuspended {
272+ // Resume the MPIJob so that the test can follow the normal path.
273+ mpiJob .Spec .RunPolicy .Suspend = ptr .To (false )
274+ mpiJob , err = s .mpiClient .KubeflowV2beta1 ().MPIJobs (mpiJob .Namespace ).Update (ctx , mpiJob , metav1.UpdateOptions {})
275+ if err != nil {
276+ t .Fatalf ("Error Updating MPIJob: %v" , err )
277+ }
278+ }
279+
280+ var workerPods []corev1.Pod
281+ if err = wait .PollUntilContextTimeout (ctx , util .WaitInterval , wait .ForeverTestTimeout , false , func (ctx context.Context ) (bool , error ) {
282+ var err error
283+ workerPods , err = getPodsForJob (ctx , s .kClient , mpiJob )
284+ if err != nil {
285+ return false , err
286+ }
287+ if len (workerPods ) != 2 {
288+ return false , nil
289+ }
290+ return true , nil
291+ }); err != nil {
292+ t .Errorf ("Failed updating scheduler-plugins PodGroup: %v" , err )
243293 }
244294
245295 err = updatePodsToPhase (ctx , s .kClient , workerPods , corev1 .PodRunning )
0 commit comments