Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 52 additions & 44 deletions pkg/nodeidentity/azure/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,81 +22,97 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
compute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"k8s.io/kops/upup/pkg/fi"
)

type instanceComputeMetadata struct {
ResourceGroupName string `json:"resourceGroupName"`
SubscriptionID string `json:"subscriptionId"`
}

type instanceMetadata struct {
Compute *instanceComputeMetadata `json:"compute"`
}

// client is an Azure client.
type client struct {
metadata *instanceMetadata
vmssesClient *compute.VirtualMachineScaleSetsClient
subscriptionID string
vmClient *compute.VirtualMachinesClient
vmssClient *compute.VirtualMachineScaleSetVMsClient
}

// newClient returns a new Client.
func newClient() (*client, error) {
m, err := queryInstanceMetadata()
metadata, err := queryComputeInstanceMetadata()
if err != nil {
return nil, fmt.Errorf("error querying instance metadata: %s", err)
}
if m.Compute.SubscriptionID == "" {
return nil, fmt.Errorf("empty subscription name")
}
if m.Compute.ResourceGroupName == "" {
return nil, fmt.Errorf("empty resource group name")
if metadata.SubscriptionID == "" {
return nil, fmt.Errorf("empty subscription ID")
}

cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, fmt.Errorf("creating identity: %w", err)
}

vmssesClient, err := compute.NewVirtualMachineScaleSetsClient(m.Compute.SubscriptionID, cred, nil)
vmClient, err := compute.NewVirtualMachinesClient(metadata.SubscriptionID, cred, nil)
if err != nil {
return nil, fmt.Errorf("creating VMs client: %w", err)
}

vmssClient, err := compute.NewVirtualMachineScaleSetVMsClient(metadata.SubscriptionID, cred, nil)
if err != nil {
return nil, fmt.Errorf("creating VMSS client: %w", err)
return nil, fmt.Errorf("creating VMSS VMs client: %w", err)
}

return &client{
metadata: m,
vmssesClient: vmssesClient,
vmClient: vmClient,
vmssClient: vmssClient,
}, nil
}

// getVMScaleSet returns the specified VM ScaleSet.
func (c *client) getVMScaleSet(ctx context.Context, vmssName string) (*compute.VirtualMachineScaleSet, error) {
opts := &compute.VirtualMachineScaleSetsClientGetOptions{
Expand: fi.PtrTo(compute.ExpandTypesForGetVMScaleSetsUserData),
func (c *client) getVMTags(ctx context.Context, providerID string) (map[string]*string, error) {
if !strings.HasPrefix(providerID, "azure://") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would have put "azure://" in a const given the multiple occurrences.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be done for all providers, so I prefer to do it separately.

return nil, fmt.Errorf("unknown providerID : %s", providerID)
}
resp, err := c.vmssesClient.Get(ctx, c.metadata.Compute.ResourceGroupName, vmssName, opts)

res, err := arm.ParseResourceID(strings.TrimPrefix(providerID, "azure://"))
if err != nil {
return nil, fmt.Errorf("getting VMSS: %w", err)
return nil, fmt.Errorf("error parsing providerID: %v", err)
}

switch res.ResourceType.String() {
case "Microsoft.Compute/virtualMachines":
resp, err := c.vmClient.Get(ctx, res.ResourceGroupName, res.Name, nil)
if err != nil {
return nil, fmt.Errorf("getting VM: %w", err)
}
return resp.VirtualMachine.Tags, nil
case "Microsoft.Compute/virtualMachineScaleSets/virtualMachines":
resp, err := c.vmssClient.Get(ctx, res.ResourceGroupName, res.Parent.Name, res.Name, nil)
if err != nil {
return nil, fmt.Errorf("getting VMSS VM: %w", err)
}
return resp.VirtualMachineScaleSetVM.Tags, nil
default:
return nil, fmt.Errorf("unsupported resource type %q for %q", res.ResourceType, providerID)
}
return &resp.VirtualMachineScaleSet, nil
}

// queryInstanceMetadata queries Azure Instance Metadata documented in
// https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service.
func queryInstanceMetadata() (*instanceMetadata, error) {
type instanceMetadata struct {
SubscriptionID string `json:"subscriptionId"`
ResourceGroupName string `json:"resourceGroupName"`
}

// queryComputeInstanceMetadata queries Azure Instance Metadata.
// https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service
func queryComputeInstanceMetadata() (*instanceMetadata, error) {
client := &http.Client{}
req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/instance", nil)
req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/instance/compute", nil)
if err != nil {
return nil, fmt.Errorf("error creating a new request: %s", err)
}
req.Header.Add("Metadata", "True")

q := req.URL.Query()
q.Add("api-version", "2025-04-07")
q.Add("format", "json")
q.Add("api-version", "2020-06-01")
req.URL.RawQuery = q.Encode()

resp, err := client.Do(req)
Expand All @@ -109,17 +125,9 @@ func queryInstanceMetadata() (*instanceMetadata, error) {
if err != nil {
return nil, fmt.Errorf("error reading a response from the metadata server: %s", err)
}
metadata, err := unmarshalInstanceMetadata(body)
if err != nil {
metadata := &instanceMetadata{}
if err := json.Unmarshal(body, metadata); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata: %s", err)
}
return metadata, nil
}

func unmarshalInstanceMetadata(data []byte) (*instanceMetadata, error) {
m := &instanceMetadata{}
if err := json.Unmarshal(data, m); err != nil {
return nil, err
}
return m, nil
}
43 changes: 0 additions & 43 deletions pkg/nodeidentity/azure/client_test.go

This file was deleted.

58 changes: 31 additions & 27 deletions pkg/nodeidentity/azure/identify.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"strings"
"time"

compute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
corev1 "k8s.io/api/core/v1"
expirationcache "k8s.io/client-go/tools/cache"
"k8s.io/klog/v2"
Expand All @@ -42,16 +42,9 @@ const (
cacheTTL = 60 * time.Minute
)

type vmssGetter interface {
getVMScaleSet(ctx context.Context, vmssName string) (*compute.VirtualMachineScaleSet, error)
}

var _ vmssGetter = &client{}

// nodeIdentifier identifies a node from Azure VM.
type nodeIdentifier struct {
vmssGetter vmssGetter

azureClient *client
// cache is a cache of nodeidentity.Info
cache expirationcache.Store
// cacheEnabled indicates if caching should be used
Expand All @@ -68,7 +61,7 @@ func New(cacheNodeidentityInfo bool) (nodeidentity.Identifier, error) {
}

return &nodeIdentifier{
vmssGetter: client,
azureClient: client,
cache: expirationcache.NewTTLStore(stringKeyFunc, cacheTTL),
cacheEnabled: cacheNodeidentityInfo,
}, nil
Expand All @@ -78,17 +71,20 @@ func New(cacheNodeidentityInfo bool) (nodeidentity.Identifier, error) {
func (i *nodeIdentifier) IdentifyNode(ctx context.Context, node *corev1.Node) (*nodeidentity.Info, error) {
providerID := node.Spec.ProviderID
if providerID == "" {
return nil, fmt.Errorf("providerID was not set for node %s", node.Name)
return nil, fmt.Errorf("providerID not set for node %q", node.Name)
}
if !strings.HasPrefix(providerID, "azure://") {
return nil, fmt.Errorf("providerID %q not recognized for node %q", providerID, node.Name)
}
vmssName, err := getVMSSNameFromProviderID(providerID)

vmName, err := getVMNameFromProviderID(providerID)
if err != nil {
return nil, fmt.Errorf("error on extracting VM ScaleSet name: %s", err)
return nil, err
}

// If caching is enabled try pulling nodeidentity.Info from cache before
// doing a EC2 API call.
// If caching is enabled, try pulling nodeidentity.Info from the cache before doing an API call.
if i.cacheEnabled {
obj, exists, err := i.cache.GetByKey(vmssName)
obj, exists, err := i.cache.GetByKey(vmName)
if err != nil {
klog.Warningf("Nodeidentity info cache lookup failure: %v", err)
}
Expand All @@ -97,13 +93,13 @@ func (i *nodeIdentifier) IdentifyNode(ctx context.Context, node *corev1.Node) (*
}
}

vmss, err := i.vmssGetter.getVMScaleSet(ctx, vmssName)
tags, err := i.azureClient.getVMTags(ctx, providerID)
if err != nil {
return nil, fmt.Errorf("error on getting VM ScaleSet: %s", err)
}

labels := map[string]string{}
for k, v := range vmss.Tags {
for k, v := range tags {
if k == azure.TagClusterName && v != nil {
labels[kops.LabelClusterName] = *v
}
Expand All @@ -120,25 +116,25 @@ func (i *nodeIdentifier) IdentifyNode(ctx context.Context, node *corev1.Node) (*
case kops.InstanceGroupRoleNode.ToLowerString():
labels[nodelabels.RoleLabelNode16] = ""
default:
klog.Warningf("Unknown or unsupported node role tag %q for VMSS %q", k, vmssName)
klog.Warningf("Unknown or unsupported node role tag %q for VM %q", k, vmName)
}
}
if strings.HasPrefix(k, ClusterNodeTemplateLabel) && v != nil {
l := strings.SplitN(*v, "=", 2)
if len(l) <= 1 {
klog.Warningf("Malformed cloud label tag %q=%q for VMSS %q", k, *v, vmssName)
klog.Warningf("Malformed cloud label tag %q=%q for VM %q", k, *v, vmName)
} else {
labels[l[0]] = l[1]
}
}
}

info := &nodeidentity.Info{
InstanceID: vmssName,
InstanceID: vmName,
Labels: labels,
}

// If caching is enabled add the nodeidentity.Info to cache.
// If caching is enabled, add the nodeidentity.Info to the cache.
if i.cacheEnabled {
err = i.cache.Add(info)
if err != nil {
Expand All @@ -155,14 +151,22 @@ func stringKeyFunc(obj interface{}) (string, error) {
return key, nil
}

func getVMSSNameFromProviderID(providerID string) (string, error) {
func getVMNameFromProviderID(providerID string) (string, error) {
if !strings.HasPrefix(providerID, "azure://") {
return "", fmt.Errorf("providerID %q not recognized", providerID)
}

l := strings.Split(strings.TrimPrefix(providerID, "azure://"), "/")
if len(l) != 11 {
return "", fmt.Errorf("unexpected format of providerID %q", providerID)
res, err := arm.ParseResourceID(strings.TrimPrefix(providerID, "azure://"))
if err != nil {
return "", fmt.Errorf("error parsing providerID: %v", err)
}

switch res.ResourceType.String() {
case "Microsoft.Compute/virtualMachines":
return res.Name, nil
case "Microsoft.Compute/virtualMachineScaleSets/virtualMachines":
return res.Parent.Name + "_" + res.Name, nil
default:
return "", fmt.Errorf("unsupported resource type %q for providerID %q", res.ResourceType, providerID)
}
return l[len(l)-3], nil
}
Loading
Loading