Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 82 additions & 5 deletions v1/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -32,6 +33,25 @@ func GetManufacturer(manufacturer string) Manufacturer {
}
}

type Architecture string

const (
ArchitectureX86_64 Architecture = "x86_64"
ArchitectureARM64 Architecture = "arm64"
ArchitectureUnknown Architecture = "unknown"
)

func GetArchitecture(architecture string) Architecture {
switch strings.ToLower(architecture) {
case "x86_64":
return ArchitectureX86_64
case "arm64":
return ArchitectureARM64
default:
return ArchitectureUnknown
}
}

type InstanceTypeID string

type InstanceType struct {
Expand All @@ -50,7 +70,7 @@ type InstanceType struct {
SupportedNumCores []int32
DefaultCores int32
VCPU int32
SupportedArchitectures []string
SupportedArchitectures []Architecture
ClockSpeedInGhz float64
Quota InstanceTypeQuota
Stoppable bool
Expand Down Expand Up @@ -114,10 +134,67 @@ type CloudInstanceType interface {
}

type GetInstanceTypeArgs struct {
Locations LocationsFilter
SupportedArchitectures []string
InstanceTypes []string
GPUManufacterers []Manufacturer
Locations LocationsFilter
InstanceTypes []string
GPUManufactererFilter *GPUManufacturerFilter // nil means all GPU manufacturers are allowed
CloudFilter *CloudFilter // nil means all clouds are allowed
ArchitectureFilter *ArchitectureFilter // nil means all architectures are allowed
}

type GPUManufacturerFilter struct {
// If IncludeGPUManufacturers is provided, only the GPU manufacturers in the list will be included
IncludeGPUManufacturers []Manufacturer

// If ExcludeGPUManufacturers is provided, the GPU manufacturers in the list will be excluded
ExcludeGPUManufacturers []Manufacturer
}

func (f *GPUManufacturerFilter) IsAllowed(manufacturer Manufacturer) bool {
if f.IncludeGPUManufacturers != nil && !slices.Contains(f.IncludeGPUManufacturers, manufacturer) {
return false
}
if f.ExcludeGPUManufacturers != nil && slices.Contains(f.ExcludeGPUManufacturers, manufacturer) {
return false
}
return true
}

// CloudFilter allows for filtering of instance types by cloud.
type CloudFilter struct {
// If IncludeClouds is provided, only the clouds in the list will be included
IncludeClouds []string

// If ExcludeClouds is provided, the clouds in the list will be excluded
ExcludeClouds []string
}

func (f *CloudFilter) IsAllowed(cloud string) bool {
if f.IncludeClouds != nil && !slices.Contains(f.IncludeClouds, cloud) {
return false
}
if f.ExcludeClouds != nil && slices.Contains(f.ExcludeClouds, cloud) {
return false
}
return true
}

// ArchitectureFilter allows for filtering of instance types by architecture.
type ArchitectureFilter struct {
// If IncludeArchitectures is provided, only the architectures in the list will be included
IncludeArchitectures []Architecture

// If ExcludeArchitectures is provided, the architectures in the list will be excluded
ExcludeArchitectures []Architecture
}

func (f *ArchitectureFilter) IsAllowed(architecture Architecture) bool {
if f.IncludeArchitectures != nil && !slices.Contains(f.IncludeArchitectures, architecture) {
return false
}
if f.ExcludeArchitectures != nil && slices.Contains(f.ExcludeArchitectures, architecture) {
return false
}
return true
}

// ValidateGetInstanceTypes validates that the GetInstanceTypes functionality works correctly
Expand Down
8 changes: 4 additions & 4 deletions v1/providers/lambdalabs/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ func (c *LambdaLabsClient) GetInstanceTypes(ctx context.Context, args v1.GetInst
})
}

if len(args.SupportedArchitectures) > 0 {
if args.ArchitectureFilter != nil {
instanceTypesFlattened = collections.Filter(instanceTypesFlattened, func(instanceType v1.InstanceType) bool {
for _, arch := range args.SupportedArchitectures {
if collections.ListContains(instanceType.SupportedArchitectures, arch) {
for _, arch := range instanceType.SupportedArchitectures {
if args.ArchitectureFilter.IsAllowed(arch) {
return true
}
}
Expand Down Expand Up @@ -190,7 +190,7 @@ func convertLambdaLabsInstanceTypeToV1InstanceType(location string, instType ope
SupportedNumCores: []int32{},
DefaultCores: 0,
VCPU: instType.Specs.Vcpus,
SupportedArchitectures: []string{"x86_64"},
SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64},
ClockSpeedInGhz: 0,
Stoppable: false,
Rebootable: true,
Expand Down
42 changes: 29 additions & 13 deletions v1/providers/shadeform/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -60,15 +59,23 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
}

func isSelectedByArgs(instanceType v1.InstanceType, args v1.GetInstanceTypeArgs) bool {
if len(args.GPUManufacterers) > 0 {
if len(instanceType.SupportedGPUs) == 0 {
if args.GPUManufactererFilter != nil {
for _, supportedGPU := range instanceType.SupportedGPUs {
if !args.GPUManufactererFilter.IsAllowed(supportedGPU.Manufacturer) {
return false
}
}
}

if args.CloudFilter != nil {
if !args.CloudFilter.IsAllowed(instanceType.Cloud) {
return false
}
}

// For each supported GPU, check to see if the manufacture matches the args. The supported GPUs
// must be a full subset of the args value.
for _, supportedGPU := range instanceType.SupportedGPUs {
if !slices.Contains(args.GPUManufacterers, supportedGPU.Manufacturer) {
if args.ArchitectureFilter != nil {
for _, architecture := range instanceType.SupportedArchitectures {
if !args.ArchitectureFilter.IsAllowed(architecture) {
return false
}
}
Expand Down Expand Up @@ -177,6 +184,7 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
gpuName := shadeformGPUTypeToBrevGPUName(shadeformInstanceType.Configuration.GpuType)
gpuManufacturer := v1.GetManufacturer(shadeformInstanceType.Configuration.GpuManufacturer)
cloud := shadeformCloud(shadeformInstanceType.Cloud)
architecture := shadeformArchitecture(gpuName)

for _, region := range shadeformInstanceType.Availability {
instanceTypes = append(instanceTypes, v1.InstanceType{
Expand All @@ -202,11 +210,12 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
Size: units.Base2Bytes(shadeformInstanceType.Configuration.StorageInGb) * units.GiB,
},
},
BasePrice: basePrice,
IsAvailable: region.Available,
Location: region.Region,
Provider: CloudProviderID,
Cloud: cloud,
SupportedArchitectures: []v1.Architecture{architecture},
BasePrice: basePrice,
IsAvailable: region.Available,
Location: region.Region,
Provider: CloudProviderID,
Cloud: cloud,
})
}

Expand All @@ -227,7 +236,6 @@ func shadeformGPUTypeToBrevGPUName(gpuType string) string {
// Shadeform may include a memory size as a suffix. This must be cleaned up before
// being used as a name.
// e.g. A100_80GB -> A100, H100_40GB -> H100

gpuType = strings.Split(gpuType, "_")[0]
return gpuType
}
Expand All @@ -242,3 +250,11 @@ func shadeformCloud(cloud openapi.Cloud) string {

return string(cloud)
}

func shadeformArchitecture(gpuName string) v1.Architecture {
// Shadeform currently does not specify the architecture, so we need to infer it from the GPU name.
if strings.HasPrefix(gpuName, "GH") || strings.HasPrefix(gpuName, "GB") {
return v1.ArchitectureARM64
}
return v1.ArchitectureX86_64
}
97 changes: 97 additions & 0 deletions v1/providers/shadeform/instancetype_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package v1

import (
"testing"

v1 "github.com/brevdev/cloud/v1"
"github.com/stretchr/testify/assert"
)

func TestIsSelectedByArgs(t *testing.T) {
t.Parallel()

x8664nvidiaaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "aws"}
x8664nvidiagcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "gcp"}
x8664intelaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "aws"}
x8664intelgcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureX86_64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "gcp"}
arm64nvidiaaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "aws"}
arm64nvidiagcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerNVIDIA}}, Cloud: "gcp"}
arm64intelaws := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "aws"}
arm64intelgcp := v1.InstanceType{SupportedArchitectures: []v1.Architecture{v1.ArchitectureARM64}, SupportedGPUs: []v1.GPU{{Manufacturer: v1.ManufacturerIntel}}, Cloud: "gcp"}

all := []v1.InstanceType{x8664nvidiaaws, x8664intelaws, arm64nvidiaaws, arm64intelaws, x8664nvidiagcp, arm64nvidiagcp, x8664intelgcp, arm64intelgcp}

cases := []struct {
name string
instanceTypes []v1.InstanceType
args v1.GetInstanceTypeArgs
want []v1.InstanceType
}{
{
name: "no filters",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{},
want: all,
},
{
name: "include only x86_64 architecture",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{ArchitectureFilter: &v1.ArchitectureFilter{IncludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64}}},
want: []v1.InstanceType{x8664nvidiaaws, x8664intelaws, x8664nvidiagcp, x8664intelgcp},
},
{
name: "exclude x86_64 architecture",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{ArchitectureFilter: &v1.ArchitectureFilter{ExcludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64}}},
want: []v1.InstanceType{arm64nvidiaaws, arm64intelaws, arm64nvidiagcp, arm64intelgcp},
},
{
name: "include only nvidia manufacturer",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{GPUManufactererFilter: &v1.GPUManufacturerFilter{IncludeGPUManufacturers: []v1.Manufacturer{v1.ManufacturerNVIDIA}}},
want: []v1.InstanceType{x8664nvidiaaws, x8664nvidiagcp, arm64nvidiaaws, arm64nvidiagcp},
},
{
name: "exclude nvidia manufacturer",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{GPUManufactererFilter: &v1.GPUManufacturerFilter{ExcludeGPUManufacturers: []v1.Manufacturer{v1.ManufacturerNVIDIA}}},
want: []v1.InstanceType{x8664intelaws, x8664intelgcp, arm64intelaws, arm64intelgcp},
},
{
name: "include only aws cloud",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{CloudFilter: &v1.CloudFilter{IncludeClouds: []string{"aws"}}},
want: []v1.InstanceType{x8664nvidiaaws, x8664intelaws, arm64nvidiaaws, arm64intelaws},
},
{
name: "exclude aws cloud",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{CloudFilter: &v1.CloudFilter{ExcludeClouds: []string{"aws"}}},
want: []v1.InstanceType{x8664nvidiagcp, x8664intelgcp, arm64nvidiagcp, arm64intelgcp},
},
{
name: "include only aws cloud, exclude arm64 architecture, include nvidia manufacturer",
instanceTypes: all,
args: v1.GetInstanceTypeArgs{
CloudFilter: &v1.CloudFilter{IncludeClouds: []string{"aws"}},
ArchitectureFilter: &v1.ArchitectureFilter{ExcludeArchitectures: []v1.Architecture{v1.ArchitectureARM64}},
GPUManufactererFilter: &v1.GPUManufacturerFilter{IncludeGPUManufacturers: []v1.Manufacturer{v1.ManufacturerNVIDIA}},
},
want: []v1.InstanceType{x8664nvidiaaws},
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

selectedInstanceTypes := []v1.InstanceType{}
for _, instanceType := range tt.instanceTypes {
if isSelectedByArgs(instanceType, tt.args) {
selectedInstanceTypes = append(selectedInstanceTypes, instanceType)
}
}
assert.ElementsMatch(t, tt.want, selectedInstanceTypes)
})
}
}
Loading