Skip to content

Commit

Permalink
Merge pull request #11 from elezar/CNT-4683/improve-pinned-memory-limits
Browse files Browse the repository at this point in the history
Rework MPS limit normalization
  • Loading branch information
elezar authored Mar 1, 2024
2 parents 11d749c + 8e42696 commit d00050e
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 30 deletions.
101 changes: 80 additions & 21 deletions api/nvidia.com/resource/gpu/nas/v1alpha1/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package v1alpha1

import (
"errors"
"fmt"
"strconv"

Expand Down Expand Up @@ -185,37 +186,95 @@ func (c TimeSliceDuration) Int() int {
return -1
}

// TODO: Always return a map of UUID -> limit
// ErrInvalidDeviceSelector indicates that a device index or UUID was invalid.
var ErrInvalidDeviceSelector error = errors.New("invalid device")

// ErrInvalidLimit indicates that a limit was invalid.
var ErrInvalidLimit error = errors.New("invalid limit")

// Normalize converts the specified per-device pinned memory limits to limits for the devices that are to be allocated.
// If provided, the defaultPinnedDeviceMemoryLimit is applied to each device before being overridden by specific values.
func (m MpsPerDevicePinnedMemoryLimit) Normalize(uuids []string, defaultPinnedDeviceMemoryLimit *resource.Quantity) (map[string]string, error) {
limits := make(map[string]string)

// We set the defaults for all expected devices.
if v := defaultPinnedDeviceMemoryLimit; v != nil {
value := v.Value() / 1024 / 1024
if value == 0 {
return nil, fmt.Errorf("default value set too low: %v", v)
}
for i := range uuids {
limits[fmt.Sprintf("%d", i)] = fmt.Sprintf("%vM", value)
}
limits, err := (*limit)(defaultPinnedDeviceMemoryLimit).get(uuids)
if err != nil {
return nil, err
}

devices := newUUIDSet(uuids)
for k, v := range m {
// TODO: This has to be an integer or a UUID
// TODO: Check that k is valid for the list of UUIDs. e.g. can't be greater than the length
_, err := strconv.Atoi(k)
id, err := devices.Normalize(k)
if err != nil {
return nil, fmt.Errorf("unable to parse key as an integer: %v", k)
return nil, err
}

value := v.Value() / 1024 / 1024
if value == 0 {
return nil, fmt.Errorf("value set too low: %v: %v", k, v)
megabyte, valid := (limit)(v).Megabyte()
if !valid {
return nil, fmt.Errorf("%w: value set too low: %v: %v", ErrInvalidLimit, k, v)
}
limits[id] = megabyte
}
return limits, nil
}

type limit resource.Quantity

func (d *limit) get(uuids []string) (map[string]string, error) {
limits := make(map[string]string)
if d == nil || len(uuids) == 0 {
return limits, nil
}

limits[k] = fmt.Sprintf("%vM", value)
megabyte, valid := d.Megabyte()
if !valid {
return nil, fmt.Errorf("%w: default value set too low: %v", ErrInvalidLimit, d)
}
for _, uuid := range uuids {
limits[uuid] = megabyte
}

return limits, nil
}

func (d limit) Value() int64 {
return (*resource.Quantity)(&d).Value()
}

func (d limit) Megabyte() (string, bool) {
v := d.Value() / 1024 / 1024
return fmt.Sprintf("%vM", v), v > 0
}

type uuidSet struct {
uuids []string
lookup map[string]bool
}

// newUUIDSet creates a set of UUIDs for managing pinned memory for requested devices.
func newUUIDSet(uuids []string) *uuidSet {
lookup := make(map[string]bool)
for _, uuid := range uuids {
lookup[uuid] = true
}

return &uuidSet{
uuids: uuids,
lookup: lookup,
}
}

func (s *uuidSet) Normalize(key string) (string, error) {
// Check whether key is a UUID
if _, ok := s.lookup[key]; ok {
return key, nil
}

index, err := strconv.Atoi(key)
if err != nil {
return "", fmt.Errorf("%w: unable to parse key as an integer: %v", ErrInvalidDeviceSelector, key)
}

if index >= 0 && index < len(s.uuids) {
return s.uuids[index], nil
}

return "", fmt.Errorf("%w: invalid device index: %v", ErrInvalidDeviceSelector, index)
}
93 changes: 84 additions & 9 deletions api/nvidia.com/resource/gpu/nas/v1alpha1/sharing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,45 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) {
expectedLimits map[string]string
}{
{
description: "no uuids, no default",
description: "empty input",
expectedLimits: map[string]string{},
},
{
description: "no uuids, invalid device index",
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"0": "1024M",
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "no uuids, default is overridden",
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"0": "1024M",
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "uuids, default is set",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
expectedLimits: map[string]string{
"0": "2048M",
"UUID0": "2048M",
},
},
{
description: "uuids, default is too low",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("1M")),
expectedError: v1alpha1.ErrInvalidLimit,
},
{
description: "uuids, override is too low",
uuids: []string{"UUID0"},
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"UUID0": resource.MustParse("1M"),
},
expectedError: v1alpha1.ErrInvalidLimit,
},
{
description: "uuids, default is overridden",
Expand All @@ -69,7 +83,68 @@ func TestMpsPerDevicePinnedMemoryLimitNormalize(t *testing.T) {
"0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"0": "1024M",
"UUID0": "1024M",
},
},
{
description: "uuids, default is overridden by uuid",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"UUID0": resource.MustParse("1Gi"),
},
expectedLimits: map[string]string{
"UUID0": "1024M",
},
},
{
description: "uuids, default is overridden, invalid UUID",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"UUID1": resource.MustParse("1Gi"),
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "uuids, default is overridden, invalid index",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("2Gi")),
perDeviceMemoryLimit: v1alpha1.MpsPerDevicePinnedMemoryLimit{
"1": resource.MustParse("1Gi"),
},
expectedError: v1alpha1.ErrInvalidDeviceSelector,
},
{
description: "unit conversion Mi to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("10Mi")),
expectedLimits: map[string]string{
"UUID0": "10M",
},
},
{
description: "unit conversion Gi to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("1Gi")),
expectedLimits: map[string]string{
"UUID0": "1024M",
},
},
{
description: "unit conversion M to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("10M")),
expectedLimits: map[string]string{
"UUID0": "9M",
},
},
{
description: "unit conversion G to M",
uuids: []string{"UUID0"},
memoryLimit: ptr(resource.MustParse("1G")),
expectedLimits: map[string]string{
"UUID0": "953M",
},
},
}
Expand Down

0 comments on commit d00050e

Please sign in to comment.