diff --git a/v1/instancetype.go b/v1/instancetype.go index 9e42d4e..4013800 100644 --- a/v1/instancetype.go +++ b/v1/instancetype.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "slices" "strings" "time" @@ -32,6 +33,25 @@ func GetManufacturer(manufacturer string) Manufacturer { } } +type Architecture string + +const ( + ArchitectureX86_64 Architecture = "x86_64" + ArchitectureARM64 Architecture = "arm64" + ArchitectureUnknown Architecture = "unknown" +) + +func GetArchitecture(architecture string) Architecture { + switch strings.ToLower(architecture) { + case "x86_64": + return ArchitectureX86_64 + case "arm64": + return ArchitectureARM64 + default: + return ArchitectureUnknown + } +} + type InstanceTypeID string type InstanceType struct { @@ -50,7 +70,7 @@ type InstanceType struct { SupportedNumCores []int32 DefaultCores int32 VCPU int32 - SupportedArchitectures []string + SupportedArchitectures []Architecture ClockSpeedInGhz float64 Quota InstanceTypeQuota Stoppable bool @@ -114,10 +134,67 @@ type CloudInstanceType interface { } type GetInstanceTypeArgs struct { - Locations LocationsFilter - SupportedArchitectures []string - InstanceTypes []string - GPUManufacterers []Manufacturer + Locations LocationsFilter + InstanceTypes []string + GPUManufactererFilter *GPUManufacturerFilter // nil means all GPU manufacturers are allowed + CloudFilter *CloudFilter // nil means all clouds are allowed + ArchitectureFilter *ArchitectureFilter // nil means all architectures are allowed +} + +type GPUManufacturerFilter struct { + // If IncludeGPUManufacturers is provided, only the GPU manufacturers in the list will be included + IncludeGPUManufacturers []Manufacturer + + // If ExcludeGPUManufacturers is provided, the GPU manufacturers in the list will be excluded + ExcludeGPUManufacturers []Manufacturer +} + +func (f *GPUManufacturerFilter) IsAllowed(manufacturer Manufacturer) bool { + if f.IncludeGPUManufacturers != nil && !slices.Contains(f.IncludeGPUManufacturers, manufacturer) { + return false + } + if f.ExcludeGPUManufacturers != nil && slices.Contains(f.ExcludeGPUManufacturers, manufacturer) { + return false + } + return true +} + +// CloudFilter allows for filtering of instance types by cloud. +type CloudFilter struct { + // If IncludeClouds is provided, only the clouds in the list will be included + IncludeClouds []string + + // If ExcludeClouds is provided, the clouds in the list will be excluded + ExcludeClouds []string +} + +func (f *CloudFilter) IsAllowed(cloud string) bool { + if f.IncludeClouds != nil && !slices.Contains(f.IncludeClouds, cloud) { + return false + } + if f.ExcludeClouds != nil && slices.Contains(f.ExcludeClouds, cloud) { + return false + } + return true +} + +// ArchitectureFilter allows for filtering of instance types by architecture. +type ArchitectureFilter struct { + // If IncludeArchitectures is provided, only the architectures in the list will be included + IncludeArchitectures []Architecture + + // If ExcludeArchitectures is provided, the architectures in the list will be excluded + ExcludeArchitectures []Architecture +} + +func (f *ArchitectureFilter) IsAllowed(architecture Architecture) bool { + if f.IncludeArchitectures != nil && !slices.Contains(f.IncludeArchitectures, architecture) { + return false + } + if f.ExcludeArchitectures != nil && slices.Contains(f.ExcludeArchitectures, architecture) { + return false + } + return true } // ValidateGetInstanceTypes validates that the GetInstanceTypes functionality works correctly diff --git a/v1/providers/lambdalabs/instancetype.go b/v1/providers/lambdalabs/instancetype.go index ecfe7af..6b79fbb 100644 --- a/v1/providers/lambdalabs/instancetype.go +++ b/v1/providers/lambdalabs/instancetype.go @@ -70,10 +70,10 @@ func (c *LambdaLabsClient) GetInstanceTypes(ctx context.Context, args v1.GetInst }) } - if len(args.SupportedArchitectures) > 0 { + if args.ArchitectureFilter != nil { instanceTypesFlattened = collections.Filter(instanceTypesFlattened, func(instanceType v1.InstanceType) bool { - for _, arch := range args.SupportedArchitectures { - if collections.ListContains(instanceType.SupportedArchitectures, arch) { + for _, arch := range instanceType.SupportedArchitectures { + if args.ArchitectureFilter.IsAllowed(arch) { return true } } @@ -190,7 +190,7 @@ func convertLambdaLabsInstanceTypeToV1InstanceType(location string, instType ope SupportedNumCores: []int32{}, DefaultCores: 0, VCPU: instType.Specs.Vcpus, - SupportedArchitectures: []string{"x86_64"}, + SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, ClockSpeedInGhz: 0, Stoppable: false, Rebootable: true, diff --git a/v1/providers/shadeform/instancetype.go b/v1/providers/shadeform/instancetype.go index 562d0f7..350a2a9 100644 --- a/v1/providers/shadeform/instancetype.go +++ b/v1/providers/shadeform/instancetype.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "slices" "strings" "time" @@ -60,15 +59,23 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta } func isSelectedByArgs(instanceType v1.InstanceType, args v1.GetInstanceTypeArgs) bool { - if len(args.GPUManufacterers) > 0 { - if len(instanceType.SupportedGPUs) == 0 { + if args.GPUManufactererFilter != nil { + for _, supportedGPU := range instanceType.SupportedGPUs { + if !args.GPUManufactererFilter.IsAllowed(supportedGPU.Manufacturer) { + return false + } + } + } + + if args.CloudFilter != nil { + if !args.CloudFilter.IsAllowed(instanceType.Cloud) { return false } + } - // For each supported GPU, check to see if the manufacture matches the args. The supported GPUs - // must be a full subset of the args value. - for _, supportedGPU := range instanceType.SupportedGPUs { - if !slices.Contains(args.GPUManufacterers, supportedGPU.Manufacturer) { + if args.ArchitectureFilter != nil { + for _, architecture := range instanceType.SupportedArchitectures { + if !args.ArchitectureFilter.IsAllowed(architecture) { return false } } @@ -177,6 +184,7 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform gpuName := shadeformGPUTypeToBrevGPUName(shadeformInstanceType.Configuration.GpuType) gpuManufacturer := v1.GetManufacturer(shadeformInstanceType.Configuration.GpuManufacturer) cloud := shadeformCloud(shadeformInstanceType.Cloud) + architecture := shadeformArchitecture(gpuName) for _, region := range shadeformInstanceType.Availability { instanceTypes = append(instanceTypes, v1.InstanceType{ @@ -202,11 +210,12 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform Size: units.Base2Bytes(shadeformInstanceType.Configuration.StorageInGb) * units.GiB, }, }, - BasePrice: basePrice, - IsAvailable: region.Available, - Location: region.Region, - Provider: CloudProviderID, - Cloud: cloud, + SupportedArchitectures: []v1.Architecture{architecture}, + BasePrice: basePrice, + IsAvailable: region.Available, + Location: region.Region, + Provider: CloudProviderID, + Cloud: cloud, }) } @@ -227,7 +236,6 @@ func shadeformGPUTypeToBrevGPUName(gpuType string) string { // Shadeform may include a memory size as a suffix. This must be cleaned up before // being used as a name. // e.g. A100_80GB -> A100, H100_40GB -> H100 - gpuType = strings.Split(gpuType, "_")[0] return gpuType } @@ -242,3 +250,11 @@ func shadeformCloud(cloud openapi.Cloud) string { return string(cloud) } + +func shadeformArchitecture(gpuName string) v1.Architecture { + // Shadeform currently does not specify the architecture, so we need to infer it from the GPU name. + if strings.HasPrefix(gpuName, "GH") || strings.HasPrefix(gpuName, "GB") { + return v1.ArchitectureARM64 + } + return v1.ArchitectureX86_64 +} diff --git a/v1/providers/shadeform/instancetype_test.go b/v1/providers/shadeform/instancetype_test.go new file mode 100644 index 0000000..fc4c15e --- /dev/null +++ b/v1/providers/shadeform/instancetype_test.go @@ -0,0 +1,97 @@ +package v1 + +import ( + "testing" + + v1 "github.com/brevdev/cloud/v1" + "github.com/stretchr/testify/assert" +) + +func TestIsSelectedByArgs(t *testing.T) { + t.Parallel() + + x8664nvidiaaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "aws"} + x8664nvidiagcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "gcp"} + x8664intelaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "aws"} + x8664intelgcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "gcp"} + arm64nvidiaaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "aws"} + arm64nvidiagcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "gcp"} + arm64intelaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "aws"} + arm64intelgcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "gcp"} + + all := []v1.InstanceType{x8664nvidiaaws, x8664intelaws, arm64nvidiaaws, arm64intelaws, x8664nvidiagcp, arm64nvidiagcp, x8664intelgcp, arm64intelgcp} + + cases := []struct { + name string + instanceTypes []v1.InstanceType + args v1.GetInstanceTypeArgs + want []v1.InstanceType + }{ + { + name: "no filters", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{}, + want: all, + }, + { + name: "include only x86_64 architecture", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{ArchitectureFilter: &v1.ArchitectureFilter{IncludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64}}}, + want: []v1.InstanceType{x8664nvidiaaws, x8664intelaws, x8664nvidiagcp, x8664intelgcp}, + }, + { + name: "exclude x86_64 architecture", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{ArchitectureFilter: &v1.ArchitectureFilter{ExcludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64}}}, + want: []v1.InstanceType{arm64nvidiaaws, arm64intelaws, arm64nvidiagcp, arm64intelgcp}, + }, + { + name: "include only nvidia manufacturer", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{GPUManufactererFilter: &v1.GPUManufacturerFilter{IncludeGPUManufacturers: []v1.Manufacturer{v1.ManufacturerNVIDIA}}}, + want: []v1.InstanceType{x8664nvidiaaws, x8664nvidiagcp, arm64nvidiaaws, arm64nvidiagcp}, + }, + { + name: "exclude nvidia manufacturer", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{GPUManufactererFilter: &v1.GPUManufacturerFilter{ExcludeGPUManufacturers: []v1.Manufacturer{v1.ManufacturerNVIDIA}}}, + want: []v1.InstanceType{x8664intelaws, x8664intelgcp, arm64intelaws, arm64intelgcp}, + }, + { + name: "include only aws cloud", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{CloudFilter: &v1.CloudFilter{IncludeClouds: []string{"aws"}}}, + want: []v1.InstanceType{x8664nvidiaaws, x8664intelaws, arm64nvidiaaws, arm64intelaws}, + }, + { + name: "exclude aws cloud", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{CloudFilter: &v1.CloudFilter{ExcludeClouds: []string{"aws"}}}, + want: []v1.InstanceType{x8664nvidiagcp, x8664intelgcp, arm64nvidiagcp, arm64intelgcp}, + }, + { + name: "include only aws cloud, exclude arm64 architecture, include nvidia manufacturer", + instanceTypes: all, + args: v1.GetInstanceTypeArgs{ + CloudFilter: &v1.CloudFilter{IncludeClouds: []string{"aws"}}, + ArchitectureFilter: &v1.ArchitectureFilter{ExcludeArchitectures: []v1.Architecture{v1.ArchitectureARM64}}, + GPUManufactererFilter: &v1.GPUManufacturerFilter{IncludeGPUManufacturers: []v1.Manufacturer{v1.ManufacturerNVIDIA}}, + }, + want: []v1.InstanceType{x8664nvidiaaws}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + selectedInstanceTypes := []v1.InstanceType{} + for _, instanceType := range tt.instanceTypes { + if isSelectedByArgs(instanceType, tt.args) { + selectedInstanceTypes = append(selectedInstanceTypes, instanceType) + } + } + assert.ElementsMatch(t, tt.want, selectedInstanceTypes) + }) + } +}