Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions internal/lambdalabs/v1/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package v1

import (
"fmt"
"strings"

v1 "github.com/brevdev/compute/pkg/v1"
)

func handleLLErrToCloudErr(err error) error {
if err == nil {
return nil
}

errStr := err.Error()

if strings.Contains(errStr, "insufficient capacity") ||
strings.Contains(errStr, "no capacity") ||
strings.Contains(errStr, "capacity not available") {
return v1.ErrInsufficientResources
}

if strings.Contains(errStr, "quota") ||
strings.Contains(errStr, "limit exceeded") ||
strings.Contains(errStr, "too many") {
return v1.ErrOutOfQuota
}

if strings.Contains(errStr, "not found") ||
strings.Contains(errStr, "does not exist") {
return v1.ErrInstanceNotFound
}

if strings.Contains(errStr, "service unavailable") ||
strings.Contains(errStr, "temporarily unavailable") {
return v1.ErrServiceUnavailable
}

return fmt.Errorf("lambda labs error: %w", err)
}
112 changes: 91 additions & 21 deletions internal/lambdalabs/v1/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package v1
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"time"

"github.com/alecthomas/units"
openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs"
v1 "github.com/brevdev/compute/pkg/v1"
)
Expand All @@ -21,17 +24,9 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn
}

if attrs.PublicKey != "" {
request := openapi.AddSSHKeyRequest{
Name: keyPairName,
PublicKey: &attrs.PublicKey,
}

_, resp, err := c.client.DefaultAPI.AddSSHKey(c.makeAuthContext(ctx)).AddSSHKeyRequest(request).Execute()
if resp != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil && !strings.Contains(err.Error(), "name must be unique") {
return nil, fmt.Errorf("failed to add SSH key: %w", err)
err := c.addSSHKeyIdempotent(ctx, keyPairName, attrs.PublicKey)
if err != nil {
return nil, err
}
}

Expand Down Expand Up @@ -60,7 +55,7 @@ func (c *LambdaLabsClient) CreateInstance(ctx context.Context, attrs v1.CreateIn
defer func() { _ = httpResp.Body.Close() }()
}
if err != nil {
return nil, fmt.Errorf("failed to launch instance: %w", err)
return nil, handleLLErrToCloudErr(err)
}

if len(resp.Data.InstanceIds) != 1 {
Expand All @@ -79,7 +74,7 @@ func (c *LambdaLabsClient) GetInstance(ctx context.Context, instanceID v1.CloudP
defer func() { _ = httpResp.Body.Close() }()
}
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
return nil, handleLLErrToCloudErr(err)
}

return convertLambdaLabsInstanceToV1Instance(resp.Data), nil
Expand All @@ -97,7 +92,7 @@ func (c *LambdaLabsClient) TerminateInstance(ctx context.Context, instanceID v1.
defer func() { _ = httpResp.Body.Close() }()
}
if err != nil {
return fmt.Errorf("failed to terminate instance: %w", err)
return handleLLErrToCloudErr(err)
}

return nil
Expand All @@ -111,7 +106,7 @@ func (c *LambdaLabsClient) ListInstances(ctx context.Context, _ v1.ListInstances
defer func() { _ = httpResp.Body.Close() }()
}
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
return nil, handleLLErrToCloudErr(err)
}

instances := make([]v1.Instance, 0, len(resp.Data))
Expand All @@ -123,6 +118,67 @@ func (c *LambdaLabsClient) ListInstances(ctx context.Context, _ v1.ListInstances
return instances, nil
}

func (c *LambdaLabsClient) addSSHKeyIdempotent(ctx context.Context, keyName, publicKey string) error {
_, resp, err := c.client.DefaultAPI.AddSSHKey(c.makeAuthContext(ctx)).AddSSHKeyRequest(openapi.AddSSHKeyRequest{
Name: keyName,
PublicKey: &publicKey,
}).Execute()
if resp != nil {
defer func() { _ = resp.Body.Close() }()
}

if err != nil {
if strings.Contains(err.Error(), "name must be unique") {
return nil
}
return handleLLErrToCloudErr(err)
}

return nil
}

func parseGPUFromDescription(description string) v1.GPU {
gpu := v1.GPU{
Manufacturer: "NVIDIA",
}

countRegex := regexp.MustCompile(`(\d+)x`)
if countMatch := countRegex.FindStringSubmatch(description); len(countMatch) > 1 {
if count, err := strconv.ParseInt(countMatch[1], 10, 32); err == nil {
gpu.Count = int32(count)
}
}

memoryRegex := regexp.MustCompile(`\((\d+)\s*GB\)`)
if memoryMatch := memoryRegex.FindStringSubmatch(description); len(memoryMatch) > 1 {
if memoryGiB, err := strconv.Atoi(memoryMatch[1]); err == nil {
gpu.Memory = units.Base2Bytes(memoryGiB) * units.GiB
}
}

nameRegex := regexp.MustCompile(`\d+x\s+(.+?)\s+\(`)
if nameMatch := nameRegex.FindStringSubmatch(description); len(nameMatch) > 1 {
gpu.Name = strings.TrimSpace(nameMatch[1])
gpu.Type = gpu.Name
}

if strings.Contains(description, "SXM4") {
gpu.NetworkDetails = "SXM4"
} else if strings.Contains(description, "PCIe") {
gpu.NetworkDetails = "PCIe"
}

return gpu
}

func generateFirewallRuleFromPort(port int32) v1.FirewallRule {
return v1.FirewallRule{
FromPort: port,
ToPort: port,
IPRanges: []string{"0.0.0.0/0"},
}
}

// RebootInstance reboots an instance
// Supported via: POST /api/v1/instance-operations/restart
func (c *LambdaLabsClient) RebootInstance(ctx context.Context, instanceID v1.CloudProviderInstanceID) error {
Expand All @@ -135,7 +191,7 @@ func (c *LambdaLabsClient) RebootInstance(ctx context.Context, instanceID v1.Clo
defer func() { _ = httpResp.Body.Close() }()
}
if err != nil {
return fmt.Errorf("failed to reboot instance: %w", err)
return handleLLErrToCloudErr(err)
}

return nil
Expand Down Expand Up @@ -175,6 +231,17 @@ func convertLambdaLabsInstanceToV1Instance(llInstance openapi.Instance) *v1.Inst
refID = llInstance.SshKeyNames[0]
}

firewallRules := v1.FirewallRules{
IngressRules: []v1.FirewallRule{
generateFirewallRuleFromPort(22),
generateFirewallRuleFromPort(2222),
},
EgressRules: []v1.FirewallRule{
generateFirewallRuleFromPort(22),
generateFirewallRuleFromPort(2222),
},
}

return &v1.Instance{
Name: name,
RefID: refID,
Expand All @@ -189,11 +256,14 @@ func convertLambdaLabsInstanceToV1Instance(llInstance openapi.Instance) *v1.Inst
Status: v1.Status{
LifecycleStatus: convertLambdaLabsStatusToV1Status(llInstance.Status),
},
Location: llInstance.Region.Name,
SSHUser: "ubuntu",
SSHPort: 22,
Stoppable: false,
Rebootable: true,
Location: llInstance.Region.Name,
SSHUser: "ubuntu",
SSHPort: 22,
Stoppable: false,
Rebootable: true,
VolumeType: "ssd",
DiskSize: units.Base2Bytes(llInstance.InstanceType.Specs.StorageGib) * units.GiB,
FirewallRules: firewallRules,
}
}

Expand Down
105 changes: 40 additions & 65 deletions internal/lambdalabs/v1/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"time"

Expand All @@ -23,48 +21,36 @@
defer func() { _ = httpResp.Body.Close() }()
}
if err != nil {
return nil, fmt.Errorf("failed to get instance types: %w", err)
return nil, handleLLErrToCloudErr(err)
}

var instanceTypes []v1.InstanceType
for _, llInstanceTypeData := range resp.Data {
for _, region := range llInstanceTypeData.RegionsWithCapacityAvailable {
instanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType(
region.Name,
llInstanceTypeData.InstanceType,
true,
)
if err != nil {
return nil, fmt.Errorf("failed to convert instance type: %w", err)
}
instanceTypes = append(instanceTypes, instanceType)

for instanceTypeName, instanceTypeData := range resp.Data {
if strings.Contains(instanceTypeName, "gh") {
continue
}
}

if len(args.Locations) > 0 && !args.Locations.IsAll() {
filtered := make([]v1.InstanceType, 0)
for _, it := range instanceTypes {
for _, loc := range args.Locations {
if it.Location == loc {
filtered = append(filtered, it)
break
}
}
if len(args.InstanceTypes) > 0 && !contains(args.InstanceTypes, instanceTypeName) {
continue
}

availableRegions := make(map[string]bool)
for _, region := range instanceTypeData.RegionsWithCapacityAvailable {
availableRegions[region.Name] = true
}

for _, region := range instanceTypeData.RegionsWithCapacityAvailable {
if len(args.Locations) > 0 && !args.Locations.IsAll() && !containsLocation(args.Locations, region.Name) {

Check failure on line 44 in internal/lambdalabs/v1/instancetype.go

View workflow job for this annotation

GitHub Actions / Test and Lint

File is not properly formatted (gofumpt)
continue
}
instanceTypes = filtered
}

if len(args.InstanceTypes) > 0 {
filtered := make([]v1.InstanceType, 0)
for _, it := range instanceTypes {
for _, itName := range args.InstanceTypes {
if it.Type == itName {
filtered = append(filtered, it)
break
}
}
v1InstanceType, err := convertLambdaLabsInstanceTypeToV1InstanceType(region.Name, instanceTypeData.InstanceType, true)
if err != nil {
return nil, fmt.Errorf("failed to convert instance type: %w", err)
}
instanceTypes = filtered
instanceTypes = append(instanceTypes, v1InstanceType)
}
}

return instanceTypes, nil
Expand Down Expand Up @@ -115,35 +101,6 @@
return instanceType, nil
}

func parseGPUFromDescription(description string) v1.GPU {
countRegex := regexp.MustCompile(`(\d+)x`)
memoryRegex := regexp.MustCompile(`(\d+) GB`)
nameRegex := regexp.MustCompile(`x (.*?) \(`)

var gpu v1.GPU

if matches := countRegex.FindStringSubmatch(description); len(matches) > 1 {
if count, err := strconv.ParseInt(matches[1], 10, 32); err == nil {
gpu.Count = int32(count)
}
}

if matches := memoryRegex.FindStringSubmatch(description); len(matches) > 1 {
if memory, err := strconv.Atoi(matches[1]); err == nil {
gpu.Memory = units.GiB * units.Base2Bytes(memory)
}
}

if matches := nameRegex.FindStringSubmatch(description); len(matches) > 1 {
gpu.Name = strings.TrimSpace(matches[1])
gpu.Type = gpu.Name
}

gpu.Manufacturer = "NVIDIA"

return gpu
}

const lambdaLocationsData = `[
{"location_name": "us-west-1", "description": "California, USA", "country": "USA"},
{"location_name": "us-west-2", "description": "Arizona, USA", "country": "USA"},
Expand Down Expand Up @@ -189,3 +146,21 @@

return locations, nil
}

func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

func containsLocation(locations v1.LocationsFilter, location string) bool {
for _, loc := range locations {
if loc == location {
return true
}
}
return false
}
Loading