diff --git a/internal/lambdalabs/v1/errors.go b/internal/lambdalabs/v1/errors.go new file mode 100644 index 0000000..73aca79 --- /dev/null +++ b/internal/lambdalabs/v1/errors.go @@ -0,0 +1,40 @@ +package v1 + +import ( + "fmt" + "strings" + + v1 "github.com/brevdev/compute/pkg/v1" +) + +func handleLLErrToCloudErr(err error) error { + if err == nil { + return nil + } + + errStr := err.Error() + + if strings.Contains(errStr, "insufficient capacity") || + strings.Contains(errStr, "no capacity") || + strings.Contains(errStr, "capacity not available") { + return v1.ErrInsufficientResources + } + + if strings.Contains(errStr, "quota") || + strings.Contains(errStr, "limit exceeded") || + strings.Contains(errStr, "too many") { + return v1.ErrOutOfQuota + } + + if strings.Contains(errStr, "not found") || + strings.Contains(errStr, "does not exist") { + return v1.ErrInstanceNotFound + } + + if strings.Contains(errStr, "service unavailable") || + strings.Contains(errStr, "temporarily unavailable") { + return v1.ErrServiceUnavailable + } + + return fmt.Errorf("lambda labs error: %w", err) +} diff --git a/internal/lambdalabs/v1/instance.go b/internal/lambdalabs/v1/instance.go index 9ecc05f..a461bcb 100644 --- a/internal/lambdalabs/v1/instance.go +++ b/internal/lambdalabs/v1/instance.go @@ -3,9 +3,12 @@ package v1 import ( "context" "fmt" + "regexp" + "strconv" "strings" "time" + "github.com/alecthomas/units" openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" v1 "github.com/brevdev/compute/pkg/v1" ) @@ -21,17 +24,9 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn } if attrs.PublicKey != "" { - request := openapi.AddSSHKeyRequest{ - Name: keyPairName, - PublicKey: &attrs.PublicKey, - } - - _, resp, err := c.client.DefaultAPI.AddSSHKey(c.makeAuthContext(ctx)).AddSSHKeyRequest(request).Execute() - if resp != nil { - defer func() { _ = resp.Body.Close() }() - } - if err != nil && !strings.Contains(err.Error(), "name must be unique") { - return nil, fmt.Errorf("failed to add SSH key: %w", err) + err := c.addSSHKeyIdempotent(ctx, keyPairName, attrs.PublicKey) + if err != nil { + return nil, err } } @@ -60,7 +55,7 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return nil, fmt.Errorf("failed to launch instance: %w", err) + return nil, handleLLErrToCloudErr(err) } if len(resp.Data.InstanceIds) != 1 { @@ -79,7 +74,7 @@ func (c *LambdaLabsClient) GetInstance(ctx context.Context, instanceID v1.CloudP defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return nil, fmt.Errorf("failed to get instance: %w", err) + return nil, handleLLErrToCloudErr(err) } return convertLambdaLabsInstanceToV1Instance(resp.Data), nil @@ -97,7 +92,7 @@ func (c *LambdaLabsClient) TerminateInstance(ctx context.Context, instanceID v1. defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return fmt.Errorf("failed to terminate instance: %w", err) + return handleLLErrToCloudErr(err) } return nil @@ -111,7 +106,7 @@ func (c *LambdaLabsClient) ListInstances(ctx context.Context, _ v1.ListInstances defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return nil, fmt.Errorf("failed to list instances: %w", err) + return nil, handleLLErrToCloudErr(err) } instances := make([]v1.Instance, 0, len(resp.Data)) @@ -123,6 +118,67 @@ func (c *LambdaLabsClient) ListInstances(ctx context.Context, _ v1.ListInstances return instances, nil } +func (c *LambdaLabsClient) addSSHKeyIdempotent(ctx context.Context, keyName, publicKey string) error { + _, resp, err := c.client.DefaultAPI.AddSSHKey(c.makeAuthContext(ctx)).AddSSHKeyRequest(openapi.AddSSHKeyRequest{ + Name: keyName, + PublicKey: &publicKey, + }).Execute() + if resp != nil { + defer func() { _ = resp.Body.Close() }() + } + + if err != nil { + if strings.Contains(err.Error(), "name must be unique") { + return nil + } + return handleLLErrToCloudErr(err) + } + + return nil +} + +func parseGPUFromDescription(description string) v1.GPU { + gpu := v1.GPU{ + Manufacturer: "NVIDIA", + } + + countRegex := regexp.MustCompile(`(\d+)x`) + if countMatch := countRegex.FindStringSubmatch(description); len(countMatch) > 1 { + if count, err := strconv.ParseInt(countMatch[1], 10, 32); err == nil { + gpu.Count = int32(count) + } + } + + memoryRegex := regexp.MustCompile(`\((\d+)\s*GB\)`) + if memoryMatch := memoryRegex.FindStringSubmatch(description); len(memoryMatch) > 1 { + if memoryGiB, err := strconv.Atoi(memoryMatch[1]); err == nil { + gpu.Memory = units.Base2Bytes(memoryGiB) * units.GiB + } + } + + nameRegex := regexp.MustCompile(`\d+x\s+(.+?)\s+\(`) + if nameMatch := nameRegex.FindStringSubmatch(description); len(nameMatch) > 1 { + gpu.Name = strings.TrimSpace(nameMatch[1]) + gpu.Type = gpu.Name + } + + if strings.Contains(description, "SXM4") { + gpu.NetworkDetails = "SXM4" + } else if strings.Contains(description, "PCIe") { + gpu.NetworkDetails = "PCIe" + } + + return gpu +} + +func generateFirewallRuleFromPort(port int32) v1.FirewallRule { + return v1.FirewallRule{ + FromPort: port, + ToPort: port, + IPRanges: []string{"0.0.0.0/0"}, + } +} + // RebootInstance reboots an instance // Supported via: POST /api/v1/instance-operations/restart func (c *LambdaLabsClient) RebootInstance(ctx context.Context, instanceID v1.CloudProviderInstanceID) error { @@ -135,7 +191,7 @@ func (c *LambdaLabsClient) RebootInstance(ctx context.Context, instanceID v1.Clo defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return fmt.Errorf("failed to reboot instance: %w", err) + return handleLLErrToCloudErr(err) } return nil @@ -175,6 +231,17 @@ func convertLambdaLabsInstanceToV1Instance(llInstance openapi.Instance) *v1.Inst refID = llInstance.SshKeyNames[0] } + firewallRules := v1.FirewallRules{ + IngressRules: []v1.FirewallRule{ + generateFirewallRuleFromPort(22), + generateFirewallRuleFromPort(2222), + }, + EgressRules: []v1.FirewallRule{ + generateFirewallRuleFromPort(22), + generateFirewallRuleFromPort(2222), + }, + } + return &v1.Instance{ Name: name, RefID: refID, @@ -189,11 +256,14 @@ func convertLambdaLabsInstanceToV1Instance(llInstance openapi.Instance) *v1.Inst Status: v1.Status{ LifecycleStatus: convertLambdaLabsStatusToV1Status(llInstance.Status), }, - Location: llInstance.Region.Name, - SSHUser: "ubuntu", - SSHPort: 22, - Stoppable: false, - Rebootable: true, + Location: llInstance.Region.Name, + SSHUser: "ubuntu", + SSHPort: 22, + Stoppable: false, + Rebootable: true, + VolumeType: "ssd", + DiskSize: units.Base2Bytes(llInstance.InstanceType.Specs.StorageGib) * units.GiB, + FirewallRules: firewallRules, } } diff --git a/internal/lambdalabs/v1/instancetype.go b/internal/lambdalabs/v1/instancetype.go index 09f788d..9e37879 100644 --- a/internal/lambdalabs/v1/instancetype.go +++ b/internal/lambdalabs/v1/instancetype.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "regexp" - "strconv" "strings" "time" @@ -23,48 +21,36 @@ func (c *LambdaLabsClient) GetInstanceTypes(ctx context.Context, args v1.GetInst defer func() { _ = httpResp.Body.Close() }() } if err != nil { - return nil, fmt.Errorf("failed to get instance types: %w", err) + return nil, handleLLErrToCloudErr(err) } var instanceTypes []v1.InstanceType - for _, llInstanceTypeData := range resp.Data { - for _, region := range llInstanceTypeData.RegionsWithCapacityAvailable { - instanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType( - region.Name, - llInstanceTypeData.InstanceType, - true, - ) - if err != nil { - return nil, fmt.Errorf("failed to convert instance type: %w", err) - } - instanceTypes = append(instanceTypes, instanceType) + + for instanceTypeName, instanceTypeData := range resp.Data { + if strings.Contains(instanceTypeName, "gh") { + continue } - } - if len(args.Locations) > 0 && !args.Locations.IsAll() { - filtered := make([]v1.InstanceType, 0) - for _, it := range instanceTypes { - for _, loc := range args.Locations { - if it.Location == loc { - filtered = append(filtered, it) - break - } - } + if len(args.InstanceTypes) > 0 && !contains(args.InstanceTypes, instanceTypeName) { + continue + } + + availableRegions := make(map[string]bool) + for _, region := range instanceTypeData.RegionsWithCapacityAvailable { + availableRegions[region.Name] = true + } + + for _, region := range instanceTypeData.RegionsWithCapacityAvailable { + if len(args.Locations) > 0 && !args.Locations.IsAll() && !containsLocation(args.Locations, region.Name) { + continue } - instanceTypes = filtered - } - if len(args.InstanceTypes) > 0 { - filtered := make([]v1.InstanceType, 0) - for _, it := range instanceTypes { - for _, itName := range args.InstanceTypes { - if it.Type == itName { - filtered = append(filtered, it) - break - } - } + v1InstanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType(region.Name, instanceTypeData.InstanceType, true) + if err != nil { + return nil, fmt.Errorf("failed to convert instance type: %w", err) } - instanceTypes = filtered + instanceTypes = append(instanceTypes, v1InstanceType) + } } return instanceTypes, nil @@ -115,35 +101,6 @@ func convertLambdaLabsInstanceTypeToV1InstanceType(location string, llInstanceTy return instanceType, nil } -func parseGPUFromDescription(description string) v1.GPU { - countRegex := regexp.MustCompile(`(\d+)x`) - memoryRegex := regexp.MustCompile(`(\d+) GB`) - nameRegex := regexp.MustCompile(`x (.*?) \(`) - - var gpu v1.GPU - - if matches := countRegex.FindStringSubmatch(description); len(matches) > 1 { - if count, err := strconv.ParseInt(matches[1], 10, 32); err == nil { - gpu.Count = int32(count) - } - } - - if matches := memoryRegex.FindStringSubmatch(description); len(matches) > 1 { - if memory, err := strconv.Atoi(matches[1]); err == nil { - gpu.Memory = units.GiB * units.Base2Bytes(memory) - } - } - - if matches := nameRegex.FindStringSubmatch(description); len(matches) > 1 { - gpu.Name = strings.TrimSpace(matches[1]) - gpu.Type = gpu.Name - } - - gpu.Manufacturer = "NVIDIA" - - return gpu -} - const lambdaLocationsData = `[ {"location_name": "us-west-1", "description": "California, USA", "country": "USA"}, {"location_name": "us-west-2", "description": "Arizona, USA", "country": "USA"}, @@ -189,3 +146,21 @@ func (c *LambdaLabsClient) GetLocations(_ context.Context, _ v1.GetLocationsArgs return locations, nil } + +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +func containsLocation(locations v1.LocationsFilter, location string) bool { + for _, loc := range locations { + if loc == location { + return true + } + } + return false +}