diff --git a/go.mod b/go.mod index cfaadd6..7d5e0ce 100644 --- a/go.mod +++ b/go.mod @@ -7,26 +7,26 @@ toolchain go1.23.2 require ( github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b github.com/bojanz/currency v1.3.1 + github.com/cenkalti/backoff/v4 v4.3.0 + github.com/gliderlabs/ssh v0.3.8 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jarcoal/httpmock v1.4.0 github.com/nebius/gosdk v0.0.0-20250731090238-d96c0d4a5930 github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.41.0 ) require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20231030212536-12f9cba37c9d.2 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect - github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gliderlabs/ssh v0.3.8 // indirect github.com/gofrs/flock v0.12.1 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - golang.org/x/crypto v0.41.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.42.0 // indirect golang.org/x/sync v0.16.0 // indirect diff --git a/go.sum b/go.sum index 949962d..4717a57 100644 --- a/go.sum +++ b/go.sum @@ -58,20 +58,14 @@ golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/lambdalabs/v1/errors_test.go b/internal/lambdalabs/v1/errors_test.go new file mode 100644 index 0000000..d6a1cc7 --- /dev/null +++ b/internal/lambdalabs/v1/errors_test.go @@ -0,0 +1,327 @@ +package v1 + +import ( + "context" + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/cenkalti/backoff/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + openapi "github.com/brevdev/cloud/internal/lambdalabs/gen/lambdalabs" + v1 "github.com/brevdev/cloud/pkg/v1" +) + +func TestHandleAPIError_InstanceNotFound(t *testing.T) { + body := `{"error": {"message": "instance does not exist"}}` + resp := &http.Response{ + StatusCode: 404, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: "404 Not Found", + } + + err := handleAPIError(context.Background(), resp, errors.New("not found")) + + var permanentErr *backoff.PermanentError + require.True(t, errors.As(err, &permanentErr)) + assert.Equal(t, v1.ErrInstanceNotFound, permanentErr.Err) +} + +func TestHandleAPIError_BannedTemporarily(t *testing.T) { + body := `{"error": {"message": "banned you temporarily"}}` + resp := &http.Response{ + StatusCode: 429, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: "429 Too Many Requests", + } + + err := handleAPIError(context.Background(), resp, errors.New("rate limited")) + + var permanentErr *backoff.PermanentError + assert.False(t, errors.As(err, &permanentErr)) + assert.Contains(t, err.Error(), "LambdaLabs API error") + assert.Contains(t, err.Error(), "banned you temporarily") +} + +func TestHandleAPIError_ClientError(t *testing.T) { + tests := []struct { + name string + statusCode int + status string + }{ + {"bad request", 400, "400 Bad Request"}, + {"unauthorized", 401, "401 Unauthorized"}, + {"forbidden", 403, "403 Forbidden"}, + {"not found", 404, "404 Not Found"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := `{"error": {"message": "client error"}}` + resp := &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: tt.status, + } + + err := handleAPIError(context.Background(), resp, errors.New("client error")) + + var permanentErr *backoff.PermanentError + require.True(t, errors.As(err, &permanentErr)) + assert.Contains(t, permanentErr.Err.Error(), "LambdaLabs API error") + }) + } +} + +func TestHandleAPIError_TooManyRequests(t *testing.T) { + body := `{"error": {"message": "too many requests"}}` + resp := &http.Response{ + StatusCode: 429, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: "429 Too Many Requests", + } + + err := handleAPIError(context.Background(), resp, errors.New("rate limited")) + + var permanentErr *backoff.PermanentError + assert.False(t, errors.As(err, &permanentErr)) + assert.Contains(t, err.Error(), "LambdaLabs API error") +} + +func TestHandleAPIError_ServerError(t *testing.T) { + tests := []struct { + name string + statusCode int + status string + }{ + {"internal server error", 500, "500 Internal Server Error"}, + {"bad gateway", 502, "502 Bad Gateway"}, + {"service unavailable", 503, "503 Service Unavailable"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := `{"error": {"message": "server error"}}` + resp := &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: tt.status, + } + + err := handleAPIError(context.Background(), resp, errors.New("server error")) + + var permanentErr *backoff.PermanentError + assert.False(t, errors.As(err, &permanentErr)) + assert.Contains(t, err.Error(), "LambdaLabs API error") + }) + } +} + +func TestHandleAPIError_OpenAPIError(t *testing.T) { + body := `{"error": {"message": "test error"}}` + + openAPIErr := openapi.GenericOpenAPIError{} + + resp := &http.Response{ + StatusCode: 400, + Body: io.NopCloser(strings.NewReader(body)), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: "400 Bad Request", + } + + err := handleAPIError(context.Background(), resp, openAPIErr) + + var permanentErr *backoff.PermanentError + require.True(t, errors.As(err, &permanentErr)) + assert.Contains(t, permanentErr.Err.Error(), "LambdaLabs API error") + assert.Contains(t, permanentErr.Err.Error(), "/test") + assert.Contains(t, permanentErr.Err.Error(), "400 Bad Request") +} + +func TestHandleAPIError_EmptyBody(t *testing.T) { + resp := &http.Response{ + StatusCode: 400, + Body: io.NopCloser(strings.NewReader("")), + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: "400 Bad Request", + } + + err := handleAPIError(context.Background(), resp, errors.New("test error")) + + var permanentErr *backoff.PermanentError + require.True(t, errors.As(err, &permanentErr)) + assert.Contains(t, permanentErr.Err.Error(), "LambdaLabs API error") + assert.Contains(t, permanentErr.Err.Error(), "test error") +} + +func TestHandleAPIError_BodyReadError(t *testing.T) { + resp := &http.Response{ + StatusCode: 400, + Body: &errorReader{}, + Request: &http.Request{URL: &url.URL{Path: "/test"}}, + Status: "400 Bad Request", + } + + err := handleAPIError(context.Background(), resp, errors.New("test error")) + + var permanentErr *backoff.PermanentError + require.True(t, errors.As(err, &permanentErr)) + assert.Contains(t, permanentErr.Err.Error(), "LambdaLabs API error") +} + +type errorReader struct{} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, errors.New("read error") +} + +func (e *errorReader) Close() error { + return nil +} + +func TestHandleErrToCloudErr_NilError(t *testing.T) { + result := handleErrToCloudErr(nil) + assert.Nil(t, result) +} + +func TestHandleErrToCloudErr_CapacityErrors(t *testing.T) { + tests := []struct { + name string + errMsg string + expected error + }{ + { + name: "not enough capacity", + errMsg: "Not enough capacity in region", + expected: v1.ErrInsufficientResources, + }, + { + name: "insufficient capacity", + errMsg: "insufficient-capacity error occurred", + expected: v1.ErrInsufficientResources, + }, + { + name: "capacity with mixed case", + errMsg: "Error: Not enough capacity available", + expected: v1.ErrInsufficientResources, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputErr := errors.New(tt.errMsg) + result := handleErrToCloudErr(inputErr) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHandleErrToCloudErr_RegionErrors(t *testing.T) { + tests := []struct { + name string + errMsg string + expected error + }{ + { + name: "region does not exist", + errMsg: "global/invalid-parameters: Region us-invalid-1 does not exist", + expected: v1.ErrInsufficientResources, + }, + { + name: "region error with different format", + errMsg: "global/invalid-parameters error: Region eu-central-99 does not exist in this zone", + expected: v1.ErrInsufficientResources, + }, + { + name: "invalid parameters without region", + errMsg: "global/invalid-parameters: Invalid instance type", + expected: errors.New("global/invalid-parameters: Invalid instance type"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputErr := errors.New(tt.errMsg) + result := handleErrToCloudErr(inputErr) + if tt.expected == v1.ErrInsufficientResources { + assert.Equal(t, tt.expected, result) + } else { + assert.Equal(t, tt.expected.Error(), result.Error()) + } + }) + } +} + +func TestHandleErrToCloudErr_OtherErrors(t *testing.T) { + tests := []struct { + name string + errMsg string + }{ + { + name: "authentication error", + errMsg: "invalid API key provided", + }, + { + name: "network error", + errMsg: "connection timeout", + }, + { + name: "generic error", + errMsg: "something went wrong", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputErr := errors.New(tt.errMsg) + result := handleErrToCloudErr(inputErr) + assert.Equal(t, inputErr, result) + }) + } +} + +func TestHandleErrToCloudErr_EdgeCases(t *testing.T) { + tests := []struct { + name string + errMsg string + expected error + }{ + { + name: "empty error message", + errMsg: "", + expected: errors.New(""), + }, + { + name: "capacity substring in larger message", + errMsg: "The request failed because Not enough capacity is available in the selected region", + expected: v1.ErrInsufficientResources, + }, + { + name: "insufficient capacity with prefix", + errMsg: "API Error: insufficient-capacity - please try again later", + expected: v1.ErrInsufficientResources, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputErr := errors.New(tt.errMsg) + result := handleErrToCloudErr(inputErr) + if tt.expected == v1.ErrInsufficientResources { + assert.Equal(t, tt.expected, result) + } else { + assert.Equal(t, tt.expected.Error(), result.Error()) + } + }) + } +}