diff --git a/cmd/nvidia-dra-controller/controller.go b/cmd/nvidia-dra-controller/controller.go index 47b1316a2..fe9412178 100644 --- a/cmd/nvidia-dra-controller/controller.go +++ b/cmd/nvidia-dra-controller/controller.go @@ -19,95 +19,50 @@ package main import ( "context" "fmt" - "sync" - "time" - - "k8s.io/client-go/informers" "github.com/NVIDIA/k8s-dra-driver/pkg/flags" - nvinformers "github.com/NVIDIA/k8s-dra-driver/pkg/nvidia.com/resource/informers/externalversions" "github.com/NVIDIA/k8s-dra-driver/pkg/workqueue" ) +// ManagerConfig defines the common config options to pass to all managers. type ManagerConfig struct { - clientsets flags.ClientSets - nvInformerFactory nvinformers.SharedInformerFactory - coreInformerFactory informers.SharedInformerFactory - workQueue *workqueue.WorkQueue + driverName string + driverNamespace string + clientsets flags.ClientSets + workQueue *workqueue.WorkQueue } -type OwnerExistsFunc func(ctx context.Context, uid string) (bool, error) - +// Controller defines the type to represent the controller. type Controller struct { - waitGroup sync.WaitGroup - - ImexManager *ImexManager - MultiNodeEnvironmentManager *MultiNodeEnvironmentManager + config *Config } -// StartController starts a Controller. -func StartController(ctx context.Context, config *Config) (*Controller, error) { - if !config.flags.deviceClasses.Has(ImexChannelType) { - return nil, nil - } +// NewController creates a new Controller. +func NewController(config *Config) *Controller { + return &Controller{config: config} +} +// Run runs a Controller. +func (c *Controller) Run(ctx context.Context) error { workQueue := workqueue.New(workqueue.DefaultControllerRateLimiter()) - nvInformerFactory := nvinformers.NewSharedInformerFactory(config.clientsets.Nvidia, 30*time.Second) - coreInformerFactory := informers.NewSharedInformerFactory(config.clientsets.Core, 30*time.Second) managerConfig := &ManagerConfig{ - clientsets: config.clientsets, - nvInformerFactory: nvInformerFactory, - coreInformerFactory: coreInformerFactory, - workQueue: workQueue, - } - - imexManager, err := StartImexManager(ctx, config) - if err != nil { - return nil, fmt.Errorf("error starting IMEX manager: %w", err) + driverName: c.config.driverName, + driverNamespace: c.config.flags.namespace, + clientsets: c.config.clientsets, + workQueue: workQueue, } - mneManager, err := NewMultiNodeEnvironmentManager(ctx, managerConfig) - if err != nil { - return nil, fmt.Errorf("error starting MultiNodeEnvironment manager: %w", err) - } - - c := &Controller{ - ImexManager: imexManager, - MultiNodeEnvironmentManager: mneManager, - } - - c.waitGroup.Add(3) - go func() { - defer c.waitGroup.Done() - nvInformerFactory.Start(ctx.Done()) - }() - go func() { - defer c.waitGroup.Done() - coreInformerFactory.Start(ctx.Done()) - }() - go func() { - defer c.waitGroup.Done() - workQueue.Run(ctx) - }() - - if err := c.MultiNodeEnvironmentManager.WaitForCacheSync(ctx); err != nil { - return nil, fmt.Errorf("error syncing cache: %w", err) - } - - return c, nil -} + mneManager := NewMultiNodeEnvironmentManager(managerConfig) -// Stop stops a running Controller. -func (m *Controller) Stop() error { - if m == nil { - return nil + if err := mneManager.Start(ctx); err != nil { + return fmt.Errorf("error starting MultiNodeEnvironment manager: %w", err) } - m.waitGroup.Wait() + workQueue.Run(ctx) - if err := m.ImexManager.Stop(); err != nil { - return fmt.Errorf("error stopping IMEX manager: %w", err) + if err := mneManager.Stop(); err != nil { + return fmt.Errorf("error stopping MultiNodeEnvironment manager: %w", err) } return nil diff --git a/cmd/nvidia-dra-controller/deployment.go b/cmd/nvidia-dra-controller/deployment.go new file mode 100644 index 000000000..18581a010 --- /dev/null +++ b/cmd/nvidia-dra-controller/deployment.go @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "bytes" + "context" + "fmt" + "sync" + "text/template" + + appsv1 "k8s.io/api/apps/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/yaml" + "k8s.io/client-go/informers" + appsv1listers "k8s.io/client-go/listers/apps/v1" + "k8s.io/client-go/tools/cache" + "k8s.io/klog/v2" + + "github.com/google/uuid" + + nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" +) + +const ( + DeploymentTemplatePath = "/templates/imex-daemon.tmpl.yaml" +) + +type DeploymentTemplateData struct { + Namespace string + GenerateName string + AppLabel string + Finalizer string + MultiNodeEnvironmentLabelKey string + MultiNodeEnvironmentLabelValue types.UID + Replicas int + NvidiaDriverRoot string +} + +type DeploymentManager struct { + sync.Mutex + + config *ManagerConfig + waitGroup sync.WaitGroup + cancelContext context.CancelFunc + multiNodeEnvironmentExists MultiNodeEnvironmentExistsFunc + + factory informers.SharedInformerFactory + informer cache.SharedIndexInformer + lister appsv1listers.DeploymentLister + + imexChannelManager *ImexChannelManager + podManagers map[string]*DeploymentPodManager +} + +func NewDeploymentManager(config *ManagerConfig, mneExists MultiNodeEnvironmentExistsFunc) *DeploymentManager { + labelSelector := &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: multiNodeEnvironmentLabelKey, + Operator: metav1.LabelSelectorOpExists, + }, + }, + } + + factory := informers.NewSharedInformerFactoryWithOptions( + config.clientsets.Core, + informerResyncPeriod, + informers.WithNamespace(config.driverNamespace), + informers.WithTweakListOptions(func(opts *metav1.ListOptions) { + opts.LabelSelector = metav1.FormatLabelSelector(labelSelector) + }), + ) + + informer := factory.Apps().V1().Deployments().Informer() + lister := factory.Apps().V1().Deployments().Lister() + + m := &DeploymentManager{ + config: config, + multiNodeEnvironmentExists: mneExists, + factory: factory, + informer: informer, + lister: lister, + imexChannelManager: NewImexChannelManager(config), + podManagers: make(map[string]*DeploymentPodManager), + } + + return m +} + +func (m *DeploymentManager) Start(ctx context.Context) (rerr error) { + ctx, cancel := context.WithCancel(ctx) + m.cancelContext = cancel + + defer func() { + if rerr != nil { + if err := m.Stop(); err != nil { + klog.Errorf("error stopping Deployment manager: %v", err) + } + } + }() + + if err := addMultiNodeEnvironmentLabelIndexer[*appsv1.Deployment](m.informer); err != nil { + return fmt.Errorf("error adding indexer for MulitNodeEnvironment label: %w", err) + } + + _, err := m.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj any) { + m.config.workQueue.Enqueue(obj, m.onAddOrUpdate) + }, + UpdateFunc: func(objOld, objNew any) { + m.config.workQueue.Enqueue(objNew, m.onAddOrUpdate) + }, + }) + if err != nil { + return fmt.Errorf("error adding event handlers for Deployment informer: %w", err) + } + + m.waitGroup.Add(1) + go func() { + defer m.waitGroup.Done() + m.factory.Start(ctx.Done()) + }() + + if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { + return fmt.Errorf("informer cache sync for Deployment failed") + } + + if err := m.imexChannelManager.Start(ctx); err != nil { + return fmt.Errorf("error starting IMEX channel manager: %w", err) + } + + return nil +} + +func (m *DeploymentManager) Stop() error { + if err := m.removeAllPodManagers(); err != nil { + return fmt.Errorf("error removing all Pod managers: %w", err) + } + if err := m.imexChannelManager.Stop(); err != nil { + return fmt.Errorf("error stopping IMEX channel manager: %w", err) + } + m.cancelContext() + m.waitGroup.Wait() + return nil +} + +func (m *DeploymentManager) Create(ctx context.Context, namespace string, replicas int, mne *nvapi.MultiNodeEnvironment) (*appsv1.Deployment, error) { + d, err := getByMultiNodeEnvironmentUID[*appsv1.Deployment](ctx, m.informer, string(mne.UID)) + if err != nil { + return nil, fmt.Errorf("error retrieving Deployment: %w", err) + } + if d != nil { + return d, nil + } + + templateData := DeploymentTemplateData{ + Namespace: m.config.driverNamespace, + GenerateName: fmt.Sprintf("%s-", mne.Name), + AppLabel: uuid.New().String(), + Finalizer: multiNodeEnvironmentFinalizer, + MultiNodeEnvironmentLabelKey: multiNodeEnvironmentLabelKey, + MultiNodeEnvironmentLabelValue: mne.UID, + Replicas: replicas, + NvidiaDriverRoot: "/", + } + + tmpl, err := template.ParseFiles(DeploymentTemplatePath) + if err != nil { + return nil, fmt.Errorf("failed to parse template file: %w", err) + } + + var deploymentYaml bytes.Buffer + if err := tmpl.Execute(&deploymentYaml, templateData); err != nil { + return nil, fmt.Errorf("failed to execute template: %w", err) + } + + var unstructuredObj unstructured.Unstructured + err = yaml.Unmarshal(deploymentYaml.Bytes(), &unstructuredObj) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal yaml: %w", err) + } + + var deployment appsv1.Deployment + err = runtime.DefaultUnstructuredConverter.FromUnstructured(unstructuredObj.UnstructuredContent(), &deployment) + if err != nil { + return nil, fmt.Errorf("failed to convert unstructured data to typed object: %w", err) + } + + d, err = m.config.clientsets.Core.AppsV1().Deployments(deployment.Namespace).Create(ctx, &deployment, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("error creating Deployment: %w", err) + } + + return d, nil +} + +func (m *DeploymentManager) Delete(ctx context.Context, mneUID string) error { + d, err := getByMultiNodeEnvironmentUID[*appsv1.Deployment](ctx, m.informer, mneUID) + if err != nil { + return fmt.Errorf("error retrieving Deployment: %w", err) + } + if d == nil { + return nil + } + + if err := m.RemoveFinalizer(ctx, d); err != nil { + return fmt.Errorf("error removing finalizer on Deployment: %w", err) + } + + err = m.config.clientsets.Core.AppsV1().Deployments(d.Namespace).Delete(ctx, d.Name, metav1.DeleteOptions{}) + if err != nil && !errors.IsNotFound(err) { + return fmt.Errorf("erroring deleting Deployment: %w", err) + } + + key := d.Spec.Selector.MatchLabels[multiNodeEnvironmentLabelKey] + if err := m.removePodManager(key); err != nil { + return fmt.Errorf("error removing Pod manager: %w", err) + } + + return nil +} + +func (m *DeploymentManager) RemoveFinalizer(ctx context.Context, d *appsv1.Deployment) error { + newD := d.DeepCopy() + + newD.Finalizers = []string{} + for _, f := range d.Finalizers { + if f != multiNodeEnvironmentFinalizer { + newD.Finalizers = append(newD.Finalizers, f) + } + } + + if _, err := m.config.clientsets.Core.AppsV1().Deployments(d.Namespace).Update(ctx, newD, metav1.UpdateOptions{}); err != nil { + return fmt.Errorf("error updating Deployment: %w", err) + } + + return nil +} + +func (m *DeploymentManager) onAddOrUpdate(ctx context.Context, obj any) error { + d, ok := obj.(*appsv1.Deployment) + if !ok { + return fmt.Errorf("failed to cast to Deployment") + } + + d, err := m.lister.Deployments(d.Namespace).Get(d.Name) + if err != nil && errors.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("erroring retreiving Deployment: %w", err) + } + + klog.Infof("Processing added or updated Deployment: %s/%s", d.Namespace, d.Name) + + exists, err := m.multiNodeEnvironmentExists(d.Labels[multiNodeEnvironmentLabelKey]) + if err != nil { + return fmt.Errorf("error checking if owner exists: %w", err) + } + if !exists { + if err := m.Delete(ctx, d.Labels[multiNodeEnvironmentLabelKey]); err != nil { + return fmt.Errorf("error deleting Deployment '%s/%s': %w", d.Namespace, d.Name, err) + } + return nil + } + + if err := m.addPodManager(ctx, d.Spec.Selector, int(*d.Spec.Replicas)); err != nil { + return fmt.Errorf("error adding Pod manager '%s/%s': %w", d.Namespace, d.Name, err) + } + + return nil +} + +func (m *DeploymentManager) addPodManager(ctx context.Context, labelSelector *metav1.LabelSelector, numPods int) error { + key := labelSelector.MatchLabels[multiNodeEnvironmentLabelKey] + + if _, exists := m.podManagers[key]; exists { + return nil + } + + podManager := NewDeploymentPodManager(m.config, m.imexChannelManager, labelSelector, numPods) + + if err := podManager.Start(ctx); err != nil { + return fmt.Errorf("error creating Pod manager: %w", err) + } + + m.Lock() + m.podManagers[key] = podManager + m.Unlock() + + return nil +} + +func (m *DeploymentManager) removePodManager(key string) error { + if _, exists := m.podManagers[key]; !exists { + return nil + } + + m.Lock() + podManager := m.podManagers[key] + m.Unlock() + + if err := podManager.Stop(); err != nil { + return fmt.Errorf("error stopping Pod manager: %w", err) + } + + m.Lock() + delete(m.podManagers, key) + m.Unlock() + + return nil +} + +func (m *DeploymentManager) removeAllPodManagers() error { + m.Lock() + for key, pm := range m.podManagers { + m.Unlock() + if err := pm.Stop(); err != nil { + return fmt.Errorf("error stopping Pod manager: %w", err) + } + m.Lock() + delete(m.podManagers, key) + } + m.Unlock() + return nil +} diff --git a/cmd/nvidia-dra-controller/deploymentpods.go b/cmd/nvidia-dra-controller/deploymentpods.go new file mode 100644 index 000000000..7a3726b13 --- /dev/null +++ b/cmd/nvidia-dra-controller/deploymentpods.go @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "slices" + "sync" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/informers" + corev1listers "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/cache" + "k8s.io/klog/v2" +) + +type DeploymentPodManager struct { + config *ManagerConfig + waitGroup sync.WaitGroup + cancelContext context.CancelFunc + + factory informers.SharedInformerFactory + informer cache.SharedInformer + lister corev1listers.PodLister + + nodeSelector corev1.NodeSelector + multiNodeEnvironmentLabel string + numPods int + + imexChannelManager *ImexChannelManager +} + +func NewDeploymentPodManager(config *ManagerConfig, imexChannelManager *ImexChannelManager, labelSelector *metav1.LabelSelector, numPods int) *DeploymentPodManager { + factory := informers.NewSharedInformerFactoryWithOptions( + config.clientsets.Core, + informerResyncPeriod, + informers.WithNamespace(config.driverNamespace), + informers.WithTweakListOptions(func(opts *metav1.ListOptions) { + opts.LabelSelector = metav1.FormatLabelSelector(labelSelector) + }), + ) + + informer := factory.Core().V1().Pods().Informer() + lister := factory.Core().V1().Pods().Lister() + + nodeSelector := corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "kubernetes.io/hostname", + Operator: corev1.NodeSelectorOpIn, + Values: []string{}, + }, + }, + }, + }, + } + + m := &DeploymentPodManager{ + config: config, + factory: factory, + informer: informer, + lister: lister, + nodeSelector: nodeSelector, + multiNodeEnvironmentLabel: labelSelector.MatchLabels[multiNodeEnvironmentLabelKey], + numPods: numPods, + imexChannelManager: imexChannelManager, + } + + return m +} + +func (m *DeploymentPodManager) Start(ctx context.Context) (rerr error) { + ctx, cancel := context.WithCancel(ctx) + m.cancelContext = cancel + + defer func() { + if rerr != nil { + if err := m.Stop(); err != nil { + klog.Errorf("error stopping DeploymentPod manager: %v", err) + } + } + }() + + _, err := m.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + m.config.workQueue.Enqueue(obj, m.onPodAddOrUpdate) + }, + UpdateFunc: func(objOld, objNew any) { + m.config.workQueue.Enqueue(objNew, m.onPodAddOrUpdate) + }, + }) + if err != nil { + return fmt.Errorf("error adding event handlers for pod informer: %w", err) + } + + m.waitGroup.Add(1) + go func() { + defer m.waitGroup.Done() + m.factory.Start(ctx.Done()) + }() + + if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { + return fmt.Errorf("error syncing pod informer: %w", err) + } + + return nil +} + +func (m *DeploymentPodManager) Stop() error { + if err := m.imexChannelManager.DeletePool(m.multiNodeEnvironmentLabel); err != nil { + return fmt.Errorf("error deleting IMEX channel pool: %w", err) + } + m.cancelContext() + m.waitGroup.Wait() + return nil +} + +func (m *DeploymentPodManager) onPodAddOrUpdate(ctx context.Context, obj any) error { + p, ok := obj.(*corev1.Pod) + if !ok { + return fmt.Errorf("failed to cast to Pod") + } + + p, err := m.lister.Pods(p.Namespace).Get(p.Name) + if err != nil && errors.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("erroring retreiving Pod: %w", err) + } + + klog.Infof("Processing added or updated Pod: %s/%s", p.Namespace, p.Name) + + if p.Spec.NodeName == "" { + return fmt.Errorf("pod not yet scheduled: %s/%s", p.Namespace, p.Name) + } + + hostnameLabels := m.nodeSelector.NodeSelectorTerms[0].MatchExpressions[0].Values + if !slices.Contains(hostnameLabels, p.Spec.NodeName) { + hostnameLabels = append(hostnameLabels, p.Spec.NodeName) + } + m.nodeSelector.NodeSelectorTerms[0].MatchExpressions[0].Values = hostnameLabels + + if len(hostnameLabels) != m.numPods { + return fmt.Errorf("node selector not yet complete") + } + + if err := m.imexChannelManager.CreateOrUpdatePool(m.multiNodeEnvironmentLabel, &m.nodeSelector); err != nil { + return fmt.Errorf("failed to create or update IMEX channel pool: %w", err) + } + + return nil +} diff --git a/cmd/nvidia-dra-controller/deviceclass.go b/cmd/nvidia-dra-controller/deviceclass.go index 5166ee958..5c813487b 100644 --- a/cmd/nvidia-dra-controller/deviceclass.go +++ b/cmd/nvidia-dra-controller/deviceclass.go @@ -19,88 +19,140 @@ package main import ( "context" "fmt" - "time" + "sync" resourceapi "k8s.io/api/resource/v1beta1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/informers" resourcelisters "k8s.io/client-go/listers/resource/v1beta1" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" - "github.com/NVIDIA/k8s-dra-driver/pkg/flags" ) type DeviceClassManager struct { - clientsets flags.ClientSets - ownerExists func(ctx context.Context, uid string) (bool, error) + config *ManagerConfig + waitGroup sync.WaitGroup + cancelContext context.CancelFunc + multiNodeEnvironmentExists MultiNodeEnvironmentExistsFunc + factory informers.SharedInformerFactory informer cache.SharedIndexInformer lister resourcelisters.DeviceClassLister } -func NewDeviceClassManager(ctx context.Context, config *ManagerConfig, ownerExists OwnerExistsFunc) (*DeviceClassManager, error) { - informer := config.coreInformerFactory.Resource().V1beta1().DeviceClasses().Informer() - lister := resourcelisters.NewDeviceClassLister(informer.GetIndexer()) +func NewDeviceClassManager(config *ManagerConfig, mneExists MultiNodeEnvironmentExistsFunc) *DeviceClassManager { + labelSelector := &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: multiNodeEnvironmentLabelKey, + Operator: metav1.LabelSelectorOpExists, + }, + }, + } + + factory := informers.NewSharedInformerFactoryWithOptions( + config.clientsets.Core, + informerResyncPeriod, + informers.WithTweakListOptions(func(opts *metav1.ListOptions) { + opts.LabelSelector = metav1.FormatLabelSelector(labelSelector) + }), + ) + + informer := factory.Resource().V1beta1().DeviceClasses().Informer() + lister := factory.Resource().V1beta1().DeviceClasses().Lister() m := &DeviceClassManager{ - clientsets: config.clientsets, - ownerExists: ownerExists, - informer: informer, - lister: lister, + config: config, + multiNodeEnvironmentExists: mneExists, + factory: factory, + informer: informer, + lister: lister, } - _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + return m +} + +func (m *DeviceClassManager) Start(ctx context.Context) (rerr error) { + ctx, cancel := context.WithCancel(ctx) + m.cancelContext = cancel + + defer func() { + if rerr != nil { + if err := m.Stop(); err != nil { + klog.Errorf("error stopping DeviceClass manager: %v", err) + } + } + }() + + if err := addMultiNodeEnvironmentLabelIndexer[*resourceapi.DeviceClass](m.informer); err != nil { + return fmt.Errorf("error adding indexer for MulitNodeEnvironment label: %w", err) + } + + _, err := m.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj any) { - config.workQueue.Enqueue(obj, m.onAddOrUpdate) + m.config.workQueue.Enqueue(obj, m.onAddOrUpdate) }, UpdateFunc: func(objOld, objNew any) { - config.workQueue.Enqueue(objNew, m.onAddOrUpdate) + m.config.workQueue.Enqueue(objNew, m.onAddOrUpdate) }, }) if err != nil { - return nil, fmt.Errorf("error adding event handlers for DeviceClass informer: %w", err) + return fmt.Errorf("error adding event handlers for DeviceClass informer: %w", err) } - return m, nil -} - -func (m *DeviceClassManager) WaitForCacheSync(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + m.waitGroup.Add(1) + go func() { + defer m.waitGroup.Done() + m.factory.Start(ctx.Done()) + }() if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { - return fmt.Errorf("informer cache sync for DeviceClasses failed") + return fmt.Errorf("informer cache sync for DeviceClass failed") } return nil } -func (m *DeviceClassManager) Create(ctx context.Context, name string, ownerReference metav1.OwnerReference) (*resourceapi.DeviceClass, error) { - if name != "" { - dc, err := m.lister.Get(name) - if err == nil { - if len(dc.OwnerReferences) != 1 && dc.OwnerReferences[0] != ownerReference { - return nil, fmt.Errorf("DeviceClass '%s' exists without expected OwnerReference: %v", name, ownerReference) - } - return dc, nil - } - if !errors.IsNotFound(err) { - return nil, fmt.Errorf("error retrieving DeviceClass: %w", err) - } +func (m *DeviceClassManager) Stop() error { + m.cancelContext() + m.waitGroup.Wait() + return nil +} + +func (m *DeviceClassManager) Create(ctx context.Context, name string, mne *nvapi.MultiNodeEnvironment) (*resourceapi.DeviceClass, error) { + dc, err := getByMultiNodeEnvironmentUID[*resourceapi.DeviceClass](ctx, m.informer, string(mne.UID)) + if err != nil { + return nil, fmt.Errorf("error retrieving DeviceClass: %w", err) + } + if dc != nil { + return dc, nil } deviceClass := &resourceapi.DeviceClass{ ObjectMeta: metav1.ObjectMeta{ - OwnerReferences: []metav1.OwnerReference{ownerReference}, - Finalizers: []string{multiNodeEnvironmentFinalizer}, + Finalizers: []string{multiNodeEnvironmentFinalizer}, + Labels: map[string]string{ + multiNodeEnvironmentLabelKey: string(mne.UID), + }, }, Spec: resourceapi.DeviceClassSpec{ Selectors: []resourceapi.DeviceSelector{ { CEL: &resourceapi.CELDeviceSelector{ - Expression: "device.driver == 'gpu.nvidia.com' && device.attributes['gpu.nvidia.com'].type == 'imex-channel'", + Expression: "device.driver == 'gpu.nvidia.com'", + }, + }, + { + CEL: &resourceapi.CELDeviceSelector{ + Expression: "device.attributes['gpu.nvidia.com'].type == 'imex-channel'", + }, + }, + { + CEL: &resourceapi.CELDeviceSelector{ + Expression: fmt.Sprintf("device.attributes['gpu.nvidia.com'].domain == '%v'", mne.UID), }, }, }, @@ -108,12 +160,12 @@ func (m *DeviceClassManager) Create(ctx context.Context, name string, ownerRefer } if name == "" { - deviceClass.GenerateName = ownerReference.Name + deviceClass.GenerateName = mne.Name } else { deviceClass.Name = name } - dc, err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Create(ctx, deviceClass, metav1.CreateOptions{}) + dc, err = m.config.clientsets.Core.ResourceV1beta1().DeviceClasses().Create(ctx, deviceClass, metav1.CreateOptions{}) if err != nil { return nil, fmt.Errorf("error creating DeviceClass: %w", err) } @@ -121,11 +173,24 @@ func (m *DeviceClassManager) Create(ctx context.Context, name string, ownerRefer return dc, nil } -func (m *DeviceClassManager) Delete(ctx context.Context, name string) error { - err := m.clientsets.Core.ResourceV1beta1().DeviceClasses().Delete(ctx, name, metav1.DeleteOptions{}) +func (m *DeviceClassManager) Delete(ctx context.Context, mneUID string) error { + dc, err := getByMultiNodeEnvironmentUID[*resourceapi.DeviceClass](ctx, m.informer, mneUID) + if err != nil { + return fmt.Errorf("error retrieving DeviceClass: %w", err) + } + if dc == nil { + return nil + } + + if err := m.RemoveFinalizer(ctx, dc.Name); err != nil { + return fmt.Errorf("error removing finalizer on DeviceClass '%s': %w", dc.Name, err) + } + + err = m.config.clientsets.Core.ResourceV1beta1().DeviceClasses().Delete(ctx, dc.Name, metav1.DeleteOptions{}) if err != nil && !errors.IsNotFound(err) { return fmt.Errorf("erroring deleting DeviceClass: %w", err) } + return nil } @@ -147,7 +212,7 @@ func (m *DeviceClassManager) RemoveFinalizer(ctx context.Context, name string) e } } - _, err = m.clientsets.Core.ResourceV1beta1().DeviceClasses().Update(ctx, newDC, metav1.UpdateOptions{}) + _, err = m.config.clientsets.Core.ResourceV1beta1().DeviceClasses().Update(ctx, newDC, metav1.UpdateOptions{}) if err != nil { return fmt.Errorf("error updating DeviceClass: %w", err) } @@ -161,31 +226,26 @@ func (m *DeviceClassManager) onAddOrUpdate(ctx context.Context, obj any) error { return fmt.Errorf("failed to cast to DeviceClass") } - klog.Infof("Processing added or updated DeviceClass: %s", dc.Name) - - if len(dc.OwnerReferences) != 1 { + dc, err := m.lister.Get(dc.Name) + if err != nil && errors.IsNotFound(err) { return nil } - - if dc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind { - return nil + if err != nil { + return fmt.Errorf("erroring retreiving DeviceClass: %w", err) } - exists, err := m.ownerExists(ctx, string(dc.OwnerReferences[0].UID)) + klog.Infof("Processing added or updated DeviceClass: %s", dc.Name) + + exists, err := m.multiNodeEnvironmentExists(dc.Labels[multiNodeEnvironmentLabelKey]) if err != nil { return fmt.Errorf("error checking if owner exists: %w", err) } - if exists { + if !exists { + if err := m.Delete(ctx, dc.Labels[multiNodeEnvironmentLabelKey]); err != nil { + return fmt.Errorf("error deleting DeviceClass '%s': %w", dc.Name, err) + } return nil } - if err := m.RemoveFinalizer(ctx, dc.Name); err != nil { - return fmt.Errorf("error removing finalizer on DeviceClass '%s': %w", dc.Name, err) - } - - if err := m.Delete(ctx, dc.Name); err != nil { - return fmt.Errorf("error deleting DeviceClass '%s': %w", dc.Name, err) - } - return nil } diff --git a/cmd/nvidia-dra-controller/imex.go b/cmd/nvidia-dra-controller/imex.go deleted file mode 100644 index 52f70b370..000000000 --- a/cmd/nvidia-dra-controller/imex.go +++ /dev/null @@ -1,405 +0,0 @@ -/* - * Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "context" - "errors" - "fmt" - "strings" - "sync" - "time" - - v1 "k8s.io/api/core/v1" - resourceapi "k8s.io/api/resource/v1beta1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/selection" - "k8s.io/client-go/informers" - "k8s.io/client-go/tools/cache" - "k8s.io/dynamic-resource-allocation/resourceslice" - "k8s.io/klog/v2" - "k8s.io/utils/ptr" - - "github.com/NVIDIA/k8s-dra-driver/pkg/flags" -) - -const ( - DriverName = "gpu.nvidia.com" - ImexDomainLabel = "nvidia.com/gpu.imex-domain" - ResourceSliceImexChannelLimit = 128 - DriverImexChannelLimit = 2048 - RetryTimeout = 1 * time.Minute -) - -// transientError defines an error indicating that it is transient. -type transientError struct{ error } - -// imexDomainOffsets represents the offset for assigning IMEX channels -// to ResourceSlices for each combination. -type imexDomainOffsets map[string]map[string]int - -type ImexManager struct { - driverName string - resourceSliceImexChannelLimit int - driverImexChannelLimit int - retryTimeout time.Duration - waitGroup sync.WaitGroup - clientsets flags.ClientSets - imexDomainOffsets imexDomainOffsets - driverResources *resourceslice.DriverResources -} - -func StartImexManager(ctx context.Context, config *Config) (*ImexManager, error) { - // Create a new set of DriverResources - driverResources := &resourceslice.DriverResources{ - Pools: make(map[string]resourceslice.Pool), - } - - // Create the manager itself - m := &ImexManager{ - driverName: DriverName, - resourceSliceImexChannelLimit: ResourceSliceImexChannelLimit, - driverImexChannelLimit: DriverImexChannelLimit, - retryTimeout: RetryTimeout, - clientsets: config.clientsets, - driverResources: driverResources, - imexDomainOffsets: make(imexDomainOffsets), - } - - // Add/Remove resource slices from IMEX domains as they come and go - err := m.manageResourceSlices(ctx) - if err != nil { - return nil, fmt.Errorf("error managing resource slices: %w", err) - } - - return m, nil -} - -// manageResourceSlices reacts to added and removed IMEX domains and triggers the creation / removal of resource slices accordingly. -func (m *ImexManager) manageResourceSlices(ctx context.Context) error { - klog.Info("Start streaming IMEX domains from nodes...") - addedDomainsCh, removedDomainsCh, err := m.streamImexDomains(ctx) - if err != nil { - return fmt.Errorf("error streaming IMEX domains: %w", err) - } - - options := resourceslice.Options{ - DriverName: m.driverName, - KubeClient: m.clientsets.Core, - Resources: m.driverResources, - } - - klog.Info("Start publishing IMEX channels to ResourceSlices...") - controller, err := resourceslice.StartController(ctx, options) - if err != nil { - return fmt.Errorf("error starting resource slice controller: %w", err) - } - - m.waitGroup.Add(1) - go func() { - defer m.waitGroup.Done() - for { - select { - case addedDomain := <-addedDomainsCh: - klog.Infof("Adding channels for new IMEX domain: %v", addedDomain) - if err := m.addImexDomain(addedDomain); err != nil { - klog.Errorf("Error adding channels for IMEX domain %s: %v", addedDomain, err) - if errors.As(err, &transientError{}) { - klog.Infof("Retrying adding channels for IMEX domain %s after %v", addedDomain, m.retryTimeout) - go func() { - time.Sleep(m.retryTimeout) - addedDomainsCh <- addedDomain - }() - } - } - controller.Update(m.driverResources) - case removedDomain := <-removedDomainsCh: - klog.Infof("Removing channels for removed IMEX domain: %v", removedDomain) - if err := m.removeImexDomain(removedDomain); err != nil { - klog.Errorf("Error removing channels for IMEX domain %s: %v", removedDomain, err) - if errors.As(err, &transientError{}) { - klog.Infof("Retrying removing channels for IMEX domain %s after %v", removedDomain, m.retryTimeout) - go func() { - time.Sleep(m.retryTimeout) - removedDomainsCh <- removedDomain - }() - } - } - controller.Update(m.driverResources) - case <-ctx.Done(): - return - } - } - }() - - return nil -} - -// Stop stops a running ImexManager. -func (m *ImexManager) Stop() error { - if m == nil { - return nil - } - - m.waitGroup.Wait() - klog.Info("Cleaning up all resourceSlices") - if err := m.cleanupResourceSlices(); err != nil { - return fmt.Errorf("error cleaning up resource slices: %w", err) - } - - return nil -} - -// addImexDomain adds an IMEX domain to be managed by the ImexManager. -func (m *ImexManager) addImexDomain(imexDomain string) error { - imexDomainID, cliqueID, err := splitImexDomain(imexDomain) - if err != nil { - return fmt.Errorf("error splitting IMEX domain '%s': %v", imexDomain, err) - } - offset, err := m.imexDomainOffsets.add(imexDomainID, cliqueID, m.resourceSliceImexChannelLimit, m.driverImexChannelLimit) - if err != nil { - return fmt.Errorf("error setting offset for IMEX channels: %w", err) - } - m.driverResources = m.driverResources.DeepCopy() - m.driverResources.Pools[imexDomain] = generateImexChannelPool(imexDomain, offset, m.resourceSliceImexChannelLimit) - return nil -} - -// removeImexDomain removes an IMEX domain from being managed by the ImexManager. -func (m *ImexManager) removeImexDomain(imexDomain string) error { - imexDomainID, cliqueID, err := splitImexDomain(imexDomain) - if err != nil { - return fmt.Errorf("error splitting IMEX domain '%s': %v", imexDomain, err) - } - m.imexDomainOffsets.remove(imexDomainID, cliqueID) - m.driverResources = m.driverResources.DeepCopy() - delete(m.driverResources.Pools, imexDomain) - return nil -} - -// streamImexDomains returns two channels that streams imexDomans that are added and removed from nodes over time. -func (m *ImexManager) streamImexDomains(ctx context.Context) (chan string, chan string, error) { - // Create channels to stream IMEX domain ids that are added / removed - addedDomainCh := make(chan string) - removedDomainCh := make(chan string) - - // Use a map to track how many nodes are part of a given IMEX domain - nodesPerImexDomain := make(map[string]int) - - // Build a label selector to get all nodes with ImexDomainLabel set - requirement, err := labels.NewRequirement(ImexDomainLabel, selection.Exists, nil) - if err != nil { - return nil, nil, fmt.Errorf("error building label selector requirement: %w", err) - } - labelSelector := labels.NewSelector().Add(*requirement).String() - - // Create a shared informer factory for nodes - informerFactory := informers.NewSharedInformerFactoryWithOptions( - m.clientsets.Core, - time.Minute*10, // Resync period - informers.WithTweakListOptions(func(options *metav1.ListOptions) { - options.LabelSelector = labelSelector - }), - ) - nodeInformer := informerFactory.Core().V1().Nodes().Informer() - - // Set up event handlers for node events - _, err = nodeInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - node := obj.(*v1.Node) // nolint:forcetypeassert - imexDomain := node.Labels[ImexDomainLabel] - if imexDomain != "" { - nodesPerImexDomain[imexDomain]++ - if nodesPerImexDomain[imexDomain] == 1 { - addedDomainCh <- imexDomain - } - } - }, - DeleteFunc: func(obj interface{}) { - node := obj.(*v1.Node) // nolint:forcetypeassert - imexDomain := node.Labels[ImexDomainLabel] - if imexDomain != "" { - nodesPerImexDomain[imexDomain]-- - if nodesPerImexDomain[imexDomain] == 0 { - removedDomainCh <- imexDomain - } - } - }, - UpdateFunc: func(oldObj, newObj interface{}) { - oldNode := oldObj.(*v1.Node) // nolint:forcetypeassert - newNode := newObj.(*v1.Node) // nolint:forcetypeassert - - oldImexDomain := oldNode.Labels[ImexDomainLabel] - newImexDomain := newNode.Labels[ImexDomainLabel] - - if oldImexDomain == newImexDomain { - return - } - if oldImexDomain != "" { - nodesPerImexDomain[oldImexDomain]-- - if nodesPerImexDomain[oldImexDomain] == 0 { - removedDomainCh <- oldImexDomain - } - } - if newImexDomain != "" { - nodesPerImexDomain[newImexDomain]++ - if nodesPerImexDomain[newImexDomain] == 1 { - addedDomainCh <- newImexDomain - } - } - }, - }) - if err != nil { - return nil, nil, fmt.Errorf("failed to create node informer: %w", err) - } - - // Start the informer and wait for it to sync - m.waitGroup.Add(1) - go func() { - defer m.waitGroup.Done() - informerFactory.Start(ctx.Done()) - }() - - // Wait for the informer caches to sync - if !cache.WaitForCacheSync(ctx.Done(), nodeInformer.HasSynced) { - return nil, nil, fmt.Errorf("failed to sync informer caches") - } - - return addedDomainCh, removedDomainCh, nil -} - -// cleanupResourceSlices removes all resource slices created by the IMEX manager. -func (m *ImexManager) cleanupResourceSlices() error { - // Delete all resource slices created by the IMEX manager - ops := metav1.ListOptions{ - FieldSelector: fmt.Sprintf("%s=%s", resourceapi.ResourceSliceSelectorDriver, DriverName), - } - l, err := m.clientsets.Core.ResourceV1beta1().ResourceSlices().List(context.Background(), ops) - if err != nil { - return fmt.Errorf("error listing resource slices: %w", err) - } - - for _, rs := range l.Items { - err := m.clientsets.Core.ResourceV1beta1().ResourceSlices().Delete(context.Background(), rs.Name, metav1.DeleteOptions{}) - if err != nil { - return fmt.Errorf("error deleting resource slice %s: %w", rs.Name, err) - } - } - - return nil -} - -// add sets the offset where an IMEX domain's channels should start counting from. -func (offsets imexDomainOffsets) add(imexDomainID string, cliqueID string, resourceSliceImexChannelLimit, driverImexChannelLimit int) (int, error) { - // Check if the IMEX domain is already in the map - if _, ok := offsets[imexDomainID]; !ok { - offsets[imexDomainID] = make(map[string]int) - } - - // Return early if the clique is already in the map - if offset, exists := offsets[imexDomainID][cliqueID]; exists { - return offset, nil - } - - // Track used offsets for the current imexDomain - usedOffsets := make(map[int]struct{}) - for _, v := range offsets[imexDomainID] { - usedOffsets[v] = struct{}{} - } - - // Look for the first unused offset, stepping by resourceSliceImexChannelLimit - var offset int - for offset = 0; offset < driverImexChannelLimit; offset += resourceSliceImexChannelLimit { - if _, exists := usedOffsets[offset]; !exists { - break - } - } - - // If we reach the limit, return an error - if offset == driverImexChannelLimit { - return -1, transientError{fmt.Errorf("channel limit reached")} - } - offsets[imexDomainID][cliqueID] = offset - - return offset, nil -} - -// remove removes the offset where an IMEX domain's channels should start counting from. -func (offsets imexDomainOffsets) remove(imexDomainID string, cliqueID string) { - delete(offsets[imexDomainID], cliqueID) - if len(offsets[imexDomainID]) == 0 { - delete(offsets, imexDomainID) - } -} - -// splitImexDomain splits an imexDomain into its IMEX domain ID and its clique ID. -func splitImexDomain(imexDomain string) (string, string, error) { - id := strings.SplitN(imexDomain, ".", 2) - if len(id) != 2 { - return "", "", fmt.Errorf("splitting by '.' not equal to exactly 2 elements") - } - return id[0], id[1], nil -} - -// generateImexChannelPool generates the contents of a ResourceSlice pool for a given IMEX domain. -func generateImexChannelPool(imexDomain string, startChannel int, numChannels int) resourceslice.Pool { - // Generate channels from startChannel to startChannel+numChannels - var devices []resourceapi.Device - for i := startChannel; i < (startChannel + numChannels); i++ { - d := resourceapi.Device{ - Name: fmt.Sprintf("imex-channel-%d", i), - Basic: &resourceapi.BasicDevice{ - Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ - "type": { - StringValue: ptr.To("imex-channel"), - }, - "channel": { - IntValue: ptr.To(int64(i)), - }, - }, - }, - } - devices = append(devices, d) - } - - // Put them in a pool named after the IMEX domain with the IMEX domain label as a node selector - pool := resourceslice.Pool{ - NodeSelector: &v1.NodeSelector{ - NodeSelectorTerms: []v1.NodeSelectorTerm{ - { - MatchExpressions: []v1.NodeSelectorRequirement{ - { - Key: ImexDomainLabel, - Operator: v1.NodeSelectorOpIn, - Values: []string{ - imexDomain, - }, - }, - }, - }, - }, - }, - Slices: []resourceslice.Slice{ - { - Devices: devices, - }, - }, - } - - return pool -} diff --git a/cmd/nvidia-dra-controller/imexchannels.go b/cmd/nvidia-dra-controller/imexchannels.go new file mode 100644 index 000000000..bd8fc6b9b --- /dev/null +++ b/cmd/nvidia-dra-controller/imexchannels.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + + v1 "k8s.io/api/core/v1" + resourceapi "k8s.io/api/resource/v1beta1" + "k8s.io/dynamic-resource-allocation/resourceslice" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" +) + +const ( + ResourceSliceImexChannelLimit = 128 + DriverImexChannelLimit = 128 // 2048 +) + +type ImexChannelManager struct { + config *ManagerConfig + cancelContext context.CancelFunc + + resourceSliceImexChannelLimit int + driverImexChannelLimit int + driverResources *resourceslice.DriverResources + + controller *resourceslice.Controller +} + +func NewImexChannelManager(config *ManagerConfig) *ImexChannelManager { + driverResources := &resourceslice.DriverResources{ + Pools: make(map[string]resourceslice.Pool), + } + + m := &ImexChannelManager{ + config: config, + resourceSliceImexChannelLimit: ResourceSliceImexChannelLimit, + driverImexChannelLimit: DriverImexChannelLimit, + driverResources: driverResources, + controller: nil, // OK, because controller.Stop() checks for nil + } + + return m +} + +// Start starts an ImexChannelManager. +func (m *ImexChannelManager) Start(ctx context.Context) (rerr error) { + ctx, cancel := context.WithCancel(ctx) + m.cancelContext = cancel + + defer func() { + if rerr != nil { + if err := m.Stop(); err != nil { + klog.Errorf("error stopping ImexChannel manager: %v", err) + } + } + }() + + options := resourceslice.Options{ + DriverName: m.config.driverName, + KubeClient: m.config.clientsets.Core, + Resources: m.driverResources, + } + + controller, err := resourceslice.StartController(ctx, options) + if err != nil { + return fmt.Errorf("error starting resource slice controller: %w", err) + } + + m.controller = controller + + return nil +} + +// Stop stops an ImexChannelManager. +func (m *ImexChannelManager) Stop() error { + m.cancelContext() + m.controller.Stop() + return nil +} + +// CreateOrUpdatePool creates or updates a pool of IMEX channels for the given IMEX domain. +func (m *ImexChannelManager) CreateOrUpdatePool(imexDomainName string, nodeSelector *v1.NodeSelector) error { + var slices []resourceslice.Slice + for i := 0; i < m.driverImexChannelLimit; i += m.resourceSliceImexChannelLimit { + slice := m.generatePoolSlice(imexDomainName, i, m.resourceSliceImexChannelLimit) + slices = append(slices, slice) + } + + pool := resourceslice.Pool{ + NodeSelector: nodeSelector, + Slices: slices, + } + + m.driverResources.Pools[imexDomainName] = pool + m.controller.Update(m.driverResources) + + return nil +} + +// DeletePool deletes a pool of IMEX channels for the given IMEX domain. +func (m *ImexChannelManager) DeletePool(imexDomainName string) error { + delete(m.driverResources.Pools, imexDomainName) + m.controller.Update(m.driverResources) + return nil +} + +// generatePoolSlice generates the contents of a single ResourceSlice of IMEX channels in the given range. +func (m *ImexChannelManager) generatePoolSlice(imexDomainName string, startChannel, numChannels int) resourceslice.Slice { + var devices []resourceapi.Device + for i := startChannel; i < (startChannel + numChannels); i++ { + d := resourceapi.Device{ + Name: fmt.Sprintf("imex-channel-%d", i), + Basic: &resourceapi.BasicDevice{ + Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ + "type": { + StringValue: ptr.To("imex-channel"), + }, + "domain": { + StringValue: ptr.To(imexDomainName), + }, + "channel": { + IntValue: ptr.To(int64(i)), + }, + }, + }, + } + devices = append(devices, d) + } + + slice := resourceslice.Slice{ + Devices: devices, + } + + return slice +} diff --git a/cmd/nvidia-dra-controller/indexers.go b/cmd/nvidia-dra-controller/indexers.go new file mode 100644 index 000000000..c883d14ab --- /dev/null +++ b/cmd/nvidia-dra-controller/indexers.go @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/cache" +) + +func uidIndexer[T metav1.ObjectMetaAccessor](obj any) ([]string, error) { + d, ok := obj.(T) + if !ok { + return nil, fmt.Errorf("expected a %T but got %T", *new(T), obj) + } + return []string{string(d.GetObjectMeta().GetUID())}, nil +} + +func addMultiNodeEnvironmentLabelIndexer[T metav1.ObjectMetaAccessor](informer cache.SharedIndexInformer) error { + return informer.AddIndexers(cache.Indexers{ + "multiNodeEnvironmentLabel": func(obj any) ([]string, error) { + d, ok := obj.(T) + if !ok { + return nil, fmt.Errorf("expected a %T but got %T", *new(T), obj) + } + labels := d.GetObjectMeta().GetLabels() + if value, exists := labels[multiNodeEnvironmentLabelKey]; exists { + return []string{value}, nil + } + return nil, nil + }, + }) +} + +func getByMultiNodeEnvironmentUID[T1 *T2, T2 any](ctx context.Context, informer cache.SharedIndexInformer, mneUID string) (T1, error) { + if !cache.WaitForCacheSync(ctx.Done(), informer.HasSynced) { + return nil, fmt.Errorf("cache sync failed for Deployment") + } + + ds, err := informer.GetIndexer().ByIndex("multiNodeEnvironmentLabel", mneUID) + if err != nil { + return nil, fmt.Errorf("error getting %T via MultiNodeEnvironment label: %w", *new(T1), err) + } + if len(ds) > 1 { + return nil, fmt.Errorf("multiple %T object with same MultiNodeEnvironment label: %w", *new(T1), err) + } + if len(ds) == 0 { + return nil, nil + } + + d, ok := ds[0].(T1) + if !ok { + return nil, fmt.Errorf("failed to cast to %T", *new(T1)) + } + + return d, nil +} diff --git a/cmd/nvidia-dra-controller/main.go b/cmd/nvidia-dra-controller/main.go index d605a8963..ce4600538 100644 --- a/cmd/nvidia-dra-controller/main.go +++ b/cmd/nvidia-dra-controller/main.go @@ -43,6 +43,10 @@ import ( "github.com/NVIDIA/k8s-dra-driver/pkg/flags" ) +const ( + DriverName = "gpu.nvidia.com" +) + type Flags struct { kubeClientConfig flags.KubeClientConfig loggingConfig *flags.LoggingConfig @@ -58,6 +62,7 @@ type Flags struct { } type Config struct { + driverName string flags *Flags clientsets flags.ClientSets mux *http.ServeMux @@ -147,6 +152,7 @@ func newApp() *cli.App { mux: mux, flags: flags, clientsets: clientsets, + driverName: DriverName, } if flags.httpEndpoint != "" { @@ -159,21 +165,25 @@ func newApp() *cli.App { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) - var controller *Controller + if !config.flags.deviceClasses.Has(ImexChannelType) { + klog.InfoS("Not configured for IMEX support, blocking indefinitely") + <-sigs + return nil + } + + errChan := make(chan error, 1) + controller := NewController(config) ctx, cancel := context.WithCancel(c.Context) - defer func() { - cancel() - if err := controller.Stop(); err != nil { - klog.Errorf("Error stopping controller: %v", err) - } + go func() { + errChan <- controller.Run(ctx) }() - controller, err = StartController(ctx, config) - if err != nil { - return fmt.Errorf("start controller: %w", err) - } - <-sigs + cancel() + + if err := <-errChan; err != nil { + return fmt.Errorf("run controller: %w", err) + } return nil }, diff --git a/cmd/nvidia-dra-controller/mnenv.go b/cmd/nvidia-dra-controller/mnenv.go index d5ae8f58e..33204fae5 100644 --- a/cmd/nvidia-dra-controller/mnenv.go +++ b/cmd/nvidia-dra-controller/mnenv.go @@ -19,115 +19,140 @@ package main import ( "context" "fmt" + "sync" "time" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" - "k8s.io/utils/ptr" nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" - "github.com/NVIDIA/k8s-dra-driver/pkg/flags" + nvinformers "github.com/NVIDIA/k8s-dra-driver/pkg/nvidia.com/resource/informers/externalversions" nvlisters "github.com/NVIDIA/k8s-dra-driver/pkg/nvidia.com/resource/listers/gpu/v1alpha1" ) +type MultiNodeEnvironmentExistsFunc func(uid string) (bool, error) + const ( - multiNodeEnvironmentFinalizer = "gpu.nvidia.com/finalizer.multiNodeEnvironment" + informerResyncPeriod = 10 * time.Minute + + multiNodeEnvironmentLabelKey = "gpu.nvidia.com/multiNodeEnvironment" + multiNodeEnvironmentFinalizer = multiNodeEnvironmentLabelKey ) type MultiNodeEnvironmentManager struct { - clientsets flags.ClientSets + config *ManagerConfig + waitGroup sync.WaitGroup + cancelContext context.CancelFunc + factory nvinformers.SharedInformerFactory informer cache.SharedIndexInformer lister nvlisters.MultiNodeEnvironmentLister - resourceClaimManager *ResourceClaimManager + deploymentManager *DeploymentManager deviceClassManager *DeviceClassManager + resourceClaimManager *ResourceClaimManager } -// StartManager starts a MultiNodeEnvironmentManager. -func NewMultiNodeEnvironmentManager(ctx context.Context, config *ManagerConfig) (*MultiNodeEnvironmentManager, error) { - informer := config.nvInformerFactory.Gpu().V1alpha1().MultiNodeEnvironments().Informer() +// NewMultiNodeEnvironmentManager creates a new MultiNodeEnvironmentManager. +func NewMultiNodeEnvironmentManager(config *ManagerConfig) *MultiNodeEnvironmentManager { + factory := nvinformers.NewSharedInformerFactory(config.clientsets.Nvidia, informerResyncPeriod) + informer := factory.Gpu().V1alpha1().MultiNodeEnvironments().Informer() lister := nvlisters.NewMultiNodeEnvironmentLister(informer.GetIndexer()) m := &MultiNodeEnvironmentManager{ - clientsets: config.clientsets, - informer: informer, - lister: lister, + config: config, + factory: factory, + informer: informer, + lister: lister, } + m.deploymentManager = NewDeploymentManager(config, m.Exists) + m.deviceClassManager = NewDeviceClassManager(config, m.Exists) + m.resourceClaimManager = NewResourceClaimManager(config, m.Exists) + + return m +} - err := informer.AddIndexers(cache.Indexers{ - "uid": func(obj interface{}) ([]string, error) { - mne, ok := obj.(*nvapi.MultiNodeEnvironment) - if !ok { - return nil, fmt.Errorf("expected a MultiNodeEnvironment but got %T", obj) +// Start starts a MultiNodeEnvironmentManager. +func (m *MultiNodeEnvironmentManager) Start(ctx context.Context) (rerr error) { + ctx, cancel := context.WithCancel(ctx) + m.cancelContext = cancel + + defer func() { + if rerr != nil { + if err := m.Stop(); err != nil { + klog.Errorf("error stopping MultiNodeEnvironment manager: %v", err) } - return []string{string(mne.UID)}, nil - }, + } + }() + + err := m.informer.AddIndexers(cache.Indexers{ + "uid": uidIndexer[*nvapi.MultiNodeEnvironment], }) if err != nil { - return nil, fmt.Errorf("error adding indexer for MultiNodeEnvironment UUIDs: %w", err) + return fmt.Errorf("error adding indexer for UIDs: %w", err) } - _, err = informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + _, err = m.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj any) { - config.workQueue.Enqueue(obj, m.onMultiNodeEnvironmentAdd) + m.config.workQueue.Enqueue(obj, m.onMultiNodeEnvironmentAdd) + }, + DeleteFunc: func(obj any) { + m.config.workQueue.Enqueue(obj, m.onMultiNodeEnvironmentDelete) }, }) if err != nil { - return nil, fmt.Errorf("error adding event handlers for MultiNodeEnvironment informer: %w", err) + return fmt.Errorf("error adding event handlers for MultiNodeEnvironment informer: %w", err) } - m.resourceClaimManager, err = NewResourceClaimManager(ctx, config, m.Exists) - if err != nil { - return nil, fmt.Errorf("error creating ResourceClaim manager: %w", err) - } - - m.deviceClassManager, err = NewDeviceClassManager(ctx, config, m.Exists) - if err != nil { - return nil, fmt.Errorf("error creating DeviceClass manager: %w", err) - } - - return m, nil -} - -// WaitForCacheSync waits for the cache for MultiNodeEnvironments to sync. -func (m *MultiNodeEnvironmentManager) WaitForCacheSync(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + m.waitGroup.Add(1) + go func() { + defer m.waitGroup.Done() + m.factory.Start(ctx.Done()) + }() if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { return fmt.Errorf("informer cache sync for MultiNodeEnvironments failed") } - if err := m.resourceClaimManager.WaitForCacheSync(ctx); err != nil { - return fmt.Errorf("error syncing dependency cache: %w", err) + if err := m.deploymentManager.Start(ctx); err != nil { + return fmt.Errorf("error starting Deployment manager: %w", err) + } + + if err := m.deviceClassManager.Start(ctx); err != nil { + return fmt.Errorf("error creating DeviceClass manager: %w", err) } - if err := m.deviceClassManager.WaitForCacheSync(ctx); err != nil { - return fmt.Errorf("error syncing dependency cache: %w", err) + if err := m.resourceClaimManager.Start(ctx); err != nil { + return fmt.Errorf("error creating ResourceClaim manager: %w", err) } return nil } -// Exists checks if a MultiNodeEnvironment with a specific UID exists. -func (m *MultiNodeEnvironmentManager) Exists(ctx context.Context, uid string) (bool, error) { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { - return false, fmt.Errorf("cache sync failed for MultiNodeEnvironment") +func (m *MultiNodeEnvironmentManager) Stop() error { + if err := m.deploymentManager.Stop(); err != nil { + return fmt.Errorf("error stopping Deployment manager: %w", err) + } + if err := m.resourceClaimManager.Stop(); err != nil { + return fmt.Errorf("error stopping ResourceClaim manager: %w", err) + } + if err := m.deviceClassManager.Stop(); err != nil { + return fmt.Errorf("error stopping DeviceClass manager: %w", err) } + m.cancelContext() + m.waitGroup.Wait() + return nil +} +// Exists checks if a MultiNodeEnvironment with a specific UID exists. +func (m *MultiNodeEnvironmentManager) Exists(uid string) (bool, error) { mnes, err := m.informer.GetIndexer().ByIndex("uid", uid) if err != nil { - return false, fmt.Errorf("error retrieving MultiNodeInformer OwnerReference by UID from indexer: %w", err) + return false, fmt.Errorf("error retrieving MultiNodeInformer by UID: %w", err) } if len(mnes) == 0 { return false, nil } - return true, nil } @@ -139,28 +164,43 @@ func (m *MultiNodeEnvironmentManager) onMultiNodeEnvironmentAdd(ctx context.Cont klog.Infof("Processing added MultiNodeEnvironment: %s/%s", mne.Namespace, mne.Name) - gvk := nvapi.SchemeGroupVersion.WithKind("MultiNodeEnvironment") - mne.APIVersion = gvk.GroupVersion().String() - mne.Kind = gvk.Kind - - ownerReference := metav1.OwnerReference{ - APIVersion: mne.APIVersion, - Kind: mne.Kind, - Name: mne.Name, - UID: mne.UID, - Controller: ptr.To(true), + if _, err := m.deploymentManager.Create(ctx, m.config.driverNamespace, mne.Spec.NumNodes, mne); err != nil { + return fmt.Errorf("error creating Deployment: %w", err) } - dc, err := m.deviceClassManager.Create(ctx, mne.Spec.DeviceClassName, ownerReference) + dc, err := m.deviceClassManager.Create(ctx, mne.Spec.DeviceClassName, mne) if err != nil { - return fmt.Errorf("error creating DeviceClass '%s': %w", "", err) + return fmt.Errorf("error creating DeviceClass: %w", err) } if mne.Spec.ResourceClaimName != "" { - if _, err := m.resourceClaimManager.Create(ctx, mne.Namespace, mne.Spec.ResourceClaimName, dc.Name, ownerReference); err != nil { + if _, err := m.resourceClaimManager.Create(ctx, mne.Namespace, mne.Spec.ResourceClaimName, dc.Name, mne); err != nil { return fmt.Errorf("error creating ResourceClaim '%s/%s': %w", mne.Namespace, mne.Spec.ResourceClaimName, err) } } return nil } + +func (m *MultiNodeEnvironmentManager) onMultiNodeEnvironmentDelete(ctx context.Context, obj any) error { + mne, ok := obj.(*nvapi.MultiNodeEnvironment) + if !ok { + return fmt.Errorf("failed to cast to MultiNodeEnvironment") + } + + klog.Infof("Processing deleted MultiNodeEnvironment: %s/%s", mne.Namespace, mne.Name) + + if err := m.deploymentManager.Delete(ctx, string(mne.UID)); err != nil { + return fmt.Errorf("error deleting Deployment: %w", err) + } + + if err := m.deviceClassManager.Delete(ctx, string(mne.UID)); err != nil { + return fmt.Errorf("error deleting DeviceClass: %w", err) + } + + if err := m.resourceClaimManager.Delete(ctx, string(mne.UID)); err != nil { + return fmt.Errorf("error deleting ResourceClaim '%s/%s': %w", mne.Namespace, mne.Spec.ResourceClaimName, err) + } + + return nil +} diff --git a/cmd/nvidia-dra-controller/resourceclaim.go b/cmd/nvidia-dra-controller/resourceclaim.go index 05cf0d3fb..2fb884b2d 100644 --- a/cmd/nvidia-dra-controller/resourceclaim.go +++ b/cmd/nvidia-dra-controller/resourceclaim.go @@ -19,82 +19,126 @@ package main import ( "context" "fmt" - "time" + "sync" resourceapi "k8s.io/api/resource/v1beta1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/informers" resourcelisters "k8s.io/client-go/listers/resource/v1beta1" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" nvapi "github.com/NVIDIA/k8s-dra-driver/api/nvidia.com/resource/gpu/v1alpha1" - "github.com/NVIDIA/k8s-dra-driver/pkg/flags" ) type ResourceClaimManager struct { - clientsets flags.ClientSets - ownerExists func(ctx context.Context, uid string) (bool, error) + config *ManagerConfig + waitGroup sync.WaitGroup + cancelContext context.CancelFunc + multiNodeEnvironmentExists MultiNodeEnvironmentExistsFunc + factory informers.SharedInformerFactory informer cache.SharedIndexInformer lister resourcelisters.ResourceClaimLister } -func NewResourceClaimManager(ctx context.Context, config *ManagerConfig, ownerExists OwnerExistsFunc) (*ResourceClaimManager, error) { - informer := config.coreInformerFactory.Resource().V1beta1().ResourceClaims().Informer() - lister := resourcelisters.NewResourceClaimLister(informer.GetIndexer()) +func NewResourceClaimManager(config *ManagerConfig, mneExists MultiNodeEnvironmentExistsFunc) *ResourceClaimManager { + labelSelector := &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: multiNodeEnvironmentLabelKey, + Operator: metav1.LabelSelectorOpExists, + }, + }, + } + + factory := informers.NewSharedInformerFactoryWithOptions( + config.clientsets.Core, + informerResyncPeriod, + informers.WithTweakListOptions(func(opts *metav1.ListOptions) { + opts.LabelSelector = metav1.FormatLabelSelector(labelSelector) + }), + ) + + informer := factory.Resource().V1beta1().ResourceClaims().Informer() + lister := factory.Resource().V1beta1().ResourceClaims().Lister() m := &ResourceClaimManager{ - clientsets: config.clientsets, - ownerExists: ownerExists, - informer: informer, - lister: lister, + config: config, + multiNodeEnvironmentExists: mneExists, + factory: factory, + informer: informer, + lister: lister, + } + + return m +} + +func (m *ResourceClaimManager) Start(ctx context.Context) (rerr error) { + ctx, cancel := context.WithCancel(ctx) + m.cancelContext = cancel + + defer func() { + if rerr != nil { + if err := m.Stop(); err != nil { + klog.Errorf("error stopping ResourceClaim manager: %v", err) + } + } + }() + + if err := addMultiNodeEnvironmentLabelIndexer[*resourceapi.ResourceClaim](m.informer); err != nil { + return fmt.Errorf("error adding indexer for MulitNodeEnvironment label: %w", err) } - _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + _, err := m.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj any) { - config.workQueue.Enqueue(obj, m.onAddOrUpdate) + m.config.workQueue.Enqueue(obj, m.onAddOrUpdate) }, UpdateFunc: func(objOld, objNew any) { - config.workQueue.Enqueue(objNew, m.onAddOrUpdate) + m.config.workQueue.Enqueue(objNew, m.onAddOrUpdate) }, }) if err != nil { - return nil, fmt.Errorf("error adding event handlers for ResourceClaim informer: %w", err) + return fmt.Errorf("error adding event handlers for ResourceClaim informer: %w", err) } - return m, nil -} - -func (m *ResourceClaimManager) WaitForCacheSync(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + m.waitGroup.Add(1) + go func() { + defer m.waitGroup.Done() + m.factory.Start(ctx.Done()) + }() if !cache.WaitForCacheSync(ctx.Done(), m.informer.HasSynced) { - - return fmt.Errorf("informer cache sync for ResourceClaims failed") + return fmt.Errorf("informer cache sync for ResourceClaim failed") } + return nil } -func (m *ResourceClaimManager) Create(ctx context.Context, namespace, name, deviceClassName string, ownerReference metav1.OwnerReference) (*resourceapi.ResourceClaim, error) { - rc, err := m.lister.ResourceClaims(namespace).Get(name) - if err == nil { - if len(rc.OwnerReferences) != 1 && rc.OwnerReferences[0] != ownerReference { - return nil, fmt.Errorf("ResourceClaim '%s/%s' exists without expected OwnerReference: %v", namespace, name, ownerReference) - } - return rc, nil - } - if !errors.IsNotFound(err) { +func (m *ResourceClaimManager) Stop() error { + m.cancelContext() + m.waitGroup.Wait() + return nil +} + +func (m *ResourceClaimManager) Create(ctx context.Context, namespace, name, deviceClassName string, mne *nvapi.MultiNodeEnvironment) (*resourceapi.ResourceClaim, error) { + rc, err := getByMultiNodeEnvironmentUID[*resourceapi.ResourceClaim](ctx, m.informer, string(mne.UID)) + if err != nil { return nil, fmt.Errorf("error retrieving ResourceClaim: %w", err) } + if rc != nil { + return rc, nil + } resourceClaim := &resourceapi.ResourceClaim{ ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: namespace, - OwnerReferences: []metav1.OwnerReference{ownerReference}, - Finalizers: []string{multiNodeEnvironmentFinalizer}, + Name: name, + Namespace: namespace, + Finalizers: []string{multiNodeEnvironmentFinalizer}, + Labels: map[string]string{ + multiNodeEnvironmentLabelKey: string(mne.UID), + }, }, Spec: resourceapi.ResourceClaimSpec{ Devices: resourceapi.DeviceClaim{ @@ -105,7 +149,7 @@ func (m *ResourceClaimManager) Create(ctx context.Context, namespace, name, devi }, } - rc, err = m.clientsets.Core.ResourceV1beta1().ResourceClaims(resourceClaim.Namespace).Create(ctx, resourceClaim, metav1.CreateOptions{}) + rc, err = m.config.clientsets.Core.ResourceV1beta1().ResourceClaims(resourceClaim.Namespace).Create(ctx, resourceClaim, metav1.CreateOptions{}) if err != nil { return nil, fmt.Errorf("error creating ResourceClaim: %w", err) } @@ -113,11 +157,24 @@ func (m *ResourceClaimManager) Create(ctx context.Context, namespace, name, devi return rc, nil } -func (m *ResourceClaimManager) Delete(ctx context.Context, namespace, name string) error { - err := m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Delete(ctx, name, metav1.DeleteOptions{}) +func (m *ResourceClaimManager) Delete(ctx context.Context, mneUID string) error { + rc, err := getByMultiNodeEnvironmentUID[*resourceapi.ResourceClaim](ctx, m.informer, mneUID) + if err != nil { + return fmt.Errorf("error retrieving ResourceClaim: %w", err) + } + if rc == nil { + return nil + } + + if err := m.RemoveFinalizer(ctx, rc.Namespace, rc.Name); err != nil { + return fmt.Errorf("error removing finalizer on ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) + } + + err = m.config.clientsets.Core.ResourceV1beta1().ResourceClaims(rc.Namespace).Delete(ctx, rc.Name, metav1.DeleteOptions{}) if err != nil && !errors.IsNotFound(err) { return fmt.Errorf("erroring deleting ResourceClaim: %w", err) } + return nil } @@ -139,7 +196,7 @@ func (m *ResourceClaimManager) RemoveFinalizer(ctx context.Context, namespace, n } } - _, err = m.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Update(ctx, newRC, metav1.UpdateOptions{}) + _, err = m.config.clientsets.Core.ResourceV1beta1().ResourceClaims(namespace).Update(ctx, newRC, metav1.UpdateOptions{}) if err != nil { return fmt.Errorf("error updating ResourceClaim: %w", err) } @@ -153,31 +210,26 @@ func (m *ResourceClaimManager) onAddOrUpdate(ctx context.Context, obj any) error return fmt.Errorf("failed to cast to ResourceClaim") } - klog.Infof("Processing added or updated ResourceClaim: %s/%s", rc.Namespace, rc.Name) - - if len(rc.OwnerReferences) != 1 { + rc, err := m.lister.ResourceClaims(rc.Namespace).Get(rc.Name) + if err != nil && errors.IsNotFound(err) { return nil } - - if rc.OwnerReferences[0].Kind != nvapi.MultiNodeEnvironmentKind { - return nil + if err != nil { + return fmt.Errorf("erroring retreiving ResourceClaim: %w", err) } - exists, err := m.ownerExists(ctx, string(rc.OwnerReferences[0].UID)) + klog.Infof("Processing added or updated ResourceClaim: %s/%s", rc.Namespace, rc.Name) + + exists, err := m.multiNodeEnvironmentExists(rc.Labels[multiNodeEnvironmentLabelKey]) if err != nil { return fmt.Errorf("error checking if owner exists: %w", err) } - if exists { + if !exists { + if err := m.Delete(ctx, rc.Labels[multiNodeEnvironmentLabelKey]); err != nil { + return fmt.Errorf("error deleting ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) + } return nil } - if err := m.RemoveFinalizer(ctx, rc.Namespace, rc.Name); err != nil { - return fmt.Errorf("error removing finalizer on ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) - } - - if err := m.Delete(ctx, rc.Namespace, rc.Name); err != nil { - return fmt.Errorf("error deleting ResourceClaim '%s/%s': %w", rc.Namespace, rc.Name, err) - } - return nil } diff --git a/deployments/helm/k8s-dra-driver/templates/clusterrole.yaml b/deployments/helm/k8s-dra-driver/templates/clusterrole.yaml index a223c6fea..37bd88977 100644 --- a/deployments/helm/k8s-dra-driver/templates/clusterrole.yaml +++ b/deployments/helm/k8s-dra-driver/templates/clusterrole.yaml @@ -7,19 +7,25 @@ metadata: rules: - apiGroups: ["gpu.nvidia.com"] resources: ["multinodeenvironments"] - verbs: ["get", "list", "watch"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - apiGroups: ["resource.k8s.io"] resources: ["resourceclaims"] - verbs: ["get", "list", "watch", "create", "update", "delete"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - apiGroups: ["resource.k8s.io"] resources: ["deviceclasses"] - verbs: ["get", "list", "watch", "create", "update", "delete"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] +- apiGroups: ["resource.k8s.io"] + resources: ["resourceslices"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] +- apiGroups: ["apps"] + resources: ["deployments"] + verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] - apiGroups: ["resource.k8s.io"] resources: ["resourceclaims/status"] verbs: ["update"] +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] - apiGroups: [""] resources: ["nodes"] verbs: ["get", "list", "watch"] -- apiGroups: ["resource.k8s.io"] - resources: ["resourceslices"] - verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] diff --git a/go.mod b/go.mod index 9197d6889..f0ca7031a 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/NVIDIA/go-nvlib v0.7.0 github.com/NVIDIA/go-nvml v0.12.4-0 github.com/NVIDIA/nvidia-container-toolkit v1.16.2 + github.com/google/uuid v1.6.0 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 github.com/spf13/pflag v1.0.5 @@ -46,7 +47,6 @@ require ( github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect diff --git a/pkg/workqueue/workqueue.go b/pkg/workqueue/workqueue.go index 3488ba78f..82be6f0c5 100644 --- a/pkg/workqueue/workqueue.go +++ b/pkg/workqueue/workqueue.go @@ -19,7 +19,6 @@ package workqueue import ( "context" "fmt" - "time" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/util/workqueue" @@ -54,9 +53,7 @@ func (q *WorkQueue) Run(ctx context.Context) { case <-ctx.Done(): return default: - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) q.processNextWorkItem(ctx) - cancel() } } } @@ -91,7 +88,7 @@ func (q *WorkQueue) processNextWorkItem(ctx context.Context) { err := q.reconcile(ctx, workItem) if err != nil { - klog.Errorf("Failed to reconcile work item %v: %v", workItem.Object, err) + klog.Errorf("Failed to reconcile work item: %v", err) q.queue.AddRateLimited(workItem) } else { q.queue.Forget(workItem)