Skip to content

Commit 6a3feda

Browse files
committed
feat(BREV-1659): Manufacturer
1 parent e16029d commit 6a3feda

File tree

4 files changed

+50
-6
lines changed

4 files changed

+50
-6
lines changed

v1/instancetype.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,33 @@ import (
55
"errors"
66
"fmt"
77
"reflect"
8+
"strings"
89
"time"
910

1011
"github.com/alecthomas/units"
1112
"github.com/bojanz/currency"
1213
"github.com/google/go-cmp/cmp"
1314
)
1415

16+
type Manufacturer string
17+
18+
const (
19+
ManufacturerNVIDIA Manufacturer = "NVIDIA"
20+
ManufacturerIntel Manufacturer = "Intel"
21+
ManufacturerUnknown Manufacturer = "unknown"
22+
)
23+
24+
func GetManufacturer(manufacturer string) Manufacturer {
25+
switch strings.ToLower(manufacturer) {
26+
case "nvidia":
27+
return ManufacturerNVIDIA
28+
case "intel":
29+
return ManufacturerIntel
30+
default:
31+
return ManufacturerUnknown
32+
}
33+
}
34+
1535
type InstanceTypeID string
1636

1737
type InstanceType struct {
@@ -76,7 +96,7 @@ type GPU struct {
7696
Memory units.Base2Bytes
7797
MemoryDetails string // "", "HBM", "GDDR", "DDR", etc.
7898
NetworkDetails string // "PCIe", "SXM4", "SXM5", etc.
79-
Manufacturer string
99+
Manufacturer Manufacturer
80100
Name string
81101
Type string
82102
}
@@ -97,6 +117,7 @@ type GetInstanceTypeArgs struct {
97117
Locations LocationsFilter
98118
SupportedArchitectures []string
99119
InstanceTypes []string
120+
GPUManufacterers []Manufacturer
100121
}
101122

102123
// ValidateGetInstanceTypes validates that the GetInstanceTypes functionality works correctly

v1/providers/shadeform/gen/shadeform/model_instance_configuration.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

v1/providers/shadeform/gen/shadeform/model_instance_type_configuration.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

v1/providers/shadeform/instancetype.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"slices"
78
"strings"
89
"time"
910

@@ -44,8 +45,11 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
4445
if err != nil {
4546
return nil, err
4647
}
47-
// Filter the list down to the instance types that are allowed by the configuration filter
48+
// Filter the list down to the instance types that are allowed by the configuration filter and the args
4849
for _, singleInstanceType := range instanceTypesFromShadeformInstanceType {
50+
if !isSelectedByArgs(singleInstanceType, args) {
51+
continue
52+
}
4953
if c.isInstanceTypeAllowed(singleInstanceType.Type) {
5054
instanceTypes = append(instanceTypes, singleInstanceType)
5155
}
@@ -55,6 +59,24 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
5559
return instanceTypes, nil
5660
}
5761

62+
func isSelectedByArgs(instanceType v1.InstanceType, args v1.GetInstanceTypeArgs) bool {
63+
if len(args.GPUManufacterers) > 0 {
64+
if len(instanceType.SupportedGPUs) == 0 {
65+
return false
66+
}
67+
68+
// For each supported GPU, check to see if the manufacture matches the args. The supported GPUs
69+
// must be a full subset of the args value.
70+
for _, supportedGPU := range instanceType.SupportedGPUs {
71+
if !slices.Contains(args.GPUManufacterers, supportedGPU.Manufacturer) {
72+
return false
73+
}
74+
}
75+
}
76+
77+
return true
78+
}
79+
5880
func (c *ShadeformClient) GetInstanceTypePollTime() time.Duration {
5981
return 5 * time.Minute
6082
}
@@ -153,6 +175,7 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
153175
}
154176

155177
gpuName := shadeformGPUTypeToBrevGPUName(shadeformInstanceType.Configuration.GpuType)
178+
gpuManufacturer := v1.GetManufacturer(shadeformInstanceType.Configuration.GpuManufacturer)
156179

157180
for _, region := range shadeformInstanceType.Availability {
158181
instanceTypes = append(instanceTypes, v1.InstanceType{
@@ -164,9 +187,9 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
164187
{
165188
Count: shadeformInstanceType.Configuration.NumGpus,
166189
Memory: units.Base2Bytes(shadeformInstanceType.Configuration.VramPerGpuInGb) * units.GiB,
167-
MemoryDetails: "", // TODO: add memory details
190+
MemoryDetails: "",
168191
NetworkDetails: shadeformInstanceType.Configuration.Interconnect,
169-
Manufacturer: shadeformInstanceType.Configuration.GpuManufacturer,
192+
Manufacturer: gpuManufacturer,
170193
Name: gpuName,
171194
Type: shadeformInstanceType.Configuration.GpuType,
172195
},

0 commit comments

Comments
 (0)