Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jmdeal committed Nov 22, 2024
1 parent 84a9ba1 commit ad7c521
Show file tree
Hide file tree
Showing 35 changed files with 93 additions and 67 deletions.
4 changes: 2 additions & 2 deletions pkg/apis/v1/nodepool_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ var _ = Describe("CEL/Default", func() {
Spec: NodeClaimTemplateSpec{
NodeClassRef: &NodeClassReference{
Group: "karpenter.test.sh",
Kind: "TestNodeClaim",
Name: "default",
Kind: "TestNodeClaim",
Name: "default",
},
Requirements: []NodeSelectorRequirementWithMinValues{
{
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/metrics/nodepool/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,6 @@ func makeLabels(nodePool *v1.NodePool, resourceTypeName string) prometheus.Label
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("metrics.nodepool").
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsManagedPredicates(c.cloudProvider))).
Complete(c)
}
2 changes: 1 addition & 1 deletion pkg/controllers/metrics/nodepool/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var _ = Describe("Metrics", func() {
corev1.ResourceEphemeralStorage: resource.MustParse("100Gi"),
}
nodePool.Spec.Limits = limits
if isNodePoolManaged {
if !isNodePoolManaged {
nodePool.Spec.Template.Spec.NodeClassRef = &v1.NodeClassReference{
Group: "karpenter.k8s.aws",
Kind: "EC2NodeClass",
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/node/health/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ var _ = BeforeSuite(func() {
env = test.NewEnvironment(
test.WithCRDs(apis.CRDs...),
test.WithCRDs(v1alpha1.CRDs...),
test.WithFieldIndexers(test.NodeClaimFieldIndexer(ctx), test.VolumeAttachmentFieldIndexer(ctx), test.NodeFieldIndexer(ctx)),
test.WithFieldIndexers(test.NodeClaimProviderIDFieldIndexer(ctx), test.VolumeAttachmentFieldIndexer(ctx), test.NodeProviderIDFieldIndexer(ctx)),
)
cloudProvider = fake.NewCloudProvider()
cloudProvider = fake.NewCloudProvider()
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/node/termination/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func (c *Controller) nodeTerminationTime(node *corev1.Node, nodeClaims ...*v1.No
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("node.termination").
For(&corev1.Node{}, builder.WithPredicates(nodeutils.IsMangedPredicates(c.cloudProvider))).
For(&corev1.Node{}, builder.WithPredicates(nodeutils.IsManagedPredicates(c.cloudProvider))).
WithOptions(
controller.Options{
RateLimiter: workqueue.NewTypedMaxOfRateLimiter[reconcile.Request](
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/node/termination/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ var _ = BeforeSuite(func() {
env = test.NewEnvironment(
test.WithCRDs(apis.CRDs...),
test.WithCRDs(v1alpha1.CRDs...),
test.WithFieldIndexers(test.NodeClaimFieldIndexer(ctx), test.VolumeAttachmentFieldIndexer(ctx)),
test.WithFieldIndexers(test.NodeClaimProviderIDFieldIndexer(ctx), test.VolumeAttachmentFieldIndexer(ctx)),
)

cloudProvider = fake.NewCloudProvider()
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/consistency/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (c *Controller) checkConsistency(ctx context.Context, nodeClaim *v1.NodeCla
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("nodeclaim.consistency").
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutil.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutil.IsManagedPredicates(c.cloudProvider))).
Watches(
&corev1.Node{},
nodeclaimutil.NodeEventHandler(c.kubeClient, nodeclaimutil.WithManagedFilter(c.cloudProvider)),
Expand Down
6 changes: 3 additions & 3 deletions pkg/controllers/nodeclaim/consistency/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ var _ = BeforeSuite(func() {
env = test.NewEnvironment(
test.WithCRDs(apis.CRDs...),
test.WithCRDs(v1alpha1.CRDs...),
test.WithFieldIndexers(test.NodeClaimFieldIndexer(ctx), test.NodeFieldIndexer(ctx)),
test.WithFieldIndexers(test.NodeClaimProviderIDFieldIndexer(ctx), test.NodeProviderIDFieldIndexer(ctx)),
)
ctx = options.ToContext(ctx, test.Options())
cp = &fake.CloudProvider{}
Expand Down Expand Up @@ -179,8 +179,8 @@ var _ = Describe("NodeClaimController", func() {
Spec: v1.NodeClaimSpec{
NodeClassRef: &v1.NodeClassReference{
Group: "karpenter.k8s.aws",
Kind: "EC2NodeClass",
Name: "default",
Kind: "EC2NodeClass",
Name: "default",
},
},
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/disruption/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (c *Controller) Reconcile(ctx context.Context, nodeClaim *v1.NodeClaim) (re
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
b := controllerruntime.NewControllerManagedBy(m).
Named("nodeclaim.disruption").
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutil.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutil.IsManagedPredicates(c.cloudProvider))).
WithOptions(controller.Options{MaxConcurrentReconciles: 10}).
// Note: We don't use the ManagedFilter (NodeClaim) for NodePool updates because drift should be captured when
// updating a NodePool's NodeClassRef to an unsupported NodeClass. However, this is currently unsupported
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/disruption/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestAPIs(t *testing.T) {

var _ = BeforeSuite(func() {
fakeClock = clock.NewFakeClock(time.Now())
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeFieldIndexer(ctx)))
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeProviderIDFieldIndexer(ctx)))
ctx = options.ToContext(ctx, test.Options())
cp = fake.NewCloudProvider()
nodeClaimDisruptionController = nodeclaimdisruption.NewController(fakeClock, env.Client, cp)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/expiration/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ func (c *Controller) Reconcile(ctx context.Context, nodeClaim *v1.NodeClaim) (re
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("nodeclaim.expiration").
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutils.IsManagedPredicates(c.cloudProvider))).
Complete(reconcile.AsReconciler(m.GetClient(), c))
}
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/expiration/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestAPIs(t *testing.T) {

var _ = BeforeSuite(func() {
fakeClock = clock.NewFakeClock(time.Now())
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeFieldIndexer(ctx)))
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeProviderIDFieldIndexer(ctx)))
ctx = options.ToContext(ctx, test.Options())
cp = fake.NewCloudProvider()
expirationController = expiration.NewController(fakeClock, env.Client, cp)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/garbagecollection/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestAPIs(t *testing.T) {

var _ = BeforeSuite(func() {
fakeClock = clock.NewFakeClock(time.Now())
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeFieldIndexer(ctx)))
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeProviderIDFieldIndexer(ctx)))
ctx = options.ToContext(ctx, test.Options())
cloudProvider = fake.NewCloudProvider()
garbageCollectionController = nodeclaimgarbagecollection.NewController(fakeClock, env.Client, cloudProvider)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/lifecycle/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func NewController(clk clock.Clock, kubeClient client.Client, cloudProvider clou
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named(c.Name()).
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutils.IsManagedPredicates(c.cloudProvider))).
Watches(
&corev1.Node{},
nodeclaimutils.NodeEventHandler(c.kubeClient, nodeclaimutils.WithManagedFilter(c.cloudProvider)),
Expand Down
4 changes: 2 additions & 2 deletions pkg/controllers/nodeclaim/lifecycle/initialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ var _ = Describe("Initialization", func() {
Spec: v1.NodeClaimSpec{
NodeClassRef: &v1.NodeClassReference{
Group: "karpenter.k8s.aws",
Kind: "EC2NodeClass",
Name: "default",
Kind: "EC2NodeClass",
Name: "default",
},
},
})
Expand Down
4 changes: 2 additions & 2 deletions pkg/controllers/nodeclaim/lifecycle/liveness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ var _ = Describe("Liveness", func() {
Spec: v1.NodeClaimSpec{
NodeClassRef: &v1.NodeClassReference{
Group: "karpenter.k8s.aws",
Kind: "EC2NodeClass",
Name: "default",
Kind: "EC2NodeClass",
Name: "default",
},
},
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/lifecycle/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestAPIs(t *testing.T) {

var _ = BeforeSuite(func() {
fakeClock = clock.NewFakeClock(time.Now())
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeFieldIndexer(ctx)))
env = test.NewEnvironment(test.WithCRDs(apis.CRDs...), test.WithCRDs(v1alpha1.CRDs...), test.WithFieldIndexers(test.NodeProviderIDFieldIndexer(ctx)))
ctx = options.ToContext(ctx, test.Options())

cloudProvider = fake.NewCloudProvider()
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodeclaim/podevents/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var _ = BeforeSuite(func() {
env = test.NewEnvironment(
test.WithCRDs(apis.CRDs...),
test.WithCRDs(v1alpha1.CRDs...),
test.WithFieldIndexers(test.NodeClaimFieldIndexer(ctx), test.NodeFieldIndexer(ctx)),
test.WithFieldIndexers(test.NodeClaimProviderIDFieldIndexer(ctx), test.NodeProviderIDFieldIndexer(ctx)),
)
ctx = options.ToContext(ctx, test.Options())
cp = fake.NewCloudProvider()
Expand Down
4 changes: 2 additions & 2 deletions pkg/controllers/nodepool/counter/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"k8s.io/apimachinery/pkg/api/equality"
"k8s.io/apimachinery/pkg/api/resource"
controllerruntime "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller"
"sigs.k8s.io/controller-runtime/pkg/manager"
Expand Down Expand Up @@ -113,10 +114,9 @@ func (c *Controller) resourceCountsFor(ownerLabel string, ownerName string) core
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("nodepool.counter").
For(&v1.NodePool{}).
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsManagedPredicates(c.cloudProvider))).
Watches(&v1.NodeClaim{}, nodepoolutils.NodeClaimEventHandler()).
Watches(&corev1.Node{}, nodepoolutils.NodeEventHandler()).
WithEventFilter(nodepoolutils.IsMangedPredicates(c.cloudProvider)).
WithOptions(controller.Options{MaxConcurrentReconciles: 10}).
Complete(reconcile.AsReconciler(m.GetClient(), c))

Expand Down
6 changes: 3 additions & 3 deletions pkg/controllers/nodepool/counter/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ var _ = Describe("Counter", func() {
nodePool = test.NodePool(v1.NodePool{Spec: v1.NodePoolSpec{Template: v1.NodeClaimTemplate{Spec: v1.NodeClaimTemplateSpec{
NodeClassRef: &v1.NodeClassReference{
Group: "karpenter.k8s.aws",
Kind: "EC2NodeClass",
Name: "default",
Kind: "EC2NodeClass",
Name: "default",
},
}}}})
nodeClaim, node = test.NodeClaimAndNode(v1.NodeClaim{
ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{
v1.NodePoolLabelKey: nodePool.Name,
v1.NodePoolLabelKey: nodePool.Name,
}},
Status: v1.NodeClaimStatus{
ProviderID: test.RandomProviderID(),
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodepool/hash/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Controller) Reconcile(ctx context.Context, np *v1.NodePool) (reconcile.
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("nodepool.hash").
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsManagedPredicates(c.cloudProvider))).
WithOptions(controller.Options{MaxConcurrentReconciles: 10}).
Complete(reconcile.AsReconciler(m.GetClient(), c))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodepool/readiness/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (c *Controller) setReadyCondition(nodePool *v1.NodePool, nodeClass status.O
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
b := controllerruntime.NewControllerManagedBy(m).
Named("nodepool.readiness").
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsManagedPredicates(c.cloudProvider))).
WithOptions(controller.Options{MaxConcurrentReconciles: 10})
for _, nodeClass := range c.cloudProvider.GetSupportedNodeClasses() {
b.Watches(nodeClass, nodepoolutils.NodeClassEventHandler(c.kubeClient))
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/nodepool/validation/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c *Controller) Reconcile(ctx context.Context, nodePool *v1.NodePool) (reco
func (c *Controller) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("nodepool.validation").
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodePool{}, builder.WithPredicates(nodepoolutils.IsManagedPredicates(c.cloudProvider))).
WithOptions(controller.Options{MaxConcurrentReconciles: 10}).
Complete(reconcile.AsReconciler(m.GetClient(), c))
}
2 changes: 1 addition & 1 deletion pkg/controllers/state/informer/nodeclaim.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (c *NodeClaimController) Reconcile(ctx context.Context, req reconcile.Reque
func (c *NodeClaimController) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("state.nodeclaim").
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutils.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodeClaim{}, builder.WithPredicates(nodeclaimutils.IsManagedPredicates(c.cloudProvider))).
WithOptions(controller.Options{MaxConcurrentReconciles: 10}).
Complete(c)
}
2 changes: 1 addition & 1 deletion pkg/controllers/state/informer/nodepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (c *NodePoolController) Reconcile(ctx context.Context, np *v1.NodePool) (re
func (c *NodePoolController) Register(_ context.Context, m manager.Manager) error {
return controllerruntime.NewControllerManagedBy(m).
Named("state.nodepool").
For(&v1.NodePool{}, builder.WithPredicates(nodepool.IsMangedPredicates(c.cloudProvider))).
For(&v1.NodePool{}, builder.WithPredicates(nodepool.IsManagedPredicates(c.cloudProvider))).
WithOptions(controller.Options{MaxConcurrentReconciles: 10}).
WithEventFilter(predicate.GenerationChangedPredicate{}).
WithEventFilter(predicate.Funcs{DeleteFunc: func(event event.DeleteEvent) bool { return false }}).
Expand Down
1 change: 0 additions & 1 deletion pkg/controllers/state/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,6 @@ var _ = Describe("Node Resource Level", func() {
v1.NodePoolLabelKey: nodePool.Name,
corev1.LabelInstanceTypeStable: cloudProvider.InstanceTypes[0].Name,
},

},
Spec: v1.NodeClaimSpec{
Requirements: []v1.NodeSelectorRequirementWithMinValues{
Expand Down
30 changes: 19 additions & 11 deletions pkg/test/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/awslabs/operatorpkg/option"
"github.com/samber/lo"
"go.uber.org/multierr"
corev1 "k8s.io/api/core/v1"
storagev1 "k8s.io/api/storage/v1"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
Expand Down Expand Up @@ -69,31 +70,38 @@ func WithFieldIndexers(fieldIndexers ...func(cache.Cache) error) option.Function
}
}

// NodeFieldIndexer provides indexes on the following fields:
//
// - spec.providerID
func NodeFieldIndexer(ctx context.Context) func(cache.Cache) error {
func NodeProviderIDFieldIndexer(ctx context.Context) func(cache.Cache) error {
return func(c cache.Cache) error {
return c.IndexField(ctx, &corev1.Node{}, "spec.providerID", func(obj client.Object) []string {
return []string{obj.(*corev1.Node).Spec.ProviderID}
})
}
}

// NodeClaimFieldIndexer provides indexes on the following fields:
//
// - status.providerID
func NodeClaimFieldIndexer(ctx context.Context) func(cache.Cache) error {
func NodeClaimProviderIDFieldIndexer(ctx context.Context) func(cache.Cache) error {
return func(c cache.Cache) error {
return c.IndexField(ctx, &v1.NodeClaim{}, "status.providerID", func(obj client.Object) []string {
return []string{obj.(*v1.NodeClaim).Status.ProviderID}
})
}
}

// VolumeAttachmentFieldIndexer provides indexes on the following fields:
//
// - status.nodeName
func NodeClaimNodeClassRefFieldIndexer(ctx context.Context) func(cache.Cache) error {
return func(c cache.Cache) error {
var err error
err = multierr.Append(err, c.IndexField(ctx, &v1.NodeClaim{}, "spec.nodeClassRef.group", func(obj client.Object) []string {
return []string{obj.(*v1.NodeClaim).Spec.NodeClassRef.Group}
}))
err = multierr.Append(err, c.IndexField(ctx, &v1.NodeClaim{}, "spec.nodeClassRef.kind", func(obj client.Object) []string {
return []string{obj.(*v1.NodeClaim).Spec.NodeClassRef.Kind}
}))
err = multierr.Append(err, c.IndexField(ctx, &v1.NodeClaim{}, "spec.nodeClassRef.name", func(obj client.Object) []string {
return []string{obj.(*v1.NodeClaim).Spec.NodeClassRef.Name}
}))
return err
}
}

func VolumeAttachmentFieldIndexer(ctx context.Context) func(cache.Cache) error {
return func(c cache.Cache) error {
return c.IndexField(ctx, &storagev1.VolumeAttachment{}, "spec.nodeName", func(obj client.Object) []string {
Expand Down
14 changes: 11 additions & 3 deletions pkg/test/nodeclaim.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package test
import (
"fmt"

"github.com/awslabs/operatorpkg/object"
"github.com/imdario/mergo"
"github.com/samber/lo"
corev1 "k8s.io/api/core/v1"

v1 "sigs.k8s.io/karpenter/pkg/apis/v1"
"sigs.k8s.io/karpenter/pkg/utils/nodeclaim"
)

// NodeClaim creates a test NodeClaim with defaults that can be overridden by overrides.
Expand All @@ -42,11 +45,14 @@ func NodeClaim(overrides ...v1.NodeClaim) *v1.NodeClaim {
}
if override.Spec.NodeClassRef == nil {
override.Spec.NodeClassRef = &v1.NodeClassReference{
Group: object.GVK(defaultNodeClass).Group,
Kind: object.GVK(defaultNodeClass).Kind,
Name: "default",
Group: "karpenter.test.sh",
Kind: "TestNodeClass",
}
}
override.Labels = lo.Assign(map[string]string{
nodeclaim.NodeClassLabelKey(override.Spec.NodeClassRef): override.Spec.NodeClassRef.Name,
}, override.Labels)
if override.Spec.Requirements == nil {
override.Spec.Requirements = []v1.NodeSelectorRequirementWithMinValues{}
}
Expand All @@ -59,7 +65,9 @@ func NodeClaim(overrides ...v1.NodeClaim) *v1.NodeClaim {

func NodeClaimAndNode(overrides ...v1.NodeClaim) (*v1.NodeClaim, *corev1.Node) {
nc := NodeClaim(overrides...)
return nc, NodeClaimLinkedNode(nc)
node := NodeClaimLinkedNode(nc)
nc.Status.NodeName = node.Name
return nc, node
}

// NodeClaimsAndNodes creates homogeneous groups of NodeClaims and Nodes based on the passed in options, evenly divided by the total nodeclaims requested
Expand Down
10 changes: 10 additions & 0 deletions pkg/test/nodeclass.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ import (
"sigs.k8s.io/karpenter/pkg/test/v1alpha1"
)

var (
// defaultNodeClass is the default NodeClass type used when creating NodeClassRefs for NodePools and NodeClaims
defaultNodeClass status.Object = &v1alpha1.TestNodeClass{}
)

// SetDefaultNodeClassType configures the default NodeClass type used when generating NodeClassRefs for test NodePools and NodeClaims.
func SetDefaultNodeClassType(nc status.Object) {
defaultNodeClass = nc
}

// NodeClass creates a test NodeClass with defaults that can be overridden by overrides.
// Overrides are applied in order, with a last write wins semantic.
func NodeClass(overrides ...v1alpha1.TestNodeClass) *v1alpha1.TestNodeClass {
Expand Down
Loading

0 comments on commit ad7c521

Please sign in to comment.