44 "context"
55 "errors"
66 "fmt"
7+ "slices"
78 "strings"
89 "time"
910
@@ -44,8 +45,11 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
4445 if err != nil {
4546 return nil , err
4647 }
47- // Filter the list down to the instance types that are allowed by the configuration filter
48+ // Filter the list down to the instance types that are allowed by the configuration filter and the args
4849 for _ , singleInstanceType := range instanceTypesFromShadeformInstanceType {
50+ if ! isSelectedByArgs (singleInstanceType , args ) {
51+ continue
52+ }
4953 if c .isInstanceTypeAllowed (singleInstanceType .Type ) {
5054 instanceTypes = append (instanceTypes , singleInstanceType )
5155 }
@@ -55,6 +59,24 @@ func (c *ShadeformClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta
5559 return instanceTypes , nil
5660}
5761
62+ func isSelectedByArgs (instanceType v1.InstanceType , args v1.GetInstanceTypeArgs ) bool {
63+ if len (args .GPUManufacterers ) > 0 {
64+ if len (instanceType .SupportedGPUs ) == 0 {
65+ return false
66+ }
67+
68+ // For each supported GPU, check to see if the manufacture matches the args. The supported GPUs
69+ // must be a full subset of the args value.
70+ for _ , supportedGPU := range instanceType .SupportedGPUs {
71+ if ! slices .Contains (args .GPUManufacterers , supportedGPU .Manufacturer ) {
72+ return false
73+ }
74+ }
75+ }
76+
77+ return true
78+ }
79+
5880func (c * ShadeformClient ) GetInstanceTypePollTime () time.Duration {
5981 return 5 * time .Minute
6082}
@@ -153,6 +175,7 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
153175 }
154176
155177 gpuName := shadeformGPUTypeToBrevGPUName (shadeformInstanceType .Configuration .GpuType )
178+ gpuManufacturer := v1 .GetManufacturer (shadeformInstanceType .Configuration .GpuManufacturer )
156179
157180 for _ , region := range shadeformInstanceType .Availability {
158181 instanceTypes = append (instanceTypes , v1.InstanceType {
@@ -164,9 +187,9 @@ func (c *ShadeformClient) convertShadeformInstanceTypeToV1InstanceType(shadeform
164187 {
165188 Count : shadeformInstanceType .Configuration .NumGpus ,
166189 Memory : units .Base2Bytes (shadeformInstanceType .Configuration .VramPerGpuInGb ) * units .GiB ,
167- MemoryDetails : "" , // TODO: add memory details
190+ MemoryDetails : "" ,
168191 NetworkDetails : shadeformInstanceType .Configuration .Interconnect ,
169- Manufacturer : shadeformInstanceType . Configuration . GpuManufacturer ,
192+ Manufacturer : gpuManufacturer ,
170193 Name : gpuName ,
171194 Type : shadeformInstanceType .Configuration .GpuType ,
172195 },
0 commit comments