diff --git a/cmd/nvidia-dra-controller/controller.go b/cmd/nvidia-dra-controller/controller.go index 83bf61920..47b1316a2 100644 --- a/cmd/nvidia-dra-controller/controller.go +++ b/cmd/nvidia-dra-controller/controller.go @@ -19,9 +19,28 @@ package main import ( "context" "fmt" + "sync" + "time" + + "k8s.io/client-go/informers" + + "github.com/NVIDIA/k8s-dra-driver/pkg/flags" + nvinformers "github.com/NVIDIA/k8s-dra-driver/pkg/nvidia.com/resource/informers/externalversions" + "github.com/NVIDIA/k8s-dra-driver/pkg/workqueue" ) +type ManagerConfig struct { + clientsets flags.ClientSets + nvInformerFactory nvinformers.SharedInformerFactory + coreInformerFactory informers.SharedInformerFactory + workQueue *workqueue.WorkQueue +} + +type OwnerExistsFunc func(ctx context.Context, uid string) (bool, error) + type Controller struct { + waitGroup sync.WaitGroup + ImexManager *ImexManager MultiNodeEnvironmentManager *MultiNodeEnvironmentManager } @@ -32,22 +51,51 @@ func StartController(ctx context.Context, config *Config) (*Controller, error) { return nil, nil } + workQueue := workqueue.New(workqueue.DefaultControllerRateLimiter()) + nvInformerFactory := nvinformers.NewSharedInformerFactory(config.clientsets.Nvidia, 30*time.Second) + coreInformerFactory := informers.NewSharedInformerFactory(config.clientsets.Core, 30*time.Second) + + managerConfig := &ManagerConfig{ + clientsets: config.clientsets, + nvInformerFactory: nvInformerFactory, + coreInformerFactory: coreInformerFactory, + workQueue: workQueue, + } + imexManager, err := StartImexManager(ctx, config) if err != nil { return nil, fmt.Errorf("error starting IMEX manager: %w", err) } - mneManager, err := StartMultiNodeEnvironmentManager(ctx, config) + mneManager, err := NewMultiNodeEnvironmentManager(ctx, managerConfig) if err != nil { return nil, fmt.Errorf("error starting MultiNodeEnvironment manager: %w", err) } - m := &Controller{ + c := &Controller{ ImexManager: imexManager, MultiNodeEnvironmentManager: mneManager, } - return m, nil + c.waitGroup.Add(3) + go func() { + defer c.waitGroup.Done() + nvInformerFactory.Start(ctx.Done()) + }() + go func() { + defer c.waitGroup.Done() + coreInformerFactory.Start(ctx.Done()) + }() + go func() { + defer c.waitGroup.Done() + workQueue.Run(ctx) + }() + + if err := c.MultiNodeEnvironmentManager.WaitForCacheSync(ctx); err != nil { + return nil, fmt.Errorf("error syncing cache: %w", err) + } + + return c, nil } // Stop stops a running Controller. @@ -55,10 +103,12 @@ func (m *Controller) Stop() error { if m == nil { return nil } - imErr := m.ImexManager.Stop() - mnErr := m.MultiNodeEnvironmentManager.Stop() - if imErr != nil || mnErr != nil { - return fmt.Errorf("IMEX manager error: %w, MultiNodeEnvironment manager error: %w", imErr, mnErr) + + m.waitGroup.Wait() + + if err := m.ImexManager.Stop(); err != nil { + return fmt.Errorf("error stopping IMEX manager: %w", err) } + return nil } diff --git a/cmd/nvidia-dra-controller/deviceclass.go b/cmd/nvidia-dra-controller/deviceclass.go new file mode 100644 index 000000000..5166ee958 --- /dev/null +++ b/cmd/nvidia-dra-controller/deviceclass.go @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "time" + + resourceapi "k8s.io/api/resource/v1beta1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + resourcelisters "k8s.io/client-go/listers/resource/v1beta1" + "k8s.io/client-go/tools/cache" + "k8s.io/klog/v2" + + nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" + "github.com/NVIDIA/k8s-dra-driver/pkg/flags" +) + +type DeviceClassManager struct { + clientsets flags.ClientSets + ownerExists func(ctx context.Context, uid string) (bool, error) + + informer cache.SharedIndexInformer + lister resourcelisters.DeviceClassLister +} + +func NewDeviceClassManager(ctx context.Context, config *ManagerConfig, ownerExists OwnerExistsFunc) (*DeviceClassManager, error) { + informer := config.coreInformerFactory.Resource().V1beta1().DeviceClasses().Informer() + lister := resourcelisters.NewDeviceClassLister(informer.GetIndexer()) + + m := &DeviceClassManager{ + clientsets: config.clientsets, + ownerExists: ownerExists, + informer: informer, + lister: lister, + } + + _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj any) { + config.workQueue.Enqueue(obj, m.onAddOrUpdate) + }, + UpdateFunc: func(objOld, objNew any) { + config.workQueue.Enqueue(objNew, m.onAddOrUpdate) + }, + }) + if err != nil { + return nil, fmt.Errorf("error adding event handlers for DeviceClass informer: %w", err) + } + + return m, nil +} + +func (m *DeviceClassManager) WaitForCacheSync(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { + return fmt.Errorf("informer cache sync for DeviceClasses failed") + } + + return nil +} + +func (m *DeviceClassManager) Create(ctx context.Context, name string, ownerReference metav1.OwnerReference) (*resourceapi.DeviceClass, error) { + if name != "" { + dc, err := m.lister.Get(name) + if err == nil { + if len(dc.OwnerReferences) != 1 && dc.OwnerReferences[0] != ownerReference { + return nil, fmt.Errorf("DeviceClass '%s' exists without expected OwnerReference: %v", name, ownerReference) + } + return dc, nil + } + if !errors.IsNotFound(err) { + return nil, fmt.Errorf("error retrieving DeviceClass: %w", err) + } + } + + deviceClass := &resourceapi.DeviceClass{ + ObjectMeta: metav1.ObjectMeta{ + OwnerReferences: []metav1.OwnerReference{ownerReference}, + Finalizers: []string{multiNodeEnvironmentFinalizer}, + }, + Spec: resourceapi.DeviceClassSpec{ + Selectors: []resourceapi.DeviceSelector{ + { + CEL: &resourceapi.CELDeviceSelector{ + Expression: "device.driver == 'gpu.nvidia.com' && device.attributes['gpu.nvidia.com'].type == 'imex-channel'", + }, + }, + }, + }, + } + + if name == "" { + deviceClass.GenerateName = ownerReference.Name + } else { + deviceClass.Name = name + } + + dc, err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Create(ctx, deviceClass, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("error creating DeviceClass: %w", err) + } + + return dc, nil +} + +func (m *DeviceClassManager) Delete(ctx context.Context, name string) error { + err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Delete(ctx, name, metav1.DeleteOptions{}) + if err != nil && !errors.IsNotFound(err) { + return fmt.Errorf("erroring deleting DeviceClass: %w", err) + } + return nil +} + +func (m *DeviceClassManager) RemoveFinalizer(ctx context.Context, name string) error { + dc, err := m.lister.Get(name) + if err != nil && errors.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("error retrieving DeviceClass: %w", err) + } + + newDC := dc.DeepCopy() + + newDC.Finalizers = []string{} + for _, f := range dc.Finalizers { + if f != multiNodeEnvironmentFinalizer { + newDC.Finalizers = append(newDC.Finalizers, f) + } + } + + _, err = m.clientsets.Core.ResourceV1beta1().DeviceClasses().Update(ctx, newDC, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("error updating DeviceClass: %w", err) + } + + return nil +} + +func (m *DeviceClassManager) onAddOrUpdate(ctx context.Context, obj any) error { + dc, ok := obj.(*resourceapi.DeviceClass) + if !ok { + return fmt.Errorf("failed to cast to DeviceClass") + } + + klog.Infof("Processing added or updated DeviceClass: %s", dc.Name) + + if len(dc.OwnerReferences) != 1 { + return nil + } + + if dc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind { + return nil + } + + exists, err := m.ownerExists(ctx, string(dc.OwnerReferences[0].UID)) + if err != nil { + return fmt.Errorf("error checking if owner exists: %w", err) + } + if exists { + return nil + } + + if err := m.RemoveFinalizer(ctx, dc.Name); err != nil { + return fmt.Errorf("error removing finalizer on DeviceClass '%s': %w", dc.Name, err) + } + + if err := m.Delete(ctx, dc.Name); err != nil { + return fmt.Errorf("error deleting DeviceClass '%s': %w", dc.Name, err) + } + + return nil +} diff --git a/cmd/nvidia-dra-controller/mnenv.go b/cmd/nvidia-dra-controller/mnenv.go index 8f5b9e804..d5ae8f58e 100644 --- a/cmd/nvidia-dra-controller/mnenv.go +++ b/cmd/nvidia-dra-controller/mnenv.go @@ -19,23 +19,16 @@ package main import ( "context" "fmt" - "sync" "time" - resourceapi "k8s.io/api/resource/v1beta1" - "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/informers" - resourcelisters "k8s.io/client-go/listers/resource/v1beta1" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" "k8s.io/utils/ptr" nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" "github.com/NVIDIA/k8s-dra-driver/pkg/flags" - nvinformers "github.com/NVIDIA/k8s-dra-driver/pkg/nvidia.com/resource/informers/externalversions" nvlisters "github.com/NVIDIA/k8s-dra-driver/pkg/nvidia.com/resource/listers/gpu/v1alpha1" - "github.com/NVIDIA/k8s-dra-driver/pkg/workqueue" ) const ( @@ -44,40 +37,26 @@ const ( type MultiNodeEnvironmentManager struct { clientsets flags.ClientSets - waitGroup sync.WaitGroup - multiNodeEnvironmentInformer cache.SharedIndexInformer - multiNodeEnvironmentLister nvlisters.MultiNodeEnvironmentLister - resourceClaimLister resourcelisters.ResourceClaimLister - deviceClassLister resourcelisters.DeviceClassLister + informer cache.SharedIndexInformer + lister nvlisters.MultiNodeEnvironmentLister + + resourceClaimManager *ResourceClaimManager + deviceClassManager *DeviceClassManager } // StartManager starts a MultiNodeEnvironmentManager. -func StartMultiNodeEnvironmentManager(ctx context.Context, config *Config) (*MultiNodeEnvironmentManager, error) { - queue := workqueue.New(workqueue.DefaultControllerRateLimiter()) - - nvInformerFactory := nvinformers.NewSharedInformerFactory(config.clientsets.Nvidia, 30*time.Second) - coreInformerFactory := informers.NewSharedInformerFactory(config.clientsets.Core, 30*time.Second) - - mneInformer := nvInformerFactory.Gpu().V1alpha1().MultiNodeEnvironments().Informer() - mneLister := nvlisters.NewMultiNodeEnvironmentLister(mneInformer.GetIndexer()) - - rcInformer := coreInformerFactory.Resource().V1beta1().ResourceClaims().Informer() - rcLister := resourcelisters.NewResourceClaimLister(rcInformer.GetIndexer()) - - dcInformer := coreInformerFactory.Resource().V1beta1().DeviceClasses().Informer() - dcLister := resourcelisters.NewDeviceClassLister(dcInformer.GetIndexer()) +func NewMultiNodeEnvironmentManager(ctx context.Context, config *ManagerConfig) (*MultiNodeEnvironmentManager, error) { + informer := config.nvInformerFactory.Gpu().V1alpha1().MultiNodeEnvironments().Informer() + lister := nvlisters.NewMultiNodeEnvironmentLister(informer.GetIndexer()) m := &MultiNodeEnvironmentManager{ - clientsets: config.clientsets, - multiNodeEnvironmentInformer: mneInformer, - multiNodeEnvironmentLister: mneLister, - resourceClaimLister: rcLister, - deviceClassLister: dcLister, + clientsets: config.clientsets, + informer: informer, + lister: lister, } - var err error - err = mneInformer.AddIndexers(cache.Indexers{ + err := informer.AddIndexers(cache.Indexers{ "uid": func(obj interface{}) ([]string, error) { mne, ok := obj.(*nvapi.MultiNodeEnvironment) if !ok { @@ -90,64 +69,69 @@ func StartMultiNodeEnvironmentManager(ctx context.Context, config *Config) (*Mul return nil, fmt.Errorf("error adding indexer for MultiNodeEnvironment UUIDs: %w", err) } - _, err = mneInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj any) { queue.Enqueue(obj, m.onMultiNodeEnvironmentAdd) }, + _, err = informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj any) { + config.workQueue.Enqueue(obj, m.onMultiNodeEnvironmentAdd) + }, }) if err != nil { return nil, fmt.Errorf("error adding event handlers for MultiNodeEnvironment informer: %w", err) } - _, err = rcInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj any) { queue.Enqueue(obj, m.onResourceClaimAddOrUpdate) }, - UpdateFunc: func(objOld, objNew any) { queue.Enqueue(objNew, m.onResourceClaimAddOrUpdate) }, - }) + m.resourceClaimManager, err = NewResourceClaimManager(ctx, config, m.Exists) if err != nil { - return nil, fmt.Errorf("error adding event handlers for ResourceClaim informer: %w", err) + return nil, fmt.Errorf("error creating ResourceClaim manager: %w", err) } - _, err = dcInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj any) { queue.Enqueue(obj, m.onDeviceClassAddOrUpdate) }, - UpdateFunc: func(objOld, objNew any) { queue.Enqueue(objNew, m.onDeviceClassAddOrUpdate) }, - }) + m.deviceClassManager, err = NewDeviceClassManager(ctx, config, m.Exists) if err != nil { - return nil, fmt.Errorf("error adding event handlers for DeviceClass informer: %w", err) + return nil, fmt.Errorf("error creating DeviceClass manager: %w", err) } - m.waitGroup.Add(3) - go func() { - defer m.waitGroup.Done() - nvInformerFactory.Start(ctx.Done()) - }() - go func() { - defer m.waitGroup.Done() - coreInformerFactory.Start(ctx.Done()) - }() - go func() { - defer m.waitGroup.Done() - queue.Run(ctx.Done()) - }() + return m, nil +} - if !cache.WaitForCacheSync(ctx.Done(), mneInformer.HasSynced, rcInformer.HasSynced, dcInformer.HasSynced) { - klog.Warning("Cache sync failed; retrying in 5 seconds") - time.Sleep(5 * time.Second) - if !cache.WaitForCacheSync(ctx.Done(), mneInformer.HasSynced, rcInformer.HasSynced, dcInformer.HasSynced) { - return nil, fmt.Errorf("informer cache sync failed twice") - } +// WaitForCacheSync waits for the cache for MultiNodeEnvironments to sync. +func (m *MultiNodeEnvironmentManager) WaitForCacheSync(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { + return fmt.Errorf("informer cache sync for MultiNodeEnvironments failed") } - return m, nil -} + if err := m.resourceClaimManager.WaitForCacheSync(ctx); err != nil { + return fmt.Errorf("error syncing dependency cache: %w", err) + } -// Stop stops a running MultiNodeEnvironmentManager. -func (m *MultiNodeEnvironmentManager) Stop() error { - if m == nil { - return nil + if err := m.deviceClassManager.WaitForCacheSync(ctx); err != nil { + return fmt.Errorf("error syncing dependency cache: %w", err) } - m.waitGroup.Wait() + return nil } -func (m *MultiNodeEnvironmentManager) onMultiNodeEnvironmentAdd(obj any) error { +// Exists checks if a MultiNodeEnvironment with a specific UID exists. +func (m *MultiNodeEnvironmentManager) Exists(ctx context.Context, uid string) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { + return false, fmt.Errorf("cache sync failed for MultiNodeEnvironment") + } + + mnes, err := m.informer.GetIndexer().ByIndex("uid", uid) + if err != nil { + return false, fmt.Errorf("error retrieving MultiNodeInformer OwnerReference by UID from indexer: %w", err) + } + if len(mnes) == 0 { + return false, nil + } + + return true, nil +} + +func (m *MultiNodeEnvironmentManager) onMultiNodeEnvironmentAdd(ctx context.Context, obj any) error { mne, ok := obj.(*nvapi.MultiNodeEnvironment) if !ok { return fmt.Errorf("failed to cast to MultiNodeEnvironment") @@ -167,242 +151,16 @@ func (m *MultiNodeEnvironmentManager) onMultiNodeEnvironmentAdd(obj any) error { Controller: ptr.To(true), } - dc, err := m.createDeviceClass(mne.Spec.DeviceClassName, ownerReference) + dc, err := m.deviceClassManager.Create(ctx, mne.Spec.DeviceClassName, ownerReference) if err != nil { return fmt.Errorf("error creating DeviceClass '%s': %w", "", err) } if mne.Spec.ResourceClaimName != "" { - if _, err := m.createResourceClaim(mne.Namespace, mne.Spec.ResourceClaimName, dc.Name, ownerReference); err != nil { + if _, err := m.resourceClaimManager.Create(ctx, mne.Namespace, mne.Spec.ResourceClaimName, dc.Name, ownerReference); err != nil { return fmt.Errorf("error creating ResourceClaim '%s/%s': %w", mne.Namespace, mne.Spec.ResourceClaimName, err) } } return nil } - -func (m *MultiNodeEnvironmentManager) onDeviceClassAddOrUpdate(obj any) error { - dc, ok := obj.(*resourceapi.DeviceClass) - if !ok { - return fmt.Errorf("failed to cast to DeviceClass") - } - - klog.Infof("Processing added or updated DeviceClass: %s", dc.Name) - - if len(dc.OwnerReferences) != 1 { - return nil - } - - if dc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind { - return nil - } - - if !cache.WaitForCacheSync(context.Background().Done(), m.multiNodeEnvironmentInformer.HasSynced) { - return fmt.Errorf("cache sync failed for MultiNodeEnvironment") - } - - mnes, err := m.multiNodeEnvironmentInformer.GetIndexer().ByIndex("uid", string(dc.OwnerReferences[0].UID)) - if err != nil { - return fmt.Errorf("error retrieving MultiNodeInformer OwnerReference by UID from indexer: %w", err) - } - if len(mnes) != 0 { - return nil - } - - if err := m.removeDeviceClassFinalizer(dc.Name); err != nil { - return fmt.Errorf("error removing finalizer on DeviceClass '%s': %w", dc.Name, err) - } - - if err := m.deleteDeviceClass(dc.Name); err != nil { - return fmt.Errorf("error deleting DeviceClass '%s': %w", dc.Name, err) - } - - return nil -} - -func (m *MultiNodeEnvironmentManager) onResourceClaimAddOrUpdate(obj any) error { - rc, ok := obj.(*resourceapi.ResourceClaim) - if !ok { - return fmt.Errorf("failed to cast to ResourceClaim") - } - - klog.Infof("Processing added or updated ResourceClaim: %s/%s", rc.Namespace, rc.Name) - - if len(rc.OwnerReferences) != 1 { - return nil - } - - if rc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind { - return nil - } - - if !cache.WaitForCacheSync(context.Background().Done(), m.multiNodeEnvironmentInformer.HasSynced) { - return fmt.Errorf("cache sync failed for MultiNodeEnvironment") - } - - mnes, err := m.multiNodeEnvironmentInformer.GetIndexer().ByIndex("uid", string(rc.OwnerReferences[0].UID)) - if err != nil { - return fmt.Errorf("error retrieving MultiNodeInformer OwnerReference by UID from indexer: %w", err) - } - if len(mnes) != 0 { - return nil - } - - if err := m.removeResourceClaimFinalizer(rc.Namespace, rc.Name); err != nil { - return fmt.Errorf("error removing finalizer on ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) - } - - if err := m.deleteResourceClaim(rc.Namespace, rc.Name); err != nil { - return fmt.Errorf("error deleting ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) - } - - return nil -} - -func (m *MultiNodeEnvironmentManager) createDeviceClass(name string, ownerReference metav1.OwnerReference) (*resourceapi.DeviceClass, error) { - if name != "" { - dc, err := m.deviceClassLister.Get(name) - if err == nil { - if len(dc.OwnerReferences) != 1 && dc.OwnerReferences[0] != ownerReference { - return nil, fmt.Errorf("DeviceClass '%s' exists without expected OwnerReference: %v", name, ownerReference) - } - return dc, nil - } - if !errors.IsNotFound(err) { - return nil, fmt.Errorf("error retrieving DeviceClass: %w", err) - } - } - - deviceClass := &resourceapi.DeviceClass{ - ObjectMeta: metav1.ObjectMeta{ - OwnerReferences: []metav1.OwnerReference{ownerReference}, - Finalizers: []string{multiNodeEnvironmentFinalizer}, - }, - Spec: resourceapi.DeviceClassSpec{ - Selectors: []resourceapi.DeviceSelector{ - { - CEL: &resourceapi.CELDeviceSelector{ - Expression: "device.driver == 'gpu.nvidia.com' && device.attributes['gpu.nvidia.com'].type == 'imex-channel'", - }, - }, - }, - }, - } - - if name == "" { - deviceClass.GenerateName = ownerReference.Name - } else { - deviceClass.Name = name - } - - dc, err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Create(context.Background(), deviceClass, metav1.CreateOptions{}) - if err != nil { - return nil, fmt.Errorf("error creating DeviceClass: %w", err) - } - - return dc, nil -} - -func (m *MultiNodeEnvironmentManager) createResourceClaim(namespace, name, deviceClassName string, ownerReference metav1.OwnerReference) (*resourceapi.ResourceClaim, error) { - rc, err := m.resourceClaimLister.ResourceClaims(namespace).Get(name) - if err == nil { - if len(rc.OwnerReferences) != 1 && rc.OwnerReferences[0] != ownerReference { - return nil, fmt.Errorf("ResourceClaim '%s/%s' exists without expected OwnerReference: %v", namespace, name, ownerReference) - } - return rc, nil - } - if !errors.IsNotFound(err) { - return nil, fmt.Errorf("error retrieving ResourceClaim: %w", err) - } - - resourceClaim := &resourceapi.ResourceClaim{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: namespace, - OwnerReferences: []metav1.OwnerReference{ownerReference}, - Finalizers: []string{multiNodeEnvironmentFinalizer}, - }, - Spec: resourceapi.ResourceClaimSpec{ - Devices: resourceapi.DeviceClaim{ - Requests: []resourceapi.DeviceRequest{{ - Name: "device", DeviceClassName: deviceClassName, - }}, - }, - }, - } - - rc, err = m.clientsets.Core.ResourceV1beta1().ResourceClaims(resourceClaim.Namespace).Create(context.Background(), resourceClaim, metav1.CreateOptions{}) - if err != nil { - return nil, fmt.Errorf("error creating ResourceClaim: %w", err) - } - - return rc, nil -} - -func (m *MultiNodeEnvironmentManager) removeDeviceClassFinalizer(name string) error { - dc, err := m.deviceClassLister.Get(name) - if err != nil && errors.IsNotFound(err) { - return nil - } - if err != nil { - return fmt.Errorf("error retrieving DeviceClass: %w", err) - } - - newDC := dc.DeepCopy() - - newDC.Finalizers = []string{} - for _, f := range dc.Finalizers { - if f != multiNodeEnvironmentFinalizer { - newDC.Finalizers = append(newDC.Finalizers, f) - } - } - - _, err = m.clientsets.Core.ResourceV1beta1().DeviceClasses().Update(context.Background(), newDC, metav1.UpdateOptions{}) - if err != nil { - return fmt.Errorf("error updating DeviceClass: %w", err) - } - - return nil -} - -func (m *MultiNodeEnvironmentManager) removeResourceClaimFinalizer(namespace, name string) error { - rc, err := m.resourceClaimLister.ResourceClaims(namespace).Get(name) - if err != nil && errors.IsNotFound(err) { - return nil - } - if err != nil { - return fmt.Errorf("error retrieving ResourceClaim: %w", err) - } - - newRC := rc.DeepCopy() - - newRC.Finalizers = []string{} - for _, f := range rc.Finalizers { - if f != multiNodeEnvironmentFinalizer { - newRC.Finalizers = append(newRC.Finalizers, f) - } - } - - _, err = m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Update(context.Background(), newRC, metav1.UpdateOptions{}) - if err != nil { - return fmt.Errorf("error updating ResourceClaim: %w", err) - } - - return nil -} - -func (m *MultiNodeEnvironmentManager) deleteDeviceClass(name string) error { - err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Delete(context.Background(), name, metav1.DeleteOptions{}) - if err != nil && !errors.IsNotFound(err) { - return fmt.Errorf("erroring deleting DeviceClass: %w", err) - } - return nil -} - -func (m *MultiNodeEnvironmentManager) deleteResourceClaim(namespace, name string) error { - err := m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Delete(context.Background(), name, metav1.DeleteOptions{}) - if err != nil && !errors.IsNotFound(err) { - return fmt.Errorf("erroring deleting ResourceClaim: %w", err) - } - return nil -} diff --git a/cmd/nvidia-dra-controller/resourceclaim.go b/cmd/nvidia-dra-controller/resourceclaim.go new file mode 100644 index 000000000..05cf0d3fb --- /dev/null +++ b/cmd/nvidia-dra-controller/resourceclaim.go @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "time" + + resourceapi "k8s.io/api/resource/v1beta1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + resourcelisters "k8s.io/client-go/listers/resource/v1beta1" + "k8s.io/client-go/tools/cache" + "k8s.io/klog/v2" + + nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" + "github.com/NVIDIA/k8s-dra-driver/pkg/flags" +) + +type ResourceClaimManager struct { + clientsets flags.ClientSets + ownerExists func(ctx context.Context, uid string) (bool, error) + + informer cache.SharedIndexInformer + lister resourcelisters.ResourceClaimLister +} + +func NewResourceClaimManager(ctx context.Context, config *ManagerConfig, ownerExists OwnerExistsFunc) (*ResourceClaimManager, error) { + informer := config.coreInformerFactory.Resource().V1beta1().ResourceClaims().Informer() + lister := resourcelisters.NewResourceClaimLister(informer.GetIndexer()) + + m := &ResourceClaimManager{ + clientsets: config.clientsets, + ownerExists: ownerExists, + informer: informer, + lister: lister, + } + + _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj any) { + config.workQueue.Enqueue(obj, m.onAddOrUpdate) + }, + UpdateFunc: func(objOld, objNew any) { + config.workQueue.Enqueue(objNew, m.onAddOrUpdate) + }, + }) + if err != nil { + return nil, fmt.Errorf("error adding event handlers for ResourceClaim informer: %w", err) + } + + return m, nil +} + +func (m *ResourceClaimManager) WaitForCacheSync(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { + + return fmt.Errorf("informer cache sync for ResourceClaims failed") + } + return nil +} + +func (m *ResourceClaimManager) Create(ctx context.Context, namespace, name, deviceClassName string, ownerReference metav1.OwnerReference) (*resourceapi.ResourceClaim, error) { + rc, err := m.lister.ResourceClaims(namespace).Get(name) + if err == nil { + if len(rc.OwnerReferences) != 1 && rc.OwnerReferences[0] != ownerReference { + return nil, fmt.Errorf("ResourceClaim '%s/%s' exists without expected OwnerReference: %v", namespace, name, ownerReference) + } + return rc, nil + } + if !errors.IsNotFound(err) { + return nil, fmt.Errorf("error retrieving ResourceClaim: %w", err) + } + + resourceClaim := &resourceapi.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + OwnerReferences: []metav1.OwnerReference{ownerReference}, + Finalizers: []string{multiNodeEnvironmentFinalizer}, + }, + Spec: resourceapi.ResourceClaimSpec{ + Devices: resourceapi.DeviceClaim{ + Requests: []resourceapi.DeviceRequest{{ + Name: "device", DeviceClassName: deviceClassName, + }}, + }, + }, + } + + rc, err = m.clientsets.Core.ResourceV1beta1().ResourceClaims(resourceClaim.Namespace).Create(ctx, resourceClaim, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("error creating ResourceClaim: %w", err) + } + + return rc, nil +} + +func (m *ResourceClaimManager) Delete(ctx context.Context, namespace, name string) error { + err := m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Delete(ctx, name, metav1.DeleteOptions{}) + if err != nil && !errors.IsNotFound(err) { + return fmt.Errorf("erroring deleting ResourceClaim: %w", err) + } + return nil +} + +func (m *ResourceClaimManager) RemoveFinalizer(ctx context.Context, namespace, name string) error { + rc, err := m.lister.ResourceClaims(namespace).Get(name) + if err != nil && errors.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("error retrieving ResourceClaim: %w", err) + } + + newRC := rc.DeepCopy() + + newRC.Finalizers = []string{} + for _, f := range rc.Finalizers { + if f != multiNodeEnvironmentFinalizer { + newRC.Finalizers = append(newRC.Finalizers, f) + } + } + + _, err = m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Update(ctx, newRC, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("error updating ResourceClaim: %w", err) + } + + return nil +} + +func (m *ResourceClaimManager) onAddOrUpdate(ctx context.Context, obj any) error { + rc, ok := obj.(*resourceapi.ResourceClaim) + if !ok { + return fmt.Errorf("failed to cast to ResourceClaim") + } + + klog.Infof("Processing added or updated ResourceClaim: %s/%s", rc.Namespace, rc.Name) + + if len(rc.OwnerReferences) != 1 { + return nil + } + + if rc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind { + return nil + } + + exists, err := m.ownerExists(ctx, string(rc.OwnerReferences[0].UID)) + if err != nil { + return fmt.Errorf("error checking if owner exists: %w", err) + } + if exists { + return nil + } + + if err := m.RemoveFinalizer(ctx, rc.Namespace, rc.Name); err != nil { + return fmt.Errorf("error removing finalizer on ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) + } + + if err := m.Delete(ctx, rc.Namespace, rc.Name); err != nil { + return fmt.Errorf("error deleting ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) + } + + return nil +} diff --git a/pkg/workqueue/workqueue.go b/pkg/workqueue/workqueue.go index f7a139a08..3488ba78f 100644 --- a/pkg/workqueue/workqueue.go +++ b/pkg/workqueue/workqueue.go @@ -17,7 +17,9 @@ package workqueue import ( + "context" "fmt" + "time" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/util/workqueue" @@ -30,7 +32,7 @@ type WorkQueue struct { type WorkItem struct { Object any - Callback func(obj any) error + Callback func(ctx context.Context, obj any) error } func DefaultControllerRateLimiter() workqueue.TypedRateLimiter[any] { @@ -39,25 +41,27 @@ func DefaultControllerRateLimiter() workqueue.TypedRateLimiter[any] { func New(r workqueue.TypedRateLimiter[any]) *WorkQueue { queue := workqueue.NewTypedRateLimitingQueue(r) - return &WorkQueue{queue} + return &WorkQueue{queue: queue} } -func (q *WorkQueue) Run(done <-chan struct{}) { +func (q *WorkQueue) Run(ctx context.Context) { go func() { - <-done + <-ctx.Done() q.queue.ShutDown() }() for { select { - case <-done: + case <-ctx.Done(): return default: - q.processNextWorkItem() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + q.processNextWorkItem(ctx) + cancel() } } } -func (q *WorkQueue) Enqueue(obj any, callback func(obj any) error) { +func (q *WorkQueue) Enqueue(obj any, callback func(ctx context.Context, obj any) error) { runtimeObj, ok := obj.(runtime.Object) if !ok { klog.Warningf("unexpected object type %T: runtime.Object required", obj) @@ -72,7 +76,7 @@ func (q *WorkQueue) Enqueue(obj any, callback func(obj any) error) { q.queue.AddRateLimited(workItem) } -func (q *WorkQueue) processNextWorkItem() { +func (q *WorkQueue) processNextWorkItem(ctx context.Context) { item, shutdown := q.queue.Get() if shutdown { return @@ -85,7 +89,7 @@ func (q *WorkQueue) processNextWorkItem() { return } - err := q.reconcile(workItem) + err := q.reconcile(ctx, workItem) if err != nil { klog.Errorf("Failed to reconcile work item %v: %v", workItem.Object, err) q.queue.AddRateLimited(workItem) @@ -94,9 +98,9 @@ func (q *WorkQueue) processNextWorkItem() { } } -func (q *WorkQueue) reconcile(workItem *WorkItem) error { +func (q *WorkQueue) reconcile(ctx context.Context, workItem *WorkItem) error { if workItem.Callback == nil { return fmt.Errorf("no callback to process work item: %+v", workItem) } - return workItem.Callback(workItem.Object) + return workItem.Callback(ctx, workItem.Object) }