Skip to content

Commit 6497fe8

Browse files
committed
reconcile all nvidiadriver CRs when any nvidiadriver CR is changed
Signed-off-by: Rahul Sharma <rahulsharm@nvidia.com>
1 parent 652724d commit 6497fe8

File tree

2 files changed

+64
-51
lines changed

2 files changed

+64
-51
lines changed

controllers/nvidiadriver_controller.go

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,33 @@ func (r *NVIDIADriverReconciler) updateCrStatus(
248248
return nil
249249
}
250250

251+
// enqueueAllNVIDIADrivers lists all NVIDIADriver instances in the cluster and enqueues a reconcile
252+
// request for each instance. This is used to trigger reconciliation for all NVIDIADriver instances
253+
// when a relevant event occurs (e.g. ClusterPolicy/NVIDIADriver update, node label change, etc).
254+
func (r *NVIDIADriverReconciler) enqueueAllNVIDIADrivers(ctx context.Context) []reconcile.Request {
255+
logger := log.FromContext(ctx)
256+
list := &nvidiav1alpha1.NVIDIADriverList{}
257+
258+
err := r.List(ctx, list)
259+
if err != nil {
260+
logger.Error(err, "Unable to list NVIDIADriver resources")
261+
return []reconcile.Request{}
262+
}
263+
264+
reconcileRequests := make([]reconcile.Request, 0, len(list.Items))
265+
for _, nvidiaDriver := range list.Items {
266+
reconcileRequests = append(reconcileRequests,
267+
reconcile.Request{
268+
NamespacedName: types.NamespacedName{
269+
Name: nvidiaDriver.GetName(),
270+
Namespace: nvidiaDriver.GetNamespace(),
271+
},
272+
})
273+
}
274+
275+
return reconcileRequests
276+
}
277+
251278
// SetupWithManager sets up the controller with the Manager.
252279
func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error {
253280
// Create state manager
@@ -277,11 +304,17 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.
277304
return err
278305
}
279306

280-
// Watch for changes to the primary resource NVIDIaDriver
307+
// Watch for changes to NVIDIADriver CRs. Whenever an event is generated for a NVIDIADriver CR,
308+
// enqueue a reconcile request for all NVIDIADriver instances.
309+
nvidiaDriverMapFn := func(ctx context.Context, _ *nvidiav1alpha1.NVIDIADriver) []reconcile.Request {
310+
return r.enqueueAllNVIDIADrivers(ctx)
311+
}
312+
313+
// Watch for changes to the primary resource NVIDIADriver
281314
err = c.Watch(source.Kind(
282315
mgr.GetCache(),
283316
&nvidiav1alpha1.NVIDIADriver{},
284-
&handler.TypedEnqueueRequestForObject[*nvidiav1alpha1.NVIDIADriver]{},
317+
handler.TypedEnqueueRequestsFromMapFunc(nvidiaDriverMapFn),
285318
predicate.TypedGenerationChangedPredicate[*nvidiav1alpha1.NVIDIADriver]{},
286319
),
287320
)
@@ -291,63 +324,21 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.
291324

292325
// Watch for changes to ClusterPolicy. Whenever an event is generated for ClusterPolicy, enqueue
293326
// a reconcile request for all NVIDIADriver instances.
294-
mapFn := func(ctx context.Context, cp *gpuv1.ClusterPolicy) []reconcile.Request {
295-
logger := log.FromContext(ctx)
296-
opts := []client.ListOption{}
297-
list := &nvidiav1alpha1.NVIDIADriverList{}
298-
299-
err := mgr.GetClient().List(ctx, list, opts...)
300-
if err != nil {
301-
logger.Error(err, "Unable to list NVIDIADriver resources")
302-
return []reconcile.Request{}
303-
}
304-
305-
reconcileRequests := []reconcile.Request{}
306-
for _, nvidiaDriver := range list.Items {
307-
reconcileRequests = append(reconcileRequests,
308-
reconcile.Request{
309-
NamespacedName: types.NamespacedName{
310-
Name: nvidiaDriver.GetName(),
311-
Namespace: nvidiaDriver.GetNamespace(),
312-
},
313-
})
314-
}
315-
316-
return reconcileRequests
327+
mapFn := func(ctx context.Context, _ *gpuv1.ClusterPolicy) []reconcile.Request {
328+
return r.enqueueAllNVIDIADrivers(ctx)
317329
}
318330

319-
// Watch for changes to the Nodes. Whenever an event is generated for ClusterPolicy, enqueue
331+
// Watch for changes to the Nodes. Whenever an event is generated for a Node, enqueue
320332
// a reconcile request for all NVIDIADriver instances.
321-
nodeMapFn := func(ctx context.Context, cp *corev1.Node) []reconcile.Request {
322-
logger := log.FromContext(ctx)
323-
opts := []client.ListOption{}
324-
list := &nvidiav1alpha1.NVIDIADriverList{}
325-
326-
err := mgr.GetClient().List(ctx, list, opts...)
327-
if err != nil {
328-
logger.Error(err, "Unable to list NVIDIADriver resources")
329-
return []reconcile.Request{}
330-
}
331-
332-
reconcileRequests := []reconcile.Request{}
333-
for _, nvidiaDriver := range list.Items {
334-
reconcileRequests = append(reconcileRequests,
335-
reconcile.Request{
336-
NamespacedName: types.NamespacedName{
337-
Name: nvidiaDriver.GetName(),
338-
Namespace: nvidiaDriver.GetNamespace(),
339-
},
340-
})
341-
}
342-
343-
return reconcileRequests
333+
nodeMapFn := func(ctx context.Context, _ *corev1.Node) []reconcile.Request {
334+
return r.enqueueAllNVIDIADrivers(ctx)
344335
}
345336

346337
err = c.Watch(
347338
source.Kind(
348339
mgr.GetCache(),
349340
&gpuv1.ClusterPolicy{},
350-
handler.TypedEnqueueRequestsFromMapFunc[*gpuv1.ClusterPolicy](mapFn),
341+
handler.TypedEnqueueRequestsFromMapFunc(mapFn),
351342
predicate.TypedGenerationChangedPredicate[*gpuv1.ClusterPolicy]{},
352343
),
353344
)
@@ -385,7 +376,7 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.
385376
err = c.Watch(
386377
source.Kind(mgr.GetCache(),
387378
&corev1.Node{},
388-
handler.TypedEnqueueRequestsFromMapFunc[*corev1.Node](nodeMapFn),
379+
handler.TypedEnqueueRequestsFromMapFunc(nodeMapFn),
389380
nodePredicate,
390381
),
391382
)

controllers/nvidiadriver_controller_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"bytes"
2121
"context"
2222
"errors"
23+
"sort"
2324
"testing"
2425

2526
"github.com/go-logr/logr"
@@ -193,3 +194,24 @@ func TestReconcile(t *testing.T) {
193194
})
194195
}
195196
}
197+
198+
func TestEnqueueAllNVIDIADrivers(t *testing.T) {
199+
scheme := runtime.NewScheme()
200+
require.NoError(t, nvidiav1alpha1.AddToScheme(scheme))
201+
202+
client := fake.NewClientBuilder().WithScheme(scheme).WithObjects(
203+
&nvidiav1alpha1.NVIDIADriver{ObjectMeta: metav1.ObjectMeta{Name: "driver-a", Namespace: "default"}},
204+
&nvidiav1alpha1.NVIDIADriver{ObjectMeta: metav1.ObjectMeta{Name: "driver-b", Namespace: "default"}},
205+
).Build()
206+
207+
reconciler := &NVIDIADriverReconciler{Client: client}
208+
requests := reconciler.enqueueAllNVIDIADrivers(context.Background())
209+
210+
require.Len(t, requests, 2)
211+
got := []string{
212+
requests[0].String(),
213+
requests[1].String(),
214+
}
215+
sort.Strings(got)
216+
require.Equal(t, []string{"default/driver-a", "default/driver-b"}, got)
217+
}

0 commit comments

Comments
 (0)